Browse Source

Update on swp2 plots

tforest 1 year ago
parent
commit
2e4aca9951
1 changed files with 76 additions and 41 deletions
  1. 76 41
      swp2.py

+ 76 - 41
swp2.py View File

216
     plt.title(title)
216
     plt.title(title)
217
     plt.savefig(title+'_b'+str(breaks)+'.pdf')
217
     plt.savefig(title+'_b'+str(breaks)+'.pdf')
218
 
218
 
219
+def plot_straight_x_y(x,y):
220
+    x_1 = [x[0]]
221
+    y_1 = []
222
+    for i in range(0, len(y)-1):
223
+        x_1.append(x[i])
224
+        x_1.append(x[i])
225
+        y_1.append(y[i])
226
+        y_1.append(y[i])
227
+    y_1 = y_1+[y[-1],y[-1]]
228
+    x_1.append(x[-1])
229
+    return x_1, y_1
230
+
219
 def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_scale = True, ax = None):
231
 def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_scale = True, ax = None):
220
     #scenari = {}
232
     #scenari = {}
221
     cpt = 0
233
     cpt = 0
237
                 x,y,likelihood,sfs,L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
249
                 x,y,likelihood,sfs,L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
238
                                                                  tgen = tgen,
250
                                                                  tgen = tgen,
239
                                                                   mu = mu, relative_theta_scale = theta_scale)
251
                                                                   mu = mu, relative_theta_scale = theta_scale)
252
+            if x == 0:
253
+                # last break did not work, then breaks = breaks-1
254
+                breaks -= 1
240
     print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
255
     print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
241
     print(cpt, "theta file(s) have been scanned.")
256
     print(cpt, "theta file(s) have been scanned.")
242
     my_dpi = 300
257
     my_dpi = 300
249
         fnt_size = 12
264
         fnt_size = 12
250
         plt.rcParams['font.size'] = fnt_size
265
         plt.rcParams['font.size'] = fnt_size
251
         ax1 = ax[0,0]
266
         ax1 = ax[0,0]
252
-    ax1.set_xlim(1e-3, 1)
253
-    #plt.ylim(0, 10)
267
+    #ax1.set_xlim(1e-3, 1)
254
     ax1.set_yscale('log')
268
     ax1.set_yscale('log')
255
     ax1.set_xscale('log')
269
     ax1.set_xscale('log')
256
     ax1.grid(True,which="both", linestyle='--', alpha = 0.3)
270
     ax1.grid(True,which="both", linestyle='--', alpha = 0.3)
257
     brkpt_lik = []
271
     brkpt_lik = []
272
+    top_plots = {}
258
     for epoch, scenari in epochs.items():
273
     for epoch, scenari in epochs.items():
259
         # sort starting by the smallest -log(Likelihood)
274
         # sort starting by the smallest -log(Likelihood)
260
         best10_scenari = (sorted(list(scenari.keys())))[:10]
275
         best10_scenari = (sorted(list(scenari.keys())))[:10]
271
             # divide by N0
286
             # divide by N0
272
             y[i] = y[i]/N0
287
             y[i] = y[i]/N0
273
             x[i] = x[i]/N0
288
             x[i] = x[i]/N0
