From c2fa8ca1918375d8c535ac0fa9537603e01a5184 Mon Sep 17 00:00:00 2001 From: Dave Lawrence Date: Thu, 11 Jun 2026 07:12:55 +0930 Subject: [PATCH 1/2] Reduce N+1 / per-row queries on page loads #1590 - OntologyVersion.latest(): one query for all candidate imports instead of one per field - OntologyVersion.get_ontology_imports(): lazy QuerySet (subquery in __in filters) instead of 5 FK lazy-loads - related_data_for_samples: batch cohort sample / trio / pedigree queries, select_related trio members - load_genome_fasta_index: bulk_create GenomeFastaContig rows - URLTestCase: opt-in query profiling via VG_QUERY_PROFILE / VG_QUERY_TRACE - Query-count regression tests --- library/django_utils/unittest_utils.py | 73 ++++++++++++++++++- ontology/models/models_ontology.py | 39 ++++++---- ontology/tests/test_query_counts.py | 30 ++++++++ snpdb/genome/fasta_index.py | 11 ++- .../snpdb/tags/related_data_for_cohort.html | 4 +- snpdb/templatetags/related_data_tags.py | 38 +++++++--- snpdb/tests/test_fasta_index.py | 16 ++++ snpdb/tests/test_query_counts.py | 36 +++++++++ 8 files changed, 212 insertions(+), 35 deletions(-) create mode 100644 ontology/tests/test_query_counts.py create mode 100644 snpdb/tests/test_fasta_index.py create mode 100644 snpdb/tests/test_query_counts.py diff --git a/library/django_utils/unittest_utils.py b/library/django_utils/unittest_utils.py index c1fdcc038..ecfa0ead2 100644 --- a/library/django_utils/unittest_utils.py +++ b/library/django_utils/unittest_utils.py @@ -1,11 +1,76 @@ import json import logging +import os +import re +import time +import traceback +from collections import Counter from collections.abc import Mapping +from contextlib import ExitStack +from django.db import connection from django.test import Client, TestCase, override_settings +from django.test.utils import CaptureQueriesContext from django.urls import reverse from django.utils.http import urlencode +QUERY_PROFILE_FILE = os.environ.get("VG_QUERY_PROFILE") +QUERY_TRACE_PATTERN = os.environ.get("VG_QUERY_TRACE") # regex: log stack traces for matching SQL + + +def _normalize_sql(sql: str) -> str: + """ Strip literals so queries differing only by parameters group together (N+1 detection) """ + sql = re.sub(r"'[^']*'", "?", sql) + sql = re.sub(r"\b\d+(\.\d+)?\b", "?", sql) + sql = re.sub(r"IN \([^)]*\)", "IN (?)", sql) + return sql + + +class QueryProfilingClient(Client): + """ Captures SQL for each GET and appends a JSON line to VG_QUERY_PROFILE """ + + def _trace_wrapper(self, path): + def wrapper(execute, sql, params, many, context): + if re.search(QUERY_TRACE_PATTERN, sql): + stack = [line for line in traceback.format_stack() + if "/site-packages/" not in line and "unittest_utils" not in line] + with open(QUERY_PROFILE_FILE + ".trace", "a") as f: + f.write(json.dumps({"path": path, "sql": sql[:300], "stack": stack[-12:]}) + "\n") + return execute(sql, params, many, context) + return wrapper + + def get(self, path, *args, **kwargs): + start = time.monotonic() + with ExitStack() as stack: + ctx = stack.enter_context(CaptureQueriesContext(connection)) + if QUERY_TRACE_PATTERN: + stack.enter_context(connection.execute_wrapper(self._trace_wrapper(path))) + response = super().get(path, *args, **kwargs) + request_ms = (time.monotonic() - start) * 1000 + real_queries = [q for q in ctx.captured_queries + if not q["sql"].startswith(("SAVEPOINT", "RELEASE SAVEPOINT", "ROLLBACK TO SAVEPOINT"))] + queries = [q["sql"] for q in real_queries] + duplicates = [{"count": count, "sql": sql} + for sql, count in Counter(_normalize_sql(sql) for sql in queries).most_common() + if count > 1] + record = { + "path": path, + "status": response.status_code, + "num_queries": len(queries), + "request_ms": round(request_ms, 1), + "sql_ms": round(sum(float(q["time"]) for q in real_queries) * 1000, 1), + "duplicates": duplicates, + } + with open(QUERY_PROFILE_FILE, "a") as f: + f.write(json.dumps(record) + "\n") + return response + + +def _make_test_client() -> Client: + if QUERY_PROFILE_FILE: + return QueryProfilingClient() + return Client() + def prevent_request_warnings(original_function): """ @@ -42,7 +107,7 @@ class URLTestCase(TestCase): and contain the file asked. @see https://stackoverflow.com/a/51580328/295724 """ def _test_urls(self, names_and_kwargs, user=None, expected_code_override=None): - client = Client() + client = _make_test_client() if user: client.force_login(user) @@ -59,7 +124,7 @@ def _test_urls(self, names_and_kwargs, user=None, expected_code_override=None): self.assertEqual(response.status_code, expected_code, msg=f"Url '{url}'") def _test_autocomplete_urls(self, names_obj_kwargs, user, in_results): - client = Client() + client = _make_test_client() client.force_login(user) for name, obj, get_kwargs in names_obj_kwargs: @@ -95,7 +160,7 @@ def _test_datatable_urls(self, names_and_kwargs, user=None, expected_code_overri self._test_urls(datatable_definition_names_and_kwargs, user=user, expected_code_override=expected_code_override) def _test_datatables_grid_urls_contains_objs(self, names_obj, user, in_results): - client = Client() + client = _make_test_client() client.force_login(user) for name, kwargs, obj in names_obj: @@ -139,7 +204,7 @@ def _test_datatables_grid_urls_contains_objs(self, names_obj, user, in_results): def _test_jqgrid_urls_contains_objs(self, names_obj, user, in_results): """ Also allow 403 if not expecting results TODO: Load grid properly call URL with sidx params etc, currently get UnorderedObjectListWarning """ - client = Client() + client = _make_test_client() client.force_login(user) for name, kwargs, obj in names_obj: diff --git a/ontology/models/models_ontology.py b/ontology/models/models_ontology.py index ecdde8a44..d0e571da0 100644 --- a/ontology/models/models_ontology.py +++ b/ontology/models/models_ontology.py @@ -4,6 +4,7 @@ """ import functools import logging +import operator import re from collections import defaultdict from dataclasses import dataclass @@ -789,14 +790,19 @@ def in_ontology_version(ontology_import: OntologyImport) -> bool: @staticmethod def latest(validate=True) -> Optional['OntologyVersion']: - oi_qs = OntologyImport.objects.all() + # Fetch candidates for all import fields in one query, then pick the latest per field + import_q_list = [Q(import_source=import_source, filename__in=filenames) + for import_source, filenames in OntologyVersion.ONTOLOGY_IMPORTS.values()] + candidate_imports = OntologyImport.objects.filter(functools.reduce(operator.or_, import_q_list)).order_by("pk") + kwargs = {} - missing_fields = set() - for field, (import_source, filenames) in OntologyVersion.ONTOLOGY_IMPORTS.items(): - if ont_import := oi_qs.filter(import_source=import_source, filename__in=filenames).order_by("pk").last(): - kwargs[field] = ont_import - elif field not in OntologyVersion.OPTIONAL_IMPORTS: - missing_fields.add(field) + for ont_import in candidate_imports: # ordered by pk so the last match per field wins + for field, (import_source, filenames) in OntologyVersion.ONTOLOGY_IMPORTS.items(): + if ont_import.import_source == import_source and ont_import.filename in filenames: + kwargs[field] = ont_import + + missing_fields = {field for field in OntologyVersion.ONTOLOGY_IMPORTS + if field not in kwargs and field not in OntologyVersion.OPTIONAL_IMPORTS} if not missing_fields: values = list(kwargs.values()) @@ -815,14 +821,17 @@ def latest(validate=True) -> Optional['OntologyVersion']: ontology_version = None return ontology_version - def get_ontology_imports(self): - return [ont_import for ont_import in [ - self.gencc_import, - self.mondo_import, - self.hp_owl_import, - self.hp_phenotype_to_genes_import, - self.omim_import - ] if ont_import is not None] + def get_ontology_imports(self) -> QuerySet[OntologyImport]: + """ Lazy QuerySet - using it in __in filters becomes a subquery (no extra queries, + unlike accessing the FK fields which lazy-loads each OntologyImport individually) """ + import_ids = [import_id for import_id in [ + self.gencc_import_id, + self.mondo_import_id, + self.hp_owl_import_id, + self.hp_phenotype_to_genes_import_id, + self.omim_import_id + ] if import_id is not None] + return OntologyImport.objects.filter(pk__in=import_ids) def get_ontology_term_relations(self): return OntologyTermRelation.objects.filter(from_import__in=self.get_ontology_imports()) diff --git a/ontology/tests/test_query_counts.py b/ontology/tests/test_query_counts.py new file mode 100644 index 000000000..868970b8e --- /dev/null +++ b/ontology/tests/test_query_counts.py @@ -0,0 +1,30 @@ +""" +Locks in query counts for OntologyVersion paths that previously had N+1 +query patterns (one query per import field). If a count here grows with the +number of OntologyImports, that's an N+1 regression to fix. +""" +from django.test import TestCase + +from ontology.models import OntologyVersion +from ontology.tests.test_data_ontology import create_test_ontology_version + + +class OntologyVersionQueryCountTest(TestCase): + @classmethod + def setUpTestData(cls): + cls.ontology_version = create_test_ontology_version() + + def test_latest_query_count(self): + # 2 queries: one for all candidate OntologyImports, one for get_or_create's get + with self.assertNumQueries(2): + latest = OntologyVersion.latest() + self.assertEqual(latest, self.ontology_version) + + def test_get_ontology_imports_is_lazy(self): + # Building the QuerySet must not hit the database (so __in filters use a subquery) + with self.assertNumQueries(0): + imports_qs = self.ontology_version.get_ontology_imports() + # Evaluating it is a single query for all import fields + with self.assertNumQueries(1): + imports = list(imports_qs) + self.assertEqual(len(imports), 5) diff --git a/snpdb/genome/fasta_index.py b/snpdb/genome/fasta_index.py index e1b5f6b95..b28866ec1 100644 --- a/snpdb/genome/fasta_index.py +++ b/snpdb/genome/fasta_index.py @@ -21,6 +21,9 @@ def load_genome_fasta_index(genome_fasta: 'GenomeFasta', genome_build: 'GenomeBu NAMES = ["NAME", "LENGTH", "OFFSET", "LINEBASES", "LINEWIDTH"] df = pd.read_csv(index_filename, sep='\t', header=None, names=NAMES) + # Can't import GenomeFastaContig at module level (models_genome imports this module) + GenomeFastaContig = genome_fasta.genomefastacontig_set.model + genome_fasta_contigs = [] for _, row in df[["NAME", "LENGTH"]].iterrows(): name = row["NAME"] length = row["LENGTH"] @@ -31,10 +34,12 @@ def load_genome_fasta_index(genome_fasta: 'GenomeFasta', genome_build: 'GenomeBu msg = f"Fasta contig {name} length = {length} but {genome_build} contig {contig} length={contig.length}" raise ValueError(msg) - genome_fasta.genomefastacontig_set.create(name=name, - length=length, - contig=contig) + genome_fasta_contigs.append(GenomeFastaContig(genome_fasta=genome_fasta, + name=name, + length=length, + contig=contig)) else: logging.warning("Could not find contig in %s for fasta name '%s'", genome_build, name) + GenomeFastaContig.objects.bulk_create(genome_fasta_contigs, batch_size=2000) return genome_fasta diff --git a/snpdb/templates/snpdb/tags/related_data_for_cohort.html b/snpdb/templates/snpdb/tags/related_data_for_cohort.html index 449ac6365..f8dab7d29 100644 --- a/snpdb/templates/snpdb/tags/related_data_for_cohort.html +++ b/snpdb/templates/snpdb/tags/related_data_for_cohort.html @@ -12,10 +12,10 @@

Sub cohorts

{% endfor %} {% endif %} - {% if cohort.trio_set.exists %} + {% if trios %}