Browse Source

Get rid of the old plot_all_epochs_theta

tforest 1 year ago
parent
commit
fed1a36d79
1 changed files with 88 additions and 358 deletions
  1. 88 358
      swp2.py

+ 88 - 358
swp2.py View File

@@ -112,53 +112,6 @@ def parse_stwp_theta_file(stwp_theta_file, breaks, mu, tgen, relative_theta_scal
112 112
 
113 113
     return x,y,likelihood,thetas,sfs,L
114 114
 
115
-def plot_k_epochs_thetafolder(folder_path, mu, tgen, breaks = 2, title = "Title", theta_scale = True):
116
-    scenari = {}
117
-    cpt = 0
118
-    for file_name in os.listdir(folder_path):
119
-        if os.path.isfile(os.path.join(folder_path, file_name)):
120
-            # Perform actions on each file
121
-            x, y, likelihood, theta, sfs, L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
122
-                                                             tgen = tgen,
123
-                                     mu = mu, relative_theta_scale = theta_scale)
124
-            if x == 0 or y == 0:
125
-                continue
126
-            cpt +=1
127
-            scenari[likelihood] = x,y
128
-    print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
129
-    print(cpt, "theta file(s) have been scanned.")
130
-    # sort starting by the smallest -log(Likelihood)
131
-    print(scenari)
132
-    best10_scenari = (sorted(list(scenari.keys())))[:10]
133
-    print("10 greatest Likelihoods", best10_scenari)
134
-    greatest_likelihood = best10_scenari[0]
135
-    x, y = scenari[greatest_likelihood]
136
-    my_dpi = 300
137
-    plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
138
-    plt.plot(x, y, 'r-', lw=2, label = 'Lik='+greatest_likelihood)
139
-    #plt.yscale('log')
140
-    plt.xscale('log')
141
-    plt.grid(True,which="both", linestyle='--', alpha = 0.3)
142
-
143
-    for scenario in best10_scenari[1:]:
144
-        x,y = scenari[scenario]
145
-        #print("\n----  Lik:",scenario,"\n\nt=", x,"\n\nN=",y, "\n\n")
146
-        plt.plot(x, y, '--', lw=1, label = 'Lik='+scenario)
147
-    if theta_scale:
148
-        plt.xlabel("Coal. time")
149
-        plt.ylabel("Pop. size scaled by N0")
150
-        recent_scale_lower_bound = y[0] * 0.01
151
-        recent_scale_upper_bound = y[0] * 0.1
152
-        plt.axvline(x=recent_scale_lower_bound)
153
-        plt.axvline(x=recent_scale_upper_bound)
154
-    else:
155
-        # years
156
-        plt.xlabel("Time (years)")
157
-        plt.ylabel("Individuals (N)")
158
-    plt.legend(loc='upper right')
159
-    plt.title(title)
160
-    plt.savefig(title+'_b'+str(breaks)+'.pdf')
161
-
162 115
 def plot_straight_x_y(x,y):
163 116
     x_1 = [x[0]]
164 117
     y_1 = []
@@ -171,7 +124,7 @@ def plot_straight_x_y(x,y):
171 124
     x_1.append(x[-1])
172 125
     return x_1, y_1
173 126
 
174
-def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title",
127
+def plot_all_epochs_thetafolder_old(folder_path, mu, tgen, title = "Title",
175 128
     theta_scale = True, ax = None, input = None, output = None):
176 129
     #scenari = {}
177 130
     cpt = 0
@@ -323,6 +276,88 @@ def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title",
323 276
     # return plots
324 277
     return ax[0], ax[1]
325 278
 
