sfs_tools.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. #!/usr/bin/env python3
  2. """
  3. FOREST Thomas (thomas.forest@college-de-france.fr)
  4. Caution : At the moment for gzipped files only.
  5. ARGS
  6. --------
  7. standalone usage : vcf_to_sfs.py VCF.gz nb_indiv
  8. TODO
  9. _____
  10. Externalize sfs transforms in a function
  11. Rectify SFS comp in parsed funct.
  12. """
  13. import gzip
  14. import sys
  15. import matplotlib.pyplot as plt
  16. from frst import customgraphics
  17. import numpy as np
  18. def parse_sfs(sfs_file):
  19. """
  20. Parse a Site Frequency Spectrum (SFS) file and return a masked spectrum.
  21. This function reads an SFS file, extracts the spectrum data, and applies a mask to it.
  22. The mask excludes specific bins from the spectrum, resulting in a masked SFS.
  23. Parameters:
  24. - sfs_file (str): The path to the SFS file to be parsed, in dadi's .fs format.
  25. Returns:
  26. - masked_spectrum (list): A masked SFS as a list of integers.
  27. Raises:
  28. - FileNotFoundError: If the specified SFS file is not found.
  29. - ValueError: If there are inconsistencies in the file format or data.
  30. Note: The actual structure of the SFS file is based on dadi's fs format.
  31. """
  32. try:
  33. with open(sfs_file, 'r') as file:
  34. # Read the first line which contains information about the file
  35. num_individuals, mode, species_name = file.readline().strip().split()
  36. num_individuals = int(num_individuals)
  37. # Read the spectrum data
  38. spectrum_data = list(map(int, file.readline().strip().split()))
  39. # Check if the number of bins in the spectrum matches the expected number
  40. if len(spectrum_data) != num_individuals:
  41. raise ValueError("Error: Number of bins in the spectrum doesn't match the expected number of individuals.")
  42. # Read the mask data
  43. mask_data = list(map(int, file.readline().strip().split()))
  44. # Check if the size of the mask matches the number of bins in the spectrum
  45. if len(mask_data) != num_individuals:
  46. raise ValueError("Error: Size of the mask doesn't match the number of bins in the spectrum.")
  47. # Apply the mask to the spectrum
  48. masked_spectrum = [spectrum_data[i] for i in range(num_individuals) if not mask_data[i]]
  49. # Error handling
  50. except FileNotFoundError:
  51. print(f"Error: File not found - {sfs_file}")
  52. except ValueError as ve:
  53. print(f"Error: {ve}")
  54. except Exception as e:
  55. print(f"Error: {e}")
  56. # final return of SFS as a list
  57. return masked_spectrum
  58. def sfs_from_vcf(n, vcf_file, folded = True, diploid = True, phased = False, verbose = False,
  59. strip = False, count_ext = False):
  60. """
  61. Generates a Site Frequency Spectrum from a gzipped VCF file format.
  62. Parameters
  63. ----------
  64. n : int
  65. Nb of individuals in sample.
  66. vcf_file : str
  67. SNPs in VCF file format.
  68. Used to generate a Site Frequency Spectrum (SFS) from a VCF.
  69. Returns
  70. -------
  71. dict
  72. Site Frequency Spectrum (SFS)
  73. """
  74. if diploid and not folded:
  75. n *= 2
  76. # initiate SFS_values with a zeros dict
  77. # if strip:
  78. # # "[1" removes the 0 bin
  79. # # "n-1]" crop the last bin (n or n/2 for folded)
  80. # SFS_dim = [1, n-1]
  81. # else:
  82. SFS_dim = [0, n+1]
  83. SFS_values = dict.fromkeys(range(SFS_dim[1]),0)
  84. count_pluriall = 0
  85. with gzip.open(vcf_file, "rb") as inputgz:
  86. line = inputgz.readline()
  87. genotypes = []
  88. print("Parsing VCF", vcf_file, "... Please wait...")
  89. while line:
  90. # decode gzipped binary lines
  91. line = line.decode('utf-8').strip()
  92. # every snp line, not comment or header
  93. if not line.startswith("##") and not line.startswith("#"):
  94. FIELDS = line.split("\t")
  95. # REF is col 4 of VCF
  96. REF = FIELDS[3].split(",")
  97. # ALT is col 5 of VCF
  98. ALT = FIELDS[4].split(",")
  99. FORMAT = line.split("\t")[8:9]
  100. SAMPLES = line.split("\t")[9:]
  101. snp_genotypes = []
  102. allele_counts = {}
  103. allele_counts_list = []
  104. # SKIP the SNP if :
  105. # 1 : missing
  106. # 2 : deletion among REF
  107. # 3 : deletion among ALT
  108. if "./.:." in line \
  109. or len(ALT[0]) > 1 \
  110. or len(REF[0]) > 1:
  111. line = inputgz.readline()
  112. continue
  113. for sample in SAMPLES:
  114. if not phased:
  115. # for UNPHASED data
  116. smpl_genotype = [int(a) for a in sample.split(':')[0].split('/') if a != '.']
  117. else:
  118. # for PHASED
  119. smpl_genotype = [int(a) for a in sample.split(':')[0].split('|') if a != '.']
  120. nb_alleles = set(smpl_genotype)
  121. snp_genotypes += smpl_genotype
  122. # skip if all individuals have the same genotype
  123. # if len(set(snp_genotypes)) == 1:
  124. # if folded or (folded == False and snp_genotypes.count(1) == 0) :
  125. # line = inputgz.readline()
  126. # continue
  127. for k in set(snp_genotypes):
  128. allele_counts[snp_genotypes.count(k)] = k
  129. allele_counts_list.append(snp_genotypes.count(k))
  130. #print(allele_counts_list)
  131. if len(set(snp_genotypes)) == 1 or allele_counts_list[0] == allele_counts_list[1]:
  132. # If only heterozygous sites 0/1; skip the site (equivalent to n bin or n/2 bin for folded)
  133. # skip if all individuals have the same genotype
  134. line = inputgz.readline()
  135. continue
  136. if len(ALT) >= 2:
  137. #pass
  138. count_pluriall +=1
  139. # TODO - work in progress
  140. # for al in range(len(ALT)-1):
  141. # SFS_values[min(allele_counts_list)-1] += 1/len(ALT)
  142. # allele_counts_list.remove(min(allele_counts_list))
  143. else:
  144. if folded:
  145. SFS_values[min(allele_counts_list)-SFS_dim[0]] += 1
  146. else :
  147. # if unfolded, count the Ones (ALT allele)
  148. #print(snp_genotypes, snp_genotypes.count(1))
  149. SFS_values[snp_genotypes.count(1)-SFS_dim[0]] += 1
  150. # all the parsing is done, change line
  151. line = inputgz.readline()
  152. if verbose:
  153. print("SFS=", SFS_values)
  154. if strip:
  155. del SFS_values[0]
  156. del SFS_values[n]
  157. print("Pluriallelic sites =", count_pluriall)
  158. return SFS_values, count_pluriall
  159. def sfs_from_parsed_vcf(n, vcf_dict, folded = True, diploid = True, phased = False, verbose = False):
  160. """
  161. Generates a Site Frequency Spectrum from a gzipped VCF file format.
  162. Parameters
  163. ----------
  164. n : int
  165. Nb of individuals in sample.
  166. vcf_file : str
  167. SNPs in VCF file format.
  168. Used to generate a Site Frequency Spectrum (SFS) from a VCF.
  169. Returns
  170. -------
  171. dict
  172. Site Frequency Spectrum (SFS)
  173. """
  174. if diploid and not folded:
  175. n *= 2
  176. # initiate SFS_values with a zeros dict
  177. SFS_values = dict.fromkeys(range(n),0)
  178. count_pluriall = 0
  179. for CHROM in vcf_dict:
  180. for SNP in vcf_dict[CHROM]:
  181. snp_genotypes = []
  182. allele_counts = {}
  183. allele_counts_list = []
  184. print(CHROM, SNP)
  185. for sample in vcf_dict[CHROM][SNP]["SAMPLES"]:
  186. if not phased:
  187. # for UNPHASED data
  188. smpl_genotype = [int(a) for a in sample.split(':')[0].split('/') if a != '.']
  189. else:
  190. # for PHASED
  191. smpl_genotype = [int(a) for a in sample.split(':')[0].split('|') if a != '.']
  192. nb_alleles = set(smpl_genotype)
  193. snp_genotypes += smpl_genotype
  194. # skip if all individuals have the same genotype
  195. if len(set(snp_genotypes)) == 1:
  196. continue
  197. for k in set(snp_genotypes):
  198. allele_counts[snp_genotypes.count(k)] = k
  199. allele_counts_list.append(snp_genotypes.count(k))
  200. SFS_values[min(allele_counts_list)-1] += 1
  201. # sum pluriall counts for this CHR to the rest
  202. count_pluriall += vcf_dict[CHROM]['NB_PLURIALL']
  203. if verbose:
  204. print("SFS=", SFS_values)
  205. print("Pluriallelic sites =", count_pluriall)
  206. return SFS_values, count_pluriall
  207. def barplot_sfs(sfs, xlab, ylab, folded=True, title = "Barplot", transformed = False, normalized = False, ploidy = 2, output = None):
  208. sfs_val = []
  209. n = len(sfs.values())
  210. sum_sites = sum(list(sfs.values()))
  211. for k, ksi in sfs.items():
  212. #ksi = list(sfs.values())[k-1]
  213. # k+1 because k starts from 0
  214. # if folded:
  215. # # ?check if 2*n or not?
  216. # sfs_val.append(ksi * k * (2*n - k))
  217. # else:
  218. # if transformed:
  219. # sfs_val.append(ksi * k)
  220. # else:
  221. # sfs_val.append(ksi)
  222. if transformed:
  223. ylab = r'$ \phi_i $'
  224. if folded:
  225. val = ((k*(2*n - k)) / (2*n))*(ksi)
  226. else:
  227. val = ksi * k
  228. else:
  229. val = ksi
  230. sfs_val.append(val)
  231. if not transformed and not normalized:
  232. ylab = r'$ \eta_i $'
  233. #terminal case, same for folded or unfolded
  234. if transformed:
  235. last_bin = list(sfs.values())[n-1] * n/ploidy
  236. else:
  237. last_bin = list(sfs.values())[n-1]
  238. sfs_val[-1] = last_bin
  239. if normalized:
  240. #ylab = "Fraction of SNPs "
  241. ylab = r'$ \phi_i $'
  242. sum_val = sum(sfs_val)
  243. for k, sfs_bin in enumerate(sfs_val):
  244. sfs_val[k] = sfs_bin / sum_val
  245. #print(sum(sfs_val))
  246. #build the plot
  247. if folded:
  248. xlab = "Minor allele frequency"
  249. n_title = n
  250. else:
  251. # the spectrum is n-1 long when unfolded
  252. n_title = n+1
  253. original_title = title
  254. # reformat title and add infos
  255. title = title+" (n="+str(n_title)+") [folded="+str(folded)+"]"+" [transformed="+str(transformed)+"]"
  256. print("SFS =", sfs)
  257. X_axis = list(sfs.keys())
  258. if transformed:
  259. print("Transformed SFS ( n =",n_title, ") :", sfs_val)
  260. #plt.axhline(y=1/n, color='r', linestyle='-')
  261. plt.bar([x+0.2 for x in list(sfs.keys())], [1/n]*n, fill=False, hatch="///", linestyle='-', width = 0.4, label= "H0 Theoric constant")
  262. else:
  263. if normalized:
  264. # then plot a theoritical distribution as 1/i
  265. sum_expected = sum([(1/(i+1)) for i,x in enumerate(list(sfs.keys()))])
  266. expected_y = [(1/(i+1))/sum_expected for i,x in enumerate(list(sfs.keys()))]
  267. print(expected_y)
  268. plt.bar([x+0.2 for x in list(sfs.keys())], expected_y, fill=False, hatch="///", linestyle='-', width = 0.4, label= "H0 Theoric constant")
  269. print(sum(expected_y))
  270. if output is not None:
  271. # if write in a file, don't open the window dynamically
  272. plot = False
  273. else:
  274. plot = True
  275. customgraphics.barplot(x = [x-0.2 for x in X_axis], width=0.4, y= sfs_val, xlab = xlab, ylab = ylab, title = title, label = "H1 Observed spectrum", xticks =list(sfs.keys()), plot = plot )
  276. if output:
  277. plt.savefig(f"{output}/{original_title}_SFS.pdf")
  278. else:
  279. plt.show()
  280. plt.close()
  281. if __name__ == "__main__":
  282. if len(sys.argv) != 3:
  283. print("Need 2 args")
  284. exit(0)
  285. # PARAM : Nb of indiv
  286. n = int(sys.argv[2])
  287. sfs = sfs_from_vcf(n, sys.argv[1], folded = True, diploid = True, phased = False, strip = True)
  288. print(sfs)