Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions classification/tests/views/test_query_scaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
The classification datatable's query count must not grow with the number of
classification rows returned - per-row work in server renderers is the main
N+1 risk in DataTables endpoints (row data itself comes from .values()).
"""
from django.contrib.auth.models import User
from django.db import connection
from django.test import Client
from django.test.utils import CaptureQueriesContext
from django.urls import reverse

from annotation.fake_annotation import get_fake_annotation_version, create_fake_variants
from classification.autopopulate_evidence_keys.autopopulate_evidence_keys import \
create_classification_for_sample_and_variant_objects
from library.django_utils.unittest_utils import URLTestCase, production_query_count
from snpdb.models import GenomeBuild, Variant, Country, Lab, Organization


class ClassificationDatatableScalingTest(URLTestCase):
@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.user = User.objects.get_or_create(username='classification_scaling_user')[0]
organization = Organization.objects.get_or_create(name="Fake Org", group_name="fake_org")[0]
australia = Country.objects.get_or_create(name="Australia")[0]
cls.lab = Lab.objects.get_or_create(name="Fake Lab", city="Adelaide", country=australia,
organization=organization, group_name="fake_org/fake_lab")[0]
cls.lab.group.user_set.add(cls.user)

cls.genome_build = GenomeBuild.get_name_or_alias("GRCh37")
cls.annotation_version = get_fake_annotation_version(cls.genome_build)
create_fake_variants(cls.genome_build)
cls.variants = list(Variant.objects.filter(Variant.get_no_reference_q()).order_by("pk")[:10])
cls._create_classifications(cls.variants[:2])

@classmethod
def _create_classifications(cls, variants):
for variant in variants:
classification = create_classification_for_sample_and_variant_objects(
cls.user, cls.lab, None, variant, cls.genome_build,
annotation_version=cls.annotation_version)
classification.patch_value({"clinical_significance": "VUS"}, user=cls.user, save=True)
classification.publish_latest(cls.user)

def _datatable_production_query_count(self, client) -> int:
url = reverse('classification_datatables')
with CaptureQueriesContext(connection) as ctx:
response = client.get(url)
self.assertEqual(response.status_code, 200)
data = response.json()
self.assertGreaterEqual(len(data["data"]), 2)
return production_query_count(ctx.captured_queries)

def test_datatable_query_count_flat_with_more_rows(self):
client = Client()
client.force_login(self.user)
self._datatable_production_query_count(client) # warm up per-process caches

num_queries_two_rows = self._datatable_production_query_count(client)
self._create_classifications(self.variants[2:])
num_queries_ten_rows = self._datatable_production_query_count(client)
self.assertEqual(num_queries_two_rows, num_queries_ten_rows)
100 changes: 96 additions & 4 deletions library/django_utils/unittest_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,103 @@
import json
import logging
import os
import re
import time
import traceback
from collections import Counter
from collections.abc import Mapping
from contextlib import ExitStack

from django.db import connection
from django.test import Client, TestCase, override_settings
from django.test.utils import CaptureQueriesContext
from django.urls import reverse
from django.utils.http import urlencode

QUERY_PROFILE_FILE = os.environ.get("VG_QUERY_PROFILE")
QUERY_TRACE_PATTERN = os.environ.get("VG_QUERY_TRACE") # regex: log stack traces for matching SQL

# Models whose managers (ObjectManagerCachingImmutable/Request) cache lookups in production
# but disable caching under settings.UNIT_TEST - so repeated gets on these tables in a test
# are not real production queries. Used by production_query_count().
PRODUCTION_CACHED_TABLES = (
"snpdb_genomebuild",
"genes_genesymbol",
"flags_flagtype",
"classification_resolvedvariantinfo",
"snpdb_allele",
"snpdb_organization",
"snpdb_lab",
)


def production_query_count(captured_queries) -> int:
""" Number of queries that would hit the database in production: excludes test
savepoints and lookups on tables whose object managers cache in production """
count = 0
for query in captured_queries:
sql = query["sql"]
if sql.startswith(("SAVEPOINT", "RELEASE SAVEPOINT", "ROLLBACK TO SAVEPOINT")):
continue
if sql.startswith("SELECT") and any(f'FROM "{table}"' in sql for table in PRODUCTION_CACHED_TABLES):
continue
count += 1
return count


def _normalize_sql(sql: str) -> str:
""" Strip literals so queries differing only by parameters group together (N+1 detection) """
sql = re.sub(r"'[^']*'", "?", sql)
sql = re.sub(r"\b\d+(\.\d+)?\b", "?", sql)
sql = re.sub(r"IN \([^)]*\)", "IN (?)", sql)
return sql


class QueryProfilingClient(Client):
""" Captures SQL for each GET and appends a JSON line to VG_QUERY_PROFILE """

def _trace_wrapper(self, path):
def wrapper(execute, sql, params, many, context):
if re.search(QUERY_TRACE_PATTERN, sql):
stack = [line for line in traceback.format_stack()
if "/site-packages/" not in line and "unittest_utils" not in line]
with open(QUERY_PROFILE_FILE + ".trace", "a") as f:
f.write(json.dumps({"path": path, "sql": sql[:300], "stack": stack[-12:]}) + "\n")
return execute(sql, params, many, context)
return wrapper

def get(self, path, *args, **kwargs):
start = time.monotonic()
with ExitStack() as stack:
ctx = stack.enter_context(CaptureQueriesContext(connection))
if QUERY_TRACE_PATTERN:
stack.enter_context(connection.execute_wrapper(self._trace_wrapper(path)))
response = super().get(path, *args, **kwargs)
request_ms = (time.monotonic() - start) * 1000
real_queries = [q for q in ctx.captured_queries
if not q["sql"].startswith(("SAVEPOINT", "RELEASE SAVEPOINT", "ROLLBACK TO SAVEPOINT"))]
queries = [q["sql"] for q in real_queries]
duplicates = [{"count": count, "sql": sql}
for sql, count in Counter(_normalize_sql(sql) for sql in queries).most_common()
if count > 1]
record = {
"path": path,
"status": response.status_code,
"num_queries": len(queries),
"request_ms": round(request_ms, 1),
"sql_ms": round(sum(float(q["time"]) for q in real_queries) * 1000, 1),
"duplicates": duplicates,
}
with open(QUERY_PROFILE_FILE, "a") as f:
f.write(json.dumps(record) + "\n")
return response


def _make_test_client() -> Client:
if QUERY_PROFILE_FILE:
return QueryProfilingClient()
return Client()


def prevent_request_warnings(original_function):
"""
Expand Down Expand Up @@ -42,7 +134,7 @@ class URLTestCase(TestCase):
and contain the file asked. @see https://stackoverflow.com/a/51580328/295724 """

