swp2.py 22KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569
  1. import matplotlib.pyplot as plt
  2. import os
  3. import numpy as np
  4. import math
  5. from scipy.special import gammaln
  6. from matplotlib.backends.backend_pdf import PdfPages
  7. from matplotlib.ticker import MaxNLocator
  8. import io
  9. from mpl_toolkits.axes_grid1.inset_locator import inset_axes
  10. from matplotlib.ticker import MultipleLocator
  11. def log_facto(k):
  12. k = int(k)
  13. if k > 1e6:
  14. return k * np.log(k) - k + np.log(2*math.pi*k)/2
  15. val = 0
  16. for i in range(2, k+1):
  17. val += np.log(i)
  18. return val
  19. def log_facto_1(k):
  20. startf = 1 # start of factorial sequence
  21. stopf = int(k+1) # end of of factorial sequence
  22. q = gammaln(range(startf+1, stopf+1)) # n! = G(n+1)
  23. return q[-1]
  24. def return_x_y_from_stwp_theta_file(stwp_theta_file, breaks, mu, tgen, relative_theta_scale = False):
  25. with open(stwp_theta_file, "r") as swp_file:
  26. # Read the first line
  27. line = swp_file.readline()
  28. L = float(line.split()[2])
  29. rands = swp_file.readline()
  30. line = swp_file.readline()
  31. # skip empty lines before SFS
  32. while line == "\n":
  33. line = swp_file.readline()
  34. sfs = np.array(line.split()).astype(float)
  35. # Process lines until the end of the file
  36. while line:
  37. # check at each line
  38. if line.startswith("dim") :
  39. dim = int(line.split()[1])
  40. if dim == breaks+1:
  41. likelihood = line.split()[5]
  42. groups = line.split()[6:6+dim]
  43. theta_site = line.split()[6+dim:6+dim+1+dim]
  44. elif dim < breaks+1:
  45. line = swp_file.readline()
  46. continue
  47. elif dim > breaks+1:
  48. break
  49. #return 0,0,0
  50. # Read the next line
  51. line = swp_file.readline()
  52. #### END of parsing
  53. # quit this file if the number of dimensions is incorrect
  54. if dim < breaks+1:
  55. return 0,0,0,0,0,0
  56. # get n, the last bin of the last group
  57. # revert the list of groups as the most recent times correspond
  58. # to the closest and last leafs of the coal. tree.
  59. groups = groups[::-1]
  60. theta_site = theta_site[::-1]
  61. # store thetas for later use
  62. grps = groups.copy()
  63. thetas = {}
  64. for i in range(len(groups)):
  65. grps[i] = grps[i].split(',')
  66. thetas[i] = [float(theta_site[i]), grps[i], likelihood]
  67. # initiate the dict of times
  68. t = {}
  69. # list of thetas
  70. theta_L = []
  71. sum_t = 0
  72. for group_nb, group in enumerate(groups):
  73. ###print(group_nb, group, theta_site[group_nb], len(theta_site))
  74. # store all the thetas one by one, with one theta per group
  75. theta_L.append(float(theta_site[group_nb]))
  76. # if the group is of size 1
  77. if len(group.split(',')) == 1:
  78. i = int(group)
  79. # if the group size is >1, take the first elem of the group
  80. # i is the first bin of each group, straight after a breakpoint
  81. else:
  82. i = int(group.split(",")[0])
  83. j = int(group.split(",")[-1])
  84. t[i] = 0
  85. #t =
  86. if len(group.split(',')) == 1:
  87. k = i
  88. if relative_theta_scale:
  89. t[i] += ((theta_L[group_nb] ) / (k*(k-1)))
  90. else:
  91. t[i] += ((theta_L[group_nb] ) / (k*(k-1)) * tgen) / mu
  92. else:
  93. for k in range(j, i-1, -1 ):
  94. if relative_theta_scale:
  95. t[i] += ((theta_L[group_nb] ) / (k*(k-1)))
  96. else:
  97. t[i] += ((theta_L[group_nb] ) / (k*(k-1)) * tgen) / mu
  98. # we add the cumulative times at the end
  99. t[i] += sum_t
  100. sum_t = t[i]
  101. # build the y axis (sizes)
  102. y = []
  103. for theta in theta_L:
  104. if relative_theta_scale:
  105. size = theta
  106. else:
  107. # with size N = theta/4mu
  108. size = theta / (4*mu)
  109. y.append(size)
  110. y.append(size)
  111. # build the time x axis
  112. x = [0]
  113. for time in range(0, len(t.values())-1):
  114. x.append(list(t.values())[time])
  115. x.append(list(t.values())[time])
  116. x.append(list(t.values())[len(t.values())-1])
  117. # if relative_theta_scale:
  118. # # rescale
  119. # #N0 = y[0]
  120. # # for i in range(len(y)):
  121. # # # divide by N0
  122. # # y[i] = y[i]/N0
  123. # # x[i] = x[i]/N0
  124. return x,y,likelihood,thetas,sfs,L
  125. def return_x_y_from_stwp_theta_file_as_is(stwp_theta_file, breaks, mu, tgen, relative_theta_scale = False):
  126. with open(stwp_theta_file, "r") as swp_file:
  127. # Read the first line
  128. line = swp_file.readline()
  129. L = float(line.split()[2])
  130. rands = swp_file.readline()
  131. line = swp_file.readline()
  132. # skip empty lines before SFS
  133. while line == "\n":
  134. line = swp_file.readline()
  135. sfs = np.array(line.split()).astype(float)
  136. # Process lines until the end of the file
  137. while line:
  138. # check at each line
  139. if line.startswith("dim") :
  140. dim = int(line.split()[1])
  141. if dim == breaks+1:
  142. likelihood = line.split()[5]
  143. groups = line.split()[6:6+dim]
  144. theta_site = line.split()[6+dim:6+dim+1+dim]
  145. elif dim < breaks+1:
  146. line = swp_file.readline()
  147. continue
  148. elif dim > breaks+1:
  149. break
  150. #return 0,0,0
  151. # Read the next line
  152. line = swp_file.readline()
  153. #### END of parsing
  154. # quit this file if the number of dimensions is incorrect
  155. if dim < breaks+1:
  156. return 0,0
  157. # get n, the last bin of the last group
  158. # revert the list of groups as the most recent times correspond
  159. # to the closest and last leafs of the coal. tree.
  160. groups = groups[::-1]
  161. theta_site = theta_site[::-1]
  162. thetas = {}
  163. for i in range(len(groups)):
  164. groups[i] = groups[i].split(',')
  165. # print(groups[i], len(groups[i]))
  166. thetas[i] = [float(theta_site[i]), groups[i], likelihood]
  167. return thetas, sfs
  168. def plot_k_epochs_thetafolder(folder_path, mu, tgen, breaks = 2, title = "Title", theta_scale = True):
  169. scenari = {}
  170. cpt = 0
  171. for file_name in os.listdir(folder_path):
  172. if os.path.isfile(os.path.join(folder_path, file_name)):
  173. # Perform actions on each file
  174. x,y,likelihood,sfs,L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
  175. tgen = tgen,
  176. mu = mu, relative_theta_scale = theta_scale)
  177. if x == 0 or y == 0:
  178. continue
  179. cpt +=1
  180. scenari[likelihood] = x,y
  181. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
  182. print(cpt, "theta file(s) have been scanned.")
  183. # sort starting by the smallest -log(Likelihood)
  184. print(scenari)
  185. best10_scenari = (sorted(list(scenari.keys())))[:10]
  186. print("10 greatest Likelihoods", best10_scenari)
  187. greatest_likelihood = best10_scenari[0]
  188. x, y = scenari[greatest_likelihood]
  189. my_dpi = 300
  190. plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  191. plt.plot(x, y, 'r-', lw=2, label = 'Lik='+greatest_likelihood)
  192. plt.xlim(1e-3, 1)
  193. plt.ylim(0, 10)
  194. #plt.yscale('log')
  195. plt.xscale('log')
  196. plt.grid(True,which="both", linestyle='--', alpha = 0.3)
  197. for scenario in best10_scenari[1:]:
  198. x,y = scenari[scenario]
  199. #print("\n---- Lik:",scenario,"\n\nt=", x,"\n\nN=",y, "\n\n")
  200. plt.plot(x, y, '--', lw=1, label = 'Lik='+scenario)
  201. if theta_scale:
  202. plt.xlabel("Coal. time")
  203. plt.ylabel("Pop. size scaled by N0")
  204. recent_scale_lower_bound = y[0] * 0.01
  205. recent_scale_upper_bound = y[0] * 0.1
  206. plt.axvline(x=recent_scale_lower_bound)
  207. plt.axvline(x=recent_scale_upper_bound)
  208. else:
  209. # years
  210. plt.xlabel("Time (years)")
  211. plt.ylabel("Individuals (N)")
  212. plt.legend(loc='upper right')
  213. plt.title(title)
  214. plt.savefig(title+'_b'+str(breaks)+'.pdf')
  215. def plot_straight_x_y(x,y):
  216. x_1 = [x[0]]
  217. y_1 = []
  218. for i in range(0, len(y)-1):
  219. x_1.append(x[i])
  220. x_1.append(x[i])
  221. y_1.append(y[i])
  222. y_1.append(y[i])
  223. y_1 = y_1+[y[-1],y[-1]]
  224. x_1.append(x[-1])
  225. return x_1, y_1
  226. def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_scale = True, ax = None):
  227. #scenari = {}
  228. cpt = 0
  229. epochs = {}
  230. for file_name in os.listdir(folder_path):
  231. breaks = 0
  232. cpt +=1
  233. if os.path.isfile(os.path.join(folder_path, file_name)):
  234. x, y, likelihood, theta, sfs, L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
  235. tgen = tgen,
  236. mu = mu, relative_theta_scale = theta_scale)
  237. SFS_stored = sfs
  238. L_stored = L
  239. while not (x == 0 and y == 0):
  240. if breaks not in epochs.keys():
  241. epochs[breaks] = {}
  242. epochs[breaks][likelihood] = x,y
  243. breaks += 1
  244. x,y,likelihood,theta,sfs,L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
  245. tgen = tgen,
  246. mu = mu, relative_theta_scale = theta_scale)
  247. if x == 0:
  248. # last break did not work, then breaks = breaks-1
  249. breaks -= 1
  250. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
  251. print(cpt, "theta file(s) have been scanned.")
  252. my_dpi = 300
  253. if ax is None:
  254. # intialize figure
  255. my_dpi = 300
  256. fnt_size = 18
  257. # plt.rcParams['font.size'] = fnt_size
  258. fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  259. else:
  260. fnt_size = 12
  261. # plt.rcParams['font.size'] = fnt_size
  262. ax1 = ax[0,0]
  263. #ax1.set_xlim(1e-3, 1)
  264. ax1.set_yscale('log')
  265. ax1.set_xscale('log')
  266. ax1.grid(True,which="both", linestyle='--', alpha = 0.3)
  267. brkpt_lik = []
  268. top_plots = {}
  269. for epoch, scenari in epochs.items():
  270. # sort starting by the smallest -log(Likelihood)
  271. best10_scenari = (sorted(list(scenari.keys())))[:10]
  272. greatest_likelihood = best10_scenari[0]
  273. # store the tuple breakpoints and likelihood for later plot
  274. brkpt_lik.append((epoch, greatest_likelihood))
  275. x, y = scenari[greatest_likelihood]
  276. #without breakpoint
  277. if epoch == 0:
  278. # do something with the theta without bp and skip the plotting
  279. N0 = y[0]
  280. #continue
  281. for i in range(len(y)):
  282. # divide by N0
  283. y[i] = y[i]/N0
  284. x[i] = x[i]/N0
  285. top_plots[greatest_likelihood] = x,y,epoch
  286. plots_likelihoods = list(top_plots.keys())
  287. for i in range(len(plots_likelihoods)):
  288. plots_likelihoods[i] = float(plots_likelihoods[i])
  289. best10_plots = sorted(plots_likelihoods)[:10]
  290. top_plot_lik = str(best10_plots[0])
  291. plot_handles = []
  292. # plt.rcParams['font.size'] = fnt_size
  293. p0, = ax1.plot(top_plots[top_plot_lik][0], top_plots[top_plot_lik][1], 'o', linestyle = "-",
  294. alpha=1, lw=2, label = str(top_plots[top_plot_lik][2])+' brks | Lik='+top_plot_lik)
  295. plot_handles.append(p0)
  296. for k, plot_Lk in enumerate(best10_plots[1:]):
  297. plot_Lk = str(plot_Lk)
  298. # plt.rcParams['font.size'] = fnt_size
  299. p, = ax1.plot(top_plots[plot_Lk][0], top_plots[plot_Lk][1], 'o', linestyle = "--",
  300. alpha=1/(k+1), lw=1.5, label = str(top_plots[plot_Lk][2])+' brks | Lik='+plot_Lk)
  301. plot_handles.append(p)
  302. if theta_scale:
  303. ax1.set_xlabel("Coal. time", fontsize=fnt_size)
  304. ax1.set_ylabel("Pop. size scaled by N0", fontsize=fnt_size)
  305. # recent_scale_lower_bound = 0.01
  306. # recent_scale_upper_bound = 0.1
  307. # ax1.axvline(x=recent_scale_lower_bound)
  308. # ax1.axvline(x=recent_scale_upper_bound)
  309. else:
  310. # years
  311. plt.set_xlabel("Time (years)", fontsize=fnt_size)
  312. plt.set_ylabel("Individuals (N)", fontsize=fnt_size)
  313. ax1.set_xlim(1e-5, 1)
  314. # plt.rcParams['font.size'] = fnt_size
  315. # print(fnt_size, "rcParam font.size=", plt.rcParams['font.size'])
  316. ax1.legend(handles = plot_handles, loc='best', fontsize = fnt_size*0.5)
  317. ax1.set_title(title)
  318. if ax is None:
  319. plt.savefig(title+'_b'+str(breaks)+'.pdf')
  320. # plot likelihood against nb of breakpoints
  321. # best possible likelihood from SFS
  322. # Segregating sites
  323. S = sum(SFS_stored)
  324. # Number of kept sites from which the SFS is computed
  325. L = L_stored
  326. # number of monomorphic sites
  327. S0 = L-S
  328. # print("SFS", SFS_stored)
  329. # print("S", S, "L", L, "S0=", S0)
  330. # compute Ln
  331. Ln = log_facto(S+S0) - log_facto(S0) + np.log(float(S0)/(S+S0)) * S0
  332. for xi in range(0, len(SFS_stored)):
  333. p_i = SFS_stored[xi] / float(S+S0)
  334. Ln += np.log(p_i) * SFS_stored[xi] - log_facto(SFS_stored[xi])
  335. # basic plot likelihood
  336. if ax is None:
  337. fig, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  338. # plt.rcParams['font.size'] = fnt_size
  339. else:
  340. #plt.rcParams['font.size'] = fnt_size
  341. ax2 = ax[2,0]
  342. ax2.plot(np.array(brkpt_lik)[:, 0], np.array(brkpt_lik)[:, 1].astype(float), 'o', linestyle = "dotted", lw=2)
  343. ax2.axhline(y=-Ln, linestyle = "-.", color = "red", label = "$-\log\mathcal{L}$ = "+str(round(-Ln, 2)))
  344. ax2.set_yscale('log')
  345. ax2.set_xlabel("# breakpoints", fontsize=fnt_size)
  346. ax2.set_ylabel("$-\log\mathcal{L}$", fontsize=fnt_size)
  347. ax2.legend(loc='best', fontsize = fnt_size*0.5)
  348. ax2.set_title(title+" Likelihood gain from # breakpoints")
  349. if ax is None:
  350. plt.savefig(title+'_Breakpts_Likelihood.pdf')
  351. # AIC
  352. if ax is None:
  353. fig, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  354. # plt.rcParams['font.size'] = '18'
  355. else:
  356. #plt.rcParams['font.size'] = fnt_size
  357. ax3 = ax[2,1]
  358. AIC = 2*(len(brkpt_lik)+1)+2*np.array(brkpt_lik)[:, 1].astype(float)
  359. ax3.plot(np.array(brkpt_lik)[:, 0], AIC, 'o', linestyle = "dotted", lw=2)
  360. AIC_ln = 2*(len(brkpt_lik)+1)-2*Ln
  361. ax3.axhline(y=AIC_ln, linestyle = "-.", color = "red",
  362. label = "Min. AIC = "+str(round(AIC_ln, 2)))
  363. ax3.set_yscale('log')
  364. ax3.set_xlabel("# breakpoints", fontsize=fnt_size)
  365. ax3.set_ylabel("AIC")
  366. ax3.legend(loc='best', fontsize = fnt_size*0.5)
  367. ax3.set_title(title+" AIC")
  368. if ax is None:
  369. plt.savefig(title+'_Breakpts_Likelihood_AIC.pdf')
  370. print("S", S)
  371. # return plots
  372. return ax
  373. def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True, breaks_max = 10, ax = None, n_ticks = 10):
  374. """
  375. Use theta values as is to do basic plots.
  376. """
  377. cpt = 0
  378. epochs = {}
  379. len_sfs = 0
  380. for file_name in os.listdir(folder_path):
  381. cpt +=1
  382. if os.path.isfile(os.path.join(folder_path, file_name)):
  383. for k in range(breaks_max):
  384. thetas,sfs = return_x_y_from_stwp_theta_file_as_is(folder_path+file_name, breaks = k,
  385. tgen = tgen,
  386. mu = mu, relative_theta_scale = theta_scale)
  387. if thetas == 0:
  388. continue
  389. if len(thetas)-1 != k:
  390. continue
  391. if k not in epochs.keys():
  392. epochs[k] = {}
  393. likelihood = thetas[k][2]
  394. epochs[k][likelihood] = thetas
  395. #epochs[k] = thetas
  396. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(k)+"\n*******\n")
  397. print(cpt, "theta file(s) have been scanned.")
  398. # multiple fig
  399. if ax is None:
  400. # intialize figure 1
  401. my_dpi = 300
  402. fnt_size = 18
  403. # plt.rcParams['font.size'] = fnt_size
  404. fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  405. else:
  406. fnt_size = 12
  407. # plt.rcParams['font.size'] = fnt_size
  408. ax1 = ax[0, 1]
  409. plt.subplots_adjust(wspace=0.3, hspace=0.3)
  410. plots = []
  411. best_epochs = {}
  412. for epoch in epochs:
  413. likelihoods = []
  414. for key in epochs[epoch].keys():
  415. likelihoods.append(float(key))
  416. likelihoods.sort()
  417. minLogLn = str(likelihoods[0])
  418. best_epochs[epoch] = epochs[epoch][minLogLn]
  419. for epoch, theta in best_epochs.items():
  420. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  421. x = []
  422. y = []
  423. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  424. for i,group in enumerate(groups):
  425. x += group[::-1]
  426. y += list(np.repeat(thetas[i], len(group)))
  427. if epoch == 0:
  428. N0 = y[0]
  429. # compute the proportion of information used at each bin of the SFS
  430. sum_theta_i = 0
  431. for i in range(2, len(y)+2):
  432. sum_theta_i+=y[i-2] / (i-1)
  433. prop = []
  434. for k in range(2, len(y)+2):
  435. prop.append(y[k-2] / (k - 1) / sum_theta_i)
  436. prop = prop[::-1]
  437. # print(prop, "\n", sum(prop))
  438. # normalise to N0 (N0 of epoch1)
  439. x_ticks = ax1.get_xticks()
  440. for i in range(len(y)):
  441. y[i] = y[i]/N0
  442. # plot
  443. x_plot, y_plot = plot_straight_x_y(x, y)
  444. #plt.plot(x, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
  445. p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  446. # add plot to the list of all plots to superimpose
  447. plots.append(p)
  448. #print(prop, "\n", sum(prop))
  449. #ax.legend(handles=[p0]+plots)
  450. ax1.set_xlabel("# bin", fontsize=fnt_size)
  451. # Set the x-axis locator to reduce the number of ticks to 10
  452. ax1.set_ylabel("theta", fontsize=fnt_size)
  453. ax1.set_title("Title", fontsize=fnt_size)
  454. ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
  455. ax1.set_xticks(x_ticks)
  456. if len(prop) >= 18:
  457. ax1.locator_params(nbins=n_ticks)
  458. # new scale of ticks if too many values
  459. cumul = 0
  460. prop_cumul = []
  461. for val in prop:
  462. prop_cumul.append(val+cumul)
  463. cumul = val+cumul
  464. ax1.set_xticklabels([f'{x[k]}\n{val:.2f}' for k, val in enumerate(prop_cumul)])
  465. if ax is None:
  466. plt.savefig(title+'_raw'+str(k)+'.pdf')
  467. # fig 2 & 3
  468. if ax is None:
  469. fig2, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  470. fig3, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  471. else:
  472. # plt.rcParams['font.size'] = fnt_size
  473. # place of plots on the grid
  474. ax2 = ax[1,0]
  475. ax3 = ax[1,1]
  476. lines_fig2 = []
  477. lines_fig3 = []
  478. #plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  479. for epoch, theta in best_epochs.items():
  480. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  481. x = []
  482. y = []
  483. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  484. for i,group in enumerate(groups):
  485. x += group[::-1]
  486. y += list(np.repeat(thetas[i], len(group)))
  487. if epoch == 0:
  488. N0 = y[0]
  489. for i in range(len(y)):
  490. y[i] = y[i]/N0
  491. x_2 = []
  492. T = 0
  493. for i in range(len(x)):
  494. x[i] = int(x[i])
  495. # compute the times as: theta_k / (k*(k-1))
  496. for i in range(0, len(x)):
  497. T += y[i] / (x[i]*(x[i]-1))
  498. x_2.append(T)
  499. # Plotting (fig 2)
  500. x_2 = [0]+x_2
  501. y = [y[0]]+y
  502. x2_plot, y2_plot = plot_straight_x_y(x_2, y)
  503. p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  504. lines_fig2.append(p2)
  505. # Plotting (fig 3) which is the same but log scale for x
  506. p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  507. lines_fig3.append(p3)
  508. ax2.set_xlabel("Relative scale", fontsize=fnt_size)
  509. ax2.set_ylabel("theta", fontsize=fnt_size)
  510. ax2.set_title("Title", fontsize=fnt_size)
  511. ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
  512. if ax is None:
  513. plt.savefig(title+'_plot2_'+str(k)+'.pdf')
  514. ax3.set_xscale('log')
  515. ax3.set_yscale('log')
  516. ax3.set_xlabel("log Relative scale", fontsize=fnt_size)
  517. ax3.set_ylabel("theta", fontsize=fnt_size)
  518. ax3.set_title("Title", fontsize=fnt_size)
  519. ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5)
  520. if ax is None:
  521. plt.savefig(title+'_plot3_'+str(k)+'_log.pdf')
  522. plt.clf()
  523. # return plots
  524. return ax
  525. def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale = True):
  526. my_dpi = 300
  527. # Add some extra space for the second axis at the bottom
  528. #plt.rcParams['font.size'] = 18
  529. fig, axs = plt.subplots(3, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
  530. #plt.rcParams['font.size'] = 12
  531. ax = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = axs)
  532. ax = plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = axs)
  533. # Adjust layout to prevent clipping of titles
  534. plt.tight_layout()
  535. # Adjust absolute space between the top and bottom rows
  536. #plt.subplots_adjust(hspace=0.7) # Adjust this value based on your requirement
  537. # Save the entire grid as a single figure
  538. plt.savefig(title+'_combined.pdf')
  539. plt.clf()
  540. # # second call for individual plots
  541. # plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = None)
  542. # plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = None)
  543. # plt.clf()
  544. if __name__ == "__main__":
  545. if len(sys.argv) != 4:
  546. print("Need 3 args: ThetaFolder MutationRate GenerationTime")
  547. exit(0)
  548. folder_path = sys.argv[1]
  549. mu = sys.argv[2]
  550. tgen = sys.argv[3]
  551. plot_all_epochs_thetafolder(folder_path, mu, tgen)