Browse Source

Update SFS plotting function

tforest 2 months ago
parent
commit
44449033db
2 changed files with 28 additions and 13 deletions
  1. 6 3
      customgraphics.py
  2. 22 10
      sfs_tools.py

+ 6 - 3
customgraphics.py View File

@@ -200,14 +200,14 @@ def scatter(x, y, ylab=None, xlab=None, title=None):
200 200
         plt.title(title)
201 201
     plt.show()
202 202
 
203
-def barplot(x=None, y=None, ylab=None, xlab=None, title=None):
203
+def barplot(x=None, y=None, ylab=None, xlab=None, title=None, label=None, xticks = None, width=1):
204 204
     if x:
205 205
         x = list(x)
206 206
         plt.xticks(x)
207
-        plt.bar(x, y)
207
+        plt.bar(x, y, width=width, label=label)
208 208
     else:
209 209
         x = list(range(len(y)))
210
-        plt.bar(x, y)
210
+        plt.bar(x, y, width=width, label=label)
211 211
         plt.xticks(x)
212 212
     if ylab:
213 213
         plt.ylabel(ylab)
@@ -215,6 +215,9 @@ def barplot(x=None, y=None, ylab=None, xlab=None, title=None):
215 215
         plt.xlabel(xlab)
216 216
     if title:
217 217
         plt.title(title)
218
+    if xticks:
219
+        plt.xticks(xticks)
220
+    plt.legend()
218 221
     plt.show()
219 222
 
220 223
 def plot_chrom_continuity(vcf_entries, chr_id, x=None, y=None, outfile = None,

+ 22 - 10
sfs_tools.py View File

@@ -21,6 +21,7 @@ import gzip
21 21
 import sys
22 22
 import matplotlib.pyplot as plt
23 23
 from frst import customgraphics
24
+import numpy as np
24 25
 
25 26
 def sfs_from_vcf(n, vcf_file, folded = True, diploid = True, phased = False, verbose = False,
26 27
                  strip = False, count_ext = False):
@@ -192,7 +193,7 @@ def sfs_from_parsed_vcf(n, vcf_dict, folded = True, diploid = True, phased = Fal
192 193
     return SFS_values, count_pluriall
193 194
 
194 195
 
195
-def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed = False, normalized = False):
196
+def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed = False, normalized = False, ploidy = 2):
196 197
     sfs_val = []
197 198
     n = len(sfs.values())
198 199
     sum_sites = sum(list(sfs.values()))
@@ -222,7 +223,7 @@ def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed =
222 223
             
223 224
     #terminal case, same for folded or unfolded
224 225
     if transformed:
225
-        last_bin = list(sfs.values())[n-1] * n/2
226
+        last_bin = list(sfs.values())[n-1] * n/ploidy
226 227
     else:
227 228
         last_bin = list(sfs.values())[n-1]
228 229
     sfs_val[-1] = last_bin
@@ -235,22 +236,33 @@ def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed =
235 236
         
236 237
         #print(sum(sfs_val))
237 238
     #build the plot
238
-    title = title+" (n="+str(len(sfs_val))+") [folded="+str(folded)+"]"+" [transformed="+str(transformed)+"]"
239
-    print("SFS =", sfs)
240 239
     if folded:
241 240
         xlab = "Minor allele frequency"
241
+        n_title = n
242
+    else:
243
+        # the spectrum is n-1 long when unfolded
244
+        n_title = n+1
245
+    
246
+    title = title+" (n="+str(n_title)+") [folded="+str(folded)+"]"+" [transformed="+str(transformed)+"]"
247
+    print("SFS =", sfs)
248
+
249
+    X_axis = list(sfs.keys()) 
250
+    
251
+
242 252
     if transformed:
243
-        print("Transformed SFS ( n =",len(sfs_val), ") :", sfs_val)
253
+        print("Transformed SFS ( n =",n_title, ") :", sfs_val)
244 254
         #plt.axhline(y=1/n, color='r', linestyle='-')
255
+        plt.bar([x+0.2 for x in list(sfs.keys())], [1/n]*n, color='r', linestyle='-', width = 0.4, label= "H0 Theoric constant")
256
+
245 257
     else:
246 258
         if normalized:
247 259
             # then plot a theoritical distribution as 1/i
248
-            expected_y = [1/(2*x+1) for x in list(sfs.keys())]
260
+            sum_expected = sum([(1/(i+1)) for i,x in enumerate(list(sfs.keys()))])
261
+            expected_y = [(1/(i+1))/sum_expected for i,x in enumerate(list(sfs.keys()))]
262
+            print(expected_y)
263
+            plt.bar([x+0.2 for x in list(sfs.keys())], expected_y, color='r', linestyle='-', width = 0.4, label= "H0 Theoric constant")
249 264
             print(sum(expected_y))
250
-            #plt.plot([x for x in list(sfs.keys())], expected_y, color='r', linestyle='-')
251
-            #print(expected_y)
252
-            
253
-    customgraphics.barplot(x = [x for x in list(sfs.keys())], y= sfs_val, xlab = xlab, ylab = ylab, title = title)
265
+    customgraphics.barplot(x = [x-0.2 for x in X_axis], width=0.4, y= sfs_val, xlab = xlab, ylab = ylab, title = title, label = "H1 Observed spectrum", xticks =list(sfs.keys()) )
254 266
     plt.show()
255 267
 
256 268
 if __name__ == "__main__":