Browse Source

Improve plotting with concordant colours

tforest 9 months ago
parent
commit
03ea6f938b
1 changed files with 46 additions and 37 deletions
  1. 46 37
      swp2.py

+ 46 - 37
swp2.py View File

176
     ax1.set_title(title)
176
     ax1.set_title(title)
177
     breaks = len(full_dict['all_epochs']['plots'])
177
     breaks = len(full_dict['all_epochs']['plots'])
178
     if ax is None:
178
     if ax is None:
179
-        plt.savefig(title+'_b'+str(breaks)+'.pdf')
179
+        plt.savefig(title+'_'+str(breaks+1)+'_epochs.pdf')
180
     # plot likelihood against nb of breakpoints
180
     # plot likelihood against nb of breakpoints
181
     if ax is None:
181
     if ax is None:
182
         fig, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
182
         fig, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
184
     else:
184
     else:
185
         #plt.rcParams['font.size'] = fnt_size
185
         #plt.rcParams['font.size'] = fnt_size
186
         ax2 = ax[0][0,1]
186
         ax2 = ax[0][0,1]
187
-
188
-    ax2.plot(full_dict['Ln_Brks'][0], full_dict['Ln_Brks'][1], 'o', linestyle = "dotted", lw=2)
187
+    # Retrieve the default color cycle from rcParams
188
+    default_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
189
+    # Create an array of colors from the default color cycle
190
+    colors = [default_colors[i % len(default_colors)] for i in range(len(full_dict['Ln_Brks'][0]))]
191
+    ax2.plot(full_dict['Ln_Brks'][0], full_dict['Ln_Brks'][1], "--", lw=1, color="black", zorder=1)
192
+    ax2.scatter(full_dict['Ln_Brks'][0], full_dict['Ln_Brks'][1], s=50, c=colors, marker='o', zorder=2)
189
     ax2.axhline(y=full_dict['best_Ln'], linestyle = "-.", color = "red", label = "$-\log\mathcal{L}$ = "+str(round(full_dict['best_Ln'], 2)))
193
     ax2.axhline(y=full_dict['best_Ln'], linestyle = "-.", color = "red", label = "$-\log\mathcal{L}$ = "+str(round(full_dict['best_Ln'], 2)))
190
     ax2.set_yscale('log')
194
     ax2.set_yscale('log')
191
     ax2.set_xlabel("# breakpoints", fontsize=fnt_size)
195
     ax2.set_xlabel("# breakpoints", fontsize=fnt_size)
202
         #plt.rcParams['font.size'] = fnt_size
206
         #plt.rcParams['font.size'] = fnt_size
203
         ax3 = ax[1][0,1]
207
         ax3 = ax[1][0,1]
204
     AIC = full_dict['AIC_Brks']
208
     AIC = full_dict['AIC_Brks']
205
-    ax3.plot(AIC[0], AIC[1], 'o', linestyle = "dotted", lw=2)
209
+    # ax3.plot(AIC[0], AIC[1], 'o', linestyle = "dotted", lw=2)
210
+    ax3.plot(AIC[0], AIC[1], "--", lw=1, color="black", zorder=1)
211
+    ax3.scatter(AIC[0], AIC[1], s=50, c=colors, marker='o', zorder=2)
206
     ax3.axhline(y=full_dict['best_AIC'], linestyle = "-.", color = "red",
212
     ax3.axhline(y=full_dict['best_AIC'], linestyle = "-.", color = "red",
207
     label = "Min. AIC = "+str(round(full_dict['best_AIC'], 2)))
213
     label = "Min. AIC = "+str(round(full_dict['best_AIC'], 2)))
208
     ax3.set_yscale('log')
214
     ax3.set_yscale('log')
341
     for file_name in os.listdir(folder_path):
347
     for file_name in os.listdir(folder_path):
342
         cpt +=1
348
         cpt +=1
343
         if os.path.isfile(os.path.join(folder_path, file_name)):
349
         if os.path.isfile(os.path.join(folder_path, file_name)):
