浏览代码

Save all epochs in JSON

tforest 11 个月前
父节点
当前提交
f90938f8d9
共有 1 个文件被更改,包括 278 次插入18 次删除
  1. 278 18
      swp2.py

+ 278 - 18
swp2.py 查看文件

@@ -16,7 +16,7 @@ def log_facto(k):
16 16
         val += np.log(i)
17 17
     return val
18 18
 
19
-def return_x_y_from_stwp_theta_file(stwp_theta_file, breaks, mu, tgen, relative_theta_scale = False):
19
+def parse_stwp_theta_file(stwp_theta_file, breaks, mu, tgen, relative_theta_scale = False):
20 20
     with open(stwp_theta_file, "r") as swp_file:
21 21
         # Read the first line
22 22
         line = swp_file.readline()
@@ -109,15 +109,8 @@ def return_x_y_from_stwp_theta_file(stwp_theta_file, breaks, mu, tgen, relative_
109 109
         x.append(list(t.values())[time])
110 110
         x.append(list(t.values())[time])
111 111
     x.append(list(t.values())[len(t.values())-1])
112
-    # if relative_theta_scale:
113
-    #     # rescale
114
-    #     #N0 = y[0]
115
-    #     # for i in range(len(y)):
116
-    #     #     # divide by N0
117
-    #     #     y[i] = y[i]/N0
118
-    #     #     x[i] = x[i]/N0
119
-    return x,y,likelihood,thetas,sfs,L
120 112
 
113
+    return x,y,likelihood,thetas,sfs,L
121 114
 
122 115
 def plot_k_epochs_thetafolder(folder_path, mu, tgen, breaks = 2, title = "Title", theta_scale = True):
123 116
     scenari = {}
@@ -125,7 +118,7 @@ def plot_k_epochs_thetafolder(folder_path, mu, tgen, breaks = 2, title = "Title"
125 118
     for file_name in os.listdir(folder_path):
126 119
         if os.path.isfile(os.path.join(folder_path, file_name)):
127 120
             # Perform actions on each file
128
-            x, y, likelihood, theta, sfs, L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
121
+            x, y, likelihood, theta, sfs, L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
129 122
                                                              tgen = tgen,
130 123
                                      mu = mu, relative_theta_scale = theta_scale)
131 124
             if x == 0 or y == 0:
@@ -178,6 +171,265 @@ def plot_straight_x_y(x,y):
178 171
     x_1.append(x[-1])
179 172
     return x_1, y_1
180 173
 
174
+def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title",
175
+    theta_scale = True, ax = None, input = None, output = None):
176
+    #scenari = {}
177
+    cpt = 0
178
+    epochs = {}
179
+    for file_name in os.listdir(folder_path):
180
+        breaks = 0
181
+        cpt +=1
182
+        if os.path.isfile(os.path.join(folder_path, file_name)):
183
+            x, y, likelihood, theta, sfs, L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
184
+                                                             tgen = tgen,
185
+                                                             mu = mu, relative_theta_scale = theta_scale)
186
+            SFS_stored = sfs
187
+            L_stored = L
188
+            while not (x == 0 and y == 0):
189
+                if breaks not in epochs.keys():
190
+                    epochs[breaks] = {}
191
+                epochs[breaks][likelihood] = x,y
192
+                breaks += 1
193
+                x,y,likelihood,theta,sfs,L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
194
+                                                                 tgen = tgen,
195
+                                                                  mu = mu, relative_theta_scale = theta_scale)
196
+            if x == 0:
197
+                # last break did not work, then breaks = breaks-1
198
+                breaks -= 1
199
+    print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
200
+    print(cpt, "theta file(s) have been scanned.")
201
+    my_dpi = 300
202
+    if ax is None:
203
+        # intialize figure
204
+        my_dpi = 300
205
+        fnt_size = 18
206
+        # plt.rcParams['font.size'] = fnt_size
207
+        fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
208
+    else:
209
+        fnt_size = 12
210
+        # plt.rcParams['font.size'] = fnt_size
211
+        ax1 = ax[1][0,0]
212
+    ax1.set_yscale('log')
213
+    ax1.set_xscale('log')
214
+    ax1.grid(True,which="both", linestyle='--', alpha = 0.3)
215
+    brkpt_lik = []
216
+    top_plots = {}
217
+    for epoch, scenari in epochs.items():
218
+        # sort starting by the smallest -log(Likelihood)
219
+        best10_scenari = (sorted(list(scenari.keys())))[:10]
220
+        greatest_likelihood = best10_scenari[0]
221
+        # store the tuple breakpoints and likelihood for later plot
222
+        brkpt_lik.append((epoch, greatest_likelihood))
223
+        x, y = scenari[greatest_likelihood]
224
+        #without breakpoint
225
+        if epoch == 0:
226
+            # do something with the theta without bp and skip the plotting
227
+            N0 = y[0]
228
+            #continue
229
+        for i in range(len(y)):
230
+            # divide by N0
231
+            y[i] = y[i]/N0
232
+            x[i] = x[i]/N0
233
+        top_plots[greatest_likelihood] = x,y,epoch
234
+    plots_likelihoods = list(top_plots.keys())
235
+    for i in range(len(plots_likelihoods)):
236
+        plots_likelihoods[i] = float(plots_likelihoods[i])
237
+    best10_plots = sorted(plots_likelihoods)[:10]
238
+    top_plot_lik = str(best10_plots[0])
239
+    plot_handles = []
240
+    # plt.rcParams['font.size'] = fnt_size
241
+    p0, = ax1.plot(top_plots[top_plot_lik][0], top_plots[top_plot_lik][1], 'o', linestyle = "-",
242
+    alpha=1, lw=2, label = str(top_plots[top_plot_lik][2])+' brks | Lik='+top_plot_lik)
243
+    plot_handles.append(p0)
244
+    for k, plot_Lk in enumerate(best10_plots[1:]):
245
+        plot_Lk = str(plot_Lk)
246
+        # plt.rcParams['font.size'] = fnt_size
247
+        p, = ax1.plot(top_plots[plot_Lk][0], top_plots[plot_Lk][1], 'o', linestyle = "--",
248
+        alpha=1/(k+1), lw=1.5, label = str(top_plots[plot_Lk][2])+' brks | Lik='+plot_Lk)
249
+        plot_handles.append(p)
250
+    if theta_scale:
251
+        ax1.set_xlabel("Coal. time", fontsize=fnt_size)
252
+        ax1.set_ylabel("Pop. size scaled by N0", fontsize=fnt_size)
253
+        # recent_scale_lower_bound = 0.01
254
+        # recent_scale_upper_bound = 0.1
255
+        # ax1.axvline(x=recent_scale_lower_bound)
256
+        # ax1.axvline(x=recent_scale_upper_bound)
257
+    else:
258
+        # years
259
+        plt.set_xlabel("Time (years)", fontsize=fnt_size)
260
+        plt.set_ylabel("Individuals (N)", fontsize=fnt_size)
261
+    # plt.rcParams['font.size'] = fnt_size
262
+    # print(fnt_size, "rcParam font.size=", plt.rcParams['font.size'])
263
+    ax1.legend(handles = plot_handles, loc='best', fontsize = fnt_size*0.5)
264
+    ax1.set_title(title)
265
+    if ax is None:
266
+        plt.savefig(title+'_b'+str(breaks)+'.pdf')
267
+    # plot likelihood against nb of breakpoints
268
+    # best possible likelihood from SFS
269
+    # Segregating sites
270
+    S = sum(SFS_stored)
271
+    # Number of kept sites from which the SFS is computed
272
+    L = L_stored
273
+    # number of monomorphic sites
274
+    S0 = L-S
275
+    # print("SFS", SFS_stored)
276
+    # print("S", S, "L", L, "S0=", S0)
277
+    # compute Ln
278
+    Ln = log_facto(S+S0) - log_facto(S0) + np.log(float(S0)/(S+S0)) * S0
279
+    for xi in range(0, len(SFS_stored)):
280
+        p_i = SFS_stored[xi] / float(S+S0)
281
+        Ln += np.log(p_i) * SFS_stored[xi] - log_facto(SFS_stored[xi])
282
+    # basic plot likelihood
283
+    if ax is None:
284
+        fig, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
285
+        # plt.rcParams['font.size'] = fnt_size
286
+    else:
287
+        #plt.rcParams['font.size'] = fnt_size
288
+        ax2 = ax[0][0,1]
289
+    ax2.plot(np.array(brkpt_lik)[:, 0], np.array(brkpt_lik)[:, 1].astype(float), 'o', linestyle = "dotted", lw=2)
290
+    ax2.axhline(y=-Ln, linestyle = "-.", color = "red", label = "$-\log\mathcal{L}$ = "+str(round(-Ln, 2)))
291
+    ax2.set_yscale('log')
292
+    ax2.set_xlabel("# breakpoints", fontsize=fnt_size)
293
+    ax2.set_ylabel("$-\log\mathcal{L}$", fontsize=fnt_size)
294
+    ax2.legend(loc='best', fontsize = fnt_size*0.5)
295
+    ax2.set_title(title+" Likelihood gain from # breakpoints")
296
+    if ax is None:
297
+        plt.savefig(title+'_Breakpts_Likelihood.pdf')
298
+    # AIC
299
+    if ax is None:
300
+        fig, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
301
+        # plt.rcParams['font.size'] = '18'
302
+    else:
303
+        #plt.rcParams['font.size'] = fnt_size
304
+        ax3 = ax[1][0,1]
305
+    AIC = []
306
+    for brk in np.array(brkpt_lik)[:, 0]:
307
+        brk = int(brk)
308
+        AIC.append((2*brk+1)+2*np.array(brkpt_lik)[brk, 1].astype(float))
309
+    ax3.plot(np.array(brkpt_lik)[:, 0], AIC, 'o', linestyle = "dotted", lw=2)
310
+    # AIC = 2*k - 2ln(L) ; where k is the number of parameters, here brks+1
311
+    AIC_ln = 2*(len(brkpt_lik)+1) - 2*Ln
312
+    ax3.axhline(y=AIC_ln, linestyle = "-.", color = "red",
313
+    label = "Min. AIC = "+str(round(AIC_ln, 2)))
314
+    selected_brks_nb = AIC.index(min(AIC))
315
+    ax3.set_yscale('log')
316
+    ax3.set_xlabel("# breakpoints", fontsize=fnt_size)
317
+    ax3.set_ylabel("AIC")
318
+    ax3.legend(loc='best', fontsize = fnt_size*0.5)
319
+    ax3.set_title(title+" AIC")
320
+    if ax is None:
321
+        plt.savefig(title+'_Breakpts_Likelihood_AIC.pdf')
322
+    print("S", S)
323
+    # return plots
324
+    return ax[0], ax[1]
325
+
326
+def save_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_scale = True, input = None, output = None):
327
+    #scenari = {}
328
+    cpt = 0
329
+    epochs = {}
330
+    plots = {}
331
+    # store ['best'], and [0] for epoch 0 etc...
332
+    for file_name in os.listdir(folder_path):
333
+        breaks = 0
334
+        cpt +=1
335
+        if os.path.isfile(os.path.join(folder_path, file_name)):
336
+            x, y, likelihood, theta, sfs, L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
337
+                                                             tgen = tgen,
338
+                                                             mu = mu, relative_theta_scale = theta_scale)
339
+            SFS_stored = sfs
340
+            L_stored = L
341
+            while not (x == 0 and y == 0):
342
+                if breaks not in epochs.keys():
343
+                    epochs[breaks] = {}
344
+                epochs[breaks][likelihood] = x,y
345
+                breaks += 1
346
+                x,y,likelihood,theta,sfs,L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
347
+                                                                 tgen = tgen,
348
+                                                                  mu = mu, relative_theta_scale = theta_scale)
349
+            if x == 0:
350
+                # last break did not work, then breaks = breaks-1
351
+                breaks -= 1
352
+    print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
353
+    print(cpt, "theta file(s) have been scanned.")
354
+
355
+    brkpt_lik = []
356
+    top_plots = {}
357
+    for epoch, scenari in epochs.items():
358
+        # sort starting by the smallest -log(Likelihood)
359
+        best10_scenari = (sorted(list(scenari.keys())))[:10]
360
+        greatest_likelihood = best10_scenari[0]
361
+        # store the tuple breakpoints and likelihood for later plot
362
+        brkpt_lik.append((epoch, greatest_likelihood))
363
+        x, y = scenari[greatest_likelihood]
364
+        #without breakpoint
365
+        if epoch == 0:
366
+            # do something with the theta without bp and skip the plotting
367
+            N0 = y[0]
368
+            #continue
369
+        for i in range(len(y)):
370
+            # divide by N0
371
+            y[i] = y[i]/N0
372
+            x[i] = x[i]/N0
373
+        top_plots[greatest_likelihood] = x,y,epoch
374
+    plots_likelihoods = list(top_plots.keys())
375
+    for i in range(len(plots_likelihoods)):
376
+        plots_likelihoods[i] = float(plots_likelihoods[i])
377
+    best10_plots = sorted(plots_likelihoods)[:10]
378
+    top_plot_lik = str(best10_plots[0])
379
+    # store x,y,brks,likelihood
380
+    plots['best'] = (top_plots[top_plot_lik][0], top_plots[top_plot_lik][1], str(top_plots[top_plot_lik][2]), top_plot_lik)
381
+    for k, plot_Lk in enumerate(best10_plots[1:]):
382
+        plot_Lk = str(plot_Lk)
383
+        plots[str(top_plots[plot_Lk][2])] = (top_plots[plot_Lk][0], top_plots[plot_Lk][1], str(top_plots[plot_Lk][2]), plot_Lk)
384
+
385
+    # plot likelihood against nb of breakpoints
386
+    # best possible likelihood from SFS
387
+    # Segregating sites
388
+    S = sum(SFS_stored)
389
+    # Number of kept sites from which the SFS is computed
390
+    L = L_stored
391
+    # number of monomorphic sites
392
+    S0 = L-S
393
+    # print("SFS", SFS_stored)
394
+    # print("S", S, "L", L, "S0=", S0)
395
+    # compute Ln
396
+    Ln = log_facto(S+S0) - log_facto(S0) + np.log(float(S0)/(S+S0)) * S0
397
+    for xi in range(0, len(SFS_stored)):
398
+        p_i = SFS_stored[xi] / float(S+S0)
399
+        Ln += np.log(p_i) * SFS_stored[xi] - log_facto(SFS_stored[xi])
400
+    # basic plot likelihood
401
+    Ln_Brks = [list(np.array(brkpt_lik)[:, 0]), list(np.array(brkpt_lik)[:, 1].astype(float))]
402
+    best_Ln = -Ln
403
+    AIC = []
404
+    for brk in np.array(brkpt_lik)[:, 0]:
405
+        brk = int(brk)
406
+        AIC.append((2*brk+1)+2*np.array(brkpt_lik)[brk, 1].astype(float))
407
+    AIC_Brks = [list(np.array(brkpt_lik)[:, 0]), AIC]
408
+    # AIC = 2*k - 2ln(L) ; where k is the number of parameters, here brks+1
409
+    AIC_ln = 2*(len(brkpt_lik)+1) - 2*Ln
410
+    best_AIC = AIC_ln
411
+
412
+    # to return : plots ; Ln_Brks ; AIC_Brks ; best_Ln ; best_AIC
413
+    # 'plots' dict keys: 'best', {epochs}('0', '1',...)
414
+    if input == None:
415
+        saved_plots = {"all_epochs":plots, "Ln_Brks":Ln_Brks,
416
+                        "AIC_Brks":AIC_Brks, "best_Ln":best_Ln,
417
+                        "best_AIC":best_AIC}
418
+    else:
419
+        # if the dict has to be loaded from input
420
+        with open(input, 'r') as json_file:
421
+            saved_plots = json.load(json_file)
422
+            saved_plots["all_epochs"] = plots
423
+            saved_plots["Ln_Brks"] = Ln_Brks
424
+            saved_plots["AIC_Brks"] = AIC_Brks
425
+            saved_plots["best_Ln"] = best_Ln
426
+            saved_plots["best_AIC"] = best_AIC
427
+    if output == None:
428
+        output = title+"_plotdata.json"
429
+    with open(output, 'w') as json_file:
430
+        json.dump(saved_plots, json_file)
431
+    return saved_plots
432
+
181 433
 def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_scale = True, ax = None):
