sfs_tools.py 7.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. #!/usr/bin/env python3
  2. """
  3. FOREST Thomas (thomas.forest@college-de-france.fr)
  4. Caution : At the moment for gzipped files only.
  5. ARGS
  6. --------
  7. standalone usage : vcf_to_sfs.py VCF.gz nb_indiv
  8. TODO
  9. _____
  10. Externalize sfs transforms in a function
  11. Rectify SFS comp in parsed funct.
  12. """
  13. import gzip
  14. import sys
  15. import matplotlib.pyplot as plt
  16. def sfs_from_vcf(n, vcf_file, folded = True, diploid = True, phased = False, verbose = False,
  17. strip = False, count_ext = False):
  18. """
  19. Generates a Site Frequency Spectrum from a gzipped VCF file format.
  20. Parameters
  21. ----------
  22. n : int
  23. Nb of individuals in sample.
  24. vcf_file : str
  25. SNPs in VCF file format.
  26. Used to generate a Site Frequency Spectrum (SFS) from a VCF.
  27. Returns
  28. -------
  29. dict
  30. Site Frequency Spectrum (SFS)
  31. """
  32. if diploid and not folded:
  33. n *= 2
  34. # initiate SFS_values with a zeros dict
  35. # if strip:
  36. # # "[1" removes the 0 bin
  37. # # "n-1]" crop the last bin (n or n/2 for folded)
  38. # SFS_dim = [1, n-1]
  39. # else:
  40. SFS_dim = [0, n+1]
  41. SFS_values = dict.fromkeys(range(SFS_dim[1]),0)
  42. count_pluriall = 0
  43. with gzip.open(vcf_file, "rb") as inputgz:
  44. line = inputgz.readline()
  45. genotypes = []
  46. print("Parsing VCF", vcf_file, "... Please wait...")
  47. while line:
  48. # decode gzipped binary lines
  49. line = line.decode('utf-8').strip()
  50. # every snp line, not comment or header
  51. if not line.startswith("##") and not line.startswith("#"):
  52. FIELDS = line.split("\t")
  53. # REF is col 4 of VCF
  54. REF = FIELDS[3].split(",")
  55. # ALT is col 5 of VCF
  56. ALT = FIELDS[4].split(",")
  57. FORMAT = line.split("\t")[8:9]
  58. SAMPLES = line.split("\t")[9:]
  59. snp_genotypes = []
  60. allele_counts = {}
  61. allele_counts_list = []
  62. # SKIP the SNP if :
  63. # 1 : missing
  64. # 2 : deletion among REF
  65. # 3 : deletion among ALT
  66. if "./.:." in line \
  67. or len(ALT[0]) > 1 \
  68. or len(REF[0]) > 1:
  69. line = inputgz.readline()
  70. continue
  71. for sample in SAMPLES:
  72. if not phased:
  73. # for UNPHASED data
  74. smpl_genotype = [int(a) for a in sample.split(':')[0].split('/') if a != '.']
  75. else:
  76. # for PHASED
  77. smpl_genotype = [int(a) for a in sample.split(':')[0].split('|') if a != '.']
  78. nb_alleles = set(smpl_genotype)
  79. snp_genotypes += smpl_genotype
  80. # skip if all individuals have the same genotype
  81. # if len(set(snp_genotypes)) == 1:
  82. # if folded or (folded == False and snp_genotypes.count(1) == 0) :
  83. # line = inputgz.readline()
  84. # continue
  85. for k in set(snp_genotypes):
  86. allele_counts[snp_genotypes.count(k)] = k
  87. allele_counts_list.append(snp_genotypes.count(k))
  88. #print(allele_counts_list)
  89. if len(set(snp_genotypes)) == 1 or allele_counts_list[0] == allele_counts_list[1]:
  90. # If only heterozygous sites 0/1; skip the site (equivalent to n bin or n/2 bin for folded)
  91. # skip if all individuals have the same genotype
  92. line = inputgz.readline()
  93. continue
  94. if len(ALT) >= 2:
  95. #pass
  96. count_pluriall +=1
  97. # TODO - work in progress
  98. # for al in range(len(ALT)-1):
  99. # SFS_values[min(allele_counts_list)-1] += 1/len(ALT)
  100. # allele_counts_list.remove(min(allele_counts_list))
  101. else:
  102. if folded:
  103. SFS_values[min(allele_counts_list)-SFS_dim[0]] += 1
  104. else :
  105. # if unfolded, count the Ones (ALT allele)
  106. #print(snp_genotypes, snp_genotypes.count(1))
  107. SFS_values[snp_genotypes.count(1)-SFS_dim[0]] += 1
  108. # all the parsing is done, change line
  109. line = inputgz.readline()
  110. if verbose:
  111. print("SFS=", SFS_values)
  112. if strip:
  113. del SFS_values[0]
  114. del SFS_values[n]
  115. print("Pluriallelic sites =", count_pluriall)
  116. return SFS_values, count_pluriall
  117. def sfs_from_parsed_vcf(n, vcf_dict, folded = True, diploid = True, phased = False, verbose = False):
  118. """
  119. Generates a Site Frequency Spectrum from a gzipped VCF file format.
  120. Parameters
  121. ----------
  122. n : int
  123. Nb of individuals in sample.
  124. vcf_file : str
  125. SNPs in VCF file format.
  126. Used to generate a Site Frequency Spectrum (SFS) from a VCF.
  127. Returns
  128. -------
  129. dict
  130. Site Frequency Spectrum (SFS)
  131. """
  132. if diploid and not folded:
  133. n *= 2
  134. # initiate SFS_values with a zeros dict
  135. SFS_values = dict.fromkeys(range(n),0)
  136. count_pluriall = 0
  137. for CHROM in vcf_dict:
  138. for SNP in vcf_dict[CHROM]:
  139. snp_genotypes = []
  140. allele_counts = {}
  141. allele_counts_list = []
  142. print(CHROM, SNP)
  143. for sample in vcf_dict[CHROM][SNP]["SAMPLES"]:
  144. if not phased:
  145. # for UNPHASED data
  146. smpl_genotype = [int(a) for a in sample.split(':')[0].split('/') if a != '.']
  147. else:
  148. # for PHASED
  149. smpl_genotype = [int(a) for a in sample.split(':')[0].split('|') if a != '.']
  150. nb_alleles = set(smpl_genotype)
  151. snp_genotypes += smpl_genotype
  152. # skip if all individuals have the same genotype
  153. if len(set(snp_genotypes)) == 1:
  154. continue
  155. for k in set(snp_genotypes):
  156. allele_counts[snp_genotypes.count(k)] = k
  157. allele_counts_list.append(snp_genotypes.count(k))
  158. SFS_values[min(allele_counts_list)-1] += 1
  159. # sum pluriall counts for this CHR to the rest
  160. count_pluriall += vcf_dict[CHROM]['NB_PLURIALL']
  161. if verbose:
  162. print("SFS=", SFS_values)
  163. print("Pluriallelic sites =", count_pluriall)
  164. return SFS_values, count_pluriall
  165. def barplot_sfs(sfs, xlab, ylab, folded=True, title = "Barplot", transformed = False):
  166. sfs_val = []
  167. n = len(sfs.values())
  168. for k in range(1, n):
  169. ksi = list(sfs.values())[k-1]
  170. # k+1 because k starts from 0
  171. # if folded:
  172. # # ?check if 2*n or not?
  173. # sfs_val.append(ksi * k * (2*n - k))
  174. # else:
  175. # if transformed:
  176. # sfs_val.append(ksi * k)
  177. # else:
  178. # sfs_val.append(ksi)
  179. if transformed:
  180. if folded:
  181. sfs_val.append(ksi * k * (2*n - k))
  182. else:
  183. sfs_val.append(ksi * k)
  184. else:
  185. sfs_val.append(ksi)
  186. #terminal case, same for folded or unfolded
  187. if transformed:
  188. sfs_val.append(list(sfs.values())[n-1] * n)
  189. else:
  190. sfs_val.append(list(sfs.values())[n-1])
  191. #build the plot
  192. title = title+" [folded="+str(folded)+"]"
  193. if ylab:
  194. plt.ylabel(ylab)
  195. if xlab:
  196. plt.xlabel(xlab)
  197. plt.title(title)
  198. plt.bar([i+1 for i in sfs.keys()], sfs_val)
  199. plt.show()
  200. if __name__ == "__main__":
  201. if len(sys.argv) != 3:
  202. print("Need 2 args")
  203. exit(0)
  204. # PARAM : Nb of indiv
  205. n = int(sys.argv[2])
  206. sfs = sfs_from_vcf(n, sys.argv[1], folded = True, diploid = True, phased = False, strip = True)
  207. print(sfs)