344
-            for k in range(breaks_max):
350
+            for k in range(breaks_max+1):
345
                 x,y,likelihood,thetas,sfs,L = parse_stwp_theta_file(folder_path+file_name, breaks = k,
351
                 x,y,likelihood,thetas,sfs,L = parse_stwp_theta_file(folder_path+file_name, breaks = k,
346
                                                                  tgen = tgen,
352
                                                                  tgen = tgen,
347
                                                                  mu = mu, relative_theta_scale = theta_scale)
353
                                                                  mu = mu, relative_theta_scale = theta_scale)
443
     return saved_plots
449
     return saved_plots
444
 
450
 
445
 def plot_scaled_theta(plot_lines, prop, title, mu, tgen, swp2_lines = None, ax = None, n_ticks = 10, subset = None, theta_scale = False):
451
 def plot_scaled_theta(plot_lines, prop, title, mu, tgen, swp2_lines = None, ax = None, n_ticks = 10, subset = None, theta_scale = False):
452
+    # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
453
+    nb_epochs = len(plot_lines)
446
     # fig 2 & 3
454
     # fig 2 & 3
447
     if ax is None:
455
     if ax is None:
448
         my_dpi = 500
456
         my_dpi = 500
463
             swp2_lines[0][k] = swp2_lines[0][k]/tgen*mu
471
             swp2_lines[0][k] = swp2_lines[0][k]/tgen*mu
464
         for k in range(len(swp2_lines[1])):
472
         for k in range(len(swp2_lines[1])):
465
             swp2_lines[1][k] = swp2_lines[1][k]*4*mu
473
             swp2_lines[1][k] = swp2_lines[1][k]*4*mu
466
-        # plot_lines = [[swp2_lines[0], swp2_lines[1]]]+plot_lines 
467
-
468
         x2_plot, y2_plot = plot_straight_x_y(swp2_lines[0],swp2_lines[1])
474
         x2_plot, y2_plot = plot_straight_x_y(swp2_lines[0],swp2_lines[1])
469
-        p2, = ax2.plot(x2_plot, y2_plot, linestyle="-", alpha=0.75, lw=2, label = 'swp2')
475
+        p2, = ax2.plot(x2_plot, y2_plot, linestyle="-", alpha=0.75, lw=2, label = 'swp2', color="black")
470
         lines_fig2.append(p2)
476
         lines_fig2.append(p2)
471
         # Plotting (fig 3) which is the same but log scale for x
477
         # Plotting (fig 3) which is the same but log scale for x
472
-        p3, = ax3.plot(x2_plot, y2_plot, linestyle="-", alpha=0.75, lw=2, label = 'swp2')
478
+        p3, = ax3.plot(x2_plot, y2_plot, linestyle="-", alpha=0.75, lw=2, label = 'swp2', color="black")
473
         lines_fig3.append(p3)
479
         lines_fig3.append(p3)
474
-    nb_breaks = len(plot_lines)
475
     for breaks, plot in enumerate(plot_lines):
480
     for breaks, plot in enumerate(plot_lines):
476
-        if subset is not None:
477
-            if breaks not in subset :
478
-                # skip if not in subset
479
-                if max(subset) > nb_breaks and breaks == nb_breaks:
480
-                    pass
481
-                else:
482
-                    continue
483
         x,y=plot
481
         x,y=plot
484
-        # y = [k/(4*mu) for k in y]
485
-        # x = [k/(mu)*tgen for k in x]
486
         x2_plot, y2_plot = plot_straight_x_y(x,y)
482
         x2_plot, y2_plot = plot_straight_x_y(x,y)
487
-        p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(breaks)+' brks')
488
-        lines_fig2.append(p2)
483
+        if subset is not None:
484
+            if breaks in subset:
485
+                masking_alpha = 0.75
486
+            else:
487
+                masking_alpha = 0
488
+        p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=masking_alpha, lw=2, label = str(breaks)+' brks')
489
         # Plotting (fig 3) which is the same but log scale for x
489
         # Plotting (fig 3) which is the same but log scale for x
490
-        p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(breaks)+' brks')
491
-        lines_fig3.append(p3)
492
-
490
+        p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=masking_alpha, lw=2, label = str(breaks)+' brks')
491
+        if subset is not None and breaks in subset:
492
+            # store for legend
493
+            lines_fig2.append(p2)
494
+            lines_fig3.append(p3)
493
     ax3.axvline(x=500/tgen*mu, linestyle="--")
