diff --git a/core/common/search.py b/core/common/search.py index a812a98d..bbb04fb3 100644 --- a/core/common/search.py +++ b/core/common/search.py @@ -11,6 +11,63 @@ from core.common.constants import ES_REQUEST_TIMEOUT from core.common.utils import is_url_encoded_string +from core.orgs.constants import ORG_OBJECT_TYPE +from core.users.constants import USER_OBJECT_TYPE + + +def get_document_public_visibility_criteria( + user, + include_creator_private_access=False, + include_owner_private_access=False, + include_organization_memberships=False, +): + """Return a shared Elasticsearch visibility criterion for owner-scoped documents.""" + criteria = Q('term', public_can_view=True) + if not getattr(user, 'is_authenticated', False): + return criteria + + private_criteria = None + username = getattr(user, 'username', None) + if username and include_creator_private_access: + private_criteria = Q('term', created_by=username) + + if username and include_owner_private_access: + owner_criteria = Q('term', owner_type=USER_OBJECT_TYPE) & Q('term', owner=username.lower()) + private_criteria = owner_criteria if private_criteria is None else private_criteria | owner_criteria + + if include_organization_memberships: + organization_mnemonics = [ + mnemonic.lower() for mnemonic in user.organizations.values_list('mnemonic', flat=True) + ] + if organization_mnemonics: + org_criteria = Q('term', owner_type=ORG_OBJECT_TYPE) & Q('terms', owner=organization_mnemonics) + private_criteria = org_criteria if private_criteria is None else private_criteria | org_criteria + + if private_criteria is None: + return criteria + + return criteria | (Q('term', public_can_view=False) & private_criteria) + + +def apply_document_public_visibility_filter( + search, + user, + include_creator_private_access=False, + include_owner_private_access=False, + include_organization_memberships=False, +): + """Apply a shared Elasticsearch visibility filter without changing staff searches.""" + if getattr(user, 'is_staff', False): + return search + + return search.filter( + get_document_public_visibility_criteria( + user, + include_creator_private_access=include_creator_private_access, + include_owner_private_access=include_owner_private_access, + include_organization_memberships=include_organization_memberships, + ) + ) class CustomESFacetedSearch(FacetedSearch): diff --git a/core/common/views.py b/core/common/views.py index 9c711d78..c3d6ca74 100644 --- a/core/common/views.py +++ b/core/common/views.py @@ -28,12 +28,20 @@ CANONICAL_URL_REQUEST_PARAM, CHECKSUMS_PARAM, ACCESS_TYPE_NONE from core.common.exceptions import Http400 from core.common.mixins import PathWalkerMixin -from core.common.search import CustomESSearch +from core.common.search import CustomESSearch, get_document_public_visibility_criteria from core.common.serializers import RootSerializer from core.common.swagger_parameters import all_resource_query_param from core.common.throttling import ThrottleUtil from core.common.utils import compact_dict_by_values, to_snake_case, parse_updated_since_param, \ to_int, get_falsy_values, get_truthy_values, format_url_for_search +from core.concepts.search import ( + get_concept_exact_search_criterion, + get_concept_fuzzy_search_criterion, + get_concept_mandatory_exclude_words_criteria, + get_concept_mandatory_words_criteria, + get_concept_search_rescore, + get_concept_wildcard_search_criterion, +) from core.concepts.permissions import CanViewParentDictionary, CanEditParentDictionary from core.orgs.constants import ORG_OBJECT_TYPE from core.users.constants import USER_OBJECT_TYPE @@ -292,6 +300,12 @@ def get_sort_attributes(self): return result def get_fuzzy_search_criterion(self, boost_divide_by=10, expansions=5): + if self.is_concept_document(): + return get_concept_fuzzy_search_criterion( + self.get_raw_search_string(), + boost_divide_by=boost_divide_by, + expansions=expansions, + ) return CustomESSearch.get_fuzzy_match_criterion( search_str=self.get_search_string(decode=False), fields=self.get_fuzzy_search_fields(), @@ -300,6 +314,11 @@ def get_fuzzy_search_criterion(self, boost_divide_by=10, expansions=5): ) def get_wildcard_search_criterion(self, search_str=None): + if self.is_concept_document(): + return get_concept_wildcard_search_criterion( + search_str or self.get_raw_search_string(), + include_map_codes=self.request.query_params.get(SEARCH_MAP_CODES_PARAM) not in get_falsy_values(), + ) fields = self.get_wildcard_search_fields() return CustomESSearch.get_wildcard_match_criterion( search_str=search_str or self.get_search_string(), @@ -307,6 +326,11 @@ def get_wildcard_search_criterion(self, search_str=None): ), fields.keys() def get_exact_search_criterion(self): + if self.is_concept_document(): + return get_concept_exact_search_criterion( + self.get_raw_search_string(), + include_map_codes=self.request.query_params.get(SEARCH_MAP_CODES_PARAM) not in get_falsy_values(), + ) match_phrase_field_list = self.document_model.get_match_phrase_attrs() match_word_fields_map = self.clean_fields(self.document_model.get_exact_match_attrs()) fields = match_phrase_field_list + list(match_word_fields_map.keys()) @@ -662,8 +686,8 @@ def is_user_scope(self): return False def get_public_criteria(self): - criteria = Q('term', public_can_view=True) user = self.request.user + criteria = Q('term', public_can_view=True) if user.is_authenticated: username = user.username @@ -671,7 +695,10 @@ def get_public_criteria(self): if self.document_model in [OrganizationDocument]: criteria |= (Q('term', public_can_view=False) & Q('term', user=username)) if self.is_concept_container_document_model() or self.is_source_child_document_model(): - criteria |= (Q('term', public_can_view=False) & Q('term', created_by=username)) + return get_document_public_visibility_criteria( + user, + include_creator_private_access=True, + ) return criteria @@ -884,42 +911,18 @@ def __get_search_results(self, ignore_retired_filter=False, sort=True, highlight sort_attrs = self._get_sort_attribute() if self.is_concept_document() and (not sort_attrs or '_score' in get(sort_attrs, '0', {})): - search_str = self.get_search_string(lower=False) - results = results.extra( - rescore={ - "window_size": 400, - "query": { - "score_mode": "total", - "query_weight": 1.0, - "rescore_query_weight": 800.0, - "rescore_query": { - "dis_max": { - "tie_breaker": 0.0, - "queries": [ - { - "constant_score": { - "filter": { "term": { "_name": { "value": search_str, "case_insensitive": True } } }, - "boost": 10.0 - } - }, - { - "constant_score": { - "filter": { "term": { "_synonyms": { "value": search_str, "case_insensitive": True } } }, - "boost": 8.0 - } - } - ] - } - } - } - } - ) + results = results.extra(rescore=get_concept_search_rescore(self.get_raw_search_string())) if fields and highlight and self.request.query_params.get(INCLUDE_SEARCH_META_PARAM) in get_truthy_values(): results = results.highlight(*self.clean_fields_for_highlight(fields)) results = results.source(excludes=['_synonyms_embeddings', '_embeddings']) return results.sort(*sort_attrs) if sort else results def get_mandatory_words_criteria(self): + if self.is_concept_document(): + return get_concept_mandatory_words_criteria( + self.get_raw_search_string(), + include_map_codes=self.request.query_params.get(SEARCH_MAP_CODES_PARAM) not in get_falsy_values(), + ) criterion = None for must_have in CustomESSearch.get_must_haves(self.get_raw_search_string()): criteria, _ = self.get_wildcard_search_criterion(f"{must_have}*") @@ -927,6 +930,11 @@ def get_mandatory_words_criteria(self): return criterion def get_mandatory_exclude_words_criteria(self): + if self.is_concept_document(): + return get_concept_mandatory_exclude_words_criteria( + self.get_raw_search_string(), + include_map_codes=self.request.query_params.get(SEARCH_MAP_CODES_PARAM) not in get_falsy_values(), + ) criterion = None for must_not_have in CustomESSearch.get_must_not_haves(self.get_raw_search_string()): criteria, _ = self.get_wildcard_search_criterion(f"{must_not_have}*") diff --git a/core/concepts/search.py b/core/concepts/search.py index eb5b23ed..f94bfcb4 100644 --- a/core/concepts/search.py +++ b/core/concepts/search.py @@ -4,8 +4,192 @@ from core.common.constants import FACET_SIZE, HEAD from core.common.search import CustomESFacetedSearch, CustomESSearch from core.common.utils import get_embeddings, is_canonical_uri +from core.concepts.documents import ConceptDocument from core.concepts.models import Concept +CONCEPT_FUZZY_BOOST_DIVIDE_BY = 10000 +CONCEPT_FUZZY_EXPANSIONS = 2 + + +def normalize_concept_search_query(query): + """Normalize raw concept search text so all callers share the same preprocessing.""" + return str(query or '').replace('"', '').replace("'", '').strip() + + +def filter_concept_search_fields(fields, include_map_codes=True): + """Optionally remove map-code fields to match the REST search toggle semantics.""" + if include_map_codes: + return fields + + if isinstance(fields, dict): + return {key: value for key, value in fields.items() if not key.endswith('map_codes')} + + return [field for field in fields if not field.endswith('map_codes')] + + +def get_concept_search_string(query, lower=True, decode=True): + """Return the normalized concept query string in the same format used by REST search.""" + return CustomESSearch.get_search_string( + normalize_concept_search_query(query), + lower=lower, + decode=decode, + ) + + +def get_concept_exact_search_criterion(query, include_map_codes=True): + """Build the exact-match clause used by both REST and GraphQL concept search.""" + match_phrase_field_list = ConceptDocument.get_match_phrase_attrs() + match_word_fields_map = filter_concept_search_fields( + ConceptDocument.get_exact_match_attrs(), + include_map_codes=include_map_codes, + ) + fields = match_phrase_field_list + list(match_word_fields_map.keys()) + return CustomESSearch.get_exact_match_criterion( + get_concept_search_string(query, lower=False, decode=False), + match_phrase_field_list, + match_word_fields_map, + ), fields + + +def get_concept_wildcard_search_criterion(query, include_map_codes=True): + """Build the wildcard clause used by both REST and GraphQL concept search.""" + fields = filter_concept_search_fields( + ConceptDocument.get_wildcard_search_attrs(), + include_map_codes=include_map_codes, + ) + return CustomESSearch.get_wildcard_match_criterion( + search_str=get_concept_search_string(query), + fields=fields, + ), list(fields.keys()) + + +def get_concept_fuzzy_search_criterion( + query, + boost_divide_by=CONCEPT_FUZZY_BOOST_DIVIDE_BY, + expansions=CONCEPT_FUZZY_EXPANSIONS, +): + """Build the fuzzy clause used by both REST and GraphQL concept search.""" + return CustomESSearch.get_fuzzy_match_criterion( + search_str=get_concept_search_string(query, decode=False), + fields=ConceptDocument.get_fuzzy_search_attrs(), + boost_divide_by=boost_divide_by, + expansions=expansions, + ) + + +def get_concept_mandatory_words_criteria(query, include_map_codes=True): + """Build the required-word wildcard clauses shared by REST and GraphQL.""" + criterion = None + for must_have in CustomESSearch.get_must_haves(normalize_concept_search_query(query)): + criteria, _ = get_concept_wildcard_search_criterion( + f"{must_have}*", + include_map_codes=include_map_codes, + ) + criterion = criteria if criterion is None else criterion & criteria + return criterion + + +def get_concept_mandatory_exclude_words_criteria(query, include_map_codes=True): + """Build the excluded-word wildcard clauses shared by REST and GraphQL.""" + criterion = None + for must_not_have in CustomESSearch.get_must_not_haves(normalize_concept_search_query(query)): + criteria, _ = get_concept_wildcard_search_criterion( + f"{must_not_have}*", + include_map_codes=include_map_codes, + ) + criterion = criteria if criterion is None else criterion | criteria + return criterion + + +def get_concept_search_rescore(query): + """Return the concept-specific ES rescore block shared by REST and GraphQL.""" + search_str = get_concept_search_string(query, lower=False) + return { + "window_size": 400, + "query": { + "score_mode": "total", + "query_weight": 1.0, + "rescore_query_weight": 800.0, + "rescore_query": { + "dis_max": { + "tie_breaker": 0.0, + "queries": [ + { + "constant_score": { + "filter": { + "term": { + "_name": { + "value": search_str, + "case_insensitive": True, + } + } + }, + "boost": 10.0, + } + }, + { + "constant_score": { + "filter": { + "term": { + "_synonyms": { + "value": search_str, + "case_insensitive": True, + } + } + }, + "boost": 8.0, + } + }, + ] + } + }, + }, + } + + +def apply_concept_text_search( + search, + query, + include_wildcard=True, + include_fuzzy=True, + include_map_codes=True, + fuzzy_boost_divide_by=CONCEPT_FUZZY_BOOST_DIVIDE_BY, + fuzzy_expansions=CONCEPT_FUZZY_EXPANSIONS, + include_rescore=False, +): + """Apply the shared concept text-search clauses to an Elasticsearch search object.""" + criterion, fields = get_concept_exact_search_criterion(query, include_map_codes=include_map_codes) + + if include_wildcard: + wildcard_criterion, wildcard_fields = get_concept_wildcard_search_criterion( + query, + include_map_codes=include_map_codes, + ) + criterion |= wildcard_criterion + fields += wildcard_fields + + if include_fuzzy: + criterion |= get_concept_fuzzy_search_criterion( + query, + boost_divide_by=fuzzy_boost_divide_by, + expansions=fuzzy_expansions, + ) + + search = search.query(criterion) + + must_have_criterion = get_concept_mandatory_words_criteria(query, include_map_codes=include_map_codes) + if must_have_criterion is not None: + search = search.filter(must_have_criterion) + + must_not_criterion = get_concept_mandatory_exclude_words_criteria(query, include_map_codes=include_map_codes) + if must_not_criterion is not None: + search = search.filter(~must_not_criterion) + + if include_rescore: + search = search.extra(rescore=get_concept_search_rescore(query)) + + return search, fields + class ConceptFacetedSearch(CustomESFacetedSearch): index = 'concepts' diff --git a/core/graphql/constants.py b/core/graphql/constants.py new file mode 100644 index 00000000..8be7f106 --- /dev/null +++ b/core/graphql/constants.py @@ -0,0 +1,40 @@ +"""Shared GraphQL error metadata used by views, resolvers, and tests.""" + +from strawberry.exceptions import GraphQLError + +AUTHENTICATION_FAILED = 'AUTHENTICATION_FAILED' +FORBIDDEN = 'FORBIDDEN' +SEARCH_UNAVAILABLE = 'SEARCH_UNAVAILABLE' + +GRAPHQL_ERROR_DEFINITIONS = { + AUTHENTICATION_FAILED: { + 'message': 'Authentication failure', + 'description': 'The provided credentials are invalid for the GraphQL API.', + }, + FORBIDDEN: { + 'message': 'Forbidden', + 'description': 'The current user cannot access the requested repository.', + }, + SEARCH_UNAVAILABLE: { + 'message': 'Search unavailable', + 'description': 'Global concept search requires Elasticsearch and is temporarily unavailable.', + }, +} +EXPECTED_GRAPHQL_ERROR_CODES = frozenset(GRAPHQL_ERROR_DEFINITIONS.keys()) + + +def build_expected_graphql_error(code): + """Return a GraphQL error with a stable code and a short client-facing description.""" + detail = GRAPHQL_ERROR_DEFINITIONS[code] + return GraphQLError( + detail['message'], + extensions={ + 'code': code, + 'description': detail['description'], + }, + ) + + +def get_graphql_error_code(error): + """Read the machine-readable error code attached to a GraphQL error when present.""" + return (getattr(error, 'extensions', None) or {}).get('code') diff --git a/core/graphql/permissions.py b/core/graphql/permissions.py new file mode 100644 index 00000000..3820f6cb --- /dev/null +++ b/core/graphql/permissions.py @@ -0,0 +1,136 @@ +"""Reusable permission helpers for GraphQL resolvers.""" + +from __future__ import annotations + +from functools import wraps +from types import SimpleNamespace +from typing import Any, Awaitable, Callable, Optional + +from asgiref.sync import sync_to_async +from django.contrib.auth.models import AnonymousUser +from strawberry.exceptions import GraphQLError + +from core.common.constants import ACCESS_TYPE_NONE +from core.common.permissions import CanViewConceptDictionary +from core.common.search import apply_document_public_visibility_filter + +from .constants import AUTHENTICATION_FAILED, FORBIDDEN, build_expected_graphql_error + +SOURCE_VERSION_CACHE_ATTR = '_graphql_source_version_cache' + + +def get_permission_target(instance, resolver): + """Return a resolver helper instance even when Strawberry passes a null root value.""" + if instance is not None: + return instance + + owner_name = resolver.__qualname__.split('.', 1)[0] + owner_class = resolver.__globals__.get(owner_name) + if owner_class is None: + raise GraphQLError('Resolver permission target is not available.') + return owner_class() + + +async def ensure_can_view_repo(user, source_version) -> None: + """Raise a GraphQL forbidden error when the repository is not visible to the user.""" + request = SimpleNamespace(user=user) + permission = CanViewConceptDictionary() + allowed = await sync_to_async( + permission.has_object_permission, + thread_sensitive=True, + )(request, None, source_version) + + if not allowed: + raise build_expected_graphql_error(FORBIDDEN) + + +def filter_global_queryset(qs, user): + """Apply the same global visibility rules used by the REST concept and mapping APIs.""" + if getattr(user, 'is_anonymous', True): + return qs.exclude(public_access=ACCESS_TYPE_NONE) + if not getattr(user, 'is_staff', False): + apply_user_criteria = getattr(qs.model, 'apply_user_criteria', None) + if apply_user_criteria: + return apply_user_criteria(qs, user) + return qs + + +def apply_es_visibility_filter(search, user): + """Mirror REST visibility rules in Elasticsearch so totals stay aligned with the DB.""" + return apply_document_public_visibility_filter( + search, + user, + include_owner_private_access=True, + include_organization_memberships=True, + ) + + +def check_user_permission( + resolver: Callable[..., Awaitable[Any]] +) -> Callable[..., Awaitable[Any]]: + """Deny repository-scoped access early while allowing global queries to continue.""" + + @wraps(resolver) + async def wrapper(self, info, *args, **kwargs): + permission_target = get_permission_target(self, resolver) + if getattr(info.context, 'auth_status', 'none') == 'invalid': + # Reject invalid credentials before repo resolution so private/public repos look the same. + raise build_expected_graphql_error(AUTHENTICATION_FAILED) + org = kwargs.get('org') + owner = kwargs.get('owner') + source = kwargs.get('source') + version = kwargs.get('version') + + if source and (org or owner): + source_version = await permission_target.get_source_version(info, org, owner, source, version) + user = getattr(info.context, 'user', AnonymousUser()) + await permission_target.ensure_can_view_repo(user, source_version) + + return await resolver(self, info, *args, **kwargs) + + return wrapper + + +class PermissionsMixin: + """Provide cached source resolution and shared permission helpers to resolvers.""" + + async def resolve_source_version_for_permissions( + self, + org: Optional[str], + owner: Optional[str], + source: str, + version: Optional[str], + ): + """Allow GraphQL query types to plug in their own source-version resolver.""" + raise NotImplementedError + + async def get_source_version( + self, + info, + org: Optional[str], + owner: Optional[str], + source: str, + version: Optional[str], + ): + """Resolve and cache the source version for the current GraphQL request.""" + cache = getattr(info.context, SOURCE_VERSION_CACHE_ATTR, None) or {} + cache_key = (org, owner, source, version) + if cache_key in cache: + return cache[cache_key] + + source_version = await self.resolve_source_version_for_permissions(org, owner, source, version) + cache[cache_key] = source_version + setattr(info.context, SOURCE_VERSION_CACHE_ATTR, cache) + return source_version + + async def ensure_can_view_repo(self, user, source_version) -> None: + """Delegate repository permission checks to the shared helper.""" + await ensure_can_view_repo(user, source_version) + + def filter_global_queryset(self, qs, user): + """Delegate global queryset visibility rules to the shared helper.""" + return filter_global_queryset(qs, user) + + def apply_es_visibility_filter(self, search, user): + """Delegate Elasticsearch visibility rules to the shared helper.""" + return apply_es_visibility_filter(search, user) diff --git a/core/graphql/queries.py b/core/graphql/queries.py index dbe2007e..74712b1e 100644 --- a/core/graphql/queries.py +++ b/core/graphql/queries.py @@ -6,19 +6,29 @@ import strawberry from asgiref.sync import sync_to_async -from django.db.models import Case, IntegerField, Prefetch, Q, When +from django.contrib.auth.models import AnonymousUser +from django.db.models import Case, F, IntegerField, Prefetch, Q, When from django.utils import timezone from elasticsearch import ConnectionError as ESConnectionError, TransportError -from elasticsearch_dsl import Q as ES_Q from pydash import get from strawberry.exceptions import GraphQLError from core.common.constants import HEAD from core.concepts.documents import ConceptDocument from core.concepts.models import Concept +from core.concepts.search import apply_concept_text_search from core.mappings.models import Mapping +from core.orgs.constants import ORG_OBJECT_TYPE from core.sources.models import Source - +from core.users.constants import USER_OBJECT_TYPE + +from .constants import SEARCH_UNAVAILABLE, build_expected_graphql_error +from .permissions import ( + PermissionsMixin, + apply_es_visibility_filter, + check_user_permission, + filter_global_queryset, +) from .types import ( CodedDatatypeDetails, ConceptNameType, @@ -67,8 +77,18 @@ class ConceptSearchResult: ) -async def resolve_source_version(org: str, source: str, version: Optional[str]) -> Source: - filters = {'organization__mnemonic': org} +async def resolve_source_version( + org: Optional[str], + owner: Optional[str], + source: str, + version: Optional[str], +) -> Source: + if org: + filters = {'organization__mnemonic': org} + elif owner: + filters = {'user__username': owner} + else: + raise GraphQLError("Either org or owner must be provided to resolve a source version.") target_version = version or HEAD instance = await sync_to_async(Source.get_version)(source, target_version, filters) @@ -76,15 +96,19 @@ async def resolve_source_version(org: str, source: str, version: Optional[str]) instance = await sync_to_async(Source.find_latest_released_version_by)({**filters, 'mnemonic': source}) if not instance: + owner_label = org or owner + owner_kind = "org" if org else "owner" raise GraphQLError( - f"Source '{source}' with version '{version or 'HEAD'}' was not found for org '{org}'." + f"Source '{source}' with version '{version or 'HEAD'}' was not found for {owner_kind} '{owner_label}'." ) return instance -def build_base_queryset(source_version: Source): - return source_version.get_concepts_queryset().filter(is_active=True, retired=False) +def build_base_queryset(source_version: Source = None): + if source_version: + return source_version.get_concepts_queryset().filter(is_active=True, retired=False) + return Concept.objects.filter(is_active=True, retired=False, id=F('versioned_object_id')) def build_mapping_prefetch(source_version: Source) -> Prefetch: @@ -103,7 +127,8 @@ def build_mapping_prefetch(source_version: Source) -> Prefetch: return Prefetch('mappings_from', queryset=mapping_qs, to_attr='graphql_mappings') -def build_global_mapping_prefetch() -> Prefetch: +def build_global_mapping_prefetch(user=None) -> Prefetch: + """Build the global mapping prefetch using the same visibility rules as REST list endpoints.""" mapping_qs = ( Mapping.objects.filter( from_concept_id__isnull=False, @@ -115,6 +140,8 @@ def build_global_mapping_prefetch() -> Prefetch: .distinct() ) + # Mapping visibility must be filtered independently because a public concept can still reference private mappings. + mapping_qs = filter_global_queryset(mapping_qs, user or AnonymousUser()) return Prefetch('mappings_from', queryset=mapping_qs, to_attr='graphql_mappings') @@ -353,6 +380,10 @@ def concept_ids_from_es( query: str, source_version: Optional[Source], pagination: Optional[dict], + owner: Optional[str] = None, + owner_type: Optional[str] = None, + version_label: Optional[str] = None, + user=None, ) -> Optional[tuple[list[int], int]]: trimmed = query.strip() if not trimmed: @@ -360,20 +391,23 @@ def concept_ids_from_es( try: search = ConceptDocument.search() + search = search.filter('term', retired=False) if source_version: - search = search.filter('term', source=source_version.mnemonic.lower()) - if source_version.is_head: + search = search.filter('term', source=source_version.mnemonic) + if owner and owner_type: + search = search.filter('term', owner=owner).filter('term', owner_type=owner_type) + + effective_version = version_label or HEAD + if effective_version == HEAD: + search = search.filter('term', source_version=HEAD) search = search.filter('term', is_latest_version=True) else: - search = search.filter('term', source_version=source_version.version) - search = search.filter('term', retired=False) + search = search.filter('term', source_version=effective_version) + else: + search = search.filter('term', is_latest_version=True) + search = apply_es_visibility_filter(search, user or AnonymousUser()) - should_queries = [ - ES_Q('match', id={'query': trimmed, 'boost': 6, 'operator': 'AND'}), - ES_Q('match_phrase_prefix', name={'query': trimmed, 'boost': 4}), - ES_Q('match', synonyms={'query': trimmed, 'boost': 2, 'operator': 'AND'}), - ] - search = search.query(ES_Q('bool', should=should_queries, minimum_should_match=1)) + search, _ = apply_concept_text_search(search, trimmed, include_rescore=True) if pagination: search = search[pagination['start']:pagination['end']] @@ -393,67 +427,79 @@ def concept_ids_from_es( return None -def fallback_db_search(base_qs, query: str): - trimmed = query.strip() - if not trimmed: - return base_qs.none() - return base_qs.filter( - Q(mnemonic__icontains=trimmed) | Q(names__name__icontains=trimmed) - ).distinct() - - async def concepts_for_ids( base_qs, concept_ids: Sequence[str], pagination: Optional[dict], mapping_prefetch: Prefetch, ) -> tuple[List[Concept], int]: - unique_ids = list(dict.fromkeys([cid for cid in concept_ids if cid])) - if not unique_ids: - raise GraphQLError('conceptIds must include at least one value when provided.') + """Fetch concepts by mnemonic while preserving the client-provided ordering.""" + ordered_ids = list(dict.fromkeys(concept_id for concept_id in concept_ids if concept_id)) + if not ordered_ids: + raise GraphQLError('conceptIds must contain at least one value.') - qs = base_qs.filter(mnemonic__in=unique_ids) - total = await sync_to_async(qs.count)() ordering = Case( - *[When(mnemonic=value, then=pos) for pos, value in enumerate(unique_ids)], - output_field=IntegerField() + *[When(mnemonic=concept_id, then=pos) for pos, concept_id in enumerate(ordered_ids)], + output_field=IntegerField(), ) - qs = qs.order_by(ordering, 'mnemonic') + qs = base_qs.filter(mnemonic__in=ordered_ids).order_by(ordering) + total = await sync_to_async(qs.count)() qs = apply_slice(qs, pagination) qs = with_concept_related(qs, mapping_prefetch) return await sync_to_async(list)(qs), total +def build_db_search_queryset(base_qs, query: str): + """Build the database fallback used when Elasticsearch is unavailable or stale.""" + trimmed = query.strip() + if not trimmed: + return base_qs.none() + + return base_qs.filter( + Q(names__name__icontains=trimmed) | Q(descriptions__name__icontains=trimmed) + ).distinct() + + async def concepts_for_query( base_qs, query: str, source_version: Source, pagination: Optional[dict], mapping_prefetch: Prefetch, + owner: Optional[str] = None, + owner_type: Optional[str] = None, + version_label: Optional[str] = None, + user=None, ) -> tuple[List[Concept], int]: - es_result = await sync_to_async(concept_ids_from_es)(query, source_version, pagination) + es_result = await sync_to_async(concept_ids_from_es)( + query, + source_version, + pagination, + owner=owner, + owner_type=owner_type, + version_label=version_label, + user=user, + ) if es_result is not None: concept_ids, total = es_result - if not concept_ids: - if total == 0: - logger.info( - 'ES returned zero hits for query="%s" in source "%s" version "%s". Falling back to DB search.', - query, - get(source_version, 'mnemonic'), - get(source_version, 'version'), - ) - else: - return [], total - else: + if concept_ids: ordering = Case( *[When(id=pk, then=pos) for pos, pk in enumerate(concept_ids)], output_field=IntegerField() ) qs = base_qs.filter(id__in=concept_ids).order_by(ordering) qs = with_concept_related(qs, mapping_prefetch) - return await sync_to_async(list)(qs), total + concepts = await sync_to_async(list)(qs) + if len(concepts) == len(concept_ids): + return concepts, total + elif total > 0: + return [], total - qs = fallback_db_search(base_qs, query).order_by('mnemonic') + if source_version is None: + # Global search is ES-backed because the DB fallback is both broader and much more expensive under outage. + raise build_expected_graphql_error(SEARCH_UNAVAILABLE) + + qs = build_db_search_queryset(base_qs, query).order_by('mnemonic') total = await sync_to_async(qs.count)() qs = apply_slice(qs, pagination) qs = with_concept_related(qs, mapping_prefetch) @@ -461,12 +507,24 @@ async def concepts_for_query( @strawberry.type -class Query: +class Query(PermissionsMixin): + async def resolve_source_version_for_permissions( + self, + org: Optional[str], + owner: Optional[str], + source: str, + version: Optional[str], + ) -> Source: + """Resolve repository versions through the shared GraphQL helper.""" + return await resolve_source_version(org, owner, source, version) + @strawberry.field(name="concepts") + @check_user_permission async def concepts( # pylint: disable=too-many-arguments,too-many-locals self, info, # pylint: disable=unused-argument org: Optional[str] = None, + owner: Optional[str] = None, source: Optional[str] = None, version: Optional[str] = None, conceptIds: Optional[List[str]] = None, @@ -474,29 +532,34 @@ async def concepts( # pylint: disable=too-many-arguments,too-many-locals page: Optional[int] = None, limit: Optional[int] = None, ) -> ConceptSearchResult: - if info.context.auth_status == 'none': - raise GraphQLError('Authentication required') - - if info.context.auth_status == 'invalid': - raise GraphQLError('Authentication failure') - + permission_target = self or Query() concept_ids_param = conceptIds or [] text_query = (query or '').strip() + user = getattr(info.context, 'user', AnonymousUser()) if not concept_ids_param and not text_query: raise GraphQLError('Either conceptIds or query must be provided.') pagination = normalize_pagination(page, limit) - if org and source: - source_version = await resolve_source_version(org, source, version) + if org and owner: + raise GraphQLError('Provide either org or owner, not both.') + + if source and not org and not owner: + raise GraphQLError('Either org or owner must be provided when source is specified.') + + owner_value = org or owner + owner_type = ORG_OBJECT_TYPE if org else (USER_OBJECT_TYPE if owner else None) + + if (org or owner) and source: + source_version = await permission_target.get_source_version(info, org, owner, source, version) base_qs = build_base_queryset(source_version) mapping_prefetch = build_mapping_prefetch(source_version) else: # Global search across all repositories source_version = None - base_qs = Concept.objects.filter(is_active=True, retired=False) - mapping_prefetch = build_global_mapping_prefetch() + base_qs = permission_target.filter_global_queryset(build_base_queryset(), user) + mapping_prefetch = build_global_mapping_prefetch(user) if concept_ids_param: concepts, total = await concepts_for_ids(base_qs, concept_ids_param, pagination, mapping_prefetch) @@ -507,6 +570,10 @@ async def concepts( # pylint: disable=too-many-arguments,too-many-locals source_version, pagination, mapping_prefetch, + owner=owner_value, + owner_type=owner_type, + version_label=version or HEAD if source_version else None, + user=user, ) serialized = await sync_to_async(serialize_concepts)(concepts) diff --git a/core/graphql/schema.py b/core/graphql/schema.py index 70874634..f91259ce 100644 --- a/core/graphql/schema.py +++ b/core/graphql/schema.py @@ -1,9 +1,21 @@ import strawberry from strawberry_django.optimizer import DjangoOptimizerExtension +from .constants import EXPECTED_GRAPHQL_ERROR_CODES, get_graphql_error_code from .queries import Query -schema = strawberry.Schema( + +class OCLGraphQLSchema(strawberry.Schema): + def process_errors(self, errors, execution_context=None): + # Expected business-rule failures should reach clients, but they should not be recorded as server errors. + unexpected_errors = [ + error for error in errors if get_graphql_error_code(error) not in EXPECTED_GRAPHQL_ERROR_CODES + ] + if unexpected_errors: + super().process_errors(unexpected_errors, execution_context) + + +schema = OCLGraphQLSchema( query=Query, extensions=[DjangoOptimizerExtension], ) diff --git a/core/graphql/tests/test_concepts_from_source.py b/core/graphql/tests/test_concepts_from_source.py index c4caf401..c972edf6 100644 --- a/core/graphql/tests/test_concepts_from_source.py +++ b/core/graphql/tests/test_concepts_from_source.py @@ -438,7 +438,9 @@ def test_fetch_concepts_for_specific_version(self): self.assertEqual(payload['versionResolved'], self.release_version.version) self.assertEqual(payload['results'][0]['conceptId'], self.concept1.mnemonic) - def test_fetch_concepts_global_search(self): + @mock.patch('core.graphql.queries.concept_ids_from_es') + def test_fetch_concepts_global_search(self, mock_es): + mock_es.return_value = ([self.concept1.id], 1) query = """ query GlobalConcepts($query: String!) { concepts(query: $query) { diff --git a/core/graphql/tests/test_graphql_view.py b/core/graphql/tests/test_graphql_view.py index 510f1b5c..0bd02964 100644 --- a/core/graphql/tests/test_graphql_view.py +++ b/core/graphql/tests/test_graphql_view.py @@ -9,6 +9,7 @@ from rest_framework.exceptions import AuthenticationFailed from core.common.tests import OCLTestCase +from core.graphql.constants import AUTHENTICATION_FAILED, SEARCH_UNAVAILABLE from core.graphql.tests.conftest import bootstrap_super_user, create_user_with_token @@ -94,7 +95,7 @@ def authenticate(self, request): token_type='Bearer', authentication_backend_class=ModelBackend, ) - ): + ), patch('strawberry.schema.base.StrawberryLogger.error') as error_logger: response = self._post_graphql( headers={"HTTP_AUTHORIZATION": "Bearer invalid-oidc-token"}, query=query @@ -104,3 +105,19 @@ def authenticate(self, request): self.assertEqual(response.status_code, 200) self.assertIn('errors', payload) self.assertIn('Authentication failure', payload['errors'][0]['message']) + self.assertEqual(payload['errors'][0]['extensions']['code'], AUTHENTICATION_FAILED) + error_logger.assert_not_called() + + @patch('core.graphql.queries.concept_ids_from_es', return_value=None) + def test_global_search_returns_explicit_error_when_es_is_unavailable(self, _mock_es): + headers = {"HTTP_AUTHORIZATION": f"Token {self.token.key}"} + query = "query { concepts(query:\"test\") { totalCount } }" + + with patch('strawberry.schema.base.StrawberryLogger.error') as error_logger: + response = self._post_graphql(headers=headers, query=query) + + payload = response.json() + self.assertEqual(response.status_code, 200) + self.assertIn('errors', payload) + self.assertEqual(payload['errors'][0]['extensions']['code'], SEARCH_UNAVAILABLE) + error_logger.assert_not_called() diff --git a/core/graphql/tests/test_query_helpers.py b/core/graphql/tests/test_query_helpers.py index 911e31ca..dcb90d07 100644 --- a/core/graphql/tests/test_query_helpers.py +++ b/core/graphql/tests/test_query_helpers.py @@ -11,7 +11,7 @@ from strawberry.django.views import AsyncGraphQLView from strawberry.exceptions import GraphQLError -from core.common.constants import HEAD +from core.common.constants import ACCESS_TYPE_NONE, ACCESS_TYPE_VIEW, HEAD from core.common.tests import OCLTestCase from core.concepts.models import Concept from core.concepts.tests.factories import ( @@ -31,7 +31,6 @@ concept_ids_from_es, concepts_for_ids, concepts_for_query, - fallback_db_search, format_datetime_for_api, has_next, normalize_pagination, @@ -47,6 +46,12 @@ serialize_names, with_concept_related, ) +from core.graphql.constants import ( + AUTHENTICATION_FAILED, + FORBIDDEN, + SEARCH_UNAVAILABLE, + build_expected_graphql_error, +) from core.graphql.schema import schema from core.graphql.tests.conftest import bootstrap_super_user, create_user_with_token from core.graphql.views import AuthenticatedGraphQLView @@ -54,6 +59,7 @@ from core.orgs.tests.factories import OrganizationFactory from core.sources.models import Source from core.sources.tests.factories import OrganizationSourceFactory +from core.users.tests.factories import UserProfileFactory class AuthenticatedGraphQLViewTests(OCLTestCase): @@ -128,6 +134,15 @@ def test_get_context_handles_session_and_token_states(self): self.assertEqual(context.user, self.user) self.assertEqual(context.auth_status, 'valid') + def test_schema_process_errors_skips_expected_business_errors(self): + with patch('strawberry.schema.base.StrawberryLogger.error') as error_logger: + schema.process_errors([build_expected_graphql_error(AUTHENTICATION_FAILED)]) + error_logger.assert_not_called() + + with patch('strawberry.schema.base.StrawberryLogger.error') as error_logger: + schema.process_errors([GraphQLError('unexpected boom')]) + error_logger.assert_called_once() + class QueryHelperTests(OCLTestCase): maxDiff = None @@ -221,7 +236,7 @@ def test_resolve_source_version_and_base_queries(self): ) with patch('core.graphql.queries.Source.get_version', return_value=self.source): success = async_to_sync(resolve_source_version)( - self.organization.mnemonic, self.source.mnemonic, None + self.organization.mnemonic, None, self.source.mnemonic, None ) self.assertEqual(success, self.source) @@ -229,12 +244,12 @@ def test_resolve_source_version_and_base_queries(self): 'core.graphql.queries.Source.find_latest_released_version_by', return_value=fallback_only ): resolved = async_to_sync(resolve_source_version)( - self.organization.mnemonic, fallback_only.mnemonic, None + self.organization.mnemonic, None, fallback_only.mnemonic, None ) self.assertEqual(resolved, fallback_only) with self.assertRaises(GraphQLError): async_to_sync(resolve_source_version)( - self.organization.mnemonic, 'missing-source', 'v-does-not-exist' + self.organization.mnemonic, None, 'missing-source', 'v-does-not-exist' ) base_qs = build_base_queryset(self.source) @@ -253,12 +268,41 @@ def test_resolve_source_version_and_base_queries(self): related_qs = with_concept_related(base_qs, mapping_prefetch) self.assertGreaterEqual(related_qs.count(), 2) + def test_build_global_mapping_prefetch_filters_private_mappings(self): + private_mapping = MappingFactory( + parent=self.source, + from_concept=self.concept1, + to_concept=self.concept2, + public_access=ACCESS_TYPE_NONE, + created_by=self.audit_user, + updated_by=self.audit_user, + ) + anonymous_qs = with_concept_related( + build_base_queryset(), + build_global_mapping_prefetch(AnonymousUser()), + ).filter(id=self.concept1.id) + anonymous_concept = list(anonymous_qs)[0] + self.assertTrue(all(mapping.public_access != ACCESS_TYPE_NONE for mapping in anonymous_concept.graphql_mappings)) + + member = UserProfileFactory( + username='graphql-mapping-member', + created_by=self.super_user, + updated_by=self.super_user, + ) + self.organization.members.add(member) + member_qs = with_concept_related( + build_base_queryset(), + build_global_mapping_prefetch(member), + ).filter(id=self.concept1.id) + member_concept = list(member_qs)[0] + self.assertTrue(any(mapping.id == private_mapping.id for mapping in member_concept.graphql_mappings)) + def test_resolve_source_version_error_path_and_pagination_defaults(self): with patch('core.graphql.queries.Source.get_version', return_value=None), patch( 'core.graphql.queries.Source.find_latest_released_version_by', return_value=None ): with self.assertRaises(GraphQLError): - async_to_sync(resolve_source_version)('ORG', 'SRC', None) + async_to_sync(resolve_source_version)('ORG', None, 'SRC', None) self.assertIsNone(normalize_pagination(None, None)) self.assertFalse(has_next(10, None)) @@ -398,6 +442,9 @@ def __getitem__(self, key): def params(self, **_kwargs): return self + def extra(self, **_kwargs): + return self + def execute(self): return FakeResponse(self._items, self._total) @@ -412,11 +459,49 @@ def execute(self): with patch('core.graphql.queries.ConceptDocument.search', side_effect=Exception('boom')): self.assertIsNone(concept_ids_from_es('text', self.source, None)) - def test_fallback_and_concepts_queries(self): - base_qs = build_base_queryset(self.source) - self.assertEqual(fallback_db_search(base_qs, ' ').count(), 0) - self.assertIn(self.concept1.id, list(fallback_db_search(base_qs, 'UTIL').values_list('id', flat=True))) + def test_concept_ids_from_es_applies_global_visibility_filter(self): + class RecordingResponse: + def __init__(self): + self.hits = SimpleNamespace(total=SimpleNamespace(value=0)) + + def __iter__(self): + return iter([]) + class RecordingSearch: + def __init__(self): + self.filters = [] + + def filter(self, *args, **kwargs): + self.filters.append((args, kwargs)) + return self + + def query(self, *_args, **_kwargs): + return self + + def __getitem__(self, _key): + return self + + def params(self, **_kwargs): + return self + + def extra(self, **_kwargs): + return self + + def execute(self): + return RecordingResponse() + + anonymous_search = RecordingSearch() + with patch('core.graphql.queries.ConceptDocument.search', return_value=anonymous_search): + concept_ids_from_es('shared', None, None, user=AnonymousUser()) + self.assertTrue( + any( + len(args) == 1 and not kwargs and 'public_can_view' in str(args[0]) + for args, kwargs in anonymous_search.filters + ) + ) + + def test_concepts_queries_behavior(self): + base_qs = build_base_queryset(self.source) mapping_prefetch = build_mapping_prefetch(self.source) with self.assertRaises(GraphQLError): async_to_sync(concepts_for_ids)(base_qs, [], normalize_pagination(1, 1), mapping_prefetch) @@ -441,7 +526,20 @@ def test_fallback_and_concepts_queries(self): concepts, total = async_to_sync(concepts_for_query)( base_qs, 'UTIL', self.source, normalize_pagination(1, 1), mapping_prefetch ) - self.assertGreaterEqual(total, 1) + self.assertEqual(total, 0) + self.assertEqual(concepts, []) + + with patch('core.graphql.queries.concept_ids_from_es', return_value=None): + with self.assertRaises(GraphQLError) as unavailable: + async_to_sync(concepts_for_query)( + build_base_queryset(), + 'UTIL', + None, + normalize_pagination(1, 1), + build_global_mapping_prefetch(AnonymousUser()), + user=AnonymousUser(), + ) + self.assertEqual(unavailable.exception.extensions['code'], SEARCH_UNAVAILABLE) with patch('core.graphql.queries.concept_ids_from_es', return_value=([], 2)): concepts, total = async_to_sync(concepts_for_query)( @@ -451,15 +549,16 @@ def test_fallback_and_concepts_queries(self): self.assertEqual(concepts, []) def test_query_concepts_auth_and_results(self): - info_none = SimpleNamespace(context=SimpleNamespace(auth_status='none')) + info_none = SimpleNamespace(context=SimpleNamespace(auth_status='none', user=AnonymousUser())) with self.assertRaises(GraphQLError): async_to_sync(Query().concepts)(info_none) - info_invalid = SimpleNamespace(context=SimpleNamespace(auth_status='invalid')) - with self.assertRaises(GraphQLError): + info_invalid = SimpleNamespace(context=SimpleNamespace(auth_status='invalid', user=AnonymousUser())) + with self.assertRaises(GraphQLError) as invalid: async_to_sync(Query().concepts)(info_invalid, query='test') + self.assertEqual(invalid.exception.extensions['code'], AUTHENTICATION_FAILED) - info_valid = SimpleNamespace(context=SimpleNamespace(auth_status='valid')) + info_valid = SimpleNamespace(context=SimpleNamespace(auth_status='valid', user=self.audit_user)) with self.assertRaises(GraphQLError): async_to_sync(Query().concepts)(info_valid) @@ -478,12 +577,12 @@ def test_query_concepts_auth_and_results(self): self.assertEqual(result_ids.limit, 1) with patch('core.graphql.queries.concept_ids_from_es', return_value=None): - result_query = async_to_sync(Query().concepts)( - info_valid, - query='UTIL', - ) - self.assertGreaterEqual(result_query.total_count, 1) - self.assertFalse(result_query.has_next_page) + with self.assertRaises(GraphQLError) as unavailable: + async_to_sync(Query().concepts)( + info_valid, + query='UTIL', + ) + self.assertEqual(unavailable.exception.extensions['code'], SEARCH_UNAVAILABLE) with patch('core.graphql.queries.concept_ids_from_es', return_value=([], 2)), patch( 'core.graphql.queries.resolve_source_version', return_value=self.source @@ -492,7 +591,139 @@ def test_query_concepts_auth_and_results(self): self.assertEqual(result_es_empty.total_count, 2) self.assertEqual(result_es_empty.results, []) - with patch('core.graphql.queries.resolve_source_version', return_value=self.source): - result_global = async_to_sync(Query().concepts)(info_valid, query='UTIL') + with patch('core.graphql.queries.concept_ids_from_es', return_value=([self.concept1.id], 1)): + result_global = async_to_sync(Query().concepts)( + info_valid, + query='UTIL', + ) self.assertIsNone(result_global.org) self.assertIsNone(result_global.source) + + def test_query_concepts_enforces_repo_permissions_and_filters_global_results(self): + private_org = OrganizationFactory( + mnemonic='PRIVATE', + created_by=self.super_user, + updated_by=self.super_user, + ) + private_source = OrganizationSourceFactory( + organization=private_org, + mnemonic='PRIVATE-SRC', + public_access=ACCESS_TYPE_NONE, + created_by=self.super_user, + updated_by=self.super_user, + ) + private_concept = ConceptFactory( + parent=private_source, + mnemonic='PRIVATE-CONCEPT', + public_access=ACCESS_TYPE_NONE, + created_by=self.audit_user, + updated_by=self.audit_user, + ) + ConceptNameFactory( + concept=private_concept, + name='Shared Visibility', + locale='en', + locale_preferred=True, + ) + + public_org = OrganizationFactory( + mnemonic='PUBLIC', + created_by=self.super_user, + updated_by=self.super_user, + ) + public_source = OrganizationSourceFactory( + organization=public_org, + mnemonic='PUBLIC-SRC', + public_access=ACCESS_TYPE_VIEW, + created_by=self.super_user, + updated_by=self.super_user, + ) + public_concept = ConceptFactory( + parent=public_source, + mnemonic='PUBLIC-CONCEPT', + public_access=ACCESS_TYPE_VIEW, + created_by=self.audit_user, + updated_by=self.audit_user, + ) + ConceptNameFactory( + concept=public_concept, + name='Shared Visibility', + locale='en', + locale_preferred=True, + ) + + outsider = UserProfileFactory( + username='graphql-outsider', + created_by=self.super_user, + updated_by=self.super_user, + ) + member = UserProfileFactory( + username='graphql-member', + created_by=self.super_user, + updated_by=self.super_user, + ) + private_org.members.add(member) + + anonymous_info = SimpleNamespace(context=SimpleNamespace(auth_status='none', user=AnonymousUser())) + outsider_info = SimpleNamespace(context=SimpleNamespace(auth_status='valid', user=outsider)) + member_info = SimpleNamespace(context=SimpleNamespace(auth_status='valid', user=member)) + invalid_info = SimpleNamespace(context=SimpleNamespace(auth_status='invalid', user=AnonymousUser())) + + with self.assertRaises(GraphQLError) as forbidden: + async_to_sync(Query().concepts)( + outsider_info, + org=private_org.mnemonic, + source=private_source.mnemonic, + conceptIds=[private_concept.mnemonic], + ) + self.assertEqual(str(forbidden.exception), 'Forbidden') + self.assertEqual(forbidden.exception.extensions['code'], FORBIDDEN) + + with self.assertRaises(GraphQLError) as invalid_private: + async_to_sync(Query().concepts)( + invalid_info, + org=private_org.mnemonic, + source=private_source.mnemonic, + conceptIds=[private_concept.mnemonic], + ) + self.assertEqual(str(invalid_private.exception), 'Authentication failure') + self.assertEqual(invalid_private.exception.extensions['code'], AUTHENTICATION_FAILED) + + with self.assertRaises(GraphQLError) as invalid_public: + async_to_sync(Query().concepts)( + invalid_info, + org=public_org.mnemonic, + source=public_source.mnemonic, + conceptIds=[public_concept.mnemonic], + ) + self.assertEqual(str(invalid_public.exception), 'Authentication failure') + self.assertEqual(invalid_public.exception.extensions['code'], AUTHENTICATION_FAILED) + + public_repo_result = async_to_sync(Query().concepts)( + anonymous_info, + org=public_org.mnemonic, + source=public_source.mnemonic, + conceptIds=[public_concept.mnemonic], + ) + self.assertEqual(public_repo_result.total_count, 1) + self.assertEqual(public_repo_result.results[0].concept_id, public_concept.mnemonic) + + with patch( + 'core.graphql.queries.concept_ids_from_es', + return_value=([public_concept.id], 1), + ): + anonymous_global = async_to_sync(Query().concepts)(anonymous_info, query='Shared Visibility') + self.assertEqual( + [concept.concept_id for concept in anonymous_global.results], + [public_concept.mnemonic], + ) + + with patch( + 'core.graphql.queries.concept_ids_from_es', + return_value=([private_concept.id, public_concept.id], 2), + ): + member_global = async_to_sync(Query().concepts)(member_info, query='Shared Visibility') + self.assertEqual( + {concept.concept_id for concept in member_global.results}, + {private_concept.mnemonic, public_concept.mnemonic}, + ) diff --git a/core/graphql/views.py b/core/graphql/views.py index eb1c8313..fc138fad 100644 --- a/core/graphql/views.py +++ b/core/graphql/views.py @@ -12,6 +12,7 @@ from django.middleware.csrf import CsrfViewMiddleware, get_token from django.utils.decorators import method_decorator from django.views.decorators.csrf import csrf_exempt +from graphql import ExecutionResult from rest_framework.authentication import get_authorization_header from rest_framework.exceptions import AuthenticationFailed from rest_framework.request import Request @@ -20,6 +21,8 @@ from core.common.authentication import OCLAuthentication from core.users.constants import GRAPHQL_API_GROUP +from .constants import AUTHENTICATION_FAILED, build_expected_graphql_error + # https://strawberry.rocks/docs/breaking-changes/0.243.0 GraphQL Strawberry needs manually handling CSRF @method_decorator(csrf_exempt, name='dispatch') @@ -104,3 +107,13 @@ def make_invalid(auth_status='invalid'): context.auth_status = 'valid' if getattr(user, 'is_authenticated', False) else 'invalid' return context + + async def execute_operation(self, request, context, root_value, sub_response): + # Invalid credentials should become a normal GraphQL error payload before resolver execution starts. + if getattr(context, 'auth_status', 'none') == 'invalid': + return ExecutionResult( + data=None, + errors=[build_expected_graphql_error(AUTHENTICATION_FAILED)], + ) + + return await super().execute_operation(request, context, root_value, sub_response)