Skip to content

Commit 0f14f9f

Browse files
committed
alignment changes and substring matching for variant
1 parent 69acd61 commit 0f14f9f

1 file changed

Lines changed: 110 additions & 3 deletions

File tree

src/benchmark/fa_benchmark.py

Lines changed: 110 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Any, Optional
1+
from typing import Dict, List, Any, Optional, Tuple
22
from difflib import SequenceMatcher
33
import numpy as np
44
import re
@@ -45,6 +45,95 @@ def expand_annotations_by_variant(annotations: List[Dict[str, Any]]) -> List[Dic
4545
return expanded
4646

4747

48+
def align_fa_annotations_by_variant(
49+
ground_truth_fa: List[Dict[str, Any]],
50+
predictions_fa: List[Dict[str, Any]],
51+
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[str]]:
52+
"""
53+
Align FA annotations by variant string with robust matching:
54+
1) Expand multi-variant records to one per variant
55+
2) Prefer rsID intersection; fallback to normalized substring containment
56+
Returns aligned (gt_list, pred_list, display_keys)
57+
"""
58+
rs_re = re.compile(r"rs\d+", re.IGNORECASE)
59+
60+
gt_expanded = expand_annotations_by_variant(ground_truth_fa or [])
61+
pred_expanded = expand_annotations_by_variant(predictions_fa or [])
62+
63+
pred_index: List[Tuple[set, str, Dict[str, Any]]] = []
64+
for rec in pred_expanded:
65+
raw = (rec.get('Variant/Haplotypes') or '').strip()
66+
raw_norm = normalize_variant(raw).lower()
67+
rsids = set(m.group(0).lower() for m in rs_re.finditer(raw))
68+
pred_index.append((rsids, raw_norm, rec))
69+
70+
aligned_gt: List[Dict[str, Any]] = []
71+
aligned_pred: List[Dict[str, Any]] = []
72+
display_keys: List[str] = []
73+
74+
for gt_rec in gt_expanded:
75+
gt_raw = (gt_rec.get('Variant/Haplotypes') or '').strip()
76+
gt_norm = normalize_variant(gt_raw).lower()
77+
gt_rs = set(m.group(0).lower() for m in rs_re.finditer(gt_raw))
78+
79+
match = None
80+
if gt_rs:
81+
for rsids, raw_norm, pred_rec in pred_index:
82+
if rsids & gt_rs:
83+
match = pred_rec
84+
break
85+
if match is None and gt_norm:
86+
for rsids, raw_norm, pred_rec in pred_index:
87+
if gt_norm in raw_norm:
88+
match = pred_rec
89+
break
90+
91+
if match is not None:
92+
aligned_gt.append(gt_rec)
93+
aligned_pred.append(match)
94+
disp = next(iter(gt_rs)) if gt_rs else gt_norm
95+
display_keys.append(disp)
96+
97+
return aligned_gt, aligned_pred, display_keys
98+
99+
def evaluate_fa_from_articles(
100+
ground_truth_article: Dict[str, Any],
101+
predictions_article: Dict[str, Any],
102+
) -> Dict[str, Any]:
103+
"""
104+
Given two article dicts with var_fa_ann lists, align by variant and evaluate.
105+
Returns standard results plus results['aligned_variants'] and a 'status'.
106+
"""
107+
gt_fa = (ground_truth_article or {}).get('var_fa_ann', []) or []
108+
pred_fa = (predictions_article or {}).get('var_fa_ann', []) or []
109+
110+
if not gt_fa or not pred_fa:
111+
return {
112+
'total_samples': 0,
113+
'field_scores': {},
114+
'overall_score': 0.0,
115+
'detailed_results': [],
116+
'aligned_variants': [],
117+
'status': 'missing_var_fa_ann',
118+
}
119+
120+
aligned_gt, aligned_pred, display = align_fa_annotations_by_variant(gt_fa, pred_fa)
121+
if not aligned_gt:
122+
return {
123+
'total_samples': 0,
124+
'field_scores': {},
125+
'overall_score': 0.0,
126+
'detailed_results': [],
127+
'aligned_variants': [],
128+
'status': 'no_overlap_after_alignment',
129+
}
130+
131+
results = _evaluate_functional_analysis_pairs(aligned_gt, aligned_pred, None)
132+
results['aligned_variants'] = display
133+
results['status'] = 'ok'
134+
return results
135+
136+
48137
def validate_external_data(annotation: Dict[str, Any]) -> List[str]:
49138
issues: List[str] = []
50139
rsid_pattern = re.compile(r'^rs\d+$', re.IGNORECASE)
@@ -138,7 +227,14 @@ def evaluate_functional_analysis(samples: List[Dict[str, Any]]) -> Dict[str, Any
138227

139228
gt_list: List[Dict[str, Any]] = [gt]
140229
pred_list: List[Dict[str, Any]] = [pred]
230+
return _evaluate_functional_analysis_pairs(gt_list, pred_list, None)
141231

232+
233+
def _evaluate_functional_analysis_pairs(
234+
gt_list: List[Dict[str, Any]],
235+
pred_list: List[Dict[str, Any]],
236+
study_parameters: Optional[List[Dict[str, Any]]],
237+
) -> Dict[str, Any]:
142238
model = _get_model()
143239

144240
def exact_match(gt_val: Any, pred_val: Any) -> float:
@@ -205,8 +301,19 @@ def is_star_allele(variant: str) -> bool:
205301
covered_count += 1
206302
return covered_count / len(gt_list_filtered)
207303

304+
def variant_substring_match(gt_val: Any, pred_val: Any) -> float:
305+
if gt_val is None and pred_val is None:
306+
return 1.0
307+
if gt_val is None or pred_val is None:
308+
return 0.0
309+
gt_str = str(gt_val).strip().lower()
310+
pred_str = str(pred_val).strip().lower()
311+
if not gt_str:
312+
return 1.0 if not pred_str else 0.0
313+
return 1.0 if gt_str in pred_str else 0.0
314+
208315
field_evaluators = {
209-
'Variant/Haplotypes': variant_coverage,
316+
'Variant/Haplotypes': variant_substring_match,
210317
'Gene': semantic_similarity,
211318
'Drug(s)': semantic_similarity,
212319
'PMID': exact_match,
@@ -241,7 +348,7 @@ def is_star_allele(variant: str) -> bool:
241348
sample_result: Dict[str, Any] = {'sample_id': i, 'field_scores': {}}
242349
for field, evaluator in field_evaluators.items():
243350
sample_result['field_scores'][field] = evaluator(gt.get(field), pred.get(field))
244-
dependency_issues = validate_all_dependencies(pred, None)
351+
dependency_issues = validate_all_dependencies(pred, study_parameters)
245352
sample_result['dependency_issues'] = dependency_issues
246353
if dependency_issues:
247354
penalty_per_issue = 0.05

0 commit comments

Comments
 (0)