495
     ax3.axvline(x=500/tgen*mu, linestyle="--")
494
     if theta_scale:
496
     if theta_scale:
495
         xlabel = "Theta scaled by N0"
497
         xlabel = "Theta scaled by N0"
496
         ylabel = "Theta scaled by N0"
498
         ylabel = "Theta scaled by N0"
497
     else:
499
     else:
498
-        xlabel = "Theta scale"
499
-        ylabel = "Theta"
500
+        xlabel = "t"
501
+        ylabel = r"$\theta$"
500
     if ax is None:
502
     if ax is None:
501
         # if not ax, then use the plt syntax, not ax...
503
         # if not ax, then use the plt syntax, not ax...
502
         plt.xlabel(xlabel, fontsize=fnt_size)
504
         plt.xlabel(xlabel, fontsize=fnt_size)
509
         plt.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
511
         plt.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
510
         plt.text(-0.13, -0.135, 'Coal. time\nGen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
512
         plt.text(-0.13, -0.135, 'Coal. time\nGen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes)
511
         plt.subplots_adjust(bottom=0.2)  # Adjust the value as needed
513
         plt.subplots_adjust(bottom=0.2)  # Adjust the value as needed
512
-        # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
513
-        plt.savefig(title+'_plot2_'+str(len(plot_lines))+'.pdf')
514
+        plt.savefig(title+'_plotB_'+str(nb_epochs)+'_epochs.pdf')
514
         # close fig2 to save memory
515
         # close fig2 to save memory
515
         plt.close(fig2)
516
         plt.close(fig2)
516
     else:
517
     else:
533
     plt.subplots_adjust(bottom=0.2)  # Adjust the value as needed
534
     plt.subplots_adjust(bottom=0.2)  # Adjust the value as needed
534
     if ax is None:
535
     if ax is None:
535
         # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
536
         # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
536
-        plt.savefig(title+'_plot3_'+str(len(plot_lines))+'_log.pdf')
537
+        plt.savefig(title+'_plotC_'+str(nb_epochs)+'_epochs_log.pdf')
537
         # close fig3 to save memory
538
         # close fig3 to save memory
538
         plt.close(fig3)
539
         plt.close(fig3)
539
     return ax
540
     return ax
540
 
541
 
541
-def plot_raw_stairs(plot_lines, prop, title, ax = None, n_ticks = 10, rescale = False, subset = None):
542
+def plot_raw_stairs(plot_lines, prop, title, ax = None, n_ticks = 10, rescale = False, subset = None, max_breaks = None):
543
+    if max_breaks:
544
+        nb_breaks = max_breaks
545
+    else:
546
+        nb_breaks = len(plot_lines)+1
542
     # multiple fig
547
     # multiple fig
543
     if ax is None:
548
     if ax is None:
544
         # intialize figure 1
549
         # intialize figure 1
545
-        my_dpi = 300
550
+        my_dpi = 500
546
         fnt_size = 18
551
         fnt_size = 18
547
         # plt.rcParams['font.size'] = fnt_size
552
         # plt.rcParams['font.size'] = fnt_size
548
         fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
553
         fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
554
+        plt.subplots_adjust(bottom=0.2)  # Adjust the value as needed
549
     else:
555
     else:
550
         fnt_size = 12
556
         fnt_size = 12
551
         # plt.rcParams['font.size'] = fnt_size
557
         # plt.rcParams['font.size'] = fnt_size
552
         ax1 = ax[0, 0]
558
         ax1 = ax[0, 0]
553
         plt.subplots_adjust(wspace=0.3, hspace=0.3)
559
         plt.subplots_adjust(wspace=0.3, hspace=0.3)
554
     plots = []
560
     plots = []
555
-    for epoch, plot in enumerate(plot_lines):
561
+    for breaks, plot in enumerate(plot_lines):
562
+        if max_breaks and breaks > max_breaks:
563
+            # stop plotting if it exceeds the limit
564
+            continue
556
         x,y = plot
565
         x,y = plot
557
         x_plot, y_plot = plot_straight_x_y(x,y)
566
         x_plot, y_plot = plot_straight_x_y(x,y)
558
-        p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
567
+        p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(breaks)+' brks')
559
 
568
 
560
         # add plot to the list of all plots to superimpose
569
         # add plot to the list of all plots to superimpose
561
         plots.append(p)
570
         plots.append(p)
565
     #ax.legend(handles=[p0]+plots)
574
     #ax.legend(handles=[p0]+plots)
566
     ax1.set_xlabel("# bin & cumul. prop. of sites", fontsize=fnt_size)
575
     ax1.set_xlabel("# bin & cumul. prop. of sites", fontsize=fnt_size)
567
     # Set the x-axis locator to reduce the number of ticks to 10
576
     # Set the x-axis locator to reduce the number of ticks to 10
568
-    ax1.set_ylabel("theta", fontsize=fnt_size)
577
+    ax1.set_ylabel(r'$\theta_k$', fontsize=fnt_size, rotation = 90)
569
     ax1.set_title(title, fontsize=fnt_size)
578
     ax1.set_title(title, fontsize=fnt_size)
570
     ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
579
     ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
571
     ax1.set_xticks(x_ticks)
580
     ax1.set_xticks(x_ticks)
579
     ax1.set_xticklabels([f'{values[k]}\n{val:.2f}' for k, val in enumerate(new_prop)], fontsize = fnt_size*0.8)
588
     ax1.set_xticklabels([f'{values[k]}\n{val:.2f}' for k, val in enumerate(new_prop)], fontsize = fnt_size*0.8)
580
     if ax is None:
589
     if ax is None:
581
         # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
590
         # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1)
582
-        plt.savefig(title+'_raw'+str(len(plot_lines))+'.pdf')
591
+        plt.savefig(title+'_raw_'+str(nb_breaks)+'_breaks.pdf')
583
         plt.close(fig)
592
         plt.close(fig)
584
     # return plots
593
     # return plots
585
     return ax
594
     return ax
588
     my_dpi = 300
597
     my_dpi = 300
589
     saved_plots_dict = save_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, output = title+"_plotdata.json")
