swp2.py 40KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996
  1. import matplotlib.pyplot as plt
  2. import os
  3. import numpy as np
  4. import math
  5. import json
  6. def log_facto(k):
  7. """
  8. Using the Stirling's approximation
  9. """
  10. k = int(k)
  11. if k > 1e6:
  12. return k * np.log(k) - k + np.log(2*math.pi*k)/2
  13. val = 0
  14. for i in range(2, k+1):
  15. val += np.log(i)
  16. return val
  17. def parse_stwp_theta_file(stwp_theta_file, breaks, mu, tgen, relative_theta_scale = False):
  18. with open(stwp_theta_file, "r") as swp_file:
  19. # Read the first line
  20. line = swp_file.readline()
  21. L = float(line.split()[2])
  22. rands = swp_file.readline()
  23. line = swp_file.readline()
  24. # skip empty lines before SFS
  25. while line == "\n":
  26. line = swp_file.readline()
  27. sfs = np.array(line.split()).astype(float)
  28. # Process lines until the end of the file
  29. while line:
  30. # check at each line
  31. if line.startswith("dim") :
  32. dim = int(line.split()[1])
  33. if dim == breaks+1:
  34. likelihood = line.split()[5]
  35. groups = line.split()[6:6+dim]
  36. theta_site = line.split()[6+dim:6+dim+1+dim]
  37. elif dim < breaks+1:
  38. line = swp_file.readline()
  39. continue
  40. elif dim > breaks+1:
  41. break
  42. #return 0,0,0
  43. # Read the next line
  44. line = swp_file.readline()
  45. #### END of parsing
  46. # quit this file if the number of dimensions is incorrect
  47. if dim < breaks+1:
  48. return 0,0,0,0,0,0
  49. # get n, the last bin of the last group
  50. # revert the list of groups as the most recent times correspond
  51. # to the closest and last leafs of the coal. tree.
  52. groups = groups[::-1]
  53. theta_site = theta_site[::-1]
  54. # store thetas for later use
  55. grps = groups.copy()
  56. thetas = {}
  57. for i in range(len(groups)):
  58. grps[i] = grps[i].split(',')
  59. thetas[i] = [float(theta_site[i]), grps[i], likelihood]
  60. # initiate the dict of times
  61. t = {}
  62. # list of thetas
  63. theta_L = []
  64. sum_t = 0
  65. for group_nb, group in enumerate(groups):
  66. ###print(group_nb, group, theta_site[group_nb], len(theta_site))
  67. # store all the thetas one by one, with one theta per group
  68. theta_L.append(float(theta_site[group_nb]))
  69. # if the group is of size 1
  70. if len(group.split(',')) == 1:
  71. i = int(group)
  72. # if the group size is >1, take the first elem of the group
  73. # i is the first bin of each group, straight after a breakpoint
  74. else:
  75. i = int(group.split(",")[0])
  76. j = int(group.split(",")[-1])
  77. t[i] = 0
  78. #t =
  79. if len(group.split(',')) == 1:
  80. k = i
  81. if relative_theta_scale:
  82. t[i] += ((theta_L[group_nb] ) / (k*(k-1)))
  83. else:
  84. t[i] += ((theta_L[group_nb] ) / (k*(k-1)) * tgen) / mu
  85. else:
  86. for k in range(j, i-1, -1 ):
  87. if relative_theta_scale:
  88. t[i] += ((theta_L[group_nb] ) / (k*(k-1)))
  89. else:
  90. t[i] += ((theta_L[group_nb] ) / (k*(k-1)) * tgen) / mu
  91. # we add the cumulative times at the end
  92. t[i] += sum_t
  93. sum_t = t[i]
  94. # build the y axis (sizes)
  95. y = []
  96. for theta in theta_L:
  97. if relative_theta_scale:
  98. size = theta
  99. else:
  100. # with size N = theta/4mu
  101. size = theta / (4*mu)
  102. y.append(size)
  103. y.append(size)
  104. # build the time x axis
  105. x = [0]
  106. for time in range(0, len(t.values())-1):
  107. x.append(list(t.values())[time])
  108. x.append(list(t.values())[time])
  109. x.append(list(t.values())[len(t.values())-1])
  110. return x,y,likelihood,thetas,sfs,L
  111. def plot_k_epochs_thetafolder(folder_path, mu, tgen, breaks = 2, title = "Title", theta_scale = True):
  112. scenari = {}
  113. cpt = 0
  114. for file_name in os.listdir(folder_path):
  115. if os.path.isfile(os.path.join(folder_path, file_name)):
  116. # Perform actions on each file
  117. x, y, likelihood, theta, sfs, L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
  118. tgen = tgen,
  119. mu = mu, relative_theta_scale = theta_scale)
  120. if x == 0 or y == 0:
  121. continue
  122. cpt +=1
  123. scenari[likelihood] = x,y
  124. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
  125. print(cpt, "theta file(s) have been scanned.")
  126. # sort starting by the smallest -log(Likelihood)
  127. print(scenari)
  128. best10_scenari = (sorted(list(scenari.keys())))[:10]
  129. print("10 greatest Likelihoods", best10_scenari)
  130. greatest_likelihood = best10_scenari[0]
  131. x, y = scenari[greatest_likelihood]
  132. my_dpi = 300
  133. plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  134. plt.plot(x, y, 'r-', lw=2, label = 'Lik='+greatest_likelihood)
  135. #plt.yscale('log')
  136. plt.xscale('log')
  137. plt.grid(True,which="both", linestyle='--', alpha = 0.3)
  138. for scenario in best10_scenari[1:]:
  139. x,y = scenari[scenario]
  140. #print("\n---- Lik:",scenario,"\n\nt=", x,"\n\nN=",y, "\n\n")
  141. plt.plot(x, y, '--', lw=1, label = 'Lik='+scenario)
  142. if theta_scale:
  143. plt.xlabel("Coal. time")
  144. plt.ylabel("Pop. size scaled by N0")
  145. recent_scale_lower_bound = y[0] * 0.01
  146. recent_scale_upper_bound = y[0] * 0.1
  147. plt.axvline(x=recent_scale_lower_bound)
  148. plt.axvline(x=recent_scale_upper_bound)
  149. else:
  150. # years
  151. plt.xlabel("Time (years)")
  152. plt.ylabel("Individuals (N)")
  153. plt.legend(loc='upper right')
  154. plt.title(title)
  155. plt.savefig(title+'_b'+str(breaks)+'.pdf')
  156. def plot_straight_x_y(x,y):
  157. x_1 = [x[0]]
  158. y_1 = []
  159. for i in range(0, len(y)-1):
  160. x_1.append(x[i])
  161. x_1.append(x[i])
  162. y_1.append(y[i])
  163. y_1.append(y[i])
  164. y_1 = y_1+[y[-1],y[-1]]
  165. x_1.append(x[-1])
  166. return x_1, y_1
  167. def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title",
  168. theta_scale = True, ax = None, input = None, output = None):
  169. #scenari = {}
  170. cpt = 0
  171. epochs = {}
  172. for file_name in os.listdir(folder_path):
  173. breaks = 0
  174. cpt +=1
  175. if os.path.isfile(os.path.join(folder_path, file_name)):
  176. x, y, likelihood, theta, sfs, L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
  177. tgen = tgen,
  178. mu = mu, relative_theta_scale = theta_scale)
  179. SFS_stored = sfs
  180. L_stored = L
  181. while not (x == 0 and y == 0):
  182. if breaks not in epochs.keys():
  183. epochs[breaks] = {}
  184. epochs[breaks][likelihood] = x,y
  185. breaks += 1
  186. x,y,likelihood,theta,sfs,L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
  187. tgen = tgen,
  188. mu = mu, relative_theta_scale = theta_scale)
  189. if x == 0:
  190. # last break did not work, then breaks = breaks-1
  191. breaks -= 1
  192. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
  193. print(cpt, "theta file(s) have been scanned.")
  194. my_dpi = 300
  195. if ax is None:
  196. # intialize figure
  197. my_dpi = 300
  198. fnt_size = 18
  199. # plt.rcParams['font.size'] = fnt_size
  200. fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  201. else:
  202. fnt_size = 12
  203. # plt.rcParams['font.size'] = fnt_size
  204. ax1 = ax[1][0,0]
  205. ax1.set_yscale('log')
  206. ax1.set_xscale('log')
  207. ax1.grid(True,which="both", linestyle='--', alpha = 0.3)
  208. brkpt_lik = []
  209. top_plots = {}
  210. for epoch, scenari in epochs.items():
  211. # sort starting by the smallest -log(Likelihood)
  212. best10_scenari = (sorted(list(scenari.keys())))[:10]
  213. greatest_likelihood = best10_scenari[0]
  214. # store the tuple breakpoints and likelihood for later plot
  215. brkpt_lik.append((epoch, greatest_likelihood))
  216. x, y = scenari[greatest_likelihood]
  217. #without breakpoint
  218. if epoch == 0:
  219. # do something with the theta without bp and skip the plotting
  220. N0 = y[0]
  221. #continue
  222. for i in range(len(y)):
  223. # divide by N0
  224. y[i] = y[i]/N0
  225. x[i] = x[i]/N0
  226. top_plots[greatest_likelihood] = x,y,epoch
  227. plots_likelihoods = list(top_plots.keys())
  228. for i in range(len(plots_likelihoods)):
  229. plots_likelihoods[i] = float(plots_likelihoods[i])
  230. best10_plots = sorted(plots_likelihoods)[:10]
  231. top_plot_lik = str(best10_plots[0])
  232. plot_handles = []
  233. # plt.rcParams['font.size'] = fnt_size
  234. p0, = ax1.plot(top_plots[top_plot_lik][0], top_plots[top_plot_lik][1], 'o', linestyle = "-",
  235. alpha=1, lw=2, label = str(top_plots[top_plot_lik][2])+' brks | Lik='+top_plot_lik)
  236. plot_handles.append(p0)
  237. for k, plot_Lk in enumerate(best10_plots[1:]):
  238. plot_Lk = str(plot_Lk)
  239. # plt.rcParams['font.size'] = fnt_size
  240. p, = ax1.plot(top_plots[plot_Lk][0], top_plots[plot_Lk][1], 'o', linestyle = "--",
  241. alpha=1/(k+1), lw=1.5, label = str(top_plots[plot_Lk][2])+' brks | Lik='+plot_Lk)
  242. plot_handles.append(p)
  243. if theta_scale:
  244. ax1.set_xlabel("Coal. time", fontsize=fnt_size)
  245. ax1.set_ylabel("Pop. size scaled by N0", fontsize=fnt_size)
  246. # recent_scale_lower_bound = 0.01
  247. # recent_scale_upper_bound = 0.1
  248. # ax1.axvline(x=recent_scale_lower_bound)
  249. # ax1.axvline(x=recent_scale_upper_bound)
  250. else:
  251. # years
  252. plt.set_xlabel("Time (years)", fontsize=fnt_size)
  253. plt.set_ylabel("Individuals (N)", fontsize=fnt_size)
  254. # plt.rcParams['font.size'] = fnt_size
  255. # print(fnt_size, "rcParam font.size=", plt.rcParams['font.size'])
  256. ax1.legend(handles = plot_handles, loc='best', fontsize = fnt_size*0.5)
  257. ax1.set_title(title)
  258. if ax is None:
  259. plt.savefig(title+'_b'+str(breaks)+'.pdf')
  260. # plot likelihood against nb of breakpoints
  261. # best possible likelihood from SFS
  262. # Segregating sites
  263. S = sum(SFS_stored)
  264. # Number of kept sites from which the SFS is computed
  265. L = L_stored
  266. # number of monomorphic sites
  267. S0 = L-S
  268. # print("SFS", SFS_stored)
  269. # print("S", S, "L", L, "S0=", S0)
  270. # compute Ln
  271. Ln = log_facto(S+S0) - log_facto(S0) + np.log(float(S0)/(S+S0)) * S0
  272. for xi in range(0, len(SFS_stored)):
  273. p_i = SFS_stored[xi] / float(S+S0)
  274. Ln += np.log(p_i) * SFS_stored[xi] - log_facto(SFS_stored[xi])
  275. # basic plot likelihood
  276. if ax is None:
  277. fig, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  278. # plt.rcParams['font.size'] = fnt_size
  279. else:
  280. #plt.rcParams['font.size'] = fnt_size
  281. ax2 = ax[0][0,1]
  282. ax2.plot(np.array(brkpt_lik)[:, 0], np.array(brkpt_lik)[:, 1].astype(float), 'o', linestyle = "dotted", lw=2)
  283. ax2.axhline(y=-Ln, linestyle = "-.", color = "red", label = "$-\log\mathcal{L}$ = "+str(round(-Ln, 2)))
  284. ax2.set_yscale('log')
  285. ax2.set_xlabel("# breakpoints", fontsize=fnt_size)
  286. ax2.set_ylabel("$-\log\mathcal{L}$", fontsize=fnt_size)
  287. ax2.legend(loc='best', fontsize = fnt_size*0.5)
  288. ax2.set_title(title+" Likelihood gain from # breakpoints")
  289. if ax is None:
  290. plt.savefig(title+'_Breakpts_Likelihood.pdf')
  291. # AIC
  292. if ax is None:
  293. fig, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  294. # plt.rcParams['font.size'] = '18'
  295. else:
  296. #plt.rcParams['font.size'] = fnt_size
  297. ax3 = ax[1][0,1]
  298. AIC = []
  299. for brk in np.array(brkpt_lik)[:, 0]:
  300. brk = int(brk)
  301. AIC.append((2*brk+1)+2*np.array(brkpt_lik)[brk, 1].astype(float))
  302. ax3.plot(np.array(brkpt_lik)[:, 0], AIC, 'o', linestyle = "dotted", lw=2)
  303. # AIC = 2*k - 2ln(L) ; where k is the number of parameters, here brks+1
  304. AIC_ln = 2*(len(brkpt_lik)+1) - 2*Ln
  305. ax3.axhline(y=AIC_ln, linestyle = "-.", color = "red",
  306. label = "Min. AIC = "+str(round(AIC_ln, 2)))
  307. selected_brks_nb = AIC.index(min(AIC))
  308. ax3.set_yscale('log')
  309. ax3.set_xlabel("# breakpoints", fontsize=fnt_size)
  310. ax3.set_ylabel("AIC")
  311. ax3.legend(loc='best', fontsize = fnt_size*0.5)
  312. ax3.set_title(title+" AIC")
  313. if ax is None:
  314. plt.savefig(title+'_Breakpts_Likelihood_AIC.pdf')
  315. print("S", S)
  316. # return plots
  317. return ax[0], ax[1]
  318. def save_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_scale = True, input = None, output = None):
  319. #scenari = {}
  320. cpt = 0
  321. epochs = {}
  322. plots = {}
  323. # store ['best'], and [0] for epoch 0 etc...
  324. for file_name in os.listdir(folder_path):
  325. breaks = 0
  326. cpt +=1
  327. if os.path.isfile(os.path.join(folder_path, file_name)):
  328. x, y, likelihood, theta, sfs, L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
  329. tgen = tgen,
  330. mu = mu, relative_theta_scale = theta_scale)
  331. SFS_stored = sfs
  332. L_stored = L
  333. while not (x == 0 and y == 0):
  334. if breaks not in epochs.keys():
  335. epochs[breaks] = {}
  336. epochs[breaks][likelihood] = x,y
  337. breaks += 1
  338. x,y,likelihood,theta,sfs,L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
  339. tgen = tgen,
  340. mu = mu, relative_theta_scale = theta_scale)
  341. if x == 0:
  342. # last break did not work, then breaks = breaks-1
  343. breaks -= 1
  344. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
  345. print(cpt, "theta file(s) have been scanned.")
  346. brkpt_lik = []
  347. top_plots = {}
  348. for epoch, scenari in epochs.items():
  349. # sort starting by the smallest -log(Likelihood)
  350. best10_scenari = (sorted(list(scenari.keys())))[:10]
  351. greatest_likelihood = best10_scenari[0]
  352. # store the tuple breakpoints and likelihood for later plot
  353. brkpt_lik.append((epoch, greatest_likelihood))
  354. x, y = scenari[greatest_likelihood]
  355. #without breakpoint
  356. if epoch == 0:
  357. # do something with the theta without bp and skip the plotting
  358. N0 = y[0]
  359. #continue
  360. for i in range(len(y)):
  361. # divide by N0
  362. y[i] = y[i]/N0
  363. x[i] = x[i]/N0
  364. top_plots[greatest_likelihood] = x,y,epoch
  365. plots_likelihoods = list(top_plots.keys())
  366. for i in range(len(plots_likelihoods)):
  367. plots_likelihoods[i] = float(plots_likelihoods[i])
  368. best10_plots = sorted(plots_likelihoods)[:10]
  369. top_plot_lik = str(best10_plots[0])
  370. # store x,y,brks,likelihood
  371. plots['best'] = (top_plots[top_plot_lik][0], top_plots[top_plot_lik][1], str(top_plots[top_plot_lik][2]), top_plot_lik)
  372. for k, plot_Lk in enumerate(best10_plots[1:]):
  373. plot_Lk = str(plot_Lk)
  374. plots[str(top_plots[plot_Lk][2])] = (top_plots[plot_Lk][0], top_plots[plot_Lk][1], str(top_plots[plot_Lk][2]), plot_Lk)
  375. # plot likelihood against nb of breakpoints
  376. # best possible likelihood from SFS
  377. # Segregating sites
  378. S = sum(SFS_stored)
  379. # Number of kept sites from which the SFS is computed
  380. L = L_stored
  381. # number of monomorphic sites
  382. S0 = L-S
  383. # print("SFS", SFS_stored)
  384. # print("S", S, "L", L, "S0=", S0)
  385. # compute Ln
  386. Ln = log_facto(S+S0) - log_facto(S0) + np.log(float(S0)/(S+S0)) * S0
  387. for xi in range(0, len(SFS_stored)):
  388. p_i = SFS_stored[xi] / float(S+S0)
  389. Ln += np.log(p_i) * SFS_stored[xi] - log_facto(SFS_stored[xi])
  390. # basic plot likelihood
  391. Ln_Brks = [list(np.array(brkpt_lik)[:, 0]), list(np.array(brkpt_lik)[:, 1].astype(float))]
  392. best_Ln = -Ln
  393. AIC = []
  394. for brk in np.array(brkpt_lik)[:, 0]:
  395. brk = int(brk)
  396. AIC.append((2*brk+1)+2*np.array(brkpt_lik)[brk, 1].astype(float))
  397. AIC_Brks = [list(np.array(brkpt_lik)[:, 0]), AIC]
  398. # AIC = 2*k - 2ln(L) ; where k is the number of parameters, here brks+1
  399. AIC_ln = 2*(len(brkpt_lik)+1) - 2*Ln
  400. best_AIC = AIC_ln
  401. # to return : plots ; Ln_Brks ; AIC_Brks ; best_Ln ; best_AIC
  402. # 'plots' dict keys: 'best', {epochs}('0', '1',...)
  403. if input == None:
  404. saved_plots = {"all_epochs":plots, "Ln_Brks":Ln_Brks,
  405. "AIC_Brks":AIC_Brks, "best_Ln":best_Ln,
  406. "best_AIC":best_AIC}
  407. else:
  408. # if the dict has to be loaded from input
  409. with open(input, 'r') as json_file:
  410. saved_plots = json.load(json_file)
  411. saved_plots["all_epochs"] = plots
  412. saved_plots["Ln_Brks"] = Ln_Brks
  413. saved_plots["AIC_Brks"] = AIC_Brks
  414. saved_plots["best_Ln"] = best_Ln
  415. saved_plots["best_AIC"] = best_AIC
  416. if output == None:
  417. output = title+"_plotdata.json"
  418. with open(output, 'w') as json_file:
  419. json.dump(saved_plots, json_file)
  420. return saved_plots
  421. def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_scale = True, ax = None):
  422. #scenari = {}
  423. cpt = 0
  424. epochs = {}
  425. for file_name in os.listdir(folder_path):
  426. breaks = 0
  427. cpt +=1
  428. if os.path.isfile(os.path.join(folder_path, file_name)):
  429. x, y, likelihood, theta, sfs, L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
  430. tgen = tgen,
  431. mu = mu, relative_theta_scale = theta_scale)
  432. SFS_stored = sfs
  433. L_stored = L
  434. while not (x == 0 and y == 0):
  435. if breaks not in epochs.keys():
  436. epochs[breaks] = {}
  437. epochs[breaks][likelihood] = x,y
  438. breaks += 1
  439. x,y,likelihood,theta,sfs,L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
  440. tgen = tgen,
  441. mu = mu, relative_theta_scale = theta_scale)
  442. if x == 0:
  443. # last break did not work, then breaks = breaks-1
  444. breaks -= 1
  445. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
  446. print(cpt, "theta file(s) have been scanned.")
  447. my_dpi = 300
  448. if ax is None:
  449. # intialize figure
  450. my_dpi = 300
  451. fnt_size = 18
  452. # plt.rcParams['font.size'] = fnt_size
  453. fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  454. else:
  455. fnt_size = 12
  456. # plt.rcParams['font.size'] = fnt_size
  457. ax1 = ax[1][0,0]
  458. ax1.set_yscale('log')
  459. ax1.set_xscale('log')
  460. ax1.grid(True,which="both", linestyle='--', alpha = 0.3)
  461. brkpt_lik = []
  462. top_plots = {}
  463. for epoch, scenari in epochs.items():
  464. # sort starting by the smallest -log(Likelihood)
  465. best10_scenari = (sorted(list(scenari.keys())))[:10]
  466. greatest_likelihood = best10_scenari[0]
  467. # store the tuple breakpoints and likelihood for later plot
  468. brkpt_lik.append((epoch, greatest_likelihood))
  469. x, y = scenari[greatest_likelihood]
  470. #without breakpoint
  471. if epoch == 0:
  472. # do something with the theta without bp and skip the plotting
  473. N0 = y[0]
  474. #continue
  475. for i in range(len(y)):
  476. # divide by N0
  477. y[i] = y[i]/N0
  478. x[i] = x[i]/N0
  479. top_plots[greatest_likelihood] = x,y,epoch
  480. plots_likelihoods = list(top_plots.keys())
  481. for i in range(len(plots_likelihoods)):
  482. plots_likelihoods[i] = float(plots_likelihoods[i])
  483. best10_plots = sorted(plots_likelihoods)[:10]
  484. top_plot_lik = str(best10_plots[0])
  485. plot_handles = []
  486. # plt.rcParams['font.size'] = fnt_size
  487. p0, = ax1.plot(top_plots[top_plot_lik][0], top_plots[top_plot_lik][1], 'o', linestyle = "-",
  488. alpha=1, lw=2, label = str(top_plots[top_plot_lik][2])+' brks | Lik='+top_plot_lik)
  489. plot_handles.append(p0)
  490. for k, plot_Lk in enumerate(best10_plots[1:]):
  491. plot_Lk = str(plot_Lk)
  492. # plt.rcParams['font.size'] = fnt_size
  493. p, = ax1.plot(top_plots[plot_Lk][0], top_plots[plot_Lk][1], 'o', linestyle = "--",
  494. alpha=1/(k+1), lw=1.5, label = str(top_plots[plot_Lk][2])+' brks | Lik='+plot_Lk)
  495. plot_handles.append(p)
  496. if theta_scale:
  497. ax1.set_xlabel("Coal. time", fontsize=fnt_size)
  498. ax1.set_ylabel("Pop. size scaled by N0", fontsize=fnt_size)
  499. # recent_scale_lower_bound = 0.01
  500. # recent_scale_upper_bound = 0.1
  501. # ax1.axvline(x=recent_scale_lower_bound)
  502. # ax1.axvline(x=recent_scale_upper_bound)
  503. else:
  504. # years
  505. plt.set_xlabel("Time (years)", fontsize=fnt_size)
  506. plt.set_ylabel("Individuals (N)", fontsize=fnt_size)
  507. # plt.rcParams['font.size'] = fnt_size
  508. # print(fnt_size, "rcParam font.size=", plt.rcParams['font.size'])
  509. ax1.legend(handles = plot_handles, loc='best', fontsize = fnt_size*0.5)
  510. ax1.set_title(title)
  511. if ax is None:
  512. plt.savefig(title+'_b'+str(breaks)+'.pdf')
  513. # plot likelihood against nb of breakpoints
  514. # best possible likelihood from SFS
  515. # Segregating sites
  516. S = sum(SFS_stored)
  517. # Number of kept sites from which the SFS is computed
  518. L = L_stored
  519. # number of monomorphic sites
  520. S0 = L-S
  521. # print("SFS", SFS_stored)
  522. # print("S", S, "L", L, "S0=", S0)
  523. # compute Ln
  524. Ln = log_facto(S+S0) - log_facto(S0) + np.log(float(S0)/(S+S0)) * S0
  525. for xi in range(0, len(SFS_stored)):
  526. p_i = SFS_stored[xi] / float(S+S0)
  527. Ln += np.log(p_i) * SFS_stored[xi] - log_facto(SFS_stored[xi])
  528. # basic plot likelihood
  529. if ax is None:
  530. fig, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  531. # plt.rcParams['font.size'] = fnt_size
  532. else:
  533. #plt.rcParams['font.size'] = fnt_size
  534. ax2 = ax[0][0,1]
  535. ax2.plot(np.array(brkpt_lik)[:, 0], np.array(brkpt_lik)[:, 1].astype(float), 'o', linestyle = "dotted", lw=2)
  536. ax2.axhline(y=-Ln, linestyle = "-.", color = "red", label = "$-\log\mathcal{L}$ = "+str(round(-Ln, 2)))
  537. ax2.set_yscale('log')
  538. ax2.set_xlabel("# breakpoints", fontsize=fnt_size)
  539. ax2.set_ylabel("$-\log\mathcal{L}$", fontsize=fnt_size)
  540. ax2.legend(loc='best', fontsize = fnt_size*0.5)
  541. ax2.set_title(title+" Likelihood gain from # breakpoints")
  542. if ax is None:
  543. plt.savefig(title+'_Breakpts_Likelihood.pdf')
  544. # AIC
  545. if ax is None:
  546. fig, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  547. # plt.rcParams['font.size'] = '18'
  548. else:
  549. #plt.rcParams['font.size'] = fnt_size
  550. ax3 = ax[1][0,1]
  551. AIC = []
  552. for brk in np.array(brkpt_lik)[:, 0]:
  553. brk = int(brk)
  554. AIC.append((2*brk+1)+2*np.array(brkpt_lik)[brk, 1].astype(float))
  555. ax3.plot(np.array(brkpt_lik)[:, 0], AIC, 'o', linestyle = "dotted", lw=2)
  556. # AIC = 2*k - 2ln(L) ; where k is the number of parameters, here brks+1
  557. AIC_ln = 2*(len(brkpt_lik)+1) - 2*Ln
  558. ax3.axhline(y=AIC_ln, linestyle = "-.", color = "red",
  559. label = "Min. AIC = "+str(round(AIC_ln, 2)))
  560. selected_brks_nb = AIC.index(min(AIC))
  561. ax3.set_yscale('log')
  562. ax3.set_xlabel("# breakpoints", fontsize=fnt_size)
  563. ax3.set_ylabel("AIC")
  564. ax3.legend(loc='best', fontsize = fnt_size*0.5)
  565. ax3.set_title(title+" AIC")
  566. if ax is None:
  567. plt.savefig(title+'_Breakpts_Likelihood_AIC.pdf')
  568. print("S", S)
  569. # return plots
  570. return ax[0], ax[1]
  571. def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
  572. breaks_max = 10, input = None, output = None):
  573. """
  574. Save theta values as is to do basic plots.
  575. """
  576. cpt = 0
  577. epochs = {}
  578. len_sfs = 0
  579. for file_name in os.listdir(folder_path):
  580. cpt +=1
  581. if os.path.isfile(os.path.join(folder_path, file_name)):
  582. for k in range(breaks_max):
  583. x,y,likelihood,thetas,sfs,L = parse_stwp_theta_file(folder_path+file_name, breaks = k,
  584. tgen = tgen,
  585. mu = mu, relative_theta_scale = theta_scale)
  586. if thetas == 0:
  587. continue
  588. if len(thetas)-1 != k:
  589. continue
  590. if k not in epochs.keys():
  591. epochs[k] = {}
  592. likelihood = str(eval(thetas[k][2]))
  593. epochs[k][likelihood] = thetas
  594. #epochs[k] = thetas
  595. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(k)+"\n*******\n")
  596. print(cpt, "theta file(s) have been scanned.")
  597. plots = []
  598. best_epochs = {}
  599. for epoch in epochs:
  600. likelihoods = []
  601. for key in epochs[epoch].keys():
  602. likelihoods.append(key)
  603. likelihoods.sort()
  604. minLogLn = str(likelihoods[0])
  605. best_epochs[epoch] = epochs[epoch][minLogLn]
  606. for epoch, theta in best_epochs.items():
  607. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  608. x = []
  609. y = []
  610. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  611. for i,group in enumerate(groups):
  612. x += group[::-1]
  613. y += list(np.repeat(thetas[i], len(group)))
  614. if epoch == 0:
  615. N0 = y[0]
  616. # compute the proportion of information used at each bin of the SFS
  617. sum_theta_i = 0
  618. for i in range(2, len(y)+2):
  619. sum_theta_i+=y[i-2] / (i-1)
  620. prop = []
  621. for k in range(2, len(y)+2):
  622. prop.append(y[k-2] / (k - 1) / sum_theta_i)
  623. prop = prop[::-1]
  624. # normalise to N0 (N0 of epoch1)
  625. for i in range(len(y)):
  626. y[i] = y[i]/N0
  627. # x_plot, y_plot = plot_straight_x_y(x, y)
  628. p = x, y
  629. # add plot to the list of all plots to superimpose
  630. plots.append(p)
  631. cumul = 0
  632. prop_cumul = []
  633. for val in prop:
  634. prop_cumul.append(val+cumul)
  635. cumul = val+cumul
  636. prop = prop_cumul
  637. lines_fig2 = []
  638. for epoch, theta in best_epochs.items():
  639. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  640. x = []
  641. y = []
  642. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  643. for i,group in enumerate(groups):
  644. x += group[::-1]
  645. y += list(np.repeat(thetas[i], len(group)))
  646. if epoch == 0:
  647. N0 = y[0]
  648. for i in range(len(y)):
  649. y[i] = y[i]/N0
  650. x_2 = []
  651. T = 0
  652. for i in range(len(x)):
  653. x[i] = int(x[i])
  654. # compute the times as: theta_k / (k*(k-1))
  655. for i in range(0, len(x)):
  656. T += y[i] / (x[i]*(x[i]-1))
  657. x_2.append(T)
  658. # Save plotting (fig 2)
  659. x_2 = [0]+x_2
  660. y = [y[0]]+y
  661. # x2_plot, y2_plot = plot_straight_x_y(x_2, y)
  662. p2 = x_2, y
  663. lines_fig2.append(p2)
  664. if input == None:
  665. saved_plots = {"raw_stairs":plots, "scaled_stairs":lines_fig2,
  666. "prop":prop}
  667. else:
  668. # if the dict has to be loaded from input
  669. with open(input, 'r') as json_file:
  670. saved_plots = json.load(json_file)
  671. saved_plots["raw_stairs"] = plots
  672. saved_plots["scaled_stairs"] = lines_fig2
  673. saved_plots["prop"] = prop
  674. if output == None:
  675. output = title+"_plotdata.json"
  676. with open(output, 'w') as json_file:
  677. json.dump(saved_plots, json_file)
  678. return saved_plots
  679. def plot_scaled_theta(plot_lines, prop, title, ax = None, n_ticks = 10):
  680. # fig 2 & 3
  681. if ax is None:
  682. my_dpi = 300
  683. fnt_size = 18
  684. fig2, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  685. fig3, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  686. else:
  687. # plt.rcParams['font.size'] = fnt_size
  688. fnt_size = 12
  689. # place of plots on the grid
  690. ax2 = ax[1,0]
  691. ax3 = ax[1,1]
  692. lines_fig2 = []
  693. lines_fig3 = []
  694. #plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  695. for epoch, plot in enumerate(plot_lines):
  696. x,y=plot
  697. x2_plot, y2_plot = plot_straight_x_y(x,y)
  698. p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  699. lines_fig2.append(p2)
  700. # Plotting (fig 3) which is the same but log scale for x
  701. p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  702. lines_fig3.append(p3)
  703. ax2.set_xlabel("Relative scale", fontsize=fnt_size)
  704. ax2.set_ylabel("theta", fontsize=fnt_size)
  705. ax2.set_title(title, fontsize=fnt_size)
  706. ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
  707. if ax is None:
  708. # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
  709. plt.savefig(title+'_plot2_'+str(len(plot_lines))+'.pdf')
  710. # close fig2 to save memory
  711. plt.close(fig2)
  712. ax3.set_xscale('log')
  713. ax3.set_yscale('log')
  714. ax3.set_xlabel("log Relative scale", fontsize=fnt_size)
  715. ax3.set_ylabel("theta", fontsize=fnt_size)
  716. ax3.set_title(title, fontsize=fnt_size)
  717. ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5)
  718. if ax is None:
  719. # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
  720. plt.savefig(title+'_plot3_'+str(len(plot_lines))+'_log.pdf')
  721. # close fig3 to save memory
  722. plt.close(fig3)
  723. return ax
  724. def plot_raw_stairs(plot_lines, prop, title, ax = None, n_ticks = 10):
  725. # multiple fig
  726. if ax is None:
  727. # intialize figure 1
  728. my_dpi = 300
  729. fnt_size = 18
  730. # plt.rcParams['font.size'] = fnt_size
  731. fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  732. else:
  733. fnt_size = 12
  734. # plt.rcParams['font.size'] = fnt_size
  735. ax1 = ax[0, 0]
  736. plt.subplots_adjust(wspace=0.3, hspace=0.3)
  737. plots = []
  738. for epoch, plot in enumerate(plot_lines):
  739. x,y = plot
  740. x_plot, y_plot = plot_straight_x_y(x,y)
  741. p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  742. # add plot to the list of all plots to superimpose
  743. plots.append(p)
  744. x_ticks = x
  745. # print(x_ticks)
  746. #print(prop, "\n", sum(prop))
  747. #ax.legend(handles=[p0]+plots)
  748. ax1.set_xlabel("# bin & cumul. prop. of sites", fontsize=fnt_size)
  749. # Set the x-axis locator to reduce the number of ticks to 10
  750. ax1.set_ylabel("theta", fontsize=fnt_size)
  751. ax1.set_title(title, fontsize=fnt_size)
  752. ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
  753. ax1.set_xticks(x_ticks)
  754. step = len(x_ticks)//(n_ticks-1)
  755. values = x_ticks[::step]
  756. new_prop = []
  757. for val in values:
  758. new_prop.append(prop[int(val)-2])
  759. new_prop = new_prop[::-1]
  760. ax1.set_xticks(values)
  761. ax1.set_xticklabels([f'{values[k]}\n{val:.2f}' for k, val in enumerate(new_prop)], fontsize = fnt_size*0.8)
  762. if ax is None:
  763. # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
  764. plt.savefig(title+'_raw'+str(len(plot_lines))+'.pdf')
  765. plt.close(fig)
  766. # return plots
  767. return ax
  768. def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True, breaks_max = 10, ax = None, n_ticks = 10):
  769. """
  770. Use theta values as is to do basic plots.
  771. """
  772. cpt = 0
  773. epochs = {}
  774. len_sfs = 0
  775. for file_name in os.listdir(folder_path):
  776. cpt +=1
  777. if os.path.isfile(os.path.join(folder_path, file_name)):
  778. for k in range(breaks_max):
  779. x, y, likelihood, theta, sfs, L = parse_stwp_theta_file(folder_path+file_name, breaks = k,
  780. tgen = tgen,
  781. mu = mu, relative_theta_scale = theta_scale)
  782. if thetas == 0:
  783. continue
  784. if len(thetas)-1 != k:
  785. continue
  786. if k not in epochs.keys():
  787. epochs[k] = {}
  788. likelihood = str(eval(thetas[k][2]))
  789. epochs[k][likelihood] = thetas
  790. #epochs[k] = thetas
  791. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(k)+"\n*******\n")
  792. print(cpt, "theta file(s) have been scanned.")
  793. # multiple fig
  794. if ax is None:
  795. # intialize figure 1
  796. my_dpi = 300
  797. fnt_size = 18
  798. # plt.rcParams['font.size'] = fnt_size
  799. fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  800. else:
  801. fnt_size = 12
  802. # plt.rcParams['font.size'] = fnt_size
  803. ax1 = ax[0, 1]
  804. plt.subplots_adjust(wspace=0.3, hspace=0.3)
  805. plots = []
  806. best_epochs = {}
  807. for epoch in epochs:
  808. likelihoods = []
  809. for key in epochs[epoch].keys():
  810. likelihoods.append(key)
  811. likelihoods.sort()
  812. minLogLn = str(likelihoods[0])
  813. best_epochs[epoch] = epochs[epoch][minLogLn]
  814. for epoch, theta in best_epochs.items():
  815. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  816. x = []
  817. y = []
  818. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  819. for i,group in enumerate(groups):
  820. x += group[::-1]
  821. y += list(np.repeat(thetas[i], len(group)))
  822. if epoch == 0:
  823. N0 = y[0]
  824. # compute the proportion of information used at each bin of the SFS
  825. sum_theta_i = 0
  826. for i in range(2, len(y)+2):
  827. sum_theta_i+=y[i-2] / (i-1)
  828. prop = []
  829. for k in range(2, len(y)+2):
  830. prop.append(y[k-2] / (k - 1) / sum_theta_i)
  831. prop = prop[::-1]
  832. # print(prop, "\n", sum(prop))
  833. # normalise to N0 (N0 of epoch1)
  834. x_ticks = ax1.get_xticks()
  835. for i in range(len(y)):
  836. y[i] = y[i]/N0
  837. # plot
  838. x_plot, y_plot = plot_straight_x_y(x, y)
  839. #plt.plot(x, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
  840. p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  841. # add plot to the list of all plots to superimpose
  842. plots.append(p)
  843. #print(prop, "\n", sum(prop))
  844. #ax.legend(handles=[p0]+plots)
  845. ax1.set_xlabel("# bin", fontsize=fnt_size)
  846. # Set the x-axis locator to reduce the number of ticks to 10
  847. ax1.set_ylabel("theta", fontsize=fnt_size)
  848. ax1.set_title(title, fontsize=fnt_size)
  849. ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
  850. ax1.set_xticks(x_ticks)
  851. if len(prop) >= 18:
  852. ax1.locator_params(nbins=n_ticks)
  853. # new scale of ticks if too many values
  854. cumul = 0
  855. prop_cumul = []
  856. for val in prop:
  857. prop_cumul.append(val+cumul)
  858. cumul = val+cumul
  859. ax1.set_xticklabels([f'{x[k]}\n{val:.2f}' for k, val in enumerate(prop_cumul)])
  860. if ax is None:
  861. plt.savefig(title+'_raw'+str(k)+'.pdf')
  862. # fig 2 & 3
  863. if ax is None:
  864. fig2, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  865. fig3, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  866. else:
  867. # plt.rcParams['font.size'] = fnt_size
  868. # place of plots on the grid
  869. ax2 = ax[1,0]
  870. ax3 = ax[1,1]
  871. lines_fig2 = []
  872. lines_fig3 = []
  873. #plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  874. for epoch, theta in best_epochs.items():
  875. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  876. x = []
  877. y = []
  878. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  879. for i,group in enumerate(groups):
  880. x += group[::-1]
  881. y += list(np.repeat(thetas[i], len(group)))
  882. if epoch == 0:
  883. N0 = y[0]
  884. for i in range(len(y)):
  885. y[i] = y[i]/N0
  886. x_2 = []
  887. T = 0
  888. for i in range(len(x)):
  889. x[i] = int(x[i])
  890. # compute the times as: theta_k / (k*(k-1))
  891. for i in range(0, len(x)):
  892. T += y[i] / (x[i]*(x[i]-1))
  893. x_2.append(T)
  894. # Plotting (fig 2)
  895. x_2 = [0]+x_2
  896. y = [y[0]]+y
  897. x2_plot, y2_plot = plot_straight_x_y(x_2, y)
  898. p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  899. lines_fig2.append(p2)
  900. # Plotting (fig 3) which is the same but log scale for x
  901. p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  902. lines_fig3.append(p3)
  903. ax2.set_xlabel("Relative scale", fontsize=fnt_size)
  904. ax2.set_ylabel("theta", fontsize=fnt_size)
  905. ax2.set_title(title, fontsize=fnt_size)
  906. ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
  907. if ax is None:
  908. plt.savefig(title+'_plot2_'+str(k)+'.pdf')
  909. ax3.set_xscale('log')
  910. ax3.set_yscale('log')
  911. ax3.set_xlabel("log Relative scale", fontsize=fnt_size)
  912. ax3.set_ylabel("theta", fontsize=fnt_size)
  913. ax3.set_title(title, fontsize=fnt_size)
  914. ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5)
  915. if ax is None:
  916. plt.savefig(title+'_plot3_'+str(k)+'_log.pdf')
  917. plt.clf()
  918. # return plots
  919. return ax
  920. def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale = True):
  921. my_dpi = 300
  922. # # Add some extra space for the second axis at the bottom
  923. # #plt.rcParams['font.size'] = 18
  924. # fig, axs = plt.subplots(2, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
  925. # #plt.rcParams['font.size'] = 12
  926. # ax = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = axs)
  927. # ax = plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = axs)
  928. # # Adjust layout to prevent clipping of titles
  929. #
  930. # # Save the entire grid as a single figure
  931. # plt.savefig(title+'_combined.pdf')
  932. # plt.clf()
  933. # # # second call for individual plots
  934. # # plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = None)
  935. # # plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = None)
  936. # # plt.clf()
  937. save_k_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, output = title+"_plotdata.json")
  938. with open(title+"_plotdata.json", 'r') as json_file:
  939. loaded_data = json.load(json_file)
  940. # plot page 1 of summary
  941. fig1, ax1 = plt.subplots(2, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
  942. # fig1.tight_layout()
  943. # Adjust absolute space between the top and bottom rows
  944. fig1.subplots_adjust(hspace=0.35) # Adjust this value based on your requirement
  945. # plot page 2 of summary
  946. fig2, ax2 = plt.subplots(2, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
  947. # fig2.tight_layout()
  948. ax1 = plot_raw_stairs(plot_lines = loaded_data['raw_stairs'],
  949. prop = loaded_data['prop'], title = title, ax = ax1)
  950. ax1 = plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'],
  951. prop = loaded_data['prop'], title = title, ax = ax1)
  952. ax1, ax2 = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = [ax1, ax2])
  953. save_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, input = title+"_plotdata.json", output = title+"_plotdata.json")
  954. fig1.savefig(title+'_combined_p1.pdf')
  955. fig2.savefig(title+'_combined_p2.pdf')
  956. plot_raw_stairs(plot_lines = loaded_data['raw_stairs'],
  957. prop = loaded_data['prop'], title = title, ax = None)
  958. plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'],
  959. prop = loaded_data['prop'], title = title, ax = None)
  960. plt.close(fig1)
  961. plt.close(fig2)
  962. if __name__ == "__main__":
  963. if len(sys.argv) != 4:
  964. print("Need 3 args: ThetaFolder MutationRate GenerationTime")
  965. exit(0)
  966. folder_path = sys.argv[1]
  967. mu = sys.argv[2]
  968. tgen = sys.argv[3]
  969. plot_all_epochs_thetafolder(folder_path, mu, tgen)