swp2.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. import matplotlib.pyplot as plt
  2. import os
  3. import numpy as np
  4. import math
  5. from scipy.special import gammaln
  6. from matplotlib.backends.backend_pdf import PdfPages
  7. def log_facto(k):
  8. k = int(k)
  9. if k > 1e6:
  10. return k * np.log(k) - k + np.log(2*math.pi*k)/2
  11. val = 0
  12. for i in range(2, k+1):
  13. val += np.log(i)
  14. return val
  15. def log_facto_1(k):
  16. startf = 1 # start of factorial sequence
  17. stopf = int(k+1) # end of of factorial sequence
  18. q = gammaln(range(startf+1, stopf+1)) # n! = G(n+1)
  19. return q[-1]
  20. def return_x_y_from_stwp_theta_file(stwp_theta_file, breaks, mu, tgen, relative_theta_scale = False):
  21. with open(stwp_theta_file, "r") as swp_file:
  22. # Read the first line
  23. line = swp_file.readline()
  24. L = float(line.split()[2])
  25. rands = swp_file.readline()
  26. line = swp_file.readline()
  27. # skip empty lines before SFS
  28. while line == "\n":
  29. line = swp_file.readline()
  30. sfs = np.array(line.split()).astype(float)
  31. # Process lines until the end of the file
  32. while line:
  33. # check at each line
  34. if line.startswith("dim") :
  35. dim = int(line.split()[1])
  36. if dim == breaks+1:
  37. likelihood = line.split()[5]
  38. groups = line.split()[6:6+dim]
  39. theta_site = line.split()[6+dim:6+dim+1+dim]
  40. elif dim < breaks+1:
  41. line = swp_file.readline()
  42. continue
  43. elif dim > breaks+1:
  44. break
  45. #return 0,0,0
  46. # Read the next line
  47. line = swp_file.readline()
  48. #### END of parsing
  49. # quit this file if the number of dimensions is incorrect
  50. if dim < breaks+1:
  51. return 0,0,0,0,0
  52. # get n, the last bin of the last group
  53. # revert the list of groups as the most recent times correspond
  54. # to the closest and last leafs of the coal. tree.
  55. groups = groups[::-1]
  56. theta_site = theta_site[::-1]
  57. # initiate the dict of times
  58. t = {}
  59. # list of thetas
  60. theta_L = []
  61. sum_t = 0
  62. for group_nb, group in enumerate(groups):
  63. ###print(group_nb, group, theta_site[group_nb], len(theta_site))
  64. # store all the thetas one by one, with one theta per group
  65. theta_L.append(float(theta_site[group_nb]))
  66. # if the group is of size 1
  67. if len(group.split(',')) == 1:
  68. i = int(group)
  69. # if the group size is >1, take the first elem of the group
  70. # i is the first bin of each group, straight after a breakpoint
  71. else:
  72. i = int(group.split(",")[0])
  73. j = int(group.split(",")[-1])
  74. t[i] = 0
  75. #t =
  76. if len(group.split(',')) == 1:
  77. k = i
  78. if relative_theta_scale:
  79. t[i] += ((theta_L[group_nb] ) / (k*(k-1)))
  80. else:
  81. t[i] += ((theta_L[group_nb] ) / (k*(k-1)) * tgen) / mu
  82. else:
  83. for k in range(j, i-1, -1 ):
  84. if relative_theta_scale:
  85. t[i] += ((theta_L[group_nb] ) / (k*(k-1)))
  86. else:
  87. t[i] += ((theta_L[group_nb] ) / (k*(k-1)) * tgen) / mu
  88. # we add the cumulative times at the end
  89. t[i] += sum_t
  90. sum_t = t[i]
  91. # build the y axis (sizes)
  92. y = []
  93. for theta in theta_L:
  94. if relative_theta_scale:
  95. size = theta
  96. else:
  97. # with size N = theta/4mu
  98. size = theta / (4*mu)
  99. y.append(size)
  100. y.append(size)
  101. # build the time x axis
  102. x = [0]
  103. for time in range(0, len(t.values())-1):
  104. x.append(list(t.values())[time])
  105. x.append(list(t.values())[time])
  106. x.append(list(t.values())[len(t.values())-1])
  107. # if relative_theta_scale:
  108. # # rescale
  109. # #N0 = y[0]
  110. # # for i in range(len(y)):
  111. # # # divide by N0
  112. # # y[i] = y[i]/N0
  113. # # x[i] = x[i]/N0
  114. return x,y,likelihood,sfs,L
  115. def return_x_y_from_stwp_theta_file_as_is(stwp_theta_file, breaks, mu, tgen, relative_theta_scale = False):
  116. with open(stwp_theta_file, "r") as swp_file:
  117. # Read the first line
  118. line = swp_file.readline()
  119. L = float(line.split()[2])
  120. rands = swp_file.readline()
  121. line = swp_file.readline()
  122. # skip empty lines before SFS
  123. while line == "\n":
  124. line = swp_file.readline()
  125. sfs = np.array(line.split()).astype(float)
  126. # Process lines until the end of the file
  127. while line:
  128. # check at each line
  129. if line.startswith("dim") :
  130. dim = int(line.split()[1])
  131. if dim == breaks+1:
  132. likelihood = line.split()[5]
  133. groups = line.split()[6:6+dim]
  134. theta_site = line.split()[6+dim:6+dim+1+dim]
  135. elif dim < breaks+1:
  136. line = swp_file.readline()
  137. continue
  138. elif dim > breaks+1:
  139. break
  140. #return 0,0,0
  141. # Read the next line
  142. line = swp_file.readline()
  143. #### END of parsing
  144. # quit this file if the number of dimensions is incorrect
  145. if dim < breaks+1:
  146. return 0,0
  147. # get n, the last bin of the last group
  148. # revert the list of groups as the most recent times correspond
  149. # to the closest and last leafs of the coal. tree.
  150. groups = groups[::-1]
  151. theta_site = theta_site[::-1]
  152. thetas = {}
  153. for i in range(len(groups)):
  154. groups[i] = groups[i].split(',')
  155. #print(groups[i], len(groups[i]))
  156. thetas[i] = [float(theta_site[i]), groups[i], likelihood]
  157. return thetas, sfs
  158. def plot_k_epochs_thetafolder(folder_path, mu, tgen, breaks = 2, title = "Title", theta_scale = True):
  159. scenari = {}
  160. cpt = 0
  161. for file_name in os.listdir(folder_path):
  162. if os.path.isfile(os.path.join(folder_path, file_name)):
  163. # Perform actions on each file
  164. x,y,likelihood,sfs,L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
  165. tgen = tgen,
  166. mu = mu, relative_theta_scale = theta_scale)
  167. if x == 0 or y == 0:
  168. continue
  169. cpt +=1
  170. scenari[likelihood] = x,y
  171. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
  172. print(cpt, "theta file(s) have been scanned.")
  173. # sort starting by the smallest -log(Likelihood)
  174. print(scenari)
  175. best10_scenari = (sorted(list(scenari.keys())))[:10]
  176. print("10 greatest Likelihoods", best10_scenari)
  177. greatest_likelihood = best10_scenari[0]
  178. x, y = scenari[greatest_likelihood]
  179. my_dpi = 300
  180. plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  181. plt.plot(x, y, 'r-', lw=2, label = 'Lik='+greatest_likelihood)
  182. plt.xlim(1e-3, 1)
  183. plt.ylim(0, 10)
  184. #plt.yscale('log')
  185. plt.xscale('log')
  186. plt.grid(True,which="both", linestyle='--', alpha = 0.3)
  187. for scenario in best10_scenari[1:]:
  188. x,y = scenari[scenario]
  189. #print("\n---- Lik:",scenario,"\n\nt=", x,"\n\nN=",y, "\n\n")
  190. plt.plot(x, y, '--', lw=1, label = 'Lik='+scenario)
  191. if theta_scale:
  192. plt.xlabel("Coal. time")
  193. plt.ylabel("Pop. size scaled by N0")
  194. recent_scale_lower_bound = y[0] * 0.01
  195. recent_scale_upper_bound = y[0] * 0.1
  196. plt.axvline(x=recent_scale_lower_bound)
  197. plt.axvline(x=recent_scale_upper_bound)
  198. else:
  199. # years
  200. plt.xlabel("Time (years)")
  201. plt.ylabel("Individuals (N)")
  202. plt.legend(loc='upper right')
  203. plt.title(title)
  204. plt.savefig(title+'_b'+str(breaks)+'.pdf')
  205. def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_scale = True):
  206. #scenari = {}
  207. cpt = 0
  208. epochs = {}
  209. for file_name in os.listdir(folder_path):
  210. breaks = 0
  211. cpt +=1
  212. if os.path.isfile(os.path.join(folder_path, file_name)):
  213. x, y, likelihood, sfs, L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
  214. tgen = tgen,
  215. mu = mu, relative_theta_scale = theta_scale)
  216. SFS_stored = sfs
  217. L_stored = L
  218. while not (x == 0 and y == 0):
  219. if breaks not in epochs.keys():
  220. epochs[breaks] = {}
  221. epochs[breaks][likelihood] = x,y
  222. breaks += 1
  223. x,y,likelihood,sfs,L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
  224. tgen = tgen,
  225. mu = mu, relative_theta_scale = theta_scale)
  226. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
  227. print(cpt, "theta file(s) have been scanned.")
  228. # intialize figure
  229. my_dpi = 300
  230. plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  231. plt.xlim(1e-3, 1)
  232. #plt.ylim(0, 10)
  233. plt.yscale('log')
  234. plt.xscale('log')
  235. plt.grid(True,which="both", linestyle='--', alpha = 0.3)
  236. brkpt_lik = []
  237. for epoch, scenari in epochs.items():
  238. # sort starting by the smallest -log(Likelihood)
  239. best10_scenari = (sorted(list(scenari.keys())))[:10]
  240. greatest_likelihood = best10_scenari[0]
  241. # store the tuple breakpoints and likelihood for later plot
  242. brkpt_lik.append((epoch, greatest_likelihood))
  243. x, y = scenari[greatest_likelihood]
  244. #without breakpoint
  245. if epoch == 0:
  246. # do something with the theta without bp and skip the plotting
  247. N0 = y[0]
  248. #continue
  249. for i in range(len(y)):
  250. # divide by N0
  251. y[i] = y[i]/N0
  252. x[i] = x[i]/N0
  253. sum_theta_i = 0
  254. print(epoch, x, y)
  255. for i in range(2, len(y)-1):
  256. sum_theta_i=y[i] / (i-1)
  257. prop = []
  258. for k in range(2, len(y)-1):
  259. prop.append(y[k+1] / (k - 1) / sum_theta_i)
  260. #print(epoch, prop)
  261. plt.plot(x, y, 'o', linestyle = "-", alpha=0.75, lw=2, label = str(epoch)+' BrkPt | Lik='+greatest_likelihood)
  262. if theta_scale:
  263. plt.xlabel("Coal. time")
  264. plt.ylabel("Pop. size scaled by N0")
  265. recent_scale_lower_bound = 0.01
  266. recent_scale_upper_bound = 0.1
  267. #print(recent_scale_lower_bound, recent_scale_upper_bound)
  268. plt.axvline(x=recent_scale_lower_bound)
  269. plt.axvline(x=recent_scale_upper_bound)
  270. else:
  271. # years
  272. plt.xlabel("Time (years)")
  273. plt.ylabel("Individuals (N)")
  274. plt.xlim(1e-5, 1)
  275. plt.legend(loc='upper right')
  276. plt.title(title)
  277. plt.savefig(title+'_b'+str(breaks)+'.pdf')
  278. # plot likelihood against nb of breakpoints
  279. # best possible likelihood from SFS
  280. # Segregating sites
  281. S = sum(SFS_stored)
  282. # number of monomorphic sites
  283. L = L_stored
  284. S0 = L-S
  285. print("SFS", SFS_stored)
  286. print("S", S, "L", L, "S0=", S0)
  287. # compute Ln
  288. Ln = log_facto(S+S0) - log_facto(S0) + np.log(float(S0)/(S+S0)) * S0
  289. for xi in range(0, len(SFS_stored)):
  290. p_i = SFS_stored[xi] / float(S+S0)
  291. Ln += np.log(p_i) * SFS_stored[xi] - log_facto(SFS_stored[xi])
  292. res = Ln
  293. print(res)
  294. # basic plot likelihood
  295. plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  296. plt.rcParams['font.size'] = '18'
  297. plt.plot(np.array(brkpt_lik)[:, 0], np.array(brkpt_lik)[:, 1].astype(float), 'o', linestyle = "dotted", lw=2)
  298. # plt.ylim(0,100)
  299. # plt.axhline(y=res)
  300. plt.yscale('log')
  301. plt.xlabel("# breakpoints", fontsize=20)
  302. plt.ylabel("$-\log\mathcal{L}$")
  303. #plt.legend(loc='upper right')
  304. plt.title(title)
  305. plt.savefig(title+'_Breakpts_Likelihood.pdf')
  306. # AIC
  307. plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  308. plt.rcParams['font.size'] = '18'
  309. AIC = 2*(len(brkpt_lik)+1)+2*np.array(brkpt_lik)[:, 1].astype(float)
  310. plt.plot(np.array(brkpt_lik)[:, 0], AIC, 'o', linestyle = "dotted", lw=2)
  311. # plt.axhline(y=106)
  312. plt.yscale('log')
  313. plt.xlabel("# breakpoints", fontsize=20)
  314. plt.ylabel("AIC")
  315. #plt.legend(loc='upper right')
  316. plt.title(title)
  317. plt.savefig(title+'_Breakpts_Likelihood_AIC.pdf')
  318. def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True, breaks_max = 5):
  319. """
  320. Use theta values as is to do basic plots.
  321. """
  322. cpt = 0
  323. epochs = {}
  324. for file_name in os.listdir(folder_path):
  325. cpt +=1
  326. if os.path.isfile(os.path.join(folder_path, file_name)):
  327. for k in range(breaks_max):
  328. thetas,sfs = return_x_y_from_stwp_theta_file_as_is(folder_path+file_name, breaks = k,
  329. tgen = tgen,
  330. mu = mu, relative_theta_scale = theta_scale)
  331. if thetas == 0:
  332. continue
  333. epochs[k] = thetas
  334. print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(k)+"\n*******\n")
  335. print(cpt, "theta file(s) have been scanned.")
  336. # intialize figure 1
  337. my_dpi = 300
  338. plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  339. for epoch, theta in epochs.items():
  340. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  341. x = []
  342. y = []
  343. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  344. for i,group in enumerate(groups):
  345. x += group[::-1]
  346. y += list(np.repeat(thetas[i], len(group)))
  347. if epoch == 0:
  348. N0 = y[0]
  349. for i in range(len(y)):
  350. y[i] = y[i]/N0
  351. plt.plot(x, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
  352. plt.xlabel("# breaks")
  353. plt.ylabel("theta")
  354. plt.legend(loc='upper right')
  355. plt.savefig(title+'_test'+str(k)+'.pdf')
  356. # fig 2 & 3
  357. plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
  358. for epoch, theta in epochs.items():
  359. groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
  360. x = []
  361. y = []
  362. thetas = np.array(list(theta.values()), dtype=object)[:, 0]
  363. for i,group in enumerate(groups):
  364. x += group[::-1]
  365. y += list(np.repeat(thetas[i], len(group)))
  366. if epoch == 0:
  367. N0 = y[0]
  368. for i in range(len(y)):
  369. y[i] = y[i]/N0
  370. x_2 = []
  371. T = 0
  372. for i in range(len(x)):
  373. x[i] = int(x[i])
  374. # compute the times as: theta_k / (k*(k-1))
  375. for i in range(0, len(x)):
  376. T += y[i] / (x[i]*(x[i]-1))
  377. x_2.append(T)
  378. # Plotting (fig 2)
  379. plt.plot(x_2, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
  380. plt.xlabel("# breaks")
  381. plt.ylabel("theta")
  382. plt.legend(loc='upper right')
  383. plt.savefig(title+'_test'+str(k)+'.pdf')
  384. # Plotting (fig 3) which is the same but log scale for x
  385. plt.plot(x_2, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
  386. plt.xscale('log')
  387. plt.xlabel("# breaks")
  388. plt.ylabel("theta")
  389. plt.legend(loc='upper right')
  390. plt.savefig(title+'_test'+str(k)+'_log.pdf')
  391. def save_multi_image(filename):
  392. pp = PdfPages(filename)
  393. fig_nums = plt.get_fignums()
  394. figs = [plt.figure(n) for n in fig_nums]
  395. for fig in figs:
  396. fig.savefig(pp, format='pdf')
  397. pp.close()
  398. def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale = True):
  399. plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale)
  400. plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks)
  401. save_multi_image(title+"_combined.pdf")
  402. if __name__ == "__main__":
  403. if len(sys.argv) != 4:
  404. print("Need 3 args: ThetaFolder MutationRate GenerationTime")
  405. exit(0)
  406. folder_path = sys.argv[1]
  407. mu = sys.argv[2]
  408. tgen = sys.argv[3]
  409. plot_all_epochs_thetafolder(folder_path, mu, tgen)