diff --git a/Snakefile b/Snakefile index 65ad44c..97831cc 100644 --- a/Snakefile +++ b/Snakefile @@ -77,6 +77,7 @@ checkpoint prescore: prescored=temp("{file}.pre.tsv"), log: "{file}.prescore.log", + threads: workflow.cores, params: cadd=os.environ["CADD"], shell: @@ -85,20 +86,18 @@ checkpoint prescore: echo '## Prescored variant file' > {output.prescored} 2> {log}; PRESCORED_FILES=`find -L {input.prescored} -maxdepth 1 -type f -name \\*.tsv.gz | wc -l` cp {input.vcf} {input.vcf}.new - if [ ${{PRESCORED_FILES}} -gt 0 ]; - then - for PRESCORED in $(ls {input.prescored}/*.tsv.gz) - do + if [ ${{PRESCORED_FILES}} -gt 0 ]; then + for PRESCORED in $(ls {input.prescored}/*.tsv.gz); do cat {input.vcf}.new \ | python {params.cadd}/src/scripts/extract_scored.py --header \ - -p $PRESCORED --found_out={output.prescored}.tmp \ + -p $PRESCORED --found_out={output.prescored}.tmp --threads {threads} \ > {input.vcf}.tmp 2>> {log}; cat {output.prescored}.tmp >> {output.prescored} - mv {input.vcf}.tmp {input.vcf}.new &> {log}; + mv {input.vcf}.tmp {input.vcf}.new 2>> {log}; done; - rm {output.prescored}.tmp &>> {log} + rm {output.prescored}.tmp 2>> {log} fi - mv {input.vcf}.new {output.novel} &>> {log} + mv {input.vcf}.new {output.novel} 2>> {log} """ diff --git a/src/scripts/extract_scored.py b/src/scripts/extract_scored.py index e88f103..6b49cc5 100755 --- a/src/scripts/extract_scored.py +++ b/src/scripts/extract_scored.py @@ -5,69 +5,254 @@ import os import pysam from optparse import OptionParser +import multiprocessing as mp -parser = OptionParser() -parser.add_option("-p", "--path", dest="path", help="Path to scored variants.") -parser.add_option("-i", "--input", dest="input", help="Read variants from vcf file (default stdin)", default=None) -parser.add_option("--found_out", dest="found_out", help="Write found variants to file (default: stdout)", default=None) -parser.add_option("--header", dest="header", help="Write full header to output (default none)", - default=False, action="store_true") -(options, args) = parser.parse_args() - -if options.input: - stdin = open(options.input, 'r') -else: - stdin = sys.stdin - -if options.found_out: - found_out = open(options.found_out, 'w') -else: - found_out = sys.stdout - -fpos, fref, falt = 1, 2, 3 -if os.path.exists(options.path) and os.path.exists(options.path+".tbi"): - filename = options.path - sys.stderr.write("Opening %s...\n" % (filename)) - regionTabix = pysam.Tabixfile(filename, 'r') - header = list(regionTabix.header) - for line in header: - if options.header: - found_out.write(line+"\n") + +def buffer_vcf_by_chromosome(input_stream): + """Read VCF from input stream and buffer by chromosome""" + vcf_by_chrom = {} + header_lines = [] + + for line in input_stream: + if line.startswith('#'): + header_lines.append(line) + continue + + fields = line.strip().split('\t') + chrom = fields[0] + + if chrom not in vcf_by_chrom: + vcf_by_chrom[chrom] = [] + vcf_by_chrom[chrom].append(line) + + return header_lines, vcf_by_chrom + + +def setup_output_dir(output_base, chrom): + """Create chromosome-specific output directory""" + chrom_dir = os.path.join(output_base, chrom) + try: + os.makedirs(chrom_dir) + except OSError: + if not os.path.isdir(chrom_dir): + raise + return chrom_dir + + +def extract_prescored_chromosome(input_file, output_base, chrom): + """Extract records for a single chromosome from prescored TSV file""" + try: + # Setup output directory + chrom_dir = setup_output_dir(output_base, chrom) + input_file_name = os.path.basename(input_file) + input_file_name_base = input_file_name.replace(".tsv.gz", "") + assert input_file_name_base != input_file_name, "The input file name {0} is not valid".format(input_file_name) + output_file = os.path.join(chrom_dir, "{0}.{1}.tsv".format(input_file_name_base, chrom)) + compressed_file = "{0}.gz".format(output_file) + + # Check if extraction is needed + if os.path.exists(compressed_file): + if os.path.getmtime(compressed_file) > os.path.getmtime(input_file): + if os.path.exists(compressed_file + ".tbi"): + if os.path.getmtime(compressed_file + ".tbi") > os.path.getmtime(compressed_file): + sys.stderr.write("The prescored file {0} for chromosome {1} is up to date, skip the extraction\n".format(compressed_file, chrom)) + return compressed_file + else: + tabix_only=True + else: + tabix_only=True + + if tabix_only: + pysam.tabix_index(compressed_file, + preset=None, + force=True, + seq_col=0, + start_col=1, + end_col=1, + zerobased=False) + return compressed_file + + # Extract records for this chromosome using tabix + tbx = pysam.TabixFile(input_file) + with open(output_file, 'w') as f: + for row in tbx.fetch(chrom): + f.write("{0}\n".format(row)) + + # Compress and index the output file + pysam.tabix_compress(output_file, compressed_file, force=True) + pysam.tabix_index(compressed_file, + preset=None, + force=True, + seq_col=0, + start_col=1, + end_col=1, + zerobased=False) + + # Remove uncompressed file + os.remove(output_file) + sys.stderr.write("The prescored file {0} for chromosome {1} is extracted\n".format(compressed_file, chrom)) + return compressed_file + + except Exception as e: + raise Exception("Error extracting prescored chromosome {0}: {1}".format(chrom, str(e))) + + +def process_chromosome(args): + """Process a single chromosome""" + chrom, vcf_lines, prescored_file, temp_dir, fpos, fref, falt = args + try: + # First extract prescored records for this chromosome + prescored_chrom_file = extract_prescored_chromosome( + prescored_file, + os.path.dirname(prescored_file), + chrom + ) + + # Setup output files for this chromosome + found_file = os.path.join(temp_dir, "matches", "found.{0}.tmp".format(chrom)) + notfound_file = os.path.join(temp_dir, "matches", "notfound.{0}.tmp".format(chrom)) + try: + os.makedirs(os.path.dirname(found_file)) + except OSError: + if not os.path.isdir(os.path.dirname(found_file)): + raise + + if prescored_chrom_file is None: + # Create empty found file and output all records to notfound file + with open(notfound_file, 'w') as f_notfound: + for line in vcf_lines: + f_notfound.write(line) + return chrom, True + + # Open prescored tabix file + pre_tbx = pysam.TabixFile(prescored_chrom_file) + + with open(found_file, 'w') as f_found, open(notfound_file, 'w') as f_notfound: + # Process each variant + for line in vcf_lines: + fields = line.strip().split('\t') + pos = int(fields[1]) + lref, allele = fields[-2], fields[-1].strip() + found = False + + # Look for matches in prescored file + for pre_line in pre_tbx.fetch(chrom, pos-1, pos): + vfields = pre_line.rstrip().split('\t') + if (vfields[fref] == lref) and (vfields[falt] == allele) and (vfields[fpos] == fields[1]): + f_found.write(pre_line + '\n') + found = True + break + + if not found: + f_notfound.write(line) + + return chrom, True + except Exception as e: + sys.stderr.write('Error processing chromosome {0}: {1}\n'.format(chrom, str(e))) + return chrom, False + + +def main(): + parser = OptionParser() + parser.add_option("-p", "--path", dest="path", help="Path to scored variants.") + parser.add_option("-i", "--input", dest="input", help="Read variants from vcf file (default stdin)", default=None) + parser.add_option("--found_out", dest="found_out", help="Write found variants to file (default: stdout)", default=None) + parser.add_option("--header", dest="header", help="Write full header to output (default none)", + default=False, action="store_true") + parser.add_option("-t", "--threads", dest="threads", help="Number of threads to use (default: 1)", default=1) + (options, args) = parser.parse_args() + + # Setup input stream + input_stream = sys.stdin + if options.input and options.input != "-": try: - fref = line.split('\t').index('Ref') - falt = line.split('\t').index('Alt') - except ValueError: - pass -else: - raise IOError("No valid file with pre-scored variants.\n") - -for line in stdin: - line = line.rstrip('\n\r') - if line.startswith('#'): - sys.stdout.write(line + '\n') - continue + input_stream = open(options.input, 'r') + except IOError as e: + sys.stderr.write("Error opening input file: {0}\n".format(str(e))) + sys.exit(1) + # Setup output stream + found_out = open(options.found_out, 'w') if options.found_out else sys.stdout + try: - fields = line.split('\t') - found = False - chrom = fields[0] - pos = int(fields[1]) - lref, allele = fields[-2], fields[-1] - for regionHit in regionTabix.fetch(chrom, pos-1, pos): - vfields = regionHit.rstrip().split('\t') - if (vfields[fref] == lref) and (vfields[falt] == allele) and (vfields[fpos] == fields[1]): - found_out.write(regionHit+"\n") - found = True - - if not found: - sys.stdout.write(line + '\n') - - except ValueError: - sys.stderr.write('Encountered uncovered chromosome\n') - sys.stdout.write(line + '\n') - -if options.input: - stdin.close() - -if options.found_out: - found_out.close() + # Initialize column indices + fpos, fref, falt = 1, 2, 3 + + # Check prescored file + if not (os.path.exists(options.path) and os.path.exists(options.path+".tbi")): + raise IOError("No valid file with pre-scored variants.\n") + + # Get header and column indices from prescored file + pre_tbx = pysam.TabixFile(options.path, 'r') + header = list(pre_tbx.header) + + # Write headers to output files if requested + if options.header: + for line in header: + found_out.write(line+"\n") + + # Get column indices from header + for line in header: + try: + fref = line.split('\t').index('Ref') + falt = line.split('\t').index('Alt') + except ValueError: + pass + + # Buffer VCF data and get chromosomes + header_lines, vcf_by_chrom = buffer_vcf_by_chromosome(input_stream) + chromosomes = sorted(vcf_by_chrom.keys()) + sys.stderr.write("The chromosomes are {0}\n".format(chromosomes)) + sys.stderr.write("There are in total {} lines of records got from the buffer of the input VCF file\n".format(sum([len(vcf_by_chrom[chrom]) for chrom in chromosomes]))) + + # Write VCF headers to stdout + for line in header_lines: + sys.stdout.write(line) + + # Get number of threads from Snakemake + threads = min(options.threads, len(chromosomes)) + sys.stderr.write("Using {0} threads to extract the scored variants across all chromosomes\n".format(threads)) + + temp_dir = os.environ.get("TMPDIR", "/tmp") + + # Setup parallel processing args + process_args = [ + (chrom, vcf_by_chrom[chrom], options.path, temp_dir, fpos, fref, falt) + for chrom in chromosomes + ] + + # Process chromosomes in parallel + pool = mp.Pool(threads) + results = pool.map(process_chromosome, process_args) + pool.close() + pool.join() + + # Combine results + for chrom, success in results: + if success: + found_file = os.path.join(temp_dir, "matches", "found.{0}.tmp".format(chrom)) + notfound_file = os.path.join(temp_dir, "matches", "notfound.{0}.tmp".format(chrom)) + + if os.path.exists(found_file): + with open(found_file) as f: + for line in f: + found_out.write(line) + os.remove(found_file) + + if os.path.exists(notfound_file): + with open(notfound_file) as f: + for line in f: + sys.stdout.write(line) + os.remove(notfound_file) + + finally: + # Close input file if it's not stdin + if options.input and options.input != "-": + input_stream.close() + + # Close output file if it's not stdout + if options.found_out: + found_out.close() + +if __name__ == "__main__": + main() diff --git a/src/scripts/lib/tools/esmScore/esmScore_frameshift_av.py b/src/scripts/lib/tools/esmScore/esmScore_frameshift_av.py index f6ba8bf..d7497d1 100644 --- a/src/scripts/lib/tools/esmScore/esmScore_frameshift_av.py +++ b/src/scripts/lib/tools/esmScore/esmScore_frameshift_av.py @@ -13,6 +13,8 @@ Author: Thorben Maass, Max Schubach Contact: tho.maass@uni-luebeck.de Year:2023 + +Refractored by yangyxt (using numpy array instead of list appending, greatly improved the performance when dealing with huge VCF file) """ import warnings @@ -22,6 +24,231 @@ from esm import pretrained import click +# Constants +WINDOW_SIZE = 250 +BATCH_SIZE = 20 + +def read_and_extract_vcf_data(input_file): + """Reads the VCF file and extracts relevant information.""" + vcf_data = [] + with BgzfReader(input_file, "r") as vcf_file: + for line in vcf_file: + vcf_data.append(line) + + info_pos = {} + for line in vcf_data: + if line.startswith("##INFO="): + info = line.split("|") + for i, item in enumerate(info): + if item in ("Feature", "Protein_position", "Amino_acids", "Consequence"): + info_pos[item] = i + if len(info_pos) == 4: + break + + # Preallocate NumPy arrays + num_variants = sum(1 for line in vcf_data if not line.startswith("#")) + variant_ids = np.empty(num_variants, dtype=object) + transcript_ids = np.empty(num_variants, dtype=object) + oAA = np.empty(num_variants, dtype=object) + nAA = np.empty(num_variants, dtype=object) + prot_pos_start = np.empty(num_variants, dtype=int) + prot_pos_end = np.empty(num_variants, dtype=int) + cons = np.empty(num_variants, dtype=object) + + idx = 0 + for variant in vcf_data: + if not variant.startswith("#"): + variant_entry = variant.split(",") + for i in range(len(variant_entry)): + variant_info = variant_entry[i].split("|") + consequences = variant_info[info_pos["Consequence"]].split("&") + if ("frameshift_variant" in consequences or "stop_gained" in consequences) and len(variant_info[info_pos["Amino_acids"]].split("/")) == 2: + variant_ids[idx] = variant_entry[0].split("|")[0] + transcript_ids[idx] = variant_info[info_pos["Feature"]] + cons[idx] = consequences + oAA[idx] = variant_info[info_pos["Amino_acids"]].split("/")[0] + nAA[idx] = variant_info[info_pos["Amino_acids"]].split("/")[1] + prot_pos_range = variant_info[info_pos["Protein_position"]].split("/")[0] + if "-" in prot_pos_range: + start, end = map(int, prot_pos_range.split("-")) + prot_pos_start[idx] = start + prot_pos_end[idx] = end + else: + pos = int(prot_pos_range) + prot_pos_start[idx] = pos + prot_pos_end[idx] = pos + idx += 1 + + # Trim arrays to actual size + variant_ids = variant_ids[:idx] + transcript_ids = transcript_ids[:idx] + cons = cons[:idx] + oAA = oAA[:idx] + nAA = nAA[:idx] + prot_pos_start = prot_pos_start[:idx] + prot_pos_end = prot_pos_end[:idx] + + return vcf_data, variant_ids, transcript_ids, oAA, nAA, prot_pos_start, prot_pos_end, cons + +def process_transcript_data(transcript_file, transcript_ids, prot_pos_start, prot_pos_end): + """Processes transcript data and creates aa_seq_ref.""" + with open(transcript_file, "r") as f: + transcript_info_entries = f.read().split(">")[1:] + + transcript_info = [] + transcript_info_id = [] + for entry in transcript_info_entries: + parts = entry.split(" ") + transcript_info.append(parts) + transcript_id_full = parts[4] + transcript_id = transcript_id_full.split(".")[0] + transcript_info_id.append(transcript_id) + + # Preallocate arrays + num_transcripts = len(transcript_ids) + aa_seq_ref = np.empty(num_transcripts, dtype=object) + total_stop_codons = np.zeros(num_transcripts, dtype=int) + stop_codons_before_mutation = np.zeros(num_transcripts, dtype=int) + stop_codons_in_indel = np.zeros(num_transcripts, dtype=int) + + for j, transcript_id in enumerate(transcript_ids): + transcript_found = False + for i, info_id in enumerate(transcript_info_id): + if info_id == transcript_id: + transcript_found = True + temp_seq = transcript_info[i][-1].replace("\n", "") + + stop_codons_before_mutation[j] = temp_seq[:prot_pos_start[j]].count("*") + stop_codons_in_indel[j] = temp_seq[prot_pos_start[j]:prot_pos_end[j]].count("*") + total_stop_codons[j] = temp_seq.count("*") + + aa_seq_ref[j] = temp_seq.replace("*", "") + break + + if not transcript_found: + aa_seq_ref[j] = "NA" + stop_codons_before_mutation[j] = 9999 + total_stop_codons[j] = 9999 + stop_codons_in_indel[j] = 9999 + + return aa_seq_ref, total_stop_codons, stop_codons_before_mutation, stop_codons_in_indel + +def prepare_data_for_esm(aa_seq, transcript_ids, prot_pos_start, stop_codons_before_mutation): + """Prepares data for the ESM model.""" + data = [] + prot_pos_mod = [] + for i in range(len(aa_seq)): + if aa_seq[i] == "NA": + continue + + adjusted_pos = prot_pos_start[i] - stop_codons_before_mutation[i] + seq_len = len(aa_seq[i]) + + if seq_len < WINDOW_SIZE: + data.append((transcript_ids[i], aa_seq[i])) + prot_pos_mod.append(adjusted_pos) + elif adjusted_pos + 1 + WINDOW_SIZE // 2 <= seq_len and adjusted_pos + 1 - WINDOW_SIZE // 2 >= 1: + start = adjusted_pos - WINDOW_SIZE // 2 + end = adjusted_pos + WINDOW_SIZE // 2 + data.append((transcript_ids[i], aa_seq[i][start:end])) + prot_pos_mod.append(WINDOW_SIZE // 2) + elif seq_len >= WINDOW_SIZE and adjusted_pos + 1 - WINDOW_SIZE // 2 < 1: + data.append((transcript_ids[i], aa_seq[i][:WINDOW_SIZE])) + prot_pos_mod.append(adjusted_pos) + else: + data.append((transcript_ids[i], aa_seq[i][-WINDOW_SIZE:])) + prot_pos_mod.append(adjusted_pos - (seq_len - WINDOW_SIZE)) + + return data, prot_pos_mod + +def calculate_esm_scores(data, prot_pos_mod, modelsToUse, conseq, batch_size=BATCH_SIZE): + """Runs the ESM model and calculates scores.""" + model_scores = [] + for model_name in modelsToUse: + torch.cuda.empty_cache() + model, alphabet = pretrained.load_model_and_alphabet(model_name) + model.eval() + batch_converter = alphabet.get_batch_converter() + + if torch.cuda.is_available(): + model = model.cuda() + + seq_scores = [] + for i in range(0, len(data), batch_size): + batch_data = data[i:i + batch_size] + batch_labels, batch_strs, batch_tokens = batch_converter(batch_data) + + with torch.no_grad(): + if torch.cuda.is_available(): + batch_tokens = batch_tokens.cuda() + token_probs = torch.log_softmax(model(batch_tokens)["logits"], dim=-1).cpu() + + for j, (transcript_id, seq) in enumerate(batch_data): + idx = i + j + if conseq[idx] == "FS": + score = 0 + for y, aa in enumerate(seq): + if y < prot_pos_mod[idx]: + score += token_probs[j, y + 1, alphabet.get_idx(aa)].item() + else: + aa_scores = [token_probs[j, y + 1, k].item() for k in range(4, 24)] + aa_scores.append(token_probs[j, y + 1, 26].item()) + aa_scores.sort() + mid = len(aa_scores) // 2 + median = (aa_scores[mid] + aa_scores[~mid]) / 2 + score += median + seq_scores.append(score) + elif conseq[idx] == "NA": + seq_scores.append(0) + + model_scores.append(seq_scores) + + return np.array(model_scores) + +def annotate_vcf_and_write_output(vcf_data, variant_ids, transcript_ids, np_array_score_diff, modelsToUse, output_file, aa_seq_ref): + """Adds scores to the VCF file and writes the output.""" + header_end = 0 + for i, line in enumerate(vcf_data): + if line.startswith("#CHROM"): + vcf_data[i - 1] += '##INFO=\n' + header_end = i + break + + vcf_data_modified = vcf_data[:header_end + 1] + + for i in range(header_end + 1, len(vcf_data)): + line = vcf_data[i] + new_line = line + j = 0 + while j < len(variant_ids): + if line.split("|")[0] == variant_ids[j]: + num_scores = 0 + for l in range(j, len(variant_ids)): + if line.split("|")[0] == variant_ids[l]: + num_scores += 1 + else: + break + + new_line = new_line[:-1] + ";EsmScoreFrameshift" + "=" + new_line[-1:] + + for h in range(num_scores): + if aa_seq_ref[j + h] != "NA": + avg_score = np.mean(np_array_score_diff[:, j + h]) + new_line = new_line[:-1] + "{0}|{1:.3f}".format(transcript_ids[j + h][11:], avg_score) + new_line[-1:] + else: + new_line = new_line[:-1] + "{0}|NA".format(transcript_ids[j + h][11:]) + new_line[-1:] + + if h < num_scores - 1: + new_line = new_line[:-1] + "," + new_line[-1:] + + j += num_scores + else: + j += 1 + vcf_data_modified.append(new_line) + + with BgzfWriter(output_file, "w") as vcf_file_output: + for line in vcf_data_modified: + vcf_file_output.write(line) @click.command() @click.option( @@ -72,196 +299,23 @@ "--batch-size", "batch_size", type=int, - default=20, + default=BATCH_SIZE, help="Batch size for esm model, default is 20", ) def cli( input_file, transcript_file, model_directory, modelsToUse, output_file, batch_size ): + """Main CLI function.""" torch.hub.set_dir(model_directory) - # get information from vcf file with SNVs and write them into lists (erstmal Bsp, später automatisch aus info zeile extrahieren) - vcf_file_data = BgzfReader(input_file, "r") # TM_example.vcf.gz - vcf_data = [] - for line in vcf_file_data: - vcf_data.append(line) - - info_pos_Feature = False # TranscriptID - info_pos_ProteinPosition = False # resdidue in protein that is mutated - info_pos_AA = False # mutation from aa (amino acid) x to y - info_pos_consequence = False - # identify positions of annotations importnat for esm score - for line in vcf_data: - if line[0:7] == "##INFO=": - info = line.split("|") - for i in range(0, len(info), 1): - if info[i] == "Feature": - info_pos_Feature = i - if info[i] == "Protein_position": - info_pos_ProteinPosition = i - if info[i] == "Amino_acids": - info_pos_AA = i - if info[i] == "Consequence": - info_pos_consequence = i - break - - # extract annotations important for esm score, "NA" for non-coding variants - variant_ids = [] - transcript_id = [] - oAA = [] - nAA = [] - protPosStart = [] - protPosEnd = [] - protPos_mod = [] - cons = [] - - for variant in vcf_data: - if variant[0:1] != "#": - variant_entry = variant.split(",") - for i in range(0, len(variant_entry), 1): - variant_info = variant_entry[i].split("|") - consequences = variant_info[info_pos_consequence].split("&") - if ( - "frameshift_variant" in consequences - or "stop_gained" in consequences - ) and len(variant_info[info_pos_AA].split("/")) == 2: - variant_ids.append(variant_entry[0].split("|")[0]) - transcript_id.append("transcript:" + variant_info[info_pos_Feature]) - cons.append(variant_info[info_pos_consequence].split("&")) - oAA.append( - variant_info[info_pos_AA].split("/")[0] - ) # can also be "-" if there is an insertion - nAA.append(variant_info[info_pos_AA].split("/")[1]) - if ( - "-" in variant_info[info_pos_ProteinPosition].split("/")[0] - ): # in case of frameshifts, vep only gives X as the new aa - protPosStart.append( - int( - variant_info[info_pos_ProteinPosition] - .split("/")[0] - .split("-")[0] - ) - ) - protPosEnd.append( - int( - variant_info[info_pos_ProteinPosition] - .split("/")[0] - .split("-")[1] - ) - ) - else: - protPosStart.append( - int(variant_info[info_pos_ProteinPosition].split("/")[0]) - ) - protPosEnd.append( - int(variant_info[info_pos_ProteinPosition].split("/")[0]) - ) - protPos_mod.append(False) - - # dissect file with all aa seqs to entries - transcript_data = open( - transcript_file, "r" - ) # - transcript_info_entries = transcript_data.read().split( - ">" - ) # evtl erstes > in file weglöschen - transcript_data.close() - transcript_info = [] - transcript_info_id = [] - - # transcript info contains aa seqs, becomes processed later - for i in range(0, len(transcript_info_entries), 1): - if transcript_info_entries[i] != "": - transcript_info.append(transcript_info_entries[i].split(" ")) - - # transcript ids - for i in range(0, len(transcript_info_entries), 1): - if transcript_info_entries[i] != "": - transcript_info_tmp = transcript_info_entries[i].split(" ")[4] - pointAt = False - # remove version of ENST ID vor comparison with vep annotation - for p in range(0, len(transcript_info_tmp), 1): - if transcript_info_tmp[p] == ".": - pointAt = p - - transcript_info_tmp = transcript_info_tmp[:pointAt] - transcript_info_id.append(transcript_info_tmp) - if (len(transcript_info_id)) != len(transcript_info): - print("ERROR!!!!!!") - - # create list with aa_seq_refs of transcript_ids, mal gucken, ob man alle auf einmal uebergebenkann an esm model - aa_seq_ref = [] - totalNumberOfStopCodons = [] - numberOfStopCodons = [] - numberOfStopCodonsInIndel = [] - for j in range(0, len(transcript_id), 1): - transcript_found = False - for i in range( - 1, len(transcript_info_id), 1 - ): # start bei 1 statt 0 weil das inputfile mit ">" anfaengt und 0. element in aa_seq_ref einfach [] ist - if transcript_info_id[i] == transcript_id[j]: - transcript_found = True - # prepare Seq remove remainings of header - temp_seq = transcript_info[i][-1] - for k in range(0, len(temp_seq), 1): - if temp_seq[k] != "\n": - k = k + 1 - else: - k = k + 1 - temp_seq = temp_seq[k:] - break - - # prepare seq (remove /n) - forbidden_chars = "\n" - for char in forbidden_chars: - temp_seq = temp_seq.replace(char, "") - - # count stop codons in seq before site of mutation - numberOfStopCodons.append(0) - if "*" in temp_seq: - for k in range(0, len(temp_seq), 1): - if temp_seq[k] == "*" and k < protPosStart[j]: - numberOfStopCodons[j] = numberOfStopCodons[j] + 1 - - # count stop codons in Indel - numberOfStopCodonsInIndel.append(0) - if "*" in temp_seq: - for k in range(0, len(temp_seq), 1): - if ( - temp_seq[k] == "*" - and k >= protPosStart[j] - and k < protPosEnd[j] - ): - numberOfStopCodonsInIndel[j] = ( - numberOfStopCodonsInIndel[j] + 1 - ) - - # count stop codons in seq - totalNumberOfStopCodons.append(0) - if "*" in temp_seq: - for k in range(0, len(temp_seq), 1): - if temp_seq[k] == "*": - totalNumberOfStopCodons[j] = totalNumberOfStopCodons[j] + 1 - - # remove additional stop codons (remove *) - forbidden_chars = "*" - for char in forbidden_chars: - temp_seq = temp_seq.replace(char, "") - - aa_seq_ref.append(temp_seq) - if transcript_found == False: - aa_seq_ref.append("NA") - numberOfStopCodons.append(9999) - totalNumberOfStopCodons.append(9999) - numberOfStopCodonsInIndel.append(9999) + vcf_data, variant_ids, transcript_ids, oAA, nAA, prot_pos_start, prot_pos_end, cons = read_and_extract_vcf_data(input_file) + aa_seq_ref, total_stop_codons, stop_codons_before_mutation, stop_codons_in_indel = process_transcript_data( + transcript_file, transcript_ids, prot_pos_start, prot_pos_end + ) conseq = [] aa_seq_alt = [] - for j in range(0, len(aa_seq_ref), 1): - # print(nAA[j]) - # print(oAA[j]) - # print("\n") - + for j in range(0, len(aa_seq_ref)): if aa_seq_ref[j] == "NA": aa_seq_alt.append("NA") conseq.append("NA") @@ -273,268 +327,17 @@ def cli( conseq.append("NA") warnings.warn( "there is a problem with the ensembl data base and vep. The ESMframesift score of this variant will be artificially set to 0. Affected transcript is " - + str(transcript_id[j]) - ) - - # prepare data array for esm model - - window = 250 - data_ref = [] - for i in range(0, len(transcript_id), 1): - if len(aa_seq_ref[i]) < window: - data_ref.append((transcript_id[i], aa_seq_ref[i])) - protPos_mod[i] = protPosStart[i] - numberOfStopCodons[i] - - elif ( - (len(aa_seq_ref[i]) >= window) - and ( - protPosStart[i] - numberOfStopCodons[i] + 1 + window / 2 - <= len(aa_seq_ref[i]) - ) - and (protPosStart[i] - numberOfStopCodons[i] + 1 - window / 2 >= 1) - ): - data_ref.append( - ( - transcript_id[i], - aa_seq_ref[i][ - protPosStart[i] - - numberOfStopCodons[i] - - int(window / 2) : protPosStart[i] - - numberOfStopCodons[i] - + int(window / 2) - ], - ) - ) # esm model can only handle 1024 amino acids, so if the sequence is longer , just the sequece around the mutaion i - protPos_mod[i] = int( - len( - aa_seq_ref[i][ - protPosStart[i] - - numberOfStopCodons[i] - - int(window / 2) : protPosStart[i] - - numberOfStopCodons[i] - + int(window / 2) - ] - ) - / 2 + + str(transcript_ids[j]) ) - elif ( - len(aa_seq_ref[i]) >= window - and protPosStart[i] - numberOfStopCodons[i] + 1 - window / 2 < 1 - ): - data_ref.append((transcript_id[i], aa_seq_ref[i][:window])) - protPos_mod[i] = protPosStart[i] - numberOfStopCodons[i] - - else: - data_ref.append((transcript_id[i], aa_seq_ref[i][-window:])) - protPos_mod[i] = ( - protPosStart[i] - numberOfStopCodons[i] - (len(aa_seq_ref[i]) - window) - ) - - data_alt = [] - - for i in range(0, len(transcript_id), 1): - if len(aa_seq_alt[i]) < window: - data_alt.append((transcript_id[i], aa_seq_alt[i])) - - elif ( - (len(aa_seq_alt[i]) >= window) - and ( - protPosStart[i] - numberOfStopCodons[i] + 1 + window / 2 - <= len(aa_seq_alt[i]) - ) - and (protPosStart[i] - numberOfStopCodons[i] + 1 - window / 2 >= 1) - ): - data_alt.append( - ( - transcript_id[i], - aa_seq_alt[i][ - protPosStart[i] - - numberOfStopCodons[i] - - int(window / 2) : protPosStart[i] - - numberOfStopCodons[i] - + int(window / 2) - ], - ) - ) # esm model can only handle 1024 amino acids, so if the sequence is longer , just the sequece around the mutaion i - - elif ( - len(aa_seq_alt[i]) >= window - and protPosStart[i] - numberOfStopCodons[i] + 1 - window / 2 < 1 - ): - data_alt.append((transcript_id[i], aa_seq_alt[i][:window])) - - else: - data_alt.append((transcript_id[i], aa_seq_alt[i][-window:])) - - ref_alt_scores = [] - # load esm model(s) - for o in range(0, len([data_ref, data_alt]), 1): - data = [data_ref, data_alt][o] - modelScores = [] # scores of different models - if len(data) >= 1: - for k in range(0, len(modelsToUse), 1): - torch.cuda.empty_cache() - model, alphabet = pretrained.load_model_and_alphabet(modelsToUse[k]) - model.eval() # disables dropout for deterministic results - batch_converter = alphabet.get_batch_converter() - - if torch.cuda.is_available(): - model = model.cuda() - # print("transferred to GPU") - - # apply es model to sequence, tokenProbs hat probs von allen aa an jeder pos basierend auf der seq in "data" - seq_scores = [] - for t in range(0, len(data), batch_size): - if t + batch_size > len(data): - batch_data = data[t:] - else: - batch_data = data[t : t + batch_size] - - batch_labels, batch_strs, batch_tokens = batch_converter(batch_data) - with torch.no_grad(): # setzt irgeineine flag auf false - if torch.cuda.is_available(): - token_probs = torch.log_softmax( - model(batch_tokens.cuda())["logits"], dim=-1 - ) - else: - token_probs = torch.log_softmax( - model(batch_tokens)["logits"], dim=-1 - ) - - # test and extract scores from tokenProbs - if o == 1: # alt seqences - for i in range(0, len(batch_data), 1): - # print (str(t+i)+" of "+ str(len(data))+ "alt seqs") - if conseq[i + t] == "FS": - score = 0 - for y in range( - 0, len(batch_data[i][1]), 1 - ): # iterating over single AA in sequence - if y < protPos_mod[i + t]: - score = ( - score - + token_probs[ - i, - y + 1, - alphabet.get_idx(batch_data[i][1][y]), - ] - ) - else: - # calc mean of all possible aa at this position - aa_scores = [] - for k in range(4, 24, 1): - aa_scores.append( - token_probs[i, y + 1, k] - ) # for all aa (except selenocystein) - aa_scores.append( - token_probs[i, y + 1, 26] - ) # for selenocystein - aa_scores.sort() - mid = len(aa_scores) // 2 - median = (aa_scores[mid] + aa_scores[~mid]) / 2 - score = score + median - - seq_scores.append(float(score)) - elif conseq[i + t] == "NA": - score = 0 - seq_scores.append(float(score)) - elif o == 0: # ref sequences - for i in range(0, len(batch_data), 1): - if conseq[i + t] == "FS": - score = 0 - for y in range( - 0, len(batch_data[i][1]), 1 - ): # iterating over single AA in sequence - score = ( - score - + token_probs[ - i, - y + 1, - alphabet.get_idx(batch_data[i][1][y]), - ] - ) - seq_scores.append(float(score)) - elif conseq[i + t] == "NA": - score = 999 # sollte nacher rausgeschissen werden, kein score sollte -999 sein - seq_scores.append(float(score)) - - modelScores.append(seq_scores) - ref_alt_scores.append(modelScores) - - np_array_scores = np.array(ref_alt_scores) - np_array_score_diff = np_array_scores[1] - np_array_scores[0] - - # write scores in cvf. file - - # get information from vcf file with SNVs and write them into lists (erstmal Bsp, später automatisch aus info zeile extrahieren) - - # identify positions of annotations important for esm score - header_end = False - for i in range(0, len(vcf_data), 1): - if vcf_data[i][0:6] == "#CHROM": - vcf_data[i - 1] = ( - vcf_data[i - 1] - + "##INFO=\n' - ) - header_end = i - break - - for i in range(header_end + 1, len(vcf_data), 1): - j = 0 - while j < len(variant_ids): - if vcf_data[i].split("|")[0] == variant_ids[j]: - # count number of vep entires per variant that result in an esm score (i.e. with consequence "missense") - numberOfEsmScoresPerVariant = 0 - for l in range(j, len(variant_ids), 1): - if vcf_data[i].split("|")[0] == variant_ids[l]: - numberOfEsmScoresPerVariant = numberOfEsmScoresPerVariant + 1 - else: - break - - # annotate vcf line with esm scores - # for k in range (0, len(modelsToUse), 1): - vcf_data[i] = ( - vcf_data[i][:-1] + ";EsmScoreFrameshift" + "=" + vcf_data[i][-1:] - ) - for h in range(0, numberOfEsmScoresPerVariant, 1): - if aa_seq_ref[j + h] != "NA": - average_score = 0 - for k in range(0, len(modelsToUse), 1): - average_score = average_score + float( - np_array_score_diff[k][j + h] - ) - average_score = average_score / len(modelsToUse) - vcf_data[i] = ( - vcf_data[i][:-1] - + str(transcript_id[j + h][11:]) - + "|" - + str(round(float(average_score), 3)) - + vcf_data[i][-1:] - ) - else: - vcf_data[i] = ( - vcf_data[i][:-1] - + str(transcript_id[j + h][11:]) - + "|" - + "NA" - + vcf_data[i][-1:] - ) - - if h != numberOfEsmScoresPerVariant - 1: - vcf_data[i] = vcf_data[i][:-1] + "," + vcf_data[i][-1:] - - j = j + numberOfEsmScoresPerVariant - else: - j = j + 1 - - vcf_file_output = BgzfWriter(output_file, "w") - for line in vcf_data: - vcf_file_output.write(line) + data_ref, prot_pos_mod_ref = prepare_data_for_esm(aa_seq_ref, transcript_ids, prot_pos_start, stop_codons_before_mutation) + data_alt, prot_pos_mod_alt = prepare_data_for_esm(aa_seq_alt, transcript_ids, prot_pos_start, stop_codons_before_mutation) - vcf_file_output.close() + ref_scores = calculate_esm_scores(data_ref, prot_pos_mod_ref, modelsToUse, conseq, batch_size) + alt_scores = calculate_esm_scores(data_alt, prot_pos_mod_alt, modelsToUse, conseq, batch_size) + np_array_score_diff = alt_scores - ref_scores + annotate_vcf_and_write_output(vcf_data, variant_ids, transcript_ids, np_array_score_diff, modelsToUse, output_file, aa_seq_ref) if __name__ == "__main__": - cli() + cli() \ No newline at end of file diff --git a/src/scripts/lib/tools/esmScore/esmScore_inFrame_av.py b/src/scripts/lib/tools/esmScore/esmScore_inFrame_av.py index a8ca043..93bf320 100644 --- a/src/scripts/lib/tools/esmScore/esmScore_inFrame_av.py +++ b/src/scripts/lib/tools/esmScore/esmScore_inFrame_av.py @@ -8,11 +8,12 @@ allele using the Ensembl VEP tools' annotations and the reference sequence. To calculate scores for inframe InDel variants, log transformed probabilities of the entire reference and alternative sequences were added up, respectively, and substracted from each other, yielding log odds ratios. The log odds ratios resulting from each of the five models were than averaged and used as final score. -Author: thorben Maass +Author: Thorben Maass, Max Schubach Contact: tho.maass@uni-luebeck.de Year:2023 -""" +Refractored by yangyxt, replacing list appending with numpy array storing. Improving computation speed by 20x when dealing with large amount of variants. +""" import numpy as np from Bio.bgzf import BgzfReader, BgzfWriter @@ -20,6 +21,202 @@ from esm import pretrained import click +# Constants +WINDOW_SIZE = 250 +BATCH_SIZE = 20 + +def read_and_extract_vcf_data(input_file): + """Reads the VCF file and extracts relevant information.""" + vcf_data = [] + with BgzfReader(input_file, "r") as vcf_file: + for line in vcf_file: + vcf_data.append(line) + + info_pos = {} + for line in vcf_data: + if line.startswith("##INFO="): + info = line.split("|") + for i, item in enumerate(info): + if item in ("Feature", "Protein_position", "Amino_acids", "Consequence"): + info_pos[item] = i + if len(info_pos) == 4: + break + + # Preallocate NumPy arrays + num_variants = sum(1 for line in vcf_data if not line.startswith("#")) + variant_ids = np.empty(num_variants, dtype=object) + transcript_ids = np.empty(num_variants, dtype=object) + oAA = np.empty(num_variants, dtype=object) + nAA = np.empty(num_variants, dtype=object) + prot_pos_start = np.empty(num_variants, dtype=int) + prot_pos_end = np.empty(num_variants, dtype=int) + cons = np.empty(num_variants, dtype=object) + + idx = 0 + for variant in vcf_data: + if not variant.startswith("#"): + variant_entry = variant.split(",") + for i in range(len(variant_entry)): + variant_info = variant_entry[i].split("|") + consequences = variant_info[info_pos["Consequence"]].split("&") + if ("inframe_insertion" in consequences or "inframe_deletion" in consequences or "missense_variant" in consequences) and len(variant_info[info_pos["Amino_acids"]].split("/")) == 2: + variant_ids[idx] = variant_entry[0].split("|")[0] + transcript_ids[idx] = variant_info[info_pos["Feature"]].split(".")[0] + cons[idx] = consequences + oAA[idx] = variant_info[info_pos["Amino_acids"]].split("/")[0] + nAA[idx] = variant_info[info_pos["Amino_acids"]].split("/")[1] + prot_pos_range = variant_info[info_pos["Protein_position"]].split("/")[0] + if "-" in prot_pos_range: + start, end = map(int, prot_pos_range.split("-")) + prot_pos_start[idx] = start + prot_pos_end[idx] = end + else: + pos = int(prot_pos_range) + prot_pos_start[idx] = pos + prot_pos_end[idx] = pos + idx += 1 + + # Trim arrays to actual size + variant_ids = variant_ids[:idx] + transcript_ids = transcript_ids[:idx] + cons = cons[:idx] + oAA = oAA[:idx] + nAA = nAA[:idx] + prot_pos_start = prot_pos_start[:idx] + prot_pos_end = prot_pos_end[:idx] + + return vcf_data, variant_ids, transcript_ids, oAA, nAA, prot_pos_start, prot_pos_end, cons + +def process_transcript_data(transcript_file, transcript_ids, prot_pos_start, prot_pos_end): + """Processes transcript data and creates aa_seq_ref.""" + with open(transcript_file, "r") as f: + transcript_info_entries = f.read().split(">")[1:] + + transcript_info = [] + transcript_info_id = [] + for entry in transcript_info_entries: + parts = entry.split(" ") + transcript_info.append(parts) + transcript_id_full = parts[4] + transcript_id = transcript_id_full.split(".")[0] + transcript_info_id.append(transcript_id) + + # Preallocate arrays + num_transcripts = len(transcript_ids) + aa_seq_ref = np.empty(num_transcripts, dtype=object) + total_stop_codons = np.zeros(num_transcripts, dtype=int) + stop_codons_before_mutation = np.zeros(num_transcripts, dtype=int) + stop_codons_in_indel = np.zeros(num_transcripts, dtype=int) + + for j, transcript_id in enumerate(transcript_ids): + transcript_found = False + for i, info_id in enumerate(transcript_info_id): + if info_id == transcript_id: + transcript_found = True + temp_seq = transcript_info[i][-1].replace("\n", "") + + stop_codons_before_mutation[j] = temp_seq[:prot_pos_start[j]].count("*") + stop_codons_in_indel[j] = temp_seq[prot_pos_start[j]:prot_pos_end[j]].count("*") + total_stop_codons[j] = temp_seq.count("*") + + aa_seq_ref[j] = temp_seq.replace("*", "") + break + + if not transcript_found: + aa_seq_ref[j] = "NA" + stop_codons_before_mutation[j] = 9999 + total_stop_codons[j] = 9999 + stop_codons_in_indel[j] = 9999 + + return aa_seq_ref, total_stop_codons, stop_codons_before_mutation, stop_codons_in_indel + +def prepare_data_for_esm(aa_seq, transcript_ids, prot_pos_start, stop_codons_before_mutation, window_size): + """Prepares data for ESM, handling windowing and edge cases.""" + data = [] + prot_pos_mod = np.copy(prot_pos_start) + + for i, seq in enumerate(aa_seq): + if seq == "NA": + continue + + if len(seq) < window_size: + data.append((transcript_ids[i], seq)) + prot_pos_mod[i] -= stop_codons_before_mutation[i] + else: + start = prot_pos_start[i] - stop_codons_before_mutation[i] + if start + 1 + window_size // 2 <= len(seq) and start + 1 - window_size // 2 >= 1: + seq_window = seq[start - window_size // 2 : start + window_size // 2] + data.append((transcript_ids[i], seq_window)) + prot_pos_mod[i] = len(seq_window) // 2 + elif start + 1 - window_size // 2 < 1: + data.append((transcript_ids[i], seq[:window_size])) + prot_pos_mod[i] = start + else: + data.append((transcript_ids[i], seq[-window_size:])) + prot_pos_mod[i] = start - (len(seq) - window_size) + + return data, prot_pos_mod + +def calculate_esm_scores(data, prot_pos_mod, modelsToUse, conseq, batch_size): + """Calculates ESM scores for a given dataset.""" + model_scores = [] + for model_name in modelsToUse: + model, alphabet = pretrained.load_model_and_alphabet(model_name) + batch_converter = alphabet.get_batch_converter() + model.eval() + if torch.cuda.is_available(): + model.cuda() + + seq_scores = [] + for i in range(0, len(data), batch_size): + batch_end = min(i + batch_size, len(data)) + batch_data = data[i:batch_end] + batch_labels, batch_strs, batch_tokens = batch_converter(batch_data) + + with torch.no_grad(): + if torch.cuda.is_available(): + batch_tokens = batch_tokens.cuda() + token_probs = torch.log_softmax(model(batch_tokens)["logits"], dim=-1).cpu() + + for j, (_, seq) in enumerate(batch_data): + score = 0 + if conseq[i + j] in ["inFrame", "MultiMissense"]: + for y, aa in enumerate(seq): + score += token_probs[j, y + 1, alphabet.get_idx(aa)].item() + elif conseq[i + j] == "NA": + score = 999 + seq_scores.append(score) + + model_scores.append(seq_scores) + + return np.array(model_scores) + +def annotate_vcf_and_write_output(vcf_data, variant_ids, transcript_ids, score_diff, modelsToUse, output_file, aa_seq_ref): + """Annotates the VCF with ESM scores and writes the output.""" + header_end = 0 + for i, line in enumerate(vcf_data): + if line.startswith("#CHROM"): + vcf_data[i - 1] += "##INFO=\n" + header_end = i + break + + vcf_output = BgzfWriter(output_file, "w") + for line in vcf_data[:header_end + 1]: + vcf_output.write(line) + + for i, line in enumerate(vcf_data[header_end + 1:]): + j = 0 + while j < len(variant_ids): + if line.split("|")[0] == variant_ids[j]: + if aa_seq_ref[j] != "NA": + score_string = [str(round(score_diff[m][j], 3)) for m in range(len(modelsToUse))] + line = line.strip() + ";EsmScoreInFrame=" + ",".join(score_string) + "\n" + else: + line = line.strip() + ";EsmScoreInFrame=NA\n" + break + j += 1 + vcf_output.write(line) + vcf_output.close() @click.command() @click.option( @@ -70,559 +267,88 @@ "--batch-size", "batch_size", type=int, - default=20, + default=BATCH_SIZE, help="Batch size for esm model, default is 20", ) -def cli(input_file, transcript_file, model_directory, modelsToUse, output_file, batch_size): +def cli( + input_file, transcript_file, model_directory, modelsToUse, output_file, batch_size +): + """Main CLI function.""" torch.hub.set_dir(model_directory) - # get information from vcf file with SNVs and write them into lists - vcf_file_data = BgzfReader(input_file, "r") # TM_example.vcf.gz - vcf_data = [] - for line in vcf_file_data: - vcf_data.append(line) - - info_pos_Feature = False # TranscriptID - info_pos_ProteinPosition = False # resdidue in protein that is mutated - info_pos_AA = False # mutation from aa (amino acid) x to y - info_pos_consequence = False - # identify positions of annotations importnat for esm score - for line in vcf_data: - if line[0:7] == "##INFO=": - info = line.split("|") - for i in range(0, len(info), 1): - if info[i] == "Feature": - info_pos_Feature = i - if info[i] == "Protein_position": - info_pos_ProteinPosition = i - if info[i] == "Amino_acids": - info_pos_AA = i - if info[i] == "Consequence": - info_pos_consequence = i - break - - # extract annotations important for esm score, "NA" for non-coding variants - variant_ids = [] - transcript_id = [] - oAA = [] - nAA = [] - protPosStart = [] - protPosEnd = [] - protPos_mod = [] - cons = [] - # protPos_mod=[]#falls protein laenger als 1024 aa - - for variant in vcf_data: - if variant[0:1] != "#": - variant_entry = variant.split(",") - for i in range(0, len(variant_entry), 1): - variant_info = variant_entry[i].split("|") - consequences = variant_info[info_pos_consequence].split("&") - if ( - ( - "inframe_insertion" in consequences - or "inframe_deletion" in consequences - ) - and len(variant_info[info_pos_AA].split("/")) == 2 - and "stop_gained" not in consequences - and "stop_lost" not in consequences - and "stop_retained_variant" not in consequences - ): - variant_ids.append(variant_entry[0].split("|")[0]) - transcript_id.append("transcript:" + variant_info[info_pos_Feature]) - cons.append(variant_info[info_pos_consequence].split("&")) - - oAA.append( - variant_info[info_pos_AA].split("/")[0] - ) # can also be "-" if there is an insertion - nAA.append(variant_info[info_pos_AA].split("/")[1]) - if ( - "-" in variant_info[info_pos_ProteinPosition].split("/")[0] - ): # in case of frameshifts, vep only gives X as the new aa - protPosStart.append( - int( - variant_info[info_pos_ProteinPosition] - .split("/")[0] - .split("-")[0] - ) - ) - protPosEnd.append( - int( - variant_info[info_pos_ProteinPosition] - .split("/")[0] - .split("-")[1] - ) - ) - else: - protPosStart.append( - int(variant_info[info_pos_ProteinPosition].split("/")[0]) - ) - protPosEnd.append( - int(variant_info[info_pos_ProteinPosition].split("/")[0]) - ) - protPos_mod.append(False) - elif ( - "missense_variant" in consequences - and "-" in variant_info[info_pos_ProteinPosition].split("/")[0] - and len(variant_info[info_pos_AA].split("/")) == 2 - and "stop_gained" not in consequences - and "stop_lost" not in consequences - and "stop_retained_variant" not in consequences - ): - variant_ids.append(variant_entry[0].split("|")[0]) - transcript_id.append("transcript:" + variant_info[info_pos_Feature]) - cons.append(variant_info[info_pos_consequence].split("&")) - oAA.append( - variant_info[info_pos_AA].split("/")[0] - ) # can also be "-" if there is an insertion - nAA.append(variant_info[info_pos_AA].split("/")[1]) - protPosStart.append( - int( - variant_info[info_pos_ProteinPosition] - .split("/")[0] - .split("-")[0] - ) - ) - protPosEnd.append( - int( - variant_info[info_pos_ProteinPosition] - .split("/")[0] - .split("-")[1] - ) - ) - protPos_mod.append(False) - - # dissect file with all aa seqs to entries - transcript_data = open( - transcript_file, "r" - ) # - transcript_info_entries = transcript_data.read().split( - ">" - ) # evtl erstes > in file weglöschen - transcript_data.close() - transcript_info = [] - transcript_info_id = [] - - # transcript info contains aa seqs, becomes processed later - for i in range(0, len(transcript_info_entries), 1): - if transcript_info_entries[i] != "": - transcript_info.append(transcript_info_entries[i].split(" ")) - - # transcript ids - for i in range(0, len(transcript_info_entries), 1): - if transcript_info_entries[i] != "": - transcript_info_tmp = transcript_info_entries[i].split(" ")[4] - pointAt = False - # remove version of ENST ID vor comparison with vep annotation - for p in range(0, len(transcript_info_tmp), 1): - if transcript_info_tmp[p] == ".": - pointAt = p - - transcript_info_tmp = transcript_info_tmp[:pointAt] - - transcript_info_id.append(transcript_info_tmp) - - if (len(transcript_info_id)) != len(transcript_info): - print("ERROR!!!!!!") - - # create list with aa_seq_refs of transcript_ids, mal gucken, ob man alle auf einmal uebergebenkann an esm model - aa_seq_ref = [] - totalNumberOfStopCodons = [] - numberOfStopCodons = [] - numberOfStopCodonsInIndel = [] - for j in range(0, len(transcript_id), 1): - transcript_found = False - for i in range( - 1, len(transcript_info_id), 1 - ): # start bei 1 statt 0 weil das inputfile mit ">" anfaengt und 0. element in aa_seq_ref einfach [] ist - if transcript_info_id[i] == transcript_id[j]: # -2 damit ".9" usw wegfaellt - transcript_found = True - # prepare Seq remove remainings of header - temp_seq = transcript_info[i][-1] - for k in range(0, len(temp_seq), 1): - if temp_seq[k] != "\n": - k = k + 1 - else: - k = k + 1 - temp_seq = temp_seq[k:] - break - - # prepare seq (remove /n) - forbidden_chars = "\n" - for char in forbidden_chars: - temp_seq = temp_seq.replace(char, "") - - # count stop codons in seq before site of mutation - numberOfStopCodons.append(0) - if "*" in temp_seq: - for k in range(0, len(temp_seq), 1): - if temp_seq[k] == "*" and k < protPosStart[j]: - numberOfStopCodons[j] = numberOfStopCodons[j] + 1 - - # count stop codons in Indel - numberOfStopCodonsInIndel.append(0) - if "*" in temp_seq: - for k in range(0, len(temp_seq), 1): - if ( - temp_seq[k] == "*" - and k >= protPosStart[j] - and k < protPosEnd[j] - ): - numberOfStopCodonsInIndel[j] = ( - numberOfStopCodonsInIndel[j] + 1 - ) - - # count stop codons in seq - totalNumberOfStopCodons.append(0) - if "*" in temp_seq: - for k in range(0, len(temp_seq), 1): - if temp_seq[k] == "*": - totalNumberOfStopCodons[j] = totalNumberOfStopCodons[j] + 1 - - # remove additional stop codons (remove *) - forbidden_chars = "*" - for char in forbidden_chars: - temp_seq = temp_seq.replace(char, "") - - aa_seq_ref.append(temp_seq) - if transcript_found == False: - aa_seq_ref.append("NA") - numberOfStopCodons.append(9999) - totalNumberOfStopCodons.append(9999) - numberOfStopCodonsInIndel.append(9999) + vcf_data, variant_ids, transcript_ids, oAA, nAA, prot_pos_start, prot_pos_end, cons = read_and_extract_vcf_data(input_file) + aa_seq_ref, total_stop_codons, stop_codons_before_mutation, stop_codons_in_indel = process_transcript_data( + transcript_file, transcript_ids, prot_pos_start, prot_pos_end + ) conseq = [] - aa_seq_alt = [] + aa_seq_alt = np.empty(len(aa_seq_ref), dtype=object) for j in range(0, len(aa_seq_ref), 1): if aa_seq_ref[j] == "NA": - aa_seq_alt.append("NA") + aa_seq_alt[j] = "NA" conseq.append("NA") - - elif ( - len(nAA[j]) == len(oAA[j]) and "-" not in oAA[j] and "-" not in nAA[j] - ): # inframe ins wenn gleich viele weg kommen wie dazu ommen (zB AAA/GGG) - nAA_mod = nAA[j].replace( - "*", "" - ) # falls oAA und nAA ein nicht terminales stopp codon haben (A*G/P*K) - aa_seq_alt.append( - aa_seq_ref[j][0 : protPosStart[j] - numberOfStopCodons[j] - 1] + elif len(nAA[j]) == len(oAA[j]) and "-" not in oAA[j] and "-" not in nAA[j]: + nAA_mod = nAA[j].replace("*", "") + aa_seq_alt[j] = ( + aa_seq_ref[j][: prot_pos_start[j] - stop_codons_before_mutation[j] - 1] + nAA_mod + aa_seq_ref[j][ - protPosEnd[j] - - numberOfStopCodons[j] - - numberOfStopCodonsInIndel[j] : + prot_pos_end[j] + - stop_codons_before_mutation[j] + - stop_codons_in_indel[j] : ] ) conseq.append("MultiMissense") - - elif ( - len(nAA[j]) >= len(oAA[j]) and "-" in oAA[j] - ): # inframe ins wenn keine alte aa weg kommt (zB -/GP) - aa_seq_alt.append( - aa_seq_ref[j][0 : protPosStart[j] - numberOfStopCodons[j] - 1] + elif len(nAA[j]) >= len(oAA[j]) and "-" in oAA[j]: + aa_seq_alt[j] = ( + aa_seq_ref[j][: prot_pos_start[j] - stop_codons_before_mutation[j] - 1] + nAA[j] - + aa_seq_ref[j][protPosStart[j] - numberOfStopCodons[j] - 1 :] + + aa_seq_ref[j][prot_pos_start[j] - stop_codons_before_mutation[j] - 1 :] ) conseq.append("inFrame") - - elif len(nAA[j]) > len( - oAA[j] - ): # inframe ins wenn alte aa zerstoert wird (zB Q/PE) - nAA_mod = nAA[j].replace( - "*", "" - ) # falls aa vor eigentlichen stopp codon eingefuegt wird und altes stopp dabei zerstoert wird, zaehlt auch als inframe aber ohne stop gain (zB */Y*) - aa_seq_alt.append( - aa_seq_ref[j][0 : protPosStart[j] - numberOfStopCodons[j] - 1] + elif len(nAA[j]) > len(oAA[j]): + nAA_mod = nAA[j].replace("*", "") + aa_seq_alt[j] = ( + aa_seq_ref[j][: prot_pos_start[j] - stop_codons_before_mutation[j] - 1] + nAA_mod + aa_seq_ref[j][ - protPosEnd[j] - - numberOfStopCodons[j] - - numberOfStopCodonsInIndel[j] : + prot_pos_end[j] + - stop_codons_before_mutation[j] + - stop_codons_in_indel[j] : ] ) conseq.append("inFrame") - - elif ( - len(nAA[j]) <= len(oAA[j]) and "-" in nAA[j] - ): # inframe deletion wenn alte aa nicht zerstoert wird (zB QQ/-) - aa_seq_alt.append( - aa_seq_ref[j][0 : protPosStart[j] - numberOfStopCodons[j] - 1] + elif len(nAA[j]) <= len(oAA[j]) and "-" in nAA[j]: + aa_seq_alt[j] = ( + aa_seq_ref[j][: prot_pos_start[j] - stop_codons_before_mutation[j] - 1] + aa_seq_ref[j][ - protPosEnd[j] - - numberOfStopCodonsInIndel[j] - - numberOfStopCodons[j] : + prot_pos_end[j] + - stop_codons_in_indel[j] + - stop_codons_before_mutation[j] : ] ) conseq.append("inFrame") - - elif len(nAA[j]) < len( - oAA[j] - ): # inframe deletion wenn alte aa zerstoert wird (zB KE/K) - nAA_mod = nAA[j].replace( - "*", "" - ) # falls alte aa und stop zerstoert wird und neues stopp eingefuegt wird, wird dann von vep nicht als stop gained bezeichnet (Y*/*) - aa_seq_alt.append( - aa_seq_ref[j][0 : protPosStart[j] - numberOfStopCodons[j] - 1] + elif len(nAA[j]) < len(oAA[j]): + nAA_mod = nAA[j].replace("*", "") + aa_seq_alt[j] = ( + aa_seq_ref[j][: prot_pos_start[j] - stop_codons_before_mutation[j] - 1] + nAA_mod + aa_seq_ref[j][ - protPosEnd[j] - - numberOfStopCodons[j] - - numberOfStopCodonsInIndel[j] : + prot_pos_end[j] + - stop_codons_before_mutation[j] + - stop_codons_in_indel[j] : ] ) conseq.append("inFrame") - # prepare data array for esm model, Problem: only give the coding sequences i - - window = 250 - data_ref = [] - for i in range(0, len(transcript_id), 1): - if len(aa_seq_ref[i]) < window: - data_ref.append((transcript_id[i], aa_seq_ref[i])) - protPos_mod[i] = protPosStart[i] - numberOfStopCodons[i] - - elif ( - (len(aa_seq_ref[i]) >= window) - and ( - protPosStart[i] - numberOfStopCodons[i] + 1 + window / 2 - <= len(aa_seq_ref[i]) - ) - and (protPosStart[i] - numberOfStopCodons[i] + 1 - window / 2 >= 1) - ): - data_ref.append( - ( - transcript_id[i], - aa_seq_ref[i][ - protPosStart[i] - - numberOfStopCodons[i] - - int(window / 2) : protPosStart[i] - - numberOfStopCodons[i] - + int(window / 2) - ], - ) - ) # esm model can only handle 1024 amino acids, so if the sequence is longer , just the sequece around the mutaion i - protPos_mod[i] = int( - len( - aa_seq_ref[i][ - protPosStart[i] - - numberOfStopCodons[i] - - int(window / 2) : protPosStart[i] - - numberOfStopCodons[i] - + int(window / 2) - ] - ) - / 2 - ) - - elif ( - len(aa_seq_ref[i]) >= window - and protPosStart[i] - numberOfStopCodons[i] + 1 - window / 2 < 1 - ): - data_ref.append((transcript_id[i], aa_seq_ref[i][:window])) - protPos_mod[i] = protPosStart[i] - numberOfStopCodons[i] - - else: - data_ref.append((transcript_id[i], aa_seq_ref[i][-window:])) - protPos_mod[i] = ( - protPosStart[i] - numberOfStopCodons[i] - (len(aa_seq_ref[i]) - window) - ) - - data_alt = [] - - for i in range(0, len(transcript_id), 1): - if len(aa_seq_alt[i]) < window: - data_alt.append((transcript_id[i], aa_seq_alt[i])) - - elif ( - (len(aa_seq_alt[i]) >= window) - and ( - protPosStart[i] - numberOfStopCodons[i] + 1 + window / 2 - <= len(aa_seq_alt[i]) - ) - and (protPosStart[i] - numberOfStopCodons[i] + 1 - window / 2 >= 1) - ): - data_alt.append( - ( - transcript_id[i], - aa_seq_alt[i][ - protPosStart[i] - - numberOfStopCodons[i] - - int(window / 2) : protPosStart[i] - - numberOfStopCodons[i] - + int(window / 2) - ], - ) - ) # esm model can only handle 1024 amino acids, so if the sequence is longer , just the sequece around the mutaion i - - elif ( - len(aa_seq_alt[i]) >= window - and protPosStart[i] - numberOfStopCodons[i] + 1 - window / 2 < 1 - ): - data_alt.append((transcript_id[i], aa_seq_alt[i][:window])) - - else: - data_alt.append((transcript_id[i], aa_seq_alt[i][-window:])) - - ref_alt_scores = [] - # load esm model(s) - for o in range(0, len([data_ref, data_alt]), 1): - data = [data_ref, data_alt][o] - modelScores = [] # scores of different models - if len(data) >= 1: - for k in range(0, len(modelsToUse), 1): - torch.cuda.empty_cache() - model, alphabet = pretrained.load_model_and_alphabet(modelsToUse[k]) - model.eval() # disables dropout for deterministic results - batch_converter = alphabet.get_batch_converter() - - if torch.cuda.is_available(): - model = model.cuda() - # print("transferred to GPU") - - # apply es model to sequence, tokenProbs hat probs von allen aa an jeder pos basierend auf der seq in "data" - seq_scores = [] - for t in range(0, len(data), batch_size): - # print (t) - if t + batch_size > len(data): - batch_data = data[t:] - else: - batch_data = data[t : t + batch_size] - - batch_labels, batch_strs, batch_tokens = batch_converter(batch_data) - with torch.no_grad(): # setzt irgeineine flag auf false - if torch.cuda.is_available(): - token_probs = torch.log_softmax( - model(batch_tokens.cuda())["logits"], dim=-1 - ) - else: - token_probs = torch.log_softmax( - model(batch_tokens)["logits"], dim=-1 - ) - - # test and extract scores from tokenProbs - if o == 1: # alt seqences - for i in range(0, len(batch_data), 1): - if ( - conseq[i + t] == "inFrame" - or conseq[i + t] == "MultiMissense" - ): - score = 0 - for y in range( - 0, len(batch_data[i][1]), 1 - ): # iterating over single AA in sequence - score = ( - score - + token_probs[ - i, - y + 1, - alphabet.get_idx(batch_data[i][1][y]), - ] - ) - seq_scores.append(float(score)) - - elif conseq[i + t] == "NA": - score = 0 - seq_scores.append(float(score)) - elif o == 0: # ref sequences - for i in range(0, len(batch_data), 1): - if ( - conseq[i + t] == "inFrame" - or conseq[i + t] == "MultiMissense" - ): - score = 0 - for y in range( - 0, len(batch_data[i][1]), 1 - ): # iterating over single AA in sequence - score = ( - score - + token_probs[ - i, - y + 1, - alphabet.get_idx(batch_data[i][1][y]), - ] - ) - seq_scores.append(float(score)) - - elif conseq[i + t] == "NA": - score = 999 # sollte nacher rausgeschissen werden, kein score sollte -999 sein - seq_scores.append(float(score)) - - modelScores.append(seq_scores) - ref_alt_scores.append(modelScores) - - np_array_scores = np.array(ref_alt_scores) - np_array_score_diff = np_array_scores[0] - np_array_scores[1] - - # write scores in cvf. file - - # get information from vcf file with SNVs and write them into lists (erstmal Bsp, später automatisch aus info zeile extrahieren) - - # identify positions of annotations important for esm score - header_end = False - for i in range(0, len(vcf_data), 1): - if vcf_data[i][0:6] == "#CHROM": - vcf_data[i - 1] = ( - vcf_data[i - 1] - + "##INFO=\n' - ) - header_end = i - break - - for i in range(header_end + 1, len(vcf_data), 1): - j = 0 - while j < len(variant_ids): - if vcf_data[i].split("|")[0] == variant_ids[j]: - # count number of vep entires per variant that result in an esm score (i.e. with consequence "missense") - numberOfEsmScoresPerVariant = 0 - for l in range(j, len(variant_ids), 1): - if vcf_data[i].split("|")[0] == variant_ids[l]: - numberOfEsmScoresPerVariant = numberOfEsmScoresPerVariant + 1 - else: - break - - # annotate vcf line with esm scores - # for k in range (0, len(modelsToUse), 1): - vcf_data[i] = ( - vcf_data[i][:-1] + ";EsmScoreInFrame" + "=" + vcf_data[i][-1:] - ) - for h in range(0, numberOfEsmScoresPerVariant, 1): - if aa_seq_ref[j + h] != "NA": - average_score = 0 - for k in range(0, len(modelsToUse), 1): - average_score = average_score + float( - np_array_score_diff[k][j + h] - ) - average_score = average_score / len(modelsToUse) - vcf_data[i] = ( - vcf_data[i][:-1] - + str(transcript_id[j + h][11:]) - + "|" - + str(round(float(average_score), 3)) - + vcf_data[i][-1:] - ) - else: - vcf_data[i] = ( - vcf_data[i][:-1] - + str(transcript_id[j + h][11:]) - + "|" - + "NA" - + vcf_data[i][-1:] - ) - - if h != numberOfEsmScoresPerVariant - 1: - vcf_data[i] = vcf_data[i][:-1] + "," + vcf_data[i][-1:] - - j = j + numberOfEsmScoresPerVariant - else: - j = j + 1 - - vcf_file_output = BgzfWriter(output_file, "w") - for line in vcf_data: - vcf_file_output.write(line) + data_ref, prot_pos_mod_ref = prepare_data_for_esm(aa_seq_ref, transcript_ids, prot_pos_start, stop_codons_before_mutation, WINDOW_SIZE) + data_alt, prot_pos_mod_alt = prepare_data_for_esm(aa_seq_alt, transcript_ids, prot_pos_start, stop_codons_before_mutation, WINDOW_SIZE) - vcf_file_output.close() + ref_scores = calculate_esm_scores(data_ref, prot_pos_mod_ref, modelsToUse, conseq, batch_size) + alt_scores = calculate_esm_scores(data_alt, prot_pos_mod_alt, modelsToUse, conseq, batch_size) + np_array_score_diff = ref_scores - alt_scores + annotate_vcf_and_write_output(vcf_data, variant_ids, transcript_ids, np_array_score_diff, modelsToUse, output_file, aa_seq_ref) if __name__ == "__main__": - cli() + cli() \ No newline at end of file