swp2.py 36KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852
  1. import matplotlib.pyplot as plt
  2. import os
  3. import numpy as np
  4. import math
  5. import json
  6. def log_facto(k):
  7. """
  8. Using the Stirling's approximation
  9. """
  10. k = int(k)
  11. if k > 1e6:
  12. return k * np.log(k) - k + np.log(2*math.pi*k)/2
  13. val = 0
  14. for i in range(2, k+1):
  15. val += np.log(i)
  16. return val
  17. def parse_stwp_theta_file(stwp_theta_file, breaks, mu, tgen, relative_theta_scale = False):
  18. with open(stwp_theta_file, "r") as swp_file:
  19. # Read the first line
  20. line = swp_file.readline()
  21. L = float(line.split()[2])
  22. rands = swp_file.readline()
  23. line = swp_file.readline()
  24. # skip empty lines before SFS
  25. while line == "\n":
  26. line = swp_file.readline()
  27. sfs = np.array(line.split()).astype(float)
  28. # Process lines until the end of the file
  29. while line:
  30. # check at each line
  31. if line.startswith("dim") :
  32. dim = int(line.split()[1])
  33. if dim == breaks+1:
  34. likelihood = line.split()[5]
  35. groups = line.split()[6:6+dim]
  36. theta_site = line.split()[6+dim:6+dim+1+dim]
  37. elif dim < breaks+1:
  38. line = swp_file.readline()
  39. continue
  40. elif dim > breaks+1:
  41. break
  42. #return 0,0,0
  43. # Read the next line
  44. line = swp_file.readline()
  45. #### END of parsing
  46. # quit this file if the number of dimensions is incorrect
  47. if dim < breaks+1:
  48. return 0,0,0,0,0,0
  49. # get n, the last bin of the last group
  50. # revert the list of groups as the most recent times correspond
  51. # to the closest and last leafs of the coal. tree.
  52. groups = groups[::-1]
  53. theta_site = theta_site[::-1]
  54. # store thetas for later use
  55. grps = groups.copy()
  56. thetas = {}
  57. for i in range(len(groups)):
  58. grps[i] = grps[i].split(',')
  59. thetas[i] = [float(theta_site[i]), grps[i], likelihood]
  60. # initiate the dict of times
  61. t = {}
  62. # list of thetas
  63. theta_L = []
  64. sum_t = 0
  65. for group_nb, group in enumerate(groups):
  66. ###print(group_nb, group, theta_site[group_nb], len(theta_site))
  67. # store all the thetas one by one, with one theta per group
  68. theta_L.append(float(theta_site[group_nb]))
  69. # if the group is of size 1
  70. if len(group.split(',')) == 1:
  71. i = int(group)
  72. # if the group size is >1, take the first elem of the group
  73. # i is the first bin of each group, straight after a breakpoint
  74. else:
  75. i = int(group.split(",")[0])
  76. j = int(group.split(",")[-1])
  77. t[i] = 0
  78. #t =
  79. if len(group.split(',')) == 1:
  80. k = i
  81. if relative_theta_scale:
  82. t[i] += ((theta_L[group_nb] ) / (k*(k-1)))
  83. else:
  84. t[i] += ((theta_L[group_nb] ) / (k*(k-1)) * tgen) / mu
  85. else:
  86. for k in range(j, i-1, -1 ):
  87. if relative_theta_scale:
  88. t[i] += ((theta_L[group_nb] ) / (k*(k-1)))
  89. else:
  90. t[i] += ((theta_L[group_nb] ) / (k*(k-1)) * tgen) / mu
  91. # we add the cumulative times at the end
  92. t[i] += sum_t
  93. sum_t = t[i]
  94. # build the y axis (sizes)
  95. y = []
  96. for theta in theta_L:
  97. if relative_theta_scale:
  98. size = theta
  99. else:
  100. # with size N = theta/4mu
  101. size = theta / (4*mu)
  102. y.append(size)
  103. y.append(size)
  104. # build the time x axis
  105. x = [0]
  106. for time in range(0, len(t.values())-1):
  107. x.append(list(t.values())[time])
  108. x.append(list(t.values())[time])
  109. x.append(list(t.values())[len(t.values())-1])
  110. return x,y,likelihood,thetas,sfs,L
  111. def plot_straight_x_y(x,y):
  112. x_1 = [x[0]]
  113. y_1 = []
  114. for i in range(0, len(y)-1):
  115. x_1.append(x[i])
  116. x_1.append(x[i])
  117. y_1.append(y[i])
  118. y_1.append(y[i])
  119. y_1 = y_1+[y[-1],y[-1]]
  120. x_1.append(x[-1])
  121. return x_1, y_1
  122. def plot_all_epochs_thetafolder(full_dict, mu, tgen, title = "Title",
  123. theta_scale = True, ax = None, input = None, output = None):
  124. my_dpi = 500
  125. L = full_dict["L"]
  126. if ax is None:
  127. # intialize figure
  128. #my_dpi = 300
  129. fnt_size = 18
  130. # plt.rcParams['font.size'] = fnt_size
  131. fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  132. else:
  133. fnt_size = 12
  134. # plt.rcParams['font.size'] = fnt_size
  135. ax1 = ax[1][0,0]
  136. ax1.set_yscale('log')
  137. ax1.set_xscale('log')
  138. plot_handles = []
  139. best_plot = full_dict['all_epochs']['best']
  140. p0, = ax1.plot(best_plot[0], best_plot[1], linestyle = "-",
  141. alpha=1, lw=2, label = str(best_plot[2])+' brks | Lik='+best_plot[3])
  142. plot_handles.append(p0)
  143. #ax1.grid(True,which="both", linestyle='--', alpha = 0.3)
  144. for k, plot_Lk in enumerate(full_dict['all_epochs']['plots']):
  145. plot_Lk = str(full_dict['all_epochs']['plots'][k][3])
  146. # plt.rcParams['font.size'] = fnt_size
  147. p, = ax1.plot(full_dict['all_epochs']['plots'][k][0], full_dict['all_epochs']['plots'][k][1], linestyle = "-",
  148. alpha=1/(k+1), lw=1.5, label = str(full_dict['all_epochs']['plots'][k][2])+' brks | Lik='+plot_Lk)
  149. plot_handles.append(p)
  150. if theta_scale:
  151. ax1.set_xlabel("Coal. time", fontsize=fnt_size)
  152. ax1.set_ylabel("Pop. size scaled by N0", fontsize=fnt_size)
  153. # recent_scale_lower_bound = 0.01
  154. # recent_scale_upper_bound = 0.1
  155. # ax1.axvline(x=recent_scale_lower_bound)
  156. # ax1.axvline(x=recent_scale_upper_bound)
  157. else:
  158. # years
  159. if ax is not None:
  160. plt.set_xlabel("Time (years)", fontsize=fnt_size)
  161. plt.set_ylabel("Effective pop. size (Ne)", fontsize=fnt_size)
  162. else:
  163. plt.xlabel("Time (years)", fontsize=fnt_size)
  164. plt.ylabel("Effective pop. size (Ne)", fontsize=fnt_size)
  165. # x_ticks = ax1.get_xticks()
  166. # ax1.set_xticklabels([f'{k:.0e}\n{k/(mu):.0e}\n{k/(mu)*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5)
  167. # ax1.set_xticklabels([f'{k}\n{k/(mu)}\n{k/(mu)*tgen}' for k in x_ticks], fontsize = fnt_size*0.8)
  168. # plt.rcParams['font.size'] = fnt_size
  169. # print(fnt_size, "rcParam font.size=", plt.rcParams['font.size'])
  170. ax1.legend(handles = plot_handles, loc='best', fontsize = fnt_size*0.5)
  171. ax1.set_title(title)
  172. breaks = len(full_dict['all_epochs']['plots'])
  173. if ax is None:
  174. plt.savefig(title+'_best_'+str(breaks+1)+'_epochs.pdf')
  175. # plot likelihood against nb of breakpoints
  176. if ax is None:
  177. fig, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  178. # plt.rcParams['font.size'] = fnt_size
  179. else:
  180. #plt.rcParams['font.size'] = fnt_size
  181. ax2 = ax[0][0,1]
  182. # Retrieve the default color cycle from rcParams
  183. default_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
  184. # Create an array of colors from the default color cycle
  185. colors = [default_colors[i % len(default_colors)] for i in range(len(full_dict['Ln_Brks'][0]))]
  186. ax2.plot(full_dict['Ln_Brks'][0], full_dict['Ln_Brks'][1], "--", lw=1, color="black", zorder=1)
  187. ax2.scatter(full_dict['Ln_Brks'][0], full_dict['Ln_Brks'][1], s=50, c=colors, marker='o', zorder=2)
  188. ax2.axhline(y=full_dict['best_Ln'], linestyle = "-.", color = "red", label = "$-\log\mathcal{L}$ = "+str(round(full_dict['best_Ln'], 2)))
  189. ax2.set_yscale('log')
  190. ax2.set_xlabel("# breakpoints", fontsize=fnt_size)
  191. ax2.set_ylabel("$-\log\mathcal{L}$", fontsize=fnt_size)
  192. ax2.legend(loc='best', fontsize = fnt_size*0.5)
  193. ax2.set_title(title+" Likelihood gain from # breakpoints")
  194. if ax is None:
  195. plt.savefig(title+'_Breakpts_Likelihood.pdf')
  196. # AIC
  197. if ax is None:
  198. fig, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  199. # plt.rcParams['font.size'] = '18'
  200. else:
  201. #plt.rcParams['font.size'] = fnt_size
  202. ax3 = ax[1][0,1]
  203. AIC = full_dict['AIC_Brks']
  204. # ax3.plot(AIC[0], AIC[1], 'o', linestyle = "dotted", lw=2)
  205. ax3.plot(AIC[0], AIC[1], "--", lw=1, color="black", zorder=1)
  206. ax3.scatter(AIC[0], AIC[1], s=50, c=colors, marker='o', zorder=2)
  207. ax3.axhline(y=full_dict['best_AIC'], linestyle = "-.", color = "red",
  208. label = "Min. AIC = "+str(round(full_dict['best_AIC'], 2)))
  209. ax3.set_yscale('log')
  210. ax3.set_xlabel("# breakpoints", fontsize=fnt_size)
  211. ax3.set_ylabel("AIC")
  212. ax3.legend(loc='best', fontsize = fnt_size*0.5)
  213. ax3.set_title(title+" AIC")
  214. if ax is None:
  215. plt.savefig(title+'_Breakpts_Likelihood_AIC.pdf')
  216. else:
  217. # return plots
  218. return ax[0], ax[1]
  219. def save_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_scale = True, input = None, output = None):
  220. #scenari = {}
  221. cpt = 0
  222. epochs = {}
  223. plots = {}
  224. # store ['best'], and [0] for epoch 0 etc...
  225. for file_name in os.listdir(folder_path):
  226. breaks = 0
  227. cpt +=1
  228. if os.path.isfile(os.path.join(folder_path, file_name)):
  229. x, y, likelihood, theta, sfs, L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
  230. tgen = tgen,
  231. mu = mu, relative_theta_scale = theta_scale)
  232. SFS_stored = sfs
  233. L_stored = L
  234. while not (x == 0 and y == 0):
  235. if breaks not in epochs.keys():
  236. epochs[breaks] = {}
  237. epochs[breaks][likelihood] = x,y
  238. breaks += 1
  239. x,y,likelihood,theta,sfs,L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
  240. tgen = tgen,
  241. mu = mu, relative_theta_scale = theta_scale)
  242. if x == 0:
  243. # last break did not work, then breaks = breaks-1
  244. breaks -= 1
  245. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
  246. print(cpt, "theta file(s) have been scanned.")
  247. brkpt_lik = []
  248. top_plots = {}
  249. best_scenario_for_epoch = {}
  250. for epoch, scenari in epochs.items():
  251. # sort starting by the smallest -log(Likelihood)
  252. best10_scenari = (sorted(list(scenari.keys())))[:10]
  253. greatest_likelihood = best10_scenari[0]
  254. # store the tuple breakpoints and likelihood for later plot
  255. brkpt_lik.append((epoch, greatest_likelihood))
  256. x, y = scenari[greatest_likelihood]
  257. #without breakpoint
  258. if epoch == 0:
  259. # do something with the theta without bp and skip the plotting
  260. N0 = y[0]
  261. #continue
  262. if theta_scale:
  263. for i in range(len(y)):
  264. # divide by N0
  265. y[i] = y[i]/N0
  266. x[i] = x[i]/N0
  267. top_plots[greatest_likelihood] = x,y,epoch
  268. best_scenario_for_epoch[epoch] = x,y,greatest_likelihood
  269. plots_likelihoods = list(top_plots.keys())
  270. for i in range(len(plots_likelihoods)):
  271. plots_likelihoods[i] = float(plots_likelihoods[i])
  272. best10_plots = sorted(plots_likelihoods)[:10]
  273. top_plot_lik = str(best10_plots[0])
  274. # store x,y,brks,likelihood
  275. plots['best'] = (top_plots[top_plot_lik][0], top_plots[top_plot_lik][1], str(top_plots[top_plot_lik][2]), top_plot_lik)
  276. plots['plots'] = []
  277. for k, epoch in enumerate(best_scenario_for_epoch.keys()):
  278. plot_Lk = str(best_scenario_for_epoch[epoch][2])
  279. x,y = best_scenario_for_epoch[epoch][0], best_scenario_for_epoch[epoch][1]
  280. plots['plots'].append([x, y, str(epoch), plot_Lk])
  281. plots['plots'] = sorted(plots['plots'], key=lambda x: float(x[3]))
  282. plots['plots'] = plots['plots'][1:]
  283. # Previous version. Was this correct????
  284. # for k, plot_Lk in enumerate(best10_plots[1:]):
  285. # plot_Lk = str(plot_Lk)
  286. # plots['plots'].append([top_plots[plot_Lk][0], top_plots[plot_Lk][1], str(top_plots[plot_Lk][2]), plot_Lk])
  287. # plot likelihood against nb of breakpoints
  288. # best possible likelihood from SFS
  289. # Segregating sites
  290. S = sum(SFS_stored)
  291. # Number of kept sites from which the SFS is computed
  292. L = L_stored
  293. # number of monomorphic sites
  294. S0 = L-S
  295. # print("SFS", SFS_stored)
  296. print("S", S, "L", L, "S0=", S0)
  297. my_n = len(SFS_stored)*2
  298. print("n=",my_n)
  299. an = 1
  300. for i in range(2, my_n):
  301. an +=1.0/i
  302. print("an=", an, "theta_w", S/an, "theta_w_p_site", (S/an)/L)
  303. # compute Ln
  304. Ln = log_facto(S+S0) - log_facto(S0) + np.log(float(S0)/(S+S0)) * S0
  305. for xi in range(0, len(SFS_stored)):
  306. p_i = SFS_stored[xi] / float(S+S0)
  307. Ln += np.log(p_i) * SFS_stored[xi] - log_facto(SFS_stored[xi])
  308. # basic plot likelihood
  309. Ln_Brks = [list(np.array(brkpt_lik)[:, 0]), list(np.array(brkpt_lik)[:, 1].astype(float))]
  310. best_Ln = -Ln
  311. AIC = []
  312. for brk in np.array(brkpt_lik)[:, 0]:
  313. brk = int(brk)
  314. AIC.append((2*brk+1)+2*np.array(brkpt_lik)[brk, 1].astype(float))
  315. AIC_Brks = [list(np.array(brkpt_lik)[:, 0]), AIC]
  316. # AIC = 2*k - 2ln(L) ; where k is the number of parameters, here brks+1
  317. AIC_ln = 2*(len(brkpt_lik)+1) - 2*Ln
  318. best_AIC = AIC_ln
  319. selected_brks_nb = AIC.index(min(AIC))
  320. # to return : plots ; Ln_Brks ; AIC_Brks ; best_Ln ; best_AIC
  321. # 'plots' dict keys: 'best', {epochs}('0', '1',...)
  322. if input == None:
  323. saved_plots = {"S":S, "S0":S0, "L":L, "mu":mu, "tgen":tgen,
  324. "all_epochs":plots, "Ln_Brks":Ln_Brks,
  325. "AIC_Brks":AIC_Brks, "best_Ln":best_Ln,
  326. "best_AIC":best_AIC, "best_epoch_by_AIC":selected_brks_nb}
  327. else:
  328. # if the dict has to be loaded from input
  329. with open(input, 'r') as json_file:
  330. saved_plots = json.load(json_file)
  331. saved_plots["S"] = S
  332. saved_plots["S0"] = S0
  333. saved_plots["L"] = L
  334. saved_plots["mu"] = mu
  335. saved_plots["tgen"] = tgen
  336. saved_plots["all_epochs"] = plots
  337. saved_plots["Ln_Brks"] = Ln_Brks
  338. saved_plots["AIC_Brks"] = AIC_Brks
  339. saved_plots["best_Ln"] = best_Ln
  340. saved_plots["best_AIC"] = best_AIC
  341. saved_plots["best_epoch_by_AIC"] = selected_brks_nb
  342. if output == None:
  343. output = title+"_plotdata.json"
  344. with open(output, 'w') as json_file:
  345. json.dump(saved_plots, json_file)
  346. return saved_plots
  347. def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
  348. breaks_max = 10, input = None, output = None):
  349. """
  350. Save theta values as is to do basic plots.
  351. """
  352. cpt = 0
  353. epochs = {}
  354. len_sfs = 0
  355. for file_name in os.listdir(folder_path):
  356. cpt +=1
  357. if os.path.isfile(os.path.join(folder_path, file_name)):
  358. for k in range(breaks_max+1):
  359. x,y,likelihood,thetas,sfs,L = parse_stwp_theta_file(folder_path+file_name, breaks = k,
  360. tgen = tgen,
  361. mu = mu, relative_theta_scale = theta_scale)
  362. if thetas == 0:
  363. continue
  364. if len(thetas)-1 != k:
  365. continue
  366. if k not in epochs.keys():
  367. epochs[k] = {}
  368. likelihood = str(eval(thetas[k][2]))
  369. epochs[k][likelihood] = thetas
  370. #epochs[k] = thetas
  371. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(k)+"\n*******\n")
  372. print(cpt, "theta file(s) have been scanned.")
  373. plots = []
  374. best_epochs = {}
  375. for epoch in epochs:
  376. likelihoods = []
  377. for key in epochs[epoch].keys():
  378. likelihoods.append(key)
  379. likelihoods.sort()
  380. minLogLn = str(likelihoods[0])
  381. best_epochs[epoch] = epochs[epoch][minLogLn]
  382. for epoch, theta in best_epochs.items():
  383. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  384. x = []
  385. y = []
  386. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  387. for i,group in enumerate(groups):
  388. x += group[::-1]
  389. y += list(np.repeat(thetas[i], len(group)))
  390. if epoch == 0:
  391. N0 = y[0]
  392. # compute the proportion of information used at each bin of the SFS
  393. sum_theta_i = 0
  394. for i in range(2, len(y)+2):
  395. sum_theta_i+=y[i-2] / (i-1)
  396. prop = []
  397. for k in range(2, len(y)+2):
  398. prop.append(y[k-2] / (k - 1) / sum_theta_i)
  399. prop = prop[::-1]
  400. if theta_scale :
  401. # normalise to N0 (N0 of epoch1)
  402. for i in range(len(y)):
  403. y[i] = y[i]/N0
  404. # x_plot, y_plot = plot_straight_x_y(x, y)
  405. p = x, y
  406. # add plot to the list of all plots to superimpose
  407. plots.append(p)
  408. cumul = 0
  409. prop_cumul = []
  410. for val in prop:
  411. prop_cumul.append(val+cumul)
  412. cumul = val+cumul
  413. prop = prop_cumul
  414. # print("raw stairs", plots[3])
  415. # ###########
  416. # time = []
  417. # for k in plots[0][0]:
  418. # k = int(k)
  419. # dt = 2.0/(k*(k-1))
  420. # time.append(2.0/(k*(k-1)))
  421. # Ne = []
  422. # for values in plots:
  423. # Ne.append(np.array(values[1])/(4*mu))
  424. # print(time)
  425. # print(Ne[3])
  426. lines_fig2 = []
  427. for epoch, theta in best_epochs.items():
  428. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  429. x = []
  430. y = []
  431. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  432. for i,group in enumerate(groups):
  433. x += group[::-1]
  434. y += list(np.repeat(thetas[i], len(group)))
  435. if epoch == 0:
  436. # watterson theta
  437. theta_w = y[0]
  438. if theta_scale :
  439. for i in range(len(y)):
  440. y[i] = y[i]/N0
  441. for i in range(len(y)):
  442. y[i] = y[i]/(4*mu)
  443. x_2 = []
  444. T = 0
  445. for i in range(len(x)):
  446. x[i] = int(x[i])
  447. # compute the times as: theta_k / (k*(k-1))
  448. for i in range(0, len(x)):
  449. T += y[i]*2 / (x[i]*(x[i]-1))
  450. x_2.append(T)
  451. # Save plotting (fig 2)
  452. # x_2 = [0]+x_2
  453. # y = [y[0]]+y
  454. # x2_plot, y2_plot = plot_straight_x_y(x_2, y)
  455. p2 = x_2, y
  456. lines_fig2.append(p2)
  457. # print("breaks=", epoch, "scaled_theta", lines_fig2[10])
  458. # print(lines_fig2[3][1][0]/(4*mu))
  459. # print(np.array(lines_fig2[3][1])/lines_fig2[3][1][0])
  460. # print("size list y=", len(lines_fig2[3][1]))
  461. #exit(0)
  462. if input == None:
  463. saved_plots = {"raw_stairs":plots, "scaled_stairs":lines_fig2,
  464. "prop":prop}
  465. else:
  466. # if the dict has to be loaded from input
  467. with open(input, 'r') as json_file:
  468. saved_plots = json.load(json_file)
  469. saved_plots["raw_stairs"] = plots
  470. saved_plots["scaled_stairs"] = lines_fig2
  471. saved_plots["prop"] = prop
  472. if output == None:
  473. output = title+"_plotdata.json"
  474. with open(output, 'w') as json_file:
  475. json.dump(saved_plots, json_file)
  476. return saved_plots
  477. def plot_scaled_theta(plot_lines, prop, title, mu, tgen, swp2_lines = None, ax = None, n_ticks = 10, subset = None, theta_scale = False):
  478. recent_limit_years = 100
  479. # recent limit in coal. time
  480. recent_limit = recent_limit_years/tgen
  481. # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
  482. nb_epochs = len(plot_lines)
  483. # fig 2 & 3
  484. if ax is None:
  485. my_dpi = 500
  486. fnt_size = 18
  487. fig2, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  488. fig3, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  489. else:
  490. # plt.rcParams['font.size'] = fnt_size
  491. fnt_size = 12
  492. # place of plots on the grid
  493. ax2 = ax[1,0]
  494. ax3 = ax[1,1]
  495. lines_fig2 = []
  496. lines_fig3 = []
  497. #plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  498. if swp2_lines:
  499. for k in range(len(swp2_lines[0])):
  500. swp2_lines[0][k] = swp2_lines[0][k]/tgen
  501. for k in range(len(swp2_lines[1])):
  502. swp2_lines[1][k] = swp2_lines[1][k]
  503. # x2_plot, y2_plot = plot_straight_x_y(swp2_lines[0],swp2_lines[1])
  504. x2_plot, y2_plot = swp2_lines[0], swp2_lines[1]
  505. p2, = ax2.plot(x2_plot, y2_plot, linestyle="-", alpha=0.75, lw=2, label = 'swp2', color="black")
  506. lines_fig2.append(p2)
  507. # Plotting (fig 3) which is the same but log scale for x
  508. p3, = ax3.plot(x2_plot, y2_plot, linestyle="-", alpha=0.75, lw=2, label = 'swp2', color="black")
  509. lines_fig3.append(p3)
  510. min_x = 1
  511. min_y = 1
  512. max_x = 0
  513. max_y = 0
  514. for breaks, plot in enumerate(plot_lines):
  515. x,y=plot
  516. x2_plot, y2_plot = plot_straight_x_y(x,y)
  517. if subset is not None:
  518. if breaks in subset:
  519. masking_alpha = 0.75
  520. autoscale = True
  521. min_x = min(min_x, min(x2_plot))
  522. min_y = min(min_y, min(y2_plot))
  523. max_x = max(max_x, max(x2_plot))
  524. max_y = max(max_y, max(y2_plot))
  525. # skip the base 0 points x_plot[0:3]
  526. t_max_below_limit = 0
  527. t_min_below_limit = recent_limit
  528. recent_change = False
  529. for t in x[1:]:
  530. if t <= recent_limit:
  531. recent_change = True
  532. t_max_below_limit = max(t_max_below_limit, t)
  533. t_min_below_limit = min(t_min_below_limit, t)
  534. Ne_max_below_limit = y[min(x.index(t_max_below_limit)+1, len(y)-1)]
  535. Ne_min_below_limit = y[x.index(t_min_below_limit)]
  536. if recent_change:
  537. print(f"\n{breaks} breaks ; This is below the recent limit of {recent_limit_years} years:\n",
  538. f"t_min (most recent time point under the limit) : {t_min_below_limit/mu*tgen:.1f} t_max (most ancient time point under the limit) : {t_max_below_limit/mu*tgen:.1f}",
  539. f"\nNe_min (effective size at t_min) : {Ne_min_below_limit/(4*mu):.1f} Ne_max (effective size at t_max) : {Ne_max_below_limit/(4*mu):.1f}",
  540. f"\nNe_min/Ne_max = {(Ne_min_below_limit/(4*mu)) / (Ne_max_below_limit/(4*mu)):.1f}",
  541. f"\nEvolution: {((Ne_min_below_limit/(4*mu)) - (Ne_max_below_limit/(4*mu)))/((Ne_max_below_limit/(4*mu)))*100:.1f}%")
  542. else:
  543. print(f"Recent event under {recent_limit_years} years: NA")
  544. # need to compute the last change and when it occured
  545. tmin = x[1]
  546. tmin_plus_1 = x[2]
  547. Ne_min = y[1]
  548. Ne_min_plus_1 = y[2]
  549. print(f"Last was {tmin/mu*tgen:.1f} years ago. And was of {((Ne_min/(4*mu)) - (Ne_min_plus_1/(4*mu)))/(Ne_min_plus_1/(4*mu))*100:.1f}%")
  550. else:
  551. masking_alpha = 0
  552. autoscale = False
  553. ax2.set_autoscale_on(autoscale)
  554. ax3.set_autoscale_on(autoscale)
  555. p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=masking_alpha, lw=2, label = str(breaks)+' brks')
  556. # Plotting (fig 3) which is the same but log scale for x
  557. p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=masking_alpha, lw=2, label = str(breaks)+' brks')
  558. if subset is not None and breaks in subset:
  559. # store for legend
  560. lines_fig2.append(p2)
  561. lines_fig3.append(p3)
  562. # put the vertical line of the "recent" time limit
  563. ax3.axvline(x=recent_limit, linestyle="--")
  564. ax3.axvline(x=recent_limit/2, linestyle="--", color="green")
  565. if theta_scale:
  566. xlabel = "Theta scaled by N0"
  567. ylabel = "Theta scaled by N0"
  568. else:
  569. xlabel = "time"
  570. ylabel = "Effective pop. size (Ne)"
  571. if ax is None:
  572. # if not ax, then use the plt syntax, not ax...
  573. plt.xlabel(xlabel, fontsize=fnt_size)
  574. plt.ylabel(ylabel, fontsize=fnt_size)
  575. plt.gca().set_xlim(0, recent_limit * 3)
  576. if recent_change:
  577. plt.ylim(Ne_min_below_limit/3, Ne_max_below_limit *3)
  578. else:
  579. plt.ylim(y2_plot[0]/3, y2_plot[0])
  580. # plt.ylim(0, max(max_y+(max_y*0.05), max(swp2_lines[1])+(max(swp2_lines[1])*0.05)))
  581. #plt.xlim(0, recent_limit * 3)
  582. #xlim_val = plt.gca().get_xlim()
  583. x_ticks = list(plt.xticks())[0]
  584. # plt.xlim(min(min_x,min(swp2_lines[0])), max(max(swp2_lines[0]), max_x))
  585. # x_ticks = list(plt.gca().get_xticks())
  586. # plt.gca().set_xticks(x_ticks)
  587. # plt.xticks(x_ticks)
  588. # plt.gca().set_xlim(xlim_val)
  589. # plt.gca().set_xticklabels([f'{k:.0e}\n{k/(mu):.0e}\n{k/(mu)*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5)
  590. plt.gca().set_xticklabels([f'{k:.1f}\n{k*tgen:.1f}' for k in x_ticks], fontsize = fnt_size*0.5)
  591. # rescale y to effective pop size
  592. # ylim_val = plt.gca().get_ylim()
  593. # plt.ylim(min(min_y,min(swp2_lines[1])), max(max_y+(max_y*0.05), max(swp2_lines[1])+(max(swp2_lines[1])*0.05)))
  594. # y_ticks = list(plt.yticks())[0]
  595. # plt.gca().set_yticks(y_ticks)
  596. # plt.gca().set_ylim(ylim_val)
  597. # plt.yticks(y_ticks)
  598. # plt.gca().set_yticklabels([f'{k/(4*mu):.0e}' for k in y_ticks], fontsize = fnt_size*0.5)
  599. # plt.title(title, fontsize=fnt_size)
  600. # plt.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
  601. # # plt.text(-0.13, -0.135, 'Coal. time\nGen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
  602. plt.text(-0.13, -0.135, 'Gen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
  603. plt.subplots_adjust(bottom=0.2) # Adjust the value as needed
  604. plt.savefig(title+'_plotB_'+str(nb_epochs)+'_epochs.pdf')
  605. # close fig2 to save memory
  606. plt.close(fig2)
  607. else:
  608. # when ax subplotting is used
  609. ax2.set_xlabel(xlabel, fontsize=fnt_size)
  610. ax2.set_ylabel(ylabel, fontsize=fnt_size)
  611. ax2.set_title(title, fontsize=fnt_size)
  612. ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
  613. ax3.set_xlabel(xlabel, fontsize=fnt_size)
  614. ax3.set_ylabel(ylabel, fontsize=fnt_size)
  615. ax3.set_title(title, fontsize=fnt_size)
  616. ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5)
  617. ax3.set_xscale('log')
  618. ax3.set_yscale('log')
  619. # Scale the x-axis
  620. # x_ticks = list(ax3.get_xticks())
  621. # ax3.set_xticks(x_ticks)
  622. # x_ticks = [i for i in range(0.1,max(max_x, max(swp2_lines[0]))), ]
  623. # ax3.set_xticks(x_ticks)
  624. ax3.set_xlim(0.1, max(max_x, max(swp2_lines[0])))
  625. x_ticks = ax3.get_xticks()
  626. # ax3.set_xlim(min(min(x_ticks), min(swp2_lines[0])), max(max_x, max(swp2_lines[0])))
  627. # ax3.set_xlim(1, max(max_x, max(swp2_lines[0])))
  628. # ax3.set_xticklabels([f'{k:.0e}\n{k/(mu):.0e}\n{k/(mu)*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5)
  629. # ax3.set_xticklabels([f'{k/(mu):.0e}\n{k/(mu)*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5)
  630. ax3.set_xticklabels([f'{k:.0e}\n{k*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5)
  631. # rescale y to effective pop size
  632. # y_ticks = list(ax3.get_yticks())
  633. # ax3.set_yticks(y_ticks)
  634. # ax3.set_ylim(min(min(y_ticks), min(swp2_lines[1])), max(max_y+(max_y*0.5), max(swp2_lines[1])+(max(swp2_lines[1])*0.5)))
  635. # ax3.set_ylim(1, max(max_y, max(swp2_lines[1])))
  636. ax3.set_ylim(1, max(max_y+(max_y*0.5), max(swp2_lines[1])+(max(swp2_lines[1])*0.5)))
  637. # ax3.set_yticklabels([f'{k/(4*mu):.0e}' for k in y_ticks], fontsize = fnt_size*0.5)
  638. # plt.text(-0.13, -0.135, 'Coal. time\nGen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
  639. # plt.text(-0.13, -0.135, 'Gen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
  640. plt.text(-0.13, -0.085, 'Gen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
  641. plt.subplots_adjust(bottom=0.2) # Adjust the value as needed
  642. if ax is None:
  643. # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
  644. plt.savefig(title+'_plotC_'+str(nb_epochs)+'_epochs_log.pdf')
  645. # close fig3 to save memory
  646. plt.close(fig3)
  647. return ax
  648. def plot_raw_stairs(plot_lines, prop, title, ax = None, n_ticks = 10, rescale = False, subset = None, max_breaks = None):
  649. if max_breaks:
  650. nb_breaks = max_breaks
  651. else:
  652. nb_breaks = len(plot_lines)+1
  653. # multiple fig
  654. if ax is None:
  655. # intialize figure 1
  656. my_dpi = 500
  657. fnt_size = 18
  658. # plt.rcParams['font.size'] = fnt_size
  659. fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  660. plt.subplots_adjust(bottom=0.2) # Adjust the value as needed
  661. else:
  662. fnt_size = 12
  663. # plt.rcParams['font.size'] = fnt_size
  664. ax1 = ax[0, 0]
  665. plt.subplots_adjust(wspace=0.3, hspace=0.3)
  666. plots = []
  667. for breaks, plot in enumerate(plot_lines):
  668. if max_breaks and breaks > max_breaks:
  669. # stop plotting if it exceeds the limit
  670. continue
  671. x,y = plot
  672. x_plot, y_plot = plot_straight_x_y(x,y)
  673. p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(breaks)+' brks')
  674. print("breaks=", breaks, "theta0", y[0])
  675. # add plot to the list of all plots to superimpose
  676. plots.append(p)
  677. x_ticks = x
  678. # print(x_ticks)
  679. #print(prop, "\n", sum(prop))
  680. #ax.legend(handles=[p0]+plots)
  681. ax1.set_xlabel("# bin & cumul. prop. of sites", fontsize=fnt_size)
  682. # Set the x-axis locator to reduce the number of ticks to 10
  683. ax1.set_ylabel(r'$\theta_k$', fontsize=fnt_size, rotation = 90)
  684. ax1.set_title(title, fontsize=fnt_size)
  685. ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
  686. ax1.set_xticks(x_ticks)
  687. step = len(x_ticks)//(n_ticks-1)
  688. values = x_ticks[::step]
  689. new_prop = []
  690. for val in values:
  691. new_prop.append(prop[int(val)-2])
  692. new_prop = new_prop[::-1]
  693. ax1.set_xticks(values)
  694. ax1.set_xticklabels([f'{values[k]}\n{val:.2f}' for k, val in enumerate(new_prop)], fontsize = fnt_size*0.8)
  695. if ax is None:
  696. # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
  697. plt.savefig(title+'_raw_'+str(nb_breaks)+'_breaks.pdf')
  698. plt.close(fig)
  699. # return plots
  700. return ax
  701. def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale = False, selected_breaks = []):
  702. my_dpi = 300
  703. saved_plots_dict = save_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, output = title+"_plotdata.json")
  704. nb_of_epochs = len(saved_plots_dict["all_epochs"]["plots"])
  705. best_epoch = saved_plots_dict["best_epoch_by_AIC"]
  706. print("Best epoch based on AIC =", best_epoch)
  707. save_k_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = nb_of_epochs, input = title+"_plotdata.json", output = title+"_plotdata.json")
  708. with open(title+"_plotdata.json", 'r') as json_file:
  709. loaded_data = json.load(json_file)
  710. # START OF COMBINED PLOT CODE
  711. # # plot page 1 of summary
  712. # fig1, ax1 = plt.subplots(2, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
  713. # # fig1.tight_layout()
  714. # # Adjust absolute space between the top and bottom rows
  715. # fig1.subplots_adjust(hspace=0.35) # Adjust this value based on your requirement
  716. # # plot page 2 of summary
  717. # fig2, ax2 = plt.subplots(2, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
  718. # # fig2.tight_layout()
  719. # ax1 = plot_raw_stairs(plot_lines = loaded_data['raw_stairs'],
  720. # prop = loaded_data['prop'], title = title, ax = ax1)
  721. # ax1 = plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'],
  722. # prop = loaded_data['prop'], title = title, ax = ax1, subset=[loaded_data['best_epoch_by_AIC']]+selected_breaks)
  723. # ax2 = plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'],
  724. # prop = loaded_data['prop'], title = title, ax = ax2)
  725. # ax1, ax2 = plot_all_epochs_thetafolder(loaded_data, mu, tgen, title, theta_scale, ax = [ax1, ax2])
  726. # fig1.savefig(title+'_combined_p1.pdf')
  727. # print("Wrote", title+'_combined_p1.pdf')
  728. # fig2.savefig(title+'_combined_p2.pdf')
  729. # print("Wrote", title+'_combined_p2.pdf')
  730. # END OF COMBINED PLOT CODE
  731. # Start of Parsing real swp2 output
  732. folder_splitted = folder_path.split("/")
  733. swp2_summary = "/".join(folder_splitted[:-2])+'/'+folder_splitted[-3]+".final.summary"
  734. swp2_vals = parse_stairwayplot_output_summary(stwplt_out = swp2_summary)
  735. swp2_x, swp2_y = swp2_vals[0], swp2_vals[1]
  736. remove_back_and_forth_points(swp2_x, swp2_y)
  737. # End of Parsing real swp2 output
  738. plot_raw_stairs(plot_lines = loaded_data['raw_stairs'],
  739. prop = loaded_data['prop'], title = title, ax = None, max_breaks = breaks)
  740. plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'], mu = mu, tgen = tgen, subset=[loaded_data['best_epoch_by_AIC']]+selected_breaks,
  741. # plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'], subset=list(range(0,3))+[loaded_data['best_epoch_by_AIC']]+selected_breaks,
  742. prop = loaded_data['prop'], title = title, swp2_lines = [swp2_x, swp2_y], ax = None)
  743. plot_all_epochs_thetafolder(loaded_data, mu, tgen, title, theta_scale, ax = None)
  744. # plt.close(fig1)
  745. # plt.close(fig2)
  746. def remove_back_and_forth_points(x_values, y_values):
  747. # to deal with some weirdness of plotting that occur sometimes with the swp2 output
  748. # sometimes the line is going back and forth as x_k > x_(k+1), which is normally not possible
  749. i = 0
  750. while i < len(x_values) - 1:
  751. if x_values[i] >= x_values[i+1]:
  752. del x_values[i]
  753. del y_values[i]
  754. else:
  755. i += 1
  756. def parse_stairwayplot_output_summary(stwplt_out, xlim = None, ylim = None, title = "default title", plot = False):
  757. #col 5
  758. year = []
  759. # col 6
  760. ne_median = []
  761. ne_2_5 = []
  762. ne_97_5 = []
  763. ne_12_5 = []
  764. # col 10
  765. ne_87_5 = []
  766. with open(stwplt_out, "r") as stwplt_stream:
  767. for line in stwplt_stream:
  768. ## Line format
  769. # mutation_per_site n_estimation theta_per_site_median theta_per_site_2.5% theta_per_site_97.5% year Ne_median Ne_2.5% Ne_97.5% Ne_12.5% Ne_87.5%
  770. if not line.startswith("mutation_per_site"):
  771. #not header
  772. values = line.strip().split()
  773. year.append(float(values[5]))
  774. ne_median.append(float(values[6]))
  775. ne_2_5.append(float(values[7]))
  776. ne_97_5.append(float(values[8]))
  777. ne_12_5.append(float(values[9]))
  778. ne_87_5.append(float(values[10]))
  779. vals = [year, ne_median, ne_2_5, ne_97_5, ne_12_5, ne_87_5]
  780. if plot :
  781. # plot parsed data
  782. label = ["Ne median", "Ne 2.5%", "Ne 97.5%", "Ne 12.5%", "Ne 87.5%"]
  783. for i in range(1, 5):
  784. fig, = plt.plot(year, vals[i], '--', alpha = 0.4)
  785. fig.set_label(label[i])
  786. # # last plot is median
  787. fig, = plt.plot(year, ne_median, 'r-', lw=2)
  788. fig.set_label(label[0])
  789. plt.legend()
  790. plt.ylabel("Individuals (Ne)")
  791. plt.xlabel("Time (years)")
  792. if xlim:
  793. plt.xlim(xlim)
  794. if ylim:
  795. plt.ylim(ylim)
  796. plt.title(title)
  797. plt.show()
  798. plt.close()
  799. return vals
  800. if __name__ == "__main__":
  801. if len(sys.argv) != 4:
  802. print("Need 3 args: ThetaFolder MutationRate GenerationTime")
  803. exit(0)
  804. folder_path = sys.argv[1]
  805. mu = sys.argv[2]
  806. tgen = sys.argv[3]
  807. plot_all_epochs_thetafolder(folder_path, mu, tgen)