Browse Source

First version of the new layout of combined plots

tforest 11 months ago
parent
commit
89813468b5
1 changed files with 89 additions and 58 deletions
  1. 89 58
      swp2.py

+ 89 - 58
swp2.py View File

269
     else:
269
     else:
270
         fnt_size = 12
270
         fnt_size = 12
271
         # plt.rcParams['font.size'] = fnt_size
271
         # plt.rcParams['font.size'] = fnt_size
272
-        ax1 = ax[0,0]
272
+        ax1 = ax[1][0,0]
273
     ax1.set_yscale('log')
273
     ax1.set_yscale('log')
274
     ax1.set_xscale('log')
274
     ax1.set_xscale('log')
275
     ax1.grid(True,which="both", linestyle='--', alpha = 0.3)
275
     ax1.grid(True,which="both", linestyle='--', alpha = 0.3)
346
         # plt.rcParams['font.size'] = fnt_size
346
         # plt.rcParams['font.size'] = fnt_size
347
     else:
347
     else:
348
         #plt.rcParams['font.size'] = fnt_size
348
         #plt.rcParams['font.size'] = fnt_size
349
-        ax2 = ax[2,0]
349
+        ax2 = ax[0][0,1]
350
     ax2.plot(np.array(brkpt_lik)[:, 0], np.array(brkpt_lik)[:, 1].astype(float), 'o', linestyle = "dotted", lw=2)
350
     ax2.plot(np.array(brkpt_lik)[:, 0], np.array(brkpt_lik)[:, 1].astype(float), 'o', linestyle = "dotted", lw=2)
351
     ax2.axhline(y=-Ln, linestyle = "-.", color = "red", label = "$-\log\mathcal{L}$ = "+str(round(-Ln, 2)))
351
     ax2.axhline(y=-Ln, linestyle = "-.", color = "red", label = "$-\log\mathcal{L}$ = "+str(round(-Ln, 2)))
352
     ax2.set_yscale('log')
352
     ax2.set_yscale('log')
362
         # plt.rcParams['font.size'] = '18'
362
         # plt.rcParams['font.size'] = '18'
363
     else:
363
     else:
364
         #plt.rcParams['font.size'] = fnt_size
364
         #plt.rcParams['font.size'] = fnt_size
365
-        ax3 = ax[2,1]
366
-    AIC = 2*(len(brkpt_lik)+1)+2*np.array(brkpt_lik)[:, 1].astype(float)
365
+        ax3 = ax[1][0,1]
366
+    AIC = []
367
+    for brk in np.array(brkpt_lik)[:, 0]:
368
+        brk = int(brk)
369
+        AIC.append((2*brk+1)+2*np.array(brkpt_lik)[brk, 1].astype(float))
367
     ax3.plot(np.array(brkpt_lik)[:, 0], AIC, 'o', linestyle = "dotted", lw=2)
370
     ax3.plot(np.array(brkpt_lik)[:, 0], AIC, 'o', linestyle = "dotted", lw=2)
368
-    AIC_ln = 2*(len(brkpt_lik)+1)-2*Ln
371
+    # AIC = 2*k - 2ln(L) ; where k is the number of parameters, here brks+1
372
+    AIC_ln = 2*(len(brkpt_lik)+1) - 2*Ln
369
     ax3.axhline(y=AIC_ln, linestyle = "-.", color = "red",
373
     ax3.axhline(y=AIC_ln, linestyle = "-.", color = "red",
370
     label = "Min. AIC = "+str(round(AIC_ln, 2)))
374
     label = "Min. AIC = "+str(round(AIC_ln, 2)))
375
+    selected_brks_nb = AIC.index(min(AIC))
371
     ax3.set_yscale('log')
376
     ax3.set_yscale('log')
372
     ax3.set_xlabel("# breakpoints", fontsize=fnt_size)
377
     ax3.set_xlabel("# breakpoints", fontsize=fnt_size)
373
     ax3.set_ylabel("AIC")
378
     ax3.set_ylabel("AIC")
377
         plt.savefig(title+'_Breakpts_Likelihood_AIC.pdf')
