瀏覽代碼

Fixing plot font size issues

tforest 5 月之前
父節點
當前提交
6287f2aa71
共有 1 個文件被更改,包括 57 次插入48 次删除
  1. 57 48
      swp2.py

+ 57 - 48
swp2.py 查看文件

7
 from matplotlib.ticker import MaxNLocator
7
 from matplotlib.ticker import MaxNLocator
8
 import io
8
 import io
9
 from mpl_toolkits.axes_grid1.inset_locator import inset_axes
9
 from mpl_toolkits.axes_grid1.inset_locator import inset_axes
10
-
10
+from matplotlib.ticker import MultipleLocator
11
 def log_facto(k):
11
 def log_facto(k):
12
     k = int(k)
12
     k = int(k)
13
     if k > 1e6:
13
     if k > 1e6:
259
         # intialize figure
259
         # intialize figure
260
         my_dpi = 300
260
         my_dpi = 300
261
         fnt_size = 18
261
         fnt_size = 18
262
+        # plt.rcParams['font.size'] = fnt_size
262
         fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
263
         fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
263
     else:
264
     else:
264
         fnt_size = 12
265
         fnt_size = 12
265
-        plt.rcParams['font.size'] = fnt_size
266
+        # plt.rcParams['font.size'] = fnt_size
266
         ax1 = ax[0,0]
267
         ax1 = ax[0,0]
267
     #ax1.set_xlim(1e-3, 1)
268
     #ax1.set_xlim(1e-3, 1)
268
     ax1.set_yscale('log')
269
     ax1.set_yscale('log')
293
     best10_plots = sorted(plots_likelihoods)[:10]
294
     best10_plots = sorted(plots_likelihoods)[:10]
294
     top_plot_lik = str(best10_plots[0])
295
     top_plot_lik = str(best10_plots[0])
295
     plot_handles = []
296
     plot_handles = []
297
+    # plt.rcParams['font.size'] = fnt_size
296
     p0, = ax1.plot(top_plots[top_plot_lik][0], top_plots[top_plot_lik][1], 'o', linestyle = "-",
298
     p0, = ax1.plot(top_plots[top_plot_lik][0], top_plots[top_plot_lik][1], 'o', linestyle = "-",
297
-    alpha=1, lw=2, label = str(top_plots[top_plot_lik][2])+' epoch | Lik='+top_plot_lik)
299
+    alpha=1, lw=2, label = str(top_plots[top_plot_lik][2])+' brks | Lik='+top_plot_lik)
298
     plot_handles.append(p0)
300
     plot_handles.append(p0)
299
     for k, plot_Lk in enumerate(best10_plots[1:]):
301
     for k, plot_Lk in enumerate(best10_plots[1:]):
300
         plot_Lk = str(plot_Lk)
302
         plot_Lk = str(plot_Lk)
303
+        # plt.rcParams['font.size'] = fnt_size
301
         p, = ax1.plot(top_plots[plot_Lk][0], top_plots[plot_Lk][1], 'o', linestyle = "--",
304
         p, = ax1.plot(top_plots[plot_Lk][0], top_plots[plot_Lk][1], 'o', linestyle = "--",
302
-        alpha=1/(k+1), lw=1.5, label = str(top_plots[plot_Lk][2])+' epoch | Lik='+plot_Lk)
305
+        alpha=1/(k+1), lw=1.5, label = str(top_plots[plot_Lk][2])+' brks | Lik='+plot_Lk)
303
         plot_handles.append(p)
306
         plot_handles.append(p)
304
     if theta_scale:
307
     if theta_scale:
305
-        ax1.set_xlabel("Coal. time")
306
-        ax1.set_ylabel("Pop. size scaled by N0")
308
+        ax1.set_xlabel("Coal. time", fontsize=fnt_size)
309
+        ax1.set_ylabel("Pop. size scaled by N0", fontsize=fnt_size)
307
         # recent_scale_lower_bound = 0.01
310
         # recent_scale_lower_bound = 0.01
308
         # recent_scale_upper_bound = 0.1
311
         # recent_scale_upper_bound = 0.1
309
         # ax1.axvline(x=recent_scale_lower_bound)
312
         # ax1.axvline(x=recent_scale_lower_bound)
310
         # ax1.axvline(x=recent_scale_upper_bound)
313
         # ax1.axvline(x=recent_scale_upper_bound)
311
     else:
314
     else:
312
         # years
315
         # years
