Browse Source

correction for sfs transform plot

tforest 1 year ago
parent
commit
a92dba2d25
2 changed files with 15 additions and 11 deletions
  1. 3 0
      customgraphics.py
  2. 12 11
      sfs_tools.py

+ 3 - 0
customgraphics.py View File

@@ -202,10 +202,13 @@ def scatter(x, y, ylab=None, xlab=None, title=None):
202 202
 
203 203
 def barplot(x=None, y=None, ylab=None, xlab=None, title=None):
204 204
     if x:
205
+        x = list(x)
206
+        plt.xticks(x)
205 207
         plt.bar(x, y)
206 208
     else:
207 209
         x = list(range(len(y)))
208 210
         plt.bar(x, y)
211
+        plt.xticks(x)
209 212
     if ylab:
210 213
         plt.ylabel(ylab)
211 214
     if xlab:

+ 12 - 11
sfs_tools.py View File

@@ -20,6 +20,7 @@ Rectify SFS comp in parsed funct.
20 20
 import gzip
21 21
 import sys
22 22
 import matplotlib.pyplot as plt
23
+from frst import customgraphics
23 24
 
24 25
 def sfs_from_vcf(n, vcf_file, folded = True, diploid = True, phased = False, verbose = False,
25 26
                  strip = False, count_ext = False):
@@ -194,8 +195,9 @@ def sfs_from_parsed_vcf(n, vcf_dict, folded = True, diploid = True, phased = Fal
194 195
 def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed = False):
195 196
     sfs_val = []
196 197
     n = len(sfs.values())
197
-    for k in range(1, n):
198
-        ksi = list(sfs.values())[k-1]
198
+    print("n =", n)
199
+    for k, ksi in sfs.items():
200
+        #ksi = list(sfs.values())[k-1]
199 201
         # k+1 because k starts from 0
200 202
         # if folded:
201 203
         #     # ?check if 2*n or not?
@@ -207,7 +209,8 @@ def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed =
207 209
         #         sfs_val.append(ksi)
208 210
         if transformed:
209 211
             if folded:
210
-                sfs_val.append(ksi * k * (2*n - k))
212
+                #sfs_val.append(ksi * k * (2*n - k))
213
+                sfs_val.append(((k*(2*n - k)) / (2*n))*ksi) 
211 214
             else:
212 215
                 sfs_val.append(ksi * k)
213 216
         else:
@@ -215,17 +218,15 @@ def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed =
215 218
             
216 219
     #terminal case, same for folded or unfolded
217 220
     if transformed:
218
-        sfs_val.append(list(sfs.values())[n-1] * n)
221
+        sfs_val[-1] = list(sfs.values())[n-1] * n
219 222
     else:
220
-         sfs_val.append(list(sfs.values())[n-1])
223
+        sfs_val[-1] = list(sfs.values())[n-1]
221 224
     #build the plot
222 225
     title = title+" [folded="+str(folded)+"]"+" [transformed="+str(transformed)+"]"
223
-    if ylab:
224
-        plt.ylabel(ylab)
225
-    if xlab:
226
-        plt.xlabel(xlab)
227
-    plt.title(title)
228
-    plt.bar([i+1 for i in sfs.keys()], sfs_val)
226
+    print("SFS =", sfs)
227
+    if transformed:
228
+        print("Transformed SFS ( n =",len(sfs_val), ") :", sfs_val)
229
+    customgraphics.barplot(x = sfs.keys(), y= sfs_val, xlab = xlab, ylab = ylab, title = title)
229 230
     plt.show()
230 231
 
231 232
 if __name__ == "__main__":