182 434
     #scenari = {}
183 435
     cpt = 0
@@ -186,7 +438,7 @@ def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_sc
186 438
         breaks = 0
187 439
         cpt +=1
188 440
         if os.path.isfile(os.path.join(folder_path, file_name)):
189
-            x, y, likelihood, theta, sfs, L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
441
+            x, y, likelihood, theta, sfs, L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
190 442
                                                              tgen = tgen,
191 443
                                                              mu = mu, relative_theta_scale = theta_scale)
192 444
             SFS_stored = sfs
@@ -196,7 +448,7 @@ def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_sc
196 448
                     epochs[breaks] = {}
197 449
                 epochs[breaks][likelihood] = x,y
198 450
                 breaks += 1
199
-                x,y,likelihood,theta,sfs,L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
451
+                x,y,likelihood,theta,sfs,L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
200 452
                                                                  tgen = tgen,
201 453
                                                                   mu = mu, relative_theta_scale = theta_scale)
202 454
             if x == 0:
@@ -330,7 +582,7 @@ def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_sc
330 582
     return ax[0], ax[1]
331 583
 
332 584
 def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
333
-    breaks_max = 10, output = None):
585
+    breaks_max = 10, input = None, output = None):
334 586
     """
335 587
     Save theta values as is to do basic plots.
336 588
     """