274
-        ax1.plot(x, y, 'o', linestyle = "-", alpha=0.75, lw=2, label = str(epoch)+' BrkPt | Lik='+greatest_likelihood)
275
-        if theta_scale:
276
-            ax1.set_xlabel("Coal. time")
277
-            ax1.set_ylabel("Pop. size scaled by N0")
278
-            recent_scale_lower_bound = 0.01
279
-            recent_scale_upper_bound = 0.1
280
-            #print(recent_scale_lower_bound, recent_scale_upper_bound)
281
-            ax1.axvline(x=recent_scale_lower_bound)
282
-            ax1.axvline(x=recent_scale_upper_bound)
283
-        else:
284
-            # years
285
-            plt.set_xlabel("Time (years)")
286
-            plt.set_ylabel("Individuals (N)")
289
+        top_plots[greatest_likelihood] = x,y,epoch
290
+    plots_likelihoods = list(top_plots.keys())
291
+    for i in range(len(plots_likelihoods)):
292
+        plots_likelihoods[i] = float(plots_likelihoods[i])
293
+    best10_plots = sorted(plots_likelihoods)[:10]
294
+    top_plot_lik = str(best10_plots[0])
295
+    plot_handles = []
296
+    p0, = ax1.plot(top_plots[top_plot_lik][0], top_plots[top_plot_lik][1], 'o', linestyle = "-",
297
+    alpha=1, lw=2, label = str(top_plots[top_plot_lik][2])+' epoch | Lik='+top_plot_lik)
298
+    plot_handles.append(p0)
299
+    for k, plot_Lk in enumerate(best10_plots[1:]):
300
+        plot_Lk = str(plot_Lk)
301
+        p, = ax1.plot(top_plots[plot_Lk][0], top_plots[plot_Lk][1], 'o', linestyle = "--",
302
+        alpha=1/(k+1), lw=1.5, label = str(top_plots[plot_Lk][2])+' epoch | Lik='+plot_Lk)
303
+        plot_handles.append(p)
304
+    if theta_scale:
305
+        ax1.set_xlabel("Coal. time")
306
+        ax1.set_ylabel("Pop. size scaled by N0")
307
+        # recent_scale_lower_bound = 0.01
308
+        # recent_scale_upper_bound = 0.1
309
+        # ax1.axvline(x=recent_scale_lower_bound)
310
+        # ax1.axvline(x=recent_scale_upper_bound)
311
+    else:
312
+        # years
313
+        plt.set_xlabel("Time (years)")
314
+        plt.set_ylabel("Individuals (N)")
287
         ax1.set_xlim(1e-5, 1)
315
         ax1.set_xlim(1e-5, 1)
288
-        ax1.legend(loc='best', fontsize = fnt_size*0.5)
289
-        ax1.set_title(title)
290
-        if ax is None:
291
-            plt.savefig(title+'_b'+str(breaks)+'.pdf')
316
+    ax1.legend(handles = plot_handles, loc='best', fontsize = fnt_size*0.5)
317
+    ax1.set_title(title)
318
+    if ax is None:
319
+        plt.savefig(title+'_b'+str(breaks)+'.pdf')
292
     # plot likelihood against nb of breakpoints
320
     # plot likelihood against nb of breakpoints
293
     # best possible likelihood from SFS
321
     # best possible likelihood from SFS
294
     # Segregating sites
322
     # Segregating sites
295
     S = sum(SFS_stored)
323
     S = sum(SFS_stored)
296
-    # number of monomorphic sites
324
+    # Number of kept sites from which the SFS is computed
297
     L = L_stored
325
     L = L_stored
326
+    # number of monomorphic sites
298
     S0 = L-S
327
     S0 = L-S
299
     # print("SFS", SFS_stored)
328
     # print("SFS", SFS_stored)
300
     # print("S", S, "L", L, "S0=", S0)
329
     # print("S", S, "L", L, "S0=", S0)
303
     for xi in range(0, len(SFS_stored)):
332
     for xi in range(0, len(SFS_stored)):
304
         p_i = SFS_stored[xi] / float(S+S0)
333
         p_i = SFS_stored[xi] / float(S+S0)
305
         Ln += np.log(p_i) * SFS_stored[xi] - log_facto(SFS_stored[xi])
334
         Ln += np.log(p_i) * SFS_stored[xi] - log_facto(SFS_stored[xi])
306
-    res = Ln
307
-    # print(res)
308
     # basic plot likelihood
335
     # basic plot likelihood
309
     if ax is None:
336
     if ax is None:
310
         fig, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
337
         fig, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
313
         plt.rcParams['font.size'] = fnt_size
340
         plt.rcParams['font.size'] = fnt_size
314
         ax2 = ax[2,0]
341
         ax2 = ax[2,0]
315
     ax2.plot(np.array(brkpt_lik)[:, 0], np.array(brkpt_lik)[:, 1].astype(float), 'o', linestyle = "dotted", lw=2)
342
     ax2.plot(np.array(brkpt_lik)[:, 0], np.array(brkpt_lik)[:, 1].astype(float), 'o', linestyle = "dotted", lw=2)
316
-    # plt.ylim(0,100)
317
-    ax2.axhline(y=-Ln)
343
+    ax2.axhline(y=-Ln, linestyle = "-.", color = "red", label = "$-\log\mathcal{L}$ = "+str(round(-Ln, 2)))
318
     ax2.set_yscale('log')
