2 Commits 14538a747b ... 1848102140

Author SHA1 Message Date
  tforest 1848102140 Fixing some scaling for swp2 output 2 months ago
  tforest 25d7ef0858 Improving sfs plotting 2 months ago
3 changed files with 166 additions and 45 deletions
  1. 4 3
      customgraphics.py
  2. 64 6
      sfs_tools.py
  3. 98 36
      swp2.py

+ 4 - 3
customgraphics.py View File

@@ -200,11 +200,11 @@ def scatter(x, y, ylab=None, xlab=None, title=None):
200 200
         plt.title(title)
201 201
     plt.show()
202 202
 
203
-def barplot(x=None, y=None, ylab=None, xlab=None, title=None, label=None, xticks = None, width=1):
203
+def barplot(x=None, y=None, ylab=None, xlab=None, title=None, label=None, xticks = None, width=1, plot = True):
204 204
     if x:
205 205
         x = list(x)
206 206
         plt.xticks(x)
207
-        plt.bar(x, y, width=width, label=label)
207
+        plt.bar(x, y, width=width, label=label, color="tab:blue")
208 208
     else:
209 209
         x = list(range(len(y)))
210 210
         plt.bar(x, y, width=width, label=label)
@@ -218,7 +218,8 @@ def barplot(x=None, y=None, ylab=None, xlab=None, title=None, label=None, xticks
218 218
     if xticks:
219 219
         plt.xticks(xticks)
220 220
     plt.legend()
221
-    plt.show()
221
+    if plot:
222
+        plt.show()
222 223
 
223 224
 def plot_chrom_continuity(vcf_entries, chr_id, x=None, y=None, outfile = None,
224 225
                           outfolder = None, returned=False, show=True, label=True, step=1, nb_subplots = None,

+ 64 - 6
sfs_tools.py View File

@@ -23,6 +23,54 @@ import matplotlib.pyplot as plt
23 23
 from frst import customgraphics
24 24
 import numpy as np
25 25
 
26
+def parse_sfs(sfs_file):
27
+    """
28
+    Parse a Site Frequency Spectrum (SFS) file and return a masked spectrum.
29
+
30
+    This function reads an SFS file, extracts the spectrum data, and applies a mask to it.
31
+    The mask excludes specific bins from the spectrum, resulting in a masked SFS.
32
+
33
+    Parameters:
34
+    - sfs_file (str): The path to the SFS file to be parsed, in dadi's .fs format.
35
+
36
+    Returns:
37
+    - masked_spectrum (list): A masked SFS as a list of integers.
38
+
39
+    Raises:
40
+    - FileNotFoundError: If the specified SFS file is not found.
41
+    - ValueError: If there are inconsistencies in the file format or data.
42
+
43
+    Note: The actual structure of the SFS file is based on dadi's fs format.
44
+    """
45
+    try:
46
+        with open(sfs_file, 'r') as file:
47
+            # Read the first line which contains information about the file
48
+            num_individuals, mode, species_name = file.readline().strip().split()
49
+            num_individuals = int(num_individuals)
50
+            # Read the spectrum data
51
+            spectrum_data = list(map(int, file.readline().strip().split()))
52
+            # Check if the number of bins in the spectrum matches the expected number
53
+            if len(spectrum_data) != num_individuals:
54
+                raise ValueError("Error: Number of bins in the spectrum doesn't match the expected number of individuals.")
55
+            # Read the mask data
56
+            mask_data = list(map(int, file.readline().strip().split()))
57
+
58
+            # Check if the size of the mask matches the number of bins in the spectrum
59
+            if len(mask_data) != num_individuals:
60
+                raise ValueError("Error: Size of the mask doesn't match the number of bins in the spectrum.")
61
+            # Apply the mask to the spectrum
62
+            masked_spectrum = [spectrum_data[i] for i in range(num_individuals) if not mask_data[i]]
63
+    # Error handling
64
+    except FileNotFoundError:
65
+        print(f"Error: File not found - {sfs_file}")
66
+    except ValueError as ve:
67
+        print(f"Error: {ve}")
68
+    except Exception as e:
69
+        print(f"Error: {e}")
70
+    # final return of SFS as a list
71
+    return masked_spectrum
72
+
73
+
26 74
 def sfs_from_vcf(n, vcf_file, folded = True, diploid = True, phased = False, verbose = False,
27 75
                  strip = False, count_ext = False):
28 76
 
@@ -193,7 +241,7 @@ def sfs_from_parsed_vcf(n, vcf_dict, folded = True, diploid = True, phased = Fal
193 241
     return SFS_values, count_pluriall
194 242
 
195 243
 
196
-def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed = False, normalized = False, ploidy = 2):
244
+def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed = False, normalized = False, ploidy = 2, output = None):
197 245
     sfs_val = []
198 246
     n = len(sfs.values())
199 247
     sum_sites = sum(list(sfs.values()))
@@ -242,7 +290,8 @@ def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed =
242 290
     else:
243 291
         # the spectrum is n-1 long when unfolded
244 292
         n_title = n+1
245
-    
293
+    original_title = title
294
+    # reformat title and add infos
246 295
     title = title+" (n="+str(n_title)+") [folded="+str(folded)+"]"+" [transformed="+str(transformed)+"]"
247 296
     print("SFS =", sfs)
248 297
 
@@ -252,7 +301,7 @@ def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed =
252 301
     if transformed:
253 302
         print("Transformed SFS ( n =",n_title, ") :", sfs_val)
254 303
         #plt.axhline(y=1/n, color='r', linestyle='-')
255
-        plt.bar([x+0.2 for x in list(sfs.keys())], [1/n]*n, color='r', linestyle='-', width = 0.4, label= "H0 Theoric constant")
304
+        plt.bar([x+0.2 for x in list(sfs.keys())], [1/n]*n, fill=False, hatch="///", linestyle='-', width = 0.4, label= "H0 Theoric constant")
256 305
 
257 306
     else:
258 307
         if normalized:
@@ -260,10 +309,19 @@ def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed =
260 309
             sum_expected = sum([(1/(i+1)) for i,x in enumerate(list(sfs.keys()))])
261 310
             expected_y = [(1/(i+1))/sum_expected for i,x in enumerate(list(sfs.keys()))]
262 311
             print(expected_y)
263
-            plt.bar([x+0.2 for x in list(sfs.keys())], expected_y, color='r', linestyle='-', width = 0.4, label= "H0 Theoric constant")
312
+            plt.bar([x+0.2 for x in list(sfs.keys())], expected_y, fill=False, hatch="///", linestyle='-', width = 0.4, label= "H0 Theoric constant")
264 313
             print(sum(expected_y))
265
-    customgraphics.barplot(x = [x-0.2 for x in X_axis], width=0.4, y= sfs_val, xlab = xlab, ylab = ylab, title = title, label = "H1 Observed spectrum", xticks =list(sfs.keys()) )
266
-    plt.show()
314
+    if output is not None:
315
+        # if write in a file, don't open the window dynamically
316
+        plot = False
317
+    else:
318
+        plot = True
319
+    customgraphics.barplot(x = [x-0.2 for x in X_axis], width=0.4, y= sfs_val, xlab = xlab, ylab = ylab, title = title, label = "H1 Observed spectrum", xticks =list(sfs.keys()), plot = plot )
320
+    if output:
321
+       plt.savefig(f"{output}/{original_title}_SFS.pdf")
322
+    else:
323
+        plt.show()
324
+    plt.close()
267 325
 
268 326
 if __name__ == "__main__":
269 327
             

+ 98 - 36
swp2.py View File

@@ -176,7 +176,7 @@ def plot_all_epochs_thetafolder(full_dict, mu, tgen, title = "Title",
176 176
     ax1.set_title(title)
177 177
     breaks = len(full_dict['all_epochs']['plots'])
178 178
     if ax is None:
179
-        plt.savefig(title+'_'+str(breaks+1)+'_epochs.pdf')
179
+        plt.savefig(title+'_best_'+str(breaks+1)+'_epochs.pdf')
180 180
     # plot likelihood against nb of breakpoints
181 181
     if ax is None:
182 182
         fig, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
@@ -299,7 +299,15 @@ def save_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_sc
299 299
     # number of monomorphic sites
300 300
     S0 = L-S
301 301
     # print("SFS", SFS_stored)
302
-    # print("S", S, "L", L, "S0=", S0)
302
+    print("S", S, "L", L, "S0=", S0)
303
+
304
+    my_n = len(SFS_stored)*2
305
+    print("n=",my_n)
306
+    an = 1
307
+    for i in range(2, my_n):
308
+        an +=1.0/i
309
+    
310
+    print("an=", an, "theta_w", S/an, "theta_w_p_site", (S/an)/L)
303 311
     # compute Ln
304 312
     Ln = log_facto(S+S0) - log_facto(S0) + np.log(float(S0)/(S+S0)) * S0
305 313
     for xi in range(0, len(SFS_stored)):
@@ -413,6 +421,25 @@ def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
413 421
         cumul = val+cumul
414 422
     prop = prop_cumul
415 423
 
424
+
425
+    # print("raw stairs", plots[3])
426
+
427
+
428
+    # ###########
429
+
430
+    # time = []
431
+    # for k in plots[0][0]:
432
+    #     k = int(k)
433
+    #     dt = 2.0/(k*(k-1))
434
+    #     time.append(2.0/(k*(k-1)))
435
+
436
+    # Ne = []
437
+    # for values in plots:
438
+    #     Ne.append(np.array(values[1])/(4*mu))
439
+    # print(time)
440
+    # print(Ne[3])
441
+
442
+    
416 443
     lines_fig2 = []
417 444
     for epoch, theta in best_epochs.items():
418 445
         groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
@@ -423,24 +450,33 @@ def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
423 450
             x += group[::-1]
424 451
             y += list(np.repeat(thetas[i], len(group)))
425 452
             if epoch == 0:
426
-                N0 = y[0]
453
+                # watterson theta
454
+                theta_w = y[0]
427 455
         if theta_scale :
428 456
             for i in range(len(y)):
429 457
                 y[i] = y[i]/N0
458
+        for i in range(len(y)):
459
+            y[i] = y[i]/(4*mu)
430 460
         x_2 = []
431 461
         T = 0
432 462
         for i in range(len(x)):
433 463
             x[i] = int(x[i])
434 464
         # compute the times as: theta_k / (k*(k-1))
435 465
         for i in range(0, len(x)):
436
-            T += y[i] / (x[i]*(x[i]-1))
466
+            T += y[i]*2 / (x[i]*(x[i]-1))
437 467
             x_2.append(T)
438 468
         # Save plotting (fig 2)
439
-        x_2 = [0]+x_2
440
-        y = [y[0]]+y
469
+        # x_2 = [0]+x_2
470
+        # y = [y[0]]+y
441 471
         # x2_plot, y2_plot = plot_straight_x_y(x_2, y)
442 472
         p2 = x_2, y
443 473
         lines_fig2.append(p2)
474
+    # print("breaks=", epoch, "scaled_theta", lines_fig2[10])
475
+    # print(lines_fig2[3][1][0]/(4*mu))
476
+    # print(np.array(lines_fig2[3][1])/lines_fig2[3][1][0])
477
+    # print("size list y=", len(lines_fig2[3][1]))
478
+    #exit(0)
479
+
444 480
     if input == None:
445 481
         saved_plots = {"raw_stairs":plots, "scaled_stairs":lines_fig2,
446 482
                         "prop":prop}
@@ -458,9 +494,9 @@ def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
458 494
     return saved_plots
459 495
 
460 496
 def plot_scaled_theta(plot_lines, prop, title, mu, tgen, swp2_lines = None, ax = None, n_ticks = 10, subset = None, theta_scale = False):
461
-    recent_limit_years = 500
497
+    recent_limit_years = 100
462 498
     # recent limit in coal. time
463
-    recent_limit = recent_limit_years/tgen*mu
499
+    recent_limit = recent_limit_years/tgen
464 500
     # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
465 501
     nb_epochs = len(plot_lines)
466 502
     # fig 2 & 3
@@ -480,9 +516,9 @@ def plot_scaled_theta(plot_lines, prop, title, mu, tgen, swp2_lines = None, ax =
480 516
     #plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
481 517
     if swp2_lines:
482 518
         for k in range(len(swp2_lines[0])):
483
-            swp2_lines[0][k] = swp2_lines[0][k]/tgen*mu
519
+            swp2_lines[0][k] = swp2_lines[0][k]/tgen
484 520
         for k in range(len(swp2_lines[1])):
485
-            swp2_lines[1][k] = swp2_lines[1][k]*4*mu
521
+            swp2_lines[1][k] = swp2_lines[1][k]
486 522
         # x2_plot, y2_plot = plot_straight_x_y(swp2_lines[0],swp2_lines[1])
487 523
         x2_plot, y2_plot = swp2_lines[0], swp2_lines[1]
488 524
         p2, = ax2.plot(x2_plot, y2_plot, linestyle="-", alpha=0.75, lw=2, label = 'swp2', color="black")
@@ -508,14 +544,14 @@ def plot_scaled_theta(plot_lines, prop, title, mu, tgen, swp2_lines = None, ax =
508 544
 
509 545
                 # skip the base 0 points x_plot[0:3]
510 546
                 t_max_below_limit = 0
511
-                t_min_below_limit = 1
547
+                t_min_below_limit = recent_limit
512 548
                 recent_change = False
513 549
                 for t in x[1:]:
514 550
                     if t <= recent_limit:
515 551
                         recent_change = True
516 552
                         t_max_below_limit = max(t_max_below_limit, t)
517 553
                         t_min_below_limit = min(t_min_below_limit, t)
518
-                        Ne_max_below_limit = y[x.index(t_max_below_limit)]
554
+                        Ne_max_below_limit = y[min(x.index(t_max_below_limit)+1, len(y)-1)]
519 555
                         Ne_min_below_limit = y[x.index(t_min_below_limit)]
520 556
                 if recent_change:
521 557
                     print(f"\n{breaks} breaks ; This is below the recent limit of {recent_limit_years} years:\n",
@@ -547,6 +583,8 @@ def plot_scaled_theta(plot_lines, prop, title, mu, tgen, swp2_lines = None, ax =
547 583
             lines_fig3.append(p3)
548 584
     # put the vertical line of the "recent" time limit
549 585
     ax3.axvline(x=recent_limit, linestyle="--")
586
+    ax3.axvline(x=recent_limit/2, linestyle="--", color="green")
587
+
550 588
     if theta_scale:
551 589
         xlabel = "Theta scaled by N0"
552 590
         ylabel = "Theta scaled by N0"
@@ -557,26 +595,36 @@ def plot_scaled_theta(plot_lines, prop, title, mu, tgen, swp2_lines = None, ax =
557 595
         # if not ax, then use the plt syntax, not ax...
558 596
         plt.xlabel(xlabel, fontsize=fnt_size)
559 597
         plt.ylabel(ylabel, fontsize=fnt_size)
560
-        #plt.xlim(left=0)
598
+        plt.gca().set_xlim(0, recent_limit * 3)
599
+        if recent_change:
600
+            plt.ylim(Ne_min_below_limit/3, Ne_max_below_limit *3)
601
+        else:
602
+            plt.ylim(y2_plot[0]/3, y2_plot[0])
603
+        # plt.ylim(0, max(max_y+(max_y*0.05), max(swp2_lines[1])+(max(swp2_lines[1])*0.05)))
604
+        #plt.xlim(0, recent_limit * 3)
561 605
         #xlim_val = plt.gca().get_xlim()
562
-        #x_ticks = list(plt.xticks())[0]
563
-        plt.xlim(min(min_x,min(swp2_lines[0])), max(max(swp2_lines[0]), max_x))
564
-        x_ticks = list(plt.gca().get_xticks())
565
-        plt.gca().set_xticks(x_ticks)
606
+        x_ticks = list(plt.xticks())[0]
607
+        # plt.xlim(min(min_x,min(swp2_lines[0])), max(max(swp2_lines[0]), max_x))
608
+        # x_ticks = list(plt.gca().get_xticks())
609
+        # plt.gca().set_xticks(x_ticks)
566 610
         # plt.xticks(x_ticks)
567 611
         # plt.gca().set_xlim(xlim_val)
568
-        plt.gca().set_xticklabels([f'{k:.0e}\n{k/(mu):.0e}\n{k/(mu)*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5)
612
+        # plt.gca().set_xticklabels([f'{k:.0e}\n{k/(mu):.0e}\n{k/(mu)*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5)
613
+        plt.gca().set_xticklabels([f'{k:.1f}\n{k*tgen:.1f}' for k in x_ticks], fontsize = fnt_size*0.5)
614
+
569 615
         # rescale y to effective pop size
570 616
         # ylim_val = plt.gca().get_ylim()
571
-        plt.ylim(min(min_y,min(swp2_lines[1])), max(max_y+(max_y*0.05), max(swp2_lines[1])+(max(swp2_lines[1])*0.05)))        
572
-        y_ticks = list(plt.yticks())[0]
573
-        plt.gca().set_yticks(y_ticks)
617
+        # plt.ylim(min(min_y,min(swp2_lines[1])), max(max_y+(max_y*0.05), max(swp2_lines[1])+(max(swp2_lines[1])*0.05)))        
618
+        # y_ticks = list(plt.yticks())[0]
619
+        # plt.gca().set_yticks(y_ticks)
574 620
         # plt.gca().set_ylim(ylim_val)
575
-        plt.yticks(y_ticks)
576
-        plt.gca().set_yticklabels([f'{k/(4*mu):.0e}' for k in y_ticks], fontsize = fnt_size*0.5)
577
-        plt.title(title, fontsize=fnt_size)
578
-        plt.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
579
-        plt.text(-0.13, -0.135, 'Coal. time\nGen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
621
+        # plt.yticks(y_ticks)
622
+        # plt.gca().set_yticklabels([f'{k/(4*mu):.0e}' for k in y_ticks], fontsize = fnt_size*0.5)
623
+        # plt.title(title, fontsize=fnt_size)
624
+        # plt.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
625
+        # # plt.text(-0.13, -0.135, 'Coal. time\nGen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
626
+        plt.text(-0.13, -0.135, 'Gen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
627
+
580 628
         plt.subplots_adjust(bottom=0.2)  # Adjust the value as needed
581 629
         plt.savefig(title+'_plotB_'+str(nb_epochs)+'_epochs.pdf')
582 630
         # close fig2 to save memory
@@ -594,16 +642,30 @@ def plot_scaled_theta(plot_lines, prop, title, mu, tgen, swp2_lines = None, ax =
594 642
     ax3.set_xscale('log')
595 643
     ax3.set_yscale('log')
596 644
     # Scale the x-axis
597
-    x_ticks = list(ax3.get_xticks())
598
-    ax3.set_xticks(x_ticks)
599
-    ax3.set_xlim(min(min(x_ticks), min(swp2_lines[0])), max(max_x, max(swp2_lines[0])))
600
-    ax3.set_xticklabels([f'{k:.0e}\n{k/(mu):.0e}\n{k/(mu)*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5)
645
+    # x_ticks = list(ax3.get_xticks())
646
+    # ax3.set_xticks(x_ticks)
647
+    # x_ticks = [i for i in range(0.1,max(max_x, max(swp2_lines[0]))), ]
648
+    # ax3.set_xticks(x_ticks)
649
+    ax3.set_xlim(0.1, max(max_x, max(swp2_lines[0])))
650
+    x_ticks = ax3.get_xticks()
651
+    # ax3.set_xlim(min(min(x_ticks), min(swp2_lines[0])), max(max_x, max(swp2_lines[0])))
652
+    # ax3.set_xlim(1, max(max_x, max(swp2_lines[0])))
653
+    # ax3.set_xticklabels([f'{k:.0e}\n{k/(mu):.0e}\n{k/(mu)*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5)
654
+    # ax3.set_xticklabels([f'{k/(mu):.0e}\n{k/(mu)*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5)
655
+    ax3.set_xticklabels([f'{k:.0e}\n{k*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5)
656
+
601 657
     # rescale y to effective pop size
602
-    y_ticks = list(ax3.get_yticks())
603
-    ax3.set_yticks(y_ticks)
604
-    ax3.set_ylim(min(min(y_ticks), min(swp2_lines[1])), max(max_y+(max_y*0.5), max(swp2_lines[1])+(max(swp2_lines[1])*0.5)))
605
-    ax3.set_yticklabels([f'{k/(4*mu):.0e}' for k in y_ticks], fontsize = fnt_size*0.5)
606
-    plt.text(-0.13, -0.135, 'Coal. time\nGen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
658
+    # y_ticks = list(ax3.get_yticks())
659
+    # ax3.set_yticks(y_ticks)
660
+    # ax3.set_ylim(min(min(y_ticks), min(swp2_lines[1])), max(max_y+(max_y*0.5), max(swp2_lines[1])+(max(swp2_lines[1])*0.5)))
661
+    # ax3.set_ylim(1, max(max_y, max(swp2_lines[1])))
662
+    ax3.set_ylim(1, max(max_y+(max_y*0.5), max(swp2_lines[1])+(max(swp2_lines[1])*0.5)))
663
+
664
+    # ax3.set_yticklabels([f'{k/(4*mu):.0e}' for k in y_ticks], fontsize = fnt_size*0.5)
665
+    # plt.text(-0.13, -0.135, 'Coal. time\nGen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
666
+    # plt.text(-0.13, -0.135, 'Gen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
667
+    plt.text(-0.13, -0.085, 'Gen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
668
+    
607 669
     plt.subplots_adjust(bottom=0.2)  # Adjust the value as needed
608 670
     if ax is None:
609 671
         # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
@@ -638,7 +700,7 @@ def plot_raw_stairs(plot_lines, prop, title, ax = None, n_ticks = 10, rescale =
638 700
         x,y = plot
639 701
         x_plot, y_plot = plot_straight_x_y(x,y)
640 702
         p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(breaks)+' brks')
641
-
703
+        print("breaks=", breaks, "theta0", y[0])
642 704
         # add plot to the list of all plots to superimpose
643 705
         plots.append(p)
644 706
     x_ticks = x