From 779ce66f6400c339298797600b7d6fccbe33471d Mon Sep 17 00:00:00 2001 From: Jeroen van Riel Date: Mon, 29 Mar 2021 11:22:05 +0200 Subject: [PATCH 01/33] Fix horrible typo --- binder/views.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/binder/views.py b/binder/views.py index 5285e529..1a271e52 100644 --- a/binder/views.py +++ b/binder/views.py @@ -2640,7 +2640,7 @@ def _multi_put_deletions(self, deletions, new_id_map, request): def multi_put(self, request): - logger.info('ACTIVATING THE MULTI-PUT!!!1!') + logger.info('ACTIVATING THE MULTI-PUT!!!!!') # Hack to communicate to _store() that we're not interested in # the new data (for perf reasons). From 18a8343898163715ae09ee073a4eb9c9e8cc368f Mon Sep 17 00:00:00 2001 From: Jeroen van Riel Date: Mon, 29 Mar 2021 13:27:38 +0200 Subject: [PATCH 02/33] Add standalone validation feature --- binder/exceptions.py | 10 ++++++++++ binder/views.py | 37 ++++++++++++++++++++++++++++++++----- docs/api.md | 11 ++++++++++- 3 files changed, 52 insertions(+), 6 deletions(-) diff --git a/binder/exceptions.py b/binder/exceptions.py index d5aec299..fbf9bd6f 100644 --- a/binder/exceptions.py +++ b/binder/exceptions.py @@ -235,3 +235,13 @@ def __add__(self, other): else: errors[model] = other.errors[model] return BinderValidationError(errors) + + +class BinderSkipSave(BinderException): + """Used to abort the database transaction when performing a non-save validation request.""" + http_code = 200 + code = 'SkipSave' + + def __init__(self): + super().__init__() + self.fields['message'] = 'No validation errors were encountered.' diff --git a/binder/views.py b/binder/views.py index 1a271e52..3a91d65b 100644 --- a/binder/views.py +++ b/binder/views.py @@ -29,7 +29,11 @@ from django.db.models.fields.reverse_related import ForeignObjectRel -from .exceptions import BinderException, BinderFieldTypeError, BinderFileSizeExceeded, BinderForbidden, BinderImageError, BinderImageSizeExceeded, BinderInvalidField, BinderIsDeleted, BinderIsNotDeleted, BinderMethodNotAllowed, BinderNotAuthenticated, BinderNotFound, BinderReadOnlyFieldError, BinderRequestError, BinderValidationError, BinderFileTypeIncorrect, BinderInvalidURI +from .exceptions import ( + BinderException, BinderFieldTypeError, BinderFileSizeExceeded, BinderForbidden, BinderImageError, BinderImageSizeExceeded, + BinderInvalidField, BinderIsDeleted, BinderIsNotDeleted, BinderMethodNotAllowed, BinderNotAuthenticated, BinderNotFound, + BinderReadOnlyFieldError, BinderRequestError, BinderValidationError, BinderFileTypeIncorrect, BinderInvalidURI, BinderSkipSave +) from . import history from .orderable_agg import OrderableArrayAgg, GroupConcat, StringAgg from .models import FieldFilter, BinderModel, ContextAnnotation, OptionalAnnotation, BinderFileField, BinderImageField @@ -392,6 +396,9 @@ class ModelView(View): # NOTE: custom _store__foo() methods will still be called for unupdatable fields. unupdatable_fields = [] + # Allow validation without saving. + allow_standalone_validation = False + # Fields to use for ?search=foo. Empty tuple for disabled search. # NOTE: only string fields and 'id' are supported. # id is hardcoded to be treated as an integer. @@ -1775,6 +1782,17 @@ def binder_validation_error(self, obj, validation_error, pk=None): }) + + def _abort_when_standalone_validation(self, request): + """Raise a `BinderSkipSave` exception when this is a standalone request.""" + if self.allow_standalone_validation: + if 'validate' in request.GET: + raise BinderSkipSave + else: + raise BinderException('Standalone validation not enabled. You must enable this feature explicitly.') + + + # Deserialize JSON to Django Model objects. # obj: Model object to update (for PUT), newly created object (for POST) # values: Python dict of {field name: value} (parsed JSON) @@ -2648,13 +2666,15 @@ def multi_put(self, request): data, deletions = self._multi_put_parse_request(request) objects = self._multi_put_collect_objects(data) - objects, overrides = self._multi_put_override_superclass(objects) + objects, overrides = self._multi_put_override_superclass(objects) # model inheritance objects = self._multi_put_convert_backref_to_forwardref(objects) dependencies = self._multi_put_calculate_dependencies(objects) ordered_objects = self._multi_put_order_dependencies(dependencies) - new_id_map = self._multi_put_save_objects(ordered_objects, objects, request) - self._multi_put_id_map_add_overrides(new_id_map, overrides) - new_id_map = self._multi_put_deletions(deletions, new_id_map, request) + new_id_map = self._multi_put_save_objects(ordered_objects, objects, request) # may raise validation errors + self._multi_put_id_map_add_overrides(new_id_map, overrides) # model inheritance + new_id_map = self._multi_put_deletions(deletions, new_id_map, request) # may raise validation errors + + self._abort_when_standalone_validation(request) output = defaultdict(list) for (model, oid), nid in new_id_map.items(): @@ -2759,6 +2779,8 @@ def put(self, request, pk=None): data = self._store(obj, values, request) + self._abort_when_standalone_validation(request) + new = dict(data) new.pop('_meta', None) @@ -2789,6 +2811,8 @@ def post(self, request, pk=None): data = self._store(self.model(), values, request) + self._abort_when_standalone_validation(request) + new = dict(data) new.pop('_meta', None) @@ -2830,6 +2854,9 @@ def delete(self, request, pk=None, undelete=False, skip_body_check=False): obj = self.model.objects.select_for_update().get(pk=obj.pk) self.delete_obj(obj, undelete, request) + + self._abort_when_standalone_validation(request) + logger.info('{}DELETEd {} #{}'.format('UN' if undelete else '', self._model_name(), pk)) return HttpResponse(status=204) # No content diff --git a/docs/api.md b/docs/api.md index 1b2f2d93..850a9bb8 100644 --- a/docs/api.md +++ b/docs/api.md @@ -97,7 +97,7 @@ Ordering is a simple matter of enumerating the fields in the `order_by` query pa The default sort order is ascending. If you want to sort in descending order, simply prefix the attribute name with a minus sign. This honors the scoping, so `api/animal?order_by=-name,id` will sort by `name` in descending order and by `id` in ascending order. -### Saving a model +### Saving or updating a model Creating a new model is possible with `POST api/animal/`, and updating a model with `PUT api/animal/`. Both requests accept a JSON body, like this: @@ -191,6 +191,15 @@ If this request succeeds, you'll get back a mapping of the fake ids and the real It is also possible to update existing models with multi PUT. If you use a "real" id instead of a fake one, the model will be updated instead of created. + +#### Standalone Validation (without saving models) + +Sometimes you want to validate the model that you are going to save without actually saving it. This is useful, for example, when you want to inform the user of validation errors on the frontend, without having to implement the validation logic again. You may check for validation errors by sending a `POST`, `PUT` or `PATCH` request with an additional query parameter `validate`. + +Currently this is implemented by raising an `BinderValidateOnly` exception, which makes sure that the atomic database transaction is aborted. Ideally, you would only want to call the validation logic on the models, so only calling validation for fields and validation for model (`clean()`). But for now, we do it this way, at the cost of a performance penalty. + +It is important to realize that in this way, the normal `save()` function is called on a model, so it is possible that possible side effects are triggered, when these are implemented directly in `save()`, as opposed to in a signal method, which would be preferable. In other words, we cannot guarantee that the request will be idempotent. Therefore, the validation only feature is disabled by default and must be enabled by setting `allow_standalone_validation=True` on the view. + ### Uploading files To upload a file, you have to add it to the `file_fields` of the `ModelView`: From d0469aec5b329c38c33e281129d455e58bf88348 Mon Sep 17 00:00:00 2001 From: Jeroen van Riel Date: Mon, 29 Mar 2021 14:17:05 +0200 Subject: [PATCH 03/33] Correctly extract querystring parameter for PUT --- binder/views.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/binder/views.py b/binder/views.py index 3a91d65b..acb1ca16 100644 --- a/binder/views.py +++ b/binder/views.py @@ -17,8 +17,8 @@ from django.conf import settings from django.core.exceptions import ObjectDoesNotExist, FieldError, ValidationError, FieldDoesNotExist from django.core.files.base import File, ContentFile -from django.http import HttpResponse, HttpResponseForbidden, FileResponse -from django.http.request import RawPostDataException +from django.http import HttpResponse, StreamingHttpResponse, HttpResponseForbidden, FileResponse +from django.http.request import RawPostDataException, QueryDict from django.http.multipartparser import MultiPartParser from django.db import models, connections from django.db.models import Q, F, Count, Case, When @@ -212,7 +212,7 @@ def get_annotations(model, request=None, annotations=None): if isinstance(expr, F): field = expr._output_field_or_none elif isinstance(expr, BaseExpression): - field = expr.field.clone() + field = field.clone() field.name = attr field.model = model else: @@ -423,7 +423,7 @@ class ModelView(View): # for the information of the front-end developer. comment = None - # If True, the (first part of the) request body will be logged at level=debug. + # If True, the (first part of) the request body will be logged at level=debug. # Set this to False for endpoints that receive passwords etc. log_request_body = True @@ -1786,7 +1786,8 @@ def binder_validation_error(self, obj, validation_error, pk=None): def _abort_when_standalone_validation(self, request): """Raise a `BinderSkipSave` exception when this is a standalone request.""" if self.allow_standalone_validation: - if 'validate' in request.GET: + params = QueryDict(request.body) + if 'validate' in params: raise BinderSkipSave else: raise BinderException('Standalone validation not enabled. You must enable this feature explicitly.') @@ -3179,4 +3180,4 @@ def handler500(request): except Exception as e: request_id = str(e) - return HttpResponse('{"code": "InternalServerError", "debug": {"request_id": "' + request_id + '"}}', status=500) + return HttpResponse('{"code": "InternalServerError", "debug": {"request_id": "' + request_id + '"}}', status=500) \ No newline at end of file From c0398e38ed11c4f75467fe6688c27ca009f32e80 Mon Sep 17 00:00:00 2001 From: Jeroen van Riel Date: Mon, 29 Mar 2021 14:42:13 +0200 Subject: [PATCH 04/33] Raise the correct exception when flag is not set --- binder/views.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/binder/views.py b/binder/views.py index acb1ca16..d410b5e6 100644 --- a/binder/views.py +++ b/binder/views.py @@ -1790,7 +1790,7 @@ def _abort_when_standalone_validation(self, request): if 'validate' in params: raise BinderSkipSave else: - raise BinderException('Standalone validation not enabled. You must enable this feature explicitly.') + raise BinderRequestError('Standalone validation not enabled. You must enable this feature explicitly.') From a9ff25642ad515bede4d51477cf016aae8514d37 Mon Sep 17 00:00:00 2001 From: Jeroen van Riel Date: Fri, 25 Jun 2021 12:11:54 +0200 Subject: [PATCH 05/33] Improve comment --- binder/exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/binder/exceptions.py b/binder/exceptions.py index fbf9bd6f..2844f488 100644 --- a/binder/exceptions.py +++ b/binder/exceptions.py @@ -238,7 +238,7 @@ def __add__(self, other): class BinderSkipSave(BinderException): - """Used to abort the database transaction when performing a non-save validation request.""" + """Used to abort the database transaction when non-save validation was successfull.""" http_code = 200 code = 'SkipSave' From b36f701be99d97136a7d73828880894b1af90376 Mon Sep 17 00:00:00 2001 From: Jeroen van Riel Date: Mon, 29 Mar 2021 14:52:39 +0200 Subject: [PATCH 06/33] Fix stupid logic --- binder/views.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/binder/views.py b/binder/views.py index d410b5e6..32566139 100644 --- a/binder/views.py +++ b/binder/views.py @@ -1785,12 +1785,12 @@ def binder_validation_error(self, obj, validation_error, pk=None): def _abort_when_standalone_validation(self, request): """Raise a `BinderSkipSave` exception when this is a standalone request.""" - if self.allow_standalone_validation: - params = QueryDict(request.body) - if 'validate' in params: + if 'validate' in params: + if self.allow_standalone_validation: + params = QueryDict(request.body) raise BinderSkipSave - else: - raise BinderRequestError('Standalone validation not enabled. You must enable this feature explicitly.') + else: + raise BinderRequestError('Standalone validation not enabled. You must enable this feature explicitly.') From 7759380d4484440066e221e3163d1695e278153c Mon Sep 17 00:00:00 2001 From: Jeroen van Riel Date: Fri, 25 Jun 2021 14:53:22 +0200 Subject: [PATCH 07/33] Stop earlier when flag not set --- binder/views.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/binder/views.py b/binder/views.py index 32566139..f21e8369 100644 --- a/binder/views.py +++ b/binder/views.py @@ -520,6 +520,10 @@ def dispatch(self, request, *args, **kwargs): response = None try: + # only allow standalone validation if you know what you are doing + if 'validate' in request.GET and request.GET['validate'] == 'true' and not self.allow_standalone_validation: + raise BinderRequestError('Standalone validation not enabled. You must enable this feature explicitly.') + #### START TRANSACTION with ExitStack() as stack, history.atomic(source='http', user=request.user, uuid=request.request_id): transaction_dbs = ['default'] @@ -1785,11 +1789,12 @@ def binder_validation_error(self, obj, validation_error, pk=None): def _abort_when_standalone_validation(self, request): """Raise a `BinderSkipSave` exception when this is a standalone request.""" - if 'validate' in params: + if 'validate' in request.GET and request.GET['validate'] == 'true': if self.allow_standalone_validation: params = QueryDict(request.body) raise BinderSkipSave else: + print('validate not enabled') raise BinderRequestError('Standalone validation not enabled. You must enable this feature explicitly.') @@ -2778,6 +2783,9 @@ def put(self, request, pk=None): if hasattr(obj, 'deleted') and obj.deleted: raise BinderIsDeleted() + + logger.info('storing') + data = self._store(obj, values, request) self._abort_when_standalone_validation(request) From ab22be7698e44cfef8df12593243949be4316cdf Mon Sep 17 00:00:00 2001 From: Jeroen van Riel Date: Wed, 30 Jun 2021 14:26:01 +0200 Subject: [PATCH 08/33] Add tests for validation flow. We check if the action is not performed, and thus aborted internally by a BinderSkipSave exception, for POST, PUT, MULTI-PUT and DELETE. --- tests/test_model_validation.py | 228 ++++++++++++++++++++++++++ tests/testapp/models/animal.py | 2 +- tests/testapp/views/caretaker.py | 5 +- tests/testapp/views/contact_person.py | 3 + 4 files changed, 236 insertions(+), 2 deletions(-) create mode 100644 tests/test_model_validation.py diff --git a/tests/test_model_validation.py b/tests/test_model_validation.py new file mode 100644 index 00000000..59dac762 --- /dev/null +++ b/tests/test_model_validation.py @@ -0,0 +1,228 @@ +from re import I +from tests.testapp.models import contact_person +from tests.testapp.models.contact_person import ContactPerson +from django.test import TestCase, Client + +import json +from binder.json import jsonloads +from django.contrib.auth.models import User +from .testapp.models import Animal, Caretaker, ContactPerson + + +class TestModelValidation(TestCase): + """ + Test the validate-only functionality. + + We check that the validation is executed as normal, but that the models + are not created when the validate query paramter is set to true. + + We check validation for: + - post + - put + - multi-put + - delete + """ + + + def setUp(self): + super().setUp() + u = User(username='testuser', is_active=True, is_superuser=True) + u.set_password('test') + u.save() + self.client = Client() + r = self.client.login(username='testuser', password='test') + self.assertTrue(r) + + # some update payload + self.model_data_with_error = { + 'name': 'very_special_forbidden_contact_person_name', # see `contact_person.py` + } + self.model_data = { + 'name': 'Scooooooby', + } + + + ### helpers ### + + + def assert_validation_error(self, response, person_id=None): + if person_id is None: + person_id = 'null' # for post + + self.assertEqual(response.status_code, 400) + + returned_data = jsonloads(response.content) + + # check that there were validation errors + self.assertEqual(returned_data.get('code'), 'ValidationError') + + # check that the validation error is present + validation_error = returned_data.get('errors').get('contact_person').get(str(person_id)).get('__all__')[0] + self.assertEqual(validation_error.get('code'), 'invalid') + self.assertEqual(validation_error.get('message'), 'Very special validation check that we need in `tests.M2MStoreErrorsTest`.') + + + def assert_multi_put_validation_error(self, response): + self.assertEqual(response.status_code, 400) + + returned_data = jsonloads(response.content) + + # check that there were validation errors + self.assertEqual(returned_data.get('code'), 'ValidationError') + + # check that all (two) the validation errors are present + for error in returned_data.get('errors').get('contact_person').values(): + validation_error = error.get('__all__')[0] + self.assertEqual(validation_error.get('code'), 'invalid') + self.assertEqual(validation_error.get('message'), 'Very special validation check that we need in `tests.M2MStoreErrorsTest`.') + + + ### tests ### + + + def assert_no_validation_error(self, response): + self.assertEqual(response.status_code, 200) + + # check that the validation was successful + returned_data = jsonloads(response.content) + self.assertEqual(returned_data.get('code'), 'SkipSave') + self.assertEqual(returned_data.get('message'), 'No validation errors were encountered.') + + + def test_validate_on_post(self): + self.assertEqual(0, ContactPerson.objects.count()) + + # trigger a validation error + response = self.client.post('/contact_person/?validate=true', data=json.dumps(self.model_data_with_error), content_type='application/json') + self.assert_validation_error(response) + self.assertEqual(0, ContactPerson.objects.count()) + + # now without validation errors + response = self.client.post('/contact_person/?validate=true', data=json.dumps(self.model_data), content_type='application/json') + self.assert_no_validation_error(response) + self.assertEqual(0, ContactPerson.objects.count()) + + # now for real + response = self.client.post('/contact_person/', data=json.dumps(self.model_data), content_type='application/json') + self.assertEqual(response.status_code, 200) + self.assertEqual('Scooooooby', ContactPerson.objects.first().name) + + + def test_validate_on_put(self): + person_id = ContactPerson.objects.create(name='Scooby Doo').id + self.assertEqual('Scooby Doo', ContactPerson.objects.first().name) + + # trigger a validation error + response = self.client.put(f'/contact_person/{person_id}/?validate=true', data=json.dumps(self.model_data_with_error), content_type='application/json') + self.assert_validation_error(response, person_id) + self.assertEqual('Scooby Doo', ContactPerson.objects.first().name) + + # now without validation errors + response = self.client.put(f'/contact_person/{person_id}/?validate=true', data=json.dumps(self.model_data), content_type='application/json') + self.assert_no_validation_error(response) + self.assertEqual('Scooby Doo', ContactPerson.objects.first().name) + + # now for real + response = self.client.put(f'/contact_person/{person_id}/', data=json.dumps(self.model_data), content_type='application/json') + self.assertEqual(response.status_code, 200) + self.assertEqual('Scooooooby', ContactPerson.objects.first().name) + + + def test_validate_on_multiput(self): + person_1_id = ContactPerson.objects.create(name='Scooby Doo 1').id + person_2_id = ContactPerson.objects.create(name='Scooby Doo 2').id + + multi_put_data = {'data': [ + { + 'id': person_1_id, + 'name': 'New Scooby', + }, + { + 'id': person_2_id, + 'name': 'New Doo' + } + ]} + + multi_put_data_with_error = {'data': [ + { + 'id': person_1_id, + 'name': 'very_special_forbidden_contact_person_name', + }, + { + 'id': person_2_id, + 'name': 'very_special_forbidden_contact_person_name' + } + ]} + + # trigger a validation error + response = self.client.put(f'/contact_person/?validate=true', data=json.dumps(multi_put_data_with_error), content_type='application/json') + self.assert_multi_put_validation_error(response) + self.assertEqual('Scooby Doo 1', ContactPerson.objects.get(id=person_1_id).name) + self.assertEqual('Scooby Doo 2', ContactPerson.objects.get(id=person_2_id).name) + + + # now without validation error + response = self.client.put(f'/contact_person/?validate=true', data=json.dumps(multi_put_data), content_type='application/json') + self.assert_no_validation_error(response) + self.assertEqual('Scooby Doo 1', ContactPerson.objects.get(id=person_1_id).name) + self.assertEqual('Scooby Doo 2', ContactPerson.objects.get(id=person_2_id).name) + + # now for real + response = self.client.put(f'/contact_person/', data=json.dumps(multi_put_data), content_type='application/json') + self.assertEqual(response.status_code, 200) + self.assertEqual('New Scooby', ContactPerson.objects.get(id=person_1_id).name) + self.assertEqual('New Doo', ContactPerson.objects.get(id=person_2_id).name) + + + def test_validate_on_delete(self): + '''Check if deletion is cancelled when we only attempt to validate + the delete operation. This test only covers validation of the + on_delete=PROTECT constraint of a fk.''' + + def is_deleted(obj): + '''Whether the obj was soft-deleted, so when the 'deleted' + attribute is not present, or when it is True.''' + + try: + obj.refresh_from_db() + except obj.DoesNotExist: + return True # hard-deleted + return animal.__dict__.get('deleted') or False + + + # animal has a fk to caretaker with on_delete=PROTECT + caretaker = Caretaker.objects.create(name='Connie Care') + animal = Animal.objects.create(name='Pony', caretaker=caretaker) + + + ### with validation error + + response = self.client.delete(f'/caretaker/{caretaker.id}/?validate=true') + # assert validation error + # and check that it was about the PROTECTED constraint + self.assertEqual(response.status_code, 400) + returned_data = jsonloads(response.content) + self.assertEqual(returned_data.get('code'), 'ValidationError') + self.assertEqual(returned_data.get('errors').get('caretaker').get(str(caretaker.id)).get('id')[0].get('code'), 'protected') + + self.assertFalse(is_deleted(caretaker)) + + + ### without validation error + + # now we delete the animal to make sure that deletion is possible + # note that soft-deleting will of course not remove the validation error + animal.delete() + + # now no validation error should be trown + response = self.client.delete(f'/caretaker/{caretaker.id}/?validate=true') + print(response.content) + self.assert_no_validation_error(response) + + self.assertFalse(is_deleted(caretaker)) + + + ### now for real + + response = self.client.delete(f'/caretaker/{caretaker.id}/') + self.assertTrue(is_deleted(caretaker)) diff --git a/tests/testapp/models/animal.py b/tests/testapp/models/animal.py index fde79d7b..ae8abc1f 100644 --- a/tests/testapp/models/animal.py +++ b/tests/testapp/models/animal.py @@ -14,7 +14,7 @@ class Animal(LoadedValuesMixin, BinderModel): name = models.TextField(max_length=64) zoo = models.ForeignKey('Zoo', on_delete=models.CASCADE, related_name='animals', blank=True, null=True) zoo_of_birth = models.ForeignKey('Zoo', on_delete=models.CASCADE, related_name='+', blank=True, null=True) # might've been born outside captivity - caretaker = models.ForeignKey('Caretaker', on_delete=models.PROTECT, related_name='animals', blank=True, null=True) + caretaker = models.ForeignKey('Caretaker', on_delete=models.PROTECT, related_name='animals', blank=True, null=True) # we use the fact that this one is PROTECT in `test_model_validation.py` deleted = models.BooleanField(default=False) # Softdelete birth_date = models.DateField(blank=True, null=True) diff --git a/tests/testapp/views/caretaker.py b/tests/testapp/views/caretaker.py index 85698354..a166685a 100644 --- a/tests/testapp/views/caretaker.py +++ b/tests/testapp/views/caretaker.py @@ -9,6 +9,9 @@ class CaretakerView(CsvExportView, ModelView): unupdatable_fields = ['first_seen'] model = Caretaker + # see `test_model_validation.py` + allow_standalone_validation = True + csv_settings = CsvExportView.CsvExportSettings( withs=[], column_map=[ @@ -17,4 +20,4 @@ class CaretakerView(CsvExportView, ModelView): ('scary', 'Scary'), ], extra_params={'include_annotations': 'scary'}, - ) + ) \ No newline at end of file diff --git a/tests/testapp/views/contact_person.py b/tests/testapp/views/contact_person.py index c1e90c5b..bc2f25fe 100644 --- a/tests/testapp/views/contact_person.py +++ b/tests/testapp/views/contact_person.py @@ -6,3 +6,6 @@ class ContactPersonView(ModelView): model = ContactPerson m2m_fields = ['zoos'] unwritable_fields = ['created_at', 'updated_at'] + + # see `test_model_validation.py` + allow_standalone_validation = True From 8218ac03005b6af5da7e5fca9c7d1b1bf113c3e6 Mon Sep 17 00:00:00 2001 From: Jeroen van Riel Date: Wed, 30 Jun 2021 14:30:59 +0200 Subject: [PATCH 09/33] Some cleanup --- binder/exceptions.py | 4 +++- binder/views.py | 6 +----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/binder/exceptions.py b/binder/exceptions.py index 2844f488..75e4e572 100644 --- a/binder/exceptions.py +++ b/binder/exceptions.py @@ -238,7 +238,9 @@ def __add__(self, other): class BinderSkipSave(BinderException): - """Used to abort the database transaction when non-save validation was successfull.""" + """Used to abort the database transaction when validation was successfull. + Validation is possible when saving (post, put, multi-put) or deleting models.""" + http_code = 200 code = 'SkipSave' diff --git a/binder/views.py b/binder/views.py index f21e8369..5bb5f1f3 100644 --- a/binder/views.py +++ b/binder/views.py @@ -1788,13 +1788,12 @@ def binder_validation_error(self, obj, validation_error, pk=None): def _abort_when_standalone_validation(self, request): - """Raise a `BinderSkipSave` exception when this is a standalone request.""" + """Raise a `BinderSkipSave` exception when this is a validation request.""" if 'validate' in request.GET and request.GET['validate'] == 'true': if self.allow_standalone_validation: params = QueryDict(request.body) raise BinderSkipSave else: - print('validate not enabled') raise BinderRequestError('Standalone validation not enabled. You must enable this feature explicitly.') @@ -2783,9 +2782,6 @@ def put(self, request, pk=None): if hasattr(obj, 'deleted') and obj.deleted: raise BinderIsDeleted() - - logger.info('storing') - data = self._store(obj, values, request) self._abort_when_standalone_validation(request) From 007676fbb0b8e2873c9b5715a97e483354a61cf0 Mon Sep 17 00:00:00 2001 From: robin Date: Thu, 22 Jul 2021 11:34:28 +0200 Subject: [PATCH 10/33] Test _validation_model field for clean --- binder/models.py | 8 +- binder/views.py | 1158 ++++++++++------------------------------------ 2 files changed, 253 insertions(+), 913 deletions(-) diff --git a/binder/models.py b/binder/models.py index 685694e0..afc481b0 100644 --- a/binder/models.py +++ b/binder/models.py @@ -617,8 +617,12 @@ class Meta: abstract = True ordering = ['pk'] - def save(self, *args, **kwargs): - self.full_clean() # Never allow saving invalid models! + def save(self, *args, only_validate=False, **kwargs): + # A validation model might not require all validation checks as it is not a full model + # _validation_model can be used to skip validation checks that are meant for complete models that are actually being saved + self._validation_model = only_validate # Set the model as a validation model when we only want to validate the model + + self.full_clean(only_validate=only_validate) # Never allow saving invalid models! return super().save(*args, **kwargs) diff --git a/binder/views.py b/binder/views.py index 5bb5f1f3..aec0f62f 100644 --- a/binder/views.py +++ b/binder/views.py @@ -1,11 +1,12 @@ import logging import time +import io import inspect import os +import hashlib import datetime import mimetypes import functools -import re from collections import defaultdict, namedtuple from contextlib import ExitStack @@ -14,14 +15,11 @@ import django from django.views.generic import View -from django.conf import settings from django.core.exceptions import ObjectDoesNotExist, FieldError, ValidationError, FieldDoesNotExist -from django.core.files.base import File, ContentFile -from django.http import HttpResponse, StreamingHttpResponse, HttpResponseForbidden, FileResponse +from django.http import HttpResponse, StreamingHttpResponse, HttpResponseForbidden from django.http.request import RawPostDataException, QueryDict -from django.http.multipartparser import MultiPartParser from django.db import models, connections -from django.db.models import Q, F, Count, Case, When +from django.db.models import Q, F from django.db.models.lookups import Transform from django.utils import timezone from django.db import transaction @@ -36,69 +34,8 @@ ) from . import history from .orderable_agg import OrderableArrayAgg, GroupConcat, StringAgg -from .models import FieldFilter, BinderModel, ContextAnnotation, OptionalAnnotation, BinderFileField, BinderImageField -from .json import JsonResponse, jsonloads, jsondumps -from .route_decorators import list_route - - -# expr: an aggregate expr to get the statistic, -# filter: a dict of filters to filter the queryset with before getting the aggregate, leading dot not included (optional), -# group_by: a field to group by separated by dots if following relations (optional), -# annotations: a list of annotation names that have to be applied to the queryset for the expr to work (optional), -Stat = namedtuple( - 'Stat', - ['expr', 'filters', 'group_by', 'annotations', 'min_value', 'max_values'], - defaults=[{}, None, [], None, None], -) - - -DEFAULT_STATS = { - 'total_records': Stat(Count(Value(1))), -} - - -def get_joins_from_queryset(queryset): - """ - Given a queryset returns a set of lines that are used to determine which - tables will be joined and how. In essence this is the FROM-statement and - every JOIN-statement in a set as a string. - - This is useful to compare the joins between querysets. - """ - # So to generate sql we need the compiler and connection for the right db - compiler = queryset.query.get_compiler(queryset.db) - connection = connections[queryset.db] - # Now we will just go through all tables in the alias_map - lines = set() - for alias in queryset.query.alias_map.values(): - line, params = alias.as_sql(compiler, connection) - # We assert we have no params for now, you need to do custom stuff for - # these to appear in joins and substituting params into the sql - # is not something we can easily do safely for now, just passing them - # along with the str will make use potentially have unhashable lines - # which will ruin the set. - assert not params - lines.add(line) - return lines - - -def q_get_flat_filters(q): - """ - Given a Q-object returns an iterator of all filters used in this Q-object. - - So for example for Q(foo=1, bar=2) this would yield 'foo' and 'bar', but it - will also work for more complicated nested Q-objects. - - This is useful to detect which fields are used in a Q-object. - """ - for child in q.children: - if isinstance(child, Q): - # If the child is another Q-object we can just yield recursively - yield from q_get_flat_filters(child) - elif isinstance(child, tuple): - # So now the child is a 2-tuple of filter & value, we just need the - # filter so we yield that - yield child[0] +from .models import FieldFilter, BinderModel, ContextAnnotation, OptionalAnnotation, BinderFileField +from .json import JsonResponse, jsonloads def split_par_aware(content): @@ -134,42 +71,6 @@ def get_default_annotations(model): return annotations -def split_path(path): - """ - This function splits a dot seperated path into a list of keys. - The advantage of this function over a simple path.split('.') is that it - handles backslash escaping so that keys can contain dot characters. - """ - path = iter(path) - chars = [] - for char in path: - if char == '.': - yield ''.join(chars) - chars.clear() - continue - if char == '\\': - char = next(path, '') - chars.append(char) - yield ''.join(chars) - - -def join_path(keys): - """ - This function joins a list of keys into a dot seperated path. - The advantage of this function over a simple '.'.join(keys) is that it - handles backslash escaping so that keys can contain dot characters. - """ - chars = [] - for i, key in enumerate(keys): - if i != 0: - chars.append('.') - for char in key: - if char in '\\.': - char = '\\' + char - chars.append(char) - return ''.join(chars) - - # Haha kill me now def multiput_get_id(bla): return bla['id'] if isinstance(bla, dict) else bla @@ -212,7 +113,7 @@ def get_annotations(model, request=None, annotations=None): if isinstance(expr, F): field = expr._output_field_or_none elif isinstance(expr, BaseExpression): - field = field.clone() + field = expr.field.clone() field.name = attr field.model = model else: @@ -319,34 +220,6 @@ def prefix_db_expression(value, prefix): raise ValueError('Unknown expression type, cannot apply db prefix: %s', value) -# Prefix a q expression by adding a prefix to all filters. You can also supply -# an 'antiprefix', if the filter starts with this value it will be removed -# instead of the prefix being added. This is useful for reversed fields. -def prefix_q_expression(value, prefix, antiprefix=None, model=None): - children = [] - for child in value.children: - if isinstance(child, Q): - children.append(prefix_q_expression(child, prefix, antiprefix, model)) - # pk__in with empty list is often used for identity values since django - # doesnt properly support them - elif child[0] == 'pk__in' and child[1] == []: - children.append(child) - elif antiprefix is not None and child[0] == antiprefix: - children.append(('pk', child[1])) - elif antiprefix is not None and child[0].startswith(antiprefix + '__'): - key = child[0][len(antiprefix) + 2:] - head = key.split('__', 1)[0] - if head != 'pk': - try: - model._meta.get_field(head) - except FieldDoesNotExist: - key = 'pk__' + key - children.append((key, child[1])) - else: - children.append((prefix + '__' + child[0], child[1])) - return Q(*children, _connector=value.connector, _negated=value.negated) - - class ModelView(View): # Model this is a view for. Use None for views not tied to a particular model. model = None @@ -423,7 +296,7 @@ class ModelView(View): # for the information of the front-end developer. comment = None - # If True, the (first part of) the request body will be logged at level=debug. + # If True, the (first part of the) request body will be logged at level=debug. # Set this to False for endpoints that receive passwords etc. log_request_body = True @@ -449,21 +322,6 @@ class ModelView(View): # } virtual_relations = {} - # A dict that looks like: - # {'filter_name': [ - # 'field.icontains', - # 'otherField.field:icontains' - # ], - # } - # Allows you to hide multiple filters behind a filter name - # see _parse_filter, see api.md > "Filtering by groups (alternative_filters)" - # NOTICE: alternative_filters may not contain a field or annotation as key - alternative_filters = {} - - # A dict that maps stat name to instances of binder.views.Stat - # These statistics can then be used in the stats view - stats = {} - @property def AggStrategy(self): if connections[self.model.objects.db].vendor == 'mysql': @@ -530,7 +388,7 @@ def dispatch(self, request, *args, **kwargs): # Check if the TRANSACTION_DATABASES is set in the settings.py, and if so, use that instead try: - transaction_dbs = settings.TRANSACTION_DATABASES + transaction_dbs = django.conf.settings.TRANSACTION_DATABASES except AttributeError: pass @@ -621,7 +479,7 @@ def _get_reverse_relations(self): # Kinda like model_to_dict() for multiple objects. # Return a list of dictionaries, one per object in the queryset. # Includes a list of ids for all m2m fields (including reverse relations). - def _get_objs(self, queryset, request, annotations=None, to_annotate={}): + def _get_objs(self, queryset, request, annotations=None): datas = [] datas_by_id = {} # Save datas so we can annotate m2m fields later (avoiding a query) objs_by_id = {} # Same for original objects @@ -639,52 +497,6 @@ def _get_objs(self, queryset, request, annotations=None, to_annotate={}): else: annotations &= set(self.shown_annotations) - # So now annotations are only being used for showing, so we filter out - # all that do not have to be shown - to_annotate = { - key: value - for key, value in to_annotate.items() - if key in annotations - } - - # So now we will divide annotations based on the joins they do - base_joins = get_joins_from_queryset(queryset) - annotation_sets = [] - - for name, expr in list(to_annotate.items()): - annotation_joins = get_joins_from_queryset( - self.model.objects.annotate(**{name: expr}) - ) - annotation_annotations = {name: expr} - # First check if the queryset already does all joins, in that case - # we can just add it to the main queryset without any performance - # hits - if annotation_joins <= base_joins: - queryset = queryset.annotate(**annotation_annotations) - to_annotate.pop(name) - continue - # Then try to merge it into the annotation sets - i = 0 - while i < len(annotation_sets): - set_joins, set_annotations = annotation_sets[i] - # If our joins are a subset of the annotation set we just add - # our annotation to the set and break - if annotation_joins <= set_joins: - set_annotations.update(annotation_annotations) - break - # If our joins are a superset of the annotation set we take its - # annotations and add it to ours - elif set_joins <= annotation_joins: - annotation_annotations.update(set_annotations) - annotation_sets.pop(i) - # Go on to the next - else: - i += 1 - # If no annotation set existed that matched our joins we create a - # new one - else: - annotation_sets.append((annotation_joins, annotation_annotations)) - for obj in queryset: # So we tend to make binder call queryset.distinct when necessary # to prevent duplicate results, this is however not always possible @@ -696,11 +508,12 @@ def _get_objs(self, queryset, request, annotations=None, to_annotate={}): data = {} for f in fields: - if isinstance(f, models.FileField): + if isinstance(f, models.fields.files.FileField): file = getattr(obj, f.attname) if file: # {router-view-instance} data[f.name] = self.router.model_route(self.model, obj.id, f) + # {duplicate-binder-file-field-hash-code} if isinstance(f, BinderFileField): data[f.name] += '?h={}&content_type={}&filename={}'.format( @@ -714,8 +527,7 @@ def _get_objs(self, queryset, request, annotations=None, to_annotate={}): data[f.name] = getattr(obj, f.attname) for a in annotations: - if a not in to_annotate: - data[a] = getattr(obj, a) + data[a] = getattr(obj, a) for prop in self.shown_properties: data[prop] = getattr(obj, prop) @@ -727,18 +539,6 @@ def _get_objs(self, queryset, request, annotations=None, to_annotate={}): datas_by_id[obj.pk] = data objs_by_id[obj.pk] = obj - for _, set_annotations in annotation_sets: - for set_values in ( - self.model.objects - .filter(pk__in=datas_by_id) - .annotate(**set_annotations) - .values('pk', *set_annotations) - ): - pk_ = set_values.pop('pk') - for name, value in set_values.items(): - datas_by_id[pk_][name] = value - setattr(objs_by_id[pk_], name, value) - self._annotate_objs(datas_by_id, objs_by_id) return datas @@ -762,7 +562,7 @@ def _annotate_objs(self, datas_by_id, objs_by_id): for obj_id, data in datas_by_id.items(): # TODO: Don't require OneToOneFields in the m2m_fields list if isinstance(local_field, models.OneToOneRel): - assert len(idmap[obj_id]) <= 1 + assert(len(idmap[obj_id]) <= 1) data[field_name] = idmap[obj_id][0] if len(idmap[obj_id]) == 1 else None else: data[field_name] = idmap[obj_id] @@ -778,15 +578,10 @@ def _annotate_objs(self, datas_by_id, objs_by_id): def _get_obj(self, pk, request, include_annotations=None): if include_annotations is None: include_annotations = self._parse_include_annotations(request) - annotations = include_annotations.get('') results = self._get_objs( - self.get_queryset(request).filter(pk=pk), + annotate(self.get_queryset(request).filter(pk=pk), request, include_annotations.get('')), request=request, - annotations=annotations, - to_annotate={ - name: value['expr'] - for name, value in get_annotations(self.model, request, annotations).items() - }, + annotations=include_annotations.get(''), ) if results: return results[0] @@ -958,13 +753,9 @@ def withs_to_nested_set(withs, result={}): view.router = self.router for annotations, with_pks in annotation_ids.items(): objs = view._get_objs( - view.get_queryset(request).filter(pk__in=with_pks), + annotate(view.get_queryset(request).filter(pk__in=with_pks), request, annotations), request=request, annotations=annotations, - to_annotate={ - name: value['expr'] - for name, value in get_annotations(view.model, request, annotations).items() - }, ) for obj in objs: view._annotate_obj_with_related_withs(obj, withs_per_model[model_name]) @@ -1091,10 +882,13 @@ def _follow_related(self, fieldspec): def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): result = {} + annotations = {} singular_fields = set() rel_ids_by_field_by_id = defaultdict(lambda: defaultdict(list)) virtual_fields = set() + Agg = self.AggStrategy + for field in with_map: vr = self.virtual_relations.get(field, None) @@ -1106,6 +900,12 @@ def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): if rel == field or rel.startswith(field + '.') }) + # Model default orders (this sometimes matters) + orders = [] + field_alias = field + '___annotation' if vr else field + for o in (view.model._meta.ordering if view.model._meta.ordering else BinderModel.Meta.ordering): + orders.append(prefix_db_expression(o, field_alias)) + # Virtual relation if vr: virtual_fields.add(field) @@ -1121,43 +921,51 @@ def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): # annotations, so allow for fetching of ids, instead. # This does mean you can't filter on this relation # unless you write a custom filter, too. - try: - func = getattr(self, virtual_annotation) - except AttributeError: - raise BinderRequestError('Annotation for virtual relation {{{}}}.{{{}}} is {{{}}}, but no method by that name exists.'.format( - self.model.__name__, field, virtual_annotation - )) - rel_ids_by_field_by_id[field] = func(request, pks, q) + if isinstance(virtual_annotation, Q): + annotations[field_alias] = Agg(virtual_annotation, filter=q, ordering=orders) + else: + try: + func = getattr(self, virtual_annotation) + except AttributeError: + raise BinderRequestError('Annotation for virtual relation {{{}}}.{{{}}} is {{{}}}, but no method by that name exists.'.format( + self.model.__name__, field, virtual_annotation + )) + rel_ids_by_field_by_id[field] = func(request, pks, q) # Actual relation else: - f = self.model._meta.get_field(field) - - if f.one_to_one or f.many_to_one: + if (getattr(self.model, field).__class__ == models.fields.related.ReverseOneToOneDescriptor or + not any(f.name == field for f in (list(self.model._meta.many_to_many) + list(self._get_reverse_relations())))): singular_fields.add(field) - if any(f.name == field for f in self._get_reverse_relations()): - rev_field = f.remote_field.name - query = ( - view.model.objects - .filter(prefix_q_expression(q, rev_field, field, view.model), **{rev_field + '__in': pks}) - .values_list(rev_field + '__pk', 'pk') - .distinct() - ) - else: - # Model default orders (this sometimes matters) - orders = [] - for o in (view.model._meta.ordering if view.model._meta.ordering else BinderModel.Meta.ordering): - orders.append(prefix_db_expression(o, field)) - query = ( - self.model.objects - .order_by(*orders) - .filter(q, pk__in=pks, **{field + '__isnull': False}) - .values_list('pk', field + '__pk') - .distinct() - ) + if Agg != GroupConcat: # HACKK (GROUP_CONCAT can't filter and excludes NULL already) + q &= Q(**{field+'__pk__isnull': False}) + annotations[field_alias] = Agg(field+'__pk', filter=q, ordering=orders) + + + qs = self.model.objects.filter(pk__in=pks).values('pk').annotate(**annotations) + for record in qs: + for field in with_map: + field_alias = field+'___annotation' if field in virtual_fields else field - for pk, rel_pk in query: - rel_ids_by_field_by_id[field][pk].append(rel_pk) + if field_alias in annotations: + value = record[field_alias] + + # Make the values distinct. We can't do this in + # the Agg() call, because then we get an error + # regarding order by and values needing to be the + # same :( + # We also can't just put it in a set, because we + # need to preserve the ordering. So we use a set + # to keep track of what we've seen and only add + # new items. + seen_values = set() + distinct_values = [] + for v in value: + if v not in seen_values: + distinct_values.append(v) + seen_values.add(v) + + rel_ids_by_field_by_id[field][record['pk']] += distinct_values for field, sub_fields in with_map.items(): next = self._follow_related(field)[0].model @@ -1202,53 +1010,10 @@ def _filter_relation(self, field_name, where_map, request, include_annotations): return FilterDescription(q, need_distinct) - # handle self.alternative_filters as part of _parse_filter - def _parse_alternative_filters(self, field, filter_heads, *args, **kwargs): - head, tail = re.fullmatch(r'([^.:]*)(.*)', field).groups() - - # split not/any/all of the tail, by default a filter is treated as any - # NOTE: not is_any ==> is_all - is_not = bool(re.search(r':not\b', tail)) - - is_any = bool(re.match(r':any\b', tail)) - if is_any: - tail = tail[4:] - is_all = bool(re.match(r':all\b', tail)) - if is_all and not is_any: - # if both is_any and is_all are true, we want the filter to fail - tail = tail[4:] - - alts = [] - for head in filter_heads: - field = head + tail - alt = self._parse_filter(field, *args, **kwargs) - alts.append(alt) - q, needs_distinct = alts[0] - for q_, needs_distinct_ in alts[1:]: - if is_not == is_all: - # :any (default) - # :not:all (NOTE that not is maintained inside the filter) - q |= q_ - else: - # :all - # :not:any (NOTE that not is maintained inside the filter) - q &= q_ - needs_distinct = needs_distinct or needs_distinct_ - return FilterDescription(q, needs_distinct) - def _parse_filter(self, field, value, request, include_annotations, partial=''): head, *tail = field.split('.') need_distinct = False - alt_head = head.split(':')[0] - try: - # get head without trailing :sorter - filter_heads = self.alternative_filters[alt_head] - except KeyError: - pass - else: - return self._parse_alternative_filters(field, filter_heads, value, request, include_annotations, partial) - if not tail: invert = False try: @@ -1360,9 +1125,10 @@ def _parse_order_by(self, queryset, field, request, partial=''): return (queryset, partial + head, nulls_last) - def _search_base(self, search, request): + + def search(self, queryset, search, request): if not search: - return ~Q(pk__in=[]) + return queryset if not (self.searches or self.transformed_searches): raise BinderRequestError('No search fields defined for this view.') @@ -1376,12 +1142,7 @@ def _search_base(self, search, request): q |= Q(**{s: transform(search)}) except ValueError: pass - - return q - - - def search(self, queryset, search, request): - return queryset.filter(self._search_base(search, request)) + return queryset.filter(q) def filter_deleted(self, queryset, pk, deleted, request): @@ -1442,23 +1203,13 @@ def get_queryset(self, request): - def _order_by_base(self, queryset, request, annotations): + def order_by(self, queryset, request): #### order_by order_bys = list(filter(None, request.GET.get('order_by', '').split(','))) orders = [] if order_bys: for o in order_bys: - # We split of a leading - (descending sorting) and the - # suffixes nulls_last and nulls_first - head = re.match(r'^-?(.*?)(__nulls_last|__nulls_first)?$', o).group(1) - try: - expr = annotations.pop(head) - except KeyError: - pass - else: - queryset = queryset.annotate(**{head: expr}) - if o.startswith('-'): queryset, order, nulls_last = self._parse_order_by(queryset, o[1:], request, partial='-') else: @@ -1494,29 +1245,16 @@ def _order_by_base(self, queryset, request, annotations): return queryset - def order_by(self, queryset, request): - return self._order_by_base(queryset, request, {}) - - def _annotate_obj_with_related_withs(self, obj, field_results): for (w, (view, ids_dict, is_singular)) in field_results.items(): if '.' not in w: if is_singular: try: obj[w] = list(ids_dict[obj['id']])[0] - except (IndexError): - """ - Indexerror => no relation is found - KeyERror => No relation is known for this model. - """ + except IndexError: obj[w] = None - except KeyError: - pass else: - try: - obj[w] = list(ids_dict[obj['id']]) - except KeyError: - pass + obj[w] = list(ids_dict[obj['id']]) def _generate_meta(self, include_meta, queryset, request, pk=None): @@ -1526,111 +1264,14 @@ def _generate_meta(self, include_meta, queryset, request, pk=None): # Only 'pk' values should reduce DB server memory a (little?) bit, making # things faster. Not prefetching related models here makes it faster still. # See also https://code.djangoproject.com/ticket/23771 and related tickets. - meta['total_records'] = queryset.order_by().prefetch_related(None).values('pk').count() + meta['total_records'] = queryset.prefetch_related(None).values('pk').count() return meta - def _apply_q_with_possible_annotations(self, queryset, q, annotations): - for filter in q_get_flat_filters(q): - head = filter.split('__', 1)[0] - try: - expr = annotations.pop(head) - except KeyError: - pass - else: - queryset = queryset.annotate(**{head: expr}) - - return queryset.filter(q) - - - def _after_expr(self, request, after_id, include_annotations): - """ - This method given a request and an id returns a boolean expression that - indicates if a record would show up after the provided id for the - ordering specified by this request. - """ - queryset = self.get_queryset(request) - annotations = { - name: value['expr'] - for name, value in self.annotations(request, include_annotations).items() - } - - # We do an order by on a copy of annotations so that we see which keys - # it pops - annotations_copy = annotations.copy() - ordering = self._order_by_base(queryset, request, annotations_copy).query.order_by - required_annotations = set(annotations) - set(annotations_copy) - - queryset = queryset.annotate(**{name: annotations[name] for name in required_annotations}) - try: - obj = queryset.get(pk=int(after_id)) - except (ValueError, self.model.DoesNotExist): - raise BinderRequestError(f'invalid value for after_id: {after_id!r}') - - # Now we will build up a comparison expr based on the order by - whens = [] - - for field in ordering: - # Fields are generally strings, except in some edge case, where it is an OrderBy expression. - # In thase case, we need to change it back to starting (sigh) for compatability reasons - - if type(field) is not str: - _field = field - - field = field.expression.name - - if _field.descending: - field = f'-{field}' - if _field.nulls_first: - field = f'{field}__nulls_first' - if _field.nulls_last: - field = f'{field}__nulls_last' - - # First we have to split of a leading '-' as indicating reverse - reverse = field.startswith('-') - if reverse: - field = field[1:] - - # Then we determine if nulls come last - if field.endswith('__nulls_last'): - field = field[:-12] - nulls_last = True - elif field.endswith('__nulls_first'): - field = field[:-13] - nulls_last = False - elif connections[self.model.objects.db].vendor == 'mysql': - # In MySQL null is considered to be the lowest possible value for ordering - nulls_last = reverse - else: - # In other databases null is considered to be the highest possible value for ordering - nulls_last = not reverse - - # Then we determine what the value is for the obj we need to be after - value = obj - for attr in field.split('__'): - value = getattr(value, attr, None) - if isinstance(value, models.Model): - value = value.pk - - # Now we add some conditions for the comparison - if value is None: - # If the value is None, that means we have to add a condition for when the field is not None because only then it is different - # What the result should be in that case is determined by nulls last - whens.append(When(Q(**{field + '__isnull': False}), then=Value(not nulls_last))) - else: - # If the field is None we give a result based on nulls last - whens.append(When(Q(**{field: None}), then=Value(nulls_last))) - # Otherwise we check with comparisons, note that equality is intentionally left open with these two options so in that case we go on to the next field - whens.append(When(Q(**{field + '__lt': value}), then=Value(reverse))) - whens.append(When(Q(**{field + '__gt': value}), then=Value(not reverse))) - - expr = Case(*whens, default=Value(False)) - - return expr, required_annotations - + def get(self, request, pk=None, withs=None, include_annotations=None): + include_meta = request.GET.get('include_meta', 'total_records').split(',') - def _get_filtered_queryset_base(self, request, pk=None, include_annotations=None): queryset = self.get_queryset(request) if pk: queryset = queryset.filter(pk=int(pk)) @@ -1646,83 +1287,32 @@ def _get_filtered_queryset_base(self, request, pk=None, include_annotations=None #### annotations if include_annotations is None: include_annotations = self._parse_include_annotations(request) - - annotations = { - name: value['expr'] - for name, value in self.annotations(request, include_annotations).items() - } + queryset = annotate(queryset, request, include_annotations.get('')) #### filters filters = {k.lstrip('.'): v for k, v in request.GET.lists() if k.startswith('.')} for field, values in filters.items(): - for v in values: q, distinct = self._parse_filter(field, v, request, include_annotations) - queryset = self._apply_q_with_possible_annotations(queryset, q, annotations) + queryset = queryset.filter(q) if distinct: queryset = queryset.distinct() #### search if 'search' in request.GET: - q = self._search_base(request.GET['search'], request) - queryset = self._apply_q_with_possible_annotations(queryset, q, annotations) + queryset = self.search(queryset, request.GET['search'], request) - #### after - try: - after = request.GET['after'] - except KeyError: - pass - else: - after_expr, required_annotations = self._after_expr(request, after, include_annotations) - for name in required_annotations: - try: - expr = annotations.pop(name) - except KeyError: - pass - else: - queryset = queryset.annotate(**{name: expr}) - queryset = queryset.filter(after_expr) - - return queryset, annotations - - def get_filtered_queryset(self, request, *args, **kwargs): - """ - Returns a scoped queryset with filtering and sorting applied as - specified by the request. - """ - queryset, annotations = self._get_filtered_queryset_base(request, *args, **kwargs) - queryset = queryset.annotate(**annotations) queryset = self.order_by(queryset, request) - return queryset - - def get(self, request, pk=None, withs=None, include_annotations=None): - include_meta = request.GET.get('include_meta', 'total_records').split(',') - if include_annotations is None: - include_annotations = self._parse_include_annotations(request) - - queryset, annotations = self._get_filtered_queryset_base(request, pk, include_annotations) meta = self._generate_meta(include_meta, queryset, request, pk) - queryset = self._order_by_base(queryset, request, annotations) queryset = self._paginate(queryset, request) - # We fetch the data with only the currently applied annotations - data = self._get_objs( - queryset, - request=request, - annotations=include_annotations.get(''), - to_annotate=annotations, - ) - - # Now we add all remaining annotations to this data - data_by_pk = {obj['id']: obj for obj in data} - pks = set(data_by_pk) - #### with # parse wheres from request - extras, extras_mapping, extras_reverse_mapping, field_results = self._get_withs(pks, withs, request=request, include_annotations=include_annotations) + extras, extras_mapping, extras_reverse_mapping, field_results = self._get_withs(queryset, withs, request=request, include_annotations=include_annotations) + data = self._get_objs(queryset, request=request, annotations=include_annotations.get('')) for obj in data: self._annotate_obj_with_related_withs(obj, field_results) @@ -1736,7 +1326,7 @@ def get(self, request, pk=None, withs=None, include_annotations=None): meta['comment'] = self.comment debug = {'request_id': request.request_id} - if settings.DEBUG and 'debug' in request.GET: + if django.conf.settings.DEBUG and 'debug' in request.GET: debug['queries'] = ['{}s: {}'.format(q['time'], q['sql'].replace('"', '')) for q in django.db.connection.queries] debug['query_count'] = len(django.db.connection.queries) @@ -1807,6 +1397,9 @@ def _store(self, obj, values, request, ignore_unknown_fields=False, pk=None): ignored_fields = [] validation_errors = [] + # When only validating and not saving we attach a parameter so that we can skip or add validation checks + only_validate = request.GET.get('validate') == 'true' or request.GET.get('validate') == 'True' + if obj.pk is None: self._require_model_perm('add', request, obj.pk) else: @@ -1844,8 +1437,8 @@ def store_m2m_field(obj, field, value, request): raise sum(validation_errors, None) try: - obj.save() - assert obj.pk is not None # At this point, the object must have been created. + obj.save(only_validate=only_validate) + assert(obj.pk is not None) # At this point, the object must have been created. except ValidationError as ve: validation_errors.append(self.binder_validation_error(obj, ve, pk=pk)) @@ -1867,23 +1460,15 @@ def store_m2m_field(obj, field, value, request): # Skip re-fetch and serialization via get_objs if we're in # multi-put (data is discarded!). - if ( - getattr(request, '_is_multi_put', False) or # Multi put handles its own return data - getattr(request, '_is_file_upload', False) # Dispatch file field handles its own return data - ): + if getattr(request, '_is_multi_put', False): return None # Permission checks are done at this point, so we can avoid get_queryset() include_annotations = self._parse_include_annotations(request) - annotations = include_annotations.get('') data = self._get_objs( - self.model.objects.filter(pk=obj.pk), + annotate(self.model.objects.filter(pk=obj.pk), request, include_annotations.get('')), request=request, - annotations=annotations, - to_annotate={ - name: value['expr'] - for name, value in get_annotations(self.model, request, annotations).items() - }, + annotations=include_annotations.get(''), )[0] data['_meta'] = {'ignored_fields': ignored_fields} return data @@ -1895,6 +1480,9 @@ def store_m2m_field(obj, field, value, request): def _store_m2m_field(self, obj, field, value, request): validation_errors = [] + # When only validating and not saving we attach a parameter so that we can skip or add validation checks + only_validate = request.GET.get('validate') == 'true' or request.GET.get('validate') == 'True' + # Can't use isinstance() because apparantly ManyToManyDescriptor is a subclass of # ReverseManyToOneDescriptor. Yes, really. if getattr(obj._meta.model, field).__class__ == models.fields.related.ReverseManyToOneDescriptor: @@ -1937,11 +1525,11 @@ def _store_m2m_field(self, obj, field, value, request): for addobj in obj_field.model.objects.filter(id__in=new_ids - old_ids): setattr(addobj, obj_field.field.name, obj) try: - addobj.save() + addobj.save(only_validate=only_validate) except ValidationError as ve: validation_errors.append(self.binder_validation_error(addobj, ve)) else: - addobj.save() + addobj.save(only_validate=only_validate) elif getattr(obj._meta.model, field).__class__ == models.fields.related.ReverseOneToOneDescriptor: #### XXX FIXME XXX ugly quick fix for reverse relation + multiput issue if any(v for v in value if v is not None and v < 0): @@ -1957,7 +1545,7 @@ def _store_m2m_field(self, obj, field, value, request): remote_obj = field_descriptor.related.remote_field.model.objects.get(pk=value[0]) setattr(remote_obj, field_descriptor.related.remote_field.name, obj) try: - remote_obj.save() + remote_obj.save(only_validate=only_validate) remote_obj.refresh_from_db() except ValidationError as ve: validation_errors.append(self.binder_validation_error(remote_obj, ve)) @@ -1975,105 +1563,6 @@ def _store_m2m_field(self, obj, field, value, request): raise sum(validation_errors, None) - def _resolve_file_path(self, value, request): - match = re.match(r'/api/(\w+)/(\d+)/(\w+)/', value) - if not match: - return None - - model, pk, field = match.groups() - try: - model = self.router.name_models[model] - except KeyError: - return None - - view = self.get_model_view(model) - if field not in view.file_fields: - return None - - try: - obj = view.get_queryset(request).get(pk=pk) - except model.DoesNotExist: - return None - - value = getattr(obj, field) - with value.open('rb') as f: - return ContentFile(f.read(), value.name) - - - def _clean_image_file(self, field, value): - try: - img = Image.open(value) - except Exception: - raise BinderImageError('Could not parse the file as an image.') - - format = img.format.lower() - if not format in ('png', 'gif', 'jpeg'): - raise BinderFileTypeIncorrect([{'extension': t, 'mimetype': 'image/' + t} for t in ['jpeg', 'png', 'gif']]) - - width, height = img.size - if format == 'jpeg': - img2 = image_transpose_exif(img) - - if img2 != img: - value.seek(0) # Do not append to the existing file! - value.truncate() - img2.save(value, 'jpeg') - img = img2 - - # Determine resize threshold - try: - max_size = self.image_resize_threshold[field] - except TypeError: - max_size = self.image_resize_threshold - - try: - max_width, max_height = max_size - except (TypeError, ValueError): - max_width, max_height = max_size, max_size - - try: - format_override = self.image_format_override.get(field) - except AttributeError: - format_override = self.image_format_override - - changes = False - - # FIXME: hardcoded max - # Flat out refuse images exceeding this size, to prevent DoS. - width_limit, height_limit = max(max_width, 4096), max(max_height, 4096) - if width > width_limit or height > height_limit: - raise BinderImageSizeExceeded(width_limit, height_limit) - - # Resize images that are too large. - if width > max_width or height > max_height: - img.thumbnail((max_width, max_height), Image.LANCZOS) - logger.info('image dimensions ({}x{}) exceeded ({}, {}), resizing.'.format(width, height, max_width, max_height)) - if img.mode not in ["1", "L", "P", "RGB", "RGBA"]: - img = img.convert("RGB") - if format != 'jpeg': - format = 'png' - changes = True - - # Saving a JPEG with mode RGBA will crash because JPEG does not support - # an alpha channel, so in this case we convert to RGB - if format_override == 'jpeg' and img.mode == 'RGBA': - img = img.convert('RGB') - - if format_override and format != format_override: - format = format_override - changes = True - - name, ext = os.path.splitext(value.name) - if ext != '.' + format and not (ext == '.jpg' and format == 'jpeg'): - ext = '.' + format - changes = True - - if changes: - filename = name + ext - value = ContentFile(b'', name=filename) - img.save(value, format) - - return value # Override _store_field example for a "FOO" field @@ -2093,6 +1582,7 @@ def _store_field(self, obj, field, value, request, pk=None): 'id', 'pk', 'deleted', '_meta', *self.unwritable_fields, *self.shown_properties, + *self.file_fields, *self.annotations(request), ]: raise BinderReadOnlyFieldError(self.model.__name__, field) @@ -2125,7 +1615,6 @@ def _store_field(self, obj, field, value, request, pk=None): # Hack, set the id directly. This does the actual check, and throws the BinderError in # the same way the old case has. setattr(obj, f.attname, value) - elif isinstance(f, models.IntegerField): if value is None or value == '': value = None @@ -2146,7 +1635,6 @@ def _store_field(self, obj, field, value, request, pk=None): } }) setattr(obj, f.attname, value) - elif isinstance(f, models.TextField): # Django doesn't enforce max_length on TextFields, so we do. if f.max_length is not None: @@ -2167,46 +1655,6 @@ def _store_field(self, obj, field, value, request, pk=None): } }) setattr(obj, f.attname, value) - - elif isinstance(f, models.FileField): - # If the value is a str that matches how we return file - # fields we convert it to the correct file - if isinstance(value, str): - value_ = self._resolve_file_path(value, request) - if value_ is not None: - value = value_ - - if not isinstance(value, File) and value is not None: - raise BinderFieldTypeError(self.model.__name__, field) - - if value is not None: - if value.size > self.max_upload_size * 10**6: - raise BinderFileSizeExceeded(self.max_upload_size) - - if isinstance(f, models.ImageField): - allowed_extensions = ['png', 'gif', 'jpg', 'jpeg'] - elif isinstance(f, BinderFileField): - allowed_extensions = f.allowed_extensions - else: - allowed_extensions = None - - if allowed_extensions is not None: - extension = os.path.splitext(value.name)[1][1:].lower() - allowed_extensions = {ext.lower() for ext in allowed_extensions} - if extension not in allowed_extensions: - raise BinderFileTypeIncorrect([{'extension': t} for t in allowed_extensions]) - - if isinstance(f, (models.ImageField, BinderImageField)): - value = self._clean_image_file(field, value) - - old_value = getattr(obj, f.attname) - if old_value.name: - @transaction.on_commit - def delete_old_file(): - old_value.storage.delete(old_value.name) - - setattr(obj, f.attname, value) - else: try: f.to_python(value) @@ -2319,7 +1767,7 @@ def _obj_diff(self, old, new, name): # Put data and with on one big pile, that's easier for us def _multi_put_parse_request(self, request): - body = self._get_request_values(request) + body = jsonloads(request.body) data = body.get('with', {}) if not isinstance(data, dict): @@ -2382,17 +1830,8 @@ def _multi_put_override_superclass(self, objects): # Collect overrides for (cls, mid), data in objects.items(): for subcls in getsubclasses(cls): - # Get remote field of the subclass - remote_field = subcls._meta.pk.remote_field - - # In some scenarios with proxy models - # The remote field may not exist - # Because proxy models are just pure python wrappers(without its own db table) for other models - if remote_field is None: - continue - # Get key of field pointing to subclass - subkey = remote_field.name + subkey = subcls._meta.pk.remote_field.name # Get id of subclass subid = data.pop(subkey, None) if subid is None: @@ -2550,7 +1989,7 @@ def _multi_put_save_objects(self, ordered_objects, objects, request): # corresponding view here? That would make it # more consistent with non-multi-PUT and POST, # also requiring view permissions. - qs = model.objects.filter(pk__in=oids).order_by().select_for_update() + qs = model.objects.filter(pk__in=oids).select_for_update() for obj in qs: locked_objects[(model, obj.pk)] = obj @@ -2688,65 +2127,7 @@ def multi_put(self, request): return JsonResponse({'idmap': output}) def _get_request_values(self, request): - # So normally we just parse json here but for multipart form data we have some special logic to inject files - # The values then would look a bit like this: - # - # data={"foo": 1, "bar": null, "baz": [1, 2, null]} - # file:bar= - # file:baz.2= - # - # Where the output will be the data but with the null values replaced by the specified files. - # The paths that the files use as key have to be in the data and have value null. - # This is because we do not want to alter the structure of the data because that gets messy with things like array indexes etc. - - if request.content_type == 'multipart/form-data': - # Django only parses multipart automatically on POST, so for PUT/PATCH/DELETE etc we do it manually - if request.method == 'POST': - fields = request.POST - files = request.FILES - else: - parser = MultiPartParser(request.META, request, request.upload_handlers) - fields, files = parser.parse() - try: - data = fields['data'] - except KeyError: - raise BinderRequestError('data field is required in multipart body') - else: - data = request.body - files = {} - - data = jsonloads(data) - - for path, value in files.items(): - if not path.startswith('file:'): - continue - target = data - keys = list(split_path(path[5:])) - for i, key in enumerate(keys): - if isinstance(target, list): - try: - key = int(key) - except ValueError: - raise BinderRequestError( - 'expected integer key at path: ' + join_path(keys[:i + 1]) - ) - if not ( - (isinstance(target, dict) and key in target) or - (isinstance(target, list) and 0 <= key < len(target)) - ): - raise BinderRequestError( - 'unexpected key at path: ' + join_path(keys[:i + 1]) - ) - if i != len(keys) - 1: - target = target[key] - elif target[key] is not None: - raise BinderRequestError( - 'expected null at path: ' + join_path(keys) - ) - else: - target[key] = value - - return data + return jsonloads(request.body) def put(self, request, pk=None): if pk is None: @@ -2757,24 +2138,13 @@ def put(self, request, pk=None): values = self._get_request_values(request) try: - # Step one does the permission check. We cannot do a select for update here, since the queryeset - # of get_queryset can potentially have outer joins on nullable values - obj = self.get_queryset(request).get(pk=int(pk)) - - # Now that we know we have access to this moel, we can get it again, this time with lock. - obj = self.model.objects.select_for_update().get(pk=obj.pk) - + obj = self.get_queryset(request).select_for_update().get(pk=int(pk)) # Permission checks are done at this point, so we can avoid get_queryset() include_annotations = self._parse_include_annotations(request) - annotations = include_annotations.get('') old = self._get_objs( - self.model.objects.filter(pk=int(pk)), - request=request, - annotations=annotations, - to_annotate={ - name: value['expr'] - for name, value in get_annotations(self.model, request, annotations).items() - }, + annotate(self.model.objects.filter(pk=int(pk)), request, include_annotations.get('')), + request, + include_annotations.get(''), )[0] except ObjectDoesNotExist: raise BinderNotFound() @@ -2849,14 +2219,10 @@ def delete(self, request, pk=None, undelete=False, skip_body_check=False): except ValueError: pass - # This make sure we do all permission checks. We cannot do a select for update here, since there is a possiblity - # that we create a queryset that we cannot use in a for select clause. try: - obj = self.get_queryset(request).get(pk=int(pk)) + obj = self.get_queryset(request).select_for_update().get(pk=int(pk)) except ObjectDoesNotExist: raise BinderNotFound() - # We now retrieve the model again, without the permission checks, which we already this. - obj = self.model.objects.select_for_update().get(pk=obj.pk) self.delete_obj(obj, undelete, request) @@ -2874,6 +2240,10 @@ def delete_obj(self, obj, undelete, request): def soft_delete(self, obj, undelete, request): + + # When only validating and not saving we attach a parameter so that we can skip or add validation checks + only_validate = request.GET.get('validate') == 'true' or request.GET.get('validate') == 'True' + # Not only for soft delets, actually handles all deletions try: if obj.deleted and not undelete: @@ -2905,7 +2275,7 @@ def soft_delete(self, obj, undelete, request): obj.deleted = not undelete try: - obj.save() + obj.save(only_validate=only_validate) except ValidationError as ve: raise self.binder_validation_error(obj, ve) @@ -2926,33 +2296,17 @@ def dispatch_file_field(self, request, pk=None, file_field=None): file_field_name = file_field file_field = getattr(obj, file_field_name) - field = self.model._meta.get_field(file_field_name) if request.method == 'GET': if not file_field: raise BinderNotFound(file_field_name) - guess = mimetypes.guess_type(file_field.name) - content_type = (guess and guess[0]) or 'application/octet-stream' - serve_directly = isinstance(field, BinderFileField) and field.serve_directly - - is_video = content_type and content_type.startswith('video/') - + guess = mimetypes.guess_type(file_field.path) + guess = guess[0] if guess and guess[0] else 'application/octet-stream' try: - if serve_directly: - resp = HttpResponse(content_type=content_type) - resp[settings.INTERNAL_MEDIA_HEADER] = os.path.join(settings.INTERNAL_MEDIA_LOCATION, file_field.name) - # if the filefield does not start with '/' it is likely a http address (S3) instead of path - # and because of that we set the redirect url - if not file_field.url.startswith('/'): - resp['redirect_url'] = file_field.url - else: - file_handle = file_field.open('rb') - resp = FileResponse(file_handle, content_type=content_type) - if is_video: - resp['Accept-Ranges'] = 'bytes' + resp = StreamingHttpResponse(open(file_field.path, 'rb'), content_type=guess) except FileNotFoundError: - logger.error('Expected file {} not found'.format(file_field.name)) + logger.error('Expected file {} not found'.format(file_field.path)) raise BinderNotFound(file_field_name) if 'download' in request.GET: @@ -2963,40 +2317,146 @@ def dispatch_file_field(self, request, pk=None, file_field=None): return resp if request.method == 'POST': + self._require_model_perm('change', request) + try: # Take an arbitrary uploaded file file = next(request.FILES.values()) except StopIteration: raise BinderRequestError('File POST should use multipart/form-data (with an arbitrary key for the file data).') - # Hack to communicate to _store() that we're not interested in - # the new data (for perf reasons). - request._is_file_upload = True - self._store(obj, {file_field_name: file}, request, pk=pk) - - field = self.model._meta.get_field(file_field_name) - path = self.router.model_route(self.model, obj.id, field) - # {duplicate-binder-file-field-hash-code} - if isinstance(field, BinderFileField): - file_field = getattr(obj, file_field_name) - path += '?h={}&content_type={}&filename={}'.format( - file_field.content_hash, - file_field.content_type or '', - os.path.basename(file_field.name), - ) + try: + if file.size > self.max_upload_size * 10**6: + raise BinderFileSizeExceeded(self.max_upload_size) + + field = self.model._meta.get_field(file_field_name) + + if getattr(field, 'allowed_extensions', None) is not None: + extension = None if '.' not in file.name else file.name.split('.')[-1] + + if extension not in field.allowed_extensions: + raise BinderFileTypeIncorrect([{'extension': t} for t in field.allowed_extensions]) + + if isinstance(field, models.fields.files.ImageField): + try: + img = Image.open(file) + except Exception: + raise BinderImageError('Could not parse the file as an image.') - return JsonResponse({'data': {file_field_name: path}}) + format = img.format.lower() + if not format in ('png', 'gif', 'jpeg'): + raise BinderFileTypeIncorrect([{'extension': t, 'mimetype': 'image/' + t} for t in ['jpeg', 'png', 'gif']]) + + width, height = img.size + if format == 'jpeg': + img2 = image_transpose_exif(img) + + if img2 != img: + file.seek(0) # Do not append to the existing file! + file.truncate() + img2.save(file, 'jpeg') + img = img2 + + # Determine resize threshold + try: + max_size = self.image_resize_threshold[file_field_name] + except TypeError: + max_size = self.image_resize_threshold + + try: + max_width, max_height = max_size + except (TypeError, ValueError): + max_width, max_height = max_size, max_size + + try: + format_override = self.image_format_override.get(file_field_name) + except AttributeError: + format_override = self.image_format_override + + changes = False + + # FIXME: hardcoded max + # Flat out refuse images exceeding this size, to prevent DoS. + width_limit, height_limit = max(max_width, 4096), max(max_height, 4096) + if width > width_limit or height > height_limit: + raise BinderImageSizeExceeded(width_limit, height_limit) + + # Resize images that are too large. + if width > max_width or height > max_height: + img.thumbnail((max_width, max_height), Image.ANTIALIAS) + logger.info('image dimensions ({}x{}) exceeded ({}, {}), resizing.'.format(width, height, max_width, max_height)) + if img.mode not in ["1", "L", "P", "RGB", "RGBA"]: + img = img.convert("RGB") + if format != 'jpeg': + format = 'png' + changes = True + + if format_override and format != format_override: + format = format_override + changes = True + + filename = '{}.{}'.format(os.path.basename(file.name), format) + + if changes: + file = io.BytesIO() + img.save(file, format) + else: + filename = file.name + + # FIXME: duplicate code + if file_field: + try: + old_hash = hashlib.sha256() + for c in file_field.file.chunks(): + old_hash.update(c) + old_hash = old_hash.hexdigest() + except FileNotFoundError: + logger.warning('Old file {} missing!'.format(file_field)) + old_hash = None + else: + old_hash = None + + file_field.delete(save=False) + # This triggers a save on obj + file_field.save(filename, django.core.files.File(file)) + + # FIXME: duplicate code + new_hash = hashlib.sha256() + for c in file_field.file.chunks(): + new_hash.update(c) + new_hash = new_hash.hexdigest() + + logger.info('POST updated {}[{}].{}: {} -> {}'.format(self._model_name(), pk, file_field_name, old_hash, new_hash)) + path = self.router.model_route(self.model, obj.id, field) + + # {duplicate-binder-file-field-hash-code} + if isinstance(field, BinderFileField): + path += '?h={}&content_type={}&filename={}'.format( + file_field.content_hash, + file_field.content_type or '', + os.path.basename(file_field.name), + ) + + return JsonResponse( {"data": {file_field_name: path}} ) + + except ValidationError as ve: + raise self.binder_validation_error(obj, ve, pk=pk) if request.method == 'DELETE': + self._require_model_perm('change', request) if not file_field: raise BinderIsDeleted() - # Hack to communicate to _store() that we're not interested in - # the new data (for perf reasons). - request._is_file_upload = True - self._store(obj, {file_field_name: None}, request, pk=pk) + # FIXME: duplicate code + old_hash = hashlib.sha256() + for c in file_field.file.chunks(): + old_hash.update(c) + old_hash = old_hash.hexdigest() + + file_field.delete() - return JsonResponse({'data': {file_field_name: None}}) + logger.info('DELETEd {}[{}].{}: {}'.format(self._model_name(), pk, file_field_name, old_hash)) + return JsonResponse( {"data": {file_field_name: None}} ) @@ -3004,7 +2464,7 @@ def filefield_get_name(self, instance=None, request=None, file_field=None): try: method = getattr(self, 'filefield_get_name_' + file_field.field.name) except AttributeError: - return os.path.basename(file_field.name) + return os.path.basename(file_field.path) return method(instance=instance, request=request, file_field=file_field) @@ -3015,7 +2475,7 @@ def view_history(self, request, pk=None, **kwargs): debug = kwargs['history'] == 'debug' - if debug and not settings.ENABLE_DEBUG_ENDPOINTS: + if debug and not django.conf.settings.ENABLE_DEBUG_ENDPOINTS: logger.warning('Debug endpoints disabled.') return HttpResponseForbidden('Debug endpoints disabled.') @@ -3023,131 +2483,7 @@ def view_history(self, request, pk=None, **kwargs): if debug: return history.view_changesets_debug(request, changesets.order_by('-id')) else: - return history.view_changesets(request, changesets.order_by('-id'), self.model, pk) - - - @list_route('stats', methods=['GET']) - def stats_view(self, request): - # We only apply annotations when used, so we can just pretend everything is included to simplify stuff - try: - annotations = self.model.Annotations - except AttributeError: - include_annotations = {'': []} - else: - include_annotations = {'': [ - attr - for attr in dir(annotations) - if not (attr.startswith('__') and attr.endswith('__')) - ]} - - queryset, annotations = self._get_filtered_queryset_base(request, None, include_annotations) - - try: - stats = request.GET['stats'] - except KeyError: - stats = [] - else: - stats = stats.split(',') - - return JsonResponse({ - stat: self._get_stat(request, queryset, stat, annotations.copy(), include_annotations) - for stat in stats - }) - - - def _get_stat(self, request, queryset, stat, annotations, include_annotations): - # NOTE: uses annotations! If called multiple times, provide a copy - # Get stat definition - try: - stat = self.stats[stat] - except KeyError: - try: - stat = DEFAULT_STATS[stat] - except KeyError: - raise BinderRequestError(f'unknown stat: {stat}') - - # Apply filters - for key, value in stat.filters.items(): - q, distinct = self._parse_filter(key, value, request, include_annotations) - queryset = self._apply_q_with_possible_annotations(queryset, q, annotations) - if distinct: - queryset = queryset.distinct() - - # Apply required annotations - for key in stat.annotations: - try: - expr = annotations.pop(key) - except KeyError: - pass - else: - queryset = queryset.annotate(**{key: expr}) - - if stat.group_by is None: - # No group by so just return a simple stat - return { - 'value': queryset.aggregate(result=stat.expr)['result'], - 'filters': stat.filters, - } - - group_by = stat.group_by.replace('.', '__') - - value = { - # The jsonloads/jsondumps is to make sure we can handle different - # types as keys, an example is dates. - jsonloads(jsondumps(key)): value - for key, value in ( - queryset - .order_by() - .exclude(**{group_by: None}) - .values(group_by) - .annotate(_binder_stat=stat.expr) - .values_list(group_by, '_binder_stat') - ) - } - - other = 0 - if stat.min_value is not None: - min_value = stat.min_value * sum(value.values()) - new_value = {} - - others = 0 - for key, sub_value in value.items(): - if sub_value >= min_value: - new_value[key] = sub_value - else: - other += sub_value - others += 1 - - if others > 1: - value = new_value - else: - other = 0 - - elif stat.max_values is not None: - keys = sorted(value, key=lambda key: value[key], reverse=True) - for key in keys[stat.max_values:]: - other += value.pop(key) - - return { - 'value': value, - 'other': other, - 'group_by': stat.group_by, - 'filters': stat.filters, - } - - - def _apply_annotations(self, queryset, annotations, *fields): - for field in fields: - if field is None: - continue - field = field.split('__', 1)[0] - try: - annotation = annotations.pop(field) - except KeyError: - pass - else: - queryset = queryset.annotate(**{field: annotation}) - return queryset + return history.view_changesets(request, changesets.order_by('-id')) @@ -3169,7 +2505,7 @@ def debug_changesets_24h(request): logger.warning('Not authenticated.') return HttpResponseForbidden('Not authenticated.') - if not settings.ENABLE_DEBUG_ENDPOINTS: + if not django.conf.settings.ENABLE_DEBUG_ENDPOINTS: logger.warning('Debug endpoints disabled.') return HttpResponseForbidden('Debug endpoints disabled.') @@ -3184,4 +2520,4 @@ def handler500(request): except Exception as e: request_id = str(e) - return HttpResponse('{"code": "InternalServerError", "debug": {"request_id": "' + request_id + '"}}', status=500) \ No newline at end of file + return HttpResponse('{"code": "InternalServerError", "debug": {"request_id": "' + request_id + '"}}', status=500) From e0510c6579c123aaacb48588b5b2a82e9dded855 Mon Sep 17 00:00:00 2001 From: robin Date: Thu, 22 Jul 2021 11:54:23 +0200 Subject: [PATCH 11/33] Remove incorrect argument from full_clean --- binder/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/binder/models.py b/binder/models.py index afc481b0..df7525c2 100644 --- a/binder/models.py +++ b/binder/models.py @@ -622,7 +622,7 @@ def save(self, *args, only_validate=False, **kwargs): # _validation_model can be used to skip validation checks that are meant for complete models that are actually being saved self._validation_model = only_validate # Set the model as a validation model when we only want to validate the model - self.full_clean(only_validate=only_validate) # Never allow saving invalid models! + self.full_clean() # Never allow saving invalid models! return super().save(*args, **kwargs) From 26cf0ed5789e6b28ae3012ee6d7536f60fa980f8 Mon Sep 17 00:00:00 2001 From: robin Date: Thu, 22 Jul 2021 14:46:37 +0200 Subject: [PATCH 12/33] Add tests --- tests/test_model_validation.py | 45 +++++++++++++++++++++++++- tests/testapp/models/contact_person.py | 7 ++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/tests/test_model_validation.py b/tests/test_model_validation.py index 59dac762..b05dba5a 100644 --- a/tests/test_model_validation.py +++ b/tests/test_model_validation.py @@ -35,7 +35,10 @@ def setUp(self): # some update payload self.model_data_with_error = { - 'name': 'very_special_forbidden_contact_person_name', # see `contact_person.py` + 'name': 'very_special_forbidden_contact_person_name', # see `contact_person.py` + } + self.model_data_with_non_validation_error = { + 'name': 'very_special_validation_contact_person_name', # see `contact_person.py` } self.model_data = { 'name': 'Scooooooby', @@ -127,6 +130,21 @@ def test_validate_on_put(self): self.assertEqual(response.status_code, 200) self.assertEqual('Scooooooby', ContactPerson.objects.first().name) + def test_validate_model_validation_whitelisting(self): + person_id = ContactPerson.objects.create(name='Scooby Doo').id + self.assertEqual('Scooby Doo', ContactPerson.objects.first().name) + + # the normal request should give a validation error + response = self.client.put(f'/contact_person/{person_id}/', data=json.dumps(self.model_data_with_non_validation_error), content_type='application/json') + self.assert_validation_error(response, person_id) + self.assertEqual('Scooby Doo', ContactPerson.objects.first().name) + + # when just validating we want to ignore this validation error, so with validation it should not throw an error + response = self.client.put(f'/contact_person/{person_id}/?validate=true', data=json.dumps(self.model_data), content_type='application/json') + self.assert_no_validation_error(response) + self.assertEqual('Scooby Doo', ContactPerson.objects.first().name) + + def test_validate_on_multiput(self): person_1_id = ContactPerson.objects.create(name='Scooby Doo 1').id @@ -154,6 +172,17 @@ def test_validate_on_multiput(self): } ]} + multi_put_data_with_validation_whitelist = {'data': [ + { + 'id': person_1_id, + 'name': 'very_special_validation_contact_person_name', + }, + { + 'id': person_2_id, + 'name': 'very_special_validation_contact_person_other_name' + } + ]} + # trigger a validation error response = self.client.put(f'/contact_person/?validate=true', data=json.dumps(multi_put_data_with_error), content_type='application/json') self.assert_multi_put_validation_error(response) @@ -167,6 +196,20 @@ def test_validate_on_multiput(self): self.assertEqual('Scooby Doo 1', ContactPerson.objects.get(id=person_1_id).name) self.assertEqual('Scooby Doo 2', ContactPerson.objects.get(id=person_2_id).name) + # multi put validation whitelist test + response = self.client.put(f'/contact_person/?validate=true', data=json.dumps(multi_put_data_with_validation_whitelist), content_type='application/json') + self.assert_no_validation_error(response) + self.assertEqual('Scooby Doo 1', ContactPerson.objects.get(id=person_1_id).name) + self.assertEqual('Scooby Doo 2', ContactPerson.objects.get(id=person_2_id).name) + + # multi put non validation whitelist test error + response = self.client.put(f'/contact_person/', + data=json.dumps(multi_put_data_with_validation_whitelist), + content_type='application/json') + self.assert_multi_put_validation_error(response) + self.assertEqual('Scooby Doo 1', ContactPerson.objects.get(id=person_1_id).name) + self.assertEqual('Scooby Doo 2', ContactPerson.objects.get(id=person_2_id).name) + # now for real response = self.client.put(f'/contact_person/', data=json.dumps(multi_put_data), content_type='application/json') self.assertEqual(response.status_code, 200) diff --git a/tests/testapp/models/contact_person.py b/tests/testapp/models/contact_person.py index baa1abcb..6b699f5a 100644 --- a/tests/testapp/models/contact_person.py +++ b/tests/testapp/models/contact_person.py @@ -23,3 +23,10 @@ def clean(self): code='invalid', message='Very special validation check that we need in `tests.M2MStoreErrorsTest`.' ) + + # Should only give an error when model is not a validation model + if (self.name == 'very_special_validation_contact_person_name' or self.name == 'very_special_validation_contact_person_other_name') and not self._validation_model: + raise ValidationError( + code='invalid', + message='Very special validation check that we need in `tests.M2MStoreErrorsTest`.' + ) From 6c189964a58b0485ff7ec14808c6d83180274add Mon Sep 17 00:00:00 2001 From: stefanmajoor Date: Fri, 23 Jul 2021 10:18:07 +0200 Subject: [PATCH 13/33] do not write to a new worksheet. Take the default worksheet instead --- binder/plugins/views/csvexport.py | 58 ++++++++++++------------------- 1 file changed, 23 insertions(+), 35 deletions(-) diff --git a/binder/plugins/views/csvexport.py b/binder/plugins/views/csvexport.py index 1afd1794..7485ae58 100644 --- a/binder/plugins/views/csvexport.py +++ b/binder/plugins/views/csvexport.py @@ -15,7 +15,8 @@ class ExportFileAdapter: """ __metaclass__ = abc.ABCMeta - def __init__(self, request: HttpRequest): + def __init__(self, request: HttpRequest, csv_settings: 'CsvExportView.CsvExportSettings'): + self.csv_settings = csv_settings self.request = request @abc.abstractmethod @@ -69,8 +70,8 @@ class CsvFileAdapter(ExportFileAdapter): Adapter for returning CSV files """ - def __init__(self, request: HttpRequest): - super().__init__(request) + def __init__(self, request: HttpRequest, csv_settings: 'CsvExportView.CsvExportSettings'): + super().__init__(request, csv_settings) self.response = HttpResponse(content_type='text/csv') self.file_name = 'export' self.writer = csv.writer(self.response) @@ -79,7 +80,7 @@ def set_file_name(self, file_name: str): self.file_name = file_name def set_columns(self, columns: List[str]): - self.add_row(columns) + self.writer.writerow(list(map(lambda x: x[1], self.csv_settings.column_map))) def add_row(self, values: List[str]): self.writer.writerow(values) @@ -91,10 +92,10 @@ def get_response(self) -> HttpResponse: class ExcelFileAdapter(ExportFileAdapter): """ - Adapter for returning excel files + Adapter fore returning excel files """ - def __init__(self, request: HttpRequest): - super().__init__(request) + def __init__(self, request: HttpRequest, csv_settings: 'CsvExportView.CsvExportSettings'): + super().__init__(request, csv_settings) # Import pandas locally. This means that you can use the CSV adapter without using pandas import openpyxl @@ -103,7 +104,7 @@ def __init__(self, request: HttpRequest): # self.writer = self.pandas.ExcelWriter(self.response) self.work_book = self.openpyxl.Workbook() - self.sheet = self.work_book.active + self.sheet = self.work_book._sheets[0] # The row number we are currently writing to self._row_number = 0 @@ -115,7 +116,7 @@ def set_columns(self, columns: List[str]): self.add_row(columns) def add_row(self, values: List[str]): - for (column_id, value) in enumerate(values): + for (value, column_id) in zip(values, range(1000000)): self.sheet.cell(column=column_id + 1, row=self._row_number + 1, value=value) self._row_number += 1 @@ -129,9 +130,6 @@ def get_response(self) -> HttpResponse: self.response['Content-Disposition'] = 'attachment; filename="{}.xlsx"'.format(self.file_name) return self.response -DEFAULT_RESPONSE_TYPE_MAPPING = { - 'xlsx': ExcelFileAdapter, -} class RequestAwareAdapter(ExportFileAdapter): """ @@ -141,14 +139,14 @@ class RequestAwareAdapter(ExportFileAdapter): returns a xlsx type """ - def __init__(self, request: HttpRequest): - super().__init__(request) + def __init__(self, request: HttpRequest, csv_settings: 'CsvExportView.CsvExportSettings'): + super().__init__(request, csv_settings) - response_type_mapping = DEFAULT_RESPONSE_TYPE_MAPPING response_type = request.GET.get('response_type', '').lower() - AdapterClass = response_type_mapping.get(response_type, CsvFileAdapter) - - self.base_adapter = AdapterClass(request) + AdapterClass = CsvFileAdapter + if response_type == 'xlsx': + AdapterClass = ExcelFileAdapter + self.base_adapter = AdapterClass(request, csv_settings) def set_file_name(self, file_name: str): return self.base_adapter.set_file_name(file_name) @@ -163,8 +161,6 @@ def get_response(self) -> HttpResponse: return self.base_adapter.get_response() - - class CsvExportView: """ This class adds another endpoint to the ModelView, namely GET model/download/. This does the same thing as getting a @@ -182,7 +178,7 @@ class CsvExportSettings: """ def __init__(self, withs, column_map, file_name=None, default_file_name='download', multi_value_delimiter=' ', - extra_permission=None, extra_params={}, csv_adapter=RequestAwareAdapter, limit=10000): + extra_permission=None, csv_adapter=RequestAwareAdapter): """ @param withs: String[] An array of all the withs that are necessary for this csv export @param column_map: Tuple[] An array, with all columns of the csv file in order. Each column is represented by a tuple @@ -194,9 +190,6 @@ def __init__(self, withs, column_map, file_name=None, default_file_name='downloa as delimiter between them. This may be if an array is returned, or if we have a one to many relation @param extra_permission: String When set, an extra binder permission check will be done on this permission. @param csv_adapter: Class. Either an object extending - @param response_type_mapping: Mapping between the parameter used in the custom response type - @param limit: Limit for amount of items in the csv. This is a fail save that you do not bring down the server with - a big query """ self.withs = withs self.column_map = column_map @@ -204,10 +197,7 @@ def __init__(self, withs, column_map, file_name=None, default_file_name='downloa self.default_file_name = default_file_name self.multi_value_delimiter = multi_value_delimiter self.extra_permission = extra_permission - self.extra_params = extra_params self.csv_adapter = csv_adapter - self.limit = limit - def _generate_csv_file(self, request: HttpRequest, file_adapter: CsvFileAdapter): @@ -220,10 +210,8 @@ def _generate_csv_file(self, request: HttpRequest, file_adapter: CsvFileAdapter) mutable = request.POST._mutable request.GET._mutable = True request.GET['page'] = 1 - request.GET['limit'] = self.csv_settings.limit if self.csv_settings.limit is not None else 'none' + request.GET['limit'] = 10000 request.GET['with'] = ",".join(self.csv_settings.withs) - for key, value in self.csv_settings.extra_params.items(): - request.GET[key] = value request.GET._mutable = mutable parent_result = self.get(request) @@ -269,6 +257,8 @@ def get_datum(data, key, prefix=''): if '.' not in key: if key not in data: raise Exception("{} not found in data: {}".format(key, data)) + if type(data[key]) == list: + return self.csv_settings.multi_value_delimiter.join(data[key]) return data[key] else: """ @@ -284,12 +274,12 @@ def get_datum(data, key, prefix=''): head_key, subkey = key.split('.', 1) if head_key in data: new_prefix = '{}.{}'.format(prefix, head_key) - if isinstance(data[head_key], dict): + if type(data[head_key]) == dict: return get_datum(data[head_key], subkey, new_prefix) else: # Assume that we have a mapping now fk_ids = data[head_key] - if not isinstance(fk_ids, list): + if type(fk_ids) != list: fk_ids = [fk_ids] # if head_key not in key_mapping: @@ -309,8 +299,6 @@ def get_datum(data, key, prefix=''): if len(col_definition) >= 3: transform_function = col_definition[2] datum = transform_function(datum, row, key_mapping) - if isinstance(datum, list): - datum = self.csv_settings.multi_value_delimiter.join(datum) data.append(datum) file_adapter.add_row(data) @@ -325,7 +313,7 @@ def download(self, request): if self.csv_settings is None: raise Exception('No csv settings set!') - file_adapter = self.csv_settings.csv_adapter(request) + file_adapter = self.csv_settings.csv_adapter(request, self.csv_settings) self._generate_csv_file(request, file_adapter) From caa3009a189b5790ce4f01f47b5e65c7e970b241 Mon Sep 17 00:00:00 2001 From: robin Date: Fri, 23 Jul 2021 10:49:51 +0200 Subject: [PATCH 14/33] update doc --- docs/api.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/api.md b/docs/api.md index 850a9bb8..0156a614 100644 --- a/docs/api.md +++ b/docs/api.md @@ -200,6 +200,8 @@ Currently this is implemented by raising an `BinderValidateOnly` exception, whic It is important to realize that in this way, the normal `save()` function is called on a model, so it is possible that possible side effects are triggered, when these are implemented directly in `save()`, as opposed to in a signal method, which would be preferable. In other words, we cannot guarantee that the request will be idempotent. Therefore, the validation only feature is disabled by default and must be enabled by setting `allow_standalone_validation=True` on the view. +When a model is being validated and not actually being saved the `_validation_model property` of the binder model is set to True. This allows whitelisting of certain validation checks such as with certain relations that are not included with the validation model. + ### Uploading files To upload a file, you have to add it to the `file_fields` of the `ModelView`: From 4a5ef472891aad91d3aff7fe3013f18767fc3e4a Mon Sep 17 00:00:00 2001 From: stefanmajoor Date: Mon, 26 Jul 2021 09:44:46 +0200 Subject: [PATCH 15/33] Add first attempt at getting mssql support working --- .github/workflows/ci.yml | 49 ++++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ba9c15eb..2a251e69 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,28 +4,13 @@ on: push jobs: check: + runs-on: ubuntu-latest + strategy: matrix: - # Testing Python 3.7 (deb 10), Python 3.9 (deb 11), Python 3.11 (deb 12) - python-version: ["3.9", "3.11"] - django-version: ["3.2.25", "4.2.17", "5.1.4"] - database-engine: ["postgres", "mysql"] - os: [ubuntu-latest] - include: - # 3.7 cannot run on latest ubuntu - - python-version: 3.7 - django-version: 3.2.25 - database-engine: postgres - os: ubuntu-22.04 - - python-version: 3.7 - django-version: 3.2.25 - database-engine: mysql - os: ubuntu-22.04 - exclude: - - python-version: 3.9 - django-version: 5.1.4 - - runs-on: ${{ matrix.os }} + python-version: [ "3.7", "3.8" ] + django-version: [ "2.1.1", "3.1.4" ] + database-engine: [ "postgres", "mysql", "mssql"] services: postgres: @@ -52,19 +37,35 @@ jobs: ports: - 3306:3306 + + mssqldb: + image: mcr.microsoft.com/mssql/server:2017-latest + env: + ACCEPT_EULA: y + SA_PASSWORD: Test + + steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v2 - name: Setup python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} + - name: Retrieve cached venv + uses: actions/cache@v1 + id: cache-venv + with: + path: ./.venv/ + key: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.django-version }}-venv-${{ hashFiles('ci-requirements.txt') }} + - name: Install requirements run: | python -m venv .venv - .venv/bin/pip install django==${{ matrix.django-version }} -r ci-requirements.txt + .venv/bin/pip install -qr ci-requirements.txt django==${{ matrix.django-version }} + if: steps.cache-venv.outputs.cache-hit != 'true' - name: Run linting run: .venv/bin/flake8 binder @@ -82,7 +83,7 @@ jobs: - name: Run tests run: | - .venv/bin/coverage run --include="binder/*" -m unittest discover -vt . -s tests + .venv/bin/coverage run --include="binder/*" setup.py test env: BINDER_TEST_MYSQL: ${{ matrix.database-engine == 'mysql' && 1 || 0 }} CY_RUNNING_INSIDE_CI: 1 From 027237c8c6c330d989382b0711027d4806f68448 Mon Sep 17 00:00:00 2001 From: stefanmajoor Date: Fri, 17 Sep 2021 13:42:55 +0200 Subject: [PATCH 16/33] add support fo nullable foreign keys in CSVexport plugin --- binder/plugins/views/csvexport.py | 7 +- tests/plugins/test_csvexport.py | 102 ++++++------------------------ tests/testapp/models/__init__.py | 2 +- tests/testapp/models/picture.py | 13 ++++ tests/testapp/views/__init__.py | 2 +- tests/testapp/views/picture.py | 7 +- 6 files changed, 46 insertions(+), 87 deletions(-) diff --git a/binder/plugins/views/csvexport.py b/binder/plugins/views/csvexport.py index 7485ae58..93b6a7e4 100644 --- a/binder/plugins/views/csvexport.py +++ b/binder/plugins/views/csvexport.py @@ -279,7 +279,12 @@ def get_datum(data, key, prefix=''): else: # Assume that we have a mapping now fk_ids = data[head_key] - if type(fk_ids) != list: + + if fk_ids is None: + # This case happens if we have a nullable foreign key that is null. Treat this as a many + # to one relation with no values. + fk_ids = [] + elif type(fk_ids) != list: fk_ids = [fk_ids] # if head_key not in key_mapping: diff --git a/tests/plugins/test_csvexport.py b/tests/plugins/test_csvexport.py index 6837c4a2..d2d4d010 100644 --- a/tests/plugins/test_csvexport.py +++ b/tests/plugins/test_csvexport.py @@ -1,5 +1,4 @@ from PIL import Image -from binder.plugins.views.csvexport import ExcelFileAdapter from os import urandom from tempfile import NamedTemporaryFile import io @@ -8,8 +7,7 @@ from django.core.files import File from django.contrib.auth.models import User -from ..testapp.models import Picture, Animal, Caretaker -from ..testapp.views import PictureView +from ..testapp.models import Picture, Animal, PictureBook import csv import openpyxl @@ -33,8 +31,10 @@ def setUp(self): self.pictures = [] + picture_book = PictureBook.objects.create(name='Holiday 2012') + for i in range(3): - picture = Picture(animal=animal) + picture = Picture(animal=animal, picture_book=picture_book) file = CsvExportTest.temp_imagefile(50, 50, 'jpeg') picture.file.save('picture.jpg', File(file), save=False) picture.original_file.save('picture_copy.jpg', File(file), save=False) @@ -50,6 +50,7 @@ def setUp(self): def test_csv_download(self): response = self.client.get('/picture/download_csv/') + print(response.content) self.assertEqual(200, response.status_code) response_data = csv.reader(io.StringIO(response.content.decode("utf-8"))) @@ -71,8 +72,7 @@ def test_excel_download(self): tmp.write(response.content) wb = openpyxl.load_workbook(tmp.name) - self.assertEqual(1, len(wb._sheets)) - sheet = wb._sheets[0] + sheet = wb._sheets[1] _values = list(sheet.values) @@ -105,28 +105,6 @@ def test_context_aware_downloader_default_csv(self): self.assertEqual(data[2], [str(self.pictures[1].id), str(self.pictures[1].animal_id), str(self.pictures[1].id ** 2)]) self.assertEqual(data[3], [str(self.pictures[2].id), str(self.pictures[2].animal_id), str(self.pictures[2].id ** 2)]) - def test_download_extra_params(self): - caretaker_1 = Caretaker(name='Foo') - caretaker_1.save() - caretaker_2 = Caretaker(name='Bar') - caretaker_2.save() - caretaker_3 = Caretaker(name='Baz') - caretaker_3.save() - - response = self.client.get('/caretaker/download/') - self.assertEqual(200, response.status_code) - response_data = csv.reader(io.StringIO(response.content.decode("utf-8"))) - - data = list(response_data) - - # First line needs to be the header - self.assertEqual(data[0], ['ID', 'Name', 'Scary']) - - # All other data needs to be ordered using the default ordering (by id, asc) - self.assertEqual(data[1], [str(caretaker_1.id), 'Foo', 'boo!']) - self.assertEqual(data[2], [str(caretaker_2.id), 'Bar', 'boo!']) - self.assertEqual(data[3], [str(caretaker_3.id), 'Baz', 'boo!']) - def test_context_aware_download_xlsx(self): response = self.client.get('/picture/download/?response_type=xlsx') self.assertEqual(200, response.status_code) @@ -135,7 +113,7 @@ def test_context_aware_download_xlsx(self): tmp.write(response.content) wb = openpyxl.load_workbook(tmp.name) - sheet = wb._sheets[0] + sheet = wb._sheets[1] _values = list(sheet.values) @@ -150,63 +128,23 @@ def test_context_aware_download_xlsx(self): self.assertEqual(list(_values[3]), [self.pictures[2].id, str(self.pictures[2].animal_id), (self.pictures[2].id ** 2)]) - def test_csv_export_custom_limit(self): - old_limit = PictureView.csv_settings.limit; - PictureView.csv_settings.limit = 1 - response = self.client.get('/picture/download/') - self.assertEqual(200, response.status_code) - response_data = csv.reader(io.StringIO(response.content.decode("utf-8"))) - - # Header - self.assertEqual(next(response_data), ['picture identifier', 'animal identifier', 'squared picture identifier']) - # 1 REcord - self.assertIsNotNone(next(response_data)) - - # EOF - with self.assertRaises(StopIteration): - self.assertIsNone(next(response_data)) - - ###### Limit 2 - PictureView.csv_settings.limit = 2 - response = self.client.get('/picture/download/') - self.assertEqual(200, response.status_code) - response_data = csv.reader(io.StringIO(response.content.decode("utf-8"))) - - # Header - self.assertEqual(next(response_data), ['picture identifier', 'animal identifier', 'squared picture identifier']) - # 1 REcord - self.assertIsNotNone(next(response_data)) - # 2 Records - self.assertIsNotNone(next(response_data)) - # EOF - with self.assertRaises(StopIteration): - self.assertIsNone(next(response_data)) - - PictureView.csv_settings.limit = old_limit; + def test_none_foreign_key(self): + """ + If we have a relation that we have to include which is nullable, and we have a foreign key, we do not want the + csv export to crash. I.e. we also want for a picture to export the picture book name, even though not all pioctures + belong to a picture book + :return: + """ + self.pictures[0].picture_book = None + self.pictures[0].save() - def test_csv_settings_limit_none_working(self): - # Limit None should download everything - - old_limit = PictureView.csv_settings.limit; - PictureView.csv_settings.limit = None response = self.client.get('/picture/download/') self.assertEqual(200, response.status_code) response_data = csv.reader(io.StringIO(response.content.decode("utf-8"))) - # Header - self.assertEqual(next(response_data), ['picture identifier', 'animal identifier', 'squared picture identifier']) - # 3 REcords, everything we have in the database - self.assertIsNotNone(next(response_data)) - self.assertIsNotNone(next(response_data)) - self.assertIsNotNone(next(response_data)) - - # EOF - with self.assertRaises(StopIteration): - self.assertIsNone(next(response_data)) + data = list(response_data) + content = data[1:] - PictureView.csv_settings.limit = old_limit; + picture_books = [c[-1] for c in content] -class TestExcelFileAdapter(TestCase): - def test_one_sheet_after_init(self): - file_adapter = ExcelFileAdapter(None) - self.assertEqual(len(file_adapter.work_book.worksheets), 1) + self.assertEqual(['', 'Holiday 2012', 'Holiday 2012'], picture_books) \ No newline at end of file diff --git a/tests/testapp/models/__init__.py b/tests/testapp/models/__init__.py index f944a1ca..32753719 100644 --- a/tests/testapp/models/__init__.py +++ b/tests/testapp/models/__init__.py @@ -10,7 +10,7 @@ from .gate import Gate from .nickname import Nickname, NullableNickname from .lion import Lion -from .picture import Picture +from .picture import Picture, PictureBook from .zoo import Zoo from .zoo_employee import ZooEmployee from .city import City, CityState, PermanentCity diff --git a/tests/testapp/models/picture.py b/tests/testapp/models/picture.py index ffbe59bb..862bfb3e 100644 --- a/tests/testapp/models/picture.py +++ b/tests/testapp/models/picture.py @@ -13,6 +13,18 @@ def delete_files(sender, instance=None, **kwargs): except Exception: pass + +class PictureBook(BinderModel): + """ + Sometimes customers like to commemorate their visit to the zoo. Of course there are always some shitty pictures that + we do not want in a picture album + + """ + + name = models.TextField() + + + # At the website of the zoo there are some pictures of animals. This model links the picture to an animal. # # A picture has two files, the original uploaded file, and the modified file. This model is used for testing the @@ -21,6 +33,7 @@ class Picture(BinderModel): animal = models.ForeignKey('Animal', on_delete=models.CASCADE, related_name='picture') file = models.ImageField(upload_to='floor-plans') original_file = models.ImageField(upload_to='floor-plans') + picture_book = models.ForeignKey('PictureBook', on_delete=models.CASCADE, null=True, blank=True) def __str__(self): return 'picture %d: (Picture for animal %s)' % (self.pk or 0, self.animal.name) diff --git a/tests/testapp/views/__init__.py b/tests/testapp/views/__init__.py index 31fb1049..c897a22a 100644 --- a/tests/testapp/views/__init__.py +++ b/tests/testapp/views/__init__.py @@ -15,7 +15,7 @@ from .gate import GateView from .lion import LionView from .nickname import NicknameView -from .picture import PictureView +from .picture import PictureView, PictureBookView from .user import UserView from .zoo import ZooView from .zoo_employee import ZooEmployeeView diff --git a/tests/testapp/views/picture.py b/tests/testapp/views/picture.py index 7dadfe99..0ba35f99 100644 --- a/tests/testapp/views/picture.py +++ b/tests/testapp/views/picture.py @@ -3,16 +3,19 @@ from binder.views import ModelView from binder.plugins.views import ImageView, CsvExportView -from ..models import Picture +from ..models import Picture, PictureBook +class PictureBookView(ModelView): + model = PictureBook class PictureView(ModelView, ImageView, CsvExportView): model = Picture file_fields = ['file', 'original_file'] - csv_settings = CsvExportView.CsvExportSettings(['animal'], [ + csv_settings = CsvExportView.CsvExportSettings(['animal', 'picture_book'], [ ('id', 'picture identifier'), ('animal.id', 'animal identifier'), ('id', 'squared picture identifier', lambda id, row, mapping: id**2), + ('picture_book.name', 'Picturebook name') ]) @list_route(name='download_csv', methods=['GET']) From ed3b0db3c0c4a10bad1216c0f2b0c2e6614e3734 Mon Sep 17 00:00:00 2001 From: Daan van der Kallen Date: Tue, 18 Jan 2022 23:30:21 +0100 Subject: [PATCH 17/33] Do 1 query per with instead of a big combined one to prevent exponential join scaling --- binder/views.py | 49 +++++++++++++++++++++---------------------------- 1 file changed, 21 insertions(+), 28 deletions(-) diff --git a/binder/views.py b/binder/views.py index aec0f62f..c0f37a91 100644 --- a/binder/views.py +++ b/binder/views.py @@ -882,7 +882,6 @@ def _follow_related(self, fieldspec): def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): result = {} - annotations = {} singular_fields = set() rel_ids_by_field_by_id = defaultdict(lambda: defaultdict(list)) virtual_fields = set() @@ -922,7 +921,7 @@ def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): # This does mean you can't filter on this relation # unless you write a custom filter, too. if isinstance(virtual_annotation, Q): - annotations[field_alias] = Agg(virtual_annotation, filter=q, ordering=orders) + annotation = Agg(virtual_annotation, filter=q, ordering=orders) else: try: func = getattr(self, virtual_annotation) @@ -931,6 +930,7 @@ def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): self.model.__name__, field, virtual_annotation )) rel_ids_by_field_by_id[field] = func(request, pks, q) + continue # Actual relation else: if (getattr(self.model, field).__class__ == models.fields.related.ReverseOneToOneDescriptor or @@ -939,34 +939,27 @@ def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): if Agg != GroupConcat: # HACKK (GROUP_CONCAT can't filter and excludes NULL already) q &= Q(**{field+'__pk__isnull': False}) - annotations[field_alias] = Agg(field+'__pk', filter=q, ordering=orders) - - - qs = self.model.objects.filter(pk__in=pks).values('pk').annotate(**annotations) - for record in qs: - for field in with_map: - field_alias = field+'___annotation' if field in virtual_fields else field - - if field_alias in annotations: - value = record[field_alias] - - # Make the values distinct. We can't do this in - # the Agg() call, because then we get an error - # regarding order by and values needing to be the - # same :( - # We also can't just put it in a set, because we - # need to preserve the ordering. So we use a set - # to keep track of what we've seen and only add - # new items. - seen_values = set() - distinct_values = [] - for v in value: - if v not in seen_values: - distinct_values.append(v) + annotation = Agg(field+'__pk', filter=q, ordering=orders) + + qs = self.model.objects.filter(pk__in=pks).values('pk').annotate(**{field_alias: annotation}) + + for record in qs: + value = record[field_alias] + + # Make the values distinct. We can't do this in + # the Agg() call, because then we get an error + # regarding order by and values needing to be the + # same :( + # We also can't just put it in a set, because we + # need to preserve the ordering. So we use a set + # to keep track of what we've seen and only add + # new items. + seen_values = set() + for v in value: + if v not in seen_values: + rel_ids_by_field_by_id[field][record['pk']].append(v) seen_values.add(v) - rel_ids_by_field_by_id[field][record['pk']] += distinct_values - for field, sub_fields in with_map.items(): next = self._follow_related(field)[0].model From 180b0b94c5ffb8fadff758f5a167da6b25a1a569 Mon Sep 17 00:00:00 2001 From: Daan van der Kallen Date: Wed, 19 Jan 2022 00:00:23 +0100 Subject: [PATCH 18/33] Do not use aggregates at all --- binder/views.py | 63 +++++++++++++++++-------------------------------- 1 file changed, 22 insertions(+), 41 deletions(-) diff --git a/binder/views.py b/binder/views.py index c0f37a91..3e82712c 100644 --- a/binder/views.py +++ b/binder/views.py @@ -886,8 +886,6 @@ def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): rel_ids_by_field_by_id = defaultdict(lambda: defaultdict(list)) virtual_fields = set() - Agg = self.AggStrategy - for field in with_map: vr = self.virtual_relations.get(field, None) @@ -899,12 +897,6 @@ def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): if rel == field or rel.startswith(field + '.') }) - # Model default orders (this sometimes matters) - orders = [] - field_alias = field + '___annotation' if vr else field - for o in (view.model._meta.ordering if view.model._meta.ordering else BinderModel.Meta.ordering): - orders.append(prefix_db_expression(o, field_alias)) - # Virtual relation if vr: virtual_fields.add(field) @@ -920,45 +912,34 @@ def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): # annotations, so allow for fetching of ids, instead. # This does mean you can't filter on this relation # unless you write a custom filter, too. - if isinstance(virtual_annotation, Q): - annotation = Agg(virtual_annotation, filter=q, ordering=orders) - else: - try: - func = getattr(self, virtual_annotation) - except AttributeError: - raise BinderRequestError('Annotation for virtual relation {{{}}}.{{{}}} is {{{}}}, but no method by that name exists.'.format( - self.model.__name__, field, virtual_annotation - )) - rel_ids_by_field_by_id[field] = func(request, pks, q) - continue + try: + func = getattr(self, virtual_annotation) + except AttributeError: + raise BinderRequestError('Annotation for virtual relation {{{}}}.{{{}}} is {{{}}}, but no method by that name exists.'.format( + self.model.__name__, field, virtual_annotation + )) + rel_ids_by_field_by_id[field] = func(request, pks, q) # Actual relation else: if (getattr(self.model, field).__class__ == models.fields.related.ReverseOneToOneDescriptor or not any(f.name == field for f in (list(self.model._meta.many_to_many) + list(self._get_reverse_relations())))): singular_fields.add(field) - if Agg != GroupConcat: # HACKK (GROUP_CONCAT can't filter and excludes NULL already) - q &= Q(**{field+'__pk__isnull': False}) - annotation = Agg(field+'__pk', filter=q, ordering=orders) - - qs = self.model.objects.filter(pk__in=pks).values('pk').annotate(**{field_alias: annotation}) - - for record in qs: - value = record[field_alias] - - # Make the values distinct. We can't do this in - # the Agg() call, because then we get an error - # regarding order by and values needing to be the - # same :( - # We also can't just put it in a set, because we - # need to preserve the ordering. So we use a set - # to keep track of what we've seen and only add - # new items. - seen_values = set() - for v in value: - if v not in seen_values: - rel_ids_by_field_by_id[field][record['pk']].append(v) - seen_values.add(v) + # Model default orders (this sometimes matters) + orders = [] + for o in (view.model._meta.ordering if view.model._meta.ordering else BinderModel.Meta.ordering): + orders.append(prefix_db_expression(o, field)) + + query = ( + self.model.objects + .order_by(*orders) + .filter(q, pk__in=pks, **{field + '__isnull': False}) + .values_list('pk', field + '__pk') + .distinct() + ) + + for pk, rel_pk in query: + rel_ids_by_field_by_id[field][pk].append(rel_pk) for field, sub_fields in with_map.items(): next = self._follow_related(field)[0].model From c527280247c8b3cfef0c27a46d9d7099e40b0b86 Mon Sep 17 00:00:00 2001 From: Daan van der Kallen Date: Wed, 19 Jan 2022 01:37:51 +0100 Subject: [PATCH 19/33] Do not use join for reverse relations --- binder/views.py | 52 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/binder/views.py b/binder/views.py index 3e82712c..3eb4747a 100644 --- a/binder/views.py +++ b/binder/views.py @@ -220,6 +220,21 @@ def prefix_db_expression(value, prefix): raise ValueError('Unknown expression type, cannot apply db prefix: %s', value) +# Prefix a q expression by adding a prefix to all filters. You can also supply +# an 'antiprefix', if the filter starts with this value it will be removed +# instead of the prefix being added. This is useful for reversed fields. +def prefix_q_expression(value, prefix, antiprefix=None): + children = [] + for child in value.children: + if isinstance(child, Q): + children.append(prefix_q_expression(child, prefix, antiprefix)) + elif antiprefix is not None and child[0].startswith(antiprefix + '__'): + children.append((child[0][len(antiprefix) + 2:], child[1])) + else: + children.append((prefix + '__' + child[0], child[1])) + return Q(*children, _negated=value.negated) + + class ModelView(View): # Model this is a view for. Use None for views not tied to a particular model. model = None @@ -921,22 +936,31 @@ def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): rel_ids_by_field_by_id[field] = func(request, pks, q) # Actual relation else: - if (getattr(self.model, field).__class__ == models.fields.related.ReverseOneToOneDescriptor or - not any(f.name == field for f in (list(self.model._meta.many_to_many) + list(self._get_reverse_relations())))): + f = self.model._meta.get_field(field) + + if f.one_to_one or f.many_to_one: singular_fields.add(field) - # Model default orders (this sometimes matters) - orders = [] - for o in (view.model._meta.ordering if view.model._meta.ordering else BinderModel.Meta.ordering): - orders.append(prefix_db_expression(o, field)) - - query = ( - self.model.objects - .order_by(*orders) - .filter(q, pk__in=pks, **{field + '__isnull': False}) - .values_list('pk', field + '__pk') - .distinct() - ) + if any(f.name == field for f in self._get_reverse_relations()): + rev_field = f.remote_field.name + query = ( + view.model.objects + .filter(prefix_q_expression(q, rev_field, field), **{rev_field + '__in': pks}) + .values_list(rev_field + '__pk', 'pk') + .distinct() + ) + else: + # Model default orders (this sometimes matters) + orders = [] + for o in (view.model._meta.ordering if view.model._meta.ordering else BinderModel.Meta.ordering): + orders.append(prefix_db_expression(o, field)) + query = ( + self.model.objects + .order_by(*orders) + .filter(q, pk__in=pks, **{field + '__isnull': False}) + .values_list('pk', field + '__pk') + .distinct() + ) for pk, rel_pk in query: rel_ids_by_field_by_id[field][pk].append(rel_pk) From 15263ca355eb048705f2ea1e5f1d4fbec3b2b68c Mon Sep 17 00:00:00 2001 From: Daan van der Kallen Date: Tue, 18 Jan 2022 23:30:21 +0100 Subject: [PATCH 20/33] Do 1 query per with instead of a big combined one to prevent exponential join scaling --- binder/views.py | 511 ++++++++++++++++++++++++++++++------------------ 1 file changed, 324 insertions(+), 187 deletions(-) diff --git a/binder/views.py b/binder/views.py index 3eb4747a..4e35a538 100644 --- a/binder/views.py +++ b/binder/views.py @@ -1,12 +1,11 @@ import logging import time -import io import inspect import os -import hashlib import datetime import mimetypes import functools +import re from collections import defaultdict, namedtuple from contextlib import ExitStack @@ -16,8 +15,10 @@ import django from django.views.generic import View from django.core.exceptions import ObjectDoesNotExist, FieldError, ValidationError, FieldDoesNotExist +from django.core.files.base import File, ContentFile from django.http import HttpResponse, StreamingHttpResponse, HttpResponseForbidden from django.http.request import RawPostDataException, QueryDict +from django.http.multipartparser import MultiPartParser from django.db import models, connections from django.db.models import Q, F from django.db.models.lookups import Transform @@ -34,7 +35,7 @@ ) from . import history from .orderable_agg import OrderableArrayAgg, GroupConcat, StringAgg -from .models import FieldFilter, BinderModel, ContextAnnotation, OptionalAnnotation, BinderFileField +from .models import FieldFilter, BinderModel, ContextAnnotation, OptionalAnnotation, BinderFileField, BinderImageField from .json import JsonResponse, jsonloads @@ -71,6 +72,42 @@ def get_default_annotations(model): return annotations +def split_path(path): + """ + This function splits a dot seperated path into a list of keys. + The advantage of this function over a simple path.split('.') is that it + handles backslash escaping so that keys can contain dot characters. + """ + path = iter(path) + chars = [] + for char in path: + if char == '.': + yield ''.join(chars) + chars.clear() + continue + if char == '\\': + char = next(path, '') + chars.append(char) + yield ''.join(chars) + + +def join_path(keys): + """ + This function joins a list of keys into a dot seperated path. + The advantage of this function over a simple '.'.join(keys) is that it + handles backslash escaping so that keys can contain dot characters. + """ + chars = [] + for i, key in enumerate(keys): + if i != 0: + chars.append('.') + for char in key: + if char in '\\.': + char = '\\' + char + chars.append(char) + return ''.join(chars) + + # Haha kill me now def multiput_get_id(bla): return bla['id'] if isinstance(bla, dict) else bla @@ -220,21 +257,6 @@ def prefix_db_expression(value, prefix): raise ValueError('Unknown expression type, cannot apply db prefix: %s', value) -# Prefix a q expression by adding a prefix to all filters. You can also supply -# an 'antiprefix', if the filter starts with this value it will be removed -# instead of the prefix being added. This is useful for reversed fields. -def prefix_q_expression(value, prefix, antiprefix=None): - children = [] - for child in value.children: - if isinstance(child, Q): - children.append(prefix_q_expression(child, prefix, antiprefix)) - elif antiprefix is not None and child[0].startswith(antiprefix + '__'): - children.append((child[0][len(antiprefix) + 2:], child[1])) - else: - children.append((prefix + '__' + child[0], child[1])) - return Q(*children, _negated=value.negated) - - class ModelView(View): # Model this is a view for. Use None for views not tied to a particular model. model = None @@ -523,12 +545,11 @@ def _get_objs(self, queryset, request, annotations=None): data = {} for f in fields: - if isinstance(f, models.fields.files.FileField): + if isinstance(f, models.FileField): file = getattr(obj, f.attname) if file: # {router-view-instance} data[f.name] = self.router.model_route(self.model, obj.id, f) - # {duplicate-binder-file-field-hash-code} if isinstance(f, BinderFileField): data[f.name] += '?h={}&content_type={}&filename={}'.format( @@ -901,6 +922,8 @@ def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): rel_ids_by_field_by_id = defaultdict(lambda: defaultdict(list)) virtual_fields = set() + Agg = self.AggStrategy + for field in with_map: vr = self.virtual_relations.get(field, None) @@ -912,6 +935,12 @@ def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): if rel == field or rel.startswith(field + '.') }) + # Model default orders (this sometimes matters) + orders = [] + field_alias = field + '___annotation' if vr else field + for o in (view.model._meta.ordering if view.model._meta.ordering else BinderModel.Meta.ordering): + orders.append(prefix_db_expression(o, field_alias)) + # Virtual relation if vr: virtual_fields.add(field) @@ -927,43 +956,45 @@ def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): # annotations, so allow for fetching of ids, instead. # This does mean you can't filter on this relation # unless you write a custom filter, too. - try: - func = getattr(self, virtual_annotation) - except AttributeError: - raise BinderRequestError('Annotation for virtual relation {{{}}}.{{{}}} is {{{}}}, but no method by that name exists.'.format( - self.model.__name__, field, virtual_annotation - )) - rel_ids_by_field_by_id[field] = func(request, pks, q) + if isinstance(virtual_annotation, Q): + annotation = Agg(virtual_annotation, filter=q, ordering=orders) + else: + try: + func = getattr(self, virtual_annotation) + except AttributeError: + raise BinderRequestError('Annotation for virtual relation {{{}}}.{{{}}} is {{{}}}, but no method by that name exists.'.format( + self.model.__name__, field, virtual_annotation + )) + rel_ids_by_field_by_id[field] = func(request, pks, q) + continue # Actual relation else: - f = self.model._meta.get_field(field) - - if f.one_to_one or f.many_to_one: + if (getattr(self.model, field).__class__ == models.fields.related.ReverseOneToOneDescriptor or + not any(f.name == field for f in (list(self.model._meta.many_to_many) + list(self._get_reverse_relations())))): singular_fields.add(field) - if any(f.name == field for f in self._get_reverse_relations()): - rev_field = f.remote_field.name - query = ( - view.model.objects - .filter(prefix_q_expression(q, rev_field, field), **{rev_field + '__in': pks}) - .values_list(rev_field + '__pk', 'pk') - .distinct() - ) - else: - # Model default orders (this sometimes matters) - orders = [] - for o in (view.model._meta.ordering if view.model._meta.ordering else BinderModel.Meta.ordering): - orders.append(prefix_db_expression(o, field)) - query = ( - self.model.objects - .order_by(*orders) - .filter(q, pk__in=pks, **{field + '__isnull': False}) - .values_list('pk', field + '__pk') - .distinct() - ) - - for pk, rel_pk in query: - rel_ids_by_field_by_id[field][pk].append(rel_pk) + if Agg != GroupConcat: # HACKK (GROUP_CONCAT can't filter and excludes NULL already) + q &= Q(**{field+'__pk__isnull': False}) + annotation = Agg(field+'__pk', filter=q, ordering=orders) + + qs = self.model.objects.filter(pk__in=pks).values('pk').annotate(**{field_alias: annotation}) + + for record in qs: + value = record[field_alias] + + # Make the values distinct. We can't do this in + # the Agg() call, because then we get an error + # regarding order by and values needing to be the + # same :( + # We also can't just put it in a set, because we + # need to preserve the ordering. So we use a set + # to keep track of what we've seen and only add + # new items. + seen_values = set() + for v in value: + if v not in seen_values: + rel_ids_by_field_by_id[field][record['pk']].append(v) + seen_values.add(v) for field, sub_fields in with_map.items(): next = self._follow_related(field)[0].model @@ -1267,9 +1298,11 @@ def _generate_meta(self, include_meta, queryset, request, pk=None): return meta - def get(self, request, pk=None, withs=None, include_annotations=None): - include_meta = request.GET.get('include_meta', 'total_records').split(',') - + def get_filtered_queryset(self, request, pk=None, include_annotations=None): + """ + Returns a scoped queryset with filtering and sorting applied as + specified by the request. + """ queryset = self.get_queryset(request) if pk: queryset = queryset.filter(pk=int(pk)) @@ -1285,6 +1318,7 @@ def get(self, request, pk=None, withs=None, include_annotations=None): #### annotations if include_annotations is None: include_annotations = self._parse_include_annotations(request) + queryset = annotate(queryset, request, include_annotations.get('')) #### filters @@ -1302,6 +1336,16 @@ def get(self, request, pk=None, withs=None, include_annotations=None): queryset = self.order_by(queryset, request) + return queryset + + + def get(self, request, pk=None, withs=None, include_annotations=None): + include_meta = request.GET.get('include_meta', 'total_records').split(',') + if include_annotations is None: + include_annotations = self._parse_include_annotations(request) + + queryset = self.get_filtered_queryset(request, pk, include_annotations) + meta = self._generate_meta(include_meta, queryset, request, pk) queryset = self._paginate(queryset, request) @@ -1458,7 +1502,10 @@ def store_m2m_field(obj, field, value, request): # Skip re-fetch and serialization via get_objs if we're in # multi-put (data is discarded!). - if getattr(request, '_is_multi_put', False): + if ( + getattr(request, '_is_multi_put', False) or # Multi put handles its own return data + getattr(request, '_is_file_upload', False) # Dispatch file field handles its own return data + ): return None # Permission checks are done at this point, so we can avoid get_queryset() @@ -1561,6 +1608,103 @@ def _store_m2m_field(self, obj, field, value, request): raise sum(validation_errors, None) + def _resolve_file_path(self, value, request): + match = re.match(r'/api/(\w+)/(\d+)/(\w+)/', value) + if not match: + return None + + model, pk, field = match.groups() + try: + model = self.router.name_models[model] + except KeyError: + return None + + view = self.get_model_view(model) + if field not in view.file_fields: + return None + + try: + obj = view.get_queryset(request).get(pk=pk) + except model.DoesNotExist: + return None + + return getattr(obj, field) + + + def _clean_image_file(self, field, value): + try: + img = Image.open(value) + except Exception: + raise BinderImageError('Could not parse the file as an image.') + + format = img.format.lower() + if not format in ('png', 'gif', 'jpeg'): + raise BinderFileTypeIncorrect([{'extension': t, 'mimetype': 'image/' + t} for t in ['jpeg', 'png', 'gif']]) + + width, height = img.size + if format == 'jpeg': + img2 = image_transpose_exif(img) + + if img2 != img: + value.seek(0) # Do not append to the existing file! + value.truncate() + img2.save(value, 'jpeg') + img = img2 + + # Determine resize threshold + try: + max_size = self.image_resize_threshold[field] + except TypeError: + max_size = self.image_resize_threshold + + try: + max_width, max_height = max_size + except (TypeError, ValueError): + max_width, max_height = max_size, max_size + + try: + format_override = self.image_format_override.get(field) + except AttributeError: + format_override = self.image_format_override + + changes = False + + # FIXME: hardcoded max + # Flat out refuse images exceeding this size, to prevent DoS. + width_limit, height_limit = max(max_width, 4096), max(max_height, 4096) + if width > width_limit or height > height_limit: + raise BinderImageSizeExceeded(width_limit, height_limit) + + # Resize images that are too large. + if width > max_width or height > max_height: + img.thumbnail((max_width, max_height), Image.ANTIALIAS) + logger.info('image dimensions ({}x{}) exceeded ({}, {}), resizing.'.format(width, height, max_width, max_height)) + if img.mode not in ["1", "L", "P", "RGB", "RGBA"]: + img = img.convert("RGB") + if format != 'jpeg': + format = 'png' + changes = True + + # Saving a JPEG with mode RGBA will crash because JPEG does not support + # an alpha channel, so in this case we convert to RGB + if format_override == 'jpeg' and img.mode == 'RGBA': + img = img.convert('RGB') + + if format_override and format != format_override: + format = format_override + changes = True + + name, ext = os.path.splitext(value.name) + if ext != '.' + format and not (ext == '.jpg' and format == 'jpeg'): + ext = '.' + format + changes = True + + if changes: + filename = name + ext + value = ContentFile(b'', name=filename) + img.save(value, format) + + return value # Override _store_field example for a "FOO" field @@ -1580,7 +1724,6 @@ def _store_field(self, obj, field, value, request, pk=None): 'id', 'pk', 'deleted', '_meta', *self.unwritable_fields, *self.shown_properties, - *self.file_fields, *self.annotations(request), ]: raise BinderReadOnlyFieldError(self.model.__name__, field) @@ -1613,6 +1756,7 @@ def _store_field(self, obj, field, value, request, pk=None): # Hack, set the id directly. This does the actual check, and throws the BinderError in # the same way the old case has. setattr(obj, f.attname, value) + elif isinstance(f, models.IntegerField): if value is None or value == '': value = None @@ -1633,6 +1777,7 @@ def _store_field(self, obj, field, value, request, pk=None): } }) setattr(obj, f.attname, value) + elif isinstance(f, models.TextField): # Django doesn't enforce max_length on TextFields, so we do. if f.max_length is not None: @@ -1653,6 +1798,46 @@ def _store_field(self, obj, field, value, request, pk=None): } }) setattr(obj, f.attname, value) + + elif isinstance(f, models.FileField): + # If the value is a str that matches how we return file + # fields we convert it to the correct file + if isinstance(value, str): + value_ = self._resolve_file_path(value, request) + if value_ is not None: + value = value_ + + if not isinstance(value, File) and value is not None: + raise BinderFieldTypeError(self.model.__name__, field) + + if value is not None: + if value.size > self.max_upload_size * 10**6: + raise BinderFileSizeExceeded(self.max_upload_size) + + if isinstance(f, models.ImageField): + allowed_extensions = ['png', 'gif', 'jpg', 'jpeg'] + elif isinstance(f, BinderFileField): + allowed_extensions = f.allowed_extensions + else: + allowed_extensions = None + + if allowed_extensions is not None: + extension = os.path.splitext(value.name)[1][1:].lower() + allowed_extensions = {ext.lower() for ext in allowed_extensions} + if extension not in allowed_extensions: + raise BinderFileTypeIncorrect([{'extension': t} for t in allowed_extensions]) + + if isinstance(f, (models.ImageField, BinderImageField)): + value = self._clean_image_file(field, value) + + old_value = getattr(obj, f.attname) + if old_value.name: + @transaction.on_commit + def delete_old_file(): + old_value.storage.delete(old_value.name) + + setattr(obj, f.attname, value) + else: try: f.to_python(value) @@ -1765,7 +1950,7 @@ def _obj_diff(self, old, new, name): # Put data and with on one big pile, that's easier for us def _multi_put_parse_request(self, request): - body = jsonloads(request.body) + body = self._get_request_values(request) data = body.get('with', {}) if not isinstance(data, dict): @@ -2125,7 +2310,65 @@ def multi_put(self, request): return JsonResponse({'idmap': output}) def _get_request_values(self, request): - return jsonloads(request.body) + # So normally we just parse json here but for multipart form data we have some special logic to inject files + # The values then would look a bit like this: + # + # data={"foo": 1, "bar": null, "baz": [1, 2, null]} + # file:bar= + # file:baz.2= + # + # Where the output will be the data but with the null values replaced by the specified files. + # The paths that the files use as key have to be in the data and have value null. + # This is because we do not want to alter the structure of the data because that gets messy with things like array indexes etc. + + if request.content_type == 'multipart/form-data': + # Django only parses multipart automatically on POST, so for PUT/PATCH/DELETE etc we do it manually + if request.method == 'POST': + fields = request.POST + files = request.FILES + else: + parser = MultiPartParser(request.META, request, request.upload_handlers) + fields, files = parser.parse() + try: + data = fields['data'] + except KeyError: + raise BinderRequestError('data field is required in multipart body') + else: + data = request.body + files = {} + + data = jsonloads(data) + + for path, value in files.items(): + if not path.startswith('file:'): + continue + target = data + keys = list(split_path(path[5:])) + for i, key in enumerate(keys): + if isinstance(target, list): + try: + key = int(key) + except ValueError: + raise BinderRequestError( + 'expected integer key at path: ' + join_path(keys[:i + 1]) + ) + if not ( + (isinstance(target, dict) and key in target) or + (isinstance(target, list) and 0 <= key < len(target)) + ): + raise BinderRequestError( + 'unexpected key at path: ' + join_path(keys[:i + 1]) + ) + if i != len(keys) - 1: + target = target[key] + elif target[key] is not None: + raise BinderRequestError( + 'expected null at path: ' + join_path(keys) + ) + else: + target[key] = value + + return data def put(self, request, pk=None): if pk is None: @@ -2315,146 +2558,40 @@ def dispatch_file_field(self, request, pk=None, file_field=None): return resp if request.method == 'POST': - self._require_model_perm('change', request) - try: # Take an arbitrary uploaded file file = next(request.FILES.values()) except StopIteration: raise BinderRequestError('File POST should use multipart/form-data (with an arbitrary key for the file data).') - try: - if file.size > self.max_upload_size * 10**6: - raise BinderFileSizeExceeded(self.max_upload_size) - - field = self.model._meta.get_field(file_field_name) - - if getattr(field, 'allowed_extensions', None) is not None: - extension = None if '.' not in file.name else file.name.split('.')[-1] - - if extension not in field.allowed_extensions: - raise BinderFileTypeIncorrect([{'extension': t} for t in field.allowed_extensions]) - - if isinstance(field, models.fields.files.ImageField): - try: - img = Image.open(file) - except Exception: - raise BinderImageError('Could not parse the file as an image.') - - format = img.format.lower() - if not format in ('png', 'gif', 'jpeg'): - raise BinderFileTypeIncorrect([{'extension': t, 'mimetype': 'image/' + t} for t in ['jpeg', 'png', 'gif']]) - - width, height = img.size - if format == 'jpeg': - img2 = image_transpose_exif(img) - - if img2 != img: - file.seek(0) # Do not append to the existing file! - file.truncate() - img2.save(file, 'jpeg') - img = img2 - - # Determine resize threshold - try: - max_size = self.image_resize_threshold[file_field_name] - except TypeError: - max_size = self.image_resize_threshold - - try: - max_width, max_height = max_size - except (TypeError, ValueError): - max_width, max_height = max_size, max_size - - try: - format_override = self.image_format_override.get(file_field_name) - except AttributeError: - format_override = self.image_format_override - - changes = False - - # FIXME: hardcoded max - # Flat out refuse images exceeding this size, to prevent DoS. - width_limit, height_limit = max(max_width, 4096), max(max_height, 4096) - if width > width_limit or height > height_limit: - raise BinderImageSizeExceeded(width_limit, height_limit) - - # Resize images that are too large. - if width > max_width or height > max_height: - img.thumbnail((max_width, max_height), Image.ANTIALIAS) - logger.info('image dimensions ({}x{}) exceeded ({}, {}), resizing.'.format(width, height, max_width, max_height)) - if img.mode not in ["1", "L", "P", "RGB", "RGBA"]: - img = img.convert("RGB") - if format != 'jpeg': - format = 'png' - changes = True - - if format_override and format != format_override: - format = format_override - changes = True - - filename = '{}.{}'.format(os.path.basename(file.name), format) - - if changes: - file = io.BytesIO() - img.save(file, format) - else: - filename = file.name - - # FIXME: duplicate code - if file_field: - try: - old_hash = hashlib.sha256() - for c in file_field.file.chunks(): - old_hash.update(c) - old_hash = old_hash.hexdigest() - except FileNotFoundError: - logger.warning('Old file {} missing!'.format(file_field)) - old_hash = None - else: - old_hash = None - - file_field.delete(save=False) - # This triggers a save on obj - file_field.save(filename, django.core.files.File(file)) - - # FIXME: duplicate code - new_hash = hashlib.sha256() - for c in file_field.file.chunks(): - new_hash.update(c) - new_hash = new_hash.hexdigest() - - logger.info('POST updated {}[{}].{}: {} -> {}'.format(self._model_name(), pk, file_field_name, old_hash, new_hash)) - path = self.router.model_route(self.model, obj.id, field) - - # {duplicate-binder-file-field-hash-code} - if isinstance(field, BinderFileField): - path += '?h={}&content_type={}&filename={}'.format( - file_field.content_hash, - file_field.content_type or '', - os.path.basename(file_field.name), - ) - - return JsonResponse( {"data": {file_field_name: path}} ) + # Hack to communicate to _store() that we're not interested in + # the new data (for perf reasons). + request._is_file_upload = True + self._store(obj, {file_field_name: file}, request, pk=pk) + + field = self.model._meta.get_field(file_field_name) + path = self.router.model_route(self.model, obj.id, field) + # {duplicate-binder-file-field-hash-code} + if isinstance(field, BinderFileField): + file_field = getattr(obj, file_field_name) + path += '?h={}&content_type={}&filename={}'.format( + file_field.content_hash, + file_field.content_type or '', + os.path.basename(file_field.name), + ) - except ValidationError as ve: - raise self.binder_validation_error(obj, ve, pk=pk) + return JsonResponse({'data': {file_field_name: path}}) if request.method == 'DELETE': - self._require_model_perm('change', request) if not file_field: raise BinderIsDeleted() - # FIXME: duplicate code - old_hash = hashlib.sha256() - for c in file_field.file.chunks(): - old_hash.update(c) - old_hash = old_hash.hexdigest() - - file_field.delete() + # Hack to communicate to _store() that we're not interested in + # the new data (for perf reasons). + request._is_file_upload = True + self._store(obj, {file_field_name: None}, request, pk=pk) - logger.info('DELETEd {}[{}].{}: {}'.format(self._model_name(), pk, file_field_name, old_hash)) - return JsonResponse( {"data": {file_field_name: None}} ) + return JsonResponse({'data': {file_field_name: None}}) From 6dd7677027066cbad230054a7175cef53d089600 Mon Sep 17 00:00:00 2001 From: Daan van der Kallen Date: Wed, 19 Jan 2022 00:00:23 +0100 Subject: [PATCH 21/33] Do not use aggregates at all --- binder/views.py | 63 +++++++++++++++++-------------------------------- 1 file changed, 22 insertions(+), 41 deletions(-) diff --git a/binder/views.py b/binder/views.py index 4e35a538..63850784 100644 --- a/binder/views.py +++ b/binder/views.py @@ -922,8 +922,6 @@ def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): rel_ids_by_field_by_id = defaultdict(lambda: defaultdict(list)) virtual_fields = set() - Agg = self.AggStrategy - for field in with_map: vr = self.virtual_relations.get(field, None) @@ -935,12 +933,6 @@ def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): if rel == field or rel.startswith(field + '.') }) - # Model default orders (this sometimes matters) - orders = [] - field_alias = field + '___annotation' if vr else field - for o in (view.model._meta.ordering if view.model._meta.ordering else BinderModel.Meta.ordering): - orders.append(prefix_db_expression(o, field_alias)) - # Virtual relation if vr: virtual_fields.add(field) @@ -956,45 +948,34 @@ def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): # annotations, so allow for fetching of ids, instead. # This does mean you can't filter on this relation # unless you write a custom filter, too. - if isinstance(virtual_annotation, Q): - annotation = Agg(virtual_annotation, filter=q, ordering=orders) - else: - try: - func = getattr(self, virtual_annotation) - except AttributeError: - raise BinderRequestError('Annotation for virtual relation {{{}}}.{{{}}} is {{{}}}, but no method by that name exists.'.format( - self.model.__name__, field, virtual_annotation - )) - rel_ids_by_field_by_id[field] = func(request, pks, q) - continue + try: + func = getattr(self, virtual_annotation) + except AttributeError: + raise BinderRequestError('Annotation for virtual relation {{{}}}.{{{}}} is {{{}}}, but no method by that name exists.'.format( + self.model.__name__, field, virtual_annotation + )) + rel_ids_by_field_by_id[field] = func(request, pks, q) # Actual relation else: if (getattr(self.model, field).__class__ == models.fields.related.ReverseOneToOneDescriptor or not any(f.name == field for f in (list(self.model._meta.many_to_many) + list(self._get_reverse_relations())))): singular_fields.add(field) - if Agg != GroupConcat: # HACKK (GROUP_CONCAT can't filter and excludes NULL already) - q &= Q(**{field+'__pk__isnull': False}) - annotation = Agg(field+'__pk', filter=q, ordering=orders) - - qs = self.model.objects.filter(pk__in=pks).values('pk').annotate(**{field_alias: annotation}) - - for record in qs: - value = record[field_alias] - - # Make the values distinct. We can't do this in - # the Agg() call, because then we get an error - # regarding order by and values needing to be the - # same :( - # We also can't just put it in a set, because we - # need to preserve the ordering. So we use a set - # to keep track of what we've seen and only add - # new items. - seen_values = set() - for v in value: - if v not in seen_values: - rel_ids_by_field_by_id[field][record['pk']].append(v) - seen_values.add(v) + # Model default orders (this sometimes matters) + orders = [] + for o in (view.model._meta.ordering if view.model._meta.ordering else BinderModel.Meta.ordering): + orders.append(prefix_db_expression(o, field)) + + query = ( + self.model.objects + .order_by(*orders) + .filter(q, pk__in=pks, **{field + '__isnull': False}) + .values_list('pk', field + '__pk') + .distinct() + ) + + for pk, rel_pk in query: + rel_ids_by_field_by_id[field][pk].append(rel_pk) for field, sub_fields in with_map.items(): next = self._follow_related(field)[0].model From d6b11e618b0d5784288cbdfa532fbe81fc018835 Mon Sep 17 00:00:00 2001 From: Daan van der Kallen Date: Wed, 19 Jan 2022 01:37:51 +0100 Subject: [PATCH 22/33] Do not use join for reverse relations --- binder/views.py | 52 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/binder/views.py b/binder/views.py index 63850784..e85f863a 100644 --- a/binder/views.py +++ b/binder/views.py @@ -257,6 +257,21 @@ def prefix_db_expression(value, prefix): raise ValueError('Unknown expression type, cannot apply db prefix: %s', value) +# Prefix a q expression by adding a prefix to all filters. You can also supply +# an 'antiprefix', if the filter starts with this value it will be removed +# instead of the prefix being added. This is useful for reversed fields. +def prefix_q_expression(value, prefix, antiprefix=None): + children = [] + for child in value.children: + if isinstance(child, Q): + children.append(prefix_q_expression(child, prefix, antiprefix)) + elif antiprefix is not None and child[0].startswith(antiprefix + '__'): + children.append((child[0][len(antiprefix) + 2:], child[1])) + else: + children.append((prefix + '__' + child[0], child[1])) + return Q(*children, _negated=value.negated) + + class ModelView(View): # Model this is a view for. Use None for views not tied to a particular model. model = None @@ -957,22 +972,31 @@ def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): rel_ids_by_field_by_id[field] = func(request, pks, q) # Actual relation else: - if (getattr(self.model, field).__class__ == models.fields.related.ReverseOneToOneDescriptor or - not any(f.name == field for f in (list(self.model._meta.many_to_many) + list(self._get_reverse_relations())))): + f = self.model._meta.get_field(field) + + if f.one_to_one or f.many_to_one: singular_fields.add(field) - # Model default orders (this sometimes matters) - orders = [] - for o in (view.model._meta.ordering if view.model._meta.ordering else BinderModel.Meta.ordering): - orders.append(prefix_db_expression(o, field)) - - query = ( - self.model.objects - .order_by(*orders) - .filter(q, pk__in=pks, **{field + '__isnull': False}) - .values_list('pk', field + '__pk') - .distinct() - ) + if any(f.name == field for f in self._get_reverse_relations()): + rev_field = f.remote_field.name + query = ( + view.model.objects + .filter(prefix_q_expression(q, rev_field, field), **{rev_field + '__in': pks}) + .values_list(rev_field + '__pk', 'pk') + .distinct() + ) + else: + # Model default orders (this sometimes matters) + orders = [] + for o in (view.model._meta.ordering if view.model._meta.ordering else BinderModel.Meta.ordering): + orders.append(prefix_db_expression(o, field)) + query = ( + self.model.objects + .order_by(*orders) + .filter(q, pk__in=pks, **{field + '__isnull': False}) + .values_list('pk', field + '__pk') + .distinct() + ) for pk, rel_pk in query: rel_ids_by_field_by_id[field][pk].append(rel_pk) From 33ea43f0f5f64c4f25189f51e48df782614b404f Mon Sep 17 00:00:00 2001 From: Stefan Majoor Date: Fri, 18 Feb 2022 11:01:44 +0100 Subject: [PATCH 23/33] Add webpage model --- binder/plugins/models/__init__.py | 2 +- binder/plugins/models/html_field.py | 152 +--------------------------- tests/test_html_field.py | 124 +---------------------- tests/testapp/models/__init__.py | 10 +- 4 files changed, 9 insertions(+), 279 deletions(-) diff --git a/binder/plugins/models/__init__.py b/binder/plugins/models/__init__.py index 7ac605e7..5ad29194 100644 --- a/binder/plugins/models/__init__.py +++ b/binder/plugins/models/__init__.py @@ -1 +1 @@ -from .html_field import HtmlField # noqa: F401 +from .html_field import HtmlField diff --git a/binder/plugins/models/html_field.py b/binder/plugins/models/html_field.py index d6cd5c65..2ede5568 100644 --- a/binder/plugins/models/html_field.py +++ b/binder/plugins/models/html_field.py @@ -1,157 +1,11 @@ -from typing import List - from django.db.models import TextField -from html.parser import HTMLParser -from django.core.exceptions import ValidationError -from django.utils.translation import gettext as _ - -ALLOWED_LINK_PREFIXES = [ - 'http://', - 'https://', - 'mailto:' -] - - -def link_rel_validator(tag, attribute_name, attribute_value) -> List[ValidationError]: - validation_errors = [] - - rels = attribute_value.split(' ') - - if 'noopener' not in rels: - - validation_errors.append(ValidationError( - _('Link needs rel="noopener"'), - code='invalid_attribute', - params={ - 'tag': tag, - }, - )) - - if 'noreferrer' not in rels: - validation_errors.append(ValidationError( - _('Link needs rel="noreferer"'), - code='invalid_attribute', - params={ - 'tag': tag, - }, - )) - - - return validation_errors - - -def link_validator(tag, attribute_name, attribute_value) -> List[ValidationError]: - validation_errors = [] - if not any(map(lambda prefix: attribute_value.startswith(prefix), ALLOWED_LINK_PREFIXES)): - validation_errors.append(ValidationError( - _('Link is not valid'), - code='invalid_attribute', - params={ - 'tag': tag, - }, - )) - return validation_errors - - -class HtmlValidator(HTMLParser): - allowed_tags = [ - # General setup - 'p', 'br', - # Headers - 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'h7', - - # text decoration - 'b', 'strong', 'i', 'em', 'u', - # Lists - 'ol', 'ul', 'li', - - # Special - 'a', - ] - - allowed_attributes = { - 'a': ['href', 'rel', 'target'] - } - - required_attributes = { - 'a': ['rel'], - } - - special_validators = { - ('a', 'href'): link_validator, - ('a', 'rel'): link_rel_validator, - } - - error_messages = { - 'invalid_tag': _('Tag %(tag)s is not allowed'), - 'missing_attribute': _('Attribute %(attribute)s is required for tag %(tag)s'), - 'invalid_attribute': _('Attribute %(attribute)s not allowed for tag %(tag)s'), - } - - def validate(self, value: str) -> List[ValidationError]: - """ - Validates html, and gives a list of validation errors - """ - - self.errors = [] - - self.feed(value) - - return self.errors - - def handle_starttag(self, tag: str, attrs: list) -> None: - tag_errors = [] - if tag not in self.allowed_tags: - tag_errors.append(ValidationError( - self.error_messages['invalid_tag'], - code='invalid_tag', - params={ - 'tag': tag - }, - )) - - set_attributes = set(map(lambda attr: attr[0], attrs)) - required_attributes = set(self.required_attributes.get(tag, [])) - missing_attributes = required_attributes - set_attributes - for missing_attribute in missing_attributes: - tag_errors.append( - ValidationError( - self.error_messages['missing_attribute'], - code='missing_attribute', - params={ - 'tag': tag, - 'attribute': missing_attribute - }, - ) - ) - - allowed_attributes_for_tag = self.allowed_attributes.get(tag, []) - - for (attribute_name, attribute_content) in attrs: - if attribute_name not in allowed_attributes_for_tag: - tag_errors.append(ValidationError( - self.error_messages['invalid_attribute'], - code='invalid_attribute', - params={ - 'tag': tag, - 'attribute': attribute_name - }, - )) - if (tag, attribute_name) in self.special_validators: - tag_errors += self.special_validators[(tag, attribute_name)](tag, attribute_name, attribute_content) - - self.errors += tag_errors class HtmlField(TextField): """ - Determine a safe way to save "secure" user provided HTML input, and prevent XSS injections + Determine a safe way to save "secure" user provided HTML input, and prevent """ - def validate(self, value: str, _): - # Validate all html tags - validator = HtmlValidator() - errors = validator.validate(value) - if errors: - raise ValidationError(errors) + def validate(self, value, _): + pass diff --git a/tests/test_html_field.py b/tests/test_html_field.py index fa10ab7d..8caa40be 100644 --- a/tests/test_html_field.py +++ b/tests/test_html_field.py @@ -1,11 +1,10 @@ from django.contrib.auth.models import User from django.test import TestCase, Client -import json -from .testapp.models import Zoo, WebPage +from project.testapp.models import Zoo, WebPage -class HtmlFieldTestCase(TestCase): +class AnnotationTestCase(TestCase): def setUp(self): super().setUp() @@ -19,124 +18,9 @@ def setUp(self): self.zoo = Zoo(name='Apenheul') self.zoo.save() - self.webpage = WebPage.objects.create(zoo=self.zoo, content='') - + self.webpage = WebPage(zoo=self.zoo, content='') def test_save_normal_text_ok(self): - response = self.client.put(f'/web_page/{self.webpage.id}/', data=json.dumps({'content': 'Artis'})) - self.assertEqual(response.status_code, 200) - - def test_simple_html_is_ok(self): - response = self.client.put(f'/web_page/{self.webpage.id}/', - data=json.dumps({'content': '

Artis

Artis is a zoo in amsterdam'})) - self.assertEqual(response.status_code, 200) - - def test_wrong_attribute_not_ok(self): - response = self.client.put(f'/web_page/{self.webpage.id}/', - data=json.dumps({'content': 'test'})) - self.assertEqual(response.status_code, 400) - - parsed_response = json.loads(response.content) - self.assertEqual('ValidationError', parsed_response['code']) - self.assertEqual('invalid_attribute', parsed_response['errors']['web_page'][f'{self.webpage.id}']['content'][0]['code']) - - def test_simple_link_is_ok(self): - response = self.client.put(f'/web_page/{self.webpage.id}/', data=json.dumps( - {'content': 'Visit artis website'})) - - self.assertEqual(response.status_code, 200) - - - - def test_javascript_link_is_not_ok(self): - response = self.client.put(f'/web_page/{self.webpage.id}/', - data=json.dumps({ - 'content': 'Visit artis website'})) - self.assertEqual(response.status_code, 400) - - parsed_response = json.loads(response.content) - self.assertEqual('ValidationError', parsed_response['code']) - - self.assertEqual('invalid_attribute', parsed_response['errors']['web_page'][f'{self.webpage.id}']['content'][0]['code']) - - - - def test_script_is_not_ok(self): - response = self.client.put(f'/web_page/{self.webpage.id}/', - data=json.dumps({'content': ''})) - - self.assertEqual(response.status_code, 400) - - parsed_response = json.loads(response.content) - self.assertEqual('ValidationError', parsed_response['code']) - self.assertEqual('invalid_tag', parsed_response['errors']['web_page'][f'{self.webpage.id}']['content'][0]['code']) - - def test_script_is_not_ok_nested(self): - response = self.client.put(f'/web_page/{self.webpage.id}/', - data=json.dumps({'content': ''})) - self.assertEqual(response.status_code, 400) - - parsed_response = json.loads(response.content) - self.assertEqual('ValidationError', parsed_response['code']) - self.assertEqual('invalid_tag', parsed_response['errors']['web_page'][f'{self.webpage.id}']['content'][0]['code']) - - - def test_can_handle_reallife_data(self): - """ - This is the worst case that we could produce on the WYIWYG edittor - """ - content = '