279
+def plot_all_epochs_thetafolder(full_dict, mu, tgen, title = "Title",
280
+    theta_scale = True, ax = None, input = None, output = None):
281
+    my_dpi = 300
282
+    if ax is None:
283
+        # intialize figure
284
+        my_dpi = 300
285
+        fnt_size = 18
286
+        # plt.rcParams['font.size'] = fnt_size
287
+        fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
288
+    else:
289
+        fnt_size = 12
290
+        # plt.rcParams['font.size'] = fnt_size
291
+        ax1 = ax[1][0,0]
292
+    ax1.set_yscale('log')
293
+    ax1.set_xscale('log')
294
+    ax1.grid(True,which="both", linestyle='--', alpha = 0.3)
295
+    plot_handles = []
296
+    best_plot = full_dict['all_epochs']['best']
297
+    p0, = ax1.plot(best_plot[0], best_plot[1], 'o', linestyle = "-",
298
+    alpha=1, lw=2, label = str(best_plot[2])+' brks | Lik='+best_plot[3])
299
+    plot_handles.append(p0)
300
+    for k, plot_Lk in enumerate(full_dict['all_epochs']['plots']):
301
+        plot_Lk = str(full_dict['all_epochs']['plots'][k][3])
302
+        # plt.rcParams['font.size'] = fnt_size
303
+        p, = ax1.plot(full_dict['all_epochs']['plots'][k][0], full_dict['all_epochs']['plots'][k][1], 'o', linestyle = "--",
304
+        alpha=1/(k+1), lw=1.5, label = str(full_dict['all_epochs']['plots'][k][2])+' brks | Lik='+plot_Lk)
305
+        plot_handles.append(p)
306
+    if theta_scale:
307
+        ax1.set_xlabel("Coal. time", fontsize=fnt_size)
308
+        ax1.set_ylabel("Pop. size scaled by N0", fontsize=fnt_size)
309
+        # recent_scale_lower_bound = 0.01
310
+        # recent_scale_upper_bound = 0.1
311
+        # ax1.axvline(x=recent_scale_lower_bound)
312
+        # ax1.axvline(x=recent_scale_upper_bound)
313
+    else:
314
+        # years
315
+        plt.set_xlabel("Time (years)", fontsize=fnt_size)
316
+        plt.set_ylabel("Individuals (N)", fontsize=fnt_size)
317
+    # plt.rcParams['font.size'] = fnt_size
318
+    # print(fnt_size, "rcParam font.size=", plt.rcParams['font.size'])
319
+    ax1.legend(handles = plot_handles, loc='best', fontsize = fnt_size*0.5)
320
+    ax1.set_title(title)
321
+    if ax is None:
322
+        plt.savefig(title+'_b'+str(breaks)+'.pdf')
323
+    # plot likelihood against nb of breakpoints
324
+    if ax is None:
325
+        fig, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
326
+        # plt.rcParams['font.size'] = fnt_size
327
+    else:
328
+        #plt.rcParams['font.size'] = fnt_size
329
+        ax2 = ax[0][0,1]
330
+
331
+    ax2.plot(full_dict['Ln_Brks'][0], full_dict['Ln_Brks'][1], 'o', linestyle = "dotted", lw=2)
332
+    ax2.axhline(y=full_dict['best_Ln'], linestyle = "-.", color = "red", label = "$-\log\mathcal{L}$ = "+str(round(full_dict['best_Ln'], 2)))
333
+    ax2.set_yscale('log')
334
+    ax2.set_xlabel("# breakpoints", fontsize=fnt_size)
335
+    ax2.set_ylabel("$-\log\mathcal{L}$", fontsize=fnt_size)
336
+    ax2.legend(loc='best', fontsize = fnt_size*0.5)
337
+    ax2.set_title(title+" Likelihood gain from # breakpoints")
338
+    if ax is None:
339
+        plt.savefig(title+'_Breakpts_Likelihood.pdf')
340
+    # AIC
341
+    if ax is None:
342
+        fig, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
343
+        # plt.rcParams['font.size'] = '18'
344
+    else:
345
+        #plt.rcParams['font.size'] = fnt_size
346
+        ax3 = ax[1][0,1]
347
+    AIC = full_dict['AIC_Brks']
348
+    ax3.plot(AIC[0], AIC[1], 'o', linestyle = "dotted", lw=2)
349
+    ax3.axhline(y=full_dict['best_AIC'], linestyle = "-.", color = "red",
350
+    label = "Min. AIC = "+str(round(full_dict['best_AIC'], 2)))
351
+    ax3.set_yscale('log')
352
+    ax3.set_xlabel("# breakpoints", fontsize=fnt_size)
353
+    ax3.set_ylabel("AIC")
354
+    ax3.legend(loc='best', fontsize = fnt_size*0.5)
355
+    ax3.set_title(title+" AIC")
356
+    if ax is None:
357
+        plt.savefig(title+'_Breakpts_Likelihood_AIC.pdf')
358
+    # return plots
359
+    return ax[0], ax[1]
360
+
326 361
 def save_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_scale = True, input = None, output = None):
