sfs_tools.py 9.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. #!/usr/bin/env python3
  2. """
  3. FOREST Thomas (thomas.forest@college-de-france.fr)
  4. Caution : At the moment for gzipped files only.
  5. ARGS
  6. --------
  7. standalone usage : vcf_to_sfs.py VCF.gz nb_indiv
  8. TODO
  9. _____
  10. Externalize sfs transforms in a function
  11. Rectify SFS comp in parsed funct.
  12. """
  13. import gzip
  14. import sys
  15. import matplotlib.pyplot as plt
  16. from frst import customgraphics
  17. import numpy as np
  18. def sfs_from_vcf(n, vcf_file, folded = True, diploid = True, phased = False, verbose = False,
  19. strip = False, count_ext = False):
  20. """
  21. Generates a Site Frequency Spectrum from a gzipped VCF file format.
  22. Parameters
  23. ----------
  24. n : int
  25. Nb of individuals in sample.
  26. vcf_file : str
  27. SNPs in VCF file format.
  28. Used to generate a Site Frequency Spectrum (SFS) from a VCF.
  29. Returns
  30. -------
  31. dict
  32. Site Frequency Spectrum (SFS)
  33. """
  34. if diploid and not folded:
  35. n *= 2
  36. # initiate SFS_values with a zeros dict
  37. # if strip:
  38. # # "[1" removes the 0 bin
  39. # # "n-1]" crop the last bin (n or n/2 for folded)
  40. # SFS_dim = [1, n-1]
  41. # else:
  42. SFS_dim = [0, n+1]
  43. SFS_values = dict.fromkeys(range(SFS_dim[1]),0)
  44. count_pluriall = 0
  45. with gzip.open(vcf_file, "rb") as inputgz:
  46. line = inputgz.readline()
  47. genotypes = []
  48. print("Parsing VCF", vcf_file, "... Please wait...")
  49. while line:
  50. # decode gzipped binary lines
  51. line = line.decode('utf-8').strip()
  52. # every snp line, not comment or header
  53. if not line.startswith("##") and not line.startswith("#"):
  54. FIELDS = line.split("\t")
  55. # REF is col 4 of VCF
  56. REF = FIELDS[3].split(",")
  57. # ALT is col 5 of VCF
  58. ALT = FIELDS[4].split(",")
  59. FORMAT = line.split("\t")[8:9]
  60. SAMPLES = line.split("\t")[9:]
  61. snp_genotypes = []
  62. allele_counts = {}
  63. allele_counts_list = []
  64. # SKIP the SNP if :
  65. # 1 : missing
  66. # 2 : deletion among REF
  67. # 3 : deletion among ALT
  68. if "./.:." in line \
  69. or len(ALT[0]) > 1 \
  70. or len(REF[0]) > 1:
  71. line = inputgz.readline()
  72. continue
  73. for sample in SAMPLES:
  74. if not phased:
  75. # for UNPHASED data
  76. smpl_genotype = [int(a) for a in sample.split(':')[0].split('/') if a != '.']
  77. else:
  78. # for PHASED
  79. smpl_genotype = [int(a) for a in sample.split(':')[0].split('|') if a != '.']
  80. nb_alleles = set(smpl_genotype)
  81. snp_genotypes += smpl_genotype
  82. # skip if all individuals have the same genotype
  83. # if len(set(snp_genotypes)) == 1:
  84. # if folded or (folded == False and snp_genotypes.count(1) == 0) :
  85. # line = inputgz.readline()
  86. # continue
  87. for k in set(snp_genotypes):
  88. allele_counts[snp_genotypes.count(k)] = k
  89. allele_counts_list.append(snp_genotypes.count(k))
  90. #print(allele_counts_list)
  91. if len(set(snp_genotypes)) == 1 or allele_counts_list[0] == allele_counts_list[1]:
  92. # If only heterozygous sites 0/1; skip the site (equivalent to n bin or n/2 bin for folded)
  93. # skip if all individuals have the same genotype
  94. line = inputgz.readline()
  95. continue
  96. if len(ALT) >= 2:
  97. #pass
  98. count_pluriall +=1
  99. # TODO - work in progress
  100. # for al in range(len(ALT)-1):
  101. # SFS_values[min(allele_counts_list)-1] += 1/len(ALT)
  102. # allele_counts_list.remove(min(allele_counts_list))
  103. else:
  104. if folded:
  105. SFS_values[min(allele_counts_list)-SFS_dim[0]] += 1
  106. else :
  107. # if unfolded, count the Ones (ALT allele)
  108. #print(snp_genotypes, snp_genotypes.count(1))
  109. SFS_values[snp_genotypes.count(1)-SFS_dim[0]] += 1
  110. # all the parsing is done, change line
  111. line = inputgz.readline()
  112. if verbose:
  113. print("SFS=", SFS_values)
  114. if strip:
  115. del SFS_values[0]
  116. del SFS_values[n]
  117. print("Pluriallelic sites =", count_pluriall)
  118. return SFS_values, count_pluriall
  119. def sfs_from_parsed_vcf(n, vcf_dict, folded = True, diploid = True, phased = False, verbose = False):
  120. """
  121. Generates a Site Frequency Spectrum from a gzipped VCF file format.
  122. Parameters
  123. ----------
  124. n : int
  125. Nb of individuals in sample.
  126. vcf_file : str
  127. SNPs in VCF file format.
  128. Used to generate a Site Frequency Spectrum (SFS) from a VCF.
  129. Returns
  130. -------
  131. dict
  132. Site Frequency Spectrum (SFS)
  133. """
  134. if diploid and not folded:
  135. n *= 2
  136. # initiate SFS_values with a zeros dict
  137. SFS_values = dict.fromkeys(range(n),0)
  138. count_pluriall = 0
  139. for CHROM in vcf_dict:
  140. for SNP in vcf_dict[CHROM]:
  141. snp_genotypes = []
  142. allele_counts = {}
  143. allele_counts_list = []
  144. print(CHROM, SNP)
  145. for sample in vcf_dict[CHROM][SNP]["SAMPLES"]:
  146. if not phased:
  147. # for UNPHASED data
  148. smpl_genotype = [int(a) for a in sample.split(':')[0].split('/') if a != '.']
  149. else:
  150. # for PHASED
  151. smpl_genotype = [int(a) for a in sample.split(':')[0].split('|') if a != '.']
  152. nb_alleles = set(smpl_genotype)
  153. snp_genotypes += smpl_genotype
  154. # skip if all individuals have the same genotype
  155. if len(set(snp_genotypes)) == 1:
  156. continue
  157. for k in set(snp_genotypes):
  158. allele_counts[snp_genotypes.count(k)] = k
  159. allele_counts_list.append(snp_genotypes.count(k))
  160. SFS_values[min(allele_counts_list)-1] += 1
  161. # sum pluriall counts for this CHR to the rest
  162. count_pluriall += vcf_dict[CHROM]['NB_PLURIALL']
  163. if verbose:
  164. print("SFS=", SFS_values)
  165. print("Pluriallelic sites =", count_pluriall)
  166. return SFS_values, count_pluriall
  167. def barplot_sfs(sfs, xlab, ylab, folded=True, title = "Barplot", transformed = False, normalized = False, ploidy = 2):
  168. sfs_val = []
  169. n = len(sfs.values())
  170. sum_sites = sum(list(sfs.values()))
  171. for k, ksi in sfs.items():
  172. #ksi = list(sfs.values())[k-1]
  173. # k+1 because k starts from 0
  174. # if folded:
  175. # # ?check if 2*n or not?
  176. # sfs_val.append(ksi * k * (2*n - k))
  177. # else:
  178. # if transformed:
  179. # sfs_val.append(ksi * k)
  180. # else:
  181. # sfs_val.append(ksi)
  182. if transformed:
  183. ylab = r'$ \phi_i $'
  184. if folded:
  185. val = ((k*(2*n - k)) / (2*n))*(ksi)
  186. else:
  187. val = ksi * k
  188. else:
  189. val = ksi
  190. sfs_val.append(val)
  191. if not transformed and not normalized:
  192. ylab = r'$ \eta_i $'
  193. #terminal case, same for folded or unfolded
  194. if transformed:
  195. last_bin = list(sfs.values())[n-1] * n/ploidy
  196. else:
  197. last_bin = list(sfs.values())[n-1]
  198. sfs_val[-1] = last_bin
  199. if normalized:
  200. #ylab = "Fraction of SNPs "
  201. ylab = r'$ \phi_i $'
  202. sum_val = sum(sfs_val)
  203. for k, sfs_bin in enumerate(sfs_val):
  204. sfs_val[k] = sfs_bin / sum_val
  205. #print(sum(sfs_val))
  206. #build the plot
  207. if folded:
  208. xlab = "Minor allele frequency"
  209. n_title = n
  210. else:
  211. # the spectrum is n-1 long when unfolded
  212. n_title = n+1
  213. title = title+" (n="+str(n_title)+") [folded="+str(folded)+"]"+" [transformed="+str(transformed)+"]"
  214. print("SFS =", sfs)
  215. X_axis = list(sfs.keys())
  216. if transformed:
  217. print("Transformed SFS ( n =",n_title, ") :", sfs_val)
  218. #plt.axhline(y=1/n, color='r', linestyle='-')
  219. plt.bar([x+0.2 for x in list(sfs.keys())], [1/n]*n, color='r', linestyle='-', width = 0.4, label= "H0 Theoric constant")
  220. else:
  221. if normalized:
  222. # then plot a theoritical distribution as 1/i
  223. sum_expected = sum([(1/(i+1)) for i,x in enumerate(list(sfs.keys()))])
  224. expected_y = [(1/(i+1))/sum_expected for i,x in enumerate(list(sfs.keys()))]
  225. print(expected_y)
  226. plt.bar([x+0.2 for x in list(sfs.keys())], expected_y, color='r', linestyle='-', width = 0.4, label= "H0 Theoric constant")
  227. print(sum(expected_y))
  228. customgraphics.barplot(x = [x-0.2 for x in X_axis], width=0.4, y= sfs_val, xlab = xlab, ylab = ylab, title = title, label = "H1 Observed spectrum", xticks =list(sfs.keys()) )
  229. plt.show()
  230. if __name__ == "__main__":
  231. if len(sys.argv) != 3:
  232. print("Need 2 args")
  233. exit(0)
  234. # PARAM : Nb of indiv
  235. n = int(sys.argv[2])
  236. sfs = sfs_from_vcf(n, sys.argv[1], folded = True, diploid = True, phased = False, strip = True)
  237. print(sfs)