Browse Source

Improving sfs plotting

tforest 9 months ago
parent
commit
25d7ef0858
2 changed files with 68 additions and 9 deletions
  1. 4 3
      customgraphics.py
  2. 64 6
      sfs_tools.py

+ 4 - 3
customgraphics.py View File

@@ -200,11 +200,11 @@ def scatter(x, y, ylab=None, xlab=None, title=None):
200 200
         plt.title(title)
201 201
     plt.show()
202 202
 
203
-def barplot(x=None, y=None, ylab=None, xlab=None, title=None, label=None, xticks = None, width=1):
203
+def barplot(x=None, y=None, ylab=None, xlab=None, title=None, label=None, xticks = None, width=1, plot = True):
204 204
     if x:
205 205
         x = list(x)
206 206
         plt.xticks(x)
207
-        plt.bar(x, y, width=width, label=label)
207
+        plt.bar(x, y, width=width, label=label, color="tab:blue")
208 208
     else:
209 209
         x = list(range(len(y)))
210 210
         plt.bar(x, y, width=width, label=label)
@@ -218,7 +218,8 @@ def barplot(x=None, y=None, ylab=None, xlab=None, title=None, label=None, xticks
218 218
     if xticks:
219 219
         plt.xticks(xticks)
220 220
     plt.legend()
221
-    plt.show()
221
+    if plot:
222
+        plt.show()
222 223
 
223 224
 def plot_chrom_continuity(vcf_entries, chr_id, x=None, y=None, outfile = None,
224 225
                           outfolder = None, returned=False, show=True, label=True, step=1, nb_subplots = None,

+ 64 - 6
sfs_tools.py View File

@@ -23,6 +23,54 @@ import matplotlib.pyplot as plt
23 23
 from frst import customgraphics
24 24
 import numpy as np
25 25
 
26
+def parse_sfs(sfs_file):
27
+    """
28
+    Parse a Site Frequency Spectrum (SFS) file and return a masked spectrum.
29
+
30
+    This function reads an SFS file, extracts the spectrum data, and applies a mask to it.
31
+    The mask excludes specific bins from the spectrum, resulting in a masked SFS.
32
+
33
+    Parameters:
34
+    - sfs_file (str): The path to the SFS file to be parsed, in dadi's .fs format.
35
+
36
+    Returns:
37
+    - masked_spectrum (list): A masked SFS as a list of integers.
38
+
39
+    Raises:
40
+    - FileNotFoundError: If the specified SFS file is not found.
41
+    - ValueError: If there are inconsistencies in the file format or data.
42
+
43
+    Note: The actual structure of the SFS file is based on dadi's fs format.
44
+    """
45
+    try:
46
+        with open(sfs_file, 'r') as file:
47
+            # Read the first line which contains information about the file
48
+            num_individuals, mode, species_name = file.readline().strip().split()
49
+            num_individuals = int(num_individuals)
50
+            # Read the spectrum data
51
+            spectrum_data = list(map(int, file.readline().strip().split()))
52
+            # Check if the number of bins in the spectrum matches the expected number
53
+            if len(spectrum_data) != num_individuals:
54
+                raise ValueError("Error: Number of bins in the spectrum doesn't match the expected number of individuals.")
55
+            # Read the mask data
56
+            mask_data = list(map(int, file.readline().strip().split()))
57
+
58
+            # Check if the size of the mask matches the number of bins in the spectrum
59
+            if len(mask_data) != num_individuals:
60
+                raise ValueError("Error: Size of the mask doesn't match the number of bins in the spectrum.")
61
+            # Apply the mask to the spectrum
62
+            masked_spectrum = [spectrum_data[i] for i in range(num_individuals) if not mask_data[i]]
63
+    # Error handling
64
+    except FileNotFoundError:
65
+        print(f"Error: File not found - {sfs_file}")
66
+    except ValueError as ve:
67
+        print(f"Error: {ve}")
68
+    except Exception as e:
69
+        print(f"Error: {e}")
70
+    # final return of SFS as a list
71
+    return masked_spectrum
72
+
73
+
26 74
 def sfs_from_vcf(n, vcf_file, folded = True, diploid = True, phased = False, verbose = False,
27 75
                  strip = False, count_ext = False):
28 76
 
@@ -193,7 +241,7 @@ def sfs_from_parsed_vcf(n, vcf_dict, folded = True, diploid = True, phased = Fal
193 241
     return SFS_values, count_pluriall
194 242
 
195 243
 
196
-def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed = False, normalized = False, ploidy = 2):
244
+def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed = False, normalized = False, ploidy = 2, output = None):
197 245
     sfs_val = []
198 246
     n = len(sfs.values())
199 247
     sum_sites = sum(list(sfs.values()))
@@ -242,7 +290,8 @@ def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed =
242 290
     else:
243 291
         # the spectrum is n-1 long when unfolded
244 292
         n_title = n+1
245
-    
293
+    original_title = title
294
+    # reformat title and add infos
246 295
     title = title+" (n="+str(n_title)+") [folded="+str(folded)+"]"+" [transformed="+str(transformed)+"]"
247 296
     print("SFS =", sfs)
248 297
 
@@ -252,7 +301,7 @@ def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed =
252 301
     if transformed:
253 302
         print("Transformed SFS ( n =",n_title, ") :", sfs_val)
254 303
         #plt.axhline(y=1/n, color='r', linestyle='-')
255
-        plt.bar([x+0.2 for x in list(sfs.keys())], [1/n]*n, color='r', linestyle='-', width = 0.4, label= "H0 Theoric constant")
304
+        plt.bar([x+0.2 for x in list(sfs.keys())], [1/n]*n, fill=False, hatch="///", linestyle='-', width = 0.4, label= "H0 Theoric constant")
256 305
 
257 306
     else:
258 307
         if normalized:
@@ -260,10 +309,19 @@ def barplot_sfs(sfs,  xlab, ylab, folded=True, title = "Barplot", transformed =
260 309
             sum_expected = sum([(1/(i+1)) for i,x in enumerate(list(sfs.keys()))])
261 310
             expected_y = [(1/(i+1))/sum_expected for i,x in enumerate(list(sfs.keys()))]
262 311
             print(expected_y)
263
-            plt.bar([x+0.2 for x in list(sfs.keys())], expected_y, color='r', linestyle='-', width = 0.4, label= "H0 Theoric constant")
312
+            plt.bar([x+0.2 for x in list(sfs.keys())], expected_y, fill=False, hatch="///", linestyle='-', width = 0.4, label= "H0 Theoric constant")
264 313
             print(sum(expected_y))
265
-    customgraphics.barplot(x = [x-0.2 for x in X_axis], width=0.4, y= sfs_val, xlab = xlab, ylab = ylab, title = title, label = "H1 Observed spectrum", xticks =list(sfs.keys()) )
266
-    plt.show()
314
+    if output is not None:
315
+        # if write in a file, don't open the window dynamically
316
+        plot = False
317
+    else:
318
+        plot = True
319
+    customgraphics.barplot(x = [x-0.2 for x in X_axis], width=0.4, y= sfs_val, xlab = xlab, ylab = ylab, title = title, label = "H1 Observed spectrum", xticks =list(sfs.keys()), plot = plot )
320
+    if output:
321
+       plt.savefig(f"{output}/{original_title}_SFS.pdf")
322
+    else:
323
+        plt.show()
324
+    plt.close()
267 325
 
268 326
 if __name__ == "__main__":
269 327