Browse Source

Fixing some scaling for swp2 output

tforest 2 months ago
parent
commit
1848102140
1 changed files with 98 additions and 36 deletions
  1. 98 36
      swp2.py

+ 98 - 36
swp2.py View File

@@ -176,7 +176,7 @@ def plot_all_epochs_thetafolder(full_dict, mu, tgen, title = "Title",
176 176
     ax1.set_title(title)
177 177
     breaks = len(full_dict['all_epochs']['plots'])
178 178
     if ax is None:
179
-        plt.savefig(title+'_'+str(breaks+1)+'_epochs.pdf')
179
+        plt.savefig(title+'_best_'+str(breaks+1)+'_epochs.pdf')
180 180
     # plot likelihood against nb of breakpoints
181 181
     if ax is None:
182 182
         fig, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
@@ -299,7 +299,15 @@ def save_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_sc
299 299
     # number of monomorphic sites
300 300
     S0 = L-S
301 301
     # print("SFS", SFS_stored)
302
-    # print("S", S, "L", L, "S0=", S0)
302
+    print("S", S, "L", L, "S0=", S0)
303
+
304
+    my_n = len(SFS_stored)*2
305
+    print("n=",my_n)
306
+    an = 1
307
+    for i in range(2, my_n):
308
+        an +=1.0/i
309
+    
310
+    print("an=", an, "theta_w", S/an, "theta_w_p_site", (S/an)/L)
303 311
     # compute Ln
304 312
     Ln = log_facto(S+S0) - log_facto(S0) + np.log(float(S0)/(S+S0)) * S0
305 313
     for xi in range(0, len(SFS_stored)):
@@ -413,6 +421,25 @@ def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
413 421
         cumul = val+cumul
414 422
     prop = prop_cumul
415 423
 
424
+
425
+    # print("raw stairs", plots[3])
426
+
427
+
428
+    # ###########
429
+
430
+    # time = []
431
+    # for k in plots[0][0]:
432
+    #     k = int(k)
433
+    #     dt = 2.0/(k*(k-1))
434
+    #     time.append(2.0/(k*(k-1)))
435
+
436
+    # Ne = []
437
+    # for values in plots:
438
+    #     Ne.append(np.array(values[1])/(4*mu))
439
+    # print(time)
440
+    # print(Ne[3])
441
+
442
+    
416 443
     lines_fig2 = []
417 444
     for epoch, theta in best_epochs.items():
418 445
         groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
@@ -423,24 +450,33 @@ def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
423 450
             x += group[::-1]
424 451
             y += list(np.repeat(thetas[i], len(group)))
425 452
             if epoch == 0:
426
-                N0 = y[0]
453
+                # watterson theta
454
+                theta_w = y[0]
427 455
         if theta_scale :
428 456
             for i in range(len(y)):
429 457
                 y[i] = y[i]/N0
458
+        for i in range(len(y)):
459
+            y[i] = y[i]/(4*mu)
430 460
         x_2 = []
431 461
         T = 0
432 462
         for i in range(len(x)):
433 463
             x[i] = int(x[i])
434 464
         # compute the times as: theta_k / (k*(k-1))
435 465
         for i in range(0, len(x)):
436
-            T += y[i] / (x[i]*(x[i]-1))
466
+            T += y[i]*2 / (x[i]*(x[i]-1))
437 467
             x_2.append(T)
438 468
         # Save plotting (fig 2)
439
-        x_2 = [0]+x_2
440
-        y = [y[0]]+y
469
+        # x_2 = [0]+x_2
470
+        # y = [y[0]]+y
441 471
         # x2_plot, y2_plot = plot_straight_x_y(x_2, y)
442 472
         p2 = x_2, y
443 473
         lines_fig2.append(p2)
474
+    # print("breaks=", epoch, "scaled_theta", lines_fig2[10])
475
+    # print(lines_fig2[3][1][0]/(4*mu))
476
+    # print(np.array(lines_fig2[3][1])/lines_fig2[3][1][0])
477
+    # print("size list y=", len(lines_fig2[3][1]))
478
+    #exit(0)
479
+
444 480
     if input == None:
445 481
         saved_plots = {"raw_stairs":plots, "scaled_stairs":lines_fig2,
446 482
                         "prop":prop}
@@ -458,9 +494,9 @@ def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
458 494
     return saved_plots
459 495
 
460 496
 def plot_scaled_theta(plot_lines, prop, title, mu, tgen, swp2_lines = None, ax = None, n_ticks = 10, subset = None, theta_scale = False):
461
-    recent_limit_years = 500
497
+    recent_limit_years = 100
462 498
     # recent limit in coal. time
463
-    recent_limit = recent_limit_years/tgen*mu
499
+    recent_limit = recent_limit_years/tgen
464 500
     # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
465 501
     nb_epochs = len(plot_lines)
466 502
     # fig 2 & 3
@@ -480,9 +516,9 @@ def plot_scaled_theta(plot_lines, prop, title, mu, tgen, swp2_lines = None, ax =
480 516
     #plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
481 517
     if swp2_lines:
482 518
         for k in range(len(swp2_lines[0])):
483
-            swp2_lines[0][k] = swp2_lines[0][k]/tgen*mu
519
+            swp2_lines[0][k] = swp2_lines[0][k]/tgen
484 520
         for k in range(len(swp2_lines[1])):
485
-            swp2_lines[1][k] = swp2_lines[1][k]*4*mu
521
+            swp2_lines[1][k] = swp2_lines[1][k]
486 522
         # x2_plot, y2_plot = plot_straight_x_y(swp2_lines[0],swp2_lines[1])
487 523
         x2_plot, y2_plot = swp2_lines[0], swp2_lines[1]
488 524
         p2, = ax2.plot(x2_plot, y2_plot, linestyle="-", alpha=0.75, lw=2, label = 'swp2', color="black")
@@ -508,14 +544,14 @@ def plot_scaled_theta(plot_lines, prop, title, mu, tgen, swp2_lines = None, ax =
508 544
 
509 545
                 # skip the base 0 points x_plot[0:3]
510 546
                 t_max_below_limit = 0
511
-                t_min_below_limit = 1
547
+                t_min_below_limit = recent_limit
512 548
                 recent_change = False
513 549
                 for t in x[1:]:
514 550
                     if t <= recent_limit:
515 551
                         recent_change = True
516 552
                         t_max_below_limit = max(t_max_below_limit, t)
517 553
                         t_min_below_limit = min(t_min_below_limit, t)
518
-                        Ne_max_below_limit = y[x.index(t_max_below_limit)]
554
+                        Ne_max_below_limit = y[min(x.index(t_max_below_limit)+1, len(y)-1)]
519 555
                         Ne_min_below_limit = y[x.index(t_min_below_limit)]
520 556
                 if recent_change:
521 557
                     print(f"\n{breaks} breaks ; This is below the recent limit of {recent_limit_years} years:\n",
@@ -547,6 +583,8 @@ def plot_scaled_theta(plot_lines, prop, title, mu, tgen, swp2_lines = None, ax =
547 583
             lines_fig3.append(p3)
548 584
     # put the vertical line of the "recent" time limit
549 585
     ax3.axvline(x=recent_limit, linestyle="--")
586
+    ax3.axvline(x=recent_limit/2, linestyle="--", color="green")
587
+
550 588
     if theta_scale:
551 589
         xlabel = "Theta scaled by N0"
552 590
         ylabel = "Theta scaled by N0"
@@ -557,26 +595,36 @@ def plot_scaled_theta(plot_lines, prop, title, mu, tgen, swp2_lines = None, ax =
557 595
         # if not ax, then use the plt syntax, not ax...
558 596
         plt.xlabel(xlabel, fontsize=fnt_size)
559 597
         plt.ylabel(ylabel, fontsize=fnt_size)
560
-        #plt.xlim(left=0)
598
+        plt.gca().set_xlim(0, recent_limit * 3)
599
+        if recent_change:
600
+            plt.ylim(Ne_min_below_limit/3, Ne_max_below_limit *3)
601
+        else:
602
+            plt.ylim(y2_plot[0]/3, y2_plot[0])
603
+        # plt.ylim(0, max(max_y+(max_y*0.05), max(swp2_lines[1])+(max(swp2_lines[1])*0.05)))
604
+        #plt.xlim(0, recent_limit * 3)
561 605
         #xlim_val = plt.gca().get_xlim()
562
-        #x_ticks = list(plt.xticks())[0]
563
-        plt.xlim(min(min_x,min(swp2_lines[0])), max(max(swp2_lines[0]), max_x))
564
-        x_ticks = list(plt.gca().get_xticks())
565
-        plt.gca().set_xticks(x_ticks)
606
+        x_ticks = list(plt.xticks())[0]
607
+        # plt.xlim(min(min_x,min(swp2_lines[0])), max(max(swp2_lines[0]), max_x))
608
+        # x_ticks = list(plt.gca().get_xticks())
609
+        # plt.gca().set_xticks(x_ticks)
566 610
         # plt.xticks(x_ticks)
567 611
         # plt.gca().set_xlim(xlim_val)
568
-        plt.gca().set_xticklabels([f'{k:.0e}\n{k/(mu):.0e}\n{k/(mu)*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5)
612
+        # plt.gca().set_xticklabels([f'{k:.0e}\n{k/(mu):.0e}\n{k/(mu)*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5)
613
+        plt.gca().set_xticklabels([f'{k:.1f}\n{k*tgen:.1f}' for k in x_ticks], fontsize = fnt_size*0.5)
614
+
569 615
         # rescale y to effective pop size
570 616
         # ylim_val = plt.gca().get_ylim()
571
-        plt.ylim(min(min_y,min(swp2_lines[1])), max(max_y+(max_y*0.05), max(swp2_lines[1])+(max(swp2_lines[1])*0.05)))        
572
-        y_ticks = list(plt.yticks())[0]
573
-        plt.gca().set_yticks(y_ticks)
617
+        # plt.ylim(min(min_y,min(swp2_lines[1])), max(max_y+(max_y*0.05), max(swp2_lines[1])+(max(swp2_lines[1])*0.05)))        
618
+        # y_ticks = list(plt.yticks())[0]
619
+        # plt.gca().set_yticks(y_ticks)
574 620
         # plt.gca().set_ylim(ylim_val)
575
-        plt.yticks(y_ticks)
576
-        plt.gca().set_yticklabels([f'{k/(4*mu):.0e}' for k in y_ticks], fontsize = fnt_size*0.5)
577
-        plt.title(title, fontsize=fnt_size)
578
-        plt.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
579
-        plt.text(-0.13, -0.135, 'Coal. time\nGen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
621
+        # plt.yticks(y_ticks)
622
+        # plt.gca().set_yticklabels([f'{k/(4*mu):.0e}' for k in y_ticks], fontsize = fnt_size*0.5)
623
+        # plt.title(title, fontsize=fnt_size)
624
+        # plt.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
625
+        # # plt.text(-0.13, -0.135, 'Coal. time\nGen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
626
+        plt.text(-0.13, -0.135, 'Gen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
627
+
580 628
         plt.subplots_adjust(bottom=0.2)  # Adjust the value as needed
581 629
         plt.savefig(title+'_plotB_'+str(nb_epochs)+'_epochs.pdf')
582 630
         # close fig2 to save memory
@@ -594,16 +642,30 @@ def plot_scaled_theta(plot_lines, prop, title, mu, tgen, swp2_lines = None, ax =
594 642
     ax3.set_xscale('log')
595 643
     ax3.set_yscale('log')
596 644
     # Scale the x-axis
597
-    x_ticks = list(ax3.get_xticks())
598
-    ax3.set_xticks(x_ticks)
599
-    ax3.set_xlim(min(min(x_ticks), min(swp2_lines[0])), max(max_x, max(swp2_lines[0])))
600
-    ax3.set_xticklabels([f'{k:.0e}\n{k/(mu):.0e}\n{k/(mu)*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5)
645
+    # x_ticks = list(ax3.get_xticks())
646
+    # ax3.set_xticks(x_ticks)
647
+    # x_ticks = [i for i in range(0.1,max(max_x, max(swp2_lines[0]))), ]
648
+    # ax3.set_xticks(x_ticks)
649
+    ax3.set_xlim(0.1, max(max_x, max(swp2_lines[0])))
650
+    x_ticks = ax3.get_xticks()
651
+    # ax3.set_xlim(min(min(x_ticks), min(swp2_lines[0])), max(max_x, max(swp2_lines[0])))
652
+    # ax3.set_xlim(1, max(max_x, max(swp2_lines[0])))
653
+    # ax3.set_xticklabels([f'{k:.0e}\n{k/(mu):.0e}\n{k/(mu)*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5)
654
+    # ax3.set_xticklabels([f'{k/(mu):.0e}\n{k/(mu)*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5)
655
+    ax3.set_xticklabels([f'{k:.0e}\n{k*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5)
656
+
601 657
     # rescale y to effective pop size
602
-    y_ticks = list(ax3.get_yticks())
603
-    ax3.set_yticks(y_ticks)
604
-    ax3.set_ylim(min(min(y_ticks), min(swp2_lines[1])), max(max_y+(max_y*0.5), max(swp2_lines[1])+(max(swp2_lines[1])*0.5)))
605
-    ax3.set_yticklabels([f'{k/(4*mu):.0e}' for k in y_ticks], fontsize = fnt_size*0.5)
606
-    plt.text(-0.13, -0.135, 'Coal. time\nGen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
658
+    # y_ticks = list(ax3.get_yticks())
659
+    # ax3.set_yticks(y_ticks)
660
+    # ax3.set_ylim(min(min(y_ticks), min(swp2_lines[1])), max(max_y+(max_y*0.5), max(swp2_lines[1])+(max(swp2_lines[1])*0.5)))
661
+    # ax3.set_ylim(1, max(max_y, max(swp2_lines[1])))
662
+    ax3.set_ylim(1, max(max_y+(max_y*0.5), max(swp2_lines[1])+(max(swp2_lines[1])*0.5)))
663
+
664
+    # ax3.set_yticklabels([f'{k/(4*mu):.0e}' for k in y_ticks], fontsize = fnt_size*0.5)
665
+    # plt.text(-0.13, -0.135, 'Coal. time\nGen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
666
+    # plt.text(-0.13, -0.135, 'Gen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
667
+    plt.text(-0.13, -0.085, 'Gen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
668
+    
607 669
     plt.subplots_adjust(bottom=0.2)  # Adjust the value as needed
608 670
     if ax is None:
609 671
         # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
@@ -638,7 +700,7 @@ def plot_raw_stairs(plot_lines, prop, title, ax = None, n_ticks = 10, rescale =
638 700
         x,y = plot
639 701
         x_plot, y_plot = plot_straight_x_y(x,y)
640 702
         p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(breaks)+' brks')
641
-
703
+        print("breaks=", breaks, "theta0", y[0])
642 704
         # add plot to the list of all plots to superimpose
643 705
         plots.append(p)
644 706
     x_ticks = x