swp2.py 29KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760
  1. import matplotlib.pyplot as plt
  2. import os
  3. import numpy as np
  4. import math
  5. import json
  6. import io
  7. from scipy.special import gammaln
  8. from matplotlib.backends.backend_pdf import PdfPages
  9. from matplotlib.ticker import MaxNLocator
  10. from mpl_toolkits.axes_grid1.inset_locator import inset_axes
  11. from matplotlib.ticker import MultipleLocator
  12. def log_facto(k):
  13. k = int(k)
  14. if k > 1e6:
  15. return k * np.log(k) - k + np.log(2*math.pi*k)/2
  16. val = 0
  17. for i in range(2, k+1):
  18. val += np.log(i)
  19. return val
  20. def log_facto_1(k):
  21. startf = 1 # start of factorial sequence
  22. stopf = int(k+1) # end of of factorial sequence
  23. q = gammaln(range(startf+1, stopf+1)) # n! = G(n+1)
  24. return q[-1]
  25. def return_x_y_from_stwp_theta_file(stwp_theta_file, breaks, mu, tgen, relative_theta_scale = False):
  26. with open(stwp_theta_file, "r") as swp_file:
  27. # Read the first line
  28. line = swp_file.readline()
  29. L = float(line.split()[2])
  30. rands = swp_file.readline()
  31. line = swp_file.readline()
  32. # skip empty lines before SFS
  33. while line == "\n":
  34. line = swp_file.readline()
  35. sfs = np.array(line.split()).astype(float)
  36. # Process lines until the end of the file
  37. while line:
  38. # check at each line
  39. if line.startswith("dim") :
  40. dim = int(line.split()[1])
  41. if dim == breaks+1:
  42. likelihood = line.split()[5]
  43. groups = line.split()[6:6+dim]
  44. theta_site = line.split()[6+dim:6+dim+1+dim]
  45. elif dim < breaks+1:
  46. line = swp_file.readline()
  47. continue
  48. elif dim > breaks+1:
  49. break
  50. #return 0,0,0
  51. # Read the next line
  52. line = swp_file.readline()
  53. #### END of parsing
  54. # quit this file if the number of dimensions is incorrect
  55. if dim < breaks+1:
  56. return 0,0,0,0,0,0
  57. # get n, the last bin of the last group
  58. # revert the list of groups as the most recent times correspond
  59. # to the closest and last leafs of the coal. tree.
  60. groups = groups[::-1]
  61. theta_site = theta_site[::-1]
  62. # store thetas for later use
  63. grps = groups.copy()
  64. thetas = {}
  65. for i in range(len(groups)):
  66. grps[i] = grps[i].split(',')
  67. thetas[i] = [float(theta_site[i]), grps[i], likelihood]
  68. # initiate the dict of times
  69. t = {}
  70. # list of thetas
  71. theta_L = []
  72. sum_t = 0
  73. for group_nb, group in enumerate(groups):
  74. ###print(group_nb, group, theta_site[group_nb], len(theta_site))
  75. # store all the thetas one by one, with one theta per group
  76. theta_L.append(float(theta_site[group_nb]))
  77. # if the group is of size 1
  78. if len(group.split(',')) == 1:
  79. i = int(group)
  80. # if the group size is >1, take the first elem of the group
  81. # i is the first bin of each group, straight after a breakpoint
  82. else:
  83. i = int(group.split(",")[0])
  84. j = int(group.split(",")[-1])
  85. t[i] = 0
  86. #t =
  87. if len(group.split(',')) == 1:
  88. k = i
  89. if relative_theta_scale:
  90. t[i] += ((theta_L[group_nb] ) / (k*(k-1)))
  91. else:
  92. t[i] += ((theta_L[group_nb] ) / (k*(k-1)) * tgen) / mu
  93. else:
  94. for k in range(j, i-1, -1 ):
  95. if relative_theta_scale:
  96. t[i] += ((theta_L[group_nb] ) / (k*(k-1)))
  97. else:
  98. t[i] += ((theta_L[group_nb] ) / (k*(k-1)) * tgen) / mu
  99. # we add the cumulative times at the end
  100. t[i] += sum_t
  101. sum_t = t[i]
  102. # build the y axis (sizes)
  103. y = []
  104. for theta in theta_L:
  105. if relative_theta_scale:
  106. size = theta
  107. else:
  108. # with size N = theta/4mu
  109. size = theta / (4*mu)
  110. y.append(size)
  111. y.append(size)
  112. # build the time x axis
  113. x = [0]
  114. for time in range(0, len(t.values())-1):
  115. x.append(list(t.values())[time])
  116. x.append(list(t.values())[time])
  117. x.append(list(t.values())[len(t.values())-1])
  118. # if relative_theta_scale:
  119. # # rescale
  120. # #N0 = y[0]
  121. # # for i in range(len(y)):
  122. # # # divide by N0
  123. # # y[i] = y[i]/N0
  124. # # x[i] = x[i]/N0
  125. return x,y,likelihood,thetas,sfs,L
  126. def return_x_y_from_stwp_theta_file_as_is(stwp_theta_file, breaks, mu, tgen, relative_theta_scale = False):
  127. with open(stwp_theta_file, "r") as swp_file:
  128. # Read the first line
  129. line = swp_file.readline()
  130. L = float(line.split()[2])
  131. rands = swp_file.readline()
  132. line = swp_file.readline()
  133. # skip empty lines before SFS
  134. while line == "\n":
  135. line = swp_file.readline()
  136. sfs = np.array(line.split()).astype(float)
  137. # Process lines until the end of the file
  138. while line:
  139. # check at each line
  140. if line.startswith("dim") :
  141. dim = int(line.split()[1])
  142. if dim == breaks+1:
  143. likelihood = line.split()[5]
  144. groups = line.split()[6:6+dim]
  145. theta_site = line.split()[6+dim:6+dim+1+dim]
  146. elif dim < breaks+1:
  147. line = swp_file.readline()
  148. continue
  149. elif dim > breaks+1:
  150. break
  151. #return 0,0,0
  152. # Read the next line
  153. line = swp_file.readline()
  154. #### END of parsing
  155. # quit this file if the number of dimensions is incorrect
  156. if dim < breaks+1:
  157. return 0,0
  158. # get n, the last bin of the last group
  159. # revert the list of groups as the most recent times correspond
  160. # to the closest and last leafs of the coal. tree.
  161. groups = groups[::-1]
  162. theta_site = theta_site[::-1]
  163. thetas = {}
  164. for i in range(len(groups)):
  165. groups[i] = groups[i].split(',')
  166. # print(groups[i], len(groups[i]))
  167. thetas[i] = [float(theta_site[i]), groups[i], likelihood]
  168. return thetas, sfs
  169. def plot_k_epochs_thetafolder(folder_path, mu, tgen, breaks = 2, title = "Title", theta_scale = True):
  170. scenari = {}
  171. cpt = 0
  172. for file_name in os.listdir(folder_path):
  173. if os.path.isfile(os.path.join(folder_path, file_name)):
  174. # Perform actions on each file
  175. x,y,likelihood,sfs,L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
  176. tgen = tgen,
  177. mu = mu, relative_theta_scale = theta_scale)
  178. if x == 0 or y == 0:
  179. continue
  180. cpt +=1
  181. scenari[likelihood] = x,y
  182. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
  183. print(cpt, "theta file(s) have been scanned.")
  184. # sort starting by the smallest -log(Likelihood)
  185. print(scenari)
  186. best10_scenari = (sorted(list(scenari.keys())))[:10]
  187. print("10 greatest Likelihoods", best10_scenari)
  188. greatest_likelihood = best10_scenari[0]
  189. x, y = scenari[greatest_likelihood]
  190. my_dpi = 300
  191. plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  192. plt.plot(x, y, 'r-', lw=2, label = 'Lik='+greatest_likelihood)
  193. #plt.yscale('log')
  194. plt.xscale('log')
  195. plt.grid(True,which="both", linestyle='--', alpha = 0.3)
  196. for scenario in best10_scenari[1:]:
  197. x,y = scenari[scenario]
  198. #print("\n---- Lik:",scenario,"\n\nt=", x,"\n\nN=",y, "\n\n")
  199. plt.plot(x, y, '--', lw=1, label = 'Lik='+scenario)
  200. if theta_scale:
  201. plt.xlabel("Coal. time")
  202. plt.ylabel("Pop. size scaled by N0")
  203. recent_scale_lower_bound = y[0] * 0.01
  204. recent_scale_upper_bound = y[0] * 0.1
  205. plt.axvline(x=recent_scale_lower_bound)
  206. plt.axvline(x=recent_scale_upper_bound)
  207. else:
  208. # years
  209. plt.xlabel("Time (years)")
  210. plt.ylabel("Individuals (N)")
  211. plt.legend(loc='upper right')
  212. plt.title(title)
  213. plt.savefig(title+'_b'+str(breaks)+'.pdf')
  214. def plot_straight_x_y(x,y):
  215. x_1 = [x[0]]
  216. y_1 = []
  217. for i in range(0, len(y)-1):
  218. x_1.append(x[i])
  219. x_1.append(x[i])
  220. y_1.append(y[i])
  221. y_1.append(y[i])
  222. y_1 = y_1+[y[-1],y[-1]]
  223. x_1.append(x[-1])
  224. return x_1, y_1
  225. def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_scale = True, ax = None):
  226. #scenari = {}
  227. cpt = 0
  228. epochs = {}
  229. for file_name in os.listdir(folder_path):
  230. breaks = 0
  231. cpt +=1
  232. if os.path.isfile(os.path.join(folder_path, file_name)):
  233. x, y, likelihood, theta, sfs, L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
  234. tgen = tgen,
  235. mu = mu, relative_theta_scale = theta_scale)
  236. SFS_stored = sfs
  237. L_stored = L
  238. while not (x == 0 and y == 0):
  239. if breaks not in epochs.keys():
  240. epochs[breaks] = {}
  241. epochs[breaks][likelihood] = x,y
  242. breaks += 1
  243. x,y,likelihood,theta,sfs,L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
  244. tgen = tgen,
  245. mu = mu, relative_theta_scale = theta_scale)
  246. if x == 0:
  247. # last break did not work, then breaks = breaks-1
  248. breaks -= 1
  249. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
  250. print(cpt, "theta file(s) have been scanned.")
  251. my_dpi = 300
  252. if ax is None:
  253. # intialize figure
  254. my_dpi = 300
  255. fnt_size = 18
  256. # plt.rcParams['font.size'] = fnt_size
  257. fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  258. else:
  259. fnt_size = 12
  260. # plt.rcParams['font.size'] = fnt_size
  261. ax1 = ax[0,0]
  262. ax1.set_yscale('log')
  263. ax1.set_xscale('log')
  264. ax1.grid(True,which="both", linestyle='--', alpha = 0.3)
  265. brkpt_lik = []
  266. top_plots = {}
  267. for epoch, scenari in epochs.items():
  268. # sort starting by the smallest -log(Likelihood)
  269. best10_scenari = (sorted(list(scenari.keys())))[:10]
  270. greatest_likelihood = best10_scenari[0]
  271. # store the tuple breakpoints and likelihood for later plot
  272. brkpt_lik.append((epoch, greatest_likelihood))
  273. x, y = scenari[greatest_likelihood]
  274. #without breakpoint
  275. if epoch == 0:
  276. # do something with the theta without bp and skip the plotting
  277. N0 = y[0]
  278. #continue
  279. for i in range(len(y)):
  280. # divide by N0
  281. y[i] = y[i]/N0
  282. x[i] = x[i]/N0
  283. top_plots[greatest_likelihood] = x,y,epoch
  284. plots_likelihoods = list(top_plots.keys())
  285. for i in range(len(plots_likelihoods)):
  286. plots_likelihoods[i] = float(plots_likelihoods[i])
  287. best10_plots = sorted(plots_likelihoods)[:10]
  288. top_plot_lik = str(best10_plots[0])
  289. plot_handles = []
  290. # plt.rcParams['font.size'] = fnt_size
  291. p0, = ax1.plot(top_plots[top_plot_lik][0], top_plots[top_plot_lik][1], 'o', linestyle = "-",
  292. alpha=1, lw=2, label = str(top_plots[top_plot_lik][2])+' brks | Lik='+top_plot_lik)
  293. plot_handles.append(p0)
  294. for k, plot_Lk in enumerate(best10_plots[1:]):
  295. plot_Lk = str(plot_Lk)
  296. # plt.rcParams['font.size'] = fnt_size
  297. p, = ax1.plot(top_plots[plot_Lk][0], top_plots[plot_Lk][1], 'o', linestyle = "--",
  298. alpha=1/(k+1), lw=1.5, label = str(top_plots[plot_Lk][2])+' brks | Lik='+plot_Lk)
  299. plot_handles.append(p)
  300. if theta_scale:
  301. ax1.set_xlabel("Coal. time", fontsize=fnt_size)
  302. ax1.set_ylabel("Pop. size scaled by N0", fontsize=fnt_size)
  303. # recent_scale_lower_bound = 0.01
  304. # recent_scale_upper_bound = 0.1
  305. # ax1.axvline(x=recent_scale_lower_bound)
  306. # ax1.axvline(x=recent_scale_upper_bound)
  307. else:
  308. # years
  309. plt.set_xlabel("Time (years)", fontsize=fnt_size)
  310. plt.set_ylabel("Individuals (N)", fontsize=fnt_size)
  311. # plt.rcParams['font.size'] = fnt_size
  312. # print(fnt_size, "rcParam font.size=", plt.rcParams['font.size'])
  313. ax1.legend(handles = plot_handles, loc='best', fontsize = fnt_size*0.5)
  314. ax1.set_title(title)
  315. if ax is None:
  316. plt.savefig(title+'_b'+str(breaks)+'.pdf')
  317. # plot likelihood against nb of breakpoints
  318. # best possible likelihood from SFS
  319. # Segregating sites
  320. S = sum(SFS_stored)
  321. # Number of kept sites from which the SFS is computed
  322. L = L_stored
  323. # number of monomorphic sites
  324. S0 = L-S
  325. # print("SFS", SFS_stored)
  326. # print("S", S, "L", L, "S0=", S0)
  327. # compute Ln
  328. Ln = log_facto(S+S0) - log_facto(S0) + np.log(float(S0)/(S+S0)) * S0
  329. for xi in range(0, len(SFS_stored)):
  330. p_i = SFS_stored[xi] / float(S+S0)
  331. Ln += np.log(p_i) * SFS_stored[xi] - log_facto(SFS_stored[xi])
  332. # basic plot likelihood
  333. if ax is None:
  334. fig, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  335. # plt.rcParams['font.size'] = fnt_size
  336. else:
  337. #plt.rcParams['font.size'] = fnt_size
  338. ax2 = ax[2,0]
  339. ax2.plot(np.array(brkpt_lik)[:, 0], np.array(brkpt_lik)[:, 1].astype(float), 'o', linestyle = "dotted", lw=2)
  340. ax2.axhline(y=-Ln, linestyle = "-.", color = "red", label = "$-\log\mathcal{L}$ = "+str(round(-Ln, 2)))
  341. ax2.set_yscale('log')
  342. ax2.set_xlabel("# breakpoints", fontsize=fnt_size)
  343. ax2.set_ylabel("$-\log\mathcal{L}$", fontsize=fnt_size)
  344. ax2.legend(loc='best', fontsize = fnt_size*0.5)
  345. ax2.set_title(title+" Likelihood gain from # breakpoints")
  346. if ax is None:
  347. plt.savefig(title+'_Breakpts_Likelihood.pdf')
  348. # AIC
  349. if ax is None:
  350. fig, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  351. # plt.rcParams['font.size'] = '18'
  352. else:
  353. #plt.rcParams['font.size'] = fnt_size
  354. ax3 = ax[2,1]
  355. AIC = 2*(len(brkpt_lik)+1)+2*np.array(brkpt_lik)[:, 1].astype(float)
  356. ax3.plot(np.array(brkpt_lik)[:, 0], AIC, 'o', linestyle = "dotted", lw=2)
  357. AIC_ln = 2*(len(brkpt_lik)+1)-2*Ln
  358. ax3.axhline(y=AIC_ln, linestyle = "-.", color = "red",
  359. label = "Min. AIC = "+str(round(AIC_ln, 2)))
  360. ax3.set_yscale('log')
  361. ax3.set_xlabel("# breakpoints", fontsize=fnt_size)
  362. ax3.set_ylabel("AIC")
  363. ax3.legend(loc='best', fontsize = fnt_size*0.5)
  364. ax3.set_title(title+" AIC")
  365. if ax is None:
  366. plt.savefig(title+'_Breakpts_Likelihood_AIC.pdf')
  367. print("S", S)
  368. # return plots
  369. return ax
  370. def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
  371. breaks_max = 10, output = None):
  372. """
  373. Save theta values as is to do basic plots.
  374. """
  375. cpt = 0
  376. epochs = {}
  377. len_sfs = 0
  378. for file_name in os.listdir(folder_path):
  379. cpt +=1
  380. if os.path.isfile(os.path.join(folder_path, file_name)):
  381. for k in range(breaks_max):
  382. thetas,sfs = return_x_y_from_stwp_theta_file_as_is(folder_path+file_name, breaks = k,
  383. tgen = tgen,
  384. mu = mu, relative_theta_scale = theta_scale)
  385. if thetas == 0:
  386. continue
  387. if len(thetas)-1 != k:
  388. continue
  389. if k not in epochs.keys():
  390. epochs[k] = {}
  391. likelihood = str(eval(thetas[k][2]))
  392. epochs[k][likelihood] = thetas
  393. #epochs[k] = thetas
  394. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(k)+"\n*******\n")
  395. print(cpt, "theta file(s) have been scanned.")
  396. plots = []
  397. best_epochs = {}
  398. for epoch in epochs:
  399. likelihoods = []
  400. for key in epochs[epoch].keys():
  401. likelihoods.append(key)
  402. likelihoods.sort()
  403. minLogLn = str(likelihoods[0])
  404. best_epochs[epoch] = epochs[epoch][minLogLn]
  405. for epoch, theta in best_epochs.items():
  406. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  407. x = []
  408. y = []
  409. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  410. for i,group in enumerate(groups):
  411. x += group[::-1]
  412. y += list(np.repeat(thetas[i], len(group)))
  413. if epoch == 0:
  414. N0 = y[0]
  415. # compute the proportion of information used at each bin of the SFS
  416. sum_theta_i = 0
  417. for i in range(2, len(y)+2):
  418. sum_theta_i+=y[i-2] / (i-1)
  419. prop = []
  420. for k in range(2, len(y)+2):
  421. prop.append(y[k-2] / (k - 1) / sum_theta_i)
  422. prop = prop[::-1]
  423. # normalise to N0 (N0 of epoch1)
  424. for i in range(len(y)):
  425. y[i] = y[i]/N0
  426. # x_plot, y_plot = plot_straight_x_y(x, y)
  427. p = x, y
  428. # add plot to the list of all plots to superimpose
  429. plots.append(p)
  430. cumul = 0
  431. prop_cumul = []
  432. for val in prop:
  433. prop_cumul.append(val+cumul)
  434. cumul = val+cumul
  435. prop = prop_cumul
  436. lines_fig2 = []
  437. for epoch, theta in best_epochs.items():
  438. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  439. x = []
  440. y = []
  441. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  442. for i,group in enumerate(groups):
  443. x += group[::-1]
  444. y += list(np.repeat(thetas[i], len(group)))
  445. if epoch == 0:
  446. N0 = y[0]
  447. for i in range(len(y)):
  448. y[i] = y[i]/N0
  449. x_2 = []
  450. T = 0
  451. for i in range(len(x)):
  452. x[i] = int(x[i])
  453. # compute the times as: theta_k / (k*(k-1))
  454. for i in range(0, len(x)):
  455. T += y[i] / (x[i]*(x[i]-1))
  456. x_2.append(T)
  457. # Save plotting (fig 2)
  458. x_2 = [0]+x_2
  459. y = [y[0]]+y
  460. # x2_plot, y2_plot = plot_straight_x_y(x_2, y)
  461. p2 = x_2, y
  462. lines_fig2.append(p2)
  463. saved_plots = {"raw_stairs":plots, "scaled_stairs":lines_fig2,
  464. "prop":prop}
  465. if output == None:
  466. output = title+"_plotdata.json"
  467. with open(output, 'w') as json_file:
  468. json.dump(saved_plots, json_file)
  469. return saved_plots
  470. def plot_raw_stairs(plot_lines, plot_lines2, prop, title, ax = None, n_ticks = 10):
  471. # multiple fig
  472. if ax is None:
  473. # intialize figure 1
  474. my_dpi = 300
  475. fnt_size = 18
  476. # plt.rcParams['font.size'] = fnt_size
  477. fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  478. else:
  479. fnt_size = 12
  480. # plt.rcParams['font.size'] = fnt_size
  481. ax1 = ax[0, 1]
  482. plt.subplots_adjust(wspace=0.3, hspace=0.3)
  483. plots = []
  484. for epoch, plot in enumerate(plot_lines):
  485. x,y = plot
  486. x_plot, y_plot = plot_straight_x_y(x,y)
  487. p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  488. # add plot to the list of all plots to superimpose
  489. plots.append(p)
  490. x_ticks = x
  491. # print(x_ticks)
  492. #print(prop, "\n", sum(prop))
  493. #ax.legend(handles=[p0]+plots)
  494. ax1.set_xlabel("# bin", fontsize=fnt_size)
  495. # Set the x-axis locator to reduce the number of ticks to 10
  496. ax1.set_ylabel("theta", fontsize=fnt_size)
  497. ax1.set_title("Title", fontsize=fnt_size)
  498. ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
  499. ax1.set_xticks(x_ticks)
  500. step = len(x_ticks)//(n_ticks-1)
  501. values = x_ticks[::step]
  502. new_prop = []
  503. for val in values:
  504. new_prop.append(prop[int(val)-2])
  505. new_prop = new_prop[::-1]
  506. ax1.set_xticks(values)
  507. ax1.set_xticklabels([f'{values[k]}\n{val:.2f}' for k, val in enumerate(new_prop)], fontsize = fnt_size*0.8)
  508. if ax is None:
  509. plt.savefig(title+'_raw'+str(k)+'.pdf')
  510. # fig 2 & 3
  511. if ax is None:
  512. fig2, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  513. fig3, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  514. else:
  515. # plt.rcParams['font.size'] = fnt_size
  516. # place of plots on the grid
  517. ax2 = ax[1,0]
  518. ax3 = ax[1,1]
  519. lines_fig2 = []
  520. lines_fig3 = []
  521. #plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  522. for epoch, plot in enumerate(plot_lines2):
  523. x,y=plot
  524. x2_plot, y2_plot = plot_straight_x_y(x,y)
  525. p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  526. lines_fig2.append(p2)
  527. # Plotting (fig 3) which is the same but log scale for x
  528. p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  529. lines_fig3.append(p3)
  530. ax2.set_xlabel("Relative scale", fontsize=fnt_size)
  531. ax2.set_ylabel("theta", fontsize=fnt_size)
  532. ax2.set_title("Title", fontsize=fnt_size)
  533. ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
  534. if ax is None:
  535. plt.savefig(title+'_plot2_'+str(k)+'.pdf')
  536. ax3.set_xscale('log')
  537. ax3.set_yscale('log')
  538. ax3.set_xlabel("log Relative scale", fontsize=fnt_size)
  539. ax3.set_ylabel("theta", fontsize=fnt_size)
  540. ax3.set_title("Title", fontsize=fnt_size)
  541. ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5)
  542. if ax is None:
  543. plt.savefig(title+'_plot3_'+str(k)+'_log.pdf')
  544. plt.clf()
  545. # return plots
  546. return ax
  547. def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True, breaks_max = 10, ax = None, n_ticks = 10):
  548. """
  549. Use theta values as is to do basic plots.
  550. """
  551. cpt = 0
  552. epochs = {}
  553. len_sfs = 0
  554. for file_name in os.listdir(folder_path):
  555. cpt +=1
  556. if os.path.isfile(os.path.join(folder_path, file_name)):
  557. for k in range(breaks_max):
  558. thetas,sfs = return_x_y_from_stwp_theta_file_as_is(folder_path+file_name, breaks = k,
  559. tgen = tgen,
  560. mu = mu, relative_theta_scale = theta_scale)
  561. if thetas == 0:
  562. continue
  563. if len(thetas)-1 != k:
  564. continue
  565. if k not in epochs.keys():
  566. epochs[k] = {}
  567. likelihood = str(eval(thetas[k][2]))
  568. epochs[k][likelihood] = thetas
  569. #epochs[k] = thetas
  570. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(k)+"\n*******\n")
  571. print(cpt, "theta file(s) have been scanned.")
  572. # multiple fig
  573. if ax is None:
  574. # intialize figure 1
  575. my_dpi = 300
  576. fnt_size = 18
  577. # plt.rcParams['font.size'] = fnt_size
  578. fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  579. else:
  580. fnt_size = 12
  581. # plt.rcParams['font.size'] = fnt_size
  582. ax1 = ax[0, 1]
  583. plt.subplots_adjust(wspace=0.3, hspace=0.3)
  584. plots = []
  585. best_epochs = {}
  586. for epoch in epochs:
  587. likelihoods = []
  588. for key in epochs[epoch].keys():
  589. likelihoods.append(key)
  590. likelihoods.sort()
  591. minLogLn = str(likelihoods[0])
  592. best_epochs[epoch] = epochs[epoch][minLogLn]
  593. for epoch, theta in best_epochs.items():
  594. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  595. x = []
  596. y = []
  597. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  598. for i,group in enumerate(groups):
  599. x += group[::-1]
  600. y += list(np.repeat(thetas[i], len(group)))
  601. if epoch == 0:
  602. N0 = y[0]
  603. # compute the proportion of information used at each bin of the SFS
  604. sum_theta_i = 0
  605. for i in range(2, len(y)+2):
  606. sum_theta_i+=y[i-2] / (i-1)
  607. prop = []
  608. for k in range(2, len(y)+2):
  609. prop.append(y[k-2] / (k - 1) / sum_theta_i)
  610. prop = prop[::-1]
  611. # print(prop, "\n", sum(prop))
  612. # normalise to N0 (N0 of epoch1)
  613. x_ticks = ax1.get_xticks()
  614. for i in range(len(y)):
  615. y[i] = y[i]/N0
  616. # plot
  617. x_plot, y_plot = plot_straight_x_y(x, y)
  618. #plt.plot(x, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
  619. p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  620. # add plot to the list of all plots to superimpose
  621. plots.append(p)
  622. #print(prop, "\n", sum(prop))
  623. #ax.legend(handles=[p0]+plots)
  624. ax1.set_xlabel("# bin", fontsize=fnt_size)
  625. # Set the x-axis locator to reduce the number of ticks to 10
  626. ax1.set_ylabel("theta", fontsize=fnt_size)
  627. ax1.set_title("Title", fontsize=fnt_size)
  628. ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
  629. ax1.set_xticks(x_ticks)
  630. if len(prop) >= 18:
  631. ax1.locator_params(nbins=n_ticks)
  632. # new scale of ticks if too many values
  633. cumul = 0
  634. prop_cumul = []
  635. for val in prop:
  636. prop_cumul.append(val+cumul)
  637. cumul = val+cumul
  638. ax1.set_xticklabels([f'{x[k]}\n{val:.2f}' for k, val in enumerate(prop_cumul)])
  639. if ax is None:
  640. plt.savefig(title+'_raw'+str(k)+'.pdf')
  641. # fig 2 & 3
  642. if ax is None:
  643. fig2, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  644. fig3, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  645. else:
  646. # plt.rcParams['font.size'] = fnt_size
  647. # place of plots on the grid
  648. ax2 = ax[1,0]
  649. ax3 = ax[1,1]
  650. lines_fig2 = []
  651. lines_fig3 = []
  652. #plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  653. for epoch, theta in best_epochs.items():
  654. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  655. x = []
  656. y = []
  657. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  658. for i,group in enumerate(groups):
  659. x += group[::-1]
  660. y += list(np.repeat(thetas[i], len(group)))
  661. if epoch == 0:
  662. N0 = y[0]
  663. for i in range(len(y)):
  664. y[i] = y[i]/N0
  665. x_2 = []
  666. T = 0
  667. for i in range(len(x)):
  668. x[i] = int(x[i])
  669. # compute the times as: theta_k / (k*(k-1))
  670. for i in range(0, len(x)):
  671. T += y[i] / (x[i]*(x[i]-1))
  672. x_2.append(T)
  673. # Plotting (fig 2)
  674. x_2 = [0]+x_2
  675. y = [y[0]]+y
  676. x2_plot, y2_plot = plot_straight_x_y(x_2, y)
  677. p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  678. lines_fig2.append(p2)
  679. # Plotting (fig 3) which is the same but log scale for x
  680. p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
  681. lines_fig3.append(p3)
  682. ax2.set_xlabel("Relative scale", fontsize=fnt_size)
  683. ax2.set_ylabel("theta", fontsize=fnt_size)
  684. ax2.set_title("Title", fontsize=fnt_size)
  685. ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
  686. if ax is None:
  687. plt.savefig(title+'_plot2_'+str(k)+'.pdf')
  688. ax3.set_xscale('log')
  689. ax3.set_yscale('log')
  690. ax3.set_xlabel("log Relative scale", fontsize=fnt_size)
  691. ax3.set_ylabel("theta", fontsize=fnt_size)
  692. ax3.set_title("Title", fontsize=fnt_size)
  693. ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5)
  694. if ax is None:
  695. plt.savefig(title+'_plot3_'+str(k)+'_log.pdf')
  696. plt.clf()
  697. # return plots
  698. return ax
  699. def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale = True):
  700. my_dpi = 300
  701. # # Add some extra space for the second axis at the bottom
  702. # #plt.rcParams['font.size'] = 18
  703. # fig, axs = plt.subplots(2, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
  704. # #plt.rcParams['font.size'] = 12
  705. # ax = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = axs)
  706. # ax = plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = axs)
  707. # # Adjust layout to prevent clipping of titles
  708. # plt.tight_layout()
  709. # # Adjust absolute space between the top and bottom rows
  710. # #plt.subplots_adjust(hspace=0.7) # Adjust this value based on your requirement
  711. # # Save the entire grid as a single figure
  712. # plt.savefig(title+'_combined.pdf')
  713. # plt.clf()
  714. # # # second call for individual plots
  715. # # plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = None)
  716. # # plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = None)
  717. # # plt.clf()
  718. # save_k_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, output = title+"_plotdata.json")
  719. with open(title+"_plotdata.json", 'r') as json_file:
  720. loaded_data = json.load(json_file)
  721. fig1, ax1 = plt.subplots(2, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
  722. # plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = ax1)
  723. ax1 = plot_raw_stairs(plot_lines = loaded_data['raw_stairs'], plot_lines2 = loaded_data['scaled_stairs'],
  724. prop = loaded_data['prop'], title = title, ax = ax1)
  725. plt.savefig(title+'_raw_scaled.pdf')
  726. fig1.clf()
  727. if __name__ == "__main__":
  728. if len(sys.argv) != 4:
  729. print("Need 3 args: ThetaFolder MutationRate GenerationTime")
  730. exit(0)
  731. folder_path = sys.argv[1]
  732. mu = sys.argv[2]
  733. tgen = sys.argv[3]
  734. plot_all_epochs_thetafolder(folder_path, mu, tgen)