Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions inference/eval_harness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
def exact_match_accuracy(predictions, references):
if not references:
return 0.0
correct = sum(1 for p, r in zip(predictions, references) if p == r)
return correct / len(references)

def eval_step_trace(generated_traces, expected_traces):
if not expected_traces:
return 0.0
score = sum(1 for g, e in zip(generated_traces, expected_traces) if g == e)
return score / len(expected_traces)

def compare_models_v1_placeholder(model_out, fallback_out, groq_out, ground_truth):
return {
"model_exact_match": model_out == ground_truth,
"fallback_exact_match": fallback_out == ground_truth,
"groq_exact_match": groq_out == ground_truth
}

def categorize_error_v1_placeholder(prediction, reference):
if not prediction:
return "empty_output"
if len(prediction) < len(reference) / 2:
return "severe_truncation"
if prediction.lower() == reference.lower():
return "casing_mismatch"
return "logic_or_hallucination"

def run_error_analysis_v1_placeholder(predictions, references):
report = {}
for p, r in zip(predictions, references):
if p != r:
category = categorize_error_v1_placeholder(p, r)
report[category] = report.get(category, 0) + 1
return report
11 changes: 11 additions & 0 deletions inference/router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
def standalone_inference(checkpoint_path, input_data, strategy="beam"):
from inference.core import CalculusSolverInference

inferencer = CalculusSolverInference(checkpoint_path)

if strategy == "beam":
return inferencer.beam_search_decode(input_data)
elif strategy == "standard":
return inferencer.solve(input_data)

raise ValueError("Unknown strategy")
23 changes: 23 additions & 0 deletions model/checkpoint_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
import torch.nn as nn

class DummyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)

def forward(self, x):
return self.linear(x)

def create_dummy_checkpoint(path):
model = DummyModel()
torch.save(model.state_dict(), path)

def validate_checkpoint(checkpoint_path, expected_shapes):
state_dict = torch.load(checkpoint_path)
for key, shape in expected_shapes.items():
if key not in state_dict:
raise KeyError(key)
if state_dict[key].shape != shape:
raise ValueError(shape)
return True
72 changes: 72 additions & 0 deletions run_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
import json
import glob
from model.checkpoint_utils import create_dummy_checkpoint, validate_checkpoint
from inference.router import standalone_inference
from inference.eval_harness import (
exact_match_accuracy,
eval_step_trace,
compare_models_v1_placeholder,
run_error_analysis_v1_placeholder,
)

def load_benchmark_data(benchmark_dir="eval/benchmarks"):
inputs = []
ground_truths = []
expected_traces = []

if not os.path.exists(benchmark_dir):
return inputs, ground_truths, expected_traces

for filepath in glob.glob(os.path.join(benchmark_dir, "*.json")):
with open(filepath, 'r') as f:
data = json.load(f)
for item in data:
inputs.append(item.get("input", ""))
ground_truths.append(item.get("output", ""))
expected_traces.append(item.get("trace", []))

return inputs, ground_truths, expected_traces

def run_end_to_end_pipeline():
checkpoint_path = "checkpoints/dummy_model.pt"
is_dummy = "dummy" in checkpoint_path

if is_dummy:
expected_shapes = {
"linear.weight": (10, 10),
"linear.bias": (10,)
}
if not os.path.exists(checkpoint_path):
create_dummy_checkpoint(checkpoint_path)
else:
expected_shapes = {
"CalculusSolverModel.real_key_placeholder": (0, 0)
}

validate_checkpoint(checkpoint_path, expected_shapes)

inputs, ground_truths, expected_traces = load_benchmark_data()

predictions = []
generated_traces = []

for x in inputs:
output = standalone_inference(checkpoint_path, x, strategy="beam")
if isinstance(output, dict):
predictions.append(output.get("prediction", ""))
generated_traces.append(output.get("trace", []))
else:
predictions.append(output)
generated_traces.append([])

em_score = exact_match_accuracy(predictions, ground_truths)
trace_score = eval_step_trace(generated_traces, expected_traces)
error_report = run_error_analysis_v1_placeholder(predictions, ground_truths)

print(f"Accuracy: {em_score}")
print(f"Trace Score: {trace_score}")
print(f"Errors: {error_report}")

if __name__ == "__main__":
run_end_to_end_pipeline()