123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336 |
- #!/usr/bin/env python3
-
- """
- FOREST Thomas (thomas.forest@college-de-france.fr)
-
- Caution : At the moment for gzipped files only.
-
- ARGS
- --------
-
- standalone usage : vcf_to_sfs.py VCF.gz nb_indiv
-
- TODO
- _____
- Externalize sfs transforms in a function
- Rectify SFS comp in parsed funct.
-
- """
-
- import gzip
- import sys
- import matplotlib.pyplot as plt
- from frst import customgraphics
- import numpy as np
-
- def parse_sfs(sfs_file):
- """
- Parse a Site Frequency Spectrum (SFS) file and return a masked spectrum.
-
- This function reads an SFS file, extracts the spectrum data, and applies a mask to it.
- The mask excludes specific bins from the spectrum, resulting in a masked SFS.
-
- Parameters:
- - sfs_file (str): The path to the SFS file to be parsed, in dadi's .fs format.
-
- Returns:
- - masked_spectrum (list): A masked SFS as a list of integers.
-
- Raises:
- - FileNotFoundError: If the specified SFS file is not found.
- - ValueError: If there are inconsistencies in the file format or data.
-
- Note: The actual structure of the SFS file is based on dadi's fs format.
- """
- try:
- with open(sfs_file, 'r') as file:
- # Read the first line which contains information about the file
- num_individuals, mode, species_name = file.readline().strip().split()
- num_individuals = int(num_individuals)
- # Read the spectrum data
- spectrum_data = list(map(int, file.readline().strip().split()))
- # Check if the number of bins in the spectrum matches the expected number
- if len(spectrum_data) != num_individuals:
- raise ValueError("Error: Number of bins in the spectrum doesn't match the expected number of individuals.")
- # Read the mask data
- mask_data = list(map(int, file.readline().strip().split()))
-
- # Check if the size of the mask matches the number of bins in the spectrum
- if len(mask_data) != num_individuals:
- raise ValueError("Error: Size of the mask doesn't match the number of bins in the spectrum.")
- # Apply the mask to the spectrum
- masked_spectrum = [spectrum_data[i] for i in range(num_individuals) if not mask_data[i]]
- # Error handling
- except FileNotFoundError:
- print(f"Error: File not found - {sfs_file}")
- except ValueError as ve:
- print(f"Error: {ve}")
- except Exception as e:
- print(f"Error: {e}")
- # final return of SFS as a list
- return masked_spectrum
-
-
- def sfs_from_vcf(n, vcf_file, folded = True, diploid = True, phased = False, verbose = False,
- strip = False, count_ext = False):
-
- """
- Generates a Site Frequency Spectrum from a gzipped VCF file format.
-
- Parameters
- ----------
- n : int
- Nb of individuals in sample.
- vcf_file : str
- SNPs in VCF file format.
-
- Used to generate a Site Frequency Spectrum (SFS) from a VCF.
-
- Returns
- -------
- dict
- Site Frequency Spectrum (SFS)
-
-
- """
-
- if diploid and not folded:
- n *= 2
- # initiate SFS_values with a zeros dict
- # if strip:
- # # "[1" removes the 0 bin
- # # "n-1]" crop the last bin (n or n/2 for folded)
- # SFS_dim = [1, n-1]
- # else:
- SFS_dim = [0, n+1]
- SFS_values = dict.fromkeys(range(SFS_dim[1]),0)
- count_pluriall = 0
- with gzip.open(vcf_file, "rb") as inputgz:
- line = inputgz.readline()
- genotypes = []
- print("Parsing VCF", vcf_file, "... Please wait...")
- while line:
- # decode gzipped binary lines
- line = line.decode('utf-8').strip()
- # every snp line, not comment or header
- if not line.startswith("##") and not line.startswith("#"):
- FIELDS = line.split("\t")
- # REF is col 4 of VCF
- REF = FIELDS[3].split(",")
- # ALT is col 5 of VCF
- ALT = FIELDS[4].split(",")
- FORMAT = line.split("\t")[8:9]
- SAMPLES = line.split("\t")[9:]
- snp_genotypes = []
- allele_counts = {}
- allele_counts_list = []
- # SKIP the SNP if :
- # 1 : missing
- # 2 : deletion among REF
- # 3 : deletion among ALT
- if "./.:." in line \
- or len(ALT[0]) > 1 \
- or len(REF[0]) > 1:
- line = inputgz.readline()
- continue
- for sample in SAMPLES:
- if not phased:
- # for UNPHASED data
- smpl_genotype = [int(a) for a in sample.split(':')[0].split('/') if a != '.']
- else:
- # for PHASED
- smpl_genotype = [int(a) for a in sample.split(':')[0].split('|') if a != '.']
- nb_alleles = set(smpl_genotype)
- snp_genotypes += smpl_genotype
- # skip if all individuals have the same genotype
- # if len(set(snp_genotypes)) == 1:
- # if folded or (folded == False and snp_genotypes.count(1) == 0) :
- # line = inputgz.readline()
- # continue
- for k in set(snp_genotypes):
- allele_counts[snp_genotypes.count(k)] = k
- allele_counts_list.append(snp_genotypes.count(k))
- #print(allele_counts_list)
- if len(set(snp_genotypes)) == 1 or allele_counts_list[0] == allele_counts_list[1]:
- # If only heterozygous sites 0/1; skip the site (equivalent to n bin or n/2 bin for folded)
- # skip if all individuals have the same genotype
- line = inputgz.readline()
- continue
- if len(ALT) >= 2:
- #pass
- count_pluriall +=1
- # TODO - work in progress
- # for al in range(len(ALT)-1):
- # SFS_values[min(allele_counts_list)-1] += 1/len(ALT)
- # allele_counts_list.remove(min(allele_counts_list))
- else:
- if folded:
- SFS_values[min(allele_counts_list)-SFS_dim[0]] += 1
- else :
- # if unfolded, count the Ones (ALT allele)
- #print(snp_genotypes, snp_genotypes.count(1))
- SFS_values[snp_genotypes.count(1)-SFS_dim[0]] += 1
- # all the parsing is done, change line
- line = inputgz.readline()
- if verbose:
- print("SFS=", SFS_values)
- if strip:
- del SFS_values[0]
- del SFS_values[n]
- print("Pluriallelic sites =", count_pluriall)
- return SFS_values, count_pluriall
-
-
- def sfs_from_parsed_vcf(n, vcf_dict, folded = True, diploid = True, phased = False, verbose = False):
-
- """
- Generates a Site Frequency Spectrum from a gzipped VCF file format.
-
- Parameters
- ----------
- n : int
- Nb of individuals in sample.
- vcf_file : str
- SNPs in VCF file format.
-
- Used to generate a Site Frequency Spectrum (SFS) from a VCF.
-
- Returns
- -------
- dict
- Site Frequency Spectrum (SFS)
-
-
- """
-
- if diploid and not folded:
- n *= 2
- # initiate SFS_values with a zeros dict
- SFS_values = dict.fromkeys(range(n),0)
- count_pluriall = 0
-
- for CHROM in vcf_dict:
- for SNP in vcf_dict[CHROM]:
- snp_genotypes = []
- allele_counts = {}
- allele_counts_list = []
- print(CHROM, SNP)
- for sample in vcf_dict[CHROM][SNP]["SAMPLES"]:
- if not phased:
- # for UNPHASED data
- smpl_genotype = [int(a) for a in sample.split(':')[0].split('/') if a != '.']
- else:
- # for PHASED
- smpl_genotype = [int(a) for a in sample.split(':')[0].split('|') if a != '.']
- nb_alleles = set(smpl_genotype)
- snp_genotypes += smpl_genotype
- # skip if all individuals have the same genotype
- if len(set(snp_genotypes)) == 1:
- continue
- for k in set(snp_genotypes):
- allele_counts[snp_genotypes.count(k)] = k
- allele_counts_list.append(snp_genotypes.count(k))
- SFS_values[min(allele_counts_list)-1] += 1
- # sum pluriall counts for this CHR to the rest
- count_pluriall += vcf_dict[CHROM]['NB_PLURIALL']
-
- if verbose:
- print("SFS=", SFS_values)
- print("Pluriallelic sites =", count_pluriall)
-
- return SFS_values, count_pluriall
-
-
- def barplot_sfs(sfs, xlab, ylab, folded=True, title = "Barplot", transformed = False, normalized = False, ploidy = 2, output = None):
- sfs_val = []
- n = len(sfs.values())
- sum_sites = sum(list(sfs.values()))
- for k, ksi in sfs.items():
- #ksi = list(sfs.values())[k-1]
- # k+1 because k starts from 0
- # if folded:
- # # ?check if 2*n or not?
- # sfs_val.append(ksi * k * (2*n - k))
- # else:
- # if transformed:
- # sfs_val.append(ksi * k)
- # else:
- # sfs_val.append(ksi)
- if transformed:
- ylab = r'$ \phi_i $'
- if folded:
- val = ((k*(2*n - k)) / (2*n))*(ksi)
- else:
- val = ksi * k
- else:
- val = ksi
- sfs_val.append(val)
-
- if not transformed and not normalized:
- ylab = r'$ \eta_i $'
-
- #terminal case, same for folded or unfolded
- if transformed:
- last_bin = list(sfs.values())[n-1] * n/ploidy
- else:
- last_bin = list(sfs.values())[n-1]
- sfs_val[-1] = last_bin
- if normalized:
- #ylab = "Fraction of SNPs "
- ylab = r'$ \phi_i $'
- sum_val = sum(sfs_val)
- for k, sfs_bin in enumerate(sfs_val):
- sfs_val[k] = sfs_bin / sum_val
-
- #print(sum(sfs_val))
- #build the plot
- if folded:
- xlab = "Minor allele frequency"
- n_title = n
- else:
- # the spectrum is n-1 long when unfolded
- n_title = n+1
- original_title = title
- # reformat title and add infos
- title = title+" (n="+str(n_title)+") [folded="+str(folded)+"]"+" [transformed="+str(transformed)+"]"
- print("SFS =", sfs)
-
- X_axis = list(sfs.keys())
-
-
- if transformed:
- print("Transformed SFS ( n =",n_title, ") :", sfs_val)
- #plt.axhline(y=1/n, color='r', linestyle='-')
- plt.bar([x+0.2 for x in list(sfs.keys())], [1/n]*n, fill=False, hatch="///", linestyle='-', width = 0.4, label= "H0 Theoric constant")
-
- else:
- if normalized:
- # then plot a theoritical distribution as 1/i
- sum_expected = sum([(1/(i+1)) for i,x in enumerate(list(sfs.keys()))])
- expected_y = [(1/(i+1))/sum_expected for i,x in enumerate(list(sfs.keys()))]
- print(expected_y)
- plt.bar([x+0.2 for x in list(sfs.keys())], expected_y, fill=False, hatch="///", linestyle='-', width = 0.4, label= "H0 Theoric constant")
- print(sum(expected_y))
- if output is not None:
- # if write in a file, don't open the window dynamically
- plot = False
- else:
- plot = True
- customgraphics.barplot(x = [x-0.2 for x in X_axis], width=0.4, y= sfs_val, xlab = xlab, ylab = ylab, title = title, label = "H1 Observed spectrum", xticks =list(sfs.keys()), plot = plot )
- if output:
- plt.savefig(f"{output}/{original_title}_SFS.pdf")
- else:
- plt.show()
- plt.close()
-
- if __name__ == "__main__":
-
- if len(sys.argv) != 3:
- print("Need 2 args")
- exit(0)
-
- # PARAM : Nb of indiv
- n = int(sys.argv[2])
- sfs = sfs_from_vcf(n, sys.argv[1], folded = True, diploid = True, phased = False, strip = True)
- print(sfs)
|