Browse Source

Update SFS plotting function

tforest 9 months ago
parent
commit
44449033db
2 changed files with 28 additions and 13 deletions
  1. 6 3
      customgraphics.py
  2. 22 10
      sfs_tools.py

+ 6 - 3
customgraphics.py View File

200
         plt.title(title)
200
         plt.title(title)
201
     plt.show()
201
     plt.show()
202
 
202
 
203
-def barplot(x=None, y=None, ylab=None, xlab=None, title=None):
203
+def barplot(x=None, y=None, ylab=None, xlab=None, title=None, label=None, xticks = None, width=1):
204
     if x:
204
     if x:
205
         x = list(x)
205
         x = list(x)
206
         plt.xticks(x)
206
         plt.xticks(x)
207
-        plt.bar(x, y)
207
+        plt.bar(x, y, width=width, label=label)
208
     else:
208
     else:
209
         x = list(range(len(y)))
209
         x = list(range(len(y)))
210
-        plt.bar(x, y)
210
+        plt.bar(x, y, width=width, label=label)
211
         plt.xticks(x)
211
         plt.xticks(x)
212
     if ylab:
212
     if ylab:
213
         plt.ylabel(ylab)
213
         plt.ylabel(ylab)
215
         plt.xlabel(xlab)
215
         plt.xlabel(xlab)
216
     if title:
216
     if title:
217
         plt.title(title)
217
         plt.title(title)
218
+    if xticks:
219
+        plt.xticks(xticks)
220
+    plt.legend()
218
     plt.show()
221
     plt.show()
219
 
222
 
220
 def plot_chrom_continuity(vcf_entries, chr_id, x=None, y=None, outfile = None,
223
 def plot_chrom_continuity(vcf_entries, chr_id, x=None, y=None, outfile = None,

+ 22 - 10
sfs_tools.py View File

21
 import sys
21
 import sys
22
 import matplotlib.pyplot as plt
22
 import matplotlib.pyplot as plt
23
 from frst import customgraphics
23
 from frst import customgraphics
24
+import numpy as np
24
 
25
 
25
 def sfs_from_vcf(n, vcf_file, folded = True, diploid = True, phased = False, verbose = False,
26
 def sfs_from_vcf(n, vcf_file, folded = True, diploid = True, phased = False, verbose = False,
26
                  strip = False, count_ext = False):
27
                  strip = False, count_ext = False):
192
     return SFS_values, count_pluriall
193
     return SFS_values, count_pluriall
193
 
194
 
194
 
195
 
195
-def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed = False, normalized = False):
196
+def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed = False, normalized = False, ploidy = 2):
196
     sfs_val = []
197
     sfs_val = []
197
     n = len(sfs.values())
198
     n = len(sfs.values())
198
     sum_sites = sum(list(sfs.values()))
199
     sum_sites = sum(list(sfs.values()))
222
             
223
             
223
     #terminal case, same for folded or unfolded
224
     #terminal case, same for folded or unfolded
224
     if transformed:
225
     if transformed:
225
-        last_bin = list(sfs.values())[n-1] * n/2
226
+        last_bin = list(sfs.values())[n-1] * n/ploidy
226
     else:
227
     else:
227
         last_bin = list(sfs.values())[n-1]
228
         last_bin = list(sfs.values())[n-1]
228
     sfs_val[-1] = last_bin
229
     sfs_val[-1] = last_bin
235
         
236
         
236
         #print(sum(sfs_val))
237
         #print(sum(sfs_val))
237
     #build the plot
238
     #build the plot
238
-    title = title+" (n="+str(len(sfs_val))+") [folded="+str(folded)+"]"+" [transformed="+str(transformed)+"]"
239
-    print("SFS =", sfs)
240
     if folded:
239
     if folded:
241
         xlab = "Minor allele frequency"
240
         xlab = "Minor allele frequency"
241
+        n_title = n
242
+    else:
243
+        # the spectrum is n-1 long when unfolded
244
+        n_title = n+1
245
+    
246
+    title = title+" (n="+str(n_title)+") [folded="+str(folded)+"]"+" [transformed="+str(transformed)+"]"
247
+    print("SFS =", sfs)
248
+
249
+    X_axis = list(sfs.keys()) 
250
+    
251
+
242
     if transformed:
252
     if transformed:
243
-        print("Transformed SFS ( n =",len(sfs_val), ") :", sfs_val)
253
+        print("Transformed SFS ( n =",n_title, ") :", sfs_val)
244
         #plt.axhline(y=1/n, color='r', linestyle='-')
254
         #plt.axhline(y=1/n, color='r', linestyle='-')
255
+        plt.bar([x+0.2 for x in list(sfs.keys())], [1/n]*n, color='r', linestyle='-', width = 0.4, label= "H0 Theoric constant")
256
+
245
     else:
257
     else:
246
         if normalized:
258
         if normalized:
247
             # then plot a theoritical distribution as 1/i
259
             # then plot a theoritical distribution as 1/i
248
-            expected_y = [1/(2*x+1) for x in list(sfs.keys())]
260
+            sum_expected = sum([(1/(i+1)) for i,x in enumerate(list(sfs.keys()))])
261
+            expected_y = [(1/(i+1))/sum_expected for i,x in enumerate(list(sfs.keys()))]
262
+            print(expected_y)
263
+            plt.bar([x+0.2 for x in list(sfs.keys())], expected_y, color='r', linestyle='-', width = 0.4, label= "H0 Theoric constant")
249
             print(sum(expected_y))
264
             print(sum(expected_y))
250
-            #plt.plot([x for x in list(sfs.keys())], expected_y, color='r', linestyle='-')
251
-            #print(expected_y)
252
-            
253
-    customgraphics.barplot(x = [x for x in list(sfs.keys())], y= sfs_val, xlab = xlab, ylab = ylab, title = title)
265
+    customgraphics.barplot(x = [x-0.2 for x in X_axis], width=0.4, y= sfs_val, xlab = xlab, ylab = ylab, title = title, label = "H1 Observed spectrum", xticks =list(sfs.keys()) )
254
     plt.show()
266
     plt.show()
255
 
267
 
256
 if __name__ == "__main__":
268
 if __name__ == "__main__":