Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 0 additions & 42 deletions tests/unit/test_aragorn_pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,48 +128,6 @@ async def test_shadowfax_rejects_multiple_intermediate_categories(redis_mock, mo
await shadowfax(_make_task(), logger)


@pytest.mark.asyncio
async def test_shadowfax_uses_intermediate_category_from_constraint(redis_mock, mocker):
"""When a constraint provides an intermediate category, the threehop's
intermediates carry that category instead of biolink:NamedThing."""
msg = _pathfinder_message(
constraints=[{"intermediate_categories": ["biolink:Gene"]}]
)
mocker.patch(
"workers.aragorn_pathfinder.worker.get_message",
new_callable=mocker.AsyncMock,
return_value=msg,
)
mocker.patch(
"workers.aragorn_pathfinder.worker.add_callback_id",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"workers.aragorn_pathfinder.worker.get_running_callbacks",
new_callable=mocker.AsyncMock,
return_value=[],
)

mock_response = mocker.Mock()
mock_response.status_code = 200
mock_httpx = mocker.patch(
"httpx.AsyncClient.post",
new_callable=mocker.AsyncMock,
return_value=mock_response,
)

await shadowfax(_make_task(), logger)

mock_httpx.assert_awaited_once()

args, kwargs = mock_httpx.call_args

threehop = kwargs["json"]
nodes = threehop["message"]["query_graph"]["nodes"]
assert nodes["intermediate_0"]["categories"] == ["biolink:Gene"]
assert nodes["intermediate_1"]["categories"] == ["biolink:Gene"]


@pytest.mark.asyncio
async def test_shadowfax_propagates_gandalf_parameters(redis_mock, mocker):
"""Custom gandalf_parameters in the input should ride along into the
Expand Down
32 changes: 18 additions & 14 deletions workers/aragorn_pathfinder/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ async def shadowfax(task, logger: logging.Logger) -> str:
qgraph = message["message"]["query_graph"]
pinned_node_keys = []
pinned_node_ids = []
retriever_query = {"message": message["message"], "parameters": parameters}
for node_key, node in qgraph["nodes"].items():
pinned_node_keys.append(node_key)
if node.get("ids", None) is not None:
Expand All @@ -85,7 +86,6 @@ async def shadowfax(task, logger: logging.Logger) -> str:
if len(set(pinned_node_ids)) != 2:
raise Exception("Pathfinder queries require two pinned nodes.")

intermediate_categories = []
path_key = next(iter(qgraph["paths"].keys()))
qpath = qgraph["paths"][path_key]
if qpath.get("constraints", None) is not None:
Expand All @@ -97,23 +97,21 @@ async def shadowfax(task, logger: logging.Logger) -> str:
intermediate_categories = (
constraints[0].get("intermediate_categories", None) or []
)
if len(intermediate_categories) > 1:
raise Exception(
"Pathfinder queries do not support multiple intermediate categories"
)
else:
intermediate_categories = ["biolink:NamedThing"]
if len(intermediate_categories) > 1:
raise Exception(
"Pathfinder queries do not support multiple intermediate categories"
)

# Create 3-hop query

message["message"]["query_graph"] = {
retriever_query["message"]["query_graph"] = {
"nodes": {
pinned_node_keys[0]: {"ids": [pinned_node_ids[0]]},
"intermediate_0": {
"categories": intermediate_categories,
"categories": ["biolink:NamedThing"],
},
"intermediate_1": {
"categories": intermediate_categories,
"categories": ["biolink:NamedThing"],
},
pinned_node_keys[1]: {"ids": [pinned_node_ids[1]]},
},
Expand Down Expand Up @@ -212,15 +210,21 @@ async def shadowfax(task, logger: logging.Logger) -> str:
# Put callback UID and query ID in postgres
await add_callback_id(query_id, callback_id, otel, logger)

message["callback"] = f"{settings.callback_host}/aragorn/callback/{callback_id}"

retriever_query["callback"] = (
f"{settings.callback_host}/aragorn/callback/{callback_id}"
)
logger.debug(f"""Sending pathfinder query to {settings.kg_retrieval_url}.""")
with tracer.start_as_current_span(f"aragorn.pathfinder.{callback_id}"):
async with httpx.AsyncClient(timeout=100) as client:
await client.post(
retriever_async_response = await client.post(
settings.kg_retrieval_url,
json=message,
json=retriever_query,
)
try:
retriever_async_response.raise_for_status()
except Exception as e:
logger.error(f"Error contacting retriever: {e}")
logger.debug(f"Error details: {retriever_async_response.json()}")

# this worker might have a timeout set for if the lookups don't finish within a certain
# amount of time
Expand Down
40 changes: 38 additions & 2 deletions workers/merge_message/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,19 @@ def merge_messages(
object_node_id = og_path.get("object")
if subject_node_id is None or object_node_id is None:
raise KeyError("Missing either subject or object from path.")

intermediate_category = None
constraints = og_path.get("constraints") or []
if len(constraints) > 0:
intermediate_categories = (
constraints[0].get("intermediate_categories") or []
)
if len(intermediate_categories) > 0:
intermediate_category = intermediate_categories[0]

kg_nodes = result["message"]["knowledge_graph"].get("nodes", {})
kg_edges = result["message"]["knowledge_graph"].get("edges", {})

aux_counter = 0
score = 0
analyses = []
Expand All @@ -544,11 +557,34 @@ def merge_messages(
for qg_edge_key, bindings in edge_bindings.items():
for binding in bindings:
path_edge_ids.add(binding["id"])
score = new_result.get("score")

score = new_result.get("score")
if not path_edge_ids:
continue

if (
intermediate_category is not None
and intermediate_category != "biolink:NamedThing"
):
nb = new_result.get("node_bindings", {})
pinned_ids = set()
for pinned in (subject_node_id, object_node_id):
for binding in nb.get(pinned, []) or []:
pinned_ids.add(binding["id"])
intermediate_node_ids = set()
for edge_id in path_edge_ids:
edge = kg_edges.get(edge_id)
if edge is None:
continue
for node_id in (edge.get("subject"), edge.get("object")):
if node_id and node_id not in pinned_ids:
intermediate_node_ids.add(node_id)
if not any(
intermediate_category
in (kg_nodes.get(nid, {}).get("categories") or [])
for nid in intermediate_node_ids
):
continue

aux_id = f"a_{aux_counter}"
aux_counter += 1

Expand Down
Loading