313
-        plt.set_xlabel("Time (years)")
314
-        plt.set_ylabel("Individuals (N)")
316
+        plt.set_xlabel("Time (years)", fontsize=fnt_size)
317
+        plt.set_ylabel("Individuals (N)", fontsize=fnt_size)
315
         ax1.set_xlim(1e-5, 1)
318
         ax1.set_xlim(1e-5, 1)
319
+    # plt.rcParams['font.size'] = fnt_size
320
+    # print(fnt_size, "rcParam font.size=", plt.rcParams['font.size'])
316
     ax1.legend(handles = plot_handles, loc='best', fontsize = fnt_size*0.5)
321
     ax1.legend(handles = plot_handles, loc='best', fontsize = fnt_size*0.5)
317
     ax1.set_title(title)
322
     ax1.set_title(title)
318
     if ax is None:
323
     if ax is None:
335
     # basic plot likelihood
340
     # basic plot likelihood
336
     if ax is None:
341
     if ax is None:
337
         fig, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
342
         fig, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
338
-        plt.rcParams['font.size'] = '18'
343
+        # plt.rcParams['font.size'] = fnt_size
339
     else:
344
     else:
340
-        plt.rcParams['font.size'] = fnt_size
345
+        #plt.rcParams['font.size'] = fnt_size
341
         ax2 = ax[2,0]
346
         ax2 = ax[2,0]
342
     ax2.plot(np.array(brkpt_lik)[:, 0], np.array(brkpt_lik)[:, 1].astype(float), 'o', linestyle = "dotted", lw=2)
347
     ax2.plot(np.array(brkpt_lik)[:, 0], np.array(brkpt_lik)[:, 1].astype(float), 'o', linestyle = "dotted", lw=2)
343
     ax2.axhline(y=-Ln, linestyle = "-.", color = "red", label = "$-\log\mathcal{L}$ = "+str(round(-Ln, 2)))
348
     ax2.axhline(y=-Ln, linestyle = "-.", color = "red", label = "$-\log\mathcal{L}$ = "+str(round(-Ln, 2)))
344
     ax2.set_yscale('log')
349
     ax2.set_yscale('log')
345
     ax2.set_xlabel("# breakpoints", fontsize=fnt_size)
350
     ax2.set_xlabel("# breakpoints", fontsize=fnt_size)
346
-    ax2.set_ylabel("$-\log\mathcal{L}$")
347
-    ax2.legend(loc='best', fontsize = fnt_size*0.8)
351
+    ax2.set_ylabel("$-\log\mathcal{L}$", fontsize=fnt_size)
352
+    ax2.legend(loc='best', fontsize = fnt_size*0.5)
348
     ax2.set_title(title+" Likelihood gain from # breakpoints")
353
     ax2.set_title(title+" Likelihood gain from # breakpoints")
349
     if ax is None:
354
     if ax is None:
350
         plt.savefig(title+'_Breakpts_Likelihood.pdf')
355
         plt.savefig(title+'_Breakpts_Likelihood.pdf')
351
     # AIC
356
     # AIC
352
     if ax is None:
357
     if ax is None:
353
         fig, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
358
         fig, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
354
-        plt.rcParams['font.size'] = '18'
359
+        # plt.rcParams['font.size'] = '18'
355
     else:
360
     else:
361
+        #plt.rcParams['font.size'] = fnt_size
356
         ax3 = ax[2,1]
362
         ax3 = ax[2,1]
357
     AIC = 2*(len(brkpt_lik)+1)+2*np.array(brkpt_lik)[:, 1].astype(float)
363
     AIC = 2*(len(brkpt_lik)+1)+2*np.array(brkpt_lik)[:, 1].astype(float)
358
     ax3.plot(np.array(brkpt_lik)[:, 0], AIC, 'o', linestyle = "dotted", lw=2)
364
     ax3.plot(np.array(brkpt_lik)[:, 0], AIC, 'o', linestyle = "dotted", lw=2)
362
     ax3.set_yscale('log')
368
     ax3.set_yscale('log')
363
     ax3.set_xlabel("# breakpoints", fontsize=fnt_size)
369
     ax3.set_xlabel("# breakpoints", fontsize=fnt_size)
364
     ax3.set_ylabel("AIC")
370
     ax3.set_ylabel("AIC")
365
-    ax3.legend(loc='best', fontsize = fnt_size*0.8)
371
+    ax3.legend(loc='best', fontsize = fnt_size*0.5)
366
     ax3.set_title(title+" AIC")
372
     ax3.set_title(title+" AIC")
367
     if ax is None:
373
     if ax is None:
368
         plt.savefig(title+'_Breakpts_Likelihood_AIC.pdf')
