diff --git a/classification/tests/views/test_query_scaling.py b/classification/tests/views/test_query_scaling.py new file mode 100644 index 000000000..ee0441c28 --- /dev/null +++ b/classification/tests/views/test_query_scaling.py @@ -0,0 +1,62 @@ +""" +The classification datatable's query count must not grow with the number of +classification rows returned - per-row work in server renderers is the main +N+1 risk in DataTables endpoints (row data itself comes from .values()). +""" +from django.contrib.auth.models import User +from django.db import connection +from django.test import Client +from django.test.utils import CaptureQueriesContext +from django.urls import reverse + +from annotation.fake_annotation import get_fake_annotation_version, create_fake_variants +from classification.autopopulate_evidence_keys.autopopulate_evidence_keys import \ + create_classification_for_sample_and_variant_objects +from library.django_utils.unittest_utils import URLTestCase, production_query_count +from snpdb.models import GenomeBuild, Variant, Country, Lab, Organization + + +class ClassificationDatatableScalingTest(URLTestCase): + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.user = User.objects.get_or_create(username='classification_scaling_user')[0] + organization = Organization.objects.get_or_create(name="Fake Org", group_name="fake_org")[0] + australia = Country.objects.get_or_create(name="Australia")[0] + cls.lab = Lab.objects.get_or_create(name="Fake Lab", city="Adelaide", country=australia, + organization=organization, group_name="fake_org/fake_lab")[0] + cls.lab.group.user_set.add(cls.user) + + cls.genome_build = GenomeBuild.get_name_or_alias("GRCh37") + cls.annotation_version = get_fake_annotation_version(cls.genome_build) + create_fake_variants(cls.genome_build) + cls.variants = list(Variant.objects.filter(Variant.get_no_reference_q()).order_by("pk")[:10]) + cls._create_classifications(cls.variants[:2]) + + @classmethod + def _create_classifications(cls, variants): + for variant in variants: + classification = create_classification_for_sample_and_variant_objects( + cls.user, cls.lab, None, variant, cls.genome_build, + annotation_version=cls.annotation_version) + classification.patch_value({"clinical_significance": "VUS"}, user=cls.user, save=True) + classification.publish_latest(cls.user) + + def _datatable_production_query_count(self, client) -> int: + url = reverse('classification_datatables') + with CaptureQueriesContext(connection) as ctx: + response = client.get(url) + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertGreaterEqual(len(data["data"]), 2) + return production_query_count(ctx.captured_queries) + + def test_datatable_query_count_flat_with_more_rows(self): + client = Client() + client.force_login(self.user) + self._datatable_production_query_count(client) # warm up per-process caches + + num_queries_two_rows = self._datatable_production_query_count(client) + self._create_classifications(self.variants[2:]) + num_queries_ten_rows = self._datatable_production_query_count(client) + self.assertEqual(num_queries_two_rows, num_queries_ten_rows) diff --git a/library/django_utils/unittest_utils.py b/library/django_utils/unittest_utils.py index c1fdcc038..cee8d9047 100644 --- a/library/django_utils/unittest_utils.py +++ b/library/django_utils/unittest_utils.py @@ -1,11 +1,103 @@ import json import logging +import os +import re +import time +import traceback +from collections import Counter from collections.abc import Mapping +from contextlib import ExitStack +from django.db import connection from django.test import Client, TestCase, override_settings +from django.test.utils import CaptureQueriesContext from django.urls import reverse from django.utils.http import urlencode +QUERY_PROFILE_FILE = os.environ.get("VG_QUERY_PROFILE") +QUERY_TRACE_PATTERN = os.environ.get("VG_QUERY_TRACE") # regex: log stack traces for matching SQL + +# Models whose managers (ObjectManagerCachingImmutable/Request) cache lookups in production +# but disable caching under settings.UNIT_TEST - so repeated gets on these tables in a test +# are not real production queries. Used by production_query_count(). +PRODUCTION_CACHED_TABLES = ( + "snpdb_genomebuild", + "genes_genesymbol", + "flags_flagtype", + "classification_resolvedvariantinfo", + "snpdb_allele", + "snpdb_organization", + "snpdb_lab", +) + + +def production_query_count(captured_queries) -> int: + """ Number of queries that would hit the database in production: excludes test + savepoints and lookups on tables whose object managers cache in production """ + count = 0 + for query in captured_queries: + sql = query["sql"] + if sql.startswith(("SAVEPOINT", "RELEASE SAVEPOINT", "ROLLBACK TO SAVEPOINT")): + continue + if sql.startswith("SELECT") and any(f'FROM "{table}"' in sql for table in PRODUCTION_CACHED_TABLES): + continue + count += 1 + return count + + +def _normalize_sql(sql: str) -> str: + """ Strip literals so queries differing only by parameters group together (N+1 detection) """ + sql = re.sub(r"'[^']*'", "?", sql) + sql = re.sub(r"\b\d+(\.\d+)?\b", "?", sql) + sql = re.sub(r"IN \([^)]*\)", "IN (?)", sql) + return sql + + +class QueryProfilingClient(Client): + """ Captures SQL for each GET and appends a JSON line to VG_QUERY_PROFILE """ + + def _trace_wrapper(self, path): + def wrapper(execute, sql, params, many, context): + if re.search(QUERY_TRACE_PATTERN, sql): + stack = [line for line in traceback.format_stack() + if "/site-packages/" not in line and "unittest_utils" not in line] + with open(QUERY_PROFILE_FILE + ".trace", "a") as f: + f.write(json.dumps({"path": path, "sql": sql[:300], "stack": stack[-12:]}) + "\n") + return execute(sql, params, many, context) + return wrapper + + def get(self, path, *args, **kwargs): + start = time.monotonic() + with ExitStack() as stack: + ctx = stack.enter_context(CaptureQueriesContext(connection)) + if QUERY_TRACE_PATTERN: + stack.enter_context(connection.execute_wrapper(self._trace_wrapper(path))) + response = super().get(path, *args, **kwargs) + request_ms = (time.monotonic() - start) * 1000 + real_queries = [q for q in ctx.captured_queries + if not q["sql"].startswith(("SAVEPOINT", "RELEASE SAVEPOINT", "ROLLBACK TO SAVEPOINT"))] + queries = [q["sql"] for q in real_queries] + duplicates = [{"count": count, "sql": sql} + for sql, count in Counter(_normalize_sql(sql) for sql in queries).most_common() + if count > 1] + record = { + "path": path, + "status": response.status_code, + "num_queries": len(queries), + "request_ms": round(request_ms, 1), + "sql_ms": round(sum(float(q["time"]) for q in real_queries) * 1000, 1), + "duplicates": duplicates, + } + with open(QUERY_PROFILE_FILE, "a") as f: + f.write(json.dumps(record) + "\n") + return response + + +def _make_test_client() -> Client: + if QUERY_PROFILE_FILE: + return QueryProfilingClient() + return Client() + def prevent_request_warnings(original_function): """ @@ -42,7 +134,7 @@ class URLTestCase(TestCase): and contain the file asked. @see https://stackoverflow.com/a/51580328/295724 """ def _test_urls(self, names_and_kwargs, user=None, expected_code_override=None): - client = Client() + client = _make_test_client() if user: client.force_login(user) @@ -59,7 +151,7 @@ def _test_urls(self, names_and_kwargs, user=None, expected_code_override=None): self.assertEqual(response.status_code, expected_code, msg=f"Url '{url}'") def _test_autocomplete_urls(self, names_obj_kwargs, user, in_results): - client = Client() + client = _make_test_client() client.force_login(user) for name, obj, get_kwargs in names_obj_kwargs: @@ -95,7 +187,7 @@ def _test_datatable_urls(self, names_and_kwargs, user=None, expected_code_overri self._test_urls(datatable_definition_names_and_kwargs, user=user, expected_code_override=expected_code_override) def _test_datatables_grid_urls_contains_objs(self, names_obj, user, in_results): - client = Client() + client = _make_test_client() client.force_login(user) for name, kwargs, obj in names_obj: @@ -139,7 +231,7 @@ def _test_datatables_grid_urls_contains_objs(self, names_obj, user, in_results): def _test_jqgrid_urls_contains_objs(self, names_obj, user, in_results): """ Also allow 403 if not expecting results TODO: Load grid properly call URL with sidx params etc, currently get UnorderedObjectListWarning """ - client = Client() + client = _make_test_client() client.force_login(user) for name, kwargs, obj in names_obj: diff --git a/ontology/models/models_ontology.py b/ontology/models/models_ontology.py index ecdde8a44..d0e571da0 100644 --- a/ontology/models/models_ontology.py +++ b/ontology/models/models_ontology.py @@ -4,6 +4,7 @@ """ import functools import logging +import operator import re from collections import defaultdict from dataclasses import dataclass @@ -789,14 +790,19 @@ def in_ontology_version(ontology_import: OntologyImport) -> bool: @staticmethod def latest(validate=True) -> Optional['OntologyVersion']: - oi_qs = OntologyImport.objects.all() + # Fetch candidates for all import fields in one query, then pick the latest per field + import_q_list = [Q(import_source=import_source, filename__in=filenames) + for import_source, filenames in OntologyVersion.ONTOLOGY_IMPORTS.values()] + candidate_imports = OntologyImport.objects.filter(functools.reduce(operator.or_, import_q_list)).order_by("pk") + kwargs = {} - missing_fields = set() - for field, (import_source, filenames) in OntologyVersion.ONTOLOGY_IMPORTS.items(): - if ont_import := oi_qs.filter(import_source=import_source, filename__in=filenames).order_by("pk").last(): - kwargs[field] = ont_import - elif field not in OntologyVersion.OPTIONAL_IMPORTS: - missing_fields.add(field) + for ont_import in candidate_imports: # ordered by pk so the last match per field wins + for field, (import_source, filenames) in OntologyVersion.ONTOLOGY_IMPORTS.items(): + if ont_import.import_source == import_source and ont_import.filename in filenames: + kwargs[field] = ont_import + + missing_fields = {field for field in OntologyVersion.ONTOLOGY_IMPORTS + if field not in kwargs and field not in OntologyVersion.OPTIONAL_IMPORTS} if not missing_fields: values = list(kwargs.values()) @@ -815,14 +821,17 @@ def latest(validate=True) -> Optional['OntologyVersion']: ontology_version = None return ontology_version - def get_ontology_imports(self): - return [ont_import for ont_import in [ - self.gencc_import, - self.mondo_import, - self.hp_owl_import, - self.hp_phenotype_to_genes_import, - self.omim_import - ] if ont_import is not None] + def get_ontology_imports(self) -> QuerySet[OntologyImport]: + """ Lazy QuerySet - using it in __in filters becomes a subquery (no extra queries, + unlike accessing the FK fields which lazy-loads each OntologyImport individually) """ + import_ids = [import_id for import_id in [ + self.gencc_import_id, + self.mondo_import_id, + self.hp_owl_import_id, + self.hp_phenotype_to_genes_import_id, + self.omim_import_id + ] if import_id is not None] + return OntologyImport.objects.filter(pk__in=import_ids) def get_ontology_term_relations(self): return OntologyTermRelation.objects.filter(from_import__in=self.get_ontology_imports()) diff --git a/ontology/tests/test_query_counts.py b/ontology/tests/test_query_counts.py new file mode 100644 index 000000000..868970b8e --- /dev/null +++ b/ontology/tests/test_query_counts.py @@ -0,0 +1,30 @@ +""" +Locks in query counts for OntologyVersion paths that previously had N+1 +query patterns (one query per import field). If a count here grows with the +number of OntologyImports, that's an N+1 regression to fix. +""" +from django.test import TestCase + +from ontology.models import OntologyVersion +from ontology.tests.test_data_ontology import create_test_ontology_version + + +class OntologyVersionQueryCountTest(TestCase): + @classmethod + def setUpTestData(cls): + cls.ontology_version = create_test_ontology_version() + + def test_latest_query_count(self): + # 2 queries: one for all candidate OntologyImports, one for get_or_create's get + with self.assertNumQueries(2): + latest = OntologyVersion.latest() + self.assertEqual(latest, self.ontology_version) + + def test_get_ontology_imports_is_lazy(self): + # Building the QuerySet must not hit the database (so __in filters use a subquery) + with self.assertNumQueries(0): + imports_qs = self.ontology_version.get_ontology_imports() + # Evaluating it is a single query for all import fields + with self.assertNumQueries(1): + imports = list(imports_qs) + self.assertEqual(len(imports), 5) diff --git a/snpdb/genome/fasta_index.py b/snpdb/genome/fasta_index.py index e1b5f6b95..b28866ec1 100644 --- a/snpdb/genome/fasta_index.py +++ b/snpdb/genome/fasta_index.py @@ -21,6 +21,9 @@ def load_genome_fasta_index(genome_fasta: 'GenomeFasta', genome_build: 'GenomeBu NAMES = ["NAME", "LENGTH", "OFFSET", "LINEBASES", "LINEWIDTH"] df = pd.read_csv(index_filename, sep='\t', header=None, names=NAMES) + # Can't import GenomeFastaContig at module level (models_genome imports this module) + GenomeFastaContig = genome_fasta.genomefastacontig_set.model + genome_fasta_contigs = [] for _, row in df[["NAME", "LENGTH"]].iterrows(): name = row["NAME"] length = row["LENGTH"] @@ -31,10 +34,12 @@ def load_genome_fasta_index(genome_fasta: 'GenomeFasta', genome_build: 'GenomeBu msg = f"Fasta contig {name} length = {length} but {genome_build} contig {contig} length={contig.length}" raise ValueError(msg) - genome_fasta.genomefastacontig_set.create(name=name, - length=length, - contig=contig) + genome_fasta_contigs.append(GenomeFastaContig(genome_fasta=genome_fasta, + name=name, + length=length, + contig=contig)) else: logging.warning("Could not find contig in %s for fasta name '%s'", genome_build, name) + GenomeFastaContig.objects.bulk_create(genome_fasta_contigs, batch_size=2000) return genome_fasta diff --git a/snpdb/templates/snpdb/tags/related_data_for_cohort.html b/snpdb/templates/snpdb/tags/related_data_for_cohort.html index 449ac6365..f8dab7d29 100644 --- a/snpdb/templates/snpdb/tags/related_data_for_cohort.html +++ b/snpdb/templates/snpdb/tags/related_data_for_cohort.html @@ -12,10 +12,10 @@

Sub cohorts

{% endfor %} {% endif %} - {% if cohort.trio_set.exists %} + {% if trios %}