swp2.py 31KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791
  1. import matplotlib.pyplot as plt
  2. import os
  3. import numpy as np
  4. import math
  5. import json
  6. import io
  7. from scipy.special import gammaln
  8. from matplotlib.backends.backend_pdf import PdfPages
  9. from matplotlib.ticker import MaxNLocator
  10. from mpl_toolkits.axes_grid1.inset_locator import inset_axes
  11. from matplotlib.ticker import MultipleLocator
  12. def log_facto(k):
  13. k = int(k)
  14. if k > 1e6:
  15. return k * np.log(k) - k + np.log(2*math.pi*k)/2
  16. val = 0
  17. for i in range(2, k+1):
  18. val += np.log(i)
  19. return val
  20. def log_facto_1(k):
  21. startf = 1 # start of factorial sequence
  22. stopf = int(k+1) # end of of factorial sequence
  23. q = gammaln(range(startf+1, stopf+1)) # n! = G(n+1)
  24. return q[-1]
  25. def return_x_y_from_stwp_theta_file(stwp_theta_file, breaks, mu, tgen, relative_theta_scale = False):
  26. with open(stwp_theta_file, "r") as swp_file:
  27. # Read the first line
  28. line = swp_file.readline()
  29. L = float(line.split()[2])
  30. rands = swp_file.readline()
  31. line = swp_file.readline()
  32. # skip empty lines before SFS
  33. while line == "\n":
  34. line = swp_file.readline()
  35. sfs = np.array(line.split()).astype(float)
  36. # Process lines until the end of the file
  37. while line:
  38. # check at each line
  39. if line.startswith("dim") :
  40. dim = int(line.split()[1])
  41. if dim == breaks+1:
  42. likelihood = line.split()[5]
  43. groups = line.split()[6:6+dim]
  44. theta_site = line.split()[6+dim:6+dim+1+dim]
  45. elif dim < breaks+1:
  46. line = swp_file.readline()
  47. continue
  48. elif dim > breaks+1:
  49. break
  50. #return 0,0,0
  51. # Read the next line
  52. line = swp_file.readline()
  53. #### END of parsing
  54. # quit this file if the number of dimensions is incorrect
  55. if dim < breaks+1:
  56. return 0,0,0,0,0,0
  57. # get n, the last bin of the last group
  58. # revert the list of groups as the most recent times correspond
  59. # to the closest and last leafs of the coal. tree.
  60. groups = groups[::-1]
  61. theta_site = theta_site[::-1]
  62. # store thetas for later use
  63. grps = groups.copy()
  64. thetas = {}
  65. for i in range(len(groups)):
  66. grps[i] = grps[i].split(',')
  67. thetas[i] = [float(theta_site[i]), grps[i], likelihood]
  68. # initiate the dict of times
  69. t = {}
  70. # list of thetas
  71. theta_L = []
  72. sum_t = 0
  73. for group_nb, group in enumerate(groups):
  74. ###print(group_nb, group, theta_site[group_nb], len(theta_site))
  75. # store all the thetas one by one, with one theta per group
  76. theta_L.append(float(theta_site[group_nb]))
  77. # if the group is of size 1
  78. if len(group.split(',')) == 1:
  79. i = int(group)
  80. # if the group size is >1, take the first elem of the group
  81. # i is the first bin of each group, straight after a breakpoint
  82. else:
  83. i = int(group.split(",")[0])
  84. j = int(group.split(",")[-1])
  85. t[i] = 0
  86. #t =
  87. if len(group.split(',')) == 1:
  88. k = i
  89. if relative_theta_scale:
  90. t[i] += ((theta_L[group_nb] ) / (k*(k-1)))
  91. else:
  92. t[i] += ((theta_L[group_nb] ) / (k*(k-1)) * tgen) / mu
  93. else:
  94. for k in range(j, i-1, -1 ):
  95. if relative_theta_scale:
  96. t[i] += ((theta_L[group_nb] ) / (k*(k-1)))
  97. else:
  98. t[i] += ((theta_L[group_nb] ) / (k*(k-1)) * tgen) / mu
  99. # we add the cumulative times at the end
  100. t[i] += sum_t
  101. sum_t = t[i]
  102. # build the y axis (sizes)
  103. y = []
  104. for theta in theta_L:
  105. if relative_theta_scale:
  106. size = theta
  107. else:
  108. # with size N = theta/4mu
  109. size = theta / (4*mu)
  110. y.append(size)
  111. y.append(size)
  112. # build the time x axis
  113. x = [0]
  114. for time in range(0, len(t.values())-1):
  115. x.append(list(t.values())[time])
  116. x.append(list(t.values())[time])
  117. x.append(list(t.values())[len(t.values())-1])
  118. # if relative_theta_scale:
  119. # # rescale
  120. # #N0 = y[0]
  121. # # for i in range(len(y)):
  122. # # # divide by N0
  123. # # y[i] = y[i]/N0
  124. # # x[i] = x[i]/N0
  125. return x,y,likelihood,thetas,sfs,L
  126. def return_x_y_from_stwp_theta_file_as_is(stwp_theta_file, breaks, mu, tgen, relative_theta_scale = False):
  127. with open(stwp_theta_file, "r") as swp_file:
  128. # Read the first line
  129. line = swp_file.readline()
  130. L = float(line.split()[2])
  131. rands = swp_file.readline()
  132. line = swp_file.readline()
  133. # skip empty lines before SFS
  134. while line == "\n":
  135. line = swp_file.readline()
  136. sfs = np.array(line.split()).astype(float)
  137. # Process lines until the end of the file
  138. while line:
  139. # check at each line
  140. if line.startswith("dim") :
  141. dim = int(line.split()[1])
  142. if dim == breaks+1:
  143. likelihood = line.split()[5]
  144. groups = line.split()[6:6+dim]
  145. theta_site = line.split()[6+dim:6+dim+1+dim]
  146. elif dim < breaks+1:
  147. line = swp_file.readline()
  148. continue
  149. elif dim > breaks+1:
  150. break
  151. #return 0,0,0
  152. # Read the next line
  153. line = swp_file.readline()
  154. #### END of parsing
  155. # quit this file if the number of dimensions is incorrect
  156. if dim < breaks+1:
  157. return 0,0
  158. # get n, the last bin of the last group
  159. # revert the list of groups as the most recent times correspond
  160. # to the closest and last leafs of the coal. tree.
  161. groups = groups[::-1]
  162. theta_site = theta_site[::-1]
  163. thetas = {}
  164. for i in range(len(groups)):
  165. groups[i] = groups[i].split(',')
  166. # print(groups[i], len(groups[i]))
  167. thetas[i] = [float(theta_site[i]), groups[i], likelihood]
  168. return thetas, sfs
  169. def plot_k_epochs_thetafolder(folder_path, mu, tgen, breaks = 2, title = "Title", theta_scale = True):
  170. scenari = {}
  171. cpt = 0
  172. for file_name in os.listdir(folder_path):
  173. if os.path.isfile(os.path.join(folder_path, file_name)):
  174. # Perform actions on each file
  175. x,y,likelihood,sfs,L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
  176. tgen = tgen,
  177. mu = mu, relative_theta_scale = theta_scale)
  178. if x == 0 or y == 0:
  179. continue
  180. cpt +=1
  181. scenari[likelihood] = x,y
  182. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
  183. print(cpt, "theta file(s) have been scanned.")
  184. # sort starting by the smallest -log(Likelihood)
  185. print(scenari)
  186. best10_scenari = (sorted(list(scenari.keys())))[:10]
  187. print("10 greatest Likelihoods", best10_scenari)
  188. greatest_likelihood = best10_scenari[0]
  189. x, y = scenari[greatest_likelihood]
  190. my_dpi = 300
  191. plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  192. plt.plot(x, y, 'r-', lw=2, label = 'Lik='+greatest_likelihood)
  193. #plt.yscale('log')
  194. plt.xscale('log')
  195. plt.grid(True,which="both", linestyle='--', alpha = 0.3)
  196. for scenario in best10_scenari[1:]:
  197. x,y = scenari[scenario]
  198. #print("\n---- Lik:",scenario,"\n\nt=", x,"\n\nN=",y, "\n\n")
  199. plt.plot(x, y, '--', lw=1, label = 'Lik='+scenario)
  200. if theta_scale:
  201. plt.xlabel("Coal. time")
  202. plt.ylabel("Pop. size scaled by N0")
  203. recent_scale_lower_bound = y[0] * 0.01
  204. recent_scale_upper_bound = y[0] * 0.1
  205. plt.axvline(x=recent_scale_lower_bound)
  206. plt.axvline(x=recent_scale_upper_bound)
  207. else:
  208. # years
  209. plt.xlabel("Time (years)")
  210. plt.ylabel("Individuals (N)")
  211. plt.legend(loc='upper right')
  212. plt.title(title)
  213. plt.savefig(title+'_b'+str(breaks)+'.pdf')
  214. def plot_straight_x_y(x,y):
  215. x_1 = [x[0]]
  216. y_1 = []
  217. for i in range(0, len(y)-1):
  218. x_1.append(x[i])
  219. x_1.append(x[i])
  220. y_1.append(y[i])
  221. y_1.append(y[i])
  222. y_1 = y_1+[y[-1],y[-1]]
  223. x_1.append(x[-1])
  224. return x_1, y_1
  225. def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_scale = True, ax = None):
  226. #scenari = {}
  227. cpt = 0
  228. epochs = {}
  229. for file_name in os.listdir(folder_path):
  230. breaks = 0
  231. cpt +=1
  232. if os.path.isfile(os.path.join(folder_path, file_name)):
  233. x, y, likelihood, theta, sfs, L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
  234. tgen = tgen,
  235. mu = mu, relative_theta_scale = theta_scale)
  236. SFS_stored = sfs
  237. L_stored = L
  238. while not (x == 0 and y == 0):
  239. if breaks not in epochs.keys():
  240. epochs[breaks] = {}
  241. epochs[breaks][likelihood] = x,y
  242. breaks += 1
  243. x,y,likelihood,theta,sfs,L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
  244. tgen = tgen,
  245. mu = mu, relative_theta_scale = theta_scale)
  246. if x == 0:
  247. # last break did not work, then breaks = breaks-1
  248. breaks -= 1
  249. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
  250. print(cpt, "theta file(s) have been scanned.")
  251. my_dpi = 300
  252. if ax is None:
  253. # intialize figure
  254. my_dpi = 300
  255. fnt_size = 18
  256. # plt.rcParams['font.size'] = fnt_size
  257. fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  258. else:
  259. fnt_size = 12
  260. # plt.rcParams['font.size'] = fnt_size
  261. ax1 = ax[1][0,0]
  262. ax1.set_yscale('log')
  263. ax1.set_xscale('log')
  264. ax1.grid(True,which="both", linestyle='--', alpha = 0.3)
  265. brkpt_lik = []
  266. top_plots = {}
  267. for epoch, scenari in epochs.items():
  268. # sort starting by the smallest -log(Likelihood)
  269. best10_scenari = (sorted(list(scenari.keys())))[:10]
  270. greatest_likelihood = best10_scenari[0]
  271. # store the tuple breakpoints and likelihood for later plot
  272. brkpt_lik.append((epoch, greatest_likelihood))
  273. x, y = scenari[greatest_likelihood]
  274. #without breakpoint
  275. if epoch == 0:
  276. # do something with the theta without bp and skip the plotting
  277. N0 = y[0]
  278. #continue
  279. for i in range(len(y)):
  280. # divide by N0
  281. y[i] = y[i]/N0
  282. x[i] = x[i]/N0
  283. top_plots[greatest_likelihood] = x,y,epoch
  284. plots_likelihoods = list(top_plots.keys())
  285. for i in range(len(plots_likelihoods)):
  286. plots_likelihoods[i] = float(plots_likelihoods[i])
  287. best10_plots = sorted(plots_likelihoods)[:10]
  288. top_plot_lik = str(best10_plots[0])
  289. plot_handles = []
  290. # plt.rcParams['font.size'] = fnt_size
  291. p0, = ax1.plot(top_plots[top_plot_lik][0], top_plots[top_plot_lik][1], 'o', linestyle = "-",
  292. alpha=1, lw=2, label = str(top_plots[top_plot_lik][2])+' brks | Lik='+top_plot_lik)
  293. plot_handles.append(p0)
  294. for k, plot_Lk in enumerate(best10_plots[1:]):
  295. plot_Lk = str(plot_Lk)
  296. # plt.rcParams['font.size'] = fnt_size
  297. p, = ax1.plot(top_plots[plot_Lk][0], top_plots[plot_Lk][1], 'o', linestyle = "--",
  298. alpha=1/(k+1), lw=1.5, label = str(top_plots[plot_Lk][2])+' brks | Lik='+plot_Lk)
  299. plot_handles.append(p)
  300. if theta_scale:
  301. ax1.set_xlabel("Coal. time", fontsize=fnt_size)
  302. ax1.set_ylabel("Pop. size scaled by N0", fontsize=fnt_size)
  303. # recent_scale_lower_bound = 0.01
  304. # recent_scale_upper_bound = 0.1
  305. # ax1.axvline(x=recent_scale_lower_bound)
  306. # ax1.axvline(x=recent_scale_upper_bound)
  307. else:
  308. # years
  309. plt.set_xlabel("Time (years)", fontsize=fnt_size)
  310. plt.set_ylabel("Individuals (N)", fontsize=fnt_size)
  311. # plt.rcParams['font.size'] = fnt_size
  312. # print(fnt_size, "rcParam font.size=", plt.rcParams['font.size'])
  313. ax1.legend(handles = plot_handles, loc='best', fontsize = fnt_size*0.5)
  314. ax1.set_title(title)
  315. if ax is None:
  316. plt.savefig(title+'_b'+str(breaks)+'.pdf')
  317. # plot likelihood against nb of breakpoints
  318. # best possible likelihood from SFS
  319. # Segregating sites
  320. S = sum(SFS_stored)
  321. # Number of kept sites from which the SFS is computed
  322. L = L_stored
  323. # number of monomorphic sites
  324. S0 = L-S
  325. # print("SFS", SFS_stored)
  326. # print("S", S, "L", L, "S0=", S0)
  327. # compute Ln
  328. Ln = log_facto(S+S0) - log_facto(S0) + np.log(float(S0)/(S+S0)) * S0
  329. for xi in range(0, len(SFS_stored)):
  330. p_i = SFS_stored[xi] / float(S+S0)
  331. Ln += np.log(p_i) * SFS_stored[xi] - log_facto(SFS_stored[xi])
  332. # basic plot likelihood
  333. if ax is None:
  334. fig, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  335. # plt.rcParams['font.size'] = fnt_size
  336. else:
  337. #plt.rcParams['font.size'] = fnt_size
  338. ax2 = ax[0][0,1]
  339. ax2.plot(np.array(brkpt_lik)[:, 0], np.array(brkpt_lik)[:, 1].astype(float), 'o', linestyle = "dotted", lw=2)
  340. ax2.axhline(y=-Ln, linestyle = "-.", color = "red", label = "$-\log\mathcal{L}$ = "+str(round(-Ln, 2)))
  341. ax2.set_yscale('log')
  342. ax2.set_xlabel("# breakpoints", fontsize=fnt_size)
  343. ax2.set_ylabel("$-\log\mathcal{L}$", fontsize=fnt_size)
  344. ax2.legend(loc='best', fontsize = fnt_size*0.5)
  345. ax2.set_title(title+" Likelihood gain from # breakpoints")
  346. if ax is None:
  347. plt.savefig(title+'_Breakpts_Likelihood.pdf')
  348. # AIC
  349. if ax is None:
  350. fig, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  351. # plt.rcParams['font.size'] = '18'
  352. else:
  353. #plt.rcParams['font.size'] = fnt_size
  354. ax3 = ax[1][0,1]
  355. AIC = []
  356. for brk in np.array(brkpt_lik)[:, 0]:
  357. brk = int(brk)
  358. AIC.append((2*brk+1)+2*np.array(brkpt_lik)[brk, 1].astype(float))
  359. ax3.plot(np.array(brkpt_lik)[:, 0], AIC, 'o', linestyle = "dotted", lw=2)
  360. # AIC = 2*k - 2ln(L) ; where k is the number of parameters, here brks+1
  361. AIC_ln = 2*(len(brkpt_lik)+1) - 2*Ln
  362. ax3.axhline(y=AIC_ln, linestyle = "-.", color = "red",
  363. label = "Min. AIC = "+str(round(AIC_ln, 2)))
  364. selected_brks_nb = AIC.index(min(AIC))
  365. ax3.set_yscale('log')
  366. ax3.set_xlabel("# breakpoints", fontsize=fnt_size)
  367. ax3.set_ylabel("AIC")
  368. ax3.legend(loc='best', fontsize = fnt_size*0.5)
  369. ax3.set_title(title+" AIC")
  370. if ax is None:
  371. plt.savefig(title+'_Breakpts_Likelihood_AIC.pdf')
  372. print("S", S)
  373. # return plots
  374. return ax[0], ax[1]
  375. def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
  376. breaks_max = 10, output = None):
  377. """
  378. Save theta values as is to do basic plots.
  379. """
  380. cpt = 0
  381. epochs = {}
  382. len_sfs = 0
  383. for file_name in os.listdir(folder_path):
  384. cpt +=1
  385. if os.path.isfile(os.path.join(folder_path, file_name)):
  386. for k in range(breaks_max):
  387. thetas,sfs = return_x_y_from_stwp_theta_file_as_is(folder_path+file_name, breaks = k,
  388. tgen = tgen,
  389. mu = mu, relative_theta_scale = theta_scale)
  390. if thetas == 0:
  391. continue
  392. if len(thetas)-1 != k:
  393. continue
  394. if k not in epochs.keys():
  395. epochs[k] = {}
  396. likelihood = str(eval(thetas[k][2]))
  397. epochs[k][likelihood] = thetas
  398. #epochs[k] = thetas
  399. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(k)+"\n*******\n")
  400. print(cpt, "theta file(s) have been scanned.")
  401. plots = []
  402. best_epochs = {}
  403. for epoch in epochs:
  404. likelihoods = []
  405. for key in epochs[epoch].keys():
  406. likelihoods.append(key)
  407. likelihoods.sort()
  408. minLogLn = str(likelihoods[0])
  409. best_epochs[epoch] = epochs[epoch][minLogLn]
  410. for epoch, theta in best_epochs.items():
  411. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  412. x = []
  413. y = []
  414. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  415. for i,group in enumerate(groups):
  416. x += group[::-1]
  417. y += list(np.repeat(thetas[i], len(group)))
  418. if epoch == 0:
  419. N0 = y[0]
  420. # compute the proportion of information used at each bin of the SFS
  421. sum_theta_i = 0
  422. for i in range(2, len(y)+2):
  423. sum_theta_i+=y[i-2] / (i-1)
  424. prop = []
  425. for k in range(2, len(y)+2):
  426. prop.append(y[k-2] / (k - 1) / sum_theta_i)
  427. prop = prop[::-1]
  428. # normalise to N0 (N0 of epoch1)
  429. for i in range(len(y)):
  430. y[i] = y[i]/N0
  431. # x_plot, y_plot = plot_straight_x_y(x, y)
  432. p = x, y
  433. # add plot to the list of all plots to superimpose
  434. plots.append(p)
  435. cumul = 0
  436. prop_cumul = []
  437. for val in prop:
  438. prop_cumul.append(val+cumul)
  439. cumul = val+cumul
  440. prop = prop_cumul
  441. lines_fig2 = []
  442. for epoch, theta in best_epochs.items():
  443. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  444. x = []
  445. y = []
  446. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  447. for i,group in enumerate(groups):
  448. x += group[::-1]
  449. y += list(np.repeat(thetas[i], len(group)))
  450. if epoch == 0:
  451. N0 = y[0]
  452. for i in range(len(y)):
  453. y[i] = y[i]/N0
  454. x_2 = []
  455. T = 0
  456. for i in range(len(x)):
  457. x[i] = int(x[i])
  458. # compute the times as: theta_k / (k*(k-1))
  459. for i in range(0, len(x)):
  460. T += y[i] / (x[i]*(x[i]-1))
  461. x_2.append(T)
  462. # Save plotting (fig 2)
  463. x_2 = [0]+x_2
  464. y = [y[0]]+y
  465. # x2_plot, y2_plot = plot_straight_x_y(x_2, y)
  466. p2 = x_2, y
  467. lines_fig2.append(p2)
  468. saved_plots = {"raw_stairs":plots, "scaled_stairs":lines_fig2,
  469. "prop":prop}
  470. if output == None:
  471. output = title+"_plotdata.json"
  472. with open(output, 'w') as json_file:
  473. json.dump(saved_plots, json_file)
  474. return saved_plots
  475. def plot_scaled_theta(plot_lines, prop, title, ax = None, n_ticks = 10):
  476. # fig 2 & 3
  477. if ax is None:
  478. my_dpi = 300
  479. fnt_size = 18
  480. fig2, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  481. fig3, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  482. else:
  483. # plt.rcParams['font.size'] = fnt_size
  484. fnt_size = 12
  485. # place of plots on the grid
  486. ax2 = ax[1,0]
  487. ax3 = ax[1,1]
  488. lines_fig2 = []
  489. lines_fig3 = []
  490. #plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  491. for epoch, plot in enumerate(plot_lines):
  492. x,y=plot
  493. x2_plot, y2_plot = plot_straight_x_y(x,y)
  494. p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  495. lines_fig2.append(p2)
  496. # Plotting (fig 3) which is the same but log scale for x
  497. p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  498. lines_fig3.append(p3)
  499. ax2.set_xlabel("Relative scale", fontsize=fnt_size)
  500. ax2.set_ylabel("theta", fontsize=fnt_size)
  501. ax2.set_title(title, fontsize=fnt_size)
  502. ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
  503. if ax is None:
  504. # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
  505. plt.savefig(title+'_plot2_'+str(len(plot_lines))+'.pdf')
  506. # close fig2 to save memory
  507. plt.close(fig2)
  508. ax3.set_xscale('log')
  509. ax3.set_yscale('log')
  510. ax3.set_xlabel("log Relative scale", fontsize=fnt_size)
  511. ax3.set_ylabel("theta", fontsize=fnt_size)
  512. ax3.set_title(title, fontsize=fnt_size)
  513. ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5)
  514. if ax is None:
  515. # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
  516. plt.savefig(title+'_plot3_'+str(len(plot_lines))+'_log.pdf')
  517. # close fig3 to save memory
  518. plt.close(fig3)
  519. return ax
  520. def plot_raw_stairs(plot_lines, prop, title, ax = None, n_ticks = 10):
  521. # multiple fig
  522. if ax is None:
  523. # intialize figure 1
  524. my_dpi = 300
  525. fnt_size = 18
  526. # plt.rcParams['font.size'] = fnt_size
  527. fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  528. else:
  529. fnt_size = 12
  530. # plt.rcParams['font.size'] = fnt_size
  531. ax1 = ax[0, 0]
  532. plt.subplots_adjust(wspace=0.3, hspace=0.3)
  533. plots = []
  534. for epoch, plot in enumerate(plot_lines):
  535. x,y = plot
  536. x_plot, y_plot = plot_straight_x_y(x,y)
  537. p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  538. # add plot to the list of all plots to superimpose
  539. plots.append(p)
  540. x_ticks = x
  541. # print(x_ticks)
  542. #print(prop, "\n", sum(prop))
  543. #ax.legend(handles=[p0]+plots)
  544. ax1.set_xlabel("# bin & cumul. prop. of sites", fontsize=fnt_size)
  545. # Set the x-axis locator to reduce the number of ticks to 10
  546. ax1.set_ylabel("theta", fontsize=fnt_size)
  547. ax1.set_title(title, fontsize=fnt_size)
  548. ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
  549. ax1.set_xticks(x_ticks)
  550. step = len(x_ticks)//(n_ticks-1)
  551. values = x_ticks[::step]
  552. new_prop = []
  553. for val in values:
  554. new_prop.append(prop[int(val)-2])
  555. new_prop = new_prop[::-1]
  556. ax1.set_xticks(values)
  557. ax1.set_xticklabels([f'{values[k]}\n{val:.2f}' for k, val in enumerate(new_prop)], fontsize = fnt_size*0.8)
  558. if ax is None:
  559. # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
  560. plt.savefig(title+'_raw'+str(len(plot_lines))+'.pdf')
  561. plt.close(fig)
  562. # return plots
  563. return ax
  564. def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True, breaks_max = 10, ax = None, n_ticks = 10):
  565. """
  566. Use theta values as is to do basic plots.
  567. """
  568. cpt = 0
  569. epochs = {}
  570. len_sfs = 0
  571. for file_name in os.listdir(folder_path):
  572. cpt +=1
  573. if os.path.isfile(os.path.join(folder_path, file_name)):
  574. for k in range(breaks_max):
  575. thetas,sfs = return_x_y_from_stwp_theta_file_as_is(folder_path+file_name, breaks = k,
  576. tgen = tgen,
  577. mu = mu, relative_theta_scale = theta_scale)
  578. if thetas == 0:
  579. continue
  580. if len(thetas)-1 != k:
  581. continue
  582. if k not in epochs.keys():
  583. epochs[k] = {}
  584. likelihood = str(eval(thetas[k][2]))
  585. epochs[k][likelihood] = thetas
  586. #epochs[k] = thetas
  587. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(k)+"\n*******\n")
  588. print(cpt, "theta file(s) have been scanned.")
  589. # multiple fig
  590. if ax is None:
  591. # intialize figure 1
  592. my_dpi = 300
  593. fnt_size = 18
  594. # plt.rcParams['font.size'] = fnt_size
  595. fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  596. else:
  597. fnt_size = 12
  598. # plt.rcParams['font.size'] = fnt_size
  599. ax1 = ax[0, 1]
  600. plt.subplots_adjust(wspace=0.3, hspace=0.3)
  601. plots = []
  602. best_epochs = {}
  603. for epoch in epochs:
  604. likelihoods = []
  605. for key in epochs[epoch].keys():
  606. likelihoods.append(key)
  607. likelihoods.sort()
  608. minLogLn = str(likelihoods[0])
  609. best_epochs[epoch] = epochs[epoch][minLogLn]
  610. for epoch, theta in best_epochs.items():
  611. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  612. x = []
  613. y = []
  614. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  615. for i,group in enumerate(groups):
  616. x += group[::-1]
  617. y += list(np.repeat(thetas[i], len(group)))
  618. if epoch == 0:
  619. N0 = y[0]
  620. # compute the proportion of information used at each bin of the SFS
  621. sum_theta_i = 0
  622. for i in range(2, len(y)+2):
  623. sum_theta_i+=y[i-2] / (i-1)
  624. prop = []
  625. for k in range(2, len(y)+2):
  626. prop.append(y[k-2] / (k - 1) / sum_theta_i)
  627. prop = prop[::-1]
  628. # print(prop, "\n", sum(prop))
  629. # normalise to N0 (N0 of epoch1)
  630. x_ticks = ax1.get_xticks()
  631. for i in range(len(y)):
  632. y[i] = y[i]/N0
  633. # plot
  634. x_plot, y_plot = plot_straight_x_y(x, y)
  635. #plt.plot(x, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
  636. p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  637. # add plot to the list of all plots to superimpose
  638. plots.append(p)
  639. #print(prop, "\n", sum(prop))
  640. #ax.legend(handles=[p0]+plots)
  641. ax1.set_xlabel("# bin", fontsize=fnt_size)
  642. # Set the x-axis locator to reduce the number of ticks to 10
  643. ax1.set_ylabel("theta", fontsize=fnt_size)
  644. ax1.set_title(title, fontsize=fnt_size)
  645. ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
  646. ax1.set_xticks(x_ticks)
  647. if len(prop) >= 18:
  648. ax1.locator_params(nbins=n_ticks)
  649. # new scale of ticks if too many values
  650. cumul = 0
  651. prop_cumul = []
  652. for val in prop:
  653. prop_cumul.append(val+cumul)
  654. cumul = val+cumul
  655. ax1.set_xticklabels([f'{x[k]}\n{val:.2f}' for k, val in enumerate(prop_cumul)])
  656. if ax is None:
  657. plt.savefig(title+'_raw'+str(k)+'.pdf')
  658. # fig 2 & 3
  659. if ax is None:
  660. fig2, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  661. fig3, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  662. else:
  663. # plt.rcParams['font.size'] = fnt_size
  664. # place of plots on the grid
  665. ax2 = ax[1,0]
  666. ax3 = ax[1,1]
  667. lines_fig2 = []
  668. lines_fig3 = []
  669. #plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  670. for epoch, theta in best_epochs.items():
  671. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  672. x = []
  673. y = []
  674. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  675. for i,group in enumerate(groups):
  676. x += group[::-1]
  677. y += list(np.repeat(thetas[i], len(group)))
  678. if epoch == 0:
  679. N0 = y[0]
  680. for i in range(len(y)):
  681. y[i] = y[i]/N0
  682. x_2 = []
  683. T = 0
  684. for i in range(len(x)):
  685. x[i] = int(x[i])
  686. # compute the times as: theta_k / (k*(k-1))
  687. for i in range(0, len(x)):
  688. T += y[i] / (x[i]*(x[i]-1))
  689. x_2.append(T)
  690. # Plotting (fig 2)
  691. x_2 = [0]+x_2
  692. y = [y[0]]+y
  693. x2_plot, y2_plot = plot_straight_x_y(x_2, y)
  694. p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  695. lines_fig2.append(p2)
  696. # Plotting (fig 3) which is the same but log scale for x
  697. p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  698. lines_fig3.append(p3)
  699. ax2.set_xlabel("Relative scale", fontsize=fnt_size)
  700. ax2.set_ylabel("theta", fontsize=fnt_size)
  701. ax2.set_title(title, fontsize=fnt_size)
  702. ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
  703. if ax is None:
  704. plt.savefig(title+'_plot2_'+str(k)+'.pdf')
  705. ax3.set_xscale('log')
  706. ax3.set_yscale('log')
  707. ax3.set_xlabel("log Relative scale", fontsize=fnt_size)
  708. ax3.set_ylabel("theta", fontsize=fnt_size)
  709. ax3.set_title(title, fontsize=fnt_size)
  710. ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5)
  711. if ax is None:
  712. plt.savefig(title+'_plot3_'+str(k)+'_log.pdf')
  713. plt.clf()
  714. # return plots
  715. return ax
  716. def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale = True):
  717. my_dpi = 300
  718. # # Add some extra space for the second axis at the bottom
  719. # #plt.rcParams['font.size'] = 18
  720. # fig, axs = plt.subplots(2, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
  721. # #plt.rcParams['font.size'] = 12
  722. # ax = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = axs)
  723. # ax = plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = axs)
  724. # # Adjust layout to prevent clipping of titles
  725. #
  726. # # Save the entire grid as a single figure
  727. # plt.savefig(title+'_combined.pdf')
  728. # plt.clf()
  729. # # # second call for individual plots
  730. # # plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = None)
  731. # # plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = None)
  732. # # plt.clf()
  733. # save_k_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, output = title+"_plotdata.json")
  734. with open(title+"_plotdata.json", 'r') as json_file:
  735. loaded_data = json.load(json_file)
  736. # plot page 1 of summary
  737. fig1, ax1 = plt.subplots(2, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
  738. # fig1.tight_layout()
  739. # Adjust absolute space between the top and bottom rows
  740. fig1.subplots_adjust(hspace=0.35) # Adjust this value based on your requirement
  741. # plot page 2 of summary
  742. fig2, ax2 = plt.subplots(2, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
  743. # fig2.tight_layout()
  744. ax1 = plot_raw_stairs(plot_lines = loaded_data['raw_stairs'],
  745. prop = loaded_data['prop'], title = title, ax = ax1)
  746. ax1 = plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'],
  747. prop = loaded_data['prop'], title = title, ax = ax1)
  748. ax1, ax2 = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = [ax1, ax2])
  749. fig1.savefig(title+'_combined_p1.pdf')
  750. fig2.savefig(title+'_combined_p2.pdf')
  751. plot_raw_stairs(plot_lines = loaded_data['raw_stairs'],
  752. prop = loaded_data['prop'], title = title, ax = None)
  753. plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'],
  754. prop = loaded_data['prop'], title = title, ax = None)
  755. plt.close(fig1)
  756. plt.close(fig2)
  757. if __name__ == "__main__":
  758. if len(sys.argv) != 4:
  759. print("Need 3 args: ThetaFolder MutationRate GenerationTime")
  760. exit(0)
  761. folder_path = sys.argv[1]
  762. mu = sys.argv[2]
  763. tgen = sys.argv[3]
  764. plot_all_epochs_thetafolder(folder_path, mu, tgen)