374
         plt.savefig(title+'_Breakpts_Likelihood_AIC.pdf')
375
+    print("S", S)
369
     return ax
376
     return ax
370
-def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True, breaks_max = 10, ax = None):
377
+
378
+def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True, breaks_max = 10, ax = None, n_ticks = 10):
371
     """
379
     """
372
     Use theta values as is to do basic plots.
380
     Use theta values as is to do basic plots.
373
     """
381
     """
390
         # intialize figure 1
398
         # intialize figure 1
391
         my_dpi = 300
399
         my_dpi = 300
392
         fnt_size = 18
400
         fnt_size = 18
401
+        # plt.rcParams['font.size'] = fnt_size
393
         fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
402
         fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
394
-        # Add some extra space for the second axis at the bottom
395
-        fig.subplots_adjust(bottom=0.15)
396
     else:
403
     else:
397
         fnt_size = 12
404
         fnt_size = 12
405
+        # plt.rcParams['font.size'] = fnt_size
398
         ax1 = ax[0, 1]
406
         ax1 = ax[0, 1]
399
         plt.subplots_adjust(wspace=0.3, hspace=0.3)
407
         plt.subplots_adjust(wspace=0.3, hspace=0.3)
400
 
408
 
401
-    twin = ax1.twiny()
402
-
403
     plots = []
409
     plots = []
404
     for epoch, theta in epochs.items():
410
     for epoch, theta in epochs.items():
405
         groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
411
         groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
421
                 prop = prop[::-1]
427
                 prop = prop[::-1]
422
                 # print(prop, "\n", sum(prop))
428
                 # print(prop, "\n", sum(prop))
423
         # normalise to N0 (N0 of epoch1)
429
         # normalise to N0 (N0 of epoch1)
430
+        x_ticks = ax1.get_xticks()
424
         for i in range(len(y)):
431
         for i in range(len(y)):
425
             y[i] = y[i]/N0
432
             y[i] = y[i]/N0
426
         # plot
433
         # plot
427
         x_plot, y_plot = plot_straight_x_y(x, y)
434
         x_plot, y_plot = plot_straight_x_y(x, y)
428
         #plt.plot(x, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
435
         #plt.plot(x, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
429
-        p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-.", alpha=0.75, lw=2, label = str(epoch)+' brks')
436
+        p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
430
         # add plot to the list of all plots to superimpose
437
         # add plot to the list of all plots to superimpose
431
         plots.append(p)
438
         plots.append(p)
432
-    # virtual line to get the second x axis for proportions
433
-    p0, = twin.plot(prop, y, alpha = 0, label="Proportion")
434
-    # Move twinned axis ticks and label from top to bottom
435
-    twin.xaxis.set_ticks_position("bottom")
436
-    twin.xaxis.set_label_position("bottom")
437
-    # Offset the twin axis below the host
438
-    if ax is None:
439
-        # arrange differently the second x axis if the plot is plain
440
-        twin.spines["bottom"].set_position(("axes", -0.15))
441
-    else:
442
-        # in a combined plot, more space between the fig and the axis
443
-        twin.spines["bottom"].set_position(("axes", -0.35))
439
+    #print(prop, "\n", sum(prop))
444
     #ax.legend(handles=[p0]+plots)
440
     #ax.legend(handles=[p0]+plots)
445
-    ax1.set_xlabel("# bin")
441
+    ax1.set_xlabel("# bin", fontsize=fnt_size)
446
     # Set the x-axis locator to reduce the number of ticks to 10
442
     # Set the x-axis locator to reduce the number of ticks to 10
447
-    ax1.xaxis.set_major_locator(MaxNLocator(nbins=10))
448
-    twin.xaxis.set_major_locator(MaxNLocator(nbins=10))
449
-    ax1.set_ylabel("theta")
450
-    twin.set_ylabel("Proportion")
451
-    ax1.set_title("Title")
443
+    ax1.set_ylabel("theta", fontsize=fnt_size)
444
+    ax1.set_title("Title", fontsize=fnt_size)
452
     ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
445
     ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
446
+    ax1.set_xticks(x_ticks)
447
+    if len(prop) >= 18:
448
+        ax1.locator_params(nbins=n_ticks)
449
+    # new scale of ticks if too many values
450
+    cumul = 0
451
+    prop_cumul = []
452
+    for val in prop:
453
+        prop_cumul.append(val+cumul)
454
+        cumul = val+cumul
455
+    ax1.set_xticklabels([f'{x[k]}\n{val:.2f}' for k, val in enumerate(prop_cumul)])
453
     if ax is None:
