sfs_tools.py 8.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. #!/usr/bin/env python3
  2. """
  3. FOREST Thomas (thomas.forest@college-de-france.fr)
  4. Caution : At the moment for gzipped files only.
  5. ARGS
  6. --------
  7. standalone usage : vcf_to_sfs.py VCF.gz nb_indiv
  8. TODO
  9. _____
  10. Externalize sfs transforms in a function
  11. Rectify SFS comp in parsed funct.
  12. """
  13. import gzip
  14. import sys
  15. import matplotlib.pyplot as plt
  16. from frst import customgraphics
  17. def sfs_from_vcf(n, vcf_file, folded = True, diploid = True, phased = False, verbose = False,
  18. strip = False, count_ext = False):
  19. """
  20. Generates a Site Frequency Spectrum from a gzipped VCF file format.
  21. Parameters
  22. ----------
  23. n : int
  24. Nb of individuals in sample.
  25. vcf_file : str
  26. SNPs in VCF file format.
  27. Used to generate a Site Frequency Spectrum (SFS) from a VCF.
  28. Returns
  29. -------
  30. dict
  31. Site Frequency Spectrum (SFS)
  32. """
  33. if diploid and not folded:
  34. n *= 2
  35. # initiate SFS_values with a zeros dict
  36. # if strip:
  37. # # "[1" removes the 0 bin
  38. # # "n-1]" crop the last bin (n or n/2 for folded)
  39. # SFS_dim = [1, n-1]
  40. # else:
  41. SFS_dim = [0, n+1]
  42. SFS_values = dict.fromkeys(range(SFS_dim[1]),0)
  43. count_pluriall = 0
  44. with gzip.open(vcf_file, "rb") as inputgz:
  45. line = inputgz.readline()
  46. genotypes = []
  47. print("Parsing VCF", vcf_file, "... Please wait...")
  48. while line:
  49. # decode gzipped binary lines
  50. line = line.decode('utf-8').strip()
  51. # every snp line, not comment or header
  52. if not line.startswith("##") and not line.startswith("#"):
  53. FIELDS = line.split("\t")
  54. # REF is col 4 of VCF
  55. REF = FIELDS[3].split(",")
  56. # ALT is col 5 of VCF
  57. ALT = FIELDS[4].split(",")
  58. FORMAT = line.split("\t")[8:9]
  59. SAMPLES = line.split("\t")[9:]
  60. snp_genotypes = []
  61. allele_counts = {}
  62. allele_counts_list = []
  63. # SKIP the SNP if :
  64. # 1 : missing
  65. # 2 : deletion among REF
  66. # 3 : deletion among ALT
  67. if "./.:." in line \
  68. or len(ALT[0]) > 1 \
  69. or len(REF[0]) > 1:
  70. line = inputgz.readline()
  71. continue
  72. for sample in SAMPLES:
  73. if not phased:
  74. # for UNPHASED data
  75. smpl_genotype = [int(a) for a in sample.split(':')[0].split('/') if a != '.']
  76. else:
  77. # for PHASED
  78. smpl_genotype = [int(a) for a in sample.split(':')[0].split('|') if a != '.']
  79. nb_alleles = set(smpl_genotype)
  80. snp_genotypes += smpl_genotype
  81. # skip if all individuals have the same genotype
  82. # if len(set(snp_genotypes)) == 1:
  83. # if folded or (folded == False and snp_genotypes.count(1) == 0) :
  84. # line = inputgz.readline()
  85. # continue
  86. for k in set(snp_genotypes):
  87. allele_counts[snp_genotypes.count(k)] = k
  88. allele_counts_list.append(snp_genotypes.count(k))
  89. #print(allele_counts_list)
  90. if len(set(snp_genotypes)) == 1 or allele_counts_list[0] == allele_counts_list[1]:
  91. # If only heterozygous sites 0/1; skip the site (equivalent to n bin or n/2 bin for folded)
  92. # skip if all individuals have the same genotype
  93. line = inputgz.readline()
  94. continue
  95. if len(ALT) >= 2:
  96. #pass
  97. count_pluriall +=1
  98. # TODO - work in progress
  99. # for al in range(len(ALT)-1):
  100. # SFS_values[min(allele_counts_list)-1] += 1/len(ALT)
  101. # allele_counts_list.remove(min(allele_counts_list))
  102. else:
  103. if folded:
  104. SFS_values[min(allele_counts_list)-SFS_dim[0]] += 1
  105. else :
  106. # if unfolded, count the Ones (ALT allele)
  107. #print(snp_genotypes, snp_genotypes.count(1))
  108. SFS_values[snp_genotypes.count(1)-SFS_dim[0]] += 1
  109. # all the parsing is done, change line
  110. line = inputgz.readline()
  111. if verbose:
  112. print("SFS=", SFS_values)
  113. if strip:
  114. del SFS_values[0]
  115. del SFS_values[n]
  116. print("Pluriallelic sites =", count_pluriall)
  117. return SFS_values, count_pluriall
  118. def sfs_from_parsed_vcf(n, vcf_dict, folded = True, diploid = True, phased = False, verbose = False):
  119. """
  120. Generates a Site Frequency Spectrum from a gzipped VCF file format.
  121. Parameters
  122. ----------
  123. n : int
  124. Nb of individuals in sample.
  125. vcf_file : str
  126. SNPs in VCF file format.
  127. Used to generate a Site Frequency Spectrum (SFS) from a VCF.
  128. Returns
  129. -------
  130. dict
  131. Site Frequency Spectrum (SFS)
  132. """
  133. if diploid and not folded:
  134. n *= 2
  135. # initiate SFS_values with a zeros dict
  136. SFS_values = dict.fromkeys(range(n),0)
  137. count_pluriall = 0
  138. for CHROM in vcf_dict:
  139. for SNP in vcf_dict[CHROM]:
  140. snp_genotypes = []
  141. allele_counts = {}
  142. allele_counts_list = []
  143. print(CHROM, SNP)
  144. for sample in vcf_dict[CHROM][SNP]["SAMPLES"]:
  145. if not phased:
  146. # for UNPHASED data
  147. smpl_genotype = [int(a) for a in sample.split(':')[0].split('/') if a != '.']
  148. else:
  149. # for PHASED
  150. smpl_genotype = [int(a) for a in sample.split(':')[0].split('|') if a != '.']
  151. nb_alleles = set(smpl_genotype)
  152. snp_genotypes += smpl_genotype
  153. # skip if all individuals have the same genotype
  154. if len(set(snp_genotypes)) == 1:
  155. continue
  156. for k in set(snp_genotypes):
  157. allele_counts[snp_genotypes.count(k)] = k
  158. allele_counts_list.append(snp_genotypes.count(k))
  159. SFS_values[min(allele_counts_list)-1] += 1
  160. # sum pluriall counts for this CHR to the rest
  161. count_pluriall += vcf_dict[CHROM]['NB_PLURIALL']
  162. if verbose:
  163. print("SFS=", SFS_values)
  164. print("Pluriallelic sites =", count_pluriall)
  165. return SFS_values, count_pluriall
  166. def barplot_sfs(sfs, xlab, ylab, folded=True, title = "Barplot", transformed = False, normalized = False):
  167. sfs_val = []
  168. n = len(sfs.values())
  169. sum_sites = sum(list(sfs.values()))
  170. for k, ksi in sfs.items():
  171. #ksi = list(sfs.values())[k-1]
  172. # k+1 because k starts from 0
  173. # if folded:
  174. # # ?check if 2*n or not?
  175. # sfs_val.append(ksi * k * (2*n - k))
  176. # else:
  177. # if transformed:
  178. # sfs_val.append(ksi * k)
  179. # else:
  180. # sfs_val.append(ksi)
  181. if transformed:
  182. ylab = r'$ \phi_i $'
  183. if folded:
  184. val = ((k*(2*n - k)) / (2*n))*(ksi)
  185. else:
  186. val = ksi * k
  187. else:
  188. val = ksi
  189. sfs_val.append(val)
  190. if not transformed and not normalized:
  191. ylab = r'$ \eta_i $'
  192. #terminal case, same for folded or unfolded
  193. if transformed:
  194. last_bin = list(sfs.values())[n-1] * n/2
  195. else:
  196. last_bin = list(sfs.values())[n-1]
  197. sfs_val[-1] = last_bin
  198. if normalized:
  199. #ylab = "Fraction of SNPs "
  200. ylab = r'$ \phi_i $'
  201. sum_val = sum(sfs_val)
  202. for k, sfs_bin in enumerate(sfs_val):
  203. sfs_val[k] = sfs_bin / sum_val
  204. #print(sum(sfs_val))
  205. #build the plot
  206. title = title+" (n="+str(len(sfs_val))+") [folded="+str(folded)+"]"+" [transformed="+str(transformed)+"]"
  207. print("SFS =", sfs)
  208. if folded:
  209. xlab = "Minor allele frequency"
  210. if transformed:
  211. print("Transformed SFS ( n =",len(sfs_val), ") :", sfs_val)
  212. #plt.axhline(y=1/n, color='r', linestyle='-')
  213. else:
  214. if normalized:
  215. # then plot a theoritical distribution as 1/i
  216. expected_y = [1/(2*x+1) for x in list(sfs.keys())]
  217. print(sum(expected_y))
  218. #plt.plot([x for x in list(sfs.keys())], expected_y, color='r', linestyle='-')
  219. #print(expected_y)
  220. customgraphics.barplot(x = [x for x in list(sfs.keys())], y= sfs_val, xlab = xlab, ylab = ylab, title = title)
  221. plt.show()
  222. if __name__ == "__main__":
  223. if len(sys.argv) != 3:
  224. print("Need 2 args")
  225. exit(0)
  226. # PARAM : Nb of indiv
  227. n = int(sys.argv[2])
  228. sfs = sfs_from_vcf(n, sys.argv[1], folded = True, diploid = True, phased = False, strip = True)
  229. print(sfs)