4 Commits 89813468b5 ... 44449033db

Author SHA1 Message Date
  tforest 44449033db Update SFS plotting function 2 months ago
  tforest fed1a36d79 Get rid of the old plot_all_epochs_theta 4 months ago
  tforest f90938f8d9 Save all epochs in JSON 4 months ago
  tforest 6a6d4bf6f9 Remove unused functions for swp2 4 months ago
3 changed files with 241 additions and 291 deletions
  1. 6 3
      customgraphics.py
  2. 22 10
      sfs_tools.py
  3. 213 278
      swp2.py

+ 6 - 3
customgraphics.py View File

@@ -200,14 +200,14 @@ def scatter(x, y, ylab=None, xlab=None, title=None):
200 200
         plt.title(title)
201 201
     plt.show()
202 202
 
203
-def barplot(x=None, y=None, ylab=None, xlab=None, title=None):
203
+def barplot(x=None, y=None, ylab=None, xlab=None, title=None, label=None, xticks = None, width=1):
204 204
     if x:
205 205
         x = list(x)
206 206
         plt.xticks(x)
207
-        plt.bar(x, y)
207
+        plt.bar(x, y, width=width, label=label)
208 208
     else:
209 209
         x = list(range(len(y)))
210
-        plt.bar(x, y)
210
+        plt.bar(x, y, width=width, label=label)
211 211
         plt.xticks(x)
212 212
     if ylab:
213 213
         plt.ylabel(ylab)
@@ -215,6 +215,9 @@ def barplot(x=None, y=None, ylab=None, xlab=None, title=None):
215 215
         plt.xlabel(xlab)
216 216
     if title:
217 217
         plt.title(title)
218
+    if xticks:
219
+        plt.xticks(xticks)
220
+    plt.legend()
218 221
     plt.show()
219 222
 
220 223
 def plot_chrom_continuity(vcf_entries, chr_id, x=None, y=None, outfile = None,

+ 22 - 10
sfs_tools.py View File

@@ -21,6 +21,7 @@ import gzip
21 21
 import sys
22 22
 import matplotlib.pyplot as plt
23 23
 from frst import customgraphics
24
+import numpy as np
24 25
 
25 26
 def sfs_from_vcf(n, vcf_file, folded = True, diploid = True, phased = False, verbose = False,
26 27
                  strip = False, count_ext = False):
@@ -192,7 +193,7 @@ def sfs_from_parsed_vcf(n, vcf_dict, folded = True, diploid = True, phased = Fal
192 193
     return SFS_values, count_pluriall
193 194
 
194 195
 
195
-def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed = False, normalized = False):
196
+def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed = False, normalized = False, ploidy = 2):
196 197
     sfs_val = []
197 198
     n = len(sfs.values())
198 199
     sum_sites = sum(list(sfs.values()))
@@ -222,7 +223,7 @@ def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed =
222 223
             
223 224
     #terminal case, same for folded or unfolded
224 225
     if transformed:
225
-        last_bin = list(sfs.values())[n-1] * n/2
226
+        last_bin = list(sfs.values())[n-1] * n/ploidy
226 227
     else:
227 228
         last_bin = list(sfs.values())[n-1]
228 229
     sfs_val[-1] = last_bin
@@ -235,22 +236,33 @@ def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed =
235 236
         
236 237
         #print(sum(sfs_val))
237 238
     #build the plot
238
-    title = title+" (n="+str(len(sfs_val))+") [folded="+str(folded)+"]"+" [transformed="+str(transformed)+"]"
239
-    print("SFS =", sfs)
240 239
     if folded:
241 240
         xlab = "Minor allele frequency"
241
+        n_title = n
242
+    else:
243
+        # the spectrum is n-1 long when unfolded
244
+        n_title = n+1
245
+    
246
+    title = title+" (n="+str(n_title)+") [folded="+str(folded)+"]"+" [transformed="+str(transformed)+"]"
247
+    print("SFS =", sfs)
248
+
249
+    X_axis = list(sfs.keys()) 
250
+    
251
+
242 252
     if transformed:
243
-        print("Transformed SFS ( n =",len(sfs_val), ") :", sfs_val)
253
+        print("Transformed SFS ( n =",n_title, ") :", sfs_val)
244 254
         #plt.axhline(y=1/n, color='r', linestyle='-')