def _test_urls(self, names_and_kwargs, user=None, expected_code_override=None):
client = Client()
client = _make_test_client()
if user:
client.force_login(user)

Expand All @@ -59,7 +151,7 @@ def _test_urls(self, names_and_kwargs, user=None, expected_code_override=None):
self.assertEqual(response.status_code, expected_code, msg=f"Url '{url}'")

def _test_autocomplete_urls(self, names_obj_kwargs, user, in_results):
client = Client()
client = _make_test_client()
client.force_login(user)

for name, obj, get_kwargs in names_obj_kwargs:
Expand Down Expand Up @@ -95,7 +187,7 @@ def _test_datatable_urls(self, names_and_kwargs, user=None, expected_code_overri
self._test_urls(datatable_definition_names_and_kwargs, user=user, expected_code_override=expected_code_override)

def _test_datatables_grid_urls_contains_objs(self, names_obj, user, in_results):
client = Client()
client = _make_test_client()
client.force_login(user)

for name, kwargs, obj in names_obj:
Expand Down Expand Up @@ -139,7 +231,7 @@ def _test_datatables_grid_urls_contains_objs(self, names_obj, user, in_results):
def _test_jqgrid_urls_contains_objs(self, names_obj, user, in_results):
""" Also allow 403 if not expecting results
TODO: Load grid properly call URL with sidx params etc, currently get UnorderedObjectListWarning """
client = Client()
client = _make_test_client()
client.force_login(user)

for name, kwargs, obj in names_obj:
Expand Down
39 changes: 24 additions & 15 deletions ontology/models/models_ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
import functools
import logging
import operator
import re
from collections import defaultdict
from dataclasses import dataclass
Expand Down Expand Up @@ -789,14 +790,19 @@ def in_ontology_version(ontology_import: OntologyImport) -> bool:

@staticmethod
def latest(validate=True) -> Optional['OntologyVersion']:
oi_qs = OntologyImport.objects.all()
# Fetch candidates for all import fields in one query, then pick the latest per field
import_q_list = [Q(import_source=import_source, filename__in=filenames)
for import_source, filenames in OntologyVersion.ONTOLOGY_IMPORTS.values()]
candidate_imports = OntologyImport.objects.filter(functools.reduce(operator.or_, import_q_list)).order_by("pk")

kwargs = {}
missing_fields = set()
for field, (import_source, filenames) in OntologyVersion.ONTOLOGY_IMPORTS.items():
if ont_import := oi_qs.filter(import_source=import_source, filename__in=filenames).order_by("pk").last():
kwargs[field] = ont_import
elif field not in OntologyVersion.OPTIONAL_IMPORTS:
missing_fields.add(field)
for ont_import in candidate_imports: # ordered by pk so the last match per field wins
for field, (import_source, filenames) in OntologyVersion.ONTOLOGY_IMPORTS.items():
if ont_import.import_source == import_source and ont_import.filename in filenames:
kwargs[field] = ont_import

missing_fields = {field for field in OntologyVersion.ONTOLOGY_IMPORTS
if field not in kwargs and field not in OntologyVersion.OPTIONAL_IMPORTS}

if not missing_fields:
values = list(kwargs.values())
Expand All @@ -815,14 +821,17 @@ def latest(validate=True) -> Optional['OntologyVersion']:
ontology_version = None
return ontology_version

def get_ontology_imports(self):
return [ont_import for ont_import in [
self.gencc_import,
self.mondo_import,
self.hp_owl_import,
self.hp_phenotype_to_genes_import,
self.omim_import
] if ont_import is not None]
def get_ontology_imports(self) -> QuerySet[OntologyImport]:
""" Lazy QuerySet - using it in __in filters becomes a subquery (no extra queries,
unlike accessing the FK fields which lazy-loads each OntologyImport individually) """
import_ids = [import_id for import_id in [
self.gencc_import_id,
self.mondo_import_id,
self.hp_owl_import_id,
self.hp_phenotype_to_genes_import_id,
self.omim_import_id
] if import_id is not None]
return OntologyImport.objects.filter(pk__in=import_ids)

def get_ontology_term_relations(self):
return OntologyTermRelation.objects.filter(from_import__in=self.get_ontology_imports())
Expand Down
30 changes: 30 additions & 0 deletions ontology/tests/test_query_counts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
Locks in query counts for OntologyVersion paths that previously had N+1
query patterns (one query per import field). If a count here grows with the
number of OntologyImports, that's an N+1 regression to fix.
"""
from django.test import TestCase

from ontology.models import OntologyVersion
from ontology.tests.test_data_ontology import create_test_ontology_version


class OntologyVersionQueryCountTest(TestCase):
@classmethod
def setUpTestData(cls):
cls.ontology_version = create_test_ontology_version()

def test_latest_query_count(self):
# 2 queries: one for all candidate OntologyImports, one for get_or_create's get
with self.assertNumQueries(2):
latest = OntologyVersion.latest()
self.assertEqual(latest, self.ontology_version)

def test_get_ontology_imports_is_lazy(self):
# Building the QuerySet must not hit the database (so __in filters use a subquery)
with self.assertNumQueries(0):
imports_qs = self.ontology_version.get_ontology_imports()
# Evaluating it is a single query for all import fields
with self.assertNumQueries(1):
imports = list(imports_qs)
self.assertEqual(len(imports), 5)
11 changes: 8 additions & 3 deletions snpdb/genome/fasta_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def load_genome_fasta_index(genome_fasta: 'GenomeFasta', genome_build: 'GenomeBu

NAMES = ["NAME", "LENGTH", "OFFSET", "LINEBASES", "LINEWIDTH"]
df = pd.read_csv(index_filename, sep='\t', header=None, names=NAMES)
# Can't import GenomeFastaContig at module level (models_genome imports this module)
GenomeFastaContig = genome_fasta.genomefastacontig_set.model
genome_fasta_contigs = []
for _, row in df[["NAME", "LENGTH"]].iterrows():
name = row["NAME"]
length = row["LENGTH"]
Expand All @@ -31,10 +34,12 @@ def load_genome_fasta_index(genome_fasta: 'GenomeFasta', genome_build: 'GenomeBu
msg = f"Fasta contig {name} length = {length} but {genome_build} contig {contig} length={contig.length}"
raise ValueError(msg)

genome_fasta.genomefastacontig_set.create(name=name,
length=length,
contig=contig)
genome_fasta_contigs.append(GenomeFastaContig(genome_fasta=genome_fasta,
name=name,
length=length,
contig=contig))
else:
logging.warning("Could not find contig in %s for fasta name '%s'", genome_build, name)

GenomeFastaContig.objects.bulk_create(genome_fasta_contigs, batch_size=2000)
return genome_fasta
4 changes: 2 additions & 2 deletions snpdb/templates/snpdb/tags/related_data_for_cohort.html
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ <h3>Sub cohorts</h3>
{% endfor %}
{% endif %}

{% if cohort.trio_set.exists %}
{% if trios %}
<div id='related-trios'>
<h3>Trios</h3>
{% for trio in cohort.trio_set.all %}
{% for trio in trios %}
<div>
<a href="{% url 'view_trio' trio.pk %}">{{ trio }}</a> - {% trio_short_description trio %}
</div>
Expand Down
38 changes: 27 additions & 11 deletions snpdb/templatetags/related_data_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@

from classification.models import Classification
from classification.views.classification_datatables import ClassificationColumns
from pedigree.models import CohortSamplePedFileRecord
from snpdb.models import CohortSample, Trio
from snpdb.models.models_enums import ImportStatus

register = Library()

TRIO_SAMPLES_SELECT_RELATED = ("mother__sample", "father__sample", "proband__sample")


def related_data_context(context, samples):
tag_context = {
Expand Down Expand Up @@ -47,17 +51,28 @@ def related_data_for_samples(context, samples, show_sample_info=True):
trios_and_samples_list = defaultdict(list)
pedigrees_and_samples_list = defaultdict(list)

for sample in samples:
for cs in sample.cohortsample_set.all():
cohort = cs.cohort
if cohort.import_status == ImportStatus.SUCCESS:
cohorts_and_samples_list[cohort].append(sample.name)

for trio in cohort.trio_set.filter(Q(mother=cs) | Q(father=cs) | Q(proband=cs)):
trios_and_samples_list[trio].append(sample.name)

for pedigree in cohort.pedigree_set.filter(cohortsamplepedfilerecord__cohort_sample=cs).distinct():
pedigrees_and_samples_list[pedigree].append(sample.name)
cohort_samples = list(CohortSample.objects.filter(sample__in=samples).select_related("cohort", "sample"))
successful_cohort_samples = [cs for cs in cohort_samples
if cs.cohort.import_status == ImportStatus.SUCCESS]
for cs in successful_cohort_samples:
cohorts_and_samples_list[cs.cohort].append(cs.sample.name)

if successful_cohort_samples:
cs_ids = {cs.pk for cs in successful_cohort_samples}
trio_qs = Trio.objects.filter(Q(mother__in=cs_ids) | Q(father__in=cs_ids) | Q(proband__in=cs_ids))
for trio in trio_qs.select_related(*TRIO_SAMPLES_SELECT_RELATED):
for cs in trio.get_cohort_samples():
if cs.pk in cs_ids:
trios_and_samples_list[trio].append(cs.sample.name)

if cohort_samples:
seen_pedigree_cohort_samples = set()
record_qs = CohortSamplePedFileRecord.objects.filter(cohort_sample__in=cohort_samples)
for record in record_qs.select_related("pedigree", "cohort_sample__sample"):
pair = (record.pedigree_id, record.cohort_sample_id)
if pair not in seen_pedigree_cohort_samples:
seen_pedigree_cohort_samples.add(pair)
pedigrees_and_samples_list[record.pedigree].append(record.cohort_sample.sample.name)

cohorts_and_samples = summarise_samples(cohorts_and_samples_list, show_sample_info)
trios_and_samples = summarise_samples(trios_and_samples_list, show_sample_info)
Expand All @@ -75,6 +90,7 @@ def related_data_for_samples(context, samples, show_sample_info=True):
def related_data_for_cohort(context, cohort):
context = related_data_context(context, cohort.get_samples())
context["cohort"] = cohort
context["trios"] = list(cohort.trio_set.select_related(*TRIO_SAMPLES_SELECT_RELATED))
return context


Expand Down
Loading
Loading