598
     saved_plots_dict = save_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, output = title+"_plotdata.json")
590
     nb_of_epochs = len(saved_plots_dict["all_epochs"]["plots"])
599
     nb_of_epochs = len(saved_plots_dict["all_epochs"]["plots"])
591
-    print(nb_of_epochs)
592
     best_epoch = saved_plots_dict["best_epoch_by_AIC"]
600
     best_epoch = saved_plots_dict["best_epoch_by_AIC"]
601
+    print("Best epoch based on AIC =", best_epoch)
593
     save_k_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = nb_of_epochs, input = title+"_plotdata.json", output = title+"_plotdata.json")
602
     save_k_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = nb_of_epochs, input = title+"_plotdata.json", output = title+"_plotdata.json")
594
 
603
 
595
     with open(title+"_plotdata.json", 'r') as json_file:
604
     with open(title+"_plotdata.json", 'r') as json_file:
628
     swp2_x, swp2_y = swp2_vals[0], swp2_vals[1]
637
     swp2_x, swp2_y = swp2_vals[0], swp2_vals[1]
629
     # End of Parsing real swp2 output
638
     # End of Parsing real swp2 output
630
     plot_raw_stairs(plot_lines = loaded_data['raw_stairs'],
639
     plot_raw_stairs(plot_lines = loaded_data['raw_stairs'],
631
-                            prop = loaded_data['prop'], title = title, ax = None)
640
+                            prop = loaded_data['prop'], title = title, ax = None, max_breaks = breaks)
632
     plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'], mu = mu, tgen = tgen, subset=[loaded_data['best_epoch_by_AIC']]+selected_breaks,
641
     plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'], mu = mu, tgen = tgen, subset=[loaded_data['best_epoch_by_AIC']]+selected_breaks,
633
     # plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'], subset=list(range(0,3))+[loaded_data['best_epoch_by_AIC']]+selected_breaks,
642
     # plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'], subset=list(range(0,3))+[loaded_data['best_epoch_by_AIC']]+selected_breaks,
634
                             prop = loaded_data['prop'], title = title, swp2_lines = [swp2_x, swp2_y], ax = None)
643
                             prop = loaded_data['prop'], title = title, swp2_lines = [swp2_x, swp2_y], ax = None)