255
+        plt.bar([x+0.2 for x in list(sfs.keys())], [1/n]*n, color='r', linestyle='-', width = 0.4, label= "H0 Theoric constant")
256
+
245 257
     else:
246 258
         if normalized:
247 259
             # then plot a theoritical distribution as 1/i
248
-            expected_y = [1/(2*x+1) for x in list(sfs.keys())]
260
+            sum_expected = sum([(1/(i+1)) for i,x in enumerate(list(sfs.keys()))])
261
+            expected_y = [(1/(i+1))/sum_expected for i,x in enumerate(list(sfs.keys()))]
262
+            print(expected_y)
263
+            plt.bar([x+0.2 for x in list(sfs.keys())], expected_y, color='r', linestyle='-', width = 0.4, label= "H0 Theoric constant")
249 264
             print(sum(expected_y))
250
-            #plt.plot([x for x in list(sfs.keys())], expected_y, color='r', linestyle='-')
251
-            #print(expected_y)
252
-            
253
-    customgraphics.barplot(x = [x for x in list(sfs.keys())], y= sfs_val, xlab = xlab, ylab = ylab, title = title)
265
+    customgraphics.barplot(x = [x-0.2 for x in X_axis], width=0.4, y= sfs_val, xlab = xlab, ylab = ylab, title = title, label = "H1 Observed spectrum", xticks =list(sfs.keys()) )
254 266
     plt.show()
255 267
 
256 268
 if __name__ == "__main__":

+ 213 - 278
swp2.py View File

@@ -3,13 +3,11 @@ import os
3 3
 import numpy as np
4 4
 import math
5 5
 import json
6
-import io
7
-from scipy.special import gammaln
8
-from matplotlib.backends.backend_pdf import PdfPages
9
-from matplotlib.ticker import MaxNLocator
10
-from mpl_toolkits.axes_grid1.inset_locator import inset_axes
11
-from matplotlib.ticker import MultipleLocator
6
+
12 7
 def log_facto(k):
8
+    """
9
+    Using the Stirling's approximation
10
+    """
13 11
     k = int(k)
14 12
     if k > 1e6:
15 13
         return k * np.log(k) - k + np.log(2*math.pi*k)/2
@@ -18,15 +16,7 @@ def log_facto(k):
18 16
         val += np.log(i)
19 17
     return val
20 18
 
21
-def log_facto_1(k):
22
-    startf = 1 # start of factorial sequence
23
-    stopf  = int(k+1) # end of of factorial sequence
24
-
25
-    q = gammaln(range(startf+1, stopf+1)) # n! = G(n+1)
26
-
27
-    return q[-1]
28
-
29
-def return_x_y_from_stwp_theta_file(stwp_theta_file, breaks, mu, tgen, relative_theta_scale = False):
19
+def parse_stwp_theta_file(stwp_theta_file, breaks, mu, tgen, relative_theta_scale = False):
30 20
     with open(stwp_theta_file, "r") as swp_file:
31 21
         # Read the first line
32 22
         line = swp_file.readline()
