sfs_tools.py 6.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. #!/usr/bin/env python3
  2. """
  3. FOREST Thomas (thomas.forest@college-de-france.fr)
  4. Caution : At the moment for gzipped files only.
  5. ARGS
  6. --------
  7. standalone usage : vcf_to_sfs.py VCF.gz nb_indiv
  8. """
  9. import gzip
  10. import sys
  11. import matplotlib.pyplot as plt
  12. def sfs_from_vcf(n, vcf_file, folded = True, diploid = True, phased = False, verbose = False):
  13. """
  14. Generates a Site Frequency Spectrum from a gzipped VCF file format.
  15. Parameters
  16. ----------
  17. n : int
  18. Nb of individuals in sample.
  19. vcf_file : str
  20. SNPs in VCF file format.
  21. Used to generate a Site Frequency Spectrum (SFS) from a VCF.
  22. Returns
  23. -------
  24. dict
  25. Site Frequency Spectrum (SFS)
  26. """
  27. if diploid and not folded:
  28. n *= 2
  29. # initiate SFS_values with a zeros dict
  30. SFS_values = dict.fromkeys(range(n),0)
  31. count_pluriall = 0
  32. with gzip.open(vcf_file, "rb") as inputgz:
  33. line = inputgz.readline()
  34. genotypes = []
  35. print("Parsing VCF", vcf_file, "... Please wait...")
  36. while line:
  37. # decode gzipped binary lines
  38. line = line.decode('utf-8').strip()
  39. # every snp line, not comment or header
  40. if not line.startswith("##") and not line.startswith("#"):
  41. FIELDS = line.split("\t")
  42. # REF is col 4 of VCF
  43. REF = FIELDS[3].split(",")
  44. # ALT is col 5 of VCF
  45. ALT = FIELDS[4].split(",")
  46. FORMAT = line.split("\t")[8:9]
  47. SAMPLES = line.split("\t")[9:]
  48. snp_genotypes = []
  49. allele_counts = {}
  50. allele_counts_list = []
  51. # SKIP the SNP if :
  52. # 1 : missing
  53. # 2 : deletion among REF
  54. # 3 : deletion among ALT
  55. if "./.:." in line \
  56. or len(ALT[0]) > 1 \
  57. or len(REF[0]) > 1:
  58. line = inputgz.readline()
  59. continue
  60. for sample in SAMPLES:
  61. if not phased:
  62. # for UNPHASED data
  63. smpl_genotype = [int(a) for a in sample.split(':')[0].split('/') if a != '.']
  64. else:
  65. # for PHASED
  66. smpl_genotype = [int(a) for a in sample.split(':')[0].split('|') if a != '.']
  67. nb_alleles = set(smpl_genotype)
  68. snp_genotypes += smpl_genotype
  69. # skip if all individuals have the same genotype
  70. if len(set(snp_genotypes)) == 1:
  71. line = inputgz.readline()
  72. continue
  73. for k in set(snp_genotypes):
  74. allele_counts[snp_genotypes.count(k)] = k
  75. allele_counts_list.append(snp_genotypes.count(k))
  76. if folded and len(ALT) >= 2:
  77. #pass
  78. count_pluriall +=1
  79. # TODO - work in progress
  80. # for al in range(len(ALT)-1):
  81. # SFS_values[min(allele_counts_list)-1] += 1/len(ALT)
  82. # allele_counts_list.remove(min(allele_counts_list))
  83. else:
  84. SFS_values[min(allele_counts_list)-1] += 1
  85. line = inputgz.readline()
  86. if verbose:
  87. print("SFS=", SFS_values)
  88. print("Pluriallelic sites =", count_pluriall)
  89. return SFS_values, count_pluriall
  90. def sfs_from_parsed_vcf(n, vcf_dict, folded = True, diploid = True, phased = False, verbose = False):
  91. """
  92. Generates a Site Frequency Spectrum from a gzipped VCF file format.
  93. Parameters
  94. ----------
  95. n : int
  96. Nb of individuals in sample.
  97. vcf_file : str
  98. SNPs in VCF file format.
  99. Used to generate a Site Frequency Spectrum (SFS) from a VCF.
  100. Returns
  101. -------
  102. dict
  103. Site Frequency Spectrum (SFS)
  104. """
  105. if diploid and not folded:
  106. n *= 2
  107. # initiate SFS_values with a zeros dict
  108. SFS_values = dict.fromkeys(range(n),0)
  109. count_pluriall = 0
  110. for CHROM in vcf_dict:
  111. for SNP in vcf_dict[CHROM]:
  112. snp_genotypes = []
  113. allele_counts = {}
  114. allele_counts_list = []
  115. print(CHROM, SNP)
  116. for sample in vcf_dict[CHROM][SNP]["SAMPLES"]:
  117. if not phased:
  118. # for UNPHASED data
  119. smpl_genotype = [int(a) for a in sample.split(':')[0].split('/') if a != '.']
  120. else:
  121. # for PHASED
  122. smpl_genotype = [int(a) for a in sample.split(':')[0].split('|') if a != '.']
  123. nb_alleles = set(smpl_genotype)
  124. snp_genotypes += smpl_genotype
  125. # skip if all individuals have the same genotype
  126. if len(set(snp_genotypes)) == 1:
  127. continue
  128. for k in set(snp_genotypes):
  129. allele_counts[snp_genotypes.count(k)] = k
  130. allele_counts_list.append(snp_genotypes.count(k))
  131. SFS_values[min(allele_counts_list)-1] += 1
  132. # sum pluriall counts for this CHR to the rest
  133. count_pluriall += vcf_dict[CHROM]['NB_PLURIALL']
  134. if verbose:
  135. print("SFS=", SFS_values)
  136. print("Pluriallelic sites =", count_pluriall)
  137. return SFS_values, count_pluriall
  138. def barplot_sfs(sfs, folded=True, title = "Barplot"):
  139. sfs_val = []
  140. n = len(sfs.values())
  141. for k in range(1, n):
  142. ksi = list(sfs.values())[k-1]
  143. # k+1 because k starts from 0
  144. if folded:
  145. sfs_val.append(ksi * k * (n - k))
  146. else:
  147. sfs_val.append(ksi * k)
  148. #terminal case, same for folded or unfolded
  149. sfs_val.append(list(sfs.values())[n-1] * n)
  150. #build the plot
  151. title = title+" [folded="+str(folded)+"]"
  152. plt.title(title)
  153. plt.bar([i+1 for i in sfs.keys()], sfs_val)
  154. plt.show()
  155. if __name__ == "__main__":
  156. if len(sys.argv) != 3:
  157. print("Need 2 args")
  158. exit(0)
  159. # PARAM : Nb of indiv
  160. n = int(sys.argv[2])
  161. sfs = sfs_from_vcf(n, sys.argv[1], folded = True, diploid = True, phased = False)
  162. print(sfs)