swp2.py 19KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. import matplotlib.pyplot as plt
  2. import os
  3. import numpy as np
  4. import math
  5. from scipy.special import gammaln
  6. from matplotlib.backends.backend_pdf import PdfPages
  7. from matplotlib.ticker import MaxNLocator
  8. import io
  9. from mpl_toolkits.axes_grid1.inset_locator import inset_axes
  10. def log_facto(k):
  11. k = int(k)
  12. if k > 1e6:
  13. return k * np.log(k) - k + np.log(2*math.pi*k)/2
  14. val = 0
  15. for i in range(2, k+1):
  16. val += np.log(i)
  17. return val
  18. def log_facto_1(k):
  19. startf = 1 # start of factorial sequence
  20. stopf = int(k+1) # end of of factorial sequence
  21. q = gammaln(range(startf+1, stopf+1)) # n! = G(n+1)
  22. return q[-1]
  23. def return_x_y_from_stwp_theta_file(stwp_theta_file, breaks, mu, tgen, relative_theta_scale = False):
  24. with open(stwp_theta_file, "r") as swp_file:
  25. # Read the first line
  26. line = swp_file.readline()
  27. L = float(line.split()[2])
  28. rands = swp_file.readline()
  29. line = swp_file.readline()
  30. # skip empty lines before SFS
  31. while line == "\n":
  32. line = swp_file.readline()
  33. sfs = np.array(line.split()).astype(float)
  34. # Process lines until the end of the file
  35. while line:
  36. # check at each line
  37. if line.startswith("dim") :
  38. dim = int(line.split()[1])
  39. if dim == breaks+1:
  40. likelihood = line.split()[5]
  41. groups = line.split()[6:6+dim]
  42. theta_site = line.split()[6+dim:6+dim+1+dim]
  43. elif dim < breaks+1:
  44. line = swp_file.readline()
  45. continue
  46. elif dim > breaks+1:
  47. break
  48. #return 0,0,0
  49. # Read the next line
  50. line = swp_file.readline()
  51. #### END of parsing
  52. # quit this file if the number of dimensions is incorrect
  53. if dim < breaks+1:
  54. return 0,0,0,0,0
  55. # get n, the last bin of the last group
  56. # revert the list of groups as the most recent times correspond
  57. # to the closest and last leafs of the coal. tree.
  58. groups = groups[::-1]
  59. theta_site = theta_site[::-1]
  60. # initiate the dict of times
  61. t = {}
  62. # list of thetas
  63. theta_L = []
  64. sum_t = 0
  65. for group_nb, group in enumerate(groups):
  66. ###print(group_nb, group, theta_site[group_nb], len(theta_site))
  67. # store all the thetas one by one, with one theta per group
  68. theta_L.append(float(theta_site[group_nb]))
  69. # if the group is of size 1
  70. if len(group.split(',')) == 1:
  71. i = int(group)
  72. # if the group size is >1, take the first elem of the group
  73. # i is the first bin of each group, straight after a breakpoint
  74. else:
  75. i = int(group.split(",")[0])
  76. j = int(group.split(",")[-1])
  77. t[i] = 0
  78. #t =
  79. if len(group.split(',')) == 1:
  80. k = i
  81. if relative_theta_scale:
  82. t[i] += ((theta_L[group_nb] ) / (k*(k-1)))
  83. else:
  84. t[i] += ((theta_L[group_nb] ) / (k*(k-1)) * tgen) / mu
  85. else:
  86. for k in range(j, i-1, -1 ):
  87. if relative_theta_scale:
  88. t[i] += ((theta_L[group_nb] ) / (k*(k-1)))
  89. else:
  90. t[i] += ((theta_L[group_nb] ) / (k*(k-1)) * tgen) / mu
  91. # we add the cumulative times at the end
  92. t[i] += sum_t
  93. sum_t = t[i]
  94. # build the y axis (sizes)
  95. y = []
  96. for theta in theta_L:
  97. if relative_theta_scale:
  98. size = theta
  99. else:
  100. # with size N = theta/4mu
  101. size = theta / (4*mu)
  102. y.append(size)
  103. y.append(size)
  104. # build the time x axis
  105. x = [0]
  106. for time in range(0, len(t.values())-1):
  107. x.append(list(t.values())[time])
  108. x.append(list(t.values())[time])
  109. x.append(list(t.values())[len(t.values())-1])
  110. # if relative_theta_scale:
  111. # # rescale
  112. # #N0 = y[0]
  113. # # for i in range(len(y)):
  114. # # # divide by N0
  115. # # y[i] = y[i]/N0
  116. # # x[i] = x[i]/N0
  117. return x,y,likelihood,sfs,L
  118. def return_x_y_from_stwp_theta_file_as_is(stwp_theta_file, breaks, mu, tgen, relative_theta_scale = False):
  119. with open(stwp_theta_file, "r") as swp_file:
  120. # Read the first line
  121. line = swp_file.readline()
  122. L = float(line.split()[2])
  123. rands = swp_file.readline()
  124. line = swp_file.readline()
  125. # skip empty lines before SFS
  126. while line == "\n":
  127. line = swp_file.readline()
  128. sfs = np.array(line.split()).astype(float)
  129. # Process lines until the end of the file
  130. while line:
  131. # check at each line
  132. if line.startswith("dim") :
  133. dim = int(line.split()[1])
  134. if dim == breaks+1:
  135. likelihood = line.split()[5]
  136. groups = line.split()[6:6+dim]
  137. theta_site = line.split()[6+dim:6+dim+1+dim]
  138. elif dim < breaks+1:
  139. line = swp_file.readline()
  140. continue
  141. elif dim > breaks+1:
  142. break
  143. #return 0,0,0
  144. # Read the next line
  145. line = swp_file.readline()
  146. #### END of parsing
  147. # quit this file if the number of dimensions is incorrect
  148. if dim < breaks+1:
  149. return 0,0
  150. # get n, the last bin of the last group
  151. # revert the list of groups as the most recent times correspond
  152. # to the closest and last leafs of the coal. tree.
  153. groups = groups[::-1]
  154. theta_site = theta_site[::-1]
  155. thetas = {}
  156. for i in range(len(groups)):
  157. groups[i] = groups[i].split(',')
  158. #print(groups[i], len(groups[i]))
  159. thetas[i] = [float(theta_site[i]), groups[i], likelihood]
  160. return thetas, sfs
  161. def plot_k_epochs_thetafolder(folder_path, mu, tgen, breaks = 2, title = "Title", theta_scale = True):
  162. scenari = {}
  163. cpt = 0
  164. for file_name in os.listdir(folder_path):
  165. if os.path.isfile(os.path.join(folder_path, file_name)):
  166. # Perform actions on each file
  167. x,y,likelihood,sfs,L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
  168. tgen = tgen,
  169. mu = mu, relative_theta_scale = theta_scale)
  170. if x == 0 or y == 0:
  171. continue
  172. cpt +=1
  173. scenari[likelihood] = x,y
  174. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
  175. print(cpt, "theta file(s) have been scanned.")
  176. # sort starting by the smallest -log(Likelihood)
  177. print(scenari)
  178. best10_scenari = (sorted(list(scenari.keys())))[:10]
  179. print("10 greatest Likelihoods", best10_scenari)
  180. greatest_likelihood = best10_scenari[0]
  181. x, y = scenari[greatest_likelihood]
  182. my_dpi = 300
  183. plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  184. plt.plot(x, y, 'r-', lw=2, label = 'Lik='+greatest_likelihood)
  185. plt.xlim(1e-3, 1)
  186. plt.ylim(0, 10)
  187. #plt.yscale('log')
  188. plt.xscale('log')
  189. plt.grid(True,which="both", linestyle='--', alpha = 0.3)
  190. for scenario in best10_scenari[1:]:
  191. x,y = scenari[scenario]
  192. #print("\n---- Lik:",scenario,"\n\nt=", x,"\n\nN=",y, "\n\n")
  193. plt.plot(x, y, '--', lw=1, label = 'Lik='+scenario)
  194. if theta_scale:
  195. plt.xlabel("Coal. time")
  196. plt.ylabel("Pop. size scaled by N0")
  197. recent_scale_lower_bound = y[0] * 0.01
  198. recent_scale_upper_bound = y[0] * 0.1
  199. plt.axvline(x=recent_scale_lower_bound)
  200. plt.axvline(x=recent_scale_upper_bound)
  201. else:
  202. # years
  203. plt.xlabel("Time (years)")
  204. plt.ylabel("Individuals (N)")
  205. plt.legend(loc='upper right')
  206. plt.title(title)
  207. plt.savefig(title+'_b'+str(breaks)+'.pdf')
  208. def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_scale = True):
  209. #scenari = {}
  210. cpt = 0
  211. epochs = {}
  212. for file_name in os.listdir(folder_path):
  213. breaks = 0
  214. cpt +=1
  215. if os.path.isfile(os.path.join(folder_path, file_name)):
  216. x, y, likelihood, sfs, L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
  217. tgen = tgen,
  218. mu = mu, relative_theta_scale = theta_scale)
  219. SFS_stored = sfs
  220. L_stored = L
  221. while not (x == 0 and y == 0):
  222. if breaks not in epochs.keys():
  223. epochs[breaks] = {}
  224. epochs[breaks][likelihood] = x,y
  225. breaks += 1
  226. x,y,likelihood,sfs,L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
  227. tgen = tgen,
  228. mu = mu, relative_theta_scale = theta_scale)
  229. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
  230. print(cpt, "theta file(s) have been scanned.")
  231. # intialize figure
  232. my_dpi = 300
  233. plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  234. plt.xlim(1e-3, 1)
  235. #plt.ylim(0, 10)
  236. plt.yscale('log')
  237. plt.xscale('log')
  238. plt.grid(True,which="both", linestyle='--', alpha = 0.3)
  239. brkpt_lik = []
  240. for epoch, scenari in epochs.items():
  241. # sort starting by the smallest -log(Likelihood)
  242. best10_scenari = (sorted(list(scenari.keys())))[:10]
  243. greatest_likelihood = best10_scenari[0]
  244. # store the tuple breakpoints and likelihood for later plot
  245. brkpt_lik.append((epoch, greatest_likelihood))
  246. x, y = scenari[greatest_likelihood]
  247. #without breakpoint
  248. if epoch == 0:
  249. # do something with the theta without bp and skip the plotting
  250. N0 = y[0]
  251. #continue
  252. for i in range(len(y)):
  253. # divide by N0
  254. y[i] = y[i]/N0
  255. x[i] = x[i]/N0
  256. plt.plot(x, y, 'o', linestyle = "-", alpha=0.75, lw=2, label = str(epoch)+' BrkPt | Lik='+greatest_likelihood)
  257. if theta_scale:
  258. plt.xlabel("Coal. time")
  259. plt.ylabel("Pop. size scaled by N0")
  260. recent_scale_lower_bound = 0.01
  261. recent_scale_upper_bound = 0.1
  262. #print(recent_scale_lower_bound, recent_scale_upper_bound)
  263. plt.axvline(x=recent_scale_lower_bound)
  264. plt.axvline(x=recent_scale_upper_bound)
  265. else:
  266. # years
  267. plt.xlabel("Time (years)")
  268. plt.ylabel("Individuals (N)")
  269. plt.xlim(1e-5, 1)
  270. plt.legend(loc='upper right')
  271. plt.title(title)
  272. plt.savefig(title+'_b'+str(breaks)+'.pdf')
  273. # plot likelihood against nb of breakpoints
  274. # best possible likelihood from SFS
  275. # Segregating sites
  276. S = sum(SFS_stored)
  277. # number of monomorphic sites
  278. L = L_stored
  279. S0 = L-S
  280. # print("SFS", SFS_stored)
  281. # print("S", S, "L", L, "S0=", S0)
  282. # compute Ln
  283. Ln = log_facto(S+S0) - log_facto(S0) + np.log(float(S0)/(S+S0)) * S0
  284. for xi in range(0, len(SFS_stored)):
  285. p_i = SFS_stored[xi] / float(S+S0)
  286. Ln += np.log(p_i) * SFS_stored[xi] - log_facto(SFS_stored[xi])
  287. res = Ln
  288. # print(res)
  289. # basic plot likelihood
  290. plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  291. plt.rcParams['font.size'] = '18'
  292. plt.plot(np.array(brkpt_lik)[:, 0], np.array(brkpt_lik)[:, 1].astype(float), 'o', linestyle = "dotted", lw=2)
  293. # plt.ylim(0,100)
  294. # plt.axhline(y=res)
  295. plt.yscale('log')
  296. plt.xlabel("# breakpoints", fontsize=20)
  297. plt.ylabel("$-\log\mathcal{L}$")
  298. #plt.legend(loc='upper right')
  299. plt.title(title)
  300. plt.savefig(title+'_Breakpts_Likelihood.pdf')
  301. # AIC
  302. plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  303. plt.rcParams['font.size'] = '18'
  304. AIC = 2*(len(brkpt_lik)+1)+2*np.array(brkpt_lik)[:, 1].astype(float)
  305. plt.plot(np.array(brkpt_lik)[:, 0], AIC, 'o', linestyle = "dotted", lw=2)
  306. # plt.axhline(y=106)
  307. plt.yscale('log')
  308. plt.xlabel("# breakpoints", fontsize=20)
  309. plt.ylabel("AIC")
  310. #plt.legend(loc='upper right')
  311. plt.title(title)
  312. plt.savefig(title+'_Breakpts_Likelihood_AIC.pdf')
  313. def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True, breaks_max = 12, ax = None):
  314. """
  315. Use theta values as is to do basic plots.
  316. """
  317. cpt = 0
  318. epochs = {}
  319. for file_name in os.listdir(folder_path):
  320. cpt +=1
  321. if os.path.isfile(os.path.join(folder_path, file_name)):
  322. for k in range(breaks_max):
  323. thetas,sfs = return_x_y_from_stwp_theta_file_as_is(folder_path+file_name, breaks = k,
  324. tgen = tgen,
  325. mu = mu, relative_theta_scale = theta_scale)
  326. if thetas == 0:
  327. continue
  328. epochs[k] = thetas
  329. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(k)+"\n*******\n")
  330. print(cpt, "theta file(s) have been scanned.")
  331. # intialize figure 1
  332. my_dpi = 300
  333. # plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  334. # multiple fig
  335. if ax is None:
  336. fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  337. # Add some extra space for the second axis at the bottom
  338. fig.subplots_adjust(bottom=0.15)
  339. else:
  340. ax1 = ax[0, 1]
  341. plt.subplots_adjust(wspace=0.3, hspace=0.3)
  342. twin = ax1.twiny()
  343. plots = []
  344. for epoch, theta in epochs.items():
  345. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  346. x = []
  347. y = []
  348. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  349. for i,group in enumerate(groups):
  350. x += group[::-1]
  351. y += list(np.repeat(thetas[i], len(group)))
  352. if epoch == 0:
  353. N0 = y[0]
  354. # compute the proportion of information used at each bin of the SFS
  355. sum_theta_i = 0
  356. for i in range(2, len(y)+2):
  357. sum_theta_i+=y[i-2] / (i-1)
  358. prop = []
  359. for k in range(2, len(y)+2):
  360. prop.append(y[k-2] / (k - 1) / sum_theta_i)
  361. prop = prop[::-1]
  362. # print(prop, "\n", sum(prop))
  363. # normalise to N0 (N0 of epoch1)
  364. for i in range(len(y)):
  365. y[i] = y[i]/N0
  366. # plot
  367. #plt.plot(x, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
  368. p, = ax1.plot(x, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
  369. # add plot to the list of all plots to superimpose
  370. plots.append(p)
  371. # virtual line to get the second x axis for proportions
  372. p0, = twin.plot(prop, y, alpha = 0, label="Proportion")
  373. # Move twinned axis ticks and label from top to bottom
  374. twin.xaxis.set_ticks_position("bottom")
  375. twin.xaxis.set_label_position("bottom")
  376. # Offset the twin axis below the host
  377. twin.spines["bottom"].set_position(("axes", -0.15))
  378. #ax.legend(handles=[p0]+plots)
  379. ax1.set_xlabel("# breaks")
  380. # Set the x-axis locator to reduce the number of ticks to 10
  381. ax1.xaxis.set_major_locator(MaxNLocator(nbins=10))
  382. ax1.set_ylabel("theta")
  383. # twin.set_ylabel("Proportion")
  384. plt.legend(handles=plots, loc='upper right')
  385. if ax is None:
  386. plt.savefig(title+'_raw'+str(k)+'.pdf')
  387. # fig 2 & 3
  388. if ax is None:
  389. fig2, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  390. fig3, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  391. else:
  392. # place of plots on the grid
  393. ax2 = ax[1,0]
  394. ax3 = ax[1,1]
  395. lines_fig2 = []
  396. lines_fig3 = []
  397. #plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  398. for epoch, theta in epochs.items():
  399. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  400. x = []
  401. y = []
  402. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  403. for i,group in enumerate(groups):
  404. x += group[::-1]
  405. y += list(np.repeat(thetas[i], len(group)))
  406. if epoch == 0:
  407. N0 = y[0]
  408. for i in range(len(y)):
  409. y[i] = y[i]/N0
  410. x_2 = []
  411. T = 0
  412. for i in range(len(x)):
  413. x[i] = int(x[i])
  414. # compute the times as: theta_k / (k*(k-1))
  415. for i in range(0, len(x)):
  416. T += y[i] / (x[i]*(x[i]-1))
  417. x_2.append(T)
  418. # Plotting (fig 2)
  419. p2, = ax2.plot(x_2, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
  420. lines_fig2.append(p2)
  421. # Plotting (fig 3) which is the same but log scale for x
  422. p3, = ax3.plot(x_2, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
  423. lines_fig3.append(p3)
  424. ax2.set_xlabel("# breaks")
  425. ax2.set_ylabel("theta")
  426. ax2.set_title("Test")
  427. ax2.legend(handles=lines_fig2, loc='upper right')
  428. if ax is None:
  429. plt.savefig(title+'_plot2_'+str(k)+'.pdf')
  430. ax3.set_xscale('log')
  431. ax3.set_xlabel("log()")
  432. ax3.set_ylabel("theta")
  433. ax3.set_title("Test")
  434. ax3.legend(handles=lines_fig3, loc='upper right')
  435. if ax is None:
  436. plt.savefig(title+'_plot3_'+str(k)+'_log.pdf')
  437. # return plots
  438. return ax
  439. def save_combined_pdf(output_path):
  440. with PdfPages(output_path) as pdf:
  441. pdf.savefig()
  442. def save_multi_image(filename):
  443. pp = PdfPages(filename)
  444. fig_nums = plt.get_fignums()
  445. figs = [plt.figure(n) for n in fig_nums]
  446. for fig in figs:
  447. fig.savefig(pp, format='pdf')
  448. pp.close()
  449. def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale = True):
  450. # plot1, plot2, plot3 = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale)
  451. my_dpi = 300
  452. # Add some extra space for the second axis at the bottom
  453. fig, axs = plt.subplots(2, 2, figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  454. ax = plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = axs)
  455. # Adjust layout to prevent clipping of titles
  456. plt.tight_layout()
  457. # Adjust absolute space between the top and bottom rows
  458. plt.subplots_adjust(hspace=0.35) # Adjust this value based on your requirement
  459. # Save the entire grid as a single figure
  460. plt.savefig(title+'_combined.pdf')
  461. if __name__ == "__main__":
  462. if len(sys.argv) != 4:
  463. print("Need 3 args: ThetaFolder MutationRate GenerationTime")
  464. exit(0)
  465. folder_path = sys.argv[1]
  466. mu = sys.argv[2]
  467. tgen = sys.argv[3]
  468. plot_all_epochs_thetafolder(folder_path, mu, tgen)