@@ -119,107 +109,8 @@ def return_x_y_from_stwp_theta_file(stwp_theta_file, breaks, mu, tgen, relative_
119 109
         x.append(list(t.values())[time])
120 110
         x.append(list(t.values())[time])
121 111
     x.append(list(t.values())[len(t.values())-1])
122
-    # if relative_theta_scale:
123
-    #     # rescale
124
-    #     #N0 = y[0]
125
-    #     # for i in range(len(y)):
126
-    #     #     # divide by N0
127
-    #     #     y[i] = y[i]/N0
128
-    #     #     x[i] = x[i]/N0
129
-    return x,y,likelihood,thetas,sfs,L
130
-
131
-def return_x_y_from_stwp_theta_file_as_is(stwp_theta_file, breaks, mu, tgen, relative_theta_scale = False):
132
-    with open(stwp_theta_file, "r") as swp_file:
133
-        # Read the first line
134
-        line = swp_file.readline()
135
-        L = float(line.split()[2])
136
-        rands = swp_file.readline()
137
-        line = swp_file.readline()
138
-        # skip empty lines before SFS
139
-        while line == "\n":
140
-            line = swp_file.readline()
141
-        sfs = np.array(line.split()).astype(float)
142
-        # Process lines until the end of the file
143
-        while line:
144
-            # check at each line
145
-            if line.startswith("dim") :
146
-                dim = int(line.split()[1])
147
-                if dim == breaks+1:
148
-                    likelihood = line.split()[5]
149
-                    groups = line.split()[6:6+dim]
150
-                    theta_site = line.split()[6+dim:6+dim+1+dim]
151
-                elif dim < breaks+1:
152
-                    line = swp_file.readline()
153
-                    continue
154
-                elif dim > breaks+1:
155
-                    break
156
-                    #return 0,0,0
157
-            # Read the next line
158
-            line = swp_file.readline()
159
-    #### END of parsing
160
-    # quit this file if the number of dimensions is incorrect
161
-    if dim < breaks+1:
162
-        return 0,0
163
-    # get n, the last bin of the last group
164
-    # revert the list of groups as the most recent times correspond
165
-    # to the closest and last leafs of the coal. tree.
166
-    groups = groups[::-1]
167
-    theta_site = theta_site[::-1]
168 112
 
169
-    thetas = {}
170
-
171
-    for i in range(len(groups)):
172
-        groups[i] = groups[i].split(',')
173
-        # print(groups[i], len(groups[i]))
174
-        thetas[i] = [float(theta_site[i]), groups[i], likelihood]
175
-    return thetas, sfs
176
-
177
-def plot_k_epochs_thetafolder(folder_path, mu, tgen, breaks = 2, title = "Title", theta_scale = True):
178
-    scenari = {}
179
-    cpt = 0
180
-    for file_name in os.listdir(folder_path):
181
-        if os.path.isfile(os.path.join(folder_path, file_name)):
182
-            # Perform actions on each file
183
-            x,y,likelihood,sfs,L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
184
-                                                             tgen = tgen,
185
-                                     mu = mu, relative_theta_scale = theta_scale)
186
-            if x == 0 or y == 0:
187
-                continue
188
-            cpt +=1
189
-            scenari[likelihood] = x,y
190
-    print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
191
-    print(cpt, "theta file(s) have been scanned.")
192
-    # sort starting by the smallest -log(Likelihood)
193
-    print(scenari)
194
-    best10_scenari = (sorted(list(scenari.keys())))[:10]
195
-    print("10 greatest Likelihoods", best10_scenari)
196
-    greatest_likelihood = best10_scenari[0]
197
-    x, y = scenari[greatest_likelihood]
198
-    my_dpi = 300
199
-    plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
200
-    plt.plot(x, y, 'r-', lw=2, label = 'Lik='+greatest_likelihood)
201
-    #plt.yscale('log')
202
-    plt.xscale('log')
203
-    plt.grid(True,which="both", linestyle='--', alpha = 0.3)
204
-
205
-    for scenario in best10_scenari[1:]:
206
-        x,y = scenari[scenario]
207
-        #print("\n----  Lik:",scenario,"\n\nt=", x,"\n\nN=",y, "\n\n")
208
-        plt.plot(x, y, '--', lw=1, label = 'Lik='+scenario)
209
-    if theta_scale:
210
-        plt.xlabel("Coal. time")
211
-        plt.ylabel("Pop. size scaled by N0")
212
-        recent_scale_lower_bound = y[0] * 0.01
213
-        recent_scale_upper_bound = y[0] * 0.1
214
-        plt.axvline(x=recent_scale_lower_bound)
215
-        plt.axvline(x=recent_scale_upper_bound)
216
-    else:
217
-        # years
218
-        plt.xlabel("Time (years)")
219
-        plt.ylabel("Individuals (N)")
220
-    plt.legend(loc='upper right')
221
-    plt.title(title)
222
-    plt.savefig(title+'_b'+str(breaks)+'.pdf')
113
+    return x,y,likelihood,thetas,sfs,L
223 114
 
224 115
 def plot_straight_x_y(x,y):
225 116
     x_1 = [x[0]]
@@ -233,7 +124,8 @@ def plot_straight_x_y(x,y):
233 124
     x_1.append(x[-1])
234 125
     return x_1, y_1
235 126
 
236
-def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_scale = True, ax = None):
127
+def plot_all_epochs_thetafolder_old(folder_path, mu, tgen, title = "Title",
128
+    theta_scale = True, ax = None, input = None, output = None):
237 129
     #scenari = {}
238 130
     cpt = 0
239 131
     epochs = {}
@@ -241,7 +133,7 @@ def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_sc
241 133
         breaks = 0
242 134
         cpt +=1
243 135
         if os.path.isfile(os.path.join(folder_path, file_name)):
