Browse Source

Improve plotting with concordant colours

tforest 2 months ago
parent
commit
03ea6f938b
1 changed files with 46 additions and 37 deletions
  1. 46 37
      swp2.py

+ 46 - 37
swp2.py View File

@@ -176,7 +176,7 @@ def plot_all_epochs_thetafolder(full_dict, mu, tgen, title = "Title",
176 176
     ax1.set_title(title)
177 177
     breaks = len(full_dict['all_epochs']['plots'])
178 178
     if ax is None:
179
-        plt.savefig(title+'_b'+str(breaks)+'.pdf')
179
+        plt.savefig(title+'_'+str(breaks+1)+'_epochs.pdf')
180 180
     # plot likelihood against nb of breakpoints
181 181
     if ax is None:
182 182
         fig, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
@@ -184,8 +184,12 @@ def plot_all_epochs_thetafolder(full_dict, mu, tgen, title = "Title",
184 184
     else:
185 185
         #plt.rcParams['font.size'] = fnt_size
186 186
         ax2 = ax[0][0,1]
187
-
188
-    ax2.plot(full_dict['Ln_Brks'][0], full_dict['Ln_Brks'][1], 'o', linestyle = "dotted", lw=2)
187
+    # Retrieve the default color cycle from rcParams
188
+    default_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
189
+    # Create an array of colors from the default color cycle
190
+    colors = [default_colors[i % len(default_colors)] for i in range(len(full_dict['Ln_Brks'][0]))]
191
+    ax2.plot(full_dict['Ln_Brks'][0], full_dict['Ln_Brks'][1], "--", lw=1, color="black", zorder=1)
192
+    ax2.scatter(full_dict['Ln_Brks'][0], full_dict['Ln_Brks'][1], s=50, c=colors, marker='o', zorder=2)
189 193
     ax2.axhline(y=full_dict['best_Ln'], linestyle = "-.", color = "red", label = "$-\log\mathcal{L}$ = "+str(round(full_dict['best_Ln'], 2)))
190 194
     ax2.set_yscale('log')
191 195
     ax2.set_xlabel("# breakpoints", fontsize=fnt_size)
@@ -202,7 +206,9 @@ def plot_all_epochs_thetafolder(full_dict, mu, tgen, title = "Title",
202 206
         #plt.rcParams['font.size'] = fnt_size
203 207
         ax3 = ax[1][0,1]
204 208
     AIC = full_dict['AIC_Brks']
205
-    ax3.plot(AIC[0], AIC[1], 'o', linestyle = "dotted", lw=2)
209
+    # ax3.plot(AIC[0], AIC[1], 'o', linestyle = "dotted", lw=2)
210
+    ax3.plot(AIC[0], AIC[1], "--", lw=1, color="black", zorder=1)
211
+    ax3.scatter(AIC[0], AIC[1], s=50, c=colors, marker='o', zorder=2)
206 212
     ax3.axhline(y=full_dict['best_AIC'], linestyle = "-.", color = "red",
207 213
     label = "Min. AIC = "+str(round(full_dict['best_AIC'], 2)))
208 214
     ax3.set_yscale('log')
@@ -341,7 +347,7 @@ def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
341 347
     for file_name in os.listdir(folder_path):
342 348
         cpt +=1
343 349
         if os.path.isfile(os.path.join(folder_path, file_name)):
344
-            for k in range(breaks_max):
350
+            for k in range(breaks_max+1):
345 351
                 x,y,likelihood,thetas,sfs,L = parse_stwp_theta_file(folder_path+file_name, breaks = k,
346 352
                                                                  tgen = tgen,
347 353
                                                                  mu = mu, relative_theta_scale = theta_scale)
@@ -443,6 +449,8 @@ def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
443 449
     return saved_plots
444 450
 
445 451
 def plot_scaled_theta(plot_lines, prop, title, mu, tgen, swp2_lines = None, ax = None, n_ticks = 10, subset = None, theta_scale = False):
452
+    # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
453
+    nb_epochs = len(plot_lines)
446 454
     # fig 2 & 3
447 455
     if ax is None:
448 456
         my_dpi = 500
@@ -463,40 +471,34 @@ def plot_scaled_theta(plot_lines, prop, title, mu, tgen, swp2_lines = None, ax =
463 471
             swp2_lines[0][k] = swp2_lines[0][k]/tgen*mu
464 472
         for k in range(len(swp2_lines[1])):
465 473
             swp2_lines[1][k] = swp2_lines[1][k]*4*mu
466
-        # plot_lines = [[swp2_lines[0], swp2_lines[1]]]+plot_lines 
467
-
468 474
         x2_plot, y2_plot = plot_straight_x_y(swp2_lines[0],swp2_lines[1])
469
-        p2, = ax2.plot(x2_plot, y2_plot, linestyle="-", alpha=0.75, lw=2, label = 'swp2')
475
+        p2, = ax2.plot(x2_plot, y2_plot, linestyle="-", alpha=0.75, lw=2, label = 'swp2', color="black")
470 476
         lines_fig2.append(p2)
471 477
         # Plotting (fig 3) which is the same but log scale for x
472
-        p3, = ax3.plot(x2_plot, y2_plot, linestyle="-", alpha=0.75, lw=2, label = 'swp2')
478
+        p3, = ax3.plot(x2_plot, y2_plot, linestyle="-", alpha=0.75, lw=2, label = 'swp2', color="black")
473 479
         lines_fig3.append(p3)
474
-    nb_breaks = len(plot_lines)
475 480
     for breaks, plot in enumerate(plot_lines):
476
-        if subset is not None:
477
-            if breaks not in subset :
478
-                # skip if not in subset
479
-                if max(subset) > nb_breaks and breaks == nb_breaks:
480
-                    pass
481
-                else:
482
-                    continue
483 481
         x,y=plot
484
-        # y = [k/(4*mu) for k in y]
485
-        # x = [k/(mu)*tgen for k in x]
486 482
         x2_plot, y2_plot = plot_straight_x_y(x,y)
487
-        p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(breaks)+' brks')
488
-        lines_fig2.append(p2)
483
+        if subset is not None:
484
+            if breaks in subset:
485
+                masking_alpha = 0.75
486
+            else:
487
+                masking_alpha = 0
488
+        p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=masking_alpha, lw=2, label = str(breaks)+' brks')
489 489
         # Plotting (fig 3) which is the same but log scale for x
490
-        p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(breaks)+' brks')
491
-        lines_fig3.append(p3)
492
-
490
+        p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=masking_alpha, lw=2, label = str(breaks)+' brks')
491
+        if subset is not None and breaks in subset:
492
+            # store for legend
493
+            lines_fig2.append(p2)
494
+            lines_fig3.append(p3)
493 495
     ax3.axvline(x=500/tgen*mu, linestyle="--")
