Browse Source

Some progress on subplots

tforest 5 months ago
parent
commit
43f689e764
1 changed files with 69 additions and 28 deletions
  1. 69 28
      swp2.py

+ 69 - 28
swp2.py View File

@@ -5,6 +5,8 @@ import math
5 5
 from scipy.special import gammaln
6 6
 from matplotlib.backends.backend_pdf import PdfPages
7 7
 from matplotlib.ticker import MaxNLocator
8
+import io
9
+from mpl_toolkits.axes_grid1.inset_locator import inset_axes
8 10
 
9 11
 def log_facto(k):
10 12
     k = int(k)
@@ -321,7 +323,7 @@ def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_sc
321 323
     plt.title(title)
322 324
     plt.savefig(title+'_Breakpts_Likelihood_AIC.pdf')
323 325
 
324
-def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True, breaks_max = 12):
326
+def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True, breaks_max = 12, ax = None):
325 327
     """
326 328
     Use theta values as is to do basic plots.
327 329
     """
@@ -343,12 +345,16 @@ def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
343 345
     my_dpi = 300
344 346
     # plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
345 347
     # multiple fig
346
-    fig, ax = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
347
-    # Add some extra space for the second axis at the bottom
348
-    fig.subplots_adjust(bottom=0.15)
349
-    twin = ax.twiny()
350
-    # Offset the right spine of twin2
351
-    # twin.spines.right.set_position(("axes", 1.2))
348
+    if ax is None:
349
+        fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
350
+        # Add some extra space for the second axis at the bottom
351
+        fig.subplots_adjust(bottom=0.15)
352
+    else:
353
+        ax1 = ax[0, 1]
354
+        plt.subplots_adjust(wspace=0.3, hspace=0.3)
355
+
356
+    twin = ax1.twiny()
357
+
352 358
     plots = []
353 359
     for epoch, theta in epochs.items():
354 360
         groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
@@ -374,7 +380,7 @@ def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
374 380
             y[i] = y[i]/N0
375 381
         # plot
376 382
         #plt.plot(x, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
377
-        p, = ax.plot(x, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
383
+        p, = ax1.plot(x, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
378 384
         # add plot to the list of all plots to superimpose
379 385
         plots.append(p)
380 386
     # virtual line to get the second x axis for proportions
@@ -383,17 +389,27 @@ def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
383 389
     twin.xaxis.set_ticks_position("bottom")
384 390
     twin.xaxis.set_label_position("bottom")
385 391
     # Offset the twin axis below the host
386
-    twin.spines["bottom"].set_position(("axes", -0.1))
392
+    twin.spines["bottom"].set_position(("axes", -0.15))
387 393
     #ax.legend(handles=[p0]+plots)
388
-    ax.set_xlabel("# breaks")
389
-    # Set the x-axis locator to reduce the number of ticks = 10
390
-    ax.xaxis.set_major_locator(MaxNLocator(nbins=10))
391
-    ax.set_ylabel("theta")
394
+    ax1.set_xlabel("# breaks")
395
+    # Set the x-axis locator to reduce the number of ticks to 10
396
+    ax1.xaxis.set_major_locator(MaxNLocator(nbins=10))
397
+    ax1.set_ylabel("theta")
392 398
     # twin.set_ylabel("Proportion")
393 399
     plt.legend(handles=plots, loc='upper right')
394
-    plt.savefig(title+'_raw'+str(k)+'.pdf')
400
+    if ax is None:
401
+        plt.savefig(title+'_raw'+str(k)+'.pdf')
395 402
     # fig 2 & 3
396
-    plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
403
+    if ax is None:
404
+        fig2, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
405
+        fig3, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
406
+    else:
407
+        # place of plots on the grid
408
+        ax2 = ax[1,0]
409
+        ax3 = ax[1,1]
410
+    lines_fig2 = []
411
+    lines_fig3 = []
412
+    #plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
397 413
     for epoch, theta in epochs.items():
398 414
         groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
399 415
         x = []
@@ -415,18 +431,30 @@ def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
415 431
             T += y[i] / (x[i]*(x[i]-1))
416 432
             x_2.append(T)
417 433
         # Plotting (fig 2)
418
-        plt.plot(x_2, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
419
-        plt.xlabel("# breaks")
420
-        plt.ylabel("theta")
421
-        plt.legend(loc='upper right')
422
-        plt.savefig(title+'_plot2_'+str(k)+'.pdf')
434
+        p2, = ax2.plot(x_2, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
435
+        lines_fig2.append(p2)
423 436
         # Plotting (fig 3) which is the same but log scale for x
424
-        plt.plot(x_2, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
425
-        plt.xscale('log')
426
-        plt.xlabel("# breaks")
427
-        plt.ylabel("theta")
428
-        plt.legend(loc='upper right')
437
+        p3, = ax3.plot(x_2, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
438
+        lines_fig3.append(p3)
439
+    ax2.set_xlabel("# breaks")
440
+    ax2.set_ylabel("theta")
441
+    ax2.set_title("Test")
442
+    ax2.legend(handles=lines_fig2, loc='upper right')
443
+    if ax is None:
444
+        plt.savefig(title+'_plot2_'+str(k)+'.pdf')
445
+    ax3.set_xscale('log')
446
+    ax3.set_xlabel("log()")
447
+    ax3.set_ylabel("theta")
448
+    ax3.set_title("Test")
449
+    ax3.legend(handles=lines_fig3, loc='upper right')
450
+    if ax is None:
429 451
         plt.savefig(title+'_plot3_'+str(k)+'_log.pdf')
452
+    # return plots
453
+    return ax
454
+
455
+def save_combined_pdf(output_path):
456
+    with PdfPages(output_path) as pdf:
457
+        pdf.savefig()
430 458
 
431 459
 def save_multi_image(filename):
432 460
     pp = PdfPages(filename)
@@ -437,9 +465,22 @@ def save_multi_image(filename):
437 465
     pp.close()
438 466
 
439 467
 def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale = True):
440
-    plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale)
441
-    plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks)
442
-    save_multi_image(title+"_combined.pdf")
468
+    # plot1, plot2, plot3 = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale)
469
+    my_dpi = 300
470
+    # Add some extra space for the second axis at the bottom
471
+    fig, axs = plt.subplots(2, 2, figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
472
+
473
+    ax = plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = axs)
474
+
475
+    # Adjust layout to prevent clipping of titles
476
+    plt.tight_layout()
477
+    # Adjust absolute space between the top and bottom rows
478
+    plt.subplots_adjust(hspace=0.35)  # Adjust this value based on your requirement
479
+
480
+    # Save the entire grid as a single figure
481
+    plt.savefig(title+'_combined.pdf')
482
+
483
+
443 484
 if __name__ == "__main__":
444 485
 
445 486
     if len(sys.argv) != 4: