Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 16 additions & 30 deletions api/main_enhanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
sys.path.insert(0, str(Path(__file__).parent.parent))

from src.agents.nurse_agent import NurseAgent
from src.agents.triage_agent import TriageAgent
from src.agents.triage_agent import TriageAgent, TriageResult
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TriageResult is imported here but never used in this module. Please remove the unused import to avoid dead code and keep the dependency surface minimal.

Suggested change
from src.agents.triage_agent import TriageAgent, TriageResult
from src.agents.triage_agent import TriageAgent

Copilot uses AI. Check for mistakes.
from src.agents.doctor_agent import DoctorAgent
from src.core.models import PatientState

Expand Down Expand Up @@ -425,22 +425,9 @@ async def process_clinical_text(

# ===== STAGE 2: TRIAGE =====
triage_start = time.time()
triage_string = triage_agent.assess(patient_state)
triage_result = triage_agent.triage(patient_state)
triage_time = (time.time() - triage_start) * 1000
# Parse triage string: "πŸ”΄ P1 - CRITICAL" -> extract priority
triage_parts = triage_string.split(' - ')
priority_with_emoji = triage_parts[0].strip() # "πŸ”΄ P1"
priority_name = triage_parts[1].strip() if len(triage_parts) > 1 else "UNKNOWN"
priority_level = priority_with_emoji.split()[-1] # "P1"
triage_result = {
'priority_level': priority_level,
'priority_name': priority_name,
'reason': triage_string,
'confidence': 0.90,
'escalation_indicators': [],
'differential': []
}
logger.info(f"[{request_id}] Triage: {priority_level} - {priority_name}")
logger.info(f"[{request_id}] Triage: {triage_result.priority_level} - {triage_result.priority_name}")

# ===== STAGE 3: DIAGNOSIS =====
diagnosis_start = time.time()
Expand All @@ -452,17 +439,16 @@ async def process_clinical_text(
total_time = (time.time() - total_start) * 1000

# Build compression data
vitals_json = compressed_json.get('vitals', {})
medication_str = compressed_json.get('medication')
compression_data = CompressionData(
chief_complaint=compressed_json.get('chief_complaint'),
vital_signs=VitalSigns(
heart_rate=compressed_json.get('vital_signs', {}).get('heart_rate'),
blood_pressure=compressed_json.get('vital_signs', {}).get('blood_pressure'),
temperature=compressed_json.get('vital_signs', {}).get('temperature'),
respiratory_rate=compressed_json.get('vital_signs', {}).get('respiratory_rate')
heart_rate=vitals_json.get('hr'),
blood_pressure=vitals_json.get('bp'),
temperature=vitals_json.get('temp'),
),
symptoms=compressed_json.get('symptoms', []),
medications=compressed_json.get('medications', []),
oxygen=compressed_json.get('oxygen')
medications=[medication_str] if medication_str else [],
)

# Build response
Expand All @@ -481,17 +467,17 @@ async def process_clinical_text(
compressed_data=compression_data
),
triage=TriageResponse(
priority_level=triage_result['priority_level'],
priority_name=triage_result['priority_name'],
confidence=triage_result.get('confidence', 0.90),
reason=triage_result['reason'],
escalation_indicators=triage_result.get('escalation_indicators', []),
priority_level=triage_result.priority_level,
priority_name=triage_result.priority_name,
confidence=triage_result.confidence,
reason=triage_result.display,
escalation_indicators=[],
triage_time_ms=round(triage_time, 2)
),
diagnosis=DiagnosisResponse(
primary_assessment=doctor_recommendation,
differential=triage_result.get('differential', []),
recommendations=triage_result.get('recommendations', []),
differential=[],
recommendations=[],
model_version="MedGemma-v5",
processing_time_ms=round(diagnosis_time, 2)
),
Expand Down
122 changes: 95 additions & 27 deletions src/agents/triage_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,114 @@

from __future__ import annotations

from typing import Literal

from pydantic import BaseModel

from src.core.models import PatientState

# Display strings keyed by priority level
_DISPLAY = {
"P1": "\U0001f534 P1 - CRITICAL",
"P2": "\U0001f7e1 P2 - URGENT",
"P3": "\U0001f7e2 P3 - STANDARD",
}

_PRIORITY_NAME = {
"P1": "CRITICAL",
"P2": "URGENT",
"P3": "STANDARD",
}


class TriageResult(BaseModel):
"""Structured triage assessment result."""

priority_level: Literal["P1", "P2", "P3"]
priority_name: str
display: str
confidence: float = 0.90


class TriageAgent:
"""Assigns a triage priority level based on patient state."""

def assess(self, patient_state: PatientState) -> str:
"""Assess triage priority from a compressed patient state.
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------

Args:
patient_state: A PatientState produced by the CompText protocol.
@staticmethod
def _parse_systolic(bp: str | None) -> int | None:
"""Extract systolic value from a BP string like '160/90'."""
if not bp:
return None
try:
return int(bp.split("/")[0])
except (ValueError, IndexError):
return None

Returns:
A string indicating the triage priority level.
"""
@staticmethod
def _classify(patient_state: PatientState) -> Literal["P1", "P2", "P3"]:
"""Determine the priority level from *patient_state*."""
protocol = patient_state.meta.get("active_protocol", "")
vitals = patient_state.vitals
systolic = TriageAgent._parse_systolic(vitals.bp)

# Parse systolic BP from string like "160/90"
systolic = None
if vitals.bp:
try:
systolic = int(vitals.bp.split("/")[0])
except (ValueError, IndexError):
pass

# P1 - CRITICAL: high-acuity protocols or critical vitals
# P1 – CRITICAL: high-acuity protocols or critical vitals
critical_protocols = ("Cardiology", "Trauma", "Neurology")
if any(p in protocol for p in critical_protocols):
return "\U0001f534 P1 - CRITICAL"
if vitals.hr is not None and vitals.hr > 120:
return "\U0001f534 P1 - CRITICAL"
if systolic is not None and systolic > 160:
return "\U0001f534 P1 - CRITICAL"
return "P1"
if vitals.hr is not None and vitals.hr > 125:
return "P1"
if systolic is not None and systolic > 165:
return "P1"
if vitals.temp is not None and vitals.temp > 40:
return "P1"
if vitals.temp is not None and vitals.temp < 35:
return "P1"

# P2 - URGENT: respiratory or fever
# P2 – URGENT: respiratory, elevated vitals, or fever
if "Respiratory" in protocol:
return "\U0001f7e1 P2 - URGENT"
if vitals.temp is not None and vitals.temp > 39.0:
return "\U0001f7e1 P2 - URGENT"
return "P2"
if vitals.hr is not None and vitals.hr >= 100:
return "P2"
if vitals.hr is not None and vitals.hr < 50:
return "P2"
if systolic is not None and systolic >= 160:
return "P2"
if vitals.temp is not None and vitals.temp >= 38:
return "P2"

# P3 – STANDARD
return "P3"

# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------

# P3 - STANDARD: everything else
return "\U0001f7e2 P3 - STANDARD"
def triage(self, patient_state: PatientState) -> TriageResult:
"""Return a structured triage result for *patient_state*.

Args:
patient_state: A PatientState produced by the CompText protocol.

Returns:
A :class:`TriageResult` with validated priority level.
"""
level = self._classify(patient_state)
return TriageResult(
priority_level=level,
priority_name=_PRIORITY_NAME[level],
display=_DISPLAY[level],
)

def assess(self, patient_state: PatientState) -> str:
"""Assess triage priority (legacy convenience wrapper).

Args:
patient_state: A PatientState produced by the CompText protocol.

Returns:
A human-readable string such as ``"πŸ”΄ P1 - CRITICAL"``.
"""
return self.triage(patient_state).display
Loading
Loading