327 362
     #scenari = {}
328 363
     cpt = 0
@@ -351,7 +386,6 @@ def save_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_sc
351 386
                 breaks -= 1
352 387
     print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
353 388
     print(cpt, "theta file(s) have been scanned.")
354
-
355 389
     brkpt_lik = []
356 390
     top_plots = {}
357 391
     for epoch, scenari in epochs.items():
@@ -378,10 +412,10 @@ def save_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_sc
378 412
     top_plot_lik = str(best10_plots[0])
379 413
     # store x,y,brks,likelihood
380 414
     plots['best'] = (top_plots[top_plot_lik][0], top_plots[top_plot_lik][1], str(top_plots[top_plot_lik][2]), top_plot_lik)
415
+    plots['plots'] = []
381 416
     for k, plot_Lk in enumerate(best10_plots[1:]):
382 417
         plot_Lk = str(plot_Lk)
383
-        plots[str(top_plots[plot_Lk][2])] = (top_plots[plot_Lk][0], top_plots[plot_Lk][1], str(top_plots[plot_Lk][2]), plot_Lk)
384
-
418
+        plots['plots'].append([top_plots[plot_Lk][0], top_plots[plot_Lk][1], str(top_plots[plot_Lk][2]), plot_Lk])
385 419
     # plot likelihood against nb of breakpoints
386 420
     # best possible likelihood from SFS
387 421
     # Segregating sites
@@ -408,7 +442,6 @@ def save_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_sc
408 442
     # AIC = 2*k - 2ln(L) ; where k is the number of parameters, here brks+1
409 443
     AIC_ln = 2*(len(brkpt_lik)+1) - 2*Ln
410 444
     best_AIC = AIC_ln
411
-
412 445
     # to return : plots ; Ln_Brks ; AIC_Brks ; best_Ln ; best_AIC
413 446
     # 'plots' dict keys: 'best', {epochs}('0', '1',...)
414 447
     if input == None:
@@ -430,157 +463,6 @@ def save_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_sc
430 463
         json.dump(saved_plots, json_file)
431 464
     return saved_plots
432 465
 