382
         plt.savefig(title+'_Breakpts_Likelihood_AIC.pdf')
378
     print("S", S)
383
     print("S", S)
379
     # return plots
384
     # return plots
380
-    return ax
385
+    return ax[0], ax[1]
381
 
386
 
382
 def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
387
 def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
383
     breaks_max = 10, output = None):
388
     breaks_max = 10, output = None):
482
         json.dump(saved_plots, json_file)
487
         json.dump(saved_plots, json_file)
483
     return saved_plots
488
     return saved_plots
484
 
489
 
485
-def plot_raw_stairs(plot_lines, plot_lines2, prop, title, ax = None, n_ticks = 10):
490
+def plot_scaled_theta(plot_lines, prop, title, ax = None, n_ticks = 10):
491
+    # fig 2 & 3
492
+    if ax is None:
493
+        my_dpi = 300
494
+        fnt_size = 18
495
+        fig2, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
496
+        fig3, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
497
+    else:
498
+        # plt.rcParams['font.size'] = fnt_size
499
+        fnt_size = 12
500
+        # place of plots on the grid
501
+        ax2 = ax[1,0]
502
+        ax3 = ax[1,1]
503
+    lines_fig2 = []
504
+    lines_fig3 = []
505
+    #plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
506
+    for epoch, plot in enumerate(plot_lines):
507
+        x,y=plot
508
+        x2_plot, y2_plot = plot_straight_x_y(x,y)
509
+        p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
510
+        lines_fig2.append(p2)
511
+        # Plotting (fig 3) which is the same but log scale for x
512
+        p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
513
+        lines_fig3.append(p3)
514
+    ax2.set_xlabel("Relative scale", fontsize=fnt_size)
515
+    ax2.set_ylabel("theta", fontsize=fnt_size)
516
+    ax2.set_title(title, fontsize=fnt_size)
517
+    ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
518
+    if ax is None:
519
+        # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
520
+        plt.savefig(title+'_plot2_'+str(len(plot_lines))+'.pdf')
521
+        # close fig2 to save memory
522
+        plt.close(fig2)
523
+    ax3.set_xscale('log')
524
+    ax3.set_yscale('log')
525
+    ax3.set_xlabel("log Relative scale", fontsize=fnt_size)
526
+    ax3.set_ylabel("theta", fontsize=fnt_size)
527
+    ax3.set_title(title, fontsize=fnt_size)
528
+    ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5)
529
+    if ax is None:
530
+        # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
531
+        plt.savefig(title+'_plot3_'+str(len(plot_lines))+'_log.pdf')
532
+        # close fig3 to save memory
533
+        plt.close(fig3)
534
+    return ax
535
+
536
+def plot_raw_stairs(plot_lines, prop, title, ax = None, n_ticks = 10):
486
     # multiple fig
537
     # multiple fig
487
     if ax is None:
538
     if ax is None:
488
         # intialize figure 1
539
         # intialize figure 1
493
     else:
544
     else:
494
         fnt_size = 12
545
         fnt_size = 12
495
         # plt.rcParams['font.size'] = fnt_size
546
         # plt.rcParams['font.size'] = fnt_size
496
-        ax1 = ax[0, 1]
547
+        ax1 = ax[0, 0]
497
         plt.subplots_adjust(wspace=0.3, hspace=0.3)
548
         plt.subplots_adjust(wspace=0.3, hspace=0.3)
498
     plots = []
549
     plots = []
499
 
550
 
508
     # print(x_ticks)
559
     # print(x_ticks)
509
     #print(prop, "\n", sum(prop))
560
     #print(prop, "\n", sum(prop))
510
     #ax.legend(handles=[p0]+plots)
561
     #ax.legend(handles=[p0]+plots)
511
-    ax1.set_xlabel("# bin", fontsize=fnt_size)
562
+    ax1.set_xlabel("# bin & cumul. prop. of sites", fontsize=fnt_size)
512
     # Set the x-axis locator to reduce the number of ticks to 10
563
     # Set the x-axis locator to reduce the number of ticks to 10
513
     ax1.set_ylabel("theta", fontsize=fnt_size)
