#!/usr/bin/env python3 """ FOREST Thomas (thomas.forest@college-de-france.fr) Caution : At the moment for gzipped files only. ARGS -------- standalone usage : vcf_to_sfs.py VCF.gz nb_indiv TODO _____ Externalize sfs transforms in a function Rectify SFS comp in parsed funct. """ import gzip import sys import matplotlib.pyplot as plt from frst import customgraphics import numpy as np def parse_sfs(sfs_file): """ Parse a Site Frequency Spectrum (SFS) file and return a masked spectrum. This function reads an SFS file, extracts the spectrum data, and applies a mask to it. The mask excludes specific bins from the spectrum, resulting in a masked SFS. Parameters: - sfs_file (str): The path to the SFS file to be parsed, in dadi's .fs format. Returns: - masked_spectrum (list): A masked SFS as a list of integers. Raises: - FileNotFoundError: If the specified SFS file is not found. - ValueError: If there are inconsistencies in the file format or data. Note: The actual structure of the SFS file is based on dadi's fs format. """ try: with open(sfs_file, 'r') as file: # Read the first line which contains information about the file num_individuals, mode, species_name = file.readline().strip().split() num_individuals = int(num_individuals) # Read the spectrum data spectrum_data = list(map(int, file.readline().strip().split())) # Check if the number of bins in the spectrum matches the expected number if len(spectrum_data) != num_individuals: raise ValueError("Error: Number of bins in the spectrum doesn't match the expected number of individuals.") # Read the mask data mask_data = list(map(int, file.readline().strip().split())) # Check if the size of the mask matches the number of bins in the spectrum if len(mask_data) != num_individuals: raise ValueError("Error: Size of the mask doesn't match the number of bins in the spectrum.") # Apply the mask to the spectrum masked_spectrum = [spectrum_data[i] for i in range(num_individuals) if not mask_data[i]] # Error handling except FileNotFoundError: print(f"Error: File not found - {sfs_file}") except ValueError as ve: print(f"Error: {ve}") except Exception as e: print(f"Error: {e}") # final return of SFS as a list return masked_spectrum def sfs_from_vcf(n, vcf_file, folded = True, diploid = True, phased = False, verbose = False, strip = False, count_ext = False): """ Generates a Site Frequency Spectrum from a gzipped VCF file format. Parameters ---------- n : int Nb of individuals in sample. vcf_file : str SNPs in VCF file format. Used to generate a Site Frequency Spectrum (SFS) from a VCF. Returns ------- dict Site Frequency Spectrum (SFS) """ if diploid and not folded: n *= 2 # initiate SFS_values with a zeros dict # if strip: # # "[1" removes the 0 bin # # "n-1]" crop the last bin (n or n/2 for folded) # SFS_dim = [1, n-1] # else: SFS_dim = [0, n+1] SFS_values = dict.fromkeys(range(SFS_dim[1]),0) count_pluriall = 0 with gzip.open(vcf_file, "rb") as inputgz: line = inputgz.readline() genotypes = [] print("Parsing VCF", vcf_file, "... Please wait...") while line: # decode gzipped binary lines line = line.decode('utf-8').strip() # every snp line, not comment or header if not line.startswith("##") and not line.startswith("#"): FIELDS = line.split("\t") # REF is col 4 of VCF REF = FIELDS[3].split(",") # ALT is col 5 of VCF ALT = FIELDS[4].split(",") FORMAT = line.split("\t")[8:9] SAMPLES = line.split("\t")[9:] snp_genotypes = [] allele_counts = {} allele_counts_list = [] # SKIP the SNP if : # 1 : missing # 2 : deletion among REF # 3 : deletion among ALT if "./.:." in line \ or len(ALT[0]) > 1 \ or len(REF[0]) > 1: line = inputgz.readline() continue for sample in SAMPLES: if not phased: # for UNPHASED data smpl_genotype = [int(a) for a in sample.split(':')[0].split('/') if a != '.'] else: # for PHASED smpl_genotype = [int(a) for a in sample.split(':')[0].split('|') if a != '.'] nb_alleles = set(smpl_genotype) snp_genotypes += smpl_genotype # skip if all individuals have the same genotype # if len(set(snp_genotypes)) == 1: # if folded or (folded == False and snp_genotypes.count(1) == 0) : # line = inputgz.readline() # continue for k in set(snp_genotypes): allele_counts[snp_genotypes.count(k)] = k allele_counts_list.append(snp_genotypes.count(k)) #print(allele_counts_list) if len(set(snp_genotypes)) == 1 or allele_counts_list[0] == allele_counts_list[1]: # If only heterozygous sites 0/1; skip the site (equivalent to n bin or n/2 bin for folded) # skip if all individuals have the same genotype line = inputgz.readline() continue if len(ALT) >= 2: #pass count_pluriall +=1 # TODO - work in progress # for al in range(len(ALT)-1): # SFS_values[min(allele_counts_list)-1] += 1/len(ALT) # allele_counts_list.remove(min(allele_counts_list)) else: if folded: SFS_values[min(allele_counts_list)-SFS_dim[0]] += 1 else : # if unfolded, count the Ones (ALT allele) #print(snp_genotypes, snp_genotypes.count(1)) SFS_values[snp_genotypes.count(1)-SFS_dim[0]] += 1 # all the parsing is done, change line line = inputgz.readline() if verbose: print("SFS=", SFS_values) if strip: del SFS_values[0] del SFS_values[n] print("Pluriallelic sites =", count_pluriall) return SFS_values, count_pluriall def sfs_from_parsed_vcf(n, vcf_dict, folded = True, diploid = True, phased = False, verbose = False): """ Generates a Site Frequency Spectrum from a gzipped VCF file format. Parameters ---------- n : int Nb of individuals in sample. vcf_file : str SNPs in VCF file format. Used to generate a Site Frequency Spectrum (SFS) from a VCF. Returns ------- dict Site Frequency Spectrum (SFS) """ if diploid and not folded: n *= 2 # initiate SFS_values with a zeros dict SFS_values = dict.fromkeys(range(n),0) count_pluriall = 0 for CHROM in vcf_dict: for SNP in vcf_dict[CHROM]: snp_genotypes = [] allele_counts = {} allele_counts_list = [] print(CHROM, SNP) for sample in vcf_dict[CHROM][SNP]["SAMPLES"]: if not phased: # for UNPHASED data smpl_genotype = [int(a) for a in sample.split(':')[0].split('/') if a != '.'] else: # for PHASED smpl_genotype = [int(a) for a in sample.split(':')[0].split('|') if a != '.'] nb_alleles = set(smpl_genotype) snp_genotypes += smpl_genotype # skip if all individuals have the same genotype if len(set(snp_genotypes)) == 1: continue for k in set(snp_genotypes): allele_counts[snp_genotypes.count(k)] = k allele_counts_list.append(snp_genotypes.count(k)) SFS_values[min(allele_counts_list)-1] += 1 # sum pluriall counts for this CHR to the rest count_pluriall += vcf_dict[CHROM]['NB_PLURIALL'] if verbose: print("SFS=", SFS_values) print("Pluriallelic sites =", count_pluriall) return SFS_values, count_pluriall def barplot_sfs(sfs, xlab, ylab, folded=True, title = "Barplot", transformed = False, normalized = False, ploidy = 2, output = None): sfs_val = [] n = len(sfs.values()) sum_sites = sum(list(sfs.values())) for k, ksi in sfs.items(): #ksi = list(sfs.values())[k-1] # k+1 because k starts from 0 # if folded: # # ?check if 2*n or not? # sfs_val.append(ksi * k * (2*n - k)) # else: # if transformed: # sfs_val.append(ksi * k) # else: # sfs_val.append(ksi) if transformed: ylab = r'$ \phi_i $' if folded: val = ((k*(2*n - k)) / (2*n))*(ksi) else: val = ksi * k else: val = ksi sfs_val.append(val) if not transformed and not normalized: ylab = r'$ \eta_i $' #terminal case, same for folded or unfolded if transformed: last_bin = list(sfs.values())[n-1] * n/ploidy else: last_bin = list(sfs.values())[n-1] sfs_val[-1] = last_bin if normalized: #ylab = "Fraction of SNPs " ylab = r'$ \phi_i $' sum_val = sum(sfs_val) for k, sfs_bin in enumerate(sfs_val): sfs_val[k] = sfs_bin / sum_val #print(sum(sfs_val)) #build the plot if folded: xlab = "Minor allele frequency" n_title = n else: # the spectrum is n-1 long when unfolded n_title = n+1 original_title = title # reformat title and add infos title = title+" (n="+str(n_title)+") [folded="+str(folded)+"]"+" [transformed="+str(transformed)+"]" print("SFS =", sfs) X_axis = list(sfs.keys()) if transformed: print("Transformed SFS ( n =",n_title, ") :", sfs_val) #plt.axhline(y=1/n, color='r', linestyle='-') plt.bar([x+0.2 for x in list(sfs.keys())], [1/n]*n, fill=False, hatch="///", linestyle='-', width = 0.4, label= "H0 Theoric constant") else: if normalized: # then plot a theoritical distribution as 1/i sum_expected = sum([(1/(i+1)) for i,x in enumerate(list(sfs.keys()))]) expected_y = [(1/(i+1))/sum_expected for i,x in enumerate(list(sfs.keys()))] print(expected_y) plt.bar([x+0.2 for x in list(sfs.keys())], expected_y, fill=False, hatch="///", linestyle='-', width = 0.4, label= "H0 Theoric constant") print(sum(expected_y)) if output is not None: # if write in a file, don't open the window dynamically plot = False else: plot = True customgraphics.barplot(x = [x-0.2 for x in X_axis], width=0.4, y= sfs_val, xlab = xlab, ylab = ylab, title = title, label = "H1 Observed spectrum", xticks =list(sfs.keys()), plot = plot ) if output: plt.savefig(f"{output}/{original_title}_SFS.pdf") else: plt.show() plt.close() if __name__ == "__main__": if len(sys.argv) != 3: print("Need 2 args") exit(0) # PARAM : Nb of indiv n = int(sys.argv[2]) sfs = sfs_from_vcf(n, sys.argv[1], folded = True, diploid = True, phased = False, strip = True) print(sfs)