244
-            x, y, likelihood, theta, sfs, L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
136
+            x, y, likelihood, theta, sfs, L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
245 137
                                                              tgen = tgen,
246 138
                                                              mu = mu, relative_theta_scale = theta_scale)
247 139
             SFS_stored = sfs
@@ -251,7 +143,7 @@ def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_sc
251 143
                     epochs[breaks] = {}
252 144
                 epochs[breaks][likelihood] = x,y
253 145
                 breaks += 1
254
-                x,y,likelihood,theta,sfs,L = return_x_y_from_stwp_theta_file(folder_path+file_name, breaks = breaks,
146
+                x,y,likelihood,theta,sfs,L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
255 147
                                                                  tgen = tgen,
256 148
                                                                   mu = mu, relative_theta_scale = theta_scale)
257 149
             if x == 0:
@@ -384,8 +276,195 @@ def plot_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_sc
384 276
     # return plots
385 277
     return ax[0], ax[1]
386 278
 
279
+def plot_all_epochs_thetafolder(full_dict, mu, tgen, title = "Title",
280
+    theta_scale = True, ax = None, input = None, output = None):
281
+    my_dpi = 300
282
+    if ax is None:
283
+        # intialize figure
284
+        my_dpi = 300
285
+        fnt_size = 18
286
+        # plt.rcParams['font.size'] = fnt_size
287
+        fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
288
+    else:
289
+        fnt_size = 12
290
+        # plt.rcParams['font.size'] = fnt_size
291
+        ax1 = ax[1][0,0]
292
+    ax1.set_yscale('log')
293
+    ax1.set_xscale('log')
294
+    ax1.grid(True,which="both", linestyle='--', alpha = 0.3)
295
+    plot_handles = []
296
+    best_plot = full_dict['all_epochs']['best']
297
+    p0, = ax1.plot(best_plot[0], best_plot[1], 'o', linestyle = "-",
298
+    alpha=1, lw=2, label = str(best_plot[2])+' brks | Lik='+best_plot[3])
299
+    plot_handles.append(p0)
300
+    for k, plot_Lk in enumerate(full_dict['all_epochs']['plots']):
301
+        plot_Lk = str(full_dict['all_epochs']['plots'][k][3])
302
+        # plt.rcParams['font.size'] = fnt_size
303
+        p, = ax1.plot(full_dict['all_epochs']['plots'][k][0], full_dict['all_epochs']['plots'][k][1], 'o', linestyle = "--",
304
+        alpha=1/(k+1), lw=1.5, label = str(full_dict['all_epochs']['plots'][k][2])+' brks | Lik='+plot_Lk)
305
+        plot_handles.append(p)
306
+    if theta_scale:
307
+        ax1.set_xlabel("Coal. time", fontsize=fnt_size)
308
+        ax1.set_ylabel("Pop. size scaled by N0", fontsize=fnt_size)
309
+        # recent_scale_lower_bound = 0.01
310
+        # recent_scale_upper_bound = 0.1
311
+        # ax1.axvline(x=recent_scale_lower_bound)
312
+        # ax1.axvline(x=recent_scale_upper_bound)
313
+    else:
314
+        # years
315
+        plt.set_xlabel("Time (years)", fontsize=fnt_size)
316
+        plt.set_ylabel("Individuals (N)", fontsize=fnt_size)
317
+    # plt.rcParams['font.size'] = fnt_size
318
+    # print(fnt_size, "rcParam font.size=", plt.rcParams['font.size'])
319
+    ax1.legend(handles = plot_handles, loc='best', fontsize = fnt_size*0.5)
320
+    ax1.set_title(title)
321
+    if ax is None:
322
+        plt.savefig(title+'_b'+str(breaks)+'.pdf')
323
+    # plot likelihood against nb of breakpoints
324
+    if ax is None:
325
+        fig, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
326
+        # plt.rcParams['font.size'] = fnt_size
327
+    else:
328
+        #plt.rcParams['font.size'] = fnt_size
329
+        ax2 = ax[0][0,1]
330
+
331
+    ax2.plot(full_dict['Ln_Brks'][0], full_dict['Ln_Brks'][1], 'o', linestyle = "dotted", lw=2)
332
+    ax2.axhline(y=full_dict['best_Ln'], linestyle = "-.", color = "red", label = "$-\log\mathcal{L}$ = "+str(round(full_dict['best_Ln'], 2)))
333
+    ax2.set_yscale('log')
334
+    ax2.set_xlabel("# breakpoints", fontsize=fnt_size)
335
+    ax2.set_ylabel("$-\log\mathcal{L}$", fontsize=fnt_size)
336
+    ax2.legend(loc='best', fontsize = fnt_size*0.5)
337
+    ax2.set_title(title+" Likelihood gain from # breakpoints")
338
+    if ax is None:
339
+        plt.savefig(title+'_Breakpts_Likelihood.pdf')
340
+    # AIC
341
+    if ax is None:
342
+        fig, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
343
+        # plt.rcParams['font.size'] = '18'
344
+    else:
345
+        #plt.rcParams['font.size'] = fnt_size
346
+        ax3 = ax[1][0,1]
347
+    AIC = full_dict['AIC_Brks']
348
+    ax3.plot(AIC[0], AIC[1], 'o', linestyle = "dotted", lw=2)
349
+    ax3.axhline(y=full_dict['best_AIC'], linestyle = "-.", color = "red",
350
+    label = "Min. AIC = "+str(round(full_dict['best_AIC'], 2)))
351
+    ax3.set_yscale('log')
352
+    ax3.set_xlabel("# breakpoints", fontsize=fnt_size)
353
+    ax3.set_ylabel("AIC")
354
+    ax3.legend(loc='best', fontsize = fnt_size*0.5)
355
+    ax3.set_title(title+" AIC")
356
+    if ax is None:
357
+        plt.savefig(title+'_Breakpts_Likelihood_AIC.pdf')
358
+    # return plots
359
+    return ax[0], ax[1]
360
+
361
+def save_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_scale = True, input = None, output = None):
362
+    #scenari = {}
363
+    cpt = 0
364
+    epochs = {}
365
+    plots = {}
366
+    # store ['best'], and [0] for epoch 0 etc...
367
+    for file_name in os.listdir(folder_path):
368
+        breaks = 0
369
+        cpt +=1
370
+        if os.path.isfile(os.path.join(folder_path, file_name)):
371
+            x, y, likelihood, theta, sfs, L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
372
+                                                             tgen = tgen,
373
+                                                             mu = mu, relative_theta_scale = theta_scale)
374
+            SFS_stored = sfs
375
+            L_stored = L
376
+            while not (x == 0 and y == 0):
377
+                if breaks not in epochs.keys():
378
+                    epochs[breaks] = {}
379
+                epochs[breaks][likelihood] = x,y
380
+                breaks += 1
381
+                x,y,likelihood,theta,sfs,L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks,
382
+                                                                 tgen = tgen,
383
+                                                                  mu = mu, relative_theta_scale = theta_scale)
384
+            if x == 0:
385
+                # last break did not work, then breaks = breaks-1
386
+                breaks -= 1
387
+    print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n")
388
+    print(cpt, "theta file(s) have been scanned.")
389
+    brkpt_lik = []
390
+    top_plots = {}
391
+    for epoch, scenari in epochs.items():
392
+        # sort starting by the smallest -log(Likelihood)
393
+        best10_scenari = (sorted(list(scenari.keys())))[:10]
394
+        greatest_likelihood = best10_scenari[0]
395
+        # store the tuple breakpoints and likelihood for later plot
396
+        brkpt_lik.append((epoch, greatest_likelihood))
397
+        x, y = scenari[greatest_likelihood]
398
+        #without breakpoint
399
+        if epoch == 0:
400
+            # do something with the theta without bp and skip the plotting
401
+            N0 = y[0]
402
+            #continue
403
+        for i in range(len(y)):
404
+            # divide by N0
405
+            y[i] = y[i]/N0
406
+            x[i] = x[i]/N0
407
+        top_plots[greatest_likelihood] = x,y,epoch
408
+    plots_likelihoods = list(top_plots.keys())
409
+    for i in range(len(plots_likelihoods)):
410
+        plots_likelihoods[i] = float(plots_likelihoods[i])
411
+    best10_plots = sorted(plots_likelihoods)[:10]
412
+    top_plot_lik = str(best10_plots[0])
413
+    # store x,y,brks,likelihood
414
+    plots['best'] = (top_plots[top_plot_lik][0], top_plots[top_plot_lik][1], str(top_plots[top_plot_lik][2]), top_plot_lik)
415
+    plots['plots'] = []
416
+    for k, plot_Lk in enumerate(best10_plots[1:]):
417
+        plot_Lk = str(plot_Lk)
418
+        plots['plots'].append([top_plots[plot_Lk][0], top_plots[plot_Lk][1], str(top_plots[plot_Lk][2]), plot_Lk])
419
+    # plot likelihood against nb of breakpoints
420
+    # best possible likelihood from SFS
421
+    # Segregating sites
422
+    S = sum(SFS_stored)
423
+    # Number of kept sites from which the SFS is computed
424
+    L = L_stored
425
+    # number of monomorphic sites
426
+    S0 = L-S
427
+    # print("SFS", SFS_stored)
428
+    # print("S", S, "L", L, "S0=", S0)
429
+    # compute Ln
430
+    Ln = log_facto(S+S0) - log_facto(S0) + np.log(float(S0)/(S+S0)) * S0
431
+    for xi in range(0, len(SFS_stored)):
432
+        p_i = SFS_stored[xi] / float(S+S0)
433
+        Ln += np.log(p_i) * SFS_stored[xi] - log_facto(SFS_stored[xi])
434
+    # basic plot likelihood
435
+    Ln_Brks = [list(np.array(brkpt_lik)[:, 0]), list(np.array(brkpt_lik)[:, 1].astype(float))]
436
+    best_Ln = -Ln
437
+    AIC = []
438
+    for brk in np.array(brkpt_lik)[:, 0]:
439
+        brk = int(brk)
440
+        AIC.append((2*brk+1)+2*np.array(brkpt_lik)[brk, 1].astype(float))
441
+    AIC_Brks = [list(np.array(brkpt_lik)[:, 0]), AIC]
442
+    # AIC = 2*k - 2ln(L) ; where k is the number of parameters, here brks+1
443
+    AIC_ln = 2*(len(brkpt_lik)+1) - 2*Ln
444
+    best_AIC = AIC_ln
445
+    # to return : plots ; Ln_Brks ; AIC_Brks ; best_Ln ; best_AIC
446
+    # 'plots' dict keys: 'best', {epochs}('0', '1',...)
447
+    if input == None:
448
+        saved_plots = {"all_epochs":plots, "Ln_Brks":Ln_Brks,
449
+                        "AIC_Brks":AIC_Brks, "best_Ln":best_Ln,
450
+                        "best_AIC":best_AIC}
451
+    else:
452
+        # if the dict has to be loaded from input
453
+        with open(input, 'r') as json_file:
454
+            saved_plots = json.load(json_file)
455
+            saved_plots["all_epochs"] = plots
456
+            saved_plots["Ln_Brks"] = Ln_Brks
457
+            saved_plots["AIC_Brks"] = AIC_Brks
458
+            saved_plots["best_Ln"] = best_Ln
459
+            saved_plots["best_AIC"] = best_AIC
460
+    if output == None:
461
+        output = title+"_plotdata.json"
462
+    with open(output, 'w') as json_file:
463
+        json.dump(saved_plots, json_file)
464
+    return saved_plots
465
+
387 466
 def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