564
     ax1.set_ylabel("theta", fontsize=fnt_size)
514
-    ax1.set_title("Title", fontsize=fnt_size)
565
+    ax1.set_title(title, fontsize=fnt_size)
515
     ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
566
     ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
516
     ax1.set_xticks(x_ticks)
567
     ax1.set_xticks(x_ticks)
517
     step = len(x_ticks)//(n_ticks-1)
568
     step = len(x_ticks)//(n_ticks-1)
523
     ax1.set_xticks(values)
574
     ax1.set_xticks(values)
524
     ax1.set_xticklabels([f'{values[k]}\n{val:.2f}' for k, val in enumerate(new_prop)], fontsize = fnt_size*0.8)
575
     ax1.set_xticklabels([f'{values[k]}\n{val:.2f}' for k, val in enumerate(new_prop)], fontsize = fnt_size*0.8)
525
     if ax is None:
576
     if ax is None:
526
-        plt.savefig(title+'_raw'+str(k)+'.pdf')
527
-    # fig 2 & 3
528
-    if ax is None:
529
-        fig2, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
530
-        fig3, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
531
-    else:
532
-        # plt.rcParams['font.size'] = fnt_size
533
-        # place of plots on the grid
534
-        ax2 = ax[1,0]
535
-        ax3 = ax[1,1]
536
-    lines_fig2 = []
537
-    lines_fig3 = []
538
-    #plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
539
-    for epoch, plot in enumerate(plot_lines2):
540
-        x,y=plot
541
-        x2_plot, y2_plot = plot_straight_x_y(x,y)
542
-        p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
543
-        lines_fig2.append(p2)
544
-        # Plotting (fig 3) which is the same but log scale for x
545
-        p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
546
-        lines_fig3.append(p3)
547
-    ax2.set_xlabel("Relative scale", fontsize=fnt_size)
548
-    ax2.set_ylabel("theta", fontsize=fnt_size)
549
-    ax2.set_title("Title", fontsize=fnt_size)
550
-    ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
551
-    if ax is None:
552
-        plt.savefig(title+'_plot2_'+str(k)+'.pdf')
553
-    ax3.set_xscale('log')
554
-    ax3.set_yscale('log')
555
-    ax3.set_xlabel("log Relative scale", fontsize=fnt_size)
556
-    ax3.set_ylabel("theta", fontsize=fnt_size)
557
-    ax3.set_title("Title", fontsize=fnt_size)
558
-    ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5)
559
-    if ax is None:
560
-        plt.savefig(title+'_plot3_'+str(k)+'_log.pdf')
561
-        plt.clf()
577
+        # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
578
+        plt.savefig(title+'_raw'+str(len(plot_lines))+'.pdf')
579
+        plt.close(fig)
562
     # return plots
580
     # return plots
563
     return ax
581
     return ax
564
 
582
 
642
     ax1.set_xlabel("# bin", fontsize=fnt_size)
660
     ax1.set_xlabel("# bin", fontsize=fnt_size)
643
     # Set the x-axis locator to reduce the number of ticks to 10
661
     # Set the x-axis locator to reduce the number of ticks to 10
644
     ax1.set_ylabel("theta", fontsize=fnt_size)
662
     ax1.set_ylabel("theta", fontsize=fnt_size)
645
-    ax1.set_title("Title", fontsize=fnt_size)
663
+    ax1.set_title(title, fontsize=fnt_size)
646
     ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
664
     ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
647
     ax1.set_xticks(x_ticks)
665
     ax1.set_xticks(x_ticks)
648
     if len(prop) >= 18:
666
     if len(prop) >= 18:
699
         lines_fig3.append(p3)
717
         lines_fig3.append(p3)
700
     ax2.set_xlabel("Relative scale", fontsize=fnt_size)
718
     ax2.set_xlabel("Relative scale", fontsize=fnt_size)
701
     ax2.set_ylabel("theta", fontsize=fnt_size)
719
     ax2.set_ylabel("theta", fontsize=fnt_size)
702
-    ax2.set_title("Title", fontsize=fnt_size)
720
+    ax2.set_title(title, fontsize=fnt_size)
703
     ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
721
     ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
704
     if ax is None:
722
     if ax is None:
705
         plt.savefig(title+'_plot2_'+str(k)+'.pdf')
723
         plt.savefig(title+'_plot2_'+str(k)+'.pdf')
707
     ax3.set_yscale('log')
725
     ax3.set_yscale('log')
708
     ax3.set_xlabel("log Relative scale", fontsize=fnt_size)
726
     ax3.set_xlabel("log Relative scale", fontsize=fnt_size)
709
     ax3.set_ylabel("theta", fontsize=fnt_size)
727
     ax3.set_ylabel("theta", fontsize=fnt_size)
710
-    ax3.set_title("Title", fontsize=fnt_size)
728
+    ax3.set_title(title, fontsize=fnt_size)
711
     ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5)
729
     ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5)
712
     if ax is None:
730
     if ax is None:
713
         plt.savefig(title+'_plot3_'+str(k)+'_log.pdf')
731
         plt.savefig(title+'_plot3_'+str(k)+'_log.pdf')
724
     # ax = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = axs)
742
     # ax = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = axs)
725
     # ax = plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = axs)
743
     # ax = plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = axs)
726
     # # Adjust layout to prevent clipping of titles
744
     # # Adjust layout to prevent clipping of titles
727
-    # plt.tight_layout()
728
-    # # Adjust absolute space between the top and bottom rows
729
-    # #plt.subplots_adjust(hspace=0.7)  # Adjust this value based on your requirement
745
+    #
746
+
730
     # # Save the entire grid as a single figure
747
     # # Save the entire grid as a single figure
731
     # plt.savefig(title+'_combined.pdf')
748
     # plt.savefig(title+'_combined.pdf')
732
     # plt.clf()
749
     # plt.clf()
735
     # # plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = None)
752
     # # plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = None)
736
     # # plt.clf()
753
     # # plt.clf()
737
     # save_k_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, output = title+"_plotdata.json")
754
     # save_k_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, output = title+"_plotdata.json")
738
-
739
     with open(title+"_plotdata.json", 'r') as json_file:
755
     with open(title+"_plotdata.json", 'r') as json_file:
740
         loaded_data = json.load(json_file)
756
         loaded_data = json.load(json_file)
741
-
757
+    # plot page 1 of summary
742
     fig1, ax1 = plt.subplots(2, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
758
     fig1, ax1 = plt.subplots(2, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
743
-    # plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = ax1)
744
-    ax1 = plot_raw_stairs(plot_lines = loaded_data['raw_stairs'], plot_lines2 = loaded_data['scaled_stairs'],
759
+    # fig1.tight_layout()
760
+    # Adjust absolute space between the top and bottom rows
761
+    fig1.subplots_adjust(hspace=0.35)  # Adjust this value based on your requirement
762
+    # plot page 2 of summary
763
+    fig2, ax2 = plt.subplots(2, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
764
+    # fig2.tight_layout()
765
+    ax1 = plot_raw_stairs(plot_lines = loaded_data['raw_stairs'],
766
+                            prop = loaded_data['prop'], title = title, ax = ax1)
767
+
768
+    ax1 = plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'],
745
                             prop = loaded_data['prop'], title = title, ax = ax1)
769
                             prop = loaded_data['prop'], title = title, ax = ax1)
770
+    ax1, ax2 = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = [ax1, ax2])
771
+    fig1.savefig(title+'_combined_p1.pdf')
772
+    fig2.savefig(title+'_combined_p2.pdf')
773
+    plot_raw_stairs(plot_lines = loaded_data['raw_stairs'],
774
+                            prop = loaded_data['prop'], title = title, ax = None)
775
+    plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'],
776
+                            prop = loaded_data['prop'], title = title, ax = None)
746
 
777
 
747
-    plt.savefig(title+'_raw_scaled.pdf')
748
-    fig1.clf()
778
+    plt.close(fig1)
779
+    plt.close(fig2)
749
 if __name__ == "__main__":
780
 if __name__ == "__main__":
750
 
781
 
751
     if len(sys.argv) != 4:
782
     if len(sys.argv) != 4: