Browse Source

Fixing proportion labels

tforest 4 months ago
parent
commit
eb2799bc98
1 changed files with 214 additions and 23 deletions
  1. 214 23
      swp2.py

+ 214 - 23
swp2.py View File

@@ -2,10 +2,11 @@ import matplotlib.pyplot as plt
2 2
 import os
3 3
 import numpy as np
4 4
 import math
5
+import json
6
+import io
5 7
 from scipy.special import gammaln
6 8
 from matplotlib.backends.backend_pdf import PdfPages
7 9
 from matplotlib.ticker import MaxNLocator
8
-import io
9 10
 from mpl_toolkits.axes_grid1.inset_locator import inset_axes
10 11
 from matplotlib.ticker import MultipleLocator
11 12
 def log_facto(k):
@@ -197,8 +198,6 @@ def plot_k_epochs_thetafolder(folder_path, mu, tgen, breaks = 2, title = "Title"
197 198
     my_dpi = 300
198 199
     plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
199 200
     plt.plot(x, y, 'r-', lw=2, label = 'Lik='+greatest_likelihood)
200
-    plt.xlim(1e-3, 1)
201
-    plt.ylim(0, 10)
202 201
     #plt.yscale('log')
203 202
     plt.xscale('log')
204 203
     plt.grid(True,which="both", linestyle='--', alpha = 0.3)
@@ -271,7 +270,6 @@ def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_sc
271 270
         fnt_size = 12
272 271
         # plt.rcParams['font.size'] = fnt_size
273 272
         ax1 = ax[0,0]
274
-    #ax1.set_xlim(1e-3, 1)
275 273
     ax1.set_yscale('log')
276 274
     ax1.set_xscale('log')
277 275
     ax1.grid(True,which="both", linestyle='--', alpha = 0.3)
@@ -321,7 +319,6 @@ def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_sc
321 319
         # years
322 320
         plt.set_xlabel("Time (years)", fontsize=fnt_size)
323 321
         plt.set_ylabel("Individuals (N)", fontsize=fnt_size)
324
-        ax1.set_xlim(1e-5, 1)
325 322
     # plt.rcParams['font.size'] = fnt_size
326 323
     # print(fnt_size, "rcParam font.size=", plt.rcParams['font.size'])
327 324
     ax1.legend(handles = plot_handles, loc='best', fontsize = fnt_size*0.5)
@@ -382,6 +379,189 @@ def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_sc
382 379
     # return plots
383 380
     return ax
384 381
 
382
+def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
383
+    breaks_max = 10, output = None):
384
+    """
385
+    Save theta values as is to do basic plots.
386
+    """
387
+    cpt = 0
388
+    epochs = {}
389
+    len_sfs = 0
390
+    for file_name in os.listdir(folder_path):
391
+        cpt +=1
392
+        if os.path.isfile(os.path.join(folder_path, file_name)):
393
+            for k in range(breaks_max):
394
+                thetas,sfs = return_x_y_from_stwp_theta_file_as_is(folder_path+file_name, breaks = k,
395
+                                                                 tgen = tgen,
396
+                                                                 mu = mu, relative_theta_scale = theta_scale)
397
+                if thetas == 0:
398
+                    continue
399
+                if len(thetas)-1 != k:
400
+                    continue
401
+                if k not in epochs.keys():
402
+                    epochs[k] = {}
403
+                likelihood = str(eval(thetas[k][2]))
404
+                epochs[k][likelihood] = thetas
405
+                #epochs[k] = thetas
406
+    print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(k)+"\n*******\n")
407
+    print(cpt, "theta file(s) have been scanned.")
408
+    plots = []
409
+    best_epochs = {}
410
+    for epoch in epochs:
411
+        likelihoods = []
412
+        for key in epochs[epoch].keys():
413
+            likelihoods.append(key)
414
+        likelihoods.sort()
415
+        minLogLn = str(likelihoods[0])
416
+        best_epochs[epoch] = epochs[epoch][minLogLn]
417
+    for epoch, theta in best_epochs.items():
418
+        groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
419
+        x = []
420
+        y = []
421
+        thetas = np.array(list(theta.values()), dtype=object)[:, 0]
422
+        for i,group in enumerate(groups):
423
+            x += group[::-1]
424
+            y += list(np.repeat(thetas[i], len(group)))
425
+            if epoch == 0:
426
+                N0 = y[0]
427
+                # compute the proportion of information used at each bin of the SFS
428
+                sum_theta_i = 0
429
+                for i in range(2, len(y)+2):
430
+                    sum_theta_i+=y[i-2] / (i-1)
431
+                prop = []
432
+                for k in range(2, len(y)+2):
433
+                    prop.append(y[k-2] / (k - 1) / sum_theta_i)
434
+                prop = prop[::-1]
435
+        # normalise to N0 (N0 of epoch1)
436
+        for i in range(len(y)):
437
+            y[i] = y[i]/N0
438
+        # x_plot, y_plot = plot_straight_x_y(x, y)
439
+        p = x, y
440
+        # add plot to the list of all plots to superimpose
441
+        plots.append(p)
442
+    cumul = 0
443
+    prop_cumul = []
444
+    for val in prop:
445
+        prop_cumul.append(val+cumul)
446
+        cumul = val+cumul
447
+    prop = prop_cumul
448
+
449
+    lines_fig2 = []
450
+    for epoch, theta in best_epochs.items():
451
+        groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
452
+        x = []
453
+        y = []
454
+        thetas = np.array(list(theta.values()), dtype=object)[:, 0]
455
+        for i,group in enumerate(groups):
456
+            x += group[::-1]
457
+            y += list(np.repeat(thetas[i], len(group)))
458
+            if epoch == 0:
459
+                N0 = y[0]
460
+        for i in range(len(y)):
461
+            y[i] = y[i]/N0
462
+        x_2 = []
463
+        T = 0
464
+        for i in range(len(x)):
465
+            x[i] = int(x[i])
466
+        # compute the times as: theta_k / (k*(k-1))
467
+        for i in range(0, len(x)):
468
+            T += y[i] / (x[i]*(x[i]-1))
469
+            x_2.append(T)
470
+        # Save plotting (fig 2)
471
+        x_2 = [0]+x_2
472
+        y = [y[0]]+y
473
+        # x2_plot, y2_plot = plot_straight_x_y(x_2, y)
474
+        p2 = x_2, y
475
+        lines_fig2.append(p2)
476
+
477
+    saved_plots = {"raw_stairs":plots, "scaled_stairs":lines_fig2,
478
+                    "prop":prop}
479
+    if output == None:
480
+        output = title+"_plotdata.json"
481
+    with open(output, 'w') as json_file:
482
+        json.dump(saved_plots, json_file)
483
+    return saved_plots
484
+
485
+def plot_raw_stairs(plot_lines, plot_lines2, prop, title, ax = None, n_ticks = 10):
486
+    # multiple fig
487
+    if ax is None:
488
+        # intialize figure 1
489
+        my_dpi = 300
490
+        fnt_size = 18
491
+        # plt.rcParams['font.size'] = fnt_size
492
+        fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
493
+    else:
494
+        fnt_size = 12
495
+        # plt.rcParams['font.size'] = fnt_size
496
+        ax1 = ax[0, 1]
497
+        plt.subplots_adjust(wspace=0.3, hspace=0.3)
498
+    plots = []
499
+
500
+    for epoch, plot in enumerate(plot_lines):
501
+        x,y = plot
502
+        x_plot, y_plot = plot_straight_x_y(x,y)
503
+        p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
504
+
505
+        # add plot to the list of all plots to superimpose
506
+        plots.append(p)
507
+    x_ticks = x
508
+    # print(x_ticks)
509
+    #print(prop, "\n", sum(prop))
510
+    #ax.legend(handles=[p0]+plots)
511
+    ax1.set_xlabel("# bin", fontsize=fnt_size)
512
+    # Set the x-axis locator to reduce the number of ticks to 10
513
+    ax1.set_ylabel("theta", fontsize=fnt_size)
514
+    ax1.set_title("Title", fontsize=fnt_size)
515
+    ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
516
+    ax1.set_xticks(x_ticks)
517
+    step = len(x_ticks)//(n_ticks-1)
518
+    values = x_ticks[::step]
519
+    new_prop = []
520
+    for val in values:
521
+        new_prop.append(prop[int(val)-2])
522
+    new_prop = new_prop[::-1]
523
+    ax1.set_xticks(values)
524
+    ax1.set_xticklabels([f'{values[k]}\n{val:.2f}' for k, val in enumerate(new_prop)], fontsize = fnt_size*0.8)
525
+    if ax is None:
526
+        plt.savefig(title+'_raw'+str(k)+'.pdf')
527
+    # fig 2 & 3
528
+    if ax is None:
529
+        fig2, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
530
+        fig3, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
531
+    else:
532
+        # plt.rcParams['font.size'] = fnt_size
533
+        # place of plots on the grid
534
+        ax2 = ax[1,0]
535
+        ax3 = ax[1,1]
536
+    lines_fig2 = []
537
+    lines_fig3 = []
538
+    #plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
539
+    for epoch, plot in enumerate(plot_lines2):
540
+        x,y=plot
541
+        x2_plot, y2_plot = plot_straight_x_y(x,y)
542
+        p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
543
+        lines_fig2.append(p2)
544
+        # Plotting (fig 3) which is the same but log scale for x
545
+        p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
546
+        lines_fig3.append(p3)
547
+    ax2.set_xlabel("Relative scale", fontsize=fnt_size)
548
+    ax2.set_ylabel("theta", fontsize=fnt_size)
549
+    ax2.set_title("Title", fontsize=fnt_size)
550
+    ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
551
+    if ax is None:
552
+        plt.savefig(title+'_plot2_'+str(k)+'.pdf')
553
+    ax3.set_xscale('log')
554
+    ax3.set_yscale('log')
555
+    ax3.set_xlabel("log Relative scale", fontsize=fnt_size)
556
+    ax3.set_ylabel("theta", fontsize=fnt_size)
557
+    ax3.set_title("Title", fontsize=fnt_size)
558
+    ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5)
559
+    if ax is None:
560
+        plt.savefig(title+'_plot3_'+str(k)+'_log.pdf')
561
+        plt.clf()
562
+    # return plots
563
+    return ax
564
+
385 565
 def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True, breaks_max = 10, ax = None, n_ticks = 10):
386 566
     """
387 567
     Use theta values as is to do basic plots.
@@ -402,7 +582,7 @@ def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
402 582
                     continue
403 583
                 if k not in epochs.keys():
404 584
                     epochs[k] = {}
405
-                likelihood = thetas[k][2]
585
+                likelihood = str(eval(thetas[k][2]))
406 586
                 epochs[k][likelihood] = thetas
407 587
                 #epochs[k] = thetas
408 588
     print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(k)+"\n*******\n")
@@ -424,7 +604,7 @@ def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
424 604
     for epoch in epochs:
425 605
         likelihoods = []
426 606
         for key in epochs[epoch].keys():
427
-            likelihoods.append(float(key))
607
+            likelihoods.append(key)
428 608
         likelihoods.sort()
429 609
         minLogLn = str(likelihoods[0])
430 610
         best_epochs[epoch] = epochs[epoch][minLogLn]
@@ -537,24 +717,35 @@ def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
537 717
 
538 718
 def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale = True):
539 719
     my_dpi = 300
540
-    # Add some extra space for the second axis at the bottom
541
-    #plt.rcParams['font.size'] = 18
542
-    fig, axs = plt.subplots(3, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
543
-    #plt.rcParams['font.size'] = 12
544
-    ax = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = axs)
545
-    ax = plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = axs)
546
-    # Adjust layout to prevent clipping of titles
547
-    plt.tight_layout()
548
-    # Adjust absolute space between the top and bottom rows
549
-    #plt.subplots_adjust(hspace=0.7)  # Adjust this value based on your requirement
550
-    # Save the entire grid as a single figure
551
-    plt.savefig(title+'_combined.pdf')
552
-    plt.clf()
553
-    # # second call for individual plots
554
-    # plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = None)
555
-    # plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = None)
720
+    # # Add some extra space for the second axis at the bottom
721
+    # #plt.rcParams['font.size'] = 18
722
+    # fig, axs = plt.subplots(2, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
723
+    # #plt.rcParams['font.size'] = 12
724
+    # ax = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = axs)
725
+    # ax = plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = axs)
726
+    # # Adjust layout to prevent clipping of titles
727
+    # plt.tight_layout()
728
+    # # Adjust absolute space between the top and bottom rows
729
+    # #plt.subplots_adjust(hspace=0.7)  # Adjust this value based on your requirement
730
+    # # Save the entire grid as a single figure
731
+    # plt.savefig(title+'_combined.pdf')
556 732
     # plt.clf()
733
+    # # # second call for individual plots
734
+    # # plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = None)
735
+    # # plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = None)
736
+    # # plt.clf()
737
+    # save_k_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, output = title+"_plotdata.json")
738
+
739
+    with open(title+"_plotdata.json", 'r') as json_file:
740
+        loaded_data = json.load(json_file)
741
+
742
+    fig1, ax1 = plt.subplots(2, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
743
+    # plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = ax1)
744
+    ax1 = plot_raw_stairs(plot_lines = loaded_data['raw_stairs'], plot_lines2 = loaded_data['scaled_stairs'],
745
+                            prop = loaded_data['prop'], title = title, ax = ax1)
557 746
 
747
+    plt.savefig(title+'_raw_scaled.pdf')
748
+    fig1.clf()
558 749
 if __name__ == "__main__":
559 750
 
560 751
     if len(sys.argv) != 4: