diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ba9c15eb..2a251e69 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,28 +4,13 @@ on: push jobs: check: + runs-on: ubuntu-latest + strategy: matrix: - # Testing Python 3.7 (deb 10), Python 3.9 (deb 11), Python 3.11 (deb 12) - python-version: ["3.9", "3.11"] - django-version: ["3.2.25", "4.2.17", "5.1.4"] - database-engine: ["postgres", "mysql"] - os: [ubuntu-latest] - include: - # 3.7 cannot run on latest ubuntu - - python-version: 3.7 - django-version: 3.2.25 - database-engine: postgres - os: ubuntu-22.04 - - python-version: 3.7 - django-version: 3.2.25 - database-engine: mysql - os: ubuntu-22.04 - exclude: - - python-version: 3.9 - django-version: 5.1.4 - - runs-on: ${{ matrix.os }} + python-version: [ "3.7", "3.8" ] + django-version: [ "2.1.1", "3.1.4" ] + database-engine: [ "postgres", "mysql", "mssql"] services: postgres: @@ -52,19 +37,35 @@ jobs: ports: - 3306:3306 + + mssqldb: + image: mcr.microsoft.com/mssql/server:2017-latest + env: + ACCEPT_EULA: y + SA_PASSWORD: Test + + steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v2 - name: Setup python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} + - name: Retrieve cached venv + uses: actions/cache@v1 + id: cache-venv + with: + path: ./.venv/ + key: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.django-version }}-venv-${{ hashFiles('ci-requirements.txt') }} + - name: Install requirements run: | python -m venv .venv - .venv/bin/pip install django==${{ matrix.django-version }} -r ci-requirements.txt + .venv/bin/pip install -qr ci-requirements.txt django==${{ matrix.django-version }} + if: steps.cache-venv.outputs.cache-hit != 'true' - name: Run linting run: .venv/bin/flake8 binder @@ -82,7 +83,7 @@ jobs: - name: Run tests run: | - .venv/bin/coverage run --include="binder/*" -m unittest discover -vt . -s tests + .venv/bin/coverage run --include="binder/*" setup.py test env: BINDER_TEST_MYSQL: ${{ matrix.database-engine == 'mysql' && 1 || 0 }} CY_RUNNING_INSIDE_CI: 1 diff --git a/binder/exceptions.py b/binder/exceptions.py index d5aec299..75e4e572 100644 --- a/binder/exceptions.py +++ b/binder/exceptions.py @@ -235,3 +235,15 @@ def __add__(self, other): else: errors[model] = other.errors[model] return BinderValidationError(errors) + + +class BinderSkipSave(BinderException): + """Used to abort the database transaction when validation was successfull. + Validation is possible when saving (post, put, multi-put) or deleting models.""" + + http_code = 200 + code = 'SkipSave' + + def __init__(self): + super().__init__() + self.fields['message'] = 'No validation errors were encountered.' diff --git a/binder/models.py b/binder/models.py index 685694e0..df7525c2 100644 --- a/binder/models.py +++ b/binder/models.py @@ -617,8 +617,12 @@ class Meta: abstract = True ordering = ['pk'] - def save(self, *args, **kwargs): - self.full_clean() # Never allow saving invalid models! + def save(self, *args, only_validate=False, **kwargs): + # A validation model might not require all validation checks as it is not a full model + # _validation_model can be used to skip validation checks that are meant for complete models that are actually being saved + self._validation_model = only_validate # Set the model as a validation model when we only want to validate the model + + self.full_clean() # Never allow saving invalid models! return super().save(*args, **kwargs) diff --git a/binder/plugins/views/csvexport.py b/binder/plugins/views/csvexport.py index 1afd1794..f9ed6d26 100644 --- a/binder/plugins/views/csvexport.py +++ b/binder/plugins/views/csvexport.py @@ -15,7 +15,8 @@ class ExportFileAdapter: """ __metaclass__ = abc.ABCMeta - def __init__(self, request: HttpRequest): + def __init__(self, request: HttpRequest, csv_settings: 'CsvExportView.CsvExportSettings'): + self.csv_settings = csv_settings self.request = request @abc.abstractmethod @@ -69,8 +70,8 @@ class CsvFileAdapter(ExportFileAdapter): Adapter for returning CSV files """ - def __init__(self, request: HttpRequest): - super().__init__(request) + def __init__(self, request: HttpRequest, csv_settings: 'CsvExportView.CsvExportSettings'): + super().__init__(request, csv_settings) self.response = HttpResponse(content_type='text/csv') self.file_name = 'export' self.writer = csv.writer(self.response) @@ -79,7 +80,7 @@ def set_file_name(self, file_name: str): self.file_name = file_name def set_columns(self, columns: List[str]): - self.add_row(columns) + self.writer.writerow(list(map(lambda x: x[1], self.csv_settings.column_map))) def add_row(self, values: List[str]): self.writer.writerow(values) @@ -91,10 +92,10 @@ def get_response(self) -> HttpResponse: class ExcelFileAdapter(ExportFileAdapter): """ - Adapter for returning excel files + Adapter fore returning excel files """ - def __init__(self, request: HttpRequest): - super().__init__(request) + def __init__(self, request: HttpRequest, csv_settings: 'CsvExportView.CsvExportSettings'): + super().__init__(request, csv_settings) # Import pandas locally. This means that you can use the CSV adapter without using pandas import openpyxl @@ -103,7 +104,7 @@ def __init__(self, request: HttpRequest): # self.writer = self.pandas.ExcelWriter(self.response) self.work_book = self.openpyxl.Workbook() - self.sheet = self.work_book.active + self.sheet = self.work_book._sheets[0] # The row number we are currently writing to self._row_number = 0 @@ -115,7 +116,7 @@ def set_columns(self, columns: List[str]): self.add_row(columns) def add_row(self, values: List[str]): - for (column_id, value) in enumerate(values): + for (value, column_id) in zip(values, range(1000000)): self.sheet.cell(column=column_id + 1, row=self._row_number + 1, value=value) self._row_number += 1 @@ -129,9 +130,6 @@ def get_response(self) -> HttpResponse: self.response['Content-Disposition'] = 'attachment; filename="{}.xlsx"'.format(self.file_name) return self.response -DEFAULT_RESPONSE_TYPE_MAPPING = { - 'xlsx': ExcelFileAdapter, -} class RequestAwareAdapter(ExportFileAdapter): """ @@ -141,14 +139,14 @@ class RequestAwareAdapter(ExportFileAdapter): returns a xlsx type """ - def __init__(self, request: HttpRequest): - super().__init__(request) + def __init__(self, request: HttpRequest, csv_settings: 'CsvExportView.CsvExportSettings'): + super().__init__(request, csv_settings) - response_type_mapping = DEFAULT_RESPONSE_TYPE_MAPPING response_type = request.GET.get('response_type', '').lower() - AdapterClass = response_type_mapping.get(response_type, CsvFileAdapter) - - self.base_adapter = AdapterClass(request) + AdapterClass = CsvFileAdapter + if response_type == 'xlsx': + AdapterClass = ExcelFileAdapter + self.base_adapter = AdapterClass(request, csv_settings) def set_file_name(self, file_name: str): return self.base_adapter.set_file_name(file_name) @@ -163,8 +161,6 @@ def get_response(self) -> HttpResponse: return self.base_adapter.get_response() - - class CsvExportView: """ This class adds another endpoint to the ModelView, namely GET model/download/. This does the same thing as getting a @@ -182,7 +178,7 @@ class CsvExportSettings: """ def __init__(self, withs, column_map, file_name=None, default_file_name='download', multi_value_delimiter=' ', - extra_permission=None, extra_params={}, csv_adapter=RequestAwareAdapter, limit=10000): + extra_permission=None, csv_adapter=RequestAwareAdapter): """ @param withs: String[] An array of all the withs that are necessary for this csv export @param column_map: Tuple[] An array, with all columns of the csv file in order. Each column is represented by a tuple @@ -194,9 +190,6 @@ def __init__(self, withs, column_map, file_name=None, default_file_name='downloa as delimiter between them. This may be if an array is returned, or if we have a one to many relation @param extra_permission: String When set, an extra binder permission check will be done on this permission. @param csv_adapter: Class. Either an object extending - @param response_type_mapping: Mapping between the parameter used in the custom response type - @param limit: Limit for amount of items in the csv. This is a fail save that you do not bring down the server with - a big query """ self.withs = withs self.column_map = column_map @@ -204,10 +197,7 @@ def __init__(self, withs, column_map, file_name=None, default_file_name='downloa self.default_file_name = default_file_name self.multi_value_delimiter = multi_value_delimiter self.extra_permission = extra_permission - self.extra_params = extra_params self.csv_adapter = csv_adapter - self.limit = limit - def _generate_csv_file(self, request: HttpRequest, file_adapter: CsvFileAdapter): @@ -220,10 +210,8 @@ def _generate_csv_file(self, request: HttpRequest, file_adapter: CsvFileAdapter) mutable = request.POST._mutable request.GET._mutable = True request.GET['page'] = 1 - request.GET['limit'] = self.csv_settings.limit if self.csv_settings.limit is not None else 'none' + request.GET['limit'] = 10000 request.GET['with'] = ",".join(self.csv_settings.withs) - for key, value in self.csv_settings.extra_params.items(): - request.GET[key] = value request.GET._mutable = mutable parent_result = self.get(request) @@ -269,6 +257,8 @@ def get_datum(data, key, prefix=''): if '.' not in key: if key not in data: raise Exception("{} not found in data: {}".format(key, data)) + if type(data[key]) == list: + return self.csv_settings.multi_value_delimiter.join(data[key]) return data[key] else: """ @@ -284,17 +274,28 @@ def get_datum(data, key, prefix=''): head_key, subkey = key.split('.', 1) if head_key in data: new_prefix = '{}.{}'.format(prefix, head_key) - if isinstance(data[head_key], dict): + if type(data[head_key]) == dict: return get_datum(data[head_key], subkey, new_prefix) else: # Assume that we have a mapping now fk_ids = data[head_key] - if not isinstance(fk_ids, list): + + if fk_ids is None: + # This case happens if we have a nullable foreign key that is null. Treat this as a many + # to one relation with no values. + fk_ids = [] + elif type(fk_ids) != list: fk_ids = [fk_ids] # if head_key not in key_mapping: prefix_key = parent_data['with_mapping'][new_prefix[1:]] - datums = [str(get_datum(key_mapping[prefix_key][fk_id], subkey, new_prefix)) for fk_id in fk_ids] + datums = [] + for fk_id in fk_ids: + try: + datums.append(str(get_datum(key_mapping[prefix_key][fk_id], subkey, new_prefix))) + except KeyError: + pass + # datums = [str(get_datum(key_mapping[prefix_key][fk_id], subkey, new_prefix)) for fk_id in fk_ids] return self.csv_settings.multi_value_delimiter.join( datums ) @@ -309,8 +310,6 @@ def get_datum(data, key, prefix=''): if len(col_definition) >= 3: transform_function = col_definition[2] datum = transform_function(datum, row, key_mapping) - if isinstance(datum, list): - datum = self.csv_settings.multi_value_delimiter.join(datum) data.append(datum) file_adapter.add_row(data) @@ -325,7 +324,7 @@ def download(self, request): if self.csv_settings is None: raise Exception('No csv settings set!') - file_adapter = self.csv_settings.csv_adapter(request) + file_adapter = self.csv_settings.csv_adapter(request, self.csv_settings) self._generate_csv_file(request, file_adapter) diff --git a/binder/views.py b/binder/views.py index 5285e529..04922aa4 100644 --- a/binder/views.py +++ b/binder/views.py @@ -14,14 +14,13 @@ import django from django.views.generic import View -from django.conf import settings from django.core.exceptions import ObjectDoesNotExist, FieldError, ValidationError, FieldDoesNotExist from django.core.files.base import File, ContentFile -from django.http import HttpResponse, HttpResponseForbidden, FileResponse -from django.http.request import RawPostDataException +from django.http import HttpResponse, StreamingHttpResponse, HttpResponseForbidden +from django.http.request import RawPostDataException, QueryDict from django.http.multipartparser import MultiPartParser from django.db import models, connections -from django.db.models import Q, F, Count, Case, When +from django.db.models import Q, F from django.db.models.lookups import Transform from django.utils import timezone from django.db import transaction @@ -29,72 +28,15 @@ from django.db.models.fields.reverse_related import ForeignObjectRel -from .exceptions import BinderException, BinderFieldTypeError, BinderFileSizeExceeded, BinderForbidden, BinderImageError, BinderImageSizeExceeded, BinderInvalidField, BinderIsDeleted, BinderIsNotDeleted, BinderMethodNotAllowed, BinderNotAuthenticated, BinderNotFound, BinderReadOnlyFieldError, BinderRequestError, BinderValidationError, BinderFileTypeIncorrect, BinderInvalidURI +from .exceptions import ( + BinderException, BinderFieldTypeError, BinderFileSizeExceeded, BinderForbidden, BinderImageError, BinderImageSizeExceeded, + BinderInvalidField, BinderIsDeleted, BinderIsNotDeleted, BinderMethodNotAllowed, BinderNotAuthenticated, BinderNotFound, + BinderReadOnlyFieldError, BinderRequestError, BinderValidationError, BinderFileTypeIncorrect, BinderInvalidURI, BinderSkipSave +) from . import history from .orderable_agg import OrderableArrayAgg, GroupConcat, StringAgg from .models import FieldFilter, BinderModel, ContextAnnotation, OptionalAnnotation, BinderFileField, BinderImageField -from .json import JsonResponse, jsonloads, jsondumps -from .route_decorators import list_route - - -# expr: an aggregate expr to get the statistic, -# filter: a dict of filters to filter the queryset with before getting the aggregate, leading dot not included (optional), -# group_by: a field to group by separated by dots if following relations (optional), -# annotations: a list of annotation names that have to be applied to the queryset for the expr to work (optional), -Stat = namedtuple( - 'Stat', - ['expr', 'filters', 'group_by', 'annotations', 'min_value', 'max_values'], - defaults=[{}, None, [], None, None], -) - - -DEFAULT_STATS = { - 'total_records': Stat(Count(Value(1))), -} - - -def get_joins_from_queryset(queryset): - """ - Given a queryset returns a set of lines that are used to determine which - tables will be joined and how. In essence this is the FROM-statement and - every JOIN-statement in a set as a string. - - This is useful to compare the joins between querysets. - """ - # So to generate sql we need the compiler and connection for the right db - compiler = queryset.query.get_compiler(queryset.db) - connection = connections[queryset.db] - # Now we will just go through all tables in the alias_map - lines = set() - for alias in queryset.query.alias_map.values(): - line, params = alias.as_sql(compiler, connection) - # We assert we have no params for now, you need to do custom stuff for - # these to appear in joins and substituting params into the sql - # is not something we can easily do safely for now, just passing them - # along with the str will make use potentially have unhashable lines - # which will ruin the set. - assert not params - lines.add(line) - return lines - - -def q_get_flat_filters(q): - """ - Given a Q-object returns an iterator of all filters used in this Q-object. - - So for example for Q(foo=1, bar=2) this would yield 'foo' and 'bar', but it - will also work for more complicated nested Q-objects. - - This is useful to detect which fields are used in a Q-object. - """ - for child in q.children: - if isinstance(child, Q): - # If the child is another Q-object we can just yield recursively - yield from q_get_flat_filters(child) - elif isinstance(child, tuple): - # So now the child is a 2-tuple of filter & value, we just need the - # filter so we yield that - yield child[0] +from .json import JsonResponse, jsonloads def split_par_aware(content): @@ -342,7 +284,6 @@ def prefix_q_expression(value, prefix, antiprefix=None, model=None): children.append((prefix + '__' + child[0], child[1])) return Q(*children, _connector=value.connector, _negated=value.negated) - class ModelView(View): # Model this is a view for. Use None for views not tied to a particular model. model = None @@ -392,6 +333,9 @@ class ModelView(View): # NOTE: custom _store__foo() methods will still be called for unupdatable fields. unupdatable_fields = [] + # Allow validation without saving. + allow_standalone_validation = False + # Fields to use for ?search=foo. Empty tuple for disabled search. # NOTE: only string fields and 'id' are supported. # id is hardcoded to be treated as an integer. @@ -442,21 +386,6 @@ class ModelView(View): # } virtual_relations = {} - # A dict that looks like: - # {'filter_name': [ - # 'field.icontains', - # 'otherField.field:icontains' - # ], - # } - # Allows you to hide multiple filters behind a filter name - # see _parse_filter, see api.md > "Filtering by groups (alternative_filters)" - # NOTICE: alternative_filters may not contain a field or annotation as key - alternative_filters = {} - - # A dict that maps stat name to instances of binder.views.Stat - # These statistics can then be used in the stats view - stats = {} - @property def AggStrategy(self): if connections[self.model.objects.db].vendor == 'mysql': @@ -513,13 +442,17 @@ def dispatch(self, request, *args, **kwargs): response = None try: + # only allow standalone validation if you know what you are doing + if 'validate' in request.GET and request.GET['validate'] == 'true' and not self.allow_standalone_validation: + raise BinderRequestError('Standalone validation not enabled. You must enable this feature explicitly.') + #### START TRANSACTION with ExitStack() as stack, history.atomic(source='http', user=request.user, uuid=request.request_id): transaction_dbs = ['default'] # Check if the TRANSACTION_DATABASES is set in the settings.py, and if so, use that instead try: - transaction_dbs = settings.TRANSACTION_DATABASES + transaction_dbs = django.conf.settings.TRANSACTION_DATABASES except AttributeError: pass @@ -610,7 +543,7 @@ def _get_reverse_relations(self): # Kinda like model_to_dict() for multiple objects. # Return a list of dictionaries, one per object in the queryset. # Includes a list of ids for all m2m fields (including reverse relations). - def _get_objs(self, queryset, request, annotations=None, to_annotate={}): + def _get_objs(self, queryset, request, annotations=None): datas = [] datas_by_id = {} # Save datas so we can annotate m2m fields later (avoiding a query) objs_by_id = {} # Same for original objects @@ -628,52 +561,6 @@ def _get_objs(self, queryset, request, annotations=None, to_annotate={}): else: annotations &= set(self.shown_annotations) - # So now annotations are only being used for showing, so we filter out - # all that do not have to be shown - to_annotate = { - key: value - for key, value in to_annotate.items() - if key in annotations - } - - # So now we will divide annotations based on the joins they do - base_joins = get_joins_from_queryset(queryset) - annotation_sets = [] - - for name, expr in list(to_annotate.items()): - annotation_joins = get_joins_from_queryset( - self.model.objects.annotate(**{name: expr}) - ) - annotation_annotations = {name: expr} - # First check if the queryset already does all joins, in that case - # we can just add it to the main queryset without any performance - # hits - if annotation_joins <= base_joins: - queryset = queryset.annotate(**annotation_annotations) - to_annotate.pop(name) - continue - # Then try to merge it into the annotation sets - i = 0 - while i < len(annotation_sets): - set_joins, set_annotations = annotation_sets[i] - # If our joins are a subset of the annotation set we just add - # our annotation to the set and break - if annotation_joins <= set_joins: - set_annotations.update(annotation_annotations) - break - # If our joins are a superset of the annotation set we take its - # annotations and add it to ours - elif set_joins <= annotation_joins: - annotation_annotations.update(set_annotations) - annotation_sets.pop(i) - # Go on to the next - else: - i += 1 - # If no annotation set existed that matched our joins we create a - # new one - else: - annotation_sets.append((annotation_joins, annotation_annotations)) - for obj in queryset: # So we tend to make binder call queryset.distinct when necessary # to prevent duplicate results, this is however not always possible @@ -703,8 +590,7 @@ def _get_objs(self, queryset, request, annotations=None, to_annotate={}): data[f.name] = getattr(obj, f.attname) for a in annotations: - if a not in to_annotate: - data[a] = getattr(obj, a) + data[a] = getattr(obj, a) for prop in self.shown_properties: data[prop] = getattr(obj, prop) @@ -716,18 +602,6 @@ def _get_objs(self, queryset, request, annotations=None, to_annotate={}): datas_by_id[obj.pk] = data objs_by_id[obj.pk] = obj - for _, set_annotations in annotation_sets: - for set_values in ( - self.model.objects - .filter(pk__in=datas_by_id) - .annotate(**set_annotations) - .values('pk', *set_annotations) - ): - pk_ = set_values.pop('pk') - for name, value in set_values.items(): - datas_by_id[pk_][name] = value - setattr(objs_by_id[pk_], name, value) - self._annotate_objs(datas_by_id, objs_by_id) return datas @@ -751,7 +625,7 @@ def _annotate_objs(self, datas_by_id, objs_by_id): for obj_id, data in datas_by_id.items(): # TODO: Don't require OneToOneFields in the m2m_fields list if isinstance(local_field, models.OneToOneRel): - assert len(idmap[obj_id]) <= 1 + assert(len(idmap[obj_id]) <= 1) data[field_name] = idmap[obj_id][0] if len(idmap[obj_id]) == 1 else None else: data[field_name] = idmap[obj_id] @@ -767,15 +641,10 @@ def _annotate_objs(self, datas_by_id, objs_by_id): def _get_obj(self, pk, request, include_annotations=None): if include_annotations is None: include_annotations = self._parse_include_annotations(request) - annotations = include_annotations.get('') results = self._get_objs( - self.get_queryset(request).filter(pk=pk), + annotate(self.get_queryset(request).filter(pk=pk), request, include_annotations.get('')), request=request, - annotations=annotations, - to_annotate={ - name: value['expr'] - for name, value in get_annotations(self.model, request, annotations).items() - }, + annotations=include_annotations.get(''), ) if results: return results[0] @@ -947,13 +816,9 @@ def withs_to_nested_set(withs, result={}): view.router = self.router for annotations, with_pks in annotation_ids.items(): objs = view._get_objs( - view.get_queryset(request).filter(pk__in=with_pks), + annotate(view.get_queryset(request).filter(pk__in=with_pks), request, annotations), request=request, annotations=annotations, - to_annotate={ - name: value['expr'] - for name, value in get_annotations(view.model, request, annotations).items() - }, ) for obj in objs: view._annotate_obj_with_related_withs(obj, withs_per_model[model_name]) @@ -1191,53 +1056,10 @@ def _filter_relation(self, field_name, where_map, request, include_annotations): return FilterDescription(q, need_distinct) - # handle self.alternative_filters as part of _parse_filter - def _parse_alternative_filters(self, field, filter_heads, *args, **kwargs): - head, tail = re.fullmatch(r'([^.:]*)(.*)', field).groups() - - # split not/any/all of the tail, by default a filter is treated as any - # NOTE: not is_any ==> is_all - is_not = bool(re.search(r':not\b', tail)) - - is_any = bool(re.match(r':any\b', tail)) - if is_any: - tail = tail[4:] - is_all = bool(re.match(r':all\b', tail)) - if is_all and not is_any: - # if both is_any and is_all are true, we want the filter to fail - tail = tail[4:] - - alts = [] - for head in filter_heads: - field = head + tail - alt = self._parse_filter(field, *args, **kwargs) - alts.append(alt) - q, needs_distinct = alts[0] - for q_, needs_distinct_ in alts[1:]: - if is_not == is_all: - # :any (default) - # :not:all (NOTE that not is maintained inside the filter) - q |= q_ - else: - # :all - # :not:any (NOTE that not is maintained inside the filter) - q &= q_ - needs_distinct = needs_distinct or needs_distinct_ - return FilterDescription(q, needs_distinct) - def _parse_filter(self, field, value, request, include_annotations, partial=''): head, *tail = field.split('.') need_distinct = False - alt_head = head.split(':')[0] - try: - # get head without trailing :sorter - filter_heads = self.alternative_filters[alt_head] - except KeyError: - pass - else: - return self._parse_alternative_filters(field, filter_heads, value, request, include_annotations, partial) - if not tail: invert = False try: @@ -1349,9 +1171,10 @@ def _parse_order_by(self, queryset, field, request, partial=''): return (queryset, partial + head, nulls_last) - def _search_base(self, search, request): + + def search(self, queryset, search, request): if not search: - return ~Q(pk__in=[]) + return queryset if not (self.searches or self.transformed_searches): raise BinderRequestError('No search fields defined for this view.') @@ -1365,12 +1188,7 @@ def _search_base(self, search, request): q |= Q(**{s: transform(search)}) except ValueError: pass - - return q - - - def search(self, queryset, search, request): - return queryset.filter(self._search_base(search, request)) + return queryset.filter(q) def filter_deleted(self, queryset, pk, deleted, request): @@ -1431,23 +1249,13 @@ def get_queryset(self, request): - def _order_by_base(self, queryset, request, annotations): + def order_by(self, queryset, request): #### order_by order_bys = list(filter(None, request.GET.get('order_by', '').split(','))) orders = [] if order_bys: for o in order_bys: - # We split of a leading - (descending sorting) and the - # suffixes nulls_last and nulls_first - head = re.match(r'^-?(.*?)(__nulls_last|__nulls_first)?$', o).group(1) - try: - expr = annotations.pop(head) - except KeyError: - pass - else: - queryset = queryset.annotate(**{head: expr}) - if o.startswith('-'): queryset, order, nulls_last = self._parse_order_by(queryset, o[1:], request, partial='-') else: @@ -1483,29 +1291,16 @@ def _order_by_base(self, queryset, request, annotations): return queryset - def order_by(self, queryset, request): - return self._order_by_base(queryset, request, {}) - - def _annotate_obj_with_related_withs(self, obj, field_results): for (w, (view, ids_dict, is_singular)) in field_results.items(): if '.' not in w: if is_singular: try: obj[w] = list(ids_dict[obj['id']])[0] - except (IndexError): - """ - Indexerror => no relation is found - KeyERror => No relation is known for this model. - """ + except IndexError: obj[w] = None - except KeyError: - pass else: - try: - obj[w] = list(ids_dict[obj['id']]) - except KeyError: - pass + obj[w] = list(ids_dict[obj['id']]) def _generate_meta(self, include_meta, queryset, request, pk=None): @@ -1515,111 +1310,16 @@ def _generate_meta(self, include_meta, queryset, request, pk=None): # Only 'pk' values should reduce DB server memory a (little?) bit, making # things faster. Not prefetching related models here makes it faster still. # See also https://code.djangoproject.com/ticket/23771 and related tickets. - meta['total_records'] = queryset.order_by().prefetch_related(None).values('pk').count() + meta['total_records'] = queryset.prefetch_related(None).values('pk').count() return meta - def _apply_q_with_possible_annotations(self, queryset, q, annotations): - for filter in q_get_flat_filters(q): - head = filter.split('__', 1)[0] - try: - expr = annotations.pop(head) - except KeyError: - pass - else: - queryset = queryset.annotate(**{head: expr}) - - return queryset.filter(q) - - - def _after_expr(self, request, after_id, include_annotations): + def get_filtered_queryset(self, request, pk=None, include_annotations=None): """ - This method given a request and an id returns a boolean expression that - indicates if a record would show up after the provided id for the - ordering specified by this request. + Returns a scoped queryset with filtering and sorting applied as + specified by the request. """ - queryset = self.get_queryset(request) - annotations = { - name: value['expr'] - for name, value in self.annotations(request, include_annotations).items() - } - - # We do an order by on a copy of annotations so that we see which keys - # it pops - annotations_copy = annotations.copy() - ordering = self._order_by_base(queryset, request, annotations_copy).query.order_by - required_annotations = set(annotations) - set(annotations_copy) - - queryset = queryset.annotate(**{name: annotations[name] for name in required_annotations}) - try: - obj = queryset.get(pk=int(after_id)) - except (ValueError, self.model.DoesNotExist): - raise BinderRequestError(f'invalid value for after_id: {after_id!r}') - - # Now we will build up a comparison expr based on the order by - whens = [] - - for field in ordering: - # Fields are generally strings, except in some edge case, where it is an OrderBy expression. - # In thase case, we need to change it back to starting (sigh) for compatability reasons - - if type(field) is not str: - _field = field - - field = field.expression.name - - if _field.descending: - field = f'-{field}' - if _field.nulls_first: - field = f'{field}__nulls_first' - if _field.nulls_last: - field = f'{field}__nulls_last' - - # First we have to split of a leading '-' as indicating reverse - reverse = field.startswith('-') - if reverse: - field = field[1:] - - # Then we determine if nulls come last - if field.endswith('__nulls_last'): - field = field[:-12] - nulls_last = True - elif field.endswith('__nulls_first'): - field = field[:-13] - nulls_last = False - elif connections[self.model.objects.db].vendor == 'mysql': - # In MySQL null is considered to be the lowest possible value for ordering - nulls_last = reverse - else: - # In other databases null is considered to be the highest possible value for ordering - nulls_last = not reverse - - # Then we determine what the value is for the obj we need to be after - value = obj - for attr in field.split('__'): - value = getattr(value, attr, None) - if isinstance(value, models.Model): - value = value.pk - - # Now we add some conditions for the comparison - if value is None: - # If the value is None, that means we have to add a condition for when the field is not None because only then it is different - # What the result should be in that case is determined by nulls last - whens.append(When(Q(**{field + '__isnull': False}), then=Value(not nulls_last))) - else: - # If the field is None we give a result based on nulls last - whens.append(When(Q(**{field: None}), then=Value(nulls_last))) - # Otherwise we check with comparisons, note that equality is intentionally left open with these two options so in that case we go on to the next field - whens.append(When(Q(**{field + '__lt': value}), then=Value(reverse))) - whens.append(When(Q(**{field + '__gt': value}), then=Value(not reverse))) - - expr = Case(*whens, default=Value(False)) - - return expr, required_annotations - - - def _get_filtered_queryset_base(self, request, pk=None, include_annotations=None): queryset = self.get_queryset(request) if pk: queryset = queryset.filter(pk=int(pk)) @@ -1636,80 +1336,43 @@ def _get_filtered_queryset_base(self, request, pk=None, include_annotations=None if include_annotations is None: include_annotations = self._parse_include_annotations(request) - annotations = { - name: value['expr'] - for name, value in self.annotations(request, include_annotations).items() - } + queryset = annotate(queryset, request, include_annotations.get('')) #### filters filters = {k.lstrip('.'): v for k, v in request.GET.lists() if k.startswith('.')} for field, values in filters.items(): - for v in values: q, distinct = self._parse_filter(field, v, request, include_annotations) - queryset = self._apply_q_with_possible_annotations(queryset, q, annotations) + queryset = queryset.filter(q) if distinct: queryset = queryset.distinct() #### search if 'search' in request.GET: - q = self._search_base(request.GET['search'], request) - queryset = self._apply_q_with_possible_annotations(queryset, q, annotations) - - #### after - try: - after = request.GET['after'] - except KeyError: - pass - else: - after_expr, required_annotations = self._after_expr(request, after, include_annotations) - for name in required_annotations: - try: - expr = annotations.pop(name) - except KeyError: - pass - else: - queryset = queryset.annotate(**{name: expr}) - queryset = queryset.filter(after_expr) - - return queryset, annotations + queryset = self.search(queryset, request.GET['search'], request) - def get_filtered_queryset(self, request, *args, **kwargs): - """ - Returns a scoped queryset with filtering and sorting applied as - specified by the request. - """ - queryset, annotations = self._get_filtered_queryset_base(request, *args, **kwargs) - queryset = queryset.annotate(**annotations) queryset = self.order_by(queryset, request) + return queryset + def get(self, request, pk=None, withs=None, include_annotations=None): include_meta = request.GET.get('include_meta', 'total_records').split(',') if include_annotations is None: include_annotations = self._parse_include_annotations(request) - queryset, annotations = self._get_filtered_queryset_base(request, pk, include_annotations) + queryset = self.get_filtered_queryset(request, pk, include_annotations) meta = self._generate_meta(include_meta, queryset, request, pk) - queryset = self._order_by_base(queryset, request, annotations) queryset = self._paginate(queryset, request) - # We fetch the data with only the currently applied annotations - data = self._get_objs( - queryset, - request=request, - annotations=include_annotations.get(''), - to_annotate=annotations, - ) - - # Now we add all remaining annotations to this data - data_by_pk = {obj['id']: obj for obj in data} - pks = set(data_by_pk) - #### with # parse wheres from request + data = self._get_objs(queryset, request=request, annotations=include_annotations.get('')) + + pks = [obj['id'] for obj in data] + extras, extras_mapping, extras_reverse_mapping, field_results = self._get_withs(pks, withs, request=request, include_annotations=include_annotations) for obj in data: @@ -1725,7 +1388,7 @@ def get(self, request, pk=None, withs=None, include_annotations=None): meta['comment'] = self.comment debug = {'request_id': request.request_id} - if settings.DEBUG and 'debug' in request.GET: + if django.conf.settings.DEBUG and 'debug' in request.GET: debug['queries'] = ['{}s: {}'.format(q['time'], q['sql'].replace('"', '')) for q in django.db.connection.queries] debug['query_count'] = len(django.db.connection.queries) @@ -1775,6 +1438,18 @@ def binder_validation_error(self, obj, validation_error, pk=None): }) + + def _abort_when_standalone_validation(self, request): + """Raise a `BinderSkipSave` exception when this is a validation request.""" + if 'validate' in request.GET and request.GET['validate'] == 'true': + if self.allow_standalone_validation: + params = QueryDict(request.body) + raise BinderSkipSave + else: + raise BinderRequestError('Standalone validation not enabled. You must enable this feature explicitly.') + + + # Deserialize JSON to Django Model objects. # obj: Model object to update (for PUT), newly created object (for POST) # values: Python dict of {field name: value} (parsed JSON) @@ -1784,6 +1459,9 @@ def _store(self, obj, values, request, ignore_unknown_fields=False, pk=None): ignored_fields = [] validation_errors = [] + # When only validating and not saving we attach a parameter so that we can skip or add validation checks + only_validate = request.GET.get('validate') == 'true' or request.GET.get('validate') == 'True' + if obj.pk is None: self._require_model_perm('add', request, obj.pk) else: @@ -1821,8 +1499,8 @@ def store_m2m_field(obj, field, value, request): raise sum(validation_errors, None) try: - obj.save() - assert obj.pk is not None # At this point, the object must have been created. + obj.save(only_validate=only_validate) + assert(obj.pk is not None) # At this point, the object must have been created. except ValidationError as ve: validation_errors.append(self.binder_validation_error(obj, ve, pk=pk)) @@ -1852,15 +1530,10 @@ def store_m2m_field(obj, field, value, request): # Permission checks are done at this point, so we can avoid get_queryset() include_annotations = self._parse_include_annotations(request) - annotations = include_annotations.get('') data = self._get_objs( - self.model.objects.filter(pk=obj.pk), + annotate(self.model.objects.filter(pk=obj.pk), request, include_annotations.get('')), request=request, - annotations=annotations, - to_annotate={ - name: value['expr'] - for name, value in get_annotations(self.model, request, annotations).items() - }, + annotations=include_annotations.get(''), )[0] data['_meta'] = {'ignored_fields': ignored_fields} return data @@ -1872,6 +1545,9 @@ def store_m2m_field(obj, field, value, request): def _store_m2m_field(self, obj, field, value, request): validation_errors = [] + # When only validating and not saving we attach a parameter so that we can skip or add validation checks + only_validate = request.GET.get('validate') == 'true' or request.GET.get('validate') == 'True' + # Can't use isinstance() because apparantly ManyToManyDescriptor is a subclass of # ReverseManyToOneDescriptor. Yes, really. if getattr(obj._meta.model, field).__class__ == models.fields.related.ReverseManyToOneDescriptor: @@ -1914,11 +1590,11 @@ def _store_m2m_field(self, obj, field, value, request): for addobj in obj_field.model.objects.filter(id__in=new_ids - old_ids): setattr(addobj, obj_field.field.name, obj) try: - addobj.save() + addobj.save(only_validate=only_validate) except ValidationError as ve: validation_errors.append(self.binder_validation_error(addobj, ve)) else: - addobj.save() + addobj.save(only_validate=only_validate) elif getattr(obj._meta.model, field).__class__ == models.fields.related.ReverseOneToOneDescriptor: #### XXX FIXME XXX ugly quick fix for reverse relation + multiput issue if any(v for v in value if v is not None and v < 0): @@ -1934,7 +1610,7 @@ def _store_m2m_field(self, obj, field, value, request): remote_obj = field_descriptor.related.remote_field.model.objects.get(pk=value[0]) setattr(remote_obj, field_descriptor.related.remote_field.name, obj) try: - remote_obj.save() + remote_obj.save(only_validate=only_validate) remote_obj.refresh_from_db() except ValidationError as ve: validation_errors.append(self.binder_validation_error(remote_obj, ve)) @@ -2023,7 +1699,7 @@ def _clean_image_file(self, field, value): # Resize images that are too large. if width > max_width or height > max_height: - img.thumbnail((max_width, max_height), Image.LANCZOS) + img.thumbnail((max_width, max_height), Image.ANTIALIAS) logger.info('image dimensions ({}x{}) exceeded ({}, {}), resizing.'.format(width, height, max_width, max_height)) if img.mode not in ["1", "L", "P", "RGB", "RGBA"]: img = img.convert("RGB") @@ -2359,17 +2035,8 @@ def _multi_put_override_superclass(self, objects): # Collect overrides for (cls, mid), data in objects.items(): for subcls in getsubclasses(cls): - # Get remote field of the subclass - remote_field = subcls._meta.pk.remote_field - - # In some scenarios with proxy models - # The remote field may not exist - # Because proxy models are just pure python wrappers(without its own db table) for other models - if remote_field is None: - continue - # Get key of field pointing to subclass - subkey = remote_field.name + subkey = subcls._meta.pk.remote_field.name # Get id of subclass subid = data.pop(subkey, None) if subid is None: @@ -2527,7 +2194,7 @@ def _multi_put_save_objects(self, ordered_objects, objects, request): # corresponding view here? That would make it # more consistent with non-multi-PUT and POST, # also requiring view permissions. - qs = model.objects.filter(pk__in=oids).order_by().select_for_update() + qs = model.objects.filter(pk__in=oids).select_for_update() for obj in qs: locked_objects[(model, obj.pk)] = obj @@ -2640,7 +2307,7 @@ def _multi_put_deletions(self, deletions, new_id_map, request): def multi_put(self, request): - logger.info('ACTIVATING THE MULTI-PUT!!!1!') + logger.info('ACTIVATING THE MULTI-PUT!!!!!') # Hack to communicate to _store() that we're not interested in # the new data (for perf reasons). @@ -2648,13 +2315,15 @@ def multi_put(self, request): data, deletions = self._multi_put_parse_request(request) objects = self._multi_put_collect_objects(data) - objects, overrides = self._multi_put_override_superclass(objects) + objects, overrides = self._multi_put_override_superclass(objects) # model inheritance objects = self._multi_put_convert_backref_to_forwardref(objects) dependencies = self._multi_put_calculate_dependencies(objects) ordered_objects = self._multi_put_order_dependencies(dependencies) - new_id_map = self._multi_put_save_objects(ordered_objects, objects, request) - self._multi_put_id_map_add_overrides(new_id_map, overrides) - new_id_map = self._multi_put_deletions(deletions, new_id_map, request) + new_id_map = self._multi_put_save_objects(ordered_objects, objects, request) # may raise validation errors + self._multi_put_id_map_add_overrides(new_id_map, overrides) # model inheritance + new_id_map = self._multi_put_deletions(deletions, new_id_map, request) # may raise validation errors + + self._abort_when_standalone_validation(request) output = defaultdict(list) for (model, oid), nid in new_id_map.items(): @@ -2732,24 +2401,13 @@ def put(self, request, pk=None): values = self._get_request_values(request) try: - # Step one does the permission check. We cannot do a select for update here, since the queryeset - # of get_queryset can potentially have outer joins on nullable values - obj = self.get_queryset(request).get(pk=int(pk)) - - # Now that we know we have access to this moel, we can get it again, this time with lock. - obj = self.model.objects.select_for_update().get(pk=obj.pk) - + obj = self.get_queryset(request).select_for_update().get(pk=int(pk)) # Permission checks are done at this point, so we can avoid get_queryset() include_annotations = self._parse_include_annotations(request) - annotations = include_annotations.get('') old = self._get_objs( - self.model.objects.filter(pk=int(pk)), - request=request, - annotations=annotations, - to_annotate={ - name: value['expr'] - for name, value in get_annotations(self.model, request, annotations).items() - }, + annotate(self.model.objects.filter(pk=int(pk)), request, include_annotations.get('')), + request, + include_annotations.get(''), )[0] except ObjectDoesNotExist: raise BinderNotFound() @@ -2759,6 +2417,8 @@ def put(self, request, pk=None): data = self._store(obj, values, request) + self._abort_when_standalone_validation(request) + new = dict(data) new.pop('_meta', None) @@ -2789,6 +2449,8 @@ def post(self, request, pk=None): data = self._store(self.model(), values, request) + self._abort_when_standalone_validation(request) + new = dict(data) new.pop('_meta', None) @@ -2820,16 +2482,15 @@ def delete(self, request, pk=None, undelete=False, skip_body_check=False): except ValueError: pass - # This make sure we do all permission checks. We cannot do a select for update here, since there is a possiblity - # that we create a queryset that we cannot use in a for select clause. try: - obj = self.get_queryset(request).get(pk=int(pk)) + obj = self.get_queryset(request).select_for_update().get(pk=int(pk)) except ObjectDoesNotExist: raise BinderNotFound() - # We now retrieve the model again, without the permission checks, which we already this. - obj = self.model.objects.select_for_update().get(pk=obj.pk) self.delete_obj(obj, undelete, request) + + self._abort_when_standalone_validation(request) + logger.info('{}DELETEd {} #{}'.format('UN' if undelete else '', self._model_name(), pk)) return HttpResponse(status=204) # No content @@ -2842,6 +2503,10 @@ def delete_obj(self, obj, undelete, request): def soft_delete(self, obj, undelete, request): + + # When only validating and not saving we attach a parameter so that we can skip or add validation checks + only_validate = request.GET.get('validate') == 'true' or request.GET.get('validate') == 'True' + # Not only for soft delets, actually handles all deletions try: if obj.deleted and not undelete: @@ -2873,7 +2538,7 @@ def soft_delete(self, obj, undelete, request): obj.deleted = not undelete try: - obj.save() + obj.save(only_validate=only_validate) except ValidationError as ve: raise self.binder_validation_error(obj, ve) @@ -2894,33 +2559,17 @@ def dispatch_file_field(self, request, pk=None, file_field=None): file_field_name = file_field file_field = getattr(obj, file_field_name) - field = self.model._meta.get_field(file_field_name) if request.method == 'GET': if not file_field: raise BinderNotFound(file_field_name) - guess = mimetypes.guess_type(file_field.name) - content_type = (guess and guess[0]) or 'application/octet-stream' - serve_directly = isinstance(field, BinderFileField) and field.serve_directly - - is_video = content_type and content_type.startswith('video/') - + guess = mimetypes.guess_type(file_field.path) + guess = guess[0] if guess and guess[0] else 'application/octet-stream' try: - if serve_directly: - resp = HttpResponse(content_type=content_type) - resp[settings.INTERNAL_MEDIA_HEADER] = os.path.join(settings.INTERNAL_MEDIA_LOCATION, file_field.name) - # if the filefield does not start with '/' it is likely a http address (S3) instead of path - # and because of that we set the redirect url - if not file_field.url.startswith('/'): - resp['redirect_url'] = file_field.url - else: - file_handle = file_field.open('rb') - resp = FileResponse(file_handle, content_type=content_type) - if is_video: - resp['Accept-Ranges'] = 'bytes' + resp = StreamingHttpResponse(open(file_field.path, 'rb'), content_type=guess) except FileNotFoundError: - logger.error('Expected file {} not found'.format(file_field.name)) + logger.error('Expected file {} not found'.format(file_field.path)) raise BinderNotFound(file_field_name) if 'download' in request.GET: @@ -2972,7 +2621,7 @@ def filefield_get_name(self, instance=None, request=None, file_field=None): try: method = getattr(self, 'filefield_get_name_' + file_field.field.name) except AttributeError: - return os.path.basename(file_field.name) + return os.path.basename(file_field.path) return method(instance=instance, request=request, file_field=file_field) @@ -2983,7 +2632,7 @@ def view_history(self, request, pk=None, **kwargs): debug = kwargs['history'] == 'debug' - if debug and not settings.ENABLE_DEBUG_ENDPOINTS: + if debug and not django.conf.settings.ENABLE_DEBUG_ENDPOINTS: logger.warning('Debug endpoints disabled.') return HttpResponseForbidden('Debug endpoints disabled.') @@ -2991,131 +2640,7 @@ def view_history(self, request, pk=None, **kwargs): if debug: return history.view_changesets_debug(request, changesets.order_by('-id')) else: - return history.view_changesets(request, changesets.order_by('-id'), self.model, pk) - - - @list_route('stats', methods=['GET']) - def stats_view(self, request): - # We only apply annotations when used, so we can just pretend everything is included to simplify stuff - try: - annotations = self.model.Annotations - except AttributeError: - include_annotations = {'': []} - else: - include_annotations = {'': [ - attr - for attr in dir(annotations) - if not (attr.startswith('__') and attr.endswith('__')) - ]} - - queryset, annotations = self._get_filtered_queryset_base(request, None, include_annotations) - - try: - stats = request.GET['stats'] - except KeyError: - stats = [] - else: - stats = stats.split(',') - - return JsonResponse({ - stat: self._get_stat(request, queryset, stat, annotations.copy(), include_annotations) - for stat in stats - }) - - - def _get_stat(self, request, queryset, stat, annotations, include_annotations): - # NOTE: uses annotations! If called multiple times, provide a copy - # Get stat definition - try: - stat = self.stats[stat] - except KeyError: - try: - stat = DEFAULT_STATS[stat] - except KeyError: - raise BinderRequestError(f'unknown stat: {stat}') - - # Apply filters - for key, value in stat.filters.items(): - q, distinct = self._parse_filter(key, value, request, include_annotations) - queryset = self._apply_q_with_possible_annotations(queryset, q, annotations) - if distinct: - queryset = queryset.distinct() - - # Apply required annotations - for key in stat.annotations: - try: - expr = annotations.pop(key) - except KeyError: - pass - else: - queryset = queryset.annotate(**{key: expr}) - - if stat.group_by is None: - # No group by so just return a simple stat - return { - 'value': queryset.aggregate(result=stat.expr)['result'], - 'filters': stat.filters, - } - - group_by = stat.group_by.replace('.', '__') - - value = { - # The jsonloads/jsondumps is to make sure we can handle different - # types as keys, an example is dates. - jsonloads(jsondumps(key)): value - for key, value in ( - queryset - .order_by() - .exclude(**{group_by: None}) - .values(group_by) - .annotate(_binder_stat=stat.expr) - .values_list(group_by, '_binder_stat') - ) - } - - other = 0 - if stat.min_value is not None: - min_value = stat.min_value * sum(value.values()) - new_value = {} - - others = 0 - for key, sub_value in value.items(): - if sub_value >= min_value: - new_value[key] = sub_value - else: - other += sub_value - others += 1 - - if others > 1: - value = new_value - else: - other = 0 - - elif stat.max_values is not None: - keys = sorted(value, key=lambda key: value[key], reverse=True) - for key in keys[stat.max_values:]: - other += value.pop(key) - - return { - 'value': value, - 'other': other, - 'group_by': stat.group_by, - 'filters': stat.filters, - } - - - def _apply_annotations(self, queryset, annotations, *fields): - for field in fields: - if field is None: - continue - field = field.split('__', 1)[0] - try: - annotation = annotations.pop(field) - except KeyError: - pass - else: - queryset = queryset.annotate(**{field: annotation}) - return queryset + return history.view_changesets(request, changesets.order_by('-id')) @@ -3137,7 +2662,7 @@ def debug_changesets_24h(request): logger.warning('Not authenticated.') return HttpResponseForbidden('Not authenticated.') - if not settings.ENABLE_DEBUG_ENDPOINTS: + if not django.conf.settings.ENABLE_DEBUG_ENDPOINTS: logger.warning('Debug endpoints disabled.') return HttpResponseForbidden('Debug endpoints disabled.') diff --git a/docs/api.md b/docs/api.md index 1b2f2d93..0156a614 100644 --- a/docs/api.md +++ b/docs/api.md @@ -97,7 +97,7 @@ Ordering is a simple matter of enumerating the fields in the `order_by` query pa The default sort order is ascending. If you want to sort in descending order, simply prefix the attribute name with a minus sign. This honors the scoping, so `api/animal?order_by=-name,id` will sort by `name` in descending order and by `id` in ascending order. -### Saving a model +### Saving or updating a model Creating a new model is possible with `POST api/animal/`, and updating a model with `PUT api/animal/`. Both requests accept a JSON body, like this: @@ -191,6 +191,17 @@ If this request succeeds, you'll get back a mapping of the fake ids and the real It is also possible to update existing models with multi PUT. If you use a "real" id instead of a fake one, the model will be updated instead of created. + +#### Standalone Validation (without saving models) + +Sometimes you want to validate the model that you are going to save without actually saving it. This is useful, for example, when you want to inform the user of validation errors on the frontend, without having to implement the validation logic again. You may check for validation errors by sending a `POST`, `PUT` or `PATCH` request with an additional query parameter `validate`. + +Currently this is implemented by raising an `BinderValidateOnly` exception, which makes sure that the atomic database transaction is aborted. Ideally, you would only want to call the validation logic on the models, so only calling validation for fields and validation for model (`clean()`). But for now, we do it this way, at the cost of a performance penalty. + +It is important to realize that in this way, the normal `save()` function is called on a model, so it is possible that possible side effects are triggered, when these are implemented directly in `save()`, as opposed to in a signal method, which would be preferable. In other words, we cannot guarantee that the request will be idempotent. Therefore, the validation only feature is disabled by default and must be enabled by setting `allow_standalone_validation=True` on the view. + +When a model is being validated and not actually being saved the `_validation_model property` of the binder model is set to True. This allows whitelisting of certain validation checks such as with certain relations that are not included with the validation model. + ### Uploading files To upload a file, you have to add it to the `file_fields` of the `ModelView`: diff --git a/tests/plugins/test_csvexport.py b/tests/plugins/test_csvexport.py index 6837c4a2..d2d4d010 100644 --- a/tests/plugins/test_csvexport.py +++ b/tests/plugins/test_csvexport.py @@ -1,5 +1,4 @@ from PIL import Image -from binder.plugins.views.csvexport import ExcelFileAdapter from os import urandom from tempfile import NamedTemporaryFile import io @@ -8,8 +7,7 @@ from django.core.files import File from django.contrib.auth.models import User -from ..testapp.models import Picture, Animal, Caretaker -from ..testapp.views import PictureView +from ..testapp.models import Picture, Animal, PictureBook import csv import openpyxl @@ -33,8 +31,10 @@ def setUp(self): self.pictures = [] + picture_book = PictureBook.objects.create(name='Holiday 2012') + for i in range(3): - picture = Picture(animal=animal) + picture = Picture(animal=animal, picture_book=picture_book) file = CsvExportTest.temp_imagefile(50, 50, 'jpeg') picture.file.save('picture.jpg', File(file), save=False) picture.original_file.save('picture_copy.jpg', File(file), save=False) @@ -50,6 +50,7 @@ def setUp(self): def test_csv_download(self): response = self.client.get('/picture/download_csv/') + print(response.content) self.assertEqual(200, response.status_code) response_data = csv.reader(io.StringIO(response.content.decode("utf-8"))) @@ -71,8 +72,7 @@ def test_excel_download(self): tmp.write(response.content) wb = openpyxl.load_workbook(tmp.name) - self.assertEqual(1, len(wb._sheets)) - sheet = wb._sheets[0] + sheet = wb._sheets[1] _values = list(sheet.values) @@ -105,28 +105,6 @@ def test_context_aware_downloader_default_csv(self): self.assertEqual(data[2], [str(self.pictures[1].id), str(self.pictures[1].animal_id), str(self.pictures[1].id ** 2)]) self.assertEqual(data[3], [str(self.pictures[2].id), str(self.pictures[2].animal_id), str(self.pictures[2].id ** 2)]) - def test_download_extra_params(self): - caretaker_1 = Caretaker(name='Foo') - caretaker_1.save() - caretaker_2 = Caretaker(name='Bar') - caretaker_2.save() - caretaker_3 = Caretaker(name='Baz') - caretaker_3.save() - - response = self.client.get('/caretaker/download/') - self.assertEqual(200, response.status_code) - response_data = csv.reader(io.StringIO(response.content.decode("utf-8"))) - - data = list(response_data) - - # First line needs to be the header - self.assertEqual(data[0], ['ID', 'Name', 'Scary']) - - # All other data needs to be ordered using the default ordering (by id, asc) - self.assertEqual(data[1], [str(caretaker_1.id), 'Foo', 'boo!']) - self.assertEqual(data[2], [str(caretaker_2.id), 'Bar', 'boo!']) - self.assertEqual(data[3], [str(caretaker_3.id), 'Baz', 'boo!']) - def test_context_aware_download_xlsx(self): response = self.client.get('/picture/download/?response_type=xlsx') self.assertEqual(200, response.status_code) @@ -135,7 +113,7 @@ def test_context_aware_download_xlsx(self): tmp.write(response.content) wb = openpyxl.load_workbook(tmp.name) - sheet = wb._sheets[0] + sheet = wb._sheets[1] _values = list(sheet.values) @@ -150,63 +128,23 @@ def test_context_aware_download_xlsx(self): self.assertEqual(list(_values[3]), [self.pictures[2].id, str(self.pictures[2].animal_id), (self.pictures[2].id ** 2)]) - def test_csv_export_custom_limit(self): - old_limit = PictureView.csv_settings.limit; - PictureView.csv_settings.limit = 1 - response = self.client.get('/picture/download/') - self.assertEqual(200, response.status_code) - response_data = csv.reader(io.StringIO(response.content.decode("utf-8"))) - - # Header - self.assertEqual(next(response_data), ['picture identifier', 'animal identifier', 'squared picture identifier']) - # 1 REcord - self.assertIsNotNone(next(response_data)) - - # EOF - with self.assertRaises(StopIteration): - self.assertIsNone(next(response_data)) - - ###### Limit 2 - PictureView.csv_settings.limit = 2 - response = self.client.get('/picture/download/') - self.assertEqual(200, response.status_code) - response_data = csv.reader(io.StringIO(response.content.decode("utf-8"))) - - # Header - self.assertEqual(next(response_data), ['picture identifier', 'animal identifier', 'squared picture identifier']) - # 1 REcord - self.assertIsNotNone(next(response_data)) - # 2 Records - self.assertIsNotNone(next(response_data)) - # EOF - with self.assertRaises(StopIteration): - self.assertIsNone(next(response_data)) - - PictureView.csv_settings.limit = old_limit; + def test_none_foreign_key(self): + """ + If we have a relation that we have to include which is nullable, and we have a foreign key, we do not want the + csv export to crash. I.e. we also want for a picture to export the picture book name, even though not all pioctures + belong to a picture book + :return: + """ + self.pictures[0].picture_book = None + self.pictures[0].save() - def test_csv_settings_limit_none_working(self): - # Limit None should download everything - - old_limit = PictureView.csv_settings.limit; - PictureView.csv_settings.limit = None response = self.client.get('/picture/download/') self.assertEqual(200, response.status_code) response_data = csv.reader(io.StringIO(response.content.decode("utf-8"))) - # Header - self.assertEqual(next(response_data), ['picture identifier', 'animal identifier', 'squared picture identifier']) - # 3 REcords, everything we have in the database - self.assertIsNotNone(next(response_data)) - self.assertIsNotNone(next(response_data)) - self.assertIsNotNone(next(response_data)) - - # EOF - with self.assertRaises(StopIteration): - self.assertIsNone(next(response_data)) + data = list(response_data) + content = data[1:] - PictureView.csv_settings.limit = old_limit; + picture_books = [c[-1] for c in content] -class TestExcelFileAdapter(TestCase): - def test_one_sheet_after_init(self): - file_adapter = ExcelFileAdapter(None) - self.assertEqual(len(file_adapter.work_book.worksheets), 1) + self.assertEqual(['', 'Holiday 2012', 'Holiday 2012'], picture_books) \ No newline at end of file diff --git a/tests/test_model_validation.py b/tests/test_model_validation.py new file mode 100644 index 00000000..b05dba5a --- /dev/null +++ b/tests/test_model_validation.py @@ -0,0 +1,271 @@ +from re import I +from tests.testapp.models import contact_person +from tests.testapp.models.contact_person import ContactPerson +from django.test import TestCase, Client + +import json +from binder.json import jsonloads +from django.contrib.auth.models import User +from .testapp.models import Animal, Caretaker, ContactPerson + + +class TestModelValidation(TestCase): + """ + Test the validate-only functionality. + + We check that the validation is executed as normal, but that the models + are not created when the validate query paramter is set to true. + + We check validation for: + - post + - put + - multi-put + - delete + """ + + + def setUp(self): + super().setUp() + u = User(username='testuser', is_active=True, is_superuser=True) + u.set_password('test') + u.save() + self.client = Client() + r = self.client.login(username='testuser', password='test') + self.assertTrue(r) + + # some update payload + self.model_data_with_error = { + 'name': 'very_special_forbidden_contact_person_name', # see `contact_person.py` + } + self.model_data_with_non_validation_error = { + 'name': 'very_special_validation_contact_person_name', # see `contact_person.py` + } + self.model_data = { + 'name': 'Scooooooby', + } + + + ### helpers ### + + + def assert_validation_error(self, response, person_id=None): + if person_id is None: + person_id = 'null' # for post + + self.assertEqual(response.status_code, 400) + + returned_data = jsonloads(response.content) + + # check that there were validation errors + self.assertEqual(returned_data.get('code'), 'ValidationError') + + # check that the validation error is present + validation_error = returned_data.get('errors').get('contact_person').get(str(person_id)).get('__all__')[0] + self.assertEqual(validation_error.get('code'), 'invalid') + self.assertEqual(validation_error.get('message'), 'Very special validation check that we need in `tests.M2MStoreErrorsTest`.') + + + def assert_multi_put_validation_error(self, response): + self.assertEqual(response.status_code, 400) + + returned_data = jsonloads(response.content) + + # check that there were validation errors + self.assertEqual(returned_data.get('code'), 'ValidationError') + + # check that all (two) the validation errors are present + for error in returned_data.get('errors').get('contact_person').values(): + validation_error = error.get('__all__')[0] + self.assertEqual(validation_error.get('code'), 'invalid') + self.assertEqual(validation_error.get('message'), 'Very special validation check that we need in `tests.M2MStoreErrorsTest`.') + + + ### tests ### + + + def assert_no_validation_error(self, response): + self.assertEqual(response.status_code, 200) + + # check that the validation was successful + returned_data = jsonloads(response.content) + self.assertEqual(returned_data.get('code'), 'SkipSave') + self.assertEqual(returned_data.get('message'), 'No validation errors were encountered.') + + + def test_validate_on_post(self): + self.assertEqual(0, ContactPerson.objects.count()) + + # trigger a validation error + response = self.client.post('/contact_person/?validate=true', data=json.dumps(self.model_data_with_error), content_type='application/json') + self.assert_validation_error(response) + self.assertEqual(0, ContactPerson.objects.count()) + + # now without validation errors + response = self.client.post('/contact_person/?validate=true', data=json.dumps(self.model_data), content_type='application/json') + self.assert_no_validation_error(response) + self.assertEqual(0, ContactPerson.objects.count()) + + # now for real + response = self.client.post('/contact_person/', data=json.dumps(self.model_data), content_type='application/json') + self.assertEqual(response.status_code, 200) + self.assertEqual('Scooooooby', ContactPerson.objects.first().name) + + + def test_validate_on_put(self): + person_id = ContactPerson.objects.create(name='Scooby Doo').id + self.assertEqual('Scooby Doo', ContactPerson.objects.first().name) + + # trigger a validation error + response = self.client.put(f'/contact_person/{person_id}/?validate=true', data=json.dumps(self.model_data_with_error), content_type='application/json') + self.assert_validation_error(response, person_id) + self.assertEqual('Scooby Doo', ContactPerson.objects.first().name) + + # now without validation errors + response = self.client.put(f'/contact_person/{person_id}/?validate=true', data=json.dumps(self.model_data), content_type='application/json') + self.assert_no_validation_error(response) + self.assertEqual('Scooby Doo', ContactPerson.objects.first().name) + + # now for real + response = self.client.put(f'/contact_person/{person_id}/', data=json.dumps(self.model_data), content_type='application/json') + self.assertEqual(response.status_code, 200) + self.assertEqual('Scooooooby', ContactPerson.objects.first().name) + + def test_validate_model_validation_whitelisting(self): + person_id = ContactPerson.objects.create(name='Scooby Doo').id + self.assertEqual('Scooby Doo', ContactPerson.objects.first().name) + + # the normal request should give a validation error + response = self.client.put(f'/contact_person/{person_id}/', data=json.dumps(self.model_data_with_non_validation_error), content_type='application/json') + self.assert_validation_error(response, person_id) + self.assertEqual('Scooby Doo', ContactPerson.objects.first().name) + + # when just validating we want to ignore this validation error, so with validation it should not throw an error + response = self.client.put(f'/contact_person/{person_id}/?validate=true', data=json.dumps(self.model_data), content_type='application/json') + self.assert_no_validation_error(response) + self.assertEqual('Scooby Doo', ContactPerson.objects.first().name) + + + + def test_validate_on_multiput(self): + person_1_id = ContactPerson.objects.create(name='Scooby Doo 1').id + person_2_id = ContactPerson.objects.create(name='Scooby Doo 2').id + + multi_put_data = {'data': [ + { + 'id': person_1_id, + 'name': 'New Scooby', + }, + { + 'id': person_2_id, + 'name': 'New Doo' + } + ]} + + multi_put_data_with_error = {'data': [ + { + 'id': person_1_id, + 'name': 'very_special_forbidden_contact_person_name', + }, + { + 'id': person_2_id, + 'name': 'very_special_forbidden_contact_person_name' + } + ]} + + multi_put_data_with_validation_whitelist = {'data': [ + { + 'id': person_1_id, + 'name': 'very_special_validation_contact_person_name', + }, + { + 'id': person_2_id, + 'name': 'very_special_validation_contact_person_other_name' + } + ]} + + # trigger a validation error + response = self.client.put(f'/contact_person/?validate=true', data=json.dumps(multi_put_data_with_error), content_type='application/json') + self.assert_multi_put_validation_error(response) + self.assertEqual('Scooby Doo 1', ContactPerson.objects.get(id=person_1_id).name) + self.assertEqual('Scooby Doo 2', ContactPerson.objects.get(id=person_2_id).name) + + + # now without validation error + response = self.client.put(f'/contact_person/?validate=true', data=json.dumps(multi_put_data), content_type='application/json') + self.assert_no_validation_error(response) + self.assertEqual('Scooby Doo 1', ContactPerson.objects.get(id=person_1_id).name) + self.assertEqual('Scooby Doo 2', ContactPerson.objects.get(id=person_2_id).name) + + # multi put validation whitelist test + response = self.client.put(f'/contact_person/?validate=true', data=json.dumps(multi_put_data_with_validation_whitelist), content_type='application/json') + self.assert_no_validation_error(response) + self.assertEqual('Scooby Doo 1', ContactPerson.objects.get(id=person_1_id).name) + self.assertEqual('Scooby Doo 2', ContactPerson.objects.get(id=person_2_id).name) + + # multi put non validation whitelist test error + response = self.client.put(f'/contact_person/', + data=json.dumps(multi_put_data_with_validation_whitelist), + content_type='application/json') + self.assert_multi_put_validation_error(response) + self.assertEqual('Scooby Doo 1', ContactPerson.objects.get(id=person_1_id).name) + self.assertEqual('Scooby Doo 2', ContactPerson.objects.get(id=person_2_id).name) + + # now for real + response = self.client.put(f'/contact_person/', data=json.dumps(multi_put_data), content_type='application/json') + self.assertEqual(response.status_code, 200) + self.assertEqual('New Scooby', ContactPerson.objects.get(id=person_1_id).name) + self.assertEqual('New Doo', ContactPerson.objects.get(id=person_2_id).name) + + + def test_validate_on_delete(self): + '''Check if deletion is cancelled when we only attempt to validate + the delete operation. This test only covers validation of the + on_delete=PROTECT constraint of a fk.''' + + def is_deleted(obj): + '''Whether the obj was soft-deleted, so when the 'deleted' + attribute is not present, or when it is True.''' + + try: + obj.refresh_from_db() + except obj.DoesNotExist: + return True # hard-deleted + return animal.__dict__.get('deleted') or False + + + # animal has a fk to caretaker with on_delete=PROTECT + caretaker = Caretaker.objects.create(name='Connie Care') + animal = Animal.objects.create(name='Pony', caretaker=caretaker) + + + ### with validation error + + response = self.client.delete(f'/caretaker/{caretaker.id}/?validate=true') + # assert validation error + # and check that it was about the PROTECTED constraint + self.assertEqual(response.status_code, 400) + returned_data = jsonloads(response.content) + self.assertEqual(returned_data.get('code'), 'ValidationError') + self.assertEqual(returned_data.get('errors').get('caretaker').get(str(caretaker.id)).get('id')[0].get('code'), 'protected') + + self.assertFalse(is_deleted(caretaker)) + + + ### without validation error + + # now we delete the animal to make sure that deletion is possible + # note that soft-deleting will of course not remove the validation error + animal.delete() + + # now no validation error should be trown + response = self.client.delete(f'/caretaker/{caretaker.id}/?validate=true') + print(response.content) + self.assert_no_validation_error(response) + + self.assertFalse(is_deleted(caretaker)) + + + ### now for real + + response = self.client.delete(f'/caretaker/{caretaker.id}/') + self.assertTrue(is_deleted(caretaker)) diff --git a/tests/testapp/models/__init__.py b/tests/testapp/models/__init__.py index f944a1ca..48e0e08f 100644 --- a/tests/testapp/models/__init__.py +++ b/tests/testapp/models/__init__.py @@ -3,7 +3,6 @@ # We have it all, from A to Z! from .animal import Animal -from .donor import Donor from .caretaker import Caretaker from .contact_person import ContactPerson from .costume import Costume @@ -16,13 +15,6 @@ from .city import City, CityState, PermanentCity from .country import Country from .web_page import WebPage -from .pet import Pet -from .reverse_config_models import ( - ReverseParent, - ReverseChild, - ReverseParentNoChildHistory, - ReverseChildNoHistory, -) # This is Postgres-specific if os.environ.get('BINDER_TEST_MYSQL', '0') != '1': diff --git a/tests/testapp/models/animal.py b/tests/testapp/models/animal.py index fde79d7b..ae8abc1f 100644 --- a/tests/testapp/models/animal.py +++ b/tests/testapp/models/animal.py @@ -14,7 +14,7 @@ class Animal(LoadedValuesMixin, BinderModel): name = models.TextField(max_length=64) zoo = models.ForeignKey('Zoo', on_delete=models.CASCADE, related_name='animals', blank=True, null=True) zoo_of_birth = models.ForeignKey('Zoo', on_delete=models.CASCADE, related_name='+', blank=True, null=True) # might've been born outside captivity - caretaker = models.ForeignKey('Caretaker', on_delete=models.PROTECT, related_name='animals', blank=True, null=True) + caretaker = models.ForeignKey('Caretaker', on_delete=models.PROTECT, related_name='animals', blank=True, null=True) # we use the fact that this one is PROTECT in `test_model_validation.py` deleted = models.BooleanField(default=False) # Softdelete birth_date = models.DateField(blank=True, null=True) diff --git a/tests/testapp/models/contact_person.py b/tests/testapp/models/contact_person.py index baa1abcb..6b699f5a 100644 --- a/tests/testapp/models/contact_person.py +++ b/tests/testapp/models/contact_person.py @@ -23,3 +23,10 @@ def clean(self): code='invalid', message='Very special validation check that we need in `tests.M2MStoreErrorsTest`.' ) + + # Should only give an error when model is not a validation model + if (self.name == 'very_special_validation_contact_person_name' or self.name == 'very_special_validation_contact_person_other_name') and not self._validation_model: + raise ValidationError( + code='invalid', + message='Very special validation check that we need in `tests.M2MStoreErrorsTest`.' + ) diff --git a/tests/testapp/models/picture.py b/tests/testapp/models/picture.py index ffbe59bb..862bfb3e 100644 --- a/tests/testapp/models/picture.py +++ b/tests/testapp/models/picture.py @@ -13,6 +13,18 @@ def delete_files(sender, instance=None, **kwargs): except Exception: pass + +class PictureBook(BinderModel): + """ + Sometimes customers like to commemorate their visit to the zoo. Of course there are always some shitty pictures that + we do not want in a picture album + + """ + + name = models.TextField() + + + # At the website of the zoo there are some pictures of animals. This model links the picture to an animal. # # A picture has two files, the original uploaded file, and the modified file. This model is used for testing the @@ -21,6 +33,7 @@ class Picture(BinderModel): animal = models.ForeignKey('Animal', on_delete=models.CASCADE, related_name='picture') file = models.ImageField(upload_to='floor-plans') original_file = models.ImageField(upload_to='floor-plans') + picture_book = models.ForeignKey('PictureBook', on_delete=models.CASCADE, null=True, blank=True) def __str__(self): return 'picture %d: (Picture for animal %s)' % (self.pk or 0, self.animal.name) diff --git a/tests/testapp/views/__init__.py b/tests/testapp/views/__init__.py index 31fb1049..40fcd3c9 100644 --- a/tests/testapp/views/__init__.py +++ b/tests/testapp/views/__init__.py @@ -20,5 +20,3 @@ from .zoo import ZooView from .zoo_employee import ZooEmployeeView from .web_page import WebPageView -from .donor import DonorView -from .pet import PetView diff --git a/tests/testapp/views/caretaker.py b/tests/testapp/views/caretaker.py index 85698354..a166685a 100644 --- a/tests/testapp/views/caretaker.py +++ b/tests/testapp/views/caretaker.py @@ -9,6 +9,9 @@ class CaretakerView(CsvExportView, ModelView): unupdatable_fields = ['first_seen'] model = Caretaker + # see `test_model_validation.py` + allow_standalone_validation = True + csv_settings = CsvExportView.CsvExportSettings( withs=[], column_map=[ @@ -17,4 +20,4 @@ class CaretakerView(CsvExportView, ModelView): ('scary', 'Scary'), ], extra_params={'include_annotations': 'scary'}, - ) + ) \ No newline at end of file diff --git a/tests/testapp/views/contact_person.py b/tests/testapp/views/contact_person.py index c1e90c5b..bc2f25fe 100644 --- a/tests/testapp/views/contact_person.py +++ b/tests/testapp/views/contact_person.py @@ -6,3 +6,6 @@ class ContactPersonView(ModelView): model = ContactPerson m2m_fields = ['zoos'] unwritable_fields = ['created_at', 'updated_at'] + + # see `test_model_validation.py` + allow_standalone_validation = True diff --git a/tests/testapp/views/picture.py b/tests/testapp/views/picture.py index 7dadfe99..0ba35f99 100644 --- a/tests/testapp/views/picture.py +++ b/tests/testapp/views/picture.py @@ -3,16 +3,19 @@ from binder.views import ModelView from binder.plugins.views import ImageView, CsvExportView -from ..models import Picture +from ..models import Picture, PictureBook +class PictureBookView(ModelView): + model = PictureBook class PictureView(ModelView, ImageView, CsvExportView): model = Picture file_fields = ['file', 'original_file'] - csv_settings = CsvExportView.CsvExportSettings(['animal'], [ + csv_settings = CsvExportView.CsvExportSettings(['animal', 'picture_book'], [ ('id', 'picture identifier'), ('animal.id', 'animal identifier'), ('id', 'squared picture identifier', lambda id, row, mapping: id**2), + ('picture_book.name', 'Picturebook name') ]) @list_route(name='download_csv', methods=['GET'])