456
     if ax is None:
454
         plt.savefig(title+'_raw'+str(k)+'.pdf')
457
         plt.savefig(title+'_raw'+str(k)+'.pdf')
455
     # fig 2 & 3
458
     # fig 2 & 3
457
         fig2, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
460
         fig2, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
458
         fig3, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
461
         fig3, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
459
     else:
462
     else:
463
+        # plt.rcParams['font.size'] = fnt_size
460
         # place of plots on the grid
464
         # place of plots on the grid
461
         ax2 = ax[1,0]
465
         ax2 = ax[1,0]
462
         ax3 = ax[1,1]
466
         ax3 = ax[1,1]
487
         x_2 = [0]+x_2
491
         x_2 = [0]+x_2
488
         y = [y[0]]+y
492
         y = [y[0]]+y
489
         x2_plot, y2_plot = plot_straight_x_y(x_2, y)
493
         x2_plot, y2_plot = plot_straight_x_y(x_2, y)
490
-        p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-.", alpha=0.75, lw=2, label = str(epoch)+' brks')
494
+        p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
491
         lines_fig2.append(p2)
495
         lines_fig2.append(p2)
492
         # Plotting (fig 3) which is the same but log scale for x
496
         # Plotting (fig 3) which is the same but log scale for x
493
-        p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-.", alpha=0.75, lw=2, label = str(epoch)+' brks')
497
+        p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
494
         lines_fig3.append(p3)
498
         lines_fig3.append(p3)
495
-    ax2.set_xlabel("Relative scale")
496
-    ax2.set_ylabel("theta")
497
-    ax2.set_title("Title")
499
+    ax2.set_xlabel("Relative scale", fontsize=fnt_size)
500
+    ax2.set_ylabel("theta", fontsize=fnt_size)
501
+    ax2.set_title("Title", fontsize=fnt_size)
498
     ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
502
     ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
499
     if ax is None:
503
     if ax is None:
500
         plt.savefig(title+'_plot2_'+str(k)+'.pdf')
504
         plt.savefig(title+'_plot2_'+str(k)+'.pdf')
501
     ax3.set_xscale('log')
505
     ax3.set_xscale('log')
502
-    ax3.set_xlabel("log Relative scale")
503
-    ax3.set_ylabel("theta")
504
-    ax3.set_title("Title")
506
+    ax3.set_yscale('log')
507
+    ax3.set_xlabel("log Relative scale", fontsize=fnt_size)
508
+    ax3.set_ylabel("theta", fontsize=fnt_size)
509
+    ax3.set_title("Title", fontsize=fnt_size)
505
     ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5)
510
     ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5)
506
     if ax is None:
511
     if ax is None:
507
         plt.savefig(title+'_plot3_'+str(k)+'_log.pdf')
512
         plt.savefig(title+'_plot3_'+str(k)+'_log.pdf')
513
+        plt.clf()
508
     # return plots
514
     # return plots
509
     return ax
515
     return ax
510
 
516
 
511
 def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale = True):
517
 def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale = True):
512
     my_dpi = 300
518
     my_dpi = 300
513
     # Add some extra space for the second axis at the bottom
519
     # Add some extra space for the second axis at the bottom
520
+    #plt.rcParams['font.size'] = 18
514
     fig, axs = plt.subplots(3, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
521
     fig, axs = plt.subplots(3, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
522
+    #plt.rcParams['font.size'] = 12
515
     ax = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = axs)
523
     ax = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = axs)
516
     ax = plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = axs)
524
     ax = plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = axs)
517
     # Adjust layout to prevent clipping of titles
525
     # Adjust layout to prevent clipping of titles
518
     plt.tight_layout()
526
     plt.tight_layout()
519
     # Adjust absolute space between the top and bottom rows
527
     # Adjust absolute space between the top and bottom rows
520
-    plt.subplots_adjust(hspace=0.7)  # Adjust this value based on your requirement
528
+    #plt.subplots_adjust(hspace=0.7)  # Adjust this value based on your requirement
521
     # Save the entire grid as a single figure
529
     # Save the entire grid as a single figure
522
     plt.savefig(title+'_combined.pdf')
530
     plt.savefig(title+'_combined.pdf')
523
-    plt.close()
531
+    plt.clf()
524
     # second call for individual plots
532
     # second call for individual plots
525
     plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = None)
533
     plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = None)
526
     plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = None)
534
     plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = None)
535
+    plt.clf()
527
 
536
 
528
 if __name__ == "__main__":
537
 if __name__ == "__main__":
529
 
538