diff --git a/tests/unit/test_aragorn_pathfinder.py b/tests/unit/test_aragorn_pathfinder.py index a97747e..17e5382 100644 --- a/tests/unit/test_aragorn_pathfinder.py +++ b/tests/unit/test_aragorn_pathfinder.py @@ -128,48 +128,6 @@ async def test_shadowfax_rejects_multiple_intermediate_categories(redis_mock, mo await shadowfax(_make_task(), logger) -@pytest.mark.asyncio -async def test_shadowfax_uses_intermediate_category_from_constraint(redis_mock, mocker): - """When a constraint provides an intermediate category, the threehop's - intermediates carry that category instead of biolink:NamedThing.""" - msg = _pathfinder_message( - constraints=[{"intermediate_categories": ["biolink:Gene"]}] - ) - mocker.patch( - "workers.aragorn_pathfinder.worker.get_message", - new_callable=mocker.AsyncMock, - return_value=msg, - ) - mocker.patch( - "workers.aragorn_pathfinder.worker.add_callback_id", - new_callable=mocker.AsyncMock, - ) - mocker.patch( - "workers.aragorn_pathfinder.worker.get_running_callbacks", - new_callable=mocker.AsyncMock, - return_value=[], - ) - - mock_response = mocker.Mock() - mock_response.status_code = 200 - mock_httpx = mocker.patch( - "httpx.AsyncClient.post", - new_callable=mocker.AsyncMock, - return_value=mock_response, - ) - - await shadowfax(_make_task(), logger) - - mock_httpx.assert_awaited_once() - - args, kwargs = mock_httpx.call_args - - threehop = kwargs["json"] - nodes = threehop["message"]["query_graph"]["nodes"] - assert nodes["intermediate_0"]["categories"] == ["biolink:Gene"] - assert nodes["intermediate_1"]["categories"] == ["biolink:Gene"] - - @pytest.mark.asyncio async def test_shadowfax_propagates_gandalf_parameters(redis_mock, mocker): """Custom gandalf_parameters in the input should ride along into the diff --git a/workers/aragorn_pathfinder/worker.py b/workers/aragorn_pathfinder/worker.py index 37fbd54..fc82122 100644 --- a/workers/aragorn_pathfinder/worker.py +++ b/workers/aragorn_pathfinder/worker.py @@ -77,6 +77,7 @@ async def shadowfax(task, logger: logging.Logger) -> str: qgraph = message["message"]["query_graph"] pinned_node_keys = [] pinned_node_ids = [] + retriever_query = {"message": message["message"], "parameters": parameters} for node_key, node in qgraph["nodes"].items(): pinned_node_keys.append(node_key) if node.get("ids", None) is not None: @@ -85,7 +86,6 @@ async def shadowfax(task, logger: logging.Logger) -> str: if len(set(pinned_node_ids)) != 2: raise Exception("Pathfinder queries require two pinned nodes.") - intermediate_categories = [] path_key = next(iter(qgraph["paths"].keys())) qpath = qgraph["paths"][path_key] if qpath.get("constraints", None) is not None: @@ -97,23 +97,21 @@ async def shadowfax(task, logger: logging.Logger) -> str: intermediate_categories = ( constraints[0].get("intermediate_categories", None) or [] ) - if len(intermediate_categories) > 1: - raise Exception( - "Pathfinder queries do not support multiple intermediate categories" - ) - else: - intermediate_categories = ["biolink:NamedThing"] + if len(intermediate_categories) > 1: + raise Exception( + "Pathfinder queries do not support multiple intermediate categories" + ) # Create 3-hop query - message["message"]["query_graph"] = { + retriever_query["message"]["query_graph"] = { "nodes": { pinned_node_keys[0]: {"ids": [pinned_node_ids[0]]}, "intermediate_0": { - "categories": intermediate_categories, + "categories": ["biolink:NamedThing"], }, "intermediate_1": { - "categories": intermediate_categories, + "categories": ["biolink:NamedThing"], }, pinned_node_keys[1]: {"ids": [pinned_node_ids[1]]}, }, @@ -212,15 +210,21 @@ async def shadowfax(task, logger: logging.Logger) -> str: # Put callback UID and query ID in postgres await add_callback_id(query_id, callback_id, otel, logger) - message["callback"] = f"{settings.callback_host}/aragorn/callback/{callback_id}" - + retriever_query["callback"] = ( + f"{settings.callback_host}/aragorn/callback/{callback_id}" + ) logger.debug(f"""Sending pathfinder query to {settings.kg_retrieval_url}.""") with tracer.start_as_current_span(f"aragorn.pathfinder.{callback_id}"): async with httpx.AsyncClient(timeout=100) as client: - await client.post( + retriever_async_response = await client.post( settings.kg_retrieval_url, - json=message, + json=retriever_query, ) + try: + retriever_async_response.raise_for_status() + except Exception as e: + logger.error(f"Error contacting retriever: {e}") + logger.debug(f"Error details: {retriever_async_response.json()}") # this worker might have a timeout set for if the lookups don't finish within a certain # amount of time diff --git a/workers/merge_message/worker.py b/workers/merge_message/worker.py index a973840..1c51d7f 100644 --- a/workers/merge_message/worker.py +++ b/workers/merge_message/worker.py @@ -534,6 +534,19 @@ def merge_messages( object_node_id = og_path.get("object") if subject_node_id is None or object_node_id is None: raise KeyError("Missing either subject or object from path.") + + intermediate_category = None + constraints = og_path.get("constraints") or [] + if len(constraints) > 0: + intermediate_categories = ( + constraints[0].get("intermediate_categories") or [] + ) + if len(intermediate_categories) > 0: + intermediate_category = intermediate_categories[0] + + kg_nodes = result["message"]["knowledge_graph"].get("nodes", {}) + kg_edges = result["message"]["knowledge_graph"].get("edges", {}) + aux_counter = 0 score = 0 analyses = [] @@ -544,11 +557,34 @@ def merge_messages( for qg_edge_key, bindings in edge_bindings.items(): for binding in bindings: path_edge_ids.add(binding["id"]) - score = new_result.get("score") - + score = new_result.get("score") if not path_edge_ids: continue + if ( + intermediate_category is not None + and intermediate_category != "biolink:NamedThing" + ): + nb = new_result.get("node_bindings", {}) + pinned_ids = set() + for pinned in (subject_node_id, object_node_id): + for binding in nb.get(pinned, []) or []: + pinned_ids.add(binding["id"]) + intermediate_node_ids = set() + for edge_id in path_edge_ids: + edge = kg_edges.get(edge_id) + if edge is None: + continue + for node_id in (edge.get("subject"), edge.get("object")): + if node_id and node_id not in pinned_ids: + intermediate_node_ids.add(node_id) + if not any( + intermediate_category + in (kg_nodes.get(nid, {}).get("categories") or []) + for nid in intermediate_node_ids + ): + continue + aux_id = f"a_{aux_counter}" aux_counter += 1