import matplotlib.pyplot as plt import os import numpy as np import math import json def log_facto(k): """ Using the Stirling's approximation """ k = int(k) if k > 1e6: return k * np.log(k) - k + np.log(2*math.pi*k)/2 val = 0 for i in range(2, k+1): val += np.log(i) return val def parse_stwp_theta_file(stwp_theta_file, breaks, mu, tgen, relative_theta_scale = False): with open(stwp_theta_file, "r") as swp_file: # Read the first line line = swp_file.readline() L = float(line.split()[2]) rands = swp_file.readline() line = swp_file.readline() # skip empty lines before SFS while line == "\n": line = swp_file.readline() sfs = np.array(line.split()).astype(float) # Process lines until the end of the file while line: # check at each line if line.startswith("dim") : dim = int(line.split()[1]) if dim == breaks+1: likelihood = line.split()[5] groups = line.split()[6:6+dim] theta_site = line.split()[6+dim:6+dim+1+dim] elif dim < breaks+1: line = swp_file.readline() continue elif dim > breaks+1: break #return 0,0,0 # Read the next line line = swp_file.readline() #### END of parsing # quit this file if the number of dimensions is incorrect if dim < breaks+1: return 0,0,0,0,0,0 # get n, the last bin of the last group # revert the list of groups as the most recent times correspond # to the closest and last leafs of the coal. tree. groups = groups[::-1] theta_site = theta_site[::-1] # store thetas for later use grps = groups.copy() thetas = {} for i in range(len(groups)): grps[i] = grps[i].split(',') thetas[i] = [float(theta_site[i]), grps[i], likelihood] # initiate the dict of times t = {} # list of thetas theta_L = [] sum_t = 0 for group_nb, group in enumerate(groups): ###print(group_nb, group, theta_site[group_nb], len(theta_site)) # store all the thetas one by one, with one theta per group theta_L.append(float(theta_site[group_nb])) # if the group is of size 1 if len(group.split(',')) == 1: i = int(group) # if the group size is >1, take the first elem of the group # i is the first bin of each group, straight after a breakpoint else: i = int(group.split(",")[0]) j = int(group.split(",")[-1]) t[i] = 0 #t = if len(group.split(',')) == 1: k = i if relative_theta_scale: t[i] += ((theta_L[group_nb] ) / (k*(k-1))) else: t[i] += ((theta_L[group_nb] ) / (k*(k-1)) * tgen) / mu else: for k in range(j, i-1, -1 ): if relative_theta_scale: t[i] += ((theta_L[group_nb] ) / (k*(k-1))) else: t[i] += ((theta_L[group_nb] ) / (k*(k-1)) * tgen) / mu # we add the cumulative times at the end t[i] += sum_t sum_t = t[i] # build the y axis (sizes) y = [] for theta in theta_L: if relative_theta_scale: size = theta else: # with size N = theta/4mu size = theta / (4*mu) y.append(size) y.append(size) # build the time x axis x = [0] for time in range(0, len(t.values())-1): x.append(list(t.values())[time]) x.append(list(t.values())[time]) x.append(list(t.values())[len(t.values())-1]) return x,y,likelihood,thetas,sfs,L def plot_straight_x_y(x,y): x_1 = [x[0]] y_1 = [] for i in range(0, len(y)-1): x_1.append(x[i]) x_1.append(x[i]) y_1.append(y[i]) y_1.append(y[i]) y_1 = y_1+[y[-1],y[-1]] x_1.append(x[-1]) return x_1, y_1 def plot_all_epochs_thetafolder(full_dict, mu, tgen, title = "Title", theta_scale = True, ax = None, input = None, output = None): my_dpi = 500 L = full_dict["L"] if ax is None: # intialize figure #my_dpi = 300 fnt_size = 18 # plt.rcParams['font.size'] = fnt_size fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi) else: fnt_size = 12 # plt.rcParams['font.size'] = fnt_size ax1 = ax[1][0,0] ax1.set_yscale('log') ax1.set_xscale('log') plot_handles = [] best_plot = full_dict['all_epochs']['best'] p0, = ax1.plot(best_plot[0], best_plot[1], linestyle = "-", alpha=1, lw=2, label = str(best_plot[2])+' brks | Lik='+best_plot[3]) plot_handles.append(p0) #ax1.grid(True,which="both", linestyle='--', alpha = 0.3) for k, plot_Lk in enumerate(full_dict['all_epochs']['plots']): plot_Lk = str(full_dict['all_epochs']['plots'][k][3]) # plt.rcParams['font.size'] = fnt_size p, = ax1.plot(full_dict['all_epochs']['plots'][k][0], full_dict['all_epochs']['plots'][k][1], linestyle = "-", alpha=1/(k+1), lw=1.5, label = str(full_dict['all_epochs']['plots'][k][2])+' brks | Lik='+plot_Lk) plot_handles.append(p) if theta_scale: ax1.set_xlabel("Coal. time", fontsize=fnt_size) ax1.set_ylabel("Pop. size scaled by N0", fontsize=fnt_size) # recent_scale_lower_bound = 0.01 # recent_scale_upper_bound = 0.1 # ax1.axvline(x=recent_scale_lower_bound) # ax1.axvline(x=recent_scale_upper_bound) else: # years if ax is not None: plt.set_xlabel("Time (years)", fontsize=fnt_size) plt.set_ylabel("Effective pop. size (Ne)", fontsize=fnt_size) else: plt.xlabel("Time (years)", fontsize=fnt_size) plt.ylabel("Effective pop. size (Ne)", fontsize=fnt_size) # x_ticks = ax1.get_xticks() # ax1.set_xticklabels([f'{k:.0e}\n{k/(mu):.0e}\n{k/(mu)*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5) # ax1.set_xticklabels([f'{k}\n{k/(mu)}\n{k/(mu)*tgen}' for k in x_ticks], fontsize = fnt_size*0.8) # plt.rcParams['font.size'] = fnt_size # print(fnt_size, "rcParam font.size=", plt.rcParams['font.size']) ax1.legend(handles = plot_handles, loc='best', fontsize = fnt_size*0.5) ax1.set_title(title) breaks = len(full_dict['all_epochs']['plots']) if ax is None: plt.savefig(title+'_best_'+str(breaks+1)+'_epochs.pdf') # plot likelihood against nb of breakpoints if ax is None: fig, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi) # plt.rcParams['font.size'] = fnt_size else: #plt.rcParams['font.size'] = fnt_size ax2 = ax[0][0,1] # Retrieve the default color cycle from rcParams default_colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] # Create an array of colors from the default color cycle colors = [default_colors[i % len(default_colors)] for i in range(len(full_dict['Ln_Brks'][0]))] ax2.plot(full_dict['Ln_Brks'][0], full_dict['Ln_Brks'][1], "--", lw=1, color="black", zorder=1) ax2.scatter(full_dict['Ln_Brks'][0], full_dict['Ln_Brks'][1], s=50, c=colors, marker='o', zorder=2) ax2.axhline(y=full_dict['best_Ln'], linestyle = "-.", color = "red", label = "$-\log\mathcal{L}$ = "+str(round(full_dict['best_Ln'], 2))) ax2.set_yscale('log') ax2.set_xlabel("# breakpoints", fontsize=fnt_size) ax2.set_ylabel("$-\log\mathcal{L}$", fontsize=fnt_size) ax2.legend(loc='best', fontsize = fnt_size*0.5) ax2.set_title(title+" Likelihood gain from # breakpoints") if ax is None: plt.savefig(title+'_Breakpts_Likelihood.pdf') # AIC if ax is None: fig, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi) # plt.rcParams['font.size'] = '18' else: #plt.rcParams['font.size'] = fnt_size ax3 = ax[1][0,1] AIC = full_dict['AIC_Brks'] # ax3.plot(AIC[0], AIC[1], 'o', linestyle = "dotted", lw=2) ax3.plot(AIC[0], AIC[1], "--", lw=1, color="black", zorder=1) ax3.scatter(AIC[0], AIC[1], s=50, c=colors, marker='o', zorder=2) ax3.axhline(y=full_dict['best_AIC'], linestyle = "-.", color = "red", label = "Min. AIC = "+str(round(full_dict['best_AIC'], 2))) ax3.set_yscale('log') ax3.set_xlabel("# breakpoints", fontsize=fnt_size) ax3.set_ylabel("AIC") ax3.legend(loc='best', fontsize = fnt_size*0.5) ax3.set_title(title+" AIC") if ax is None: plt.savefig(title+'_Breakpts_Likelihood_AIC.pdf') else: # return plots return ax[0], ax[1] def save_all_epochs_thetafolder(folder_path, mu, tgen, title = "Title", theta_scale = True, input = None, output = None): #scenari = {} cpt = 0 epochs = {} plots = {} # store ['best'], and [0] for epoch 0 etc... for file_name in os.listdir(folder_path): breaks = 0 cpt +=1 if os.path.isfile(os.path.join(folder_path, file_name)): x, y, likelihood, theta, sfs, L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks, tgen = tgen, mu = mu, relative_theta_scale = theta_scale) SFS_stored = sfs L_stored = L while not (x == 0 and y == 0): if breaks not in epochs.keys(): epochs[breaks] = {} epochs[breaks][likelihood] = x,y breaks += 1 x,y,likelihood,theta,sfs,L = parse_stwp_theta_file(folder_path+file_name, breaks = breaks, tgen = tgen, mu = mu, relative_theta_scale = theta_scale) if x == 0: # last break did not work, then breaks = breaks-1 breaks -= 1 print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(breaks)+"\n*******\n") print(cpt, "theta file(s) have been scanned.") brkpt_lik = [] top_plots = {} best_scenario_for_epoch = {} for epoch, scenari in epochs.items(): # sort starting by the smallest -log(Likelihood) best10_scenari = (sorted(list(scenari.keys())))[:10] greatest_likelihood = best10_scenari[0] # store the tuple breakpoints and likelihood for later plot brkpt_lik.append((epoch, greatest_likelihood)) x, y = scenari[greatest_likelihood] #without breakpoint if epoch == 0: # do something with the theta without bp and skip the plotting N0 = y[0] #continue if theta_scale: for i in range(len(y)): # divide by N0 y[i] = y[i]/N0 x[i] = x[i]/N0 top_plots[greatest_likelihood] = x,y,epoch best_scenario_for_epoch[epoch] = x,y,greatest_likelihood plots_likelihoods = list(top_plots.keys()) for i in range(len(plots_likelihoods)): plots_likelihoods[i] = float(plots_likelihoods[i]) best10_plots = sorted(plots_likelihoods)[:10] top_plot_lik = str(best10_plots[0]) # store x,y,brks,likelihood plots['best'] = (top_plots[top_plot_lik][0], top_plots[top_plot_lik][1], str(top_plots[top_plot_lik][2]), top_plot_lik) plots['plots'] = [] for k, epoch in enumerate(best_scenario_for_epoch.keys()): plot_Lk = str(best_scenario_for_epoch[epoch][2]) x,y = best_scenario_for_epoch[epoch][0], best_scenario_for_epoch[epoch][1] plots['plots'].append([x, y, str(epoch), plot_Lk]) plots['plots'] = sorted(plots['plots'], key=lambda x: float(x[3])) plots['plots'] = plots['plots'][1:] # Previous version. Was this correct???? # for k, plot_Lk in enumerate(best10_plots[1:]): # plot_Lk = str(plot_Lk) # plots['plots'].append([top_plots[plot_Lk][0], top_plots[plot_Lk][1], str(top_plots[plot_Lk][2]), plot_Lk]) # plot likelihood against nb of breakpoints # best possible likelihood from SFS # Segregating sites S = sum(SFS_stored) # Number of kept sites from which the SFS is computed L = L_stored # number of monomorphic sites S0 = L-S # print("SFS", SFS_stored) print("S", S, "L", L, "S0=", S0) my_n = len(SFS_stored)*2 print("n=",my_n) an = 1 for i in range(2, my_n): an +=1.0/i print("an=", an, "theta_w", S/an, "theta_w_p_site", (S/an)/L) # compute Ln Ln = log_facto(S+S0) - log_facto(S0) + np.log(float(S0)/(S+S0)) * S0 for xi in range(0, len(SFS_stored)): p_i = SFS_stored[xi] / float(S+S0) Ln += np.log(p_i) * SFS_stored[xi] - log_facto(SFS_stored[xi]) # basic plot likelihood Ln_Brks = [list(np.array(brkpt_lik)[:, 0]), list(np.array(brkpt_lik)[:, 1].astype(float))] best_Ln = -Ln AIC = [] for brk in np.array(brkpt_lik)[:, 0]: brk = int(brk) AIC.append((2*brk+1)+2*np.array(brkpt_lik)[brk, 1].astype(float)) AIC_Brks = [list(np.array(brkpt_lik)[:, 0]), AIC] # AIC = 2*k - 2ln(L) ; where k is the number of parameters, here brks+1 AIC_ln = 2*(len(brkpt_lik)+1) - 2*Ln best_AIC = AIC_ln selected_brks_nb = AIC.index(min(AIC)) # to return : plots ; Ln_Brks ; AIC_Brks ; best_Ln ; best_AIC # 'plots' dict keys: 'best', {epochs}('0', '1',...) if input == None: saved_plots = {"S":S, "S0":S0, "L":L, "mu":mu, "tgen":tgen, "all_epochs":plots, "Ln_Brks":Ln_Brks, "AIC_Brks":AIC_Brks, "best_Ln":best_Ln, "best_AIC":best_AIC, "best_epoch_by_AIC":selected_brks_nb} else: # if the dict has to be loaded from input with open(input, 'r') as json_file: saved_plots = json.load(json_file) saved_plots["S"] = S saved_plots["S0"] = S0 saved_plots["L"] = L saved_plots["mu"] = mu saved_plots["tgen"] = tgen saved_plots["all_epochs"] = plots saved_plots["Ln_Brks"] = Ln_Brks saved_plots["AIC_Brks"] = AIC_Brks saved_plots["best_Ln"] = best_Ln saved_plots["best_AIC"] = best_AIC saved_plots["best_epoch_by_AIC"] = selected_brks_nb if output == None: output = title+"_plotdata.json" with open(output, 'w') as json_file: json.dump(saved_plots, json_file) return saved_plots def save_k_theta(folder_path, mu, tgen, title = "Title", theta_scale = True, breaks_max = 10, input = None, output = None): """ Save theta values as is to do basic plots. """ cpt = 0 epochs = {} len_sfs = 0 for file_name in os.listdir(folder_path): cpt +=1 if os.path.isfile(os.path.join(folder_path, file_name)): for k in range(breaks_max+1): x,y,likelihood,thetas,sfs,L = parse_stwp_theta_file(folder_path+file_name, breaks = k, tgen = tgen, mu = mu, relative_theta_scale = theta_scale) if thetas == 0: continue if len(thetas)-1 != k: continue if k not in epochs.keys(): epochs[k] = {} likelihood = str(eval(thetas[k][2])) epochs[k][likelihood] = thetas #epochs[k] = thetas print("\n*******\n"+title+"\n--------\n"+"mu="+str(mu)+"\ntgen="+str(tgen)+"\nbreaks="+str(k)+"\n*******\n") print(cpt, "theta file(s) have been scanned.") plots = [] best_epochs = {} for epoch in epochs: likelihoods = [] for key in epochs[epoch].keys(): likelihoods.append(key) likelihoods.sort() minLogLn = str(likelihoods[0]) best_epochs[epoch] = epochs[epoch][minLogLn] for epoch, theta in best_epochs.items(): groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist() x = [] y = [] thetas = np.array(list(theta.values()), dtype=object)[:, 0] for i,group in enumerate(groups): x += group[::-1] y += list(np.repeat(thetas[i], len(group))) if epoch == 0: N0 = y[0] # compute the proportion of information used at each bin of the SFS sum_theta_i = 0 for i in range(2, len(y)+2): sum_theta_i+=y[i-2] / (i-1) prop = [] for k in range(2, len(y)+2): prop.append(y[k-2] / (k - 1) / sum_theta_i) prop = prop[::-1] if theta_scale : # normalise to N0 (N0 of epoch1) for i in range(len(y)): y[i] = y[i]/N0 # x_plot, y_plot = plot_straight_x_y(x, y) p = x, y # add plot to the list of all plots to superimpose plots.append(p) cumul = 0 prop_cumul = [] for val in prop: prop_cumul.append(val+cumul) cumul = val+cumul prop = prop_cumul # print("raw stairs", plots[3]) # ########### # time = [] # for k in plots[0][0]: # k = int(k) # dt = 2.0/(k*(k-1)) # time.append(2.0/(k*(k-1))) # Ne = [] # for values in plots: # Ne.append(np.array(values[1])/(4*mu)) # print(time) # print(Ne[3]) lines_fig2 = [] for epoch, theta in best_epochs.items(): groups = np.array(list(theta.values()), dtype=object)[:, 1].tolist() x = [] y = [] thetas = np.array(list(theta.values()), dtype=object)[:, 0] for i,group in enumerate(groups): x += group[::-1] y += list(np.repeat(thetas[i], len(group))) if epoch == 0: # watterson theta theta_w = y[0] if theta_scale : for i in range(len(y)): y[i] = y[i]/N0 for i in range(len(y)): y[i] = y[i]/(4*mu) x_2 = [] T = 0 for i in range(len(x)): x[i] = int(x[i]) # compute the times as: theta_k / (k*(k-1)) for i in range(0, len(x)): T += y[i]*2 / (x[i]*(x[i]-1)) x_2.append(T) # Save plotting (fig 2) # x_2 = [0]+x_2 # y = [y[0]]+y # x2_plot, y2_plot = plot_straight_x_y(x_2, y) p2 = x_2, y lines_fig2.append(p2) # print("breaks=", epoch, "scaled_theta", lines_fig2[10]) # print(lines_fig2[3][1][0]/(4*mu)) # print(np.array(lines_fig2[3][1])/lines_fig2[3][1][0]) # print("size list y=", len(lines_fig2[3][1])) #exit(0) if input == None: saved_plots = {"raw_stairs":plots, "scaled_stairs":lines_fig2, "prop":prop} else: # if the dict has to be loaded from input with open(input, 'r') as json_file: saved_plots = json.load(json_file) saved_plots["raw_stairs"] = plots saved_plots["scaled_stairs"] = lines_fig2 saved_plots["prop"] = prop if output == None: output = title+"_plotdata.json" with open(output, 'w') as json_file: json.dump(saved_plots, json_file) return saved_plots def plot_scaled_theta(plot_lines, prop, title, mu, tgen, swp2_lines = None, ax = None, n_ticks = 10, subset = None, theta_scale = False): recent_limit_years = 100 # recent limit in coal. time recent_limit = recent_limit_years/tgen # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1) nb_epochs = len(plot_lines) # fig 2 & 3 if ax is None: my_dpi = 500 fnt_size = 18 fig2, ax2 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi) fig3, ax3 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi) else: # plt.rcParams['font.size'] = fnt_size fnt_size = 12 # place of plots on the grid ax2 = ax[1,0] ax3 = ax[1,1] lines_fig2 = [] lines_fig3 = [] #plt.figure(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi) if swp2_lines: for k in range(len(swp2_lines[0])): swp2_lines[0][k] = swp2_lines[0][k]/tgen for k in range(len(swp2_lines[1])): swp2_lines[1][k] = swp2_lines[1][k] # x2_plot, y2_plot = plot_straight_x_y(swp2_lines[0],swp2_lines[1]) x2_plot, y2_plot = swp2_lines[0], swp2_lines[1] p2, = ax2.plot(x2_plot, y2_plot, linestyle="-", alpha=0.75, lw=2, label = 'swp2', color="black") lines_fig2.append(p2) # Plotting (fig 3) which is the same but log scale for x p3, = ax3.plot(x2_plot, y2_plot, linestyle="-", alpha=0.75, lw=2, label = 'swp2', color="black") lines_fig3.append(p3) min_x = 1 min_y = 1 max_x = 0 max_y = 0 for breaks, plot in enumerate(plot_lines): x,y=plot x2_plot, y2_plot = plot_straight_x_y(x,y) if subset is not None: if breaks in subset: masking_alpha = 0.75 autoscale = True min_x = min(min_x, min(x2_plot)) min_y = min(min_y, min(y2_plot)) max_x = max(max_x, max(x2_plot)) max_y = max(max_y, max(y2_plot)) # skip the base 0 points x_plot[0:3] t_max_below_limit = 0 t_min_below_limit = recent_limit recent_change = False for t in x[1:]: if t <= recent_limit: recent_change = True t_max_below_limit = max(t_max_below_limit, t) t_min_below_limit = min(t_min_below_limit, t) Ne_max_below_limit = y[min(x.index(t_max_below_limit)+1, len(y)-1)] Ne_min_below_limit = y[x.index(t_min_below_limit)] if recent_change: print(f"\n{breaks} breaks ; This is below the recent limit of {recent_limit_years} years:\n", f"t_min (most recent time point under the limit) : {t_min_below_limit/mu*tgen:.1f} t_max (most ancient time point under the limit) : {t_max_below_limit/mu*tgen:.1f}", f"\nNe_min (effective size at t_min) : {Ne_min_below_limit/(4*mu):.1f} Ne_max (effective size at t_max) : {Ne_max_below_limit/(4*mu):.1f}", f"\nNe_min/Ne_max = {(Ne_min_below_limit/(4*mu)) / (Ne_max_below_limit/(4*mu)):.1f}", f"\nEvolution: {((Ne_min_below_limit/(4*mu)) - (Ne_max_below_limit/(4*mu)))/((Ne_max_below_limit/(4*mu)))*100:.1f}%") else: print(f"Recent event under {recent_limit_years} years: NA") # need to compute the last change and when it occured tmin = x[1] tmin_plus_1 = x[2] Ne_min = y[1] Ne_min_plus_1 = y[2] print(f"Last was {tmin/mu*tgen:.1f} years ago. And was of {((Ne_min/(4*mu)) - (Ne_min_plus_1/(4*mu)))/(Ne_min_plus_1/(4*mu))*100:.1f}%") else: masking_alpha = 0 autoscale = False ax2.set_autoscale_on(autoscale) ax3.set_autoscale_on(autoscale) p2, = ax2.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=masking_alpha, lw=2, label = str(breaks)+' brks') # Plotting (fig 3) which is the same but log scale for x p3, = ax3.plot(x2_plot, y2_plot, 'o', linestyle="-", alpha=masking_alpha, lw=2, label = str(breaks)+' brks') if subset is not None and breaks in subset: # store for legend lines_fig2.append(p2) lines_fig3.append(p3) # put the vertical line of the "recent" time limit ax3.axvline(x=recent_limit, linestyle="--") ax3.axvline(x=recent_limit/2, linestyle="--", color="green") if theta_scale: xlabel = "Theta scaled by N0" ylabel = "Theta scaled by N0" else: xlabel = "time" ylabel = "Effective pop. size (Ne)" if ax is None: # if not ax, then use the plt syntax, not ax... plt.xlabel(xlabel, fontsize=fnt_size) plt.ylabel(ylabel, fontsize=fnt_size) plt.gca().set_xlim(0, recent_limit * 3) if recent_change: plt.ylim(Ne_min_below_limit/3, Ne_max_below_limit *3) else: plt.ylim(y2_plot[0]/3, y2_plot[0]) # plt.ylim(0, max(max_y+(max_y*0.05), max(swp2_lines[1])+(max(swp2_lines[1])*0.05))) #plt.xlim(0, recent_limit * 3) #xlim_val = plt.gca().get_xlim() x_ticks = list(plt.xticks())[0] # plt.xlim(min(min_x,min(swp2_lines[0])), max(max(swp2_lines[0]), max_x)) # x_ticks = list(plt.gca().get_xticks()) # plt.gca().set_xticks(x_ticks) # plt.xticks(x_ticks) # plt.gca().set_xlim(xlim_val) # plt.gca().set_xticklabels([f'{k:.0e}\n{k/(mu):.0e}\n{k/(mu)*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5) plt.gca().set_xticklabels([f'{k:.1f}\n{k*tgen:.1f}' for k in x_ticks], fontsize = fnt_size*0.5) # rescale y to effective pop size # ylim_val = plt.gca().get_ylim() # plt.ylim(min(min_y,min(swp2_lines[1])), max(max_y+(max_y*0.05), max(swp2_lines[1])+(max(swp2_lines[1])*0.05))) # y_ticks = list(plt.yticks())[0] # plt.gca().set_yticks(y_ticks) # plt.gca().set_ylim(ylim_val) # plt.yticks(y_ticks) # plt.gca().set_yticklabels([f'{k/(4*mu):.0e}' for k in y_ticks], fontsize = fnt_size*0.5) # plt.title(title, fontsize=fnt_size) # plt.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5) # # plt.text(-0.13, -0.135, 'Coal. time\nGen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes) plt.text(-0.13, -0.135, 'Gen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes) plt.subplots_adjust(bottom=0.2) # Adjust the value as needed plt.savefig(title+'_plotB_'+str(nb_epochs)+'_epochs.pdf') # close fig2 to save memory plt.close(fig2) else: # when ax subplotting is used ax2.set_xlabel(xlabel, fontsize=fnt_size) ax2.set_ylabel(ylabel, fontsize=fnt_size) ax2.set_title(title, fontsize=fnt_size) ax2.legend(handles=lines_fig2, loc='best', fontsize = fnt_size*0.5) ax3.set_xlabel(xlabel, fontsize=fnt_size) ax3.set_ylabel(ylabel, fontsize=fnt_size) ax3.set_title(title, fontsize=fnt_size) ax3.legend(handles=lines_fig3, loc='best', fontsize = fnt_size*0.5) ax3.set_xscale('log') ax3.set_yscale('log') # Scale the x-axis # x_ticks = list(ax3.get_xticks()) # ax3.set_xticks(x_ticks) # x_ticks = [i for i in range(0.1,max(max_x, max(swp2_lines[0]))), ] # ax3.set_xticks(x_ticks) ax3.set_xlim(0.1, max(max_x, max(swp2_lines[0]))) x_ticks = ax3.get_xticks() # ax3.set_xlim(min(min(x_ticks), min(swp2_lines[0])), max(max_x, max(swp2_lines[0]))) # ax3.set_xlim(1, max(max_x, max(swp2_lines[0]))) # ax3.set_xticklabels([f'{k:.0e}\n{k/(mu):.0e}\n{k/(mu)*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5) # ax3.set_xticklabels([f'{k/(mu):.0e}\n{k/(mu)*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5) ax3.set_xticklabels([f'{k:.0e}\n{k*tgen:.0e}' for k in x_ticks], fontsize = fnt_size*0.5) # rescale y to effective pop size # y_ticks = list(ax3.get_yticks()) # ax3.set_yticks(y_ticks) # ax3.set_ylim(min(min(y_ticks), min(swp2_lines[1])), max(max_y+(max_y*0.5), max(swp2_lines[1])+(max(swp2_lines[1])*0.5))) # ax3.set_ylim(1, max(max_y, max(swp2_lines[1]))) ax3.set_ylim(1, max(max_y+(max_y*0.5), max(swp2_lines[1])+(max(swp2_lines[1])*0.5))) # ax3.set_yticklabels([f'{k/(4*mu):.0e}' for k in y_ticks], fontsize = fnt_size*0.5) # plt.text(-0.13, -0.135, 'Coal. time\nGen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes) # plt.text(-0.13, -0.135, 'Gen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes) plt.text(-0.13, -0.085, 'Gen. time\nYears', ha='left', va='bottom', transform=ax3.transAxes) plt.subplots_adjust(bottom=0.2) # Adjust the value as needed if ax is None: # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1) plt.savefig(title+'_plotC_'+str(nb_epochs)+'_epochs_log.pdf') # close fig3 to save memory plt.close(fig3) return ax def plot_raw_stairs(plot_lines, prop, title, ax = None, n_ticks = 10, rescale = False, subset = None, max_breaks = None): if max_breaks: nb_breaks = max_breaks else: nb_breaks = len(plot_lines)+1 # multiple fig if ax is None: # intialize figure 1 my_dpi = 500 fnt_size = 18 # plt.rcParams['font.size'] = fnt_size fig, ax1 = plt.subplots(figsize=(5000/my_dpi, 2800/my_dpi), dpi=my_dpi) plt.subplots_adjust(bottom=0.2) # Adjust the value as needed else: fnt_size = 12 # plt.rcParams['font.size'] = fnt_size ax1 = ax[0, 0] plt.subplots_adjust(wspace=0.3, hspace=0.3) plots = [] for breaks, plot in enumerate(plot_lines): if max_breaks and breaks > max_breaks: # stop plotting if it exceeds the limit continue x,y = plot x_plot, y_plot = plot_straight_x_y(x,y) p, = ax1.plot(x_plot, y_plot, 'o', linestyle="-", alpha=0.75, lw=2, label = str(breaks)+' brks') print("breaks=", breaks, "theta0", y[0]) # add plot to the list of all plots to superimpose plots.append(p) x_ticks = x # print(x_ticks) #print(prop, "\n", sum(prop)) #ax.legend(handles=[p0]+plots) ax1.set_xlabel("# bin & cumul. prop. of sites", fontsize=fnt_size) # Set the x-axis locator to reduce the number of ticks to 10 ax1.set_ylabel(r'$\theta_k$', fontsize=fnt_size, rotation = 90) ax1.set_title(title, fontsize=fnt_size) ax1.legend(handles=plots, loc='best', fontsize = fnt_size*0.5) ax1.set_xticks(x_ticks) step = len(x_ticks)//(n_ticks-1) values = x_ticks[::step] new_prop = [] for val in values: new_prop.append(prop[int(val)-2]) new_prop = new_prop[::-1] ax1.set_xticks(values) ax1.set_xticklabels([f'{values[k]}\n{val:.2f}' for k, val in enumerate(new_prop)], fontsize = fnt_size*0.8) if ax is None: # nb of plot_lines represent the number of epochs stored (len(plot_lines) = #breaks+1) plt.savefig(title+'_raw_'+str(nb_breaks)+'_breaks.pdf') plt.close(fig) # return plots return ax def combined_plot(folder_path, mu, tgen, breaks, title = "Title", theta_scale = False, selected_breaks = []): my_dpi = 300 saved_plots_dict = save_all_epochs_thetafolder(folder_path, mu, tgen, title, theta_scale, output = title+"_plotdata.json") nb_of_epochs = len(saved_plots_dict["all_epochs"]["plots"]) best_epoch = saved_plots_dict["best_epoch_by_AIC"] print("Best epoch based on AIC =", best_epoch) save_k_theta(folder_path, mu, tgen, title, theta_scale, breaks_max = nb_of_epochs, input = title+"_plotdata.json", output = title+"_plotdata.json") with open(title+"_plotdata.json", 'r') as json_file: loaded_data = json.load(json_file) # START OF COMBINED PLOT CODE # # plot page 1 of summary # fig1, ax1 = plt.subplots(2, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi) # # fig1.tight_layout() # # Adjust absolute space between the top and bottom rows # fig1.subplots_adjust(hspace=0.35) # Adjust this value based on your requirement # # plot page 2 of summary # fig2, ax2 = plt.subplots(2, 2, figsize=(5000/my_dpi, 2970/my_dpi), dpi=my_dpi) # # fig2.tight_layout() # ax1 = plot_raw_stairs(plot_lines = loaded_data['raw_stairs'], # prop = loaded_data['prop'], title = title, ax = ax1) # ax1 = plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'], # prop = loaded_data['prop'], title = title, ax = ax1, subset=[loaded_data['best_epoch_by_AIC']]+selected_breaks) # ax2 = plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'], # prop = loaded_data['prop'], title = title, ax = ax2) # ax1, ax2 = plot_all_epochs_thetafolder(loaded_data, mu, tgen, title, theta_scale, ax = [ax1, ax2]) # fig1.savefig(title+'_combined_p1.pdf') # print("Wrote", title+'_combined_p1.pdf') # fig2.savefig(title+'_combined_p2.pdf') # print("Wrote", title+'_combined_p2.pdf') # END OF COMBINED PLOT CODE # Start of Parsing real swp2 output folder_splitted = folder_path.split("/") swp2_summary = "/".join(folder_splitted[:-2])+'/'+folder_splitted[-3]+".final.summary" swp2_vals = parse_stairwayplot_output_summary(stwplt_out = swp2_summary) swp2_x, swp2_y = swp2_vals[0], swp2_vals[1] remove_back_and_forth_points(swp2_x, swp2_y) # End of Parsing real swp2 output plot_raw_stairs(plot_lines = loaded_data['raw_stairs'], prop = loaded_data['prop'], title = title, ax = None, max_breaks = breaks) plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'], mu = mu, tgen = tgen, subset=[loaded_data['best_epoch_by_AIC']]+selected_breaks, # plot_scaled_theta(plot_lines = loaded_data['scaled_stairs'], subset=list(range(0,3))+[loaded_data['best_epoch_by_AIC']]+selected_breaks, prop = loaded_data['prop'], title = title, swp2_lines = [swp2_x, swp2_y], ax = None) plot_all_epochs_thetafolder(loaded_data, mu, tgen, title, theta_scale, ax = None) # plt.close(fig1) # plt.close(fig2) def remove_back_and_forth_points(x_values, y_values): # to deal with some weirdness of plotting that occur sometimes with the swp2 output # sometimes the line is going back and forth as x_k > x_(k+1), which is normally not possible i = 0 while i < len(x_values) - 1: if x_values[i] >= x_values[i+1]: del x_values[i] del y_values[i] else: i += 1 def parse_stairwayplot_output_summary(stwplt_out, xlim = None, ylim = None, title = "default title", plot = False): #col 5 year = [] # col 6 ne_median = [] ne_2_5 = [] ne_97_5 = [] ne_12_5 = [] # col 10 ne_87_5 = [] with open(stwplt_out, "r") as stwplt_stream: for line in stwplt_stream: ## Line format # mutation_per_site n_estimation theta_per_site_median theta_per_site_2.5% theta_per_site_97.5% year Ne_median Ne_2.5% Ne_97.5% Ne_12.5% Ne_87.5% if not line.startswith("mutation_per_site"): #not header values = line.strip().split() year.append(float(values[5])) ne_median.append(float(values[6])) ne_2_5.append(float(values[7])) ne_97_5.append(float(values[8])) ne_12_5.append(float(values[9])) ne_87_5.append(float(values[10])) vals = [year, ne_median, ne_2_5, ne_97_5, ne_12_5, ne_87_5] if plot : # plot parsed data label = ["Ne median", "Ne 2.5%", "Ne 97.5%", "Ne 12.5%", "Ne 87.5%"] for i in range(1, 5): fig, = plt.plot(year, vals[i], '--', alpha = 0.4) fig.set_label(label[i]) # # last plot is median fig, = plt.plot(year, ne_median, 'r-', lw=2) fig.set_label(label[0]) plt.legend() plt.ylabel("Individuals (Ne)") plt.xlabel("Time (years)") if xlim: plt.xlim(xlim) if ylim: plt.ylim(ylim) plt.title(title) plt.show() plt.close() return vals if __name__ == "__main__": if len(sys.argv) != 4: print("Need 3 args: ThetaFolder MutationRate GenerationTime") exit(0) folder_path = sys.argv[1] mu = sys.argv[2] tgen = sys.argv[3] plot_all_epochs_thetafolder(folder_path, mu, tgen)