494 496
     if theta_scale:
495 497
         xlabel = "Theta scaled by N0"
496 498
         ylabel = "Theta scaled by N0"
497 499
     else:
498
-        xlabel = "Theta scale"
499
-        ylabel = "Theta"
500
+        xlabel = "t"
501
+        ylabel = r"$\theta$"
500 502
     if ax is None:
501 503
         # if not ax, then use the plt syntax, not ax...
502 504
         plt.xlabel(xlabel, fontsize=fnt_size)
@@ -509,8 +511,7 @@ def plot_scaled_theta(plot_lines, prop, title, mu, tgen, swp2_lines = None, ax =
509 511
         plt.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
510 512
         plt.text(-0.13, -0.135, 'Coal. time\nGen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
511 513
         plt.subplots_adjust(bottom=0.2)  # Adjust the value as needed
512
-        # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
513
-        plt.savefig(title+'_plot2_'+str(len(plot_lines))+'.pdf')
514
+        plt.savefig(title+'_plotB_'+str(nb_epochs)+'_epochs.pdf')
514 515
         # close fig2 to save memory
515 516
         plt.close(fig2)
516 517
     else:
@@ -533,29 +534,37 @@ def plot_scaled_theta(plot_lines, prop, title, mu, tgen, swp2_lines = None, ax =
533 534
     plt.subplots_adjust(bottom=0.2)  # Adjust the value as needed
534 535
     if ax is None:
535 536
         # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
536
-        plt.savefig(title+'_plot3_'+str(len(plot_lines))+'_log.pdf')
537
+        plt.savefig(title+'_plotC_'+str(nb_epochs)+'_epochs_log.pdf')
537 538
         # close fig3 to save memory
538 539
         plt.close(fig3)
539 540
     return ax
540 541
 
541
-def plot_raw_stairs(plot_lines, prop, title, ax = None, n_ticks = 10, rescale = False, subset = None):
542
+def plot_raw_stairs(plot_lines, prop, title, ax = None, n_ticks = 10, rescale = False, subset = None, max_breaks = None):
543
+    if max_breaks:
544
+        nb_breaks = max_breaks
545
+    else:
546
+        nb_breaks = len(plot_lines)+1
542 547
     # multiple fig
543 548
     if ax is None:
544 549
         # intialize figure 1
545
-        my_dpi = 300
550
+        my_dpi = 500
546 551
         fnt_size = 18
547 552
         # plt.rcParams['font.size'] = fnt_size
548 553
         fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
554
+        plt.subplots_adjust(bottom=0.2)  # Adjust the value as needed
549 555
     else:
550 556
         fnt_size = 12
551 557
         # plt.rcParams['font.size'] = fnt_size
552 558
         ax1 = ax[0, 0]
553 559
         plt.subplots_adjust(wspace=0.3, hspace=0.3)
554 560
     plots = []
555
-    for epoch, plot in enumerate(plot_lines):
561
+    for breaks, plot in enumerate(plot_lines):
562
+        if max_breaks and breaks > max_breaks:
563
+            # stop plotting if it exceeds the limit
564
+            continue
556 565
         x,y = plot
557 566
         x_plot, y_plot = plot_straight_x_y(x,y)
558
-        p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
567
+        p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(breaks)+' brks')
559 568
 
560 569
         # add plot to the list of all plots to superimpose
561 570
         plots.append(p)
@@ -565,7 +574,7 @@ def plot_raw_stairs(plot_lines, prop, title, ax = None, n_ticks = 10, rescale =
565 574
     #ax.legend(handles=[p0]+plots)
566 575
     ax1.set_xlabel("# bin & cumul. prop. of sites", fontsize=fnt_size)
567 576
     # Set the x-axis locator to reduce the number of ticks to 10
568
-    ax1.set_ylabel("theta", fontsize=fnt_size)
577
+    ax1.set_ylabel(r'$\theta_k$', fontsize=fnt_size, rotation = 90)
569 578
     ax1.set_title(title, fontsize=fnt_size)
570 579
     ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
571 580
     ax1.set_xticks(x_ticks)
@@ -579,7 +588,7 @@ def plot_raw_stairs(plot_lines, prop, title, ax = None, n_ticks = 10, rescale =
579 588
     ax1.set_xticklabels([f'{values[k]}\n{val:.2f}' for k, val in enumerate(new_prop)], fontsize = fnt_size*0.8)
580 589
     if ax is None:
581 590
         # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
582
-        plt.savefig(title+'_raw'+str(len(plot_lines))+'.pdf')
591
+        plt.savefig(title+'_raw_'+str(nb_breaks)+'_breaks.pdf')
583 592
         plt.close(fig)
584 593
     # return plots
585 594
     return ax
@@ -588,8 +597,8 @@ def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale =
588 597
     my_dpi = 300
589 598
     saved_plots_dict = save_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, output = title+"_plotdata.json")
590 599
     nb_of_epochs = len(saved_plots_dict["all_epochs"]["plots"])
591
-    print(nb_of_epochs)
592 600
     best_epoch = saved_plots_dict["best_epoch_by_AIC"]
601
+    print("Best epoch based on AIC =", best_epoch)
593 602
     save_k_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = nb_of_epochs, input = title+"_plotdata.json", output = title+"_plotdata.json")
594 603
 
595 604
     with open(title+"_plotdata.json", 'r') as json_file:
@@ -628,7 +637,7 @@ def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale =
628 637
     swp2_x, swp2_y = swp2_vals[0], swp2_vals[1]
629 638
     # End of Parsing real swp2 output
630 639
     plot_raw_stairs(plot_lines = loaded_data['raw_stairs'],
631
-                            prop = loaded_data['prop'], title = title, ax = None)
640
+                            prop = loaded_data['prop'], title = title, ax = None, max_breaks = breaks)
632 641
     plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'], mu = mu, tgen = tgen, subset=[loaded_data['best_epoch_by_AIC']]+selected_breaks,
633 642
     # plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'], subset=list(range(0,3))+[loaded_data['best_epoch_by_AIC']]+selected_breaks,
634 643
                             prop = loaded_data['prop'], title = title, swp2_lines = [swp2_x, swp2_y], ax = None)