433
-def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_scale = True, ax = None):
434
-    #scenari = {}
435
-    cpt = 0
436
-    epochs = {}
437
-    for file_name in os.listdir(folder_path):
438
-        breaks = 0
439
-        cpt +=1
440
-        if os.path.isfile(os.path.join(folder_path, file_name)):
441
-            x, y, likelihood, theta, sfs, L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
442
-                                                             tgen = tgen,
443
-                                                             mu = mu, relative_theta_scale = theta_scale)
444
-            SFS_stored = sfs
445
-            L_stored = L
446
-            while not (x == 0 and y == 0):
447
-                if breaks not in epochs.keys():
448
-                    epochs[breaks] = {}
449
-                epochs[breaks][likelihood] = x,y
450
-                breaks += 1
451
-                x,y,likelihood,theta,sfs,L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
452
-                                                                 tgen = tgen,
453
-                                                                  mu = mu, relative_theta_scale = theta_scale)
454
-            if x == 0:
455
-                # last break did not work, then breaks = breaks-1
456
-                breaks -= 1
457
-    print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
458
-    print(cpt, "theta file(s) have been scanned.")
459
-    my_dpi = 300
460
-    if ax is None:
461
-        # intialize figure
462
-        my_dpi = 300
463
-        fnt_size = 18
464
-        # plt.rcParams['font.size'] = fnt_size
465
-        fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
466
-    else:
467
-        fnt_size = 12
468
-        # plt.rcParams['font.size'] = fnt_size
469
-        ax1 = ax[1][0,0]
470
-    ax1.set_yscale('log')
471
-    ax1.set_xscale('log')
472
-    ax1.grid(True,which="both", linestyle='--', alpha = 0.3)
473
-    brkpt_lik = []
474
-    top_plots = {}
475
-    for epoch, scenari in epochs.items():
476
-        # sort starting by the smallest -log(Likelihood)
477
-        best10_scenari = (sorted(list(scenari.keys())))[:10]
478
-        greatest_likelihood = best10_scenari[0]
479
-        # store the tuple breakpoints and likelihood for later plot
480
-        brkpt_lik.append((epoch, greatest_likelihood))
481
-        x, y = scenari[greatest_likelihood]
482
-        #without breakpoint
483
-        if epoch == 0:
484
-            # do something with the theta without bp and skip the plotting
485
-            N0 = y[0]
486
-            #continue
487
-        for i in range(len(y)):
488
-            # divide by N0
489
-            y[i] = y[i]/N0
490
-            x[i] = x[i]/N0
491
-        top_plots[greatest_likelihood] = x,y,epoch
492
-    plots_likelihoods = list(top_plots.keys())
493
-    for i in range(len(plots_likelihoods)):
494
-        plots_likelihoods[i] = float(plots_likelihoods[i])
495
-    best10_plots = sorted(plots_likelihoods)[:10]
496
-    top_plot_lik = str(best10_plots[0])
497
-    plot_handles = []
498
-    # plt.rcParams['font.size'] = fnt_size
499
-    p0, = ax1.plot(top_plots[top_plot_lik][0], top_plots[top_plot_lik][1], 'o', linestyle = "-",
500
-    alpha=1, lw=2, label = str(top_plots[top_plot_lik][2])+' brks | Lik='+top_plot_lik)
501
-    plot_handles.append(p0)
502
-    for k, plot_Lk in enumerate(best10_plots[1:]):
503
-        plot_Lk = str(plot_Lk)
504
-        # plt.rcParams['font.size'] = fnt_size
505
-        p, = ax1.plot(top_plots[plot_Lk][0], top_plots[plot_Lk][1], 'o', linestyle = "--",
506
-        alpha=1/(k+1), lw=1.5, label = str(top_plots[plot_Lk][2])+' brks | Lik='+plot_Lk)
507
-        plot_handles.append(p)
508
-    if theta_scale:
509
-        ax1.set_xlabel("Coal. time", fontsize=fnt_size)
510
-        ax1.set_ylabel("Pop. size scaled by N0", fontsize=fnt_size)
511
-        # recent_scale_lower_bound = 0.01
512
-        # recent_scale_upper_bound = 0.1
513
-        # ax1.axvline(x=recent_scale_lower_bound)
514
-        # ax1.axvline(x=recent_scale_upper_bound)
515
-    else:
516
-        # years
517
-        plt.set_xlabel("Time (years)", fontsize=fnt_size)
518
-        plt.set_ylabel("Individuals (N)", fontsize=fnt_size)
519
-    # plt.rcParams['font.size'] = fnt_size
520
-    # print(fnt_size, "rcParam font.size=", plt.rcParams['font.size'])
521
-    ax1.legend(handles = plot_handles, loc='best', fontsize = fnt_size*0.5)
522
-    ax1.set_title(title)
523
-    if ax is None:
524
-        plt.savefig(title+'_b'+str(breaks)+'.pdf')
525
-    # plot likelihood against nb of breakpoints
526
-    # best possible likelihood from SFS
527
-    # Segregating sites
528
-    S = sum(SFS_stored)
529
-    # Number of kept sites from which the SFS is computed
530
-    L = L_stored
531
-    # number of monomorphic sites
532
-    S0 = L-S
533
-    # print("SFS", SFS_stored)
534
-    # print("S", S, "L", L, "S0=", S0)
535
-    # compute Ln
536
-    Ln = log_facto(S+S0) - log_facto(S0) + np.log(float(S0)/(S+S0)) * S0
537
-    for xi in range(0, len(SFS_stored)):
538
-        p_i = SFS_stored[xi] / float(S+S0)
539
-        Ln += np.log(p_i) * SFS_stored[xi] - log_facto(SFS_stored[xi])
540
-    # basic plot likelihood
541
-    if ax is None:
542
-        fig, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
543
-        # plt.rcParams['font.size'] = fnt_size
544
-    else:
545
-        #plt.rcParams['font.size'] = fnt_size
546
-        ax2 = ax[0][0,1]
547
-    ax2.plot(np.array(brkpt_lik)[:, 0], np.array(brkpt_lik)[:, 1].astype(float), 'o', linestyle = "dotted", lw=2)
548
-    ax2.axhline(y=-Ln, linestyle = "-.", color = "red", label = "$-\log\mathcal{L}$ = "+str(round(-Ln, 2)))
549
-    ax2.set_yscale('log')
550
-    ax2.set_xlabel("# breakpoints", fontsize=fnt_size)
551
-    ax2.set_ylabel("$-\log\mathcal{L}$", fontsize=fnt_size)
552
-    ax2.legend(loc='best', fontsize = fnt_size*0.5)
553
-    ax2.set_title(title+" Likelihood gain from # breakpoints")
554
-    if ax is None:
555
-        plt.savefig(title+'_Breakpts_Likelihood.pdf')
556
-    # AIC
557
-    if ax is None:
558
-        fig, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
559
-        # plt.rcParams['font.size'] = '18'
560
-    else:
561
-        #plt.rcParams['font.size'] = fnt_size
562
-        ax3 = ax[1][0,1]
563
-    AIC = []
564
-    for brk in np.array(brkpt_lik)[:, 0]:
565
-        brk = int(brk)
566
-        AIC.append((2*brk+1)+2*np.array(brkpt_lik)[brk, 1].astype(float))
567
-    ax3.plot(np.array(brkpt_lik)[:, 0], AIC, 'o', linestyle = "dotted", lw=2)
568
-    # AIC = 2*k - 2ln(L) ; where k is the number of parameters, here brks+1
569
-    AIC_ln = 2*(len(brkpt_lik)+1) - 2*Ln
570
-    ax3.axhline(y=AIC_ln, linestyle = "-.", color = "red",
571
-    label = "Min. AIC = "+str(round(AIC_ln, 2)))
572
-    selected_brks_nb = AIC.index(min(AIC))
573
-    ax3.set_yscale('log')
574
-    ax3.set_xlabel("# breakpoints", fontsize=fnt_size)
575
-    ax3.set_ylabel("AIC")
576
-    ax3.legend(loc='best', fontsize = fnt_size*0.5)
577
-    ax3.set_title(title+" AIC")
578
-    if ax is None:
579
-        plt.savefig(title+'_Breakpts_Likelihood_AIC.pdf')
580
-    print("S", S)
581
-    # return plots
582
-    return ax[0], ax[1]
583
-
584 466
 def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
