sfs_tools.py 4.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. #!/usr/bin/env python3
  2. """
  3. FOREST Thomas (thomas.forest@college-de-france.fr)
  4. Caution : At the moment for gzipped files only.
  5. ARGS
  6. --------
  7. standalone usage : vcf_to_sfs.py VCF.gz nb_indiv
  8. """
  9. import gzip
  10. import sys
  11. import matplotlib.pyplot as plt
  12. def sfs_from_vcf(n, vcf_file, folded = True, diploid = True, phased = False, verbose = False):
  13. """
  14. Multiplication de deux nombres entiers.
  15. Cette fonction ne sert pas à grand chose.
  16. Parameters
  17. ----------
  18. n : int
  19. Nb of individuals in sample.
  20. vcf_file : str
  21. SNPs in VCF file format.
  22. Used to generate a Site Frequency Spectrum (SFS) from a VCF.
  23. Returns
  24. -------
  25. dict
  26. Site Frequency Spectrum (SFS)
  27. """
  28. if diploid and not folded:
  29. n *= 2
  30. # initiate SFS_values with a zeros dict
  31. SFS_values = dict.fromkeys(range(n),0)
  32. count_pluriall = 0
  33. with gzip.open(vcf_file, "rb") as inputgz:
  34. line = inputgz.readline()
  35. genotypes = []
  36. print("Parsing VCF", vcf_file, "... Please wait...")
  37. while line:
  38. # decode gzipped binary lines
  39. line = line.decode('utf-8').strip()
  40. # every snp line, not comment or header
  41. if not line.startswith("##") and not line.startswith("#"):
  42. FIELDS = line.split("\t")
  43. # REF is col 4 of VCF
  44. REF = FIELDS[3].split(",")
  45. # ALT is col 5 of VCF
  46. ALT = FIELDS[4].split(",")
  47. FORMAT = line.split("\t")[8:9]
  48. SAMPLES = line.split("\t")[9:]
  49. snp_genotypes = []
  50. allele_counts = {}
  51. allele_counts_list = []
  52. # SKIP the SNP if :
  53. # 1 : missing
  54. # 2 : deletion among REF
  55. # 3 : deletion among ALT
  56. if "./.:." in line \
  57. or len(ALT[0]) > 1 \
  58. or len(REF[0]) > 1:
  59. line = inputgz.readline()
  60. continue
  61. for sample in SAMPLES:
  62. if not phased:
  63. # for UNPHASED data
  64. smpl_genotype = [int(a) for a in sample.split(':')[0].split('/') if a != '.']
  65. else:
  66. # for PHASED
  67. smpl_genotype = [int(a) for a in sample.split(':')[0].split('|') if a != '.']
  68. nb_alleles = set(smpl_genotype)
  69. snp_genotypes += smpl_genotype
  70. # skip if all individuals have the same genotype
  71. if len(set(snp_genotypes)) == 1:
  72. line = inputgz.readline()
  73. continue
  74. for k in set(snp_genotypes):
  75. allele_counts[snp_genotypes.count(k)] = k
  76. allele_counts_list.append(snp_genotypes.count(k))
  77. if folded and len(ALT) >= 2:
  78. #pass
  79. count_pluriall +=1
  80. # TODO - work in progress
  81. # for al in range(len(ALT)-1):
  82. # SFS_values[min(allele_counts_list)-1] += 1/len(ALT)
  83. # allele_counts_list.remove(min(allele_counts_list))
  84. else:
  85. SFS_values[min(allele_counts_list)-1] += 1
  86. line = inputgz.readline()
  87. if verbose:
  88. print("SFS=", SFS_values)
  89. print("Pluriallelic sites =", count_pluriall)
  90. return SFS_values
  91. def barplot_sfs(sfs, folded=True, title = "Barplot"):
  92. sfs_val = []
  93. n = len(sfs.values())
  94. for k in range(1, n):
  95. ksi = list(sfs.values())[k-1]
  96. # k+1 because k starts from 0
  97. if folded:
  98. sfs_val.append(ksi * k * (n - k))
  99. else:
  100. sfs_val.append(ksi * k)
  101. #terminal case, same for folded or unfolded
  102. sfs_val.append(list(sfs.values())[n-1] * n)
  103. #build the plot
  104. title = title+" [folded="+str(folded)+"]"
  105. plt.title(title)
  106. plt.bar(sfs.keys(), sfs_val)
  107. plt.show()
  108. if __name__ == "__main__":
  109. if len(sys.argv) != 3:
  110. print("Need 2 args")
  111. exit(0)
  112. # PARAM : Nb of indiv
  113. n = int(sys.argv[2])
  114. sfs = sfs_from_vcf(n, sys.argv[1], folded = True, diploid = True, phased = False)
  115. print(sfs)