@@ -341,7 +593,7 @@ def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
341 593
         cpt +=1
342 594
         if os.path.isfile(os.path.join(folder_path, file_name)):
343 595
             for k in range(breaks_max):
344
-                x,y,likelihood,thetas,sfs,L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = k,
596
+                x,y,likelihood,thetas,sfs,L = parse_stwp_theta_file(folder_path+file_name, breaks = k,
345 597
                                                                  tgen = tgen,
346 598
                                                                  mu = mu, relative_theta_scale = theta_scale)
347 599
                 if thetas == 0:
@@ -423,9 +675,16 @@ def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
423 675
         # x2_plot, y2_plot = plot_straight_x_y(x_2, y)
424 676
         p2 = x_2, y
425 677
         lines_fig2.append(p2)
426
-
427
-    saved_plots = {"raw_stairs":plots, "scaled_stairs":lines_fig2,
428
-                    "prop":prop}
678
+    if input == None:
679
+        saved_plots = {"raw_stairs":plots, "scaled_stairs":lines_fig2,
680
+                        "prop":prop}
681
+    else:
682
+        # if the dict has to be loaded from input
683
+        with open(input, 'r') as json_file:
684
+            saved_plots = json.load(json_file)
685
+        saved_plots["raw_stairs"] = plots
686
+        saved_plots["scaled_stairs"] = lines_fig2
687
+        saved_plots["prop"] = prop
429 688
     if output == None:
430 689
         output = title+"_plotdata.json"
431 690
     with open(output, 'w') as json_file:
@@ -536,7 +795,7 @@ def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
536 795
         cpt +=1
537 796
         if os.path.isfile(os.path.join(folder_path, file_name)):
538 797
             for k in range(breaks_max):
539
-                x, y, likelihood, theta, sfs, L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = k,
798
+                x, y, likelihood, theta, sfs, L = parse_stwp_theta_file(folder_path+file_name, breaks = k,
540 799
                                                                  tgen = tgen,
541 800
                                                                  mu = mu, relative_theta_scale = theta_scale)
542 801
                 if thetas == 0:
@@ -713,6 +972,7 @@ def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale =
713 972
     ax1 = plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'],
714 973
                             prop = loaded_data['prop'], title = title, ax = ax1)
715 974
     ax1, ax2 = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = [ax1, ax2])
975
+    save_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, input = title+"_plotdata.json", output = title+"_plotdata.json")
716 976
     fig1.savefig(title+'_combined_p1.pdf')
717 977
     fig2.savefig(title+'_combined_p2.pdf')
718 978
     plot_raw_stairs(plot_lines = loaded_data['raw_stairs'],