585 467
     breaks_max = 10, input = None, output = None):
586 468
     """
@@ -784,159 +666,6 @@ def plot_raw_stairs(plot_lines, prop, title, ax = None, n_ticks = 10):
784 666
     # return plots
785 667
     return ax
786 668
 
787
-def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True, breaks_max = 10, ax = None, n_ticks = 10):
788
-    """
789
-    Use theta values as is to do basic plots.
790
-    """
791
-    cpt = 0
792
-    epochs = {}
793
-    len_sfs = 0
794
-    for file_name in os.listdir(folder_path):
795
-        cpt +=1
796
-        if os.path.isfile(os.path.join(folder_path, file_name)):
797
-            for k in range(breaks_max):
798
-                x, y, likelihood, theta, sfs, L = parse_stwp_theta_file(folder_path+file_name, breaks = k,
799
-                                                                 tgen = tgen,
800
-                                                                 mu = mu, relative_theta_scale = theta_scale)
801
-                if thetas == 0:
802
-                    continue
803
-                if len(thetas)-1 != k:
804
-                    continue
805
-                if k not in epochs.keys():
806
-                    epochs[k] = {}
807
-                likelihood = str(eval(thetas[k][2]))
808
-                epochs[k][likelihood] = thetas
809
-                #epochs[k] = thetas
810
-    print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(k)+"\n*******\n")
811
-    print(cpt, "theta file(s) have been scanned.")
812
-    # multiple fig
813
-    if ax is None:
814
-        # intialize figure 1
815
-        my_dpi = 300
816
-        fnt_size = 18
817
-        # plt.rcParams['font.size'] = fnt_size
818
-        fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
819
-    else:
820
-        fnt_size = 12
821
-        # plt.rcParams['font.size'] = fnt_size
822
-        ax1 = ax[0, 1]
823
-        plt.subplots_adjust(wspace=0.3, hspace=0.3)
824
-    plots = []
825
-    best_epochs = {}
826
-    for epoch in epochs:
827
-        likelihoods = []
828
-        for key in epochs[epoch].keys():
829
-            likelihoods.append(key)
830
-        likelihoods.sort()
831
-        minLogLn = str(likelihoods[0])
832
-        best_epochs[epoch] = epochs[epoch][minLogLn]
833
-    for epoch, theta in best_epochs.items():
834
-        groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
835
-        x = []
836
-        y = []
837
-        thetas = np.array(list(theta.values()), dtype=object)[:, 0]
838
-        for i,group in enumerate(groups):
839
-            x += group[::-1]
840
-            y += list(np.repeat(thetas[i], len(group)))
841
-            if epoch == 0:
842
-                N0 = y[0]
843
-                # compute the proportion of information used at each bin of the SFS
844
-                sum_theta_i = 0
845
-                for i in range(2, len(y)+2):
846
-                    sum_theta_i+=y[i-2] / (i-1)
847
-                prop = []
848
-                for k in range(2, len(y)+2):
849
-                    prop.append(y[k-2] / (k - 1) / sum_theta_i)
850
-                prop = prop[::-1]
851
-                # print(prop, "\n", sum(prop))
852
-        # normalise to N0 (N0 of epoch1)
853
-        x_ticks = ax1.get_xticks()
854
-        for i in range(len(y)):
855
-            y[i] = y[i]/N0
856
-        # plot
857
-        x_plot, y_plot = plot_straight_x_y(x, y)
858
-        #plt.plot(x, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
859
-        p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
860
-        # add plot to the list of all plots to superimpose
861
-        plots.append(p)
862
-    #print(prop, "\n", sum(prop))
863
-    #ax.legend(handles=[p0]+plots)
864
-    ax1.set_xlabel("# bin", fontsize=fnt_size)
865
-    # Set the x-axis locator to reduce the number of ticks to 10
866
-    ax1.set_ylabel("theta", fontsize=fnt_size)
867
-    ax1.set_title(title, fontsize=fnt_size)
868
-    ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
869
-    ax1.set_xticks(x_ticks)
870
-    if len(prop) >= 18:
871
-        ax1.locator_params(nbins=n_ticks)
872
-    # new scale of ticks if too many values
873
-    cumul = 0
874
-    prop_cumul = []
875
-    for val in prop:
876
-        prop_cumul.append(val+cumul)
877
-        cumul = val+cumul
878
-    ax1.set_xticklabels([f'{x[k]}\n{val:.2f}' for k, val in enumerate(prop_cumul)])
879
-    if ax is None:
880
-        plt.savefig(title+'_raw'+str(k)+'.pdf')
881
-    # fig 2 & 3
882
-    if ax is None:
883
-        fig2, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
884
-        fig3, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
885
-    else:
886
-        # plt.rcParams['font.size'] = fnt_size
887
-        # place of plots on the grid
888
-        ax2 = ax[1,0]
889
-        ax3 = ax[1,1]
890
-    lines_fig2 = []
891
-    lines_fig3 = []
892
-    #plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
893
-    for epoch, theta in best_epochs.items():
894
-        groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
895
-        x = []
896
-        y = []
897
-        thetas = np.array(list(theta.values()), dtype=object)[:, 0]
898
-        for i,group in enumerate(groups):
899
-            x += group[::-1]
900
-            y += list(np.repeat(thetas[i], len(group)))
901
-            if epoch == 0:
902
-                N0 = y[0]
903
-        for i in range(len(y)):
904
-            y[i] = y[i]/N0
905
-        x_2 = []
906
-        T = 0
907
-        for i in range(len(x)):
908
-            x[i] = int(x[i])
909
-        # compute the times as: theta_k / (k*(k-1))
910
-        for i in range(0, len(x)):
911
-            T += y[i] / (x[i]*(x[i]-1))
912
-            x_2.append(T)
913
-        # Plotting (fig 2)
914
-        x_2 = [0]+x_2
915
-        y = [y[0]]+y
916
-        x2_plot, y2_plot = plot_straight_x_y(x_2, y)
917
-        p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
918
-        lines_fig2.append(p2)
919
-        # Plotting (fig 3) which is the same but log scale for x
920
-        p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
921
-        lines_fig3.append(p3)
922
-    ax2.set_xlabel("Relative scale", fontsize=fnt_size)
923
-    ax2.set_ylabel("theta", fontsize=fnt_size)
924
-    ax2.set_title(title, fontsize=fnt_size)
925
-    ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
926
-    if ax is None:
927
-        plt.savefig(title+'_plot2_'+str(k)+'.pdf')
928
-    ax3.set_xscale('log')
929
-    ax3.set_yscale('log')
930
-    ax3.set_xlabel("log Relative scale", fontsize=fnt_size)
931
-    ax3.set_ylabel("theta", fontsize=fnt_size)
932
-    ax3.set_title(title, fontsize=fnt_size)
933
-    ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5)
934
-    if ax is None:
935
-        plt.savefig(title+'_plot3_'+str(k)+'_log.pdf')
936
-        plt.clf()
937
-    # return plots
938
-    return ax
939
-
940 669
 def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale = True):
941 670
     my_dpi = 300
942 671
     # # Add some extra space for the second axis at the bottom
@@ -956,6 +685,8 @@ def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale =
956 685
     # # plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = None)
957 686
     # # plt.clf()
958 687
     save_k_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, output = title+"_plotdata.json")
688
+    save_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, input = title+"_plotdata.json", output = title+"_plotdata.json")
689
+
959 690
     with open(title+"_plotdata.json", 'r') as json_file:
960 691
         loaded_data = json.load(json_file)
961 692
     # plot page 1 of summary
@@ -971,8 +702,7 @@ def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale =
971 702
 
972 703
     ax1 = plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'],
973 704
                             prop = loaded_data['prop'], title = title, ax = ax1)
974
-    ax1, ax2 = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = [ax1, ax2])
975
-    save_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, input = title+"_plotdata.json", output = title+"_plotdata.json")
705
+    ax1, ax2 = plot_all_epochs_thetafolder(loaded_data, mu, tgen, title, theta_scale, ax = [ax1, ax2])
976 706
     fig1.savefig(title+'_combined_p1.pdf')
977 707
     fig2.savefig(title+'_combined_p2.pdf')
978 708
     plot_raw_stairs(plot_lines = loaded_data['raw_stairs'],