swp2.py 29KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736
  1. import matplotlib.pyplot as plt
  2. import os
  3. import numpy as np
  4. import math
  5. import json
  6. def log_facto(k):
  7. """
  8. Using the Stirling's approximation
  9. """
  10. k = int(k)
  11. if k > 1e6:
  12. return k * np.log(k) - k + np.log(2*math.pi*k)/2
  13. val = 0
  14. for i in range(2, k+1):
  15. val += np.log(i)
  16. return val
  17. def return_x_y_from_stwp_theta_file(stwp_theta_file, breaks, mu, tgen, relative_theta_scale = False):
  18. with open(stwp_theta_file, "r") as swp_file:
  19. # Read the first line
  20. line = swp_file.readline()
  21. L = float(line.split()[2])
  22. rands = swp_file.readline()
  23. line = swp_file.readline()
  24. # skip empty lines before SFS
  25. while line == "\n":
  26. line = swp_file.readline()
  27. sfs = np.array(line.split()).astype(float)
  28. # Process lines until the end of the file
  29. while line:
  30. # check at each line
  31. if line.startswith("dim") :
  32. dim = int(line.split()[1])
  33. if dim == breaks+1:
  34. likelihood = line.split()[5]
  35. groups = line.split()[6:6+dim]
  36. theta_site = line.split()[6+dim:6+dim+1+dim]
  37. elif dim < breaks+1:
  38. line = swp_file.readline()
  39. continue
  40. elif dim > breaks+1:
  41. break
  42. #return 0,0,0
  43. # Read the next line
  44. line = swp_file.readline()
  45. #### END of parsing
  46. # quit this file if the number of dimensions is incorrect
  47. if dim < breaks+1:
  48. return 0,0,0,0,0,0
  49. # get n, the last bin of the last group
  50. # revert the list of groups as the most recent times correspond
  51. # to the closest and last leafs of the coal. tree.
  52. groups = groups[::-1]
  53. theta_site = theta_site[::-1]
  54. # store thetas for later use
  55. grps = groups.copy()
  56. thetas = {}
  57. for i in range(len(groups)):
  58. grps[i] = grps[i].split(',')
  59. thetas[i] = [float(theta_site[i]), grps[i], likelihood]
  60. # initiate the dict of times
  61. t = {}
  62. # list of thetas
  63. theta_L = []
  64. sum_t = 0
  65. for group_nb, group in enumerate(groups):
  66. ###print(group_nb, group, theta_site[group_nb], len(theta_site))
  67. # store all the thetas one by one, with one theta per group
  68. theta_L.append(float(theta_site[group_nb]))
  69. # if the group is of size 1
  70. if len(group.split(',')) == 1:
  71. i = int(group)
  72. # if the group size is >1, take the first elem of the group
  73. # i is the first bin of each group, straight after a breakpoint
  74. else:
  75. i = int(group.split(",")[0])
  76. j = int(group.split(",")[-1])
  77. t[i] = 0
  78. #t =
  79. if len(group.split(',')) == 1:
  80. k = i
  81. if relative_theta_scale:
  82. t[i] += ((theta_L[group_nb] ) / (k*(k-1)))
  83. else:
  84. t[i] += ((theta_L[group_nb] ) / (k*(k-1)) * tgen) / mu
  85. else:
  86. for k in range(j, i-1, -1 ):
  87. if relative_theta_scale:
  88. t[i] += ((theta_L[group_nb] ) / (k*(k-1)))
  89. else:
  90. t[i] += ((theta_L[group_nb] ) / (k*(k-1)) * tgen) / mu
  91. # we add the cumulative times at the end
  92. t[i] += sum_t
  93. sum_t = t[i]
  94. # build the y axis (sizes)
  95. y = []
  96. for theta in theta_L:
  97. if relative_theta_scale:
  98. size = theta
  99. else:
  100. # with size N = theta/4mu
  101. size = theta / (4*mu)
  102. y.append(size)
  103. y.append(size)
  104. # build the time x axis
  105. x = [0]
  106. for time in range(0, len(t.values())-1):
  107. x.append(list(t.values())[time])
  108. x.append(list(t.values())[time])
  109. x.append(list(t.values())[len(t.values())-1])
  110. # if relative_theta_scale:
  111. # # rescale
  112. # #N0 = y[0]
  113. # # for i in range(len(y)):
  114. # # # divide by N0
  115. # # y[i] = y[i]/N0
  116. # # x[i] = x[i]/N0
  117. return x,y,likelihood,thetas,sfs,L
  118. def plot_k_epochs_thetafolder(folder_path, mu, tgen, breaks = 2, title = "Title", theta_scale = True):
  119. scenari = {}
  120. cpt = 0
  121. for file_name in os.listdir(folder_path):
  122. if os.path.isfile(os.path.join(folder_path, file_name)):
  123. # Perform actions on each file
  124. x, y, likelihood, theta, sfs, L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
  125. tgen = tgen,
  126. mu = mu, relative_theta_scale = theta_scale)
  127. if x == 0 or y == 0:
  128. continue
  129. cpt +=1
  130. scenari[likelihood] = x,y
  131. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
  132. print(cpt, "theta file(s) have been scanned.")
  133. # sort starting by the smallest -log(Likelihood)
  134. print(scenari)
  135. best10_scenari = (sorted(list(scenari.keys())))[:10]
  136. print("10 greatest Likelihoods", best10_scenari)
  137. greatest_likelihood = best10_scenari[0]
  138. x, y = scenari[greatest_likelihood]
  139. my_dpi = 300
  140. plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  141. plt.plot(x, y, 'r-', lw=2, label = 'Lik='+greatest_likelihood)
  142. #plt.yscale('log')
  143. plt.xscale('log')
  144. plt.grid(True,which="both", linestyle='--', alpha = 0.3)
  145. for scenario in best10_scenari[1:]:
  146. x,y = scenari[scenario]
  147. #print("\n---- Lik:",scenario,"\n\nt=", x,"\n\nN=",y, "\n\n")
  148. plt.plot(x, y, '--', lw=1, label = 'Lik='+scenario)
  149. if theta_scale:
  150. plt.xlabel("Coal. time")
  151. plt.ylabel("Pop. size scaled by N0")
  152. recent_scale_lower_bound = y[0] * 0.01
  153. recent_scale_upper_bound = y[0] * 0.1
  154. plt.axvline(x=recent_scale_lower_bound)
  155. plt.axvline(x=recent_scale_upper_bound)
  156. else:
  157. # years
  158. plt.xlabel("Time (years)")
  159. plt.ylabel("Individuals (N)")
  160. plt.legend(loc='upper right')
  161. plt.title(title)
  162. plt.savefig(title+'_b'+str(breaks)+'.pdf')
  163. def plot_straight_x_y(x,y):
  164. x_1 = [x[0]]
  165. y_1 = []
  166. for i in range(0, len(y)-1):
  167. x_1.append(x[i])
  168. x_1.append(x[i])
  169. y_1.append(y[i])
  170. y_1.append(y[i])
  171. y_1 = y_1+[y[-1],y[-1]]
  172. x_1.append(x[-1])
  173. return x_1, y_1
  174. def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_scale = True, ax = None):
  175. #scenari = {}
  176. cpt = 0
  177. epochs = {}
  178. for file_name in os.listdir(folder_path):
  179. breaks = 0
  180. cpt +=1
  181. if os.path.isfile(os.path.join(folder_path, file_name)):
  182. x, y, likelihood, theta, sfs, L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
  183. tgen = tgen,
  184. mu = mu, relative_theta_scale = theta_scale)
  185. SFS_stored = sfs
  186. L_stored = L
  187. while not (x == 0 and y == 0):
  188. if breaks not in epochs.keys():
  189. epochs[breaks] = {}
  190. epochs[breaks][likelihood] = x,y
  191. breaks += 1
  192. x,y,likelihood,theta,sfs,L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
  193. tgen = tgen,
  194. mu = mu, relative_theta_scale = theta_scale)
  195. if x == 0:
  196. # last break did not work, then breaks = breaks-1
  197. breaks -= 1
  198. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
  199. print(cpt, "theta file(s) have been scanned.")
  200. my_dpi = 300
  201. if ax is None:
  202. # intialize figure
  203. my_dpi = 300
  204. fnt_size = 18
  205. # plt.rcParams['font.size'] = fnt_size
  206. fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  207. else:
  208. fnt_size = 12
  209. # plt.rcParams['font.size'] = fnt_size
  210. ax1 = ax[1][0,0]
  211. ax1.set_yscale('log')
  212. ax1.set_xscale('log')
  213. ax1.grid(True,which="both", linestyle='--', alpha = 0.3)
  214. brkpt_lik = []
  215. top_plots = {}
  216. for epoch, scenari in epochs.items():
  217. # sort starting by the smallest -log(Likelihood)
  218. best10_scenari = (sorted(list(scenari.keys())))[:10]
  219. greatest_likelihood = best10_scenari[0]
  220. # store the tuple breakpoints and likelihood for later plot
  221. brkpt_lik.append((epoch, greatest_likelihood))
  222. x, y = scenari[greatest_likelihood]
  223. #without breakpoint
  224. if epoch == 0:
  225. # do something with the theta without bp and skip the plotting
  226. N0 = y[0]
  227. #continue
  228. for i in range(len(y)):
  229. # divide by N0
  230. y[i] = y[i]/N0
  231. x[i] = x[i]/N0
  232. top_plots[greatest_likelihood] = x,y,epoch
  233. plots_likelihoods = list(top_plots.keys())
  234. for i in range(len(plots_likelihoods)):
  235. plots_likelihoods[i] = float(plots_likelihoods[i])
  236. best10_plots = sorted(plots_likelihoods)[:10]
  237. top_plot_lik = str(best10_plots[0])
  238. plot_handles = []
  239. # plt.rcParams['font.size'] = fnt_size
  240. p0, = ax1.plot(top_plots[top_plot_lik][0], top_plots[top_plot_lik][1], 'o', linestyle = "-",
  241. alpha=1, lw=2, label = str(top_plots[top_plot_lik][2])+' brks | Lik='+top_plot_lik)
  242. plot_handles.append(p0)
  243. for k, plot_Lk in enumerate(best10_plots[1:]):
  244. plot_Lk = str(plot_Lk)
  245. # plt.rcParams['font.size'] = fnt_size
  246. p, = ax1.plot(top_plots[plot_Lk][0], top_plots[plot_Lk][1], 'o', linestyle = "--",
  247. alpha=1/(k+1), lw=1.5, label = str(top_plots[plot_Lk][2])+' brks | Lik='+plot_Lk)
  248. plot_handles.append(p)
  249. if theta_scale:
  250. ax1.set_xlabel("Coal. time", fontsize=fnt_size)
  251. ax1.set_ylabel("Pop. size scaled by N0", fontsize=fnt_size)
  252. # recent_scale_lower_bound = 0.01
  253. # recent_scale_upper_bound = 0.1
  254. # ax1.axvline(x=recent_scale_lower_bound)
  255. # ax1.axvline(x=recent_scale_upper_bound)
  256. else:
  257. # years
  258. plt.set_xlabel("Time (years)", fontsize=fnt_size)
  259. plt.set_ylabel("Individuals (N)", fontsize=fnt_size)
  260. # plt.rcParams['font.size'] = fnt_size
  261. # print(fnt_size, "rcParam font.size=", plt.rcParams['font.size'])
  262. ax1.legend(handles = plot_handles, loc='best', fontsize = fnt_size*0.5)
  263. ax1.set_title(title)
  264. if ax is None:
  265. plt.savefig(title+'_b'+str(breaks)+'.pdf')
  266. # plot likelihood against nb of breakpoints
  267. # best possible likelihood from SFS
  268. # Segregating sites
  269. S = sum(SFS_stored)
  270. # Number of kept sites from which the SFS is computed
  271. L = L_stored
  272. # number of monomorphic sites
  273. S0 = L-S
  274. # print("SFS", SFS_stored)
  275. # print("S", S, "L", L, "S0=", S0)
  276. # compute Ln
  277. Ln = log_facto(S+S0) - log_facto(S0) + np.log(float(S0)/(S+S0)) * S0
  278. for xi in range(0, len(SFS_stored)):
  279. p_i = SFS_stored[xi] / float(S+S0)
  280. Ln += np.log(p_i) * SFS_stored[xi] - log_facto(SFS_stored[xi])
  281. # basic plot likelihood
  282. if ax is None:
  283. fig, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  284. # plt.rcParams['font.size'] = fnt_size
  285. else:
  286. #plt.rcParams['font.size'] = fnt_size
  287. ax2 = ax[0][0,1]
  288. ax2.plot(np.array(brkpt_lik)[:, 0], np.array(brkpt_lik)[:, 1].astype(float), 'o', linestyle = "dotted", lw=2)
  289. ax2.axhline(y=-Ln, linestyle = "-.", color = "red", label = "$-\log\mathcal{L}$ = "+str(round(-Ln, 2)))
  290. ax2.set_yscale('log')
  291. ax2.set_xlabel("# breakpoints", fontsize=fnt_size)
  292. ax2.set_ylabel("$-\log\mathcal{L}$", fontsize=fnt_size)
  293. ax2.legend(loc='best', fontsize = fnt_size*0.5)
  294. ax2.set_title(title+" Likelihood gain from # breakpoints")
  295. if ax is None:
  296. plt.savefig(title+'_Breakpts_Likelihood.pdf')
  297. # AIC
  298. if ax is None:
  299. fig, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  300. # plt.rcParams['font.size'] = '18'
  301. else:
  302. #plt.rcParams['font.size'] = fnt_size
  303. ax3 = ax[1][0,1]
  304. AIC = []
  305. for brk in np.array(brkpt_lik)[:, 0]:
  306. brk = int(brk)
  307. AIC.append((2*brk+1)+2*np.array(brkpt_lik)[brk, 1].astype(float))
  308. ax3.plot(np.array(brkpt_lik)[:, 0], AIC, 'o', linestyle = "dotted", lw=2)
  309. # AIC = 2*k - 2ln(L) ; where k is the number of parameters, here brks+1
  310. AIC_ln = 2*(len(brkpt_lik)+1) - 2*Ln
  311. ax3.axhline(y=AIC_ln, linestyle = "-.", color = "red",
  312. label = "Min. AIC = "+str(round(AIC_ln, 2)))
  313. selected_brks_nb = AIC.index(min(AIC))
  314. ax3.set_yscale('log')
  315. ax3.set_xlabel("# breakpoints", fontsize=fnt_size)
  316. ax3.set_ylabel("AIC")
  317. ax3.legend(loc='best', fontsize = fnt_size*0.5)
  318. ax3.set_title(title+" AIC")
  319. if ax is None:
  320. plt.savefig(title+'_Breakpts_Likelihood_AIC.pdf')
  321. print("S", S)
  322. # return plots
  323. return ax[0], ax[1]
  324. def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
  325. breaks_max = 10, output = None):
  326. """
  327. Save theta values as is to do basic plots.
  328. """
  329. cpt = 0
  330. epochs = {}
  331. len_sfs = 0
  332. for file_name in os.listdir(folder_path):
  333. cpt +=1
  334. if os.path.isfile(os.path.join(folder_path, file_name)):
  335. for k in range(breaks_max):
  336. x,y,likelihood,thetas,sfs,L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = k,
  337. tgen = tgen,
  338. mu = mu, relative_theta_scale = theta_scale)
  339. if thetas == 0:
  340. continue
  341. if len(thetas)-1 != k:
  342. continue
  343. if k not in epochs.keys():
  344. epochs[k] = {}
  345. likelihood = str(eval(thetas[k][2]))
  346. epochs[k][likelihood] = thetas
  347. #epochs[k] = thetas
  348. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(k)+"\n*******\n")
  349. print(cpt, "theta file(s) have been scanned.")
  350. plots = []
  351. best_epochs = {}
  352. for epoch in epochs:
  353. likelihoods = []
  354. for key in epochs[epoch].keys():
  355. likelihoods.append(key)
  356. likelihoods.sort()
  357. minLogLn = str(likelihoods[0])
  358. best_epochs[epoch] = epochs[epoch][minLogLn]
  359. for epoch, theta in best_epochs.items():
  360. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  361. x = []
  362. y = []
  363. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  364. for i,group in enumerate(groups):
  365. x += group[::-1]
  366. y += list(np.repeat(thetas[i], len(group)))
  367. if epoch == 0:
  368. N0 = y[0]
  369. # compute the proportion of information used at each bin of the SFS
  370. sum_theta_i = 0
  371. for i in range(2, len(y)+2):
  372. sum_theta_i+=y[i-2] / (i-1)
  373. prop = []
  374. for k in range(2, len(y)+2):
  375. prop.append(y[k-2] / (k - 1) / sum_theta_i)
  376. prop = prop[::-1]
  377. # normalise to N0 (N0 of epoch1)
  378. for i in range(len(y)):
  379. y[i] = y[i]/N0
  380. # x_plot, y_plot = plot_straight_x_y(x, y)
  381. p = x, y
  382. # add plot to the list of all plots to superimpose
  383. plots.append(p)
  384. cumul = 0
  385. prop_cumul = []
  386. for val in prop:
  387. prop_cumul.append(val+cumul)
  388. cumul = val+cumul
  389. prop = prop_cumul
  390. lines_fig2 = []
  391. for epoch, theta in best_epochs.items():
  392. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  393. x = []
  394. y = []
  395. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  396. for i,group in enumerate(groups):
  397. x += group[::-1]
  398. y += list(np.repeat(thetas[i], len(group)))
  399. if epoch == 0:
  400. N0 = y[0]
  401. for i in range(len(y)):
  402. y[i] = y[i]/N0
  403. x_2 = []
  404. T = 0
  405. for i in range(len(x)):
  406. x[i] = int(x[i])
  407. # compute the times as: theta_k / (k*(k-1))
  408. for i in range(0, len(x)):
  409. T += y[i] / (x[i]*(x[i]-1))
  410. x_2.append(T)
  411. # Save plotting (fig 2)
  412. x_2 = [0]+x_2
  413. y = [y[0]]+y
  414. # x2_plot, y2_plot = plot_straight_x_y(x_2, y)
  415. p2 = x_2, y
  416. lines_fig2.append(p2)
  417. saved_plots = {"raw_stairs":plots, "scaled_stairs":lines_fig2,
  418. "prop":prop}
  419. if output == None:
  420. output = title+"_plotdata.json"
  421. with open(output, 'w') as json_file:
  422. json.dump(saved_plots, json_file)
  423. return saved_plots
  424. def plot_scaled_theta(plot_lines, prop, title, ax = None, n_ticks = 10):
  425. # fig 2 & 3
  426. if ax is None:
  427. my_dpi = 300
  428. fnt_size = 18
  429. fig2, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  430. fig3, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  431. else:
  432. # plt.rcParams['font.size'] = fnt_size
  433. fnt_size = 12
  434. # place of plots on the grid
  435. ax2 = ax[1,0]
  436. ax3 = ax[1,1]
  437. lines_fig2 = []
  438. lines_fig3 = []
  439. #plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  440. for epoch, plot in enumerate(plot_lines):
  441. x,y=plot
  442. x2_plot, y2_plot = plot_straight_x_y(x,y)
  443. p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  444. lines_fig2.append(p2)
  445. # Plotting (fig 3) which is the same but log scale for x
  446. p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  447. lines_fig3.append(p3)
  448. ax2.set_xlabel("Relative scale", fontsize=fnt_size)
  449. ax2.set_ylabel("theta", fontsize=fnt_size)
  450. ax2.set_title(title, fontsize=fnt_size)
  451. ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
  452. if ax is None:
  453. # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
  454. plt.savefig(title+'_plot2_'+str(len(plot_lines))+'.pdf')
  455. # close fig2 to save memory
  456. plt.close(fig2)
  457. ax3.set_xscale('log')
  458. ax3.set_yscale('log')
  459. ax3.set_xlabel("log Relative scale", fontsize=fnt_size)
  460. ax3.set_ylabel("theta", fontsize=fnt_size)
  461. ax3.set_title(title, fontsize=fnt_size)
  462. ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5)
  463. if ax is None:
  464. # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
  465. plt.savefig(title+'_plot3_'+str(len(plot_lines))+'_log.pdf')
  466. # close fig3 to save memory
  467. plt.close(fig3)
  468. return ax
  469. def plot_raw_stairs(plot_lines, prop, title, ax = None, n_ticks = 10):
  470. # multiple fig
  471. if ax is None:
  472. # intialize figure 1
  473. my_dpi = 300
  474. fnt_size = 18
  475. # plt.rcParams['font.size'] = fnt_size
  476. fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  477. else:
  478. fnt_size = 12
  479. # plt.rcParams['font.size'] = fnt_size
  480. ax1 = ax[0, 0]
  481. plt.subplots_adjust(wspace=0.3, hspace=0.3)
  482. plots = []
  483. for epoch, plot in enumerate(plot_lines):
  484. x,y = plot
  485. x_plot, y_plot = plot_straight_x_y(x,y)
  486. p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  487. # add plot to the list of all plots to superimpose
  488. plots.append(p)
  489. x_ticks = x
  490. # print(x_ticks)
  491. #print(prop, "\n", sum(prop))
  492. #ax.legend(handles=[p0]+plots)
  493. ax1.set_xlabel("# bin & cumul. prop. of sites", fontsize=fnt_size)
  494. # Set the x-axis locator to reduce the number of ticks to 10
  495. ax1.set_ylabel("theta", fontsize=fnt_size)
  496. ax1.set_title(title, fontsize=fnt_size)
  497. ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
  498. ax1.set_xticks(x_ticks)
  499. step = len(x_ticks)//(n_ticks-1)
  500. values = x_ticks[::step]
  501. new_prop = []
  502. for val in values:
  503. new_prop.append(prop[int(val)-2])
  504. new_prop = new_prop[::-1]
  505. ax1.set_xticks(values)
  506. ax1.set_xticklabels([f'{values[k]}\n{val:.2f}' for k, val in enumerate(new_prop)], fontsize = fnt_size*0.8)
  507. if ax is None:
  508. # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
  509. plt.savefig(title+'_raw'+str(len(plot_lines))+'.pdf')
  510. plt.close(fig)
  511. # return plots
  512. return ax
  513. def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True, breaks_max = 10, ax = None, n_ticks = 10):
  514. """
  515. Use theta values as is to do basic plots.
  516. """
  517. cpt = 0
  518. epochs = {}
  519. len_sfs = 0
  520. for file_name in os.listdir(folder_path):
  521. cpt +=1
  522. if os.path.isfile(os.path.join(folder_path, file_name)):
  523. for k in range(breaks_max):
  524. x, y, likelihood, theta, sfs, L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = k,
  525. tgen = tgen,
  526. mu = mu, relative_theta_scale = theta_scale)
  527. if thetas == 0:
  528. continue
  529. if len(thetas)-1 != k:
  530. continue
  531. if k not in epochs.keys():
  532. epochs[k] = {}
  533. likelihood = str(eval(thetas[k][2]))
  534. epochs[k][likelihood] = thetas
  535. #epochs[k] = thetas
  536. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(k)+"\n*******\n")
  537. print(cpt, "theta file(s) have been scanned.")
  538. # multiple fig
  539. if ax is None:
  540. # intialize figure 1
  541. my_dpi = 300
  542. fnt_size = 18
  543. # plt.rcParams['font.size'] = fnt_size
  544. fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  545. else:
  546. fnt_size = 12
  547. # plt.rcParams['font.size'] = fnt_size
  548. ax1 = ax[0, 1]
  549. plt.subplots_adjust(wspace=0.3, hspace=0.3)
  550. plots = []
  551. best_epochs = {}
  552. for epoch in epochs:
  553. likelihoods = []
  554. for key in epochs[epoch].keys():
  555. likelihoods.append(key)
  556. likelihoods.sort()
  557. minLogLn = str(likelihoods[0])
  558. best_epochs[epoch] = epochs[epoch][minLogLn]
  559. for epoch, theta in best_epochs.items():
  560. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  561. x = []
  562. y = []
  563. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  564. for i,group in enumerate(groups):
  565. x += group[::-1]
  566. y += list(np.repeat(thetas[i], len(group)))
  567. if epoch == 0:
  568. N0 = y[0]
  569. # compute the proportion of information used at each bin of the SFS
  570. sum_theta_i = 0
  571. for i in range(2, len(y)+2):
  572. sum_theta_i+=y[i-2] / (i-1)
  573. prop = []
  574. for k in range(2, len(y)+2):
  575. prop.append(y[k-2] / (k - 1) / sum_theta_i)
  576. prop = prop[::-1]
  577. # print(prop, "\n", sum(prop))
  578. # normalise to N0 (N0 of epoch1)
  579. x_ticks = ax1.get_xticks()
  580. for i in range(len(y)):
  581. y[i] = y[i]/N0
  582. # plot
  583. x_plot, y_plot = plot_straight_x_y(x, y)
  584. #plt.plot(x, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
  585. p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  586. # add plot to the list of all plots to superimpose
  587. plots.append(p)
  588. #print(prop, "\n", sum(prop))
  589. #ax.legend(handles=[p0]+plots)
  590. ax1.set_xlabel("# bin", fontsize=fnt_size)
  591. # Set the x-axis locator to reduce the number of ticks to 10
  592. ax1.set_ylabel("theta", fontsize=fnt_size)
  593. ax1.set_title(title, fontsize=fnt_size)
  594. ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
  595. ax1.set_xticks(x_ticks)
  596. if len(prop) >= 18:
  597. ax1.locator_params(nbins=n_ticks)
  598. # new scale of ticks if too many values
  599. cumul = 0
  600. prop_cumul = []
  601. for val in prop:
  602. prop_cumul.append(val+cumul)
  603. cumul = val+cumul
  604. ax1.set_xticklabels([f'{x[k]}\n{val:.2f}' for k, val in enumerate(prop_cumul)])
  605. if ax is None:
  606. plt.savefig(title+'_raw'+str(k)+'.pdf')
  607. # fig 2 & 3
  608. if ax is None:
  609. fig2, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  610. fig3, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  611. else:
  612. # plt.rcParams['font.size'] = fnt_size
  613. # place of plots on the grid
  614. ax2 = ax[1,0]
  615. ax3 = ax[1,1]
  616. lines_fig2 = []
  617. lines_fig3 = []
  618. #plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  619. for epoch, theta in best_epochs.items():
  620. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  621. x = []
  622. y = []
  623. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  624. for i,group in enumerate(groups):
  625. x += group[::-1]
  626. y += list(np.repeat(thetas[i], len(group)))
  627. if epoch == 0:
  628. N0 = y[0]
  629. for i in range(len(y)):
  630. y[i] = y[i]/N0
  631. x_2 = []
  632. T = 0
  633. for i in range(len(x)):
  634. x[i] = int(x[i])
  635. # compute the times as: theta_k / (k*(k-1))
  636. for i in range(0, len(x)):
  637. T += y[i] / (x[i]*(x[i]-1))
  638. x_2.append(T)
  639. # Plotting (fig 2)
  640. x_2 = [0]+x_2
  641. y = [y[0]]+y
  642. x2_plot, y2_plot = plot_straight_x_y(x_2, y)
  643. p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  644. lines_fig2.append(p2)
  645. # Plotting (fig 3) which is the same but log scale for x
  646. p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  647. lines_fig3.append(p3)
  648. ax2.set_xlabel("Relative scale", fontsize=fnt_size)
  649. ax2.set_ylabel("theta", fontsize=fnt_size)
  650. ax2.set_title(title, fontsize=fnt_size)
  651. ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
  652. if ax is None:
  653. plt.savefig(title+'_plot2_'+str(k)+'.pdf')
  654. ax3.set_xscale('log')
  655. ax3.set_yscale('log')
  656. ax3.set_xlabel("log Relative scale", fontsize=fnt_size)
  657. ax3.set_ylabel("theta", fontsize=fnt_size)
  658. ax3.set_title(title, fontsize=fnt_size)
  659. ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5)
  660. if ax is None:
  661. plt.savefig(title+'_plot3_'+str(k)+'_log.pdf')
  662. plt.clf()
  663. # return plots
  664. return ax
  665. def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale = True):
  666. my_dpi = 300
  667. # # Add some extra space for the second axis at the bottom
  668. # #plt.rcParams['font.size'] = 18
  669. # fig, axs = plt.subplots(2, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
  670. # #plt.rcParams['font.size'] = 12
  671. # ax = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = axs)
  672. # ax = plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = axs)
  673. # # Adjust layout to prevent clipping of titles
  674. #
  675. # # Save the entire grid as a single figure
  676. # plt.savefig(title+'_combined.pdf')
  677. # plt.clf()
  678. # # # second call for individual plots
  679. # # plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = None)
  680. # # plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = None)
  681. # # plt.clf()
  682. save_k_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, output = title+"_plotdata.json")
  683. with open(title+"_plotdata.json", 'r') as json_file:
  684. loaded_data = json.load(json_file)
  685. # plot page 1 of summary
  686. fig1, ax1 = plt.subplots(2, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
  687. # fig1.tight_layout()
  688. # Adjust absolute space between the top and bottom rows
  689. fig1.subplots_adjust(hspace=0.35) # Adjust this value based on your requirement
  690. # plot page 2 of summary
  691. fig2, ax2 = plt.subplots(2, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
  692. # fig2.tight_layout()
  693. ax1 = plot_raw_stairs(plot_lines = loaded_data['raw_stairs'],
  694. prop = loaded_data['prop'], title = title, ax = ax1)
  695. ax1 = plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'],
  696. prop = loaded_data['prop'], title = title, ax = ax1)
  697. ax1, ax2 = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = [ax1, ax2])
  698. fig1.savefig(title+'_combined_p1.pdf')
  699. fig2.savefig(title+'_combined_p2.pdf')
  700. plot_raw_stairs(plot_lines = loaded_data['raw_stairs'],
  701. prop = loaded_data['prop'], title = title, ax = None)
  702. plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'],
  703. prop = loaded_data['prop'], title = title, ax = None)
  704. plt.close(fig1)
  705. plt.close(fig2)
  706. if __name__ == "__main__":
  707. if len(sys.argv) != 4:
  708. print("Need 3 args: ThetaFolder MutationRate GenerationTime")
  709. exit(0)
  710. folder_path = sys.argv[1]
  711. mu = sys.argv[2]
  712. tgen = sys.argv[3]
  713. plot_all_epochs_thetafolder(folder_path, mu, tgen)