344
     ax2.set_yscale('log')
319
     ax2.set_xlabel("# breakpoints", fontsize=fnt_size)
345
     ax2.set_xlabel("# breakpoints", fontsize=fnt_size)
320
     ax2.set_ylabel("$-\log\mathcal{L}$")
346
     ax2.set_ylabel("$-\log\mathcal{L}$")
321
-    #ax2.legend(loc='best', fontsize = fnt_size*0.5)
322
-    ax2.set_title(title)
347
+    ax2.legend(loc='best', fontsize = fnt_size*0.8)
348
+    ax2.set_title(title+" Likelihood gain from # breakpoints")
323
     if ax is None:
349
     if ax is None:
324
         plt.savefig(title+'_Breakpts_Likelihood.pdf')
350
         plt.savefig(title+'_Breakpts_Likelihood.pdf')
325
     # AIC
351
     # AIC
330
         ax3 = ax[2,1]
356
         ax3 = ax[2,1]
331
     AIC = 2*(len(brkpt_lik)+1)+2*np.array(brkpt_lik)[:, 1].astype(float)
357
     AIC = 2*(len(brkpt_lik)+1)+2*np.array(brkpt_lik)[:, 1].astype(float)
332
     ax3.plot(np.array(brkpt_lik)[:, 0], AIC, 'o', linestyle = "dotted", lw=2)
358
     ax3.plot(np.array(brkpt_lik)[:, 0], AIC, 'o', linestyle = "dotted", lw=2)
333
-    ax3.axhline(y=2*(len(brkpt_lik)+1)-2*Ln)
359
+    AIC_ln = 2*(len(brkpt_lik)+1)-2*Ln
360
+    ax3.axhline(y=AIC_ln, linestyle = "-.", color = "red",
361
+    label = "Min. AIC = "+str(round(AIC_ln, 2)))
334
     ax3.set_yscale('log')
362
     ax3.set_yscale('log')
335
     ax3.set_xlabel("# breakpoints", fontsize=fnt_size)
363
     ax3.set_xlabel("# breakpoints", fontsize=fnt_size)
336
     ax3.set_ylabel("AIC")
364
     ax3.set_ylabel("AIC")
337
-    #ax3.legend(loc='best', fontsize = fnt_size*0.5)
338
-    ax3.set_title(title)
365
+    ax3.legend(loc='best', fontsize = fnt_size*0.8)
366
+    ax3.set_title(title+" AIC")
339
     if ax is None:
367
     if ax is None:
340
         plt.savefig(title+'_Breakpts_Likelihood_AIC.pdf')
368
         plt.savefig(title+'_Breakpts_Likelihood_AIC.pdf')
341
     return ax
369
     return ax
396
         for i in range(len(y)):
424
         for i in range(len(y)):
397
             y[i] = y[i]/N0
425
             y[i] = y[i]/N0
398
         # plot
426
         # plot
427
+        x_plot, y_plot = plot_straight_x_y(x, y)
399
         #plt.plot(x, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
428
         #plt.plot(x, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