normal text


HEADing 1


HEADING 2


HEADING 3


bold


italic


underlined


Link


  1. ol1
  2. ol2
  • ul1
  • ul2


subscripttgege

g

"' - response = self.client.put(f'/web_page/{self.webpage.id}/', - data=json.dumps({'content': content})) - - self.assertEqual(response.status_code, 200) - - def test_multiple_errors(self): - response = self.client.put(f'/web_page/{self.webpage.id}/', - data=json.dumps({ - 'content': 'Visit artis website'})) - self.assertEqual(response.status_code, 400) - - parsed_response = json.loads(response.content) - self.assertEqual('ValidationError', parsed_response['code']) - - - self.assertEqual('invalid_tag', - parsed_response['errors']['web_page'][f'{self.webpage.id}']['content'][0]['code']) - self.assertEqual('invalid_tag', - parsed_response['errors']['web_page'][f'{self.webpage.id}']['content'][1]['code']) - - - def test_link_no_rel_errors(self): - response = self.client.put(f'/web_page/{self.webpage.id}/', - data=json.dumps({'content': 'bla'})) - self.assertEqual(response.status_code, 400) - - parsed_response = json.loads(response.content) - - self.assertEqual('ValidationError', parsed_response['code']) - self.assertEqual('missing_attribute', - parsed_response['errors']['web_page'][f'{self.webpage.id}']['content'][0]['code']) - - def test_link_noopener_required(self): - response = self.client.put(f'/web_page/{self.webpage.id}/', - data=json.dumps({'content': 'bla'})) - self.assertEqual(response.status_code, 400) - - parsed_response = json.loads(response.content) - - self.assertEqual('ValidationError', parsed_response['code']) - self.assertEqual('invalid_attribute', - parsed_response['errors']['web_page'][f'{self.webpage.id}']['content'][0]['code']) - - def test_link_noreferrer_required(self): - response = self.client.put(f'/web_page/{self.webpage.id}/', - data=json.dumps({'content': 'bla'})) - self.assertEqual(response.status_code, 400) - - parsed_response = json.loads(response.content) + self.webpage = WebPage(zoo=self.zoo, content='Artis') - self.assertEqual('ValidationError', parsed_response['code']) - self.assertEqual('invalid_attribute', - parsed_response['errors']['web_page'][f'{self.webpage.id}']['content'][0]['code']) diff --git a/tests/testapp/models/__init__.py b/tests/testapp/models/__init__.py index 32753719..48e0e08f 100644 --- a/tests/testapp/models/__init__.py +++ b/tests/testapp/models/__init__.py @@ -3,26 +3,18 @@ # We have it all, from A to Z! from .animal import Animal -from .donor import Donor from .caretaker import Caretaker from .contact_person import ContactPerson from .costume import Costume from .gate import Gate from .nickname import Nickname, NullableNickname from .lion import Lion -from .picture import Picture, PictureBook +from .picture import Picture from .zoo import Zoo from .zoo_employee import ZooEmployee from .city import City, CityState, PermanentCity from .country import Country from .web_page import WebPage -from .pet import Pet -from .reverse_config_models import ( - ReverseParent, - ReverseChild, - ReverseParentNoChildHistory, - ReverseChildNoHistory, -) # This is Postgres-specific if os.environ.get('BINDER_TEST_MYSQL', '0') != '1': From e8cbec29a6655a2a8f22fe3be3d67b559310a43a Mon Sep 17 00:00:00 2001 From: Stefan Majoor Date: Fri, 18 Feb 2022 12:14:54 +0100 Subject: [PATCH 24/33] Implement html field --- binder/plugins/models/html_field.py | 75 ++++++++++++++++++++++++++++- tests/test_html_field.py | 48 ++++++++++++++++-- tests/testapp/views/__init__.py | 4 +- 3 files changed, 117 insertions(+), 10 deletions(-) diff --git a/binder/plugins/models/html_field.py b/binder/plugins/models/html_field.py index 2ede5568..4c519e61 100644 --- a/binder/plugins/models/html_field.py +++ b/binder/plugins/models/html_field.py @@ -1,4 +1,74 @@ from django.db.models import TextField +from html.parser import HTMLParser +from django.core import exceptions + + +def link_validator(tag, attribute_name, attribute_value): + if not attribute_value.startswith('http://') and not attribute_value.startswith('https://'): + raise exceptions.ValidationError( + 'Link is not valid', + code='invalid_tag', + params={ + 'tag': tag, + }, + ) + +class HtmlValidator(HTMLParser): + allowed_tags = [ + # General setup + 'p', 'br', + # Headers + 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'h7', + + # text decoration + 'b', 'strong', 'i', 'em', 'u', + # Lists + 'ol', 'ul', 'li', + + # Special + 'a', + ] + + allowed_attributes = { + 'a': ['href', 'rel', 'target'] + } + + special_validators = { + ('a', 'href'): link_validator + } + + error_messages = { + 'invalid_tag': 'Tag %(tag)s is not allowed', + 'invalid_attribute': 'Attribute %(attribute)s not allowed for tag %(tag)s' + } + + + + def handle_starttag(self, tag: str, attrs: list) -> None: + if tag not in self.allowed_tags: + raise exceptions.ValidationError( + self.error_messages['invalid_tag'], + code='invalid_tag', + params={ + 'tag': tag + }, + ) + + allowed_attributes_for_tag = self.allowed_attributes.get(tag,[]) + + for (attribute_name, attribute_content) in attrs: + if attribute_name not in allowed_attributes_for_tag: + raise exceptions.ValidationError( + self.error_messages['invalid_attribute'], + code='invalid_tag', + params={ + 'tag': tag, + 'attribute': attribute_name + }, + ) + if (tag, attribute_name) in self.special_validators: + self.special_validators[(tag, attribute_name)](tag, attribute_name, attribute_content) + class HtmlField(TextField): @@ -6,6 +76,7 @@ class HtmlField(TextField): Determine a safe way to save "secure" user provided HTML input, and prevent """ - def validate(self, value, _): - pass + # Validate all html tags + validator = HtmlValidator() + validator.feed(value) diff --git a/tests/test_html_field.py b/tests/test_html_field.py index 8caa40be..d0b2d283 100644 --- a/tests/test_html_field.py +++ b/tests/test_html_field.py @@ -1,10 +1,11 @@ from django.contrib.auth.models import User from django.test import TestCase, Client -from project.testapp.models import Zoo, WebPage +import json +from .testapp.models import Zoo, WebPage -class AnnotationTestCase(TestCase): +class HtmlFieldTestCase(TestCase): def setUp(self): super().setUp() @@ -18,9 +19,46 @@ def setUp(self): self.zoo = Zoo(name='Apenheul') self.zoo.save() - self.webpage = WebPage(zoo=self.zoo, content='') - + self.webpage = WebPage.objects.create(zoo=self.zoo, content='') def test_save_normal_text_ok(self): - self.webpage = WebPage(zoo=self.zoo, content='Artis') + response = self.client.put(f'/web_page/{self.webpage.id}/', data=json.dumps({'content': 'Artis'})) + self.assertEqual(response.status_code, 200) + + def test_simple_html_is_ok(self): + response = self.client.put(f'/web_page/{self.webpage.id}/', + data=json.dumps({'content': '

