Browse Source

correction for sfs transform plot

tforest 1 year ago
parent
commit
a92dba2d25
2 changed files with 15 additions and 11 deletions
  1. 3 0
      customgraphics.py
  2. 12 11
      sfs_tools.py

+ 3 - 0
customgraphics.py View File

202
 
202
 
203
 def barplot(x=None, y=None, ylab=None, xlab=None, title=None):
203
 def barplot(x=None, y=None, ylab=None, xlab=None, title=None):
204
     if x:
204
     if x:
205
+        x = list(x)
206
+        plt.xticks(x)
205
         plt.bar(x, y)
207
         plt.bar(x, y)
206
     else:
208
     else:
207
         x = list(range(len(y)))
209
         x = list(range(len(y)))
208
         plt.bar(x, y)
210
         plt.bar(x, y)
211
+        plt.xticks(x)
209
     if ylab:
212
     if ylab:
210
         plt.ylabel(ylab)
213
         plt.ylabel(ylab)
211
     if xlab:
214
     if xlab:

+ 12 - 11
sfs_tools.py View File

20
 import gzip
20
 import gzip
21
 import sys
21
 import sys
22
 import matplotlib.pyplot as plt
22
 import matplotlib.pyplot as plt
23
+from frst import customgraphics
23
 
24
 
24
 def sfs_from_vcf(n, vcf_file, folded = True, diploid = True, phased = False, verbose = False,
25
 def sfs_from_vcf(n, vcf_file, folded = True, diploid = True, phased = False, verbose = False,
25
                  strip = False, count_ext = False):
26
                  strip = False, count_ext = False):
194
 def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed = False):
195
 def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed = False):
195
     sfs_val = []
196
     sfs_val = []
196
     n = len(sfs.values())
197
     n = len(sfs.values())
197
-    for k in range(1, n):
198
-        ksi = list(sfs.values())[k-1]
198
+    print("n =", n)
199
+    for k, ksi in sfs.items():
200
+        #ksi = list(sfs.values())[k-1]
199
         # k+1 because k starts from 0
201
         # k+1 because k starts from 0
200
         # if folded:
202
         # if folded:
201
         #     # ?check if 2*n or not?
203
         #     # ?check if 2*n or not?
207
         #         sfs_val.append(ksi)
209
         #         sfs_val.append(ksi)
208
         if transformed:
210
         if transformed:
209
             if folded:
211
             if folded:
210
-                sfs_val.append(ksi * k * (2*n - k))
212
+                #sfs_val.append(ksi * k * (2*n - k))
213
+                sfs_val.append(((k*(2*n - k)) / (2*n))*ksi) 
211
             else:
214
             else:
212
                 sfs_val.append(ksi * k)
215
                 sfs_val.append(ksi * k)
213
         else:
216
         else:
215
             
218
             
216
     #terminal case, same for folded or unfolded
219
     #terminal case, same for folded or unfolded
217
     if transformed:
220
     if transformed:
218
-        sfs_val.append(list(sfs.values())[n-1] * n)
221
+        sfs_val[-1] = list(sfs.values())[n-1] * n
219
     else:
222
     else:
220
-         sfs_val.append(list(sfs.values())[n-1])
223
+        sfs_val[-1] = list(sfs.values())[n-1]
221
     #build the plot
224
     #build the plot
222
     title = title+" [folded="+str(folded)+"]"+" [transformed="+str(transformed)+"]"
225
     title = title+" [folded="+str(folded)+"]"+" [transformed="+str(transformed)+"]"
223
-    if ylab:
224
-        plt.ylabel(ylab)
225
-    if xlab:
226
-        plt.xlabel(xlab)
227
-    plt.title(title)
228
-    plt.bar([i+1 for i in sfs.keys()], sfs_val)
226
+    print("SFS =", sfs)
227
+    if transformed:
228
+        print("Transformed SFS ( n =",len(sfs_val), ") :", sfs_val)
229
+    customgraphics.barplot(x = sfs.keys(), y= sfs_val, xlab = xlab, ylab = ylab, title = title)
229
     plt.show()
230
     plt.show()
230
 
231
 
231
 if __name__ == "__main__":
232
 if __name__ == "__main__":