400
-        p, = ax1.plot(x, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
429
+        p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-.", alpha=0.75, lw=2, label = str(epoch)+' brks')
401
         # add plot to the list of all plots to superimpose
430
         # add plot to the list of all plots to superimpose
402
         plots.append(p)
431
         plots.append(p)
403
     # virtual line to get the second x axis for proportions
432
     # virtual line to get the second x axis for proportions
413
         # in a combined plot, more space between the fig and the axis
442
         # in a combined plot, more space between the fig and the axis
414
         twin.spines["bottom"].set_position(("axes", -0.35))
443
         twin.spines["bottom"].set_position(("axes", -0.35))
415
     #ax.legend(handles=[p0]+plots)
444
     #ax.legend(handles=[p0]+plots)
416
-    ax1.set_xlabel("# breaks")
445
+    ax1.set_xlabel("# bin")
417
     # Set the x-axis locator to reduce the number of ticks to 10
446
     # Set the x-axis locator to reduce the number of ticks to 10
418
     ax1.xaxis.set_major_locator(MaxNLocator(nbins=10))
447
     ax1.xaxis.set_major_locator(MaxNLocator(nbins=10))
448
+    twin.xaxis.set_major_locator(MaxNLocator(nbins=10))
419
     ax1.set_ylabel("theta")
449
     ax1.set_ylabel("theta")
420
     twin.set_ylabel("Proportion")
450
     twin.set_ylabel("Proportion")
451
+    ax1.set_title("Title")
421
     ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
452
     ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
422
     if ax is None:
453
     if ax is None:
423
         plt.savefig(title+'_raw'+str(k)+'.pdf')
454
         plt.savefig(title+'_raw'+str(k)+'.pdf')
453
             T += y[i] / (x[i]*(x[i]-1))
484
             T += y[i] / (x[i]*(x[i]-1))
454
             x_2.append(T)
485
             x_2.append(T)
455
         # Plotting (fig 2)
486
         # Plotting (fig 2)
456
-        p2, = ax2.plot(x_2, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
487
+        x_2 = [0]+x_2
488
+        y = [y[0]]+y
489
+        x2_plot, y2_plot = plot_straight_x_y(x_2, y)
490
+        p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-.", alpha=0.75, lw=2, label = str(epoch)+' brks')
457
         lines_fig2.append(p2)
491
         lines_fig2.append(p2)
458
         # Plotting (fig 3) which is the same but log scale for x
492
         # Plotting (fig 3) which is the same but log scale for x
459
-        p3, = ax3.plot(x_2, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
493
+        p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-.", alpha=0.75, lw=2, label = str(epoch)+' brks')
460
         lines_fig3.append(p3)
494
         lines_fig3.append(p3)
461
-    ax2.set_xlabel("# breaks")
495
+    ax2.set_xlabel("Relative scale")
462
     ax2.set_ylabel("theta")
496
     ax2.set_ylabel("theta")
463
-    ax2.set_title("Test")
497
+    ax2.set_title("Title")
464
     ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
498
     ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
465
     if ax is None:
499
     if ax is None:
466
         plt.savefig(title+'_plot2_'+str(k)+'.pdf')
500
         plt.savefig(title+'_plot2_'+str(k)+'.pdf')
467
     ax3.set_xscale('log')
501
     ax3.set_xscale('log')
468
-    ax3.set_xlabel("log()")
502
+    ax3.set_xlabel("log Relative scale")
469
     ax3.set_ylabel("theta")
503
     ax3.set_ylabel("theta")
470
-    ax3.set_title("Test")
504
+    ax3.set_title("Title")
471
     ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5)
505
     ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5)
472
     if ax is None:
506
     if ax is None:
473
         plt.savefig(title+'_plot3_'+str(k)+'_log.pdf')
507
         plt.savefig(title+'_plot3_'+str(k)+'_log.pdf')
477
 def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale = True):
511
 def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale = True):
478
     my_dpi = 300
512
     my_dpi = 300
479
     # Add some extra space for the second axis at the bottom
513
     # Add some extra space for the second axis at the bottom
480
-    fig, axs = plt.subplots(3, 2, figsize=(4500/my_dpi, 2970/my_dpi), dpi=my_dpi)
514
+    fig, axs = plt.subplots(3, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi)
481
     ax = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = axs)
515
     ax = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = axs)
482
     ax = plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = axs)
516
     ax = plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = axs)
483
-
484
     # Adjust layout to prevent clipping of titles
517
     # Adjust layout to prevent clipping of titles
485
     plt.tight_layout()
518
     plt.tight_layout()
486
     # Adjust absolute space between the top and bottom rows
519
     # Adjust absolute space between the top and bottom rows
487
     plt.subplots_adjust(hspace=0.7)  # Adjust this value based on your requirement
520
     plt.subplots_adjust(hspace=0.7)  # Adjust this value based on your requirement
488
-
489
     # Save the entire grid as a single figure
521
     # Save the entire grid as a single figure
490
     plt.savefig(title+'_combined.pdf')
522
     plt.savefig(title+'_combined.pdf')
491
-
523
+    plt.close()
524
+    # second call for individual plots
525
+    plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = None)
526
+    plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = None)
492
 
527
 
493
 if __name__ == "__main__":
528
 if __name__ == "__main__":
494
 
529