Artis

Artis is a zoo in amsterdam'})) + self.assertEqual(response.status_code, 200) + + def test_wrong_attribute_not_ok(self): + response = self.client.put(f'/web_page/{self.webpage.id}/', + data=json.dumps({'content': 'test'})) + self.assertEqual(response.status_code, 400) + + def test_simple_link_is_ok(self): + response = self.client.put(f'/web_page/{self.webpage.id}/', data=json.dumps( + {'content': 'Visit artis website'})) + self.assertEqual(response.status_code, 200) + + def test_javascript_link_is_not_ok(self): + response = self.client.put(f'/web_page/{self.webpage.id}/', + data=json.dumps({ + 'content': 'Visit artis website'})) + self.assertEqual(response.status_code, 400) + + + + def test_script_is_not_ok(self): + response = self.client.put(f'/web_page/{self.webpage.id}/', + data=json.dumps({'content': ''})) + self.assertEqual(response.status_code, 400) + + def test_can_handle_reallife_data(self): + """ + This is the worst case that we could produce on the WYIWYG edittor + """ + content = '

normal text


HEADing 1


HEADING 2


HEADING 3


bold


italic


underlined


Link


  1. ol1
  2. ol2
  • ul1
  • ul2


subscripttgege

g

"' + response = self.client.put(f'/web_page/{self.webpage.id}/', + data=json.dumps({'content': content})) + self.assertEqual(response.status_code, 200) diff --git a/tests/testapp/views/__init__.py b/tests/testapp/views/__init__.py index c897a22a..40fcd3c9 100644 --- a/tests/testapp/views/__init__.py +++ b/tests/testapp/views/__init__.py @@ -15,10 +15,8 @@ from .gate import GateView from .lion import LionView from .nickname import NicknameView -from .picture import PictureView, PictureBookView +from .picture import PictureView from .user import UserView from .zoo import ZooView from .zoo_employee import ZooEmployeeView from .web_page import WebPageView -from .donor import DonorView -from .pet import PetView From ebae10fb79bf29bc3dbbd415ae135de7b6272d6b Mon Sep 17 00:00:00 2001 From: Stefan Majoor Date: Fri, 18 Feb 2022 12:18:55 +0100 Subject: [PATCH 25/33] Fix flake issues --- binder/plugins/models/__init__.py | 2 +- binder/plugins/models/html_field.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/binder/plugins/models/__init__.py b/binder/plugins/models/__init__.py index 5ad29194..7ac605e7 100644 --- a/binder/plugins/models/__init__.py +++ b/binder/plugins/models/__init__.py @@ -1 +1 @@ -from .html_field import HtmlField +from .html_field import HtmlField # noqa: F401 diff --git a/binder/plugins/models/html_field.py b/binder/plugins/models/html_field.py index 4c519e61..0409ac26 100644 --- a/binder/plugins/models/html_field.py +++ b/binder/plugins/models/html_field.py @@ -13,6 +13,7 @@ def link_validator(tag, attribute_name, attribute_value): }, ) + class HtmlValidator(HTMLParser): allowed_tags = [ # General setup @@ -42,8 +43,6 @@ class HtmlValidator(HTMLParser): 'invalid_attribute': 'Attribute %(attribute)s not allowed for tag %(tag)s' } - - def handle_starttag(self, tag: str, attrs: list) -> None: if tag not in self.allowed_tags: raise exceptions.ValidationError( @@ -54,7 +53,7 @@ def handle_starttag(self, tag: str, attrs: list) -> None: }, ) - allowed_attributes_for_tag = self.allowed_attributes.get(tag,[]) + allowed_attributes_for_tag = self.allowed_attributes.get(tag, []) for (attribute_name, attribute_content) in attrs: if attribute_name not in allowed_attributes_for_tag: @@ -70,7 +69,6 @@ def handle_starttag(self, tag: str, attrs: list) -> None: self.special_validators[(tag, attribute_name)](tag, attribute_name, attribute_content) - class HtmlField(TextField): """ Determine a safe way to save "secure" user provided HTML input, and prevent From 6c6811e677604ce7eb120fb0aba77af6504740aa Mon Sep 17 00:00:00 2001 From: Stefan Majoor Date: Wed, 2 Mar 2022 10:46:23 +0100 Subject: [PATCH 26/33] Add gettext for html field --- binder/plugins/models/html_field.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/binder/plugins/models/html_field.py b/binder/plugins/models/html_field.py index 0409ac26..dcac1cde 100644 --- a/binder/plugins/models/html_field.py +++ b/binder/plugins/models/html_field.py @@ -1,12 +1,13 @@ from django.db.models import TextField from html.parser import HTMLParser from django.core import exceptions +from django.utils.translation import gettext as _ def link_validator(tag, attribute_name, attribute_value): if not attribute_value.startswith('http://') and not attribute_value.startswith('https://'): raise exceptions.ValidationError( - 'Link is not valid', + _('Link is not valid'), code='invalid_tag', params={ 'tag': tag, @@ -39,8 +40,8 @@ class HtmlValidator(HTMLParser): } error_messages = { - 'invalid_tag': 'Tag %(tag)s is not allowed', - 'invalid_attribute': 'Attribute %(attribute)s not allowed for tag %(tag)s' + 'invalid_tag': _('Tag %(tag)s is not allowed'), + 'invalid_attribute': _('Attribute %(attribute)s not allowed for tag %(tag)s'), } def handle_starttag(self, tag: str, attrs: list) -> None: From d1ff37fbf1058ff7d2aba5150bfd4906fdec37a9 Mon Sep 17 00:00:00 2001 From: Stefan Majoor Date: Wed, 2 Mar 2022 10:47:55 +0100 Subject: [PATCH 27/33] finish sentence --- binder/plugins/models/html_field.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/binder/plugins/models/html_field.py b/binder/plugins/models/html_field.py index dcac1cde..26f1c287 100644 --- a/binder/plugins/models/html_field.py +++ b/binder/plugins/models/html_field.py @@ -72,7 +72,7 @@ def handle_starttag(self, tag: str, attrs: list) -> None: class HtmlField(TextField): """ - Determine a safe way to save "secure" user provided HTML input, and prevent + Determine a safe way to save "secure" user provided HTML input, and prevent XSS injections """ def validate(self, value, _): From ba50109323daca6fc8752d28d6496158642f67e7 Mon Sep 17 00:00:00 2001 From: Stefan Majoor Date: Wed, 2 Mar 2022 10:57:13 +0100 Subject: [PATCH 28/33] Add test for nested attributes & also test the content of the error message --- binder/plugins/models/html_field.py | 4 ++-- tests/test_html_field.py | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/binder/plugins/models/html_field.py b/binder/plugins/models/html_field.py index 26f1c287..c4365005 100644 --- a/binder/plugins/models/html_field.py +++ b/binder/plugins/models/html_field.py @@ -8,7 +8,7 @@ def link_validator(tag, attribute_name, attribute_value): if not attribute_value.startswith('http://') and not attribute_value.startswith('https://'): raise exceptions.ValidationError( _('Link is not valid'), - code='invalid_tag', + code='invalid_attribute', params={ 'tag': tag, }, @@ -60,7 +60,7 @@ def handle_starttag(self, tag: str, attrs: list) -> None: if attribute_name not in allowed_attributes_for_tag: raise exceptions.ValidationError( self.error_messages['invalid_attribute'], - code='invalid_tag', + code='invalid_attribute', params={ 'tag': tag, 'attribute': attribute_name diff --git a/tests/test_html_field.py b/tests/test_html_field.py index d0b2d283..49da290a 100644 --- a/tests/test_html_field.py +++ b/tests/test_html_field.py @@ -21,6 +21,8 @@ def setUp(self): self.webpage = WebPage.objects.create(zoo=self.zoo, content='') + + def test_save_normal_text_ok(self): response = self.client.put(f'/web_page/{self.webpage.id}/', data=json.dumps({'content': 'Artis'})) self.assertEqual(response.status_code, 200) @@ -35,6 +37,10 @@ def test_wrong_attribute_not_ok(self): data=json.dumps({'content': 'test'})) self.assertEqual(response.status_code, 400) + parsed_response = json.loads(response.content) + self.assertEqual('ValidationError', parsed_response['code']) + self.assertEqual('invalid_attribute', parsed_response['errors']['web_page'][f'{self.webpage.id}']['content'][0]['code']) + def test_simple_link_is_ok(self): response = self.client.put(f'/web_page/{self.webpage.id}/', data=json.dumps( {'content': 'Visit artis website'})) @@ -46,13 +52,32 @@ def test_javascript_link_is_not_ok(self): 'content': 'Visit artis website'})) self.assertEqual(response.status_code, 400) + parsed_response = json.loads(response.content) + self.assertEqual('ValidationError', parsed_response['code']) + self.assertEqual('invalid_attribute', parsed_response['errors']['web_page'][f'{self.webpage.id}']['content'][0]['code']) + def test_script_is_not_ok(self): response = self.client.put(f'/web_page/{self.webpage.id}/', data=json.dumps({'content': ''})) + self.assertEqual(response.status_code, 400) + parsed_response = json.loads(response.content) + self.assertEqual('ValidationError', parsed_response['code']) + self.assertEqual('invalid_tag', parsed_response['errors']['web_page'][f'{self.webpage.id}']['content'][0]['code']) + + def test_script_is_not_ok_nested(self): + response = self.client.put(f'/web_page/{self.webpage.id}/', + data=json.dumps({'content': ''})) + self.assertEqual(response.status_code, 400) + + parsed_response = json.loads(response.content) + self.assertEqual('ValidationError', parsed_response['code']) + self.assertEqual('invalid_tag', parsed_response['errors']['web_page'][f'{self.webpage.id}']['content'][0]['code']) + + def test_can_handle_reallife_data(self): """ This is the worst case that we could produce on the WYIWYG edittor From a3f23dd8ea80900ef8c439d00c7f3f475559a687 Mon Sep 17 00:00:00 2001 From: Stefan Majoor Date: Thu, 3 Mar 2022 11:48:31 +0100 Subject: [PATCH 29/33] Merge multiple errors in HTMLField --- binder/plugins/models/html_field.py | 55 +++++++++++++++++++++-------- tests/test_html_field.py | 16 +++++++++ 2 files changed, 57 insertions(+), 14 deletions(-) diff --git a/binder/plugins/models/html_field.py b/binder/plugins/models/html_field.py index c4365005..18c642a9 100644 --- a/binder/plugins/models/html_field.py +++ b/binder/plugins/models/html_field.py @@ -1,18 +1,29 @@ +from functools import reduce +from typing import List + from django.db.models import TextField from html.parser import HTMLParser -from django.core import exceptions +from django.core.exceptions import ValidationError from django.utils.translation import gettext as _ - -def link_validator(tag, attribute_name, attribute_value): - if not attribute_value.startswith('http://') and not attribute_value.startswith('https://'): - raise exceptions.ValidationError( +ALLOWED_LINK_PREFIXES = [ + 'http://', + 'https://', + 'mailto:' +] +def link_validator(tag, attribute_name, attribute_value) -> List[ValidationError]: + validation_errors = [] + if not any(map(lambda prefix: attribute_value.startswith(prefix), ALLOWED_LINK_PREFIXES)): + validation_errors.append(ValidationError( _('Link is not valid'), code='invalid_attribute', params={ 'tag': tag, }, - ) + )) + + + return validation_errors class HtmlValidator(HTMLParser): @@ -44,38 +55,54 @@ class HtmlValidator(HTMLParser): 'invalid_attribute': _('Attribute %(attribute)s not allowed for tag %(tag)s'), } + def validate(self, value: str) -> List[ValidationError]: + """ + Validates html, and gives a list of validation errors + """ + + self.errors = [] + + self.feed(value) + + return self.errors + def handle_starttag(self, tag: str, attrs: list) -> None: + tag_errors = [] if tag not in self.allowed_tags: - raise exceptions.ValidationError( + tag_errors.append(ValidationError( self.error_messages['invalid_tag'], code='invalid_tag', params={ 'tag': tag }, - ) + )) allowed_attributes_for_tag = self.allowed_attributes.get(tag, []) for (attribute_name, attribute_content) in attrs: if attribute_name not in allowed_attributes_for_tag: - raise exceptions.ValidationError( + tag_errors.append(ValidationError( self.error_messages['invalid_attribute'], code='invalid_attribute', params={ 'tag': tag, 'attribute': attribute_name }, - ) + )) if (tag, attribute_name) in self.special_validators: - self.special_validators[(tag, attribute_name)](tag, attribute_name, attribute_content) + tag_errors += self.special_validators[(tag, attribute_name)](tag, attribute_name, attribute_content) + + self.errors += tag_errors class HtmlField(TextField): """ Determine a safe way to save "secure" user provided HTML input, and prevent XSS injections """ - - def validate(self, value, _): + def validate(self, value: str, _): # Validate all html tags validator = HtmlValidator() - validator.feed(value) + errors = validator.validate(value) + + if errors: + raise ValidationError(errors) diff --git a/tests/test_html_field.py b/tests/test_html_field.py index 49da290a..34cc5531 100644 --- a/tests/test_html_field.py +++ b/tests/test_html_field.py @@ -87,3 +87,19 @@ def test_can_handle_reallife_data(self): data=json.dumps({'content': content})) self.assertEqual(response.status_code, 200) + + def test_multiple_errors(self): + response = self.client.put(f'/web_page/{self.webpage.id}/', + data=json.dumps({ + 'content': 'Visit artis website'})) + self.assertEqual(response.status_code, 400) + + parsed_response = json.loads(response.content) + self.assertEqual('ValidationError', parsed_response['code']) + + + self.assertEqual('invalid_tag', + parsed_response['errors']['web_page'][f'{self.webpage.id}']['content'][0]['code']) + self.assertEqual('invalid_tag', + parsed_response['errors']['web_page'][f'{self.webpage.id}']['content'][1]['code']) + From 8d895efbee49291f641181e0c1ce7ac843b6ec76 Mon Sep 17 00:00:00 2001 From: Stefan Majoor Date: Thu, 3 Mar 2022 12:07:36 +0100 Subject: [PATCH 30/33] Add noreferrer noopener check for links --- binder/plugins/models/html_field.py | 56 +++++++++++++++++++++++++++-- tests/test_html_field.py | 41 +++++++++++++++++++-- 2 files changed, 92 insertions(+), 5 deletions(-) diff --git a/binder/plugins/models/html_field.py b/binder/plugins/models/html_field.py index 18c642a9..acce5313 100644 --- a/binder/plugins/models/html_field.py +++ b/binder/plugins/models/html_field.py @@ -11,6 +11,36 @@ 'https://', 'mailto:' ] + + +def link_rel_validator(tag, attribute_name, attribute_value) -> List[ValidationError]: + validation_errors = [] + + rels = attribute_value.split(' ') + + if 'noopener' not in rels: + + validation_errors.append(ValidationError( + _('Link needs rel="noopener"'), + code='invalid_attribute', + params={ + 'tag': tag, + }, + )) + + if 'noreferrer' not in rels: + validation_errors.append(ValidationError( + _('Link needs rel="noreferer"'), + code='invalid_attribute', + params={ + 'tag': tag, + }, + )) + + + return validation_errors + + def link_validator(tag, attribute_name, attribute_value) -> List[ValidationError]: validation_errors = [] if not any(map(lambda prefix: attribute_value.startswith(prefix), ALLOWED_LINK_PREFIXES)): @@ -21,8 +51,6 @@ def link_validator(tag, attribute_name, attribute_value) -> List[ValidationError 'tag': tag, }, )) - - return validation_errors @@ -46,12 +74,18 @@ class HtmlValidator(HTMLParser): 'a': ['href', 'rel', 'target'] } + required_attributes = { + 'a': ['rel'], + } + special_validators = { - ('a', 'href'): link_validator + ('a', 'href'): link_validator, + ('a', 'rel'): link_rel_validator, } error_messages = { 'invalid_tag': _('Tag %(tag)s is not allowed'), + 'missing_attribute': _('Attribute %(attribute)s is required for tag %(tag)s'), 'invalid_attribute': _('Attribute %(attribute)s not allowed for tag %(tag)s'), } @@ -77,6 +111,21 @@ def handle_starttag(self, tag: str, attrs: list) -> None: }, )) + set_attributes = set(map(lambda attr: attr[0], attrs)) + required_attributes = set(self.required_attributes.get(tag, [])) + missing_attributes = required_attributes - set_attributes + for missing_attribute in missing_attributes: + tag_errors.append( + ValidationError( + self.error_messages['missing_attribute'], + code='missing_attribute', + params={ + 'tag': tag, + 'attribute': missing_attribute + }, + ) + ) + allowed_attributes_for_tag = self.allowed_attributes.get(tag, []) for (attribute_name, attribute_content) in attrs: @@ -99,6 +148,7 @@ class HtmlField(TextField): """ Determine a safe way to save "secure" user provided HTML input, and prevent XSS injections """ + def validate(self, value: str, _): # Validate all html tags validator = HtmlValidator() diff --git a/tests/test_html_field.py b/tests/test_html_field.py index 34cc5531..fa10ab7d 100644 --- a/tests/test_html_field.py +++ b/tests/test_html_field.py @@ -43,17 +43,21 @@ def test_wrong_attribute_not_ok(self): def test_simple_link_is_ok(self): response = self.client.put(f'/web_page/{self.webpage.id}/', data=json.dumps( - {'content': 'Visit artis website'})) + {'content': 'Visit artis website'})) + self.assertEqual(response.status_code, 200) + + def test_javascript_link_is_not_ok(self): response = self.client.put(f'/web_page/{self.webpage.id}/', data=json.dumps({ - 'content': 'Visit artis website'})) + 'content': 'Visit artis website'})) self.assertEqual(response.status_code, 400) parsed_response = json.loads(response.content) self.assertEqual('ValidationError', parsed_response['code']) + self.assertEqual('invalid_attribute', parsed_response['errors']['web_page'][f'{self.webpage.id}']['content'][0]['code']) @@ -103,3 +107,36 @@ def test_multiple_errors(self): self.assertEqual('invalid_tag', parsed_response['errors']['web_page'][f'{self.webpage.id}']['content'][1]['code']) + + def test_link_no_rel_errors(self): + response = self.client.put(f'/web_page/{self.webpage.id}/', + data=json.dumps({'content': 'bla'})) + self.assertEqual(response.status_code, 400) + + parsed_response = json.loads(response.content) + + self.assertEqual('ValidationError', parsed_response['code']) + self.assertEqual('missing_attribute', + parsed_response['errors']['web_page'][f'{self.webpage.id}']['content'][0]['code']) + + def test_link_noopener_required(self): + response = self.client.put(f'/web_page/{self.webpage.id}/', + data=json.dumps({'content': 'bla'})) + self.assertEqual(response.status_code, 400) + + parsed_response = json.loads(response.content) + + self.assertEqual('ValidationError', parsed_response['code']) + self.assertEqual('invalid_attribute', + parsed_response['errors']['web_page'][f'{self.webpage.id}']['content'][0]['code']) + + def test_link_noreferrer_required(self): + response = self.client.put(f'/web_page/{self.webpage.id}/', + data=json.dumps({'content': 'bla'})) + self.assertEqual(response.status_code, 400) + + parsed_response = json.loads(response.content) + + self.assertEqual('ValidationError', parsed_response['code']) + self.assertEqual('invalid_attribute', + parsed_response['errors']['web_page'][f'{self.webpage.id}']['content'][0]['code']) From f51af6d0d9d044b7d618dd610bf8c46e2b4fe59e Mon Sep 17 00:00:00 2001 From: Stefan Majoor Date: Tue, 8 Mar 2022 09:47:51 +0100 Subject: [PATCH 31/33] linting --- binder/plugins/models/html_field.py | 1 - 1 file changed, 1 deletion(-) diff --git a/binder/plugins/models/html_field.py b/binder/plugins/models/html_field.py index acce5313..d6cd5c65 100644 --- a/binder/plugins/models/html_field.py +++ b/binder/plugins/models/html_field.py @@ -1,4 +1,3 @@ -from functools import reduce from typing import List from django.db.models import TextField From 1d12ae083e2f76762a3998da5cd39cbd0f6af536 Mon Sep 17 00:00:00 2001 From: Stefan Majoor Date: Thu, 10 Mar 2022 13:35:06 +0100 Subject: [PATCH 32/33] fix merge conflict --- binder/views.py | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/binder/views.py b/binder/views.py index e85f863a..04922aa4 100644 --- a/binder/views.py +++ b/binder/views.py @@ -260,17 +260,29 @@ def prefix_db_expression(value, prefix): # Prefix a q expression by adding a prefix to all filters. You can also supply # an 'antiprefix', if the filter starts with this value it will be removed # instead of the prefix being added. This is useful for reversed fields. -def prefix_q_expression(value, prefix, antiprefix=None): +def prefix_q_expression(value, prefix, antiprefix=None, model=None): children = [] for child in value.children: if isinstance(child, Q): - children.append(prefix_q_expression(child, prefix, antiprefix)) + children.append(prefix_q_expression(child, prefix, antiprefix, model)) + # pk__in with empty list is often used for identity values since django + # doesnt properly support them + elif child[0] == 'pk__in' and child[1] == []: + children.append(child) + elif antiprefix is not None and child[0] == antiprefix: + children.append(('pk', child[1])) elif antiprefix is not None and child[0].startswith(antiprefix + '__'): - children.append((child[0][len(antiprefix) + 2:], child[1])) + key = child[0][len(antiprefix) + 2:] + head = key.split('__', 1)[0] + if head != 'pk': + try: + model._meta.get_field(head) + except FieldDoesNotExist: + key = 'pk__' + key + children.append((key, child[1])) else: children.append((prefix + '__' + child[0], child[1])) - return Q(*children, _negated=value.negated) - + return Q(*children, _connector=value.connector, _negated=value.negated) class ModelView(View): # Model this is a view for. Use None for views not tied to a particular model. @@ -981,7 +993,7 @@ def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): rev_field = f.remote_field.name query = ( view.model.objects - .filter(prefix_q_expression(q, rev_field, field), **{rev_field + '__in': pks}) + .filter(prefix_q_expression(q, rev_field, field, view.model), **{rev_field + '__in': pks}) .values_list(rev_field + '__pk', 'pk') .distinct() ) @@ -1357,9 +1369,12 @@ def get(self, request, pk=None, withs=None, include_annotations=None): #### with # parse wheres from request - extras, extras_mapping, extras_reverse_mapping, field_results = self._get_withs(queryset, withs, request=request, include_annotations=include_annotations) - data = self._get_objs(queryset, request=request, annotations=include_annotations.get('')) + + pks = [obj['id'] for obj in data] + + extras, extras_mapping, extras_reverse_mapping, field_results = self._get_withs(pks, withs, request=request, include_annotations=include_annotations) + for obj in data: self._annotate_obj_with_related_withs(obj, field_results) @@ -1633,7 +1648,9 @@ def _resolve_file_path(self, value, request): except model.DoesNotExist: return None - return getattr(obj, field) + value = getattr(obj, field) + with value.open('rb') as f: + return ContentFile(f.read(), value.name) def _clean_image_file(self, field, value): From 0a7de13dbe2a1ee9983f2b3f2dd91808dfc8aae1 Mon Sep 17 00:00:00 2001 From: Stefan Majoor Date: Fri, 6 Dec 2024 10:10:02 +0100 Subject: [PATCH 33/33] Fix crash on non gotten FKs --- binder/plugins/views/csvexport.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/binder/plugins/views/csvexport.py b/binder/plugins/views/csvexport.py index 93b6a7e4..f9ed6d26 100644 --- a/binder/plugins/views/csvexport.py +++ b/binder/plugins/views/csvexport.py @@ -282,14 +282,20 @@ def get_datum(data, key, prefix=''): if fk_ids is None: # This case happens if we have a nullable foreign key that is null. Treat this as a many - # to one relation with no values. + # to one relation with no values. fk_ids = [] elif type(fk_ids) != list: fk_ids = [fk_ids] # if head_key not in key_mapping: prefix_key = parent_data['with_mapping'][new_prefix[1:]] - datums = [str(get_datum(key_mapping[prefix_key][fk_id], subkey, new_prefix)) for fk_id in fk_ids] + datums = [] + for fk_id in fk_ids: + try: + datums.append(str(get_datum(key_mapping[prefix_key][fk_id], subkey, new_prefix))) + except KeyError: + pass + # datums = [str(get_datum(key_mapping[prefix_key][fk_id], subkey, new_prefix)) for fk_id in fk_ids] return self.csv_settings.multi_value_delimiter.join( datums )