388
-    breaks_max = 10, output = None):
467
+    breaks_max = 10, input = None, output = None):
389 468
     """
390 469
     Save theta values as is to do basic plots.
391 470
     """
@@ -396,7 +475,7 @@ def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
396 475
         cpt +=1
397 476
         if os.path.isfile(os.path.join(folder_path, file_name)):
398 477
             for k in range(breaks_max):
399
-                thetas,sfs = return_x_y_from_stwp_theta_file_as_is(folder_path+file_name, breaks = k,
478
+                x,y,likelihood,thetas,sfs,L = parse_stwp_theta_file(folder_path+file_name, breaks = k,
400 479
                                                                  tgen = tgen,
401 480
                                                                  mu = mu, relative_theta_scale = theta_scale)
402 481
                 if thetas == 0:
@@ -478,9 +557,16 @@ def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True,
478 557
         # x2_plot, y2_plot = plot_straight_x_y(x_2, y)
479 558
         p2 = x_2, y
480 559
         lines_fig2.append(p2)
481
-
482
-    saved_plots = {"raw_stairs":plots, "scaled_stairs":lines_fig2,
483
-                    "prop":prop}
560
+    if input == None:
561
+        saved_plots = {"raw_stairs":plots, "scaled_stairs":lines_fig2,
562
+                        "prop":prop}
563
+    else:
564
+        # if the dict has to be loaded from input
565
+        with open(input, 'r') as json_file:
566
+            saved_plots = json.load(json_file)
567
+        saved_plots["raw_stairs"] = plots
568
+        saved_plots["scaled_stairs"] = lines_fig2
569
+        saved_plots["prop"] = prop
484 570
     if output == None:
485 571
         output = title+"_plotdata.json"
486 572
     with open(output, 'w') as json_file:
@@ -580,159 +666,6 @@ def plot_raw_stairs(plot_lines, prop, title, ax = None, n_ticks = 10):
580 666
     # return plots
581 667
     return ax
582 668
 
583
-def plot_test_theta(folder_path, mu, tgen, title = "Title", theta_scale = True, breaks_max = 10, ax = None, n_ticks = 10):
584
-    """
585
-    Use theta values as is to do basic plots.
586
-    """
587
-    cpt = 0
588
-    epochs = {}
589
-    len_sfs = 0
590
-    for file_name in os.listdir(folder_path):
591
-        cpt +=1
592
-        if os.path.isfile(os.path.join(folder_path, file_name)):
593
-            for k in range(breaks_max):
594
-                thetas,sfs = return_x_y_from_stwp_theta_file_as_is(folder_path+file_name, breaks = k,
595
-                                                                 tgen = tgen,
596
-                                                                 mu = mu, relative_theta_scale = theta_scale)
597
-                if thetas == 0:
598
-                    continue
599
-                if len(thetas)-1 != k:
600
-                    continue
601
-                if k not in epochs.keys():
602
-                    epochs[k] = {}
603
-                likelihood = str(eval(thetas[k][2]))
604
-                epochs[k][likelihood] = thetas
605
-                #epochs[k] = thetas
606
-    print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(k)+"\n*******\n")
607
-    print(cpt, "theta file(s) have been scanned.")
608
-    # multiple fig
609
-    if ax is None:
610
-        # intialize figure 1
611
-        my_dpi = 300
612
-        fnt_size = 18
613
-        # plt.rcParams['font.size'] = fnt_size
614
-        fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
615
-    else:
616
-        fnt_size = 12
617
-        # plt.rcParams['font.size'] = fnt_size
618
-        ax1 = ax[0, 1]
619
-        plt.subplots_adjust(wspace=0.3, hspace=0.3)
620
-    plots = []
621
-    best_epochs = {}
622
-    for epoch in epochs:
623
-        likelihoods = []
624
-        for key in epochs[epoch].keys():
625
-            likelihoods.append(key)
626
-        likelihoods.sort()
627
-        minLogLn = str(likelihoods[0])
628
-        best_epochs[epoch] = epochs[epoch][minLogLn]
629
-    for epoch, theta in best_epochs.items():
630
-        groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
631
-        x = []
632
-        y = []
633
-        thetas = np.array(list(theta.values()), dtype=object)[:, 0]
634
-        for i,group in enumerate(groups):
635
-            x += group[::-1]
636
-            y += list(np.repeat(thetas[i], len(group)))
637
-            if epoch == 0:
638
-                N0 = y[0]
639
-                # compute the proportion of information used at each bin of the SFS
640
-                sum_theta_i = 0
641
-                for i in range(2, len(y)+2):
642
-                    sum_theta_i+=y[i-2] / (i-1)
643
-                prop = []
644
-                for k in range(2, len(y)+2):
645
-                    prop.append(y[k-2] / (k - 1) / sum_theta_i)
646
-                prop = prop[::-1]
647
-                # print(prop, "\n", sum(prop))
648
-        # normalise to N0 (N0 of epoch1)
649
-        x_ticks = ax1.get_xticks()
650
-        for i in range(len(y)):
651
-            y[i] = y[i]/N0
652
-        # plot
653
-        x_plot, y_plot = plot_straight_x_y(x, y)
654
-        #plt.plot(x, y, 'o', linestyle="dotted", alpha=0.75, lw=2, label = str(epoch)+' brks')
655
-        p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
656
-        # add plot to the list of all plots to superimpose
657
-        plots.append(p)
658
-    #print(prop, "\n", sum(prop))
659
-    #ax.legend(handles=[p0]+plots)
660
-    ax1.set_xlabel("# bin", fontsize=fnt_size)
661
-    # Set the x-axis locator to reduce the number of ticks to 10
662
-    ax1.set_ylabel("theta", fontsize=fnt_size)
663
-    ax1.set_title(title, fontsize=fnt_size)
664
-    ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5)
665
-    ax1.set_xticks(x_ticks)
666
-    if len(prop) >= 18:
667
-        ax1.locator_params(nbins=n_ticks)
668
-    # new scale of ticks if too many values
669
-    cumul = 0
670
-    prop_cumul = []
671
-    for val in prop:
672
-        prop_cumul.append(val+cumul)
673
-        cumul = val+cumul
674
-    ax1.set_xticklabels([f'{x[k]}\n{val:.2f}' for k, val in enumerate(prop_cumul)])
675
-    if ax is None:
676
-        plt.savefig(title+'_raw'+str(k)+'.pdf')
677
-    # fig 2 & 3
678
-    if ax is None:
679
-        fig2, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
680
-        fig3, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
681
-    else:
682
-        # plt.rcParams['font.size'] = fnt_size
683
-        # place of plots on the grid
684
-        ax2 = ax[1,0]
685
-        ax3 = ax[1,1]
686
-    lines_fig2 = []
687
-    lines_fig3 = []
688
-    #plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi)
689
-    for epoch, theta in best_epochs.items():
690
-        groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist()
691
-        x = []
692
-        y = []
693
-        thetas = np.array(list(theta.values()), dtype=object)[:, 0]
694
-        for i,group in enumerate(groups):
695
-            x += group[::-1]
696
-            y += list(np.repeat(thetas[i], len(group)))
697
-            if epoch == 0:
698
-                N0 = y[0]
699
-        for i in range(len(y)):
700
-            y[i] = y[i]/N0
701
-        x_2 = []
702
-        T = 0
703
-        for i in range(len(x)):
704
-            x[i] = int(x[i])
705
-        # compute the times as: theta_k / (k*(k-1))
706
-        for i in range(0, len(x)):
707
-            T += y[i] / (x[i]*(x[i]-1))
708
-            x_2.append(T)
709
-        # Plotting (fig 2)
710
-        x_2 = [0]+x_2
711
-        y = [y[0]]+y
712
-        x2_plot, y2_plot = plot_straight_x_y(x_2, y)
713
-        p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
714
-        lines_fig2.append(p2)
715
-        # Plotting (fig 3) which is the same but log scale for x
716
-        p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(epoch)+' brks')
717
-        lines_fig3.append(p3)
718
-    ax2.set_xlabel("Relative scale", fontsize=fnt_size)
719
-    ax2.set_ylabel("theta", fontsize=fnt_size)
720
-    ax2.set_title(title, fontsize=fnt_size)
721
-    ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5)
722
-    if ax is None:
723
-        plt.savefig(title+'_plot2_'+str(k)+'.pdf')
724
-    ax3.set_xscale('log')
725
-    ax3.set_yscale('log')
726
-    ax3.set_xlabel("log Relative scale", fontsize=fnt_size)
727
-    ax3.set_ylabel("theta", fontsize=fnt_size)
728
-    ax3.set_title(title, fontsize=fnt_size)
729
-    ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5)
730
-    if ax is None:
731
-        plt.savefig(title+'_plot3_'+str(k)+'_log.pdf')
732
-        plt.clf()
733
-    # return plots
734
-    return ax
735
-
736 669
 def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale = True):
737 670
     my_dpi = 300
738 671
     # # Add some extra space for the second axis at the bottom
@@ -751,7 +684,9 @@ def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale =
751 684
     # # plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = None)
752 685
     # # plot_test_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, ax = None)
753 686
     # # plt.clf()
754
-    # save_k_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, output = title+"_plotdata.json")
687
+    save_k_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = breaks, output = title+"_plotdata.json")
688
+    save_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, input = title+"_plotdata.json", output = title+"_plotdata.json")
689
+
755 690
     with open(title+"_plotdata.json", 'r') as json_file:
756 691
         loaded_data = json.load(json_file)
757 692
     # plot page 1 of summary
@@ -767,7 +702,7 @@ def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale =
767 702
 
768 703
     ax1 = plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'],
769 704
                             prop = loaded_data['prop'], title = title, ax = ax1)
770
-    ax1, ax2 = plot_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, ax = [ax1, ax2])
705
+    ax1, ax2 = plot_all_epochs_thetafolder(loaded_data, mu, tgen, title, theta_scale, ax = [ax1, ax2])
771 706
     fig1.savefig(title+'_combined_p1.pdf')
772 707
     fig2.savefig(title+'_combined_p2.pdf')
773 708
     plot_raw_stairs(plot_lines = loaded_data['raw_stairs'],