diff --git a/.github/workflows/test-fair-database.yml b/.github/workflows/test-fair-database.yml new file mode 100644 index 000000000..9652a8377 --- /dev/null +++ b/.github/workflows/test-fair-database.yml @@ -0,0 +1,54 @@ +name: Tests + +on: + [push, pull_request] + +defaults: + run: + shell: bash + +jobs: + unit-test: + + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [macos-latest, ubuntu-latest, windows-latest] + python-version: ['3.12'] + fail-fast: false + + steps: + + - name: Obtain SasData source from git + uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: | + **/test.yml + **/requirements*.txt + + ### Installation of build-dependencies + + - name: Install Python dependencies + run: | + python -m pip install --upgrade pip + python -m pip install wheel setuptools + python -m pip install -r requirements.txt + python -m pip install -r sasdata/fair_database/requirements.txt + + ### Build and test sasdata + + - name: Build sasdata + run: | + # BUILD SASDATA + python -m pip install -e . + + ### Build documentation (if enabled) + + - name: Test with Django tests + run: | + python sasdata/fair_database/manage.py test sasdata.fair_database diff --git a/.gitignore b/.gitignore index ff18e7a00..a24aa7e04 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,7 @@ **/build /dist .mplconfig +**/db.sqlite3 # INSTALL.md recommends a venv that should not be committed venv diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b92d96596..f1958b339 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,19 @@ -default_install_hook_types: [pre-commit, pre-push] - repos: - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.12.9 - hooks: - # Run the linter, applying any available fixes - - id: ruff-check - stages: [ pre-commit, pre-push ] - args: [ --fix-only ] +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + files: "sasdata/fair_database/.*" +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.9.2 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + files: "sasdata/fair_database/.*" + - id: ruff-format + files: "sasdata/fair_database/.*" +- repo: https://github.com/codespell-project/codespell + rev: v2.3.0 + hooks: + - id: codespell + files: "sasdata/fair_database/.*" diff --git a/requirements.txt b/requirements.txt index 64bfe29a4..a1609faa1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,4 +16,3 @@ html5lib # Other stuff matplotlib -pre-commit \ No newline at end of file diff --git a/sasdata/ascii_reader_metadata.py b/sasdata/ascii_reader_metadata.py index 850492a84..d507bbb3f 100644 --- a/sasdata/ascii_reader_metadata.py +++ b/sasdata/ascii_reader_metadata.py @@ -1,5 +1,5 @@ -import re from dataclasses import dataclass, field +from re import split as re_split from typing import TypeVar initial_metadata = { @@ -10,27 +10,11 @@ 'process': ['name', 'date', 'description', 'term', 'notes'], 'sample': ['name', 'sample_id', 'thickness', 'transmission', 'temperature', 'position', 'orientation', 'details'], 'transmission_spectrum': ['name', 'timestamp', 'transmission', 'transmission_deviation'], - 'magnetic': ['demagnetizing_field', 'saturation_magnetization', 'applied_magnetic_field', 'counting_index'], 'other': ['title', 'run', 'definition'] } -CASING_REGEX = r'[A-Z][a-z]*' - -# First item has the highest precedence. -SEPARATOR_PRECEDENCE = [ - '_', - '-', -] -# If none of these characters exist in that string, use casing. See init_separator - T = TypeVar('T') -# TODO: There may be a better place for this. -pairings = {'I': 'dI', 'Q': 'dQ', 'Qx': 'dQx', 'Qy': 'dQy', 'Qz': 'dQz'} -pairing_error = {value: key for key, value in pairings.items()} -# Allows this to be bidirectional. -bidirectional_pairings = pairings | pairing_error - @dataclass class AsciiMetadataCategory[T]: values: dict[str, T] = field(default_factory=dict) @@ -42,68 +26,13 @@ def default_categories() -> dict[str, AsciiMetadataCategory[str | int]]: class AsciiReaderMetadata: # Key is the filename. filename_specific_metadata: dict[str, dict[str, AsciiMetadataCategory[str]]] = field(default_factory=dict) - # True instead of str means use the casing to separate the filename. - filename_separator: dict[str, str | bool] = field(default_factory=dict) + filename_separator: dict[str, str] = field(default_factory=dict) master_metadata: dict[str, AsciiMetadataCategory[int]] = field(default_factory=default_categories) - def init_separator(self, filename: str): - separator = next(filter(lambda c: c in SEPARATOR_PRECEDENCE, filename), True) - self.filename_separator[filename] = separator - - def filename_components(self, filename: str, cut_off_extension: bool = True, capture: bool = False) -> list[str]: - """Split the filename into several components based on the current separator for that file.""" - separator = self.filename_separator[filename] - # FIXME: This sort of string construction may be an issue. Might need an alternative. - base_str = '({})' if capture else '{}' - if isinstance(separator, str): - splitted = re.split(base_str.replace('{}', separator), filename) - else: - splitted = re.findall(base_str.replace('{}', CASING_REGEX), filename) - # If the last component has a file extensions, remove it. - last_component = splitted[-1] - if cut_off_extension and '.' in last_component: - pos = last_component.index('.') - last_component = last_component[:pos] - splitted[-1] = last_component - return splitted - - def purge_unreachable(self, filename: str): - """This is used when the separator has changed. If lets say we now have 2 components when there were 5 but the - 3rd component was selected, this will now produce an index out of range exception. Thus we'll need to purge this - to stop exceptions from happening.""" - components = self.filename_components(filename) - component_length = len(components) - # Converting to list as this mutates the dictionary as it goes through it. - for category_name, category in list(self.master_metadata.items()): - for key, value in list(category.values.items()): - if value >= component_length: - del self.master_metadata[category_name].values[key] + def filename_components(self, filename: str) -> list[str]: + return re_split(self.filename_separator[filename], filename) - def all_file_metadata(self, filename: str) -> dict[str, AsciiMetadataCategory[str]]: - """Return all of the metadata for known for the specified filename. This - will combin the master metadata specified for all files with the - metadata specific to that filename.""" - file_metadata = self.filename_specific_metadata[filename] - components = self.filename_components(filename) - # The ordering here is important. If there are conflicts, the second dictionary will override the first one. - # Conflicts shouldn't really be happening anyway but if they do, we're gonna go with the master metadata taking - # precedence for now. - return_metadata: dict[str, AsciiMetadataCategory[str]] = {} - for category_name, category in (file_metadata | self.master_metadata).items(): - combined_category_dict = category.values | self.master_metadata[category_name].values - new_category_dict: dict[str, str] = {} - for key, value in combined_category_dict.items(): - if isinstance(value, str): - new_category_dict[key] = value - elif isinstance(value, int): - new_category_dict[key] = components[value] - else: - raise TypeError(f'Invalid value for {key} in {category_name}') - new_category = AsciiMetadataCategory(new_category_dict) - return_metadata[category_name] = new_category - return return_metadata def get_metadata(self, category: str, value: str, filename: str, error_on_not_found=False) -> str | None: - """Get a particular piece of metadata for the filename.""" components = self.filename_components(filename) # We prioritise the master metadata. @@ -122,10 +51,6 @@ def get_metadata(self, category: str, value: str, filename: str, error_on_not_fo return None def update_metadata(self, category: str, key: str, filename: str, new_value: str | int): - """Update the metadata for a filename. If the new_value is a string, - then this new metadata will be specific to that file. Otherwise, if - new_value is an integer, then this will represent the component of the - filename that this metadata applies to all.""" if isinstance(new_value, str): self.filename_specific_metadata[filename][category].values[key] = new_value # TODO: What about the master metadata? Until that's gone, that still takes precedence. @@ -135,7 +60,6 @@ def update_metadata(self, category: str, key: str, filename: str, new_value: str raise TypeError('Invalid type for new_value') def clear_metadata(self, category: str, key: str, filename: str): - """Remove any metadata recorded for a certain filename.""" category_obj = self.filename_specific_metadata[filename][category] if key in category_obj.values: del category_obj.values[key] @@ -143,7 +67,5 @@ def clear_metadata(self, category: str, key: str, filename: str): del self.master_metadata[category].values[key] def add_file(self, new_filename: str): - """Add a filename to the metadata, filling it with some default - categories.""" # TODO: Fix typing here. Pyright is showing errors. self.filename_specific_metadata[new_filename] = default_categories() diff --git a/sasdata/data.py b/sasdata/data.py index 788f3d7f5..596d12097 100644 --- a/sasdata/data.py +++ b/sasdata/data.py @@ -1,150 +1,40 @@ -import typing -import h5py -import numpy as np -from h5py._hl.group import Group as HDF5Group -from sasdata import dataset_types -from sasdata.dataset_types import DatasetType -from sasdata.metadata import Metadata, MetadataEncoder -from sasdata.quantities.quantity import Quantity + +from sasdata.metadata import Metadata +from sasdata.quantities.quantity import NamedQuantity class SasData: - def __init__( - self, - name: str, - data_contents: dict[str, Quantity], - dataset_type: DatasetType, - metadata: Metadata, - verbose: bool = False, - ): + def __init__(self, name: str, + data_contents: list[NamedQuantity], + raw_metadata: Group, + verbose: bool=False): + self.name = name - # validate data contents - if not all([key in dataset_type.optional or key in dataset_type.required for key in data_contents]): - raise ValueError(f"Columns don't match the dataset type: {[key for key in data_contents]}") self._data_contents = data_contents + self._raw_metadata = raw_metadata self._verbose = verbose - self.metadata = metadata - - # TODO: Could this be optional? - self.dataset_type: DatasetType = dataset_type + self.metadata = Metadata(AccessorTarget(raw_metadata, verbose=verbose)) # Components that need to be organised after creation - self.mask = None # TODO: fill out - self.model_requirements = None # TODO: fill out - - # TODO: Handle the other data types. - @property - def ordinate(self) -> Quantity: - match self.dataset_type: - case dataset_types.one_dim | dataset_types.two_dim: - return self._data_contents["I"] - case dataset_types.sesans: - return self._data_contents["Depolarisation"] - case _: - return None + self.ordinate: NamedQuantity[np.ndarray] = None # TODO: fill out + self.abscissae: list[NamedQuantity[np.ndarray]] = None # TODO: fill out + self.mask = None # TODO: fill out + self.model_requirements = None # TODO: fill out - @property - def abscissae(self) -> Quantity: - match self.dataset_type: - case dataset_types.one_dim: - return self._data_contents["Q"] - case dataset_types.two_dim: - # Type hinting is a bit lacking. Assume each part of the zip is a scalar value. - data_contents = np.array( - list( - zip( - self._data_contents["Qx"].value, - self._data_contents["Qy"].value, - ) - ) - ) - # Use this value to extract units etc. Assume they will be the same for Qy. - reference_data_content = self._data_contents["Qx"] - # TODO: If this is a derived quantity then we are going to lose that - # information. - # - # TODO: Won't work when there's errors involved. On reflection, we - # probably want to avoid creating a new Quantity but at the moment I - # can't see a way around it. - return Quantity(data_contents, reference_data_content.units, name=self._data_contents["Qx"].name, id_header=self._data_contents["Qx"]._id_header) - case dataset_types.sesans: - return self._data_contents["SpinEchoLength"] - case _: - None - - def __getitem__(self, item: str): - return self._data_contents[item] - - def summary(self, indent=" "): + def summary(self, indent = " ", include_raw=False): s = f"{self.name}\n" - for data in sorted(self._data_contents, reverse=True): + for data in self._data_contents: s += f"{indent}{data}\n" s += "Metadata:\n" s += "\n" s += self.metadata.summary() - return s - - @staticmethod - def from_json(obj): - return SasData( - name=obj["name"], - dataset_type=DatasetType( - name=obj["type"]["name"], - required=obj["type"]["required"], - optional=obj["type"]["optional"], - expected_orders=obj["type"]["expected_orders"], - ), - data_contents=obj["data_contents"], - metadata=Metadata.from_json(obj["metadata"]), - ) + if include_raw: + s += key_tree(self._raw_metadata) - def _save_h5(self, sasentry: HDF5Group): - """Export data into HDF5 file""" - sasentry.attrs["name"] = self.name - self.metadata.as_h5(sasentry) - - # We export each data set into its own entry, so we only ever - # need sasdata01 - group = sasentry.create_group("sasdata01") - for idx, (key, sasdata) in enumerate(self._data_contents.items()): - sasdata.as_h5(group, key) - - - @staticmethod - def save_h5(data: dict[str, typing.Self], path: str | typing.BinaryIO): - with h5py.File(path, "w") as f: - for idx, (key, data) in enumerate(data.items()): - sasentry = f.create_group(f"sasentry{idx+1:02d}") - if not key.startswith("sasentry"): - sasentry.attrs["sasview_key"] = key - data._save_h5(sasentry) - - - -class SasDataEncoder(MetadataEncoder): - def default(self, obj): - match obj: - case DatasetType(): - return { - "name": obj.name, - "required": obj.required, - "optional": obj.optional, - "expected_orders": obj.expected_orders, - } - case SasData(): - return { - "name": obj.name, - "data_contents": obj._data_contents, - "type": obj.dataset_type, - "mask": obj.mask, - "metadata": obj.metadata, - "model_requirements": obj.model_requirements, - } - case _: - return super().default(obj) + return s diff --git a/sasdata/dataloader/readers/danse_reader.py b/sasdata/dataloader/readers/danse_reader.py index d1851c77d..10a1c515c 100644 --- a/sasdata/dataloader/readers/danse_reader.py +++ b/sasdata/dataloader/readers/danse_reader.py @@ -11,7 +11,6 @@ #This work benefited from DANSE software developed under NSF award DMR-0520547. #copyright 2008, University of Tennessee ############################################################################# -import logging import os import numpy as np diff --git a/sasdata/dataset_types.py b/sasdata/dataset_types.py index cbe36825d..c7d2f5927 100644 --- a/sasdata/dataset_types.py +++ b/sasdata/dataset_types.py @@ -19,7 +19,7 @@ class DatasetType: one_dim = DatasetType( name="1D I vs Q", required=["Q", "I"], - optional=["dI", "dQ", "Shadowfactor", "Qmean", "dQl", "dQw"], + optional=["dI", "dQ", "shadow"], expected_orders=[ ["Q", "I", "dI"], ["Q", "dQ", "I", "dI"]]) @@ -27,7 +27,7 @@ class DatasetType: two_dim = DatasetType( name="2D I vs Q", required=["Qx", "Qy", "I"], - optional=["dQx", "dQy", "dQz", "dI", "Qz", "ShadowFactor", "mask"], + optional=["dQx", "dQy", "dQz", "dI", "Qz", "shadow"], expected_orders=[ ["Qx", "Qy", "I"], ["Qx", "Qy", "I", "dI"], @@ -44,8 +44,8 @@ class DatasetType: sesans = DatasetType( name="SESANS", - required=["SpinEchoLength", "Depolarisation", "Wavelength"], - optional=["Transmission", "Polarisation"], + required=["z", "G"], + optional=["stuff", "other stuff", "more stuff"], expected_orders=[["z", "G"]]) dataset_types = {dataset.name for dataset in [one_dim, two_dim, sesans]} @@ -68,11 +68,8 @@ class DatasetType: "dQx": units.inverse_length, "dQy": units.inverse_length, "dQz": units.inverse_length, - "SpinEchoLength": units.length, - "Depolarisation": units.inverse_volume, - "Wavelength": units.length, - "Transmission": units.dimensionless, - "Polarisation": units.dimensionless, + "z": units.length, + "G": units.area, "shadow": units.dimensionless, "temperature": units.temperature, "magnetic field": units.magnetic_flux_density diff --git a/sasdata/default_units.py b/sasdata/default_units.py index b71fed42f..bbb275664 100644 --- a/sasdata/default_units.py +++ b/sasdata/default_units.py @@ -5,18 +5,13 @@ from sasdata.quantities.units import NamedUnit default_units = { - "Q": [unit.per_nanometer, unit.per_angstrom, unit.per_meter], - "I": [unit.per_centimeter, unit.per_meter], - "dQ": "Q", - "dI": "I", + 'Q': [unit.per_nanometer, unit.per_angstrom, unit.per_meter], + 'I': [unit.per_centimeter, unit.per_meter] } def defaults_or_fallback(column_name: str) -> list[NamedUnit]: - value = default_units.get(column_name, unit_kinds[column_name].units) - if isinstance(value, str): - return defaults_or_fallback(value) - return value + return default_units.get(column_name, unit_kinds[column_name].units) def first_default_for_fallback(column_name: str) -> NamedUnit: diff --git a/sasdata/fair_database/__init__.py b/sasdata/fair_database/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sasdata/fair_database/data/__init__.py b/sasdata/fair_database/data/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sasdata/fair_database/data/admin.py b/sasdata/fair_database/data/admin.py new file mode 100644 index 000000000..2134875a6 --- /dev/null +++ b/sasdata/fair_database/data/admin.py @@ -0,0 +1,11 @@ +from data import models +from django.contrib import admin + +admin.site.register(models.DataFile) +admin.site.register(models.Session) +admin.site.register(models.PublishedState) +admin.site.register(models.DataSet) +admin.site.register(models.MetaData) +admin.site.register(models.Quantity) +admin.site.register(models.OperationTree) +admin.site.register(models.ReferenceQuantity) diff --git a/sasdata/fair_database/data/apps.py b/sasdata/fair_database/data/apps.py new file mode 100644 index 000000000..b882be950 --- /dev/null +++ b/sasdata/fair_database/data/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class DataConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "data" diff --git a/sasdata/fair_database/data/forms.py b/sasdata/fair_database/data/forms.py new file mode 100644 index 000000000..519556e2f --- /dev/null +++ b/sasdata/fair_database/data/forms.py @@ -0,0 +1,9 @@ +from data.models import DataFile +from django import forms + + +# Create the form class. +class DataFileForm(forms.ModelForm): + class Meta: + model = DataFile + fields = ["file", "is_public"] diff --git a/sasdata/fair_database/data/migrations/0001_initial.py b/sasdata/fair_database/data/migrations/0001_initial.py new file mode 100644 index 000000000..e8f7219a6 --- /dev/null +++ b/sasdata/fair_database/data/migrations/0001_initial.py @@ -0,0 +1,332 @@ +# Generated by Django 5.1.6 on 2025-04-23 18:08 + +import data.models +import django.core.files.storage +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="DataFile", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "is_public", + models.BooleanField( + default=False, help_text="opt in to make your data public" + ), + ), + ( + "file_name", + models.CharField( + blank=True, + default=None, + help_text="File name", + max_length=200, + null=True, + ), + ), + ( + "file", + models.FileField( + default=None, + help_text="This is a file", + storage=django.core.files.storage.FileSystemStorage(), + upload_to="uploaded_files", + ), + ), + ( + "current_user", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "users", + models.ManyToManyField( + blank=True, related_name="+", to=settings.AUTH_USER_MODEL + ), + ), + ], + options={ + "abstract": False, + }, + ), + migrations.CreateModel( + name="DataSet", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "is_public", + models.BooleanField( + default=False, help_text="opt in to make your data public" + ), + ), + ("name", models.CharField(max_length=200)), + ( + "current_user", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to=settings.AUTH_USER_MODEL, + ), + ), + ("files", models.ManyToManyField(to="data.datafile")), + ( + "users", + models.ManyToManyField( + blank=True, related_name="+", to=settings.AUTH_USER_MODEL + ), + ), + ], + options={ + "abstract": False, + }, + ), + migrations.CreateModel( + name="MetaData", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("title", models.CharField(default="Title", max_length=500)), + ("run", models.JSONField(default=data.models.empty_list)), + ("definition", models.TextField(blank=True, null=True)), + ("instrument", models.JSONField(blank=True, null=True)), + ("process", models.JSONField(default=data.models.empty_list)), + ("sample", models.JSONField(blank=True, null=True)), + ( + "dataset", + models.OneToOneField( + on_delete=django.db.models.deletion.CASCADE, + related_name="metadata", + to="data.dataset", + ), + ), + ], + ), + migrations.CreateModel( + name="Quantity", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("value", models.JSONField()), + ("variance", models.JSONField()), + ("units", models.CharField(max_length=200)), + ("hash", models.IntegerField()), + ("label", models.CharField(max_length=50)), + ( + "dataset", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="data_contents", + to="data.dataset", + ), + ), + ], + ), + migrations.CreateModel( + name="OperationTree", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "operation", + models.CharField( + choices=[ + ("zero", "0 [Add.Id.]"), + ("one", "1 [Mul.Id.]"), + ("constant", "Constant"), + ("variable", "Variable"), + ("neg", "Neg"), + ("reciprocal", "Inv"), + ("add", "Add"), + ("sub", "Sub"), + ("mul", "Mul"), + ("div", "Div"), + ("pow", "Pow"), + ("transpose", "Transpose"), + ("dot", "Dot"), + ("matmul", "MatMul"), + ("tensor_product", "TensorProduct"), + ], + max_length=20, + ), + ), + ("parameters", models.JSONField(default=data.models.empty_dict)), + ("label", models.CharField(blank=True, max_length=10, null=True)), + ( + "child_operation", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="parent_operations", + to="data.operationtree", + ), + ), + ( + "quantity", + models.OneToOneField( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="operation_tree", + to="data.quantity", + ), + ), + ], + ), + migrations.CreateModel( + name="ReferenceQuantity", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("value", models.JSONField()), + ("variance", models.JSONField()), + ("units", models.CharField(max_length=200)), + ("hash", models.IntegerField()), + ( + "derived_quantity", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="references", + to="data.quantity", + ), + ), + ], + ), + migrations.CreateModel( + name="Session", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "is_public", + models.BooleanField( + default=False, help_text="opt in to make your data public" + ), + ), + ("title", models.CharField(max_length=200)), + ( + "current_user", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "users", + models.ManyToManyField( + blank=True, related_name="+", to=settings.AUTH_USER_MODEL + ), + ), + ], + options={ + "abstract": False, + }, + ), + migrations.CreateModel( + name="PublishedState", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("published", models.BooleanField(default=False)), + ("doi", models.URLField()), + ( + "session", + models.OneToOneField( + on_delete=django.db.models.deletion.CASCADE, + related_name="published_state", + to="data.session", + ), + ), + ], + ), + migrations.AddField( + model_name="dataset", + name="session", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="datasets", + to="data.session", + ), + ), + ] diff --git a/sasdata/fair_database/data/migrations/__init__.py b/sasdata/fair_database/data/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sasdata/fair_database/data/models.py b/sasdata/fair_database/data/models.py new file mode 100644 index 000000000..3a36dc831 --- /dev/null +++ b/sasdata/fair_database/data/models.py @@ -0,0 +1,232 @@ +from django.contrib.auth.models import User +from django.core.files.storage import FileSystemStorage +from django.db import models + + +# method for empty list default value +def empty_list(): + return [] + + +# method for empty dictionary default value +def empty_dict(): + return {} + + +class Data(models.Model): + """Base model for data with access-related information.""" + + # owner of the data + current_user = models.ForeignKey( + User, blank=True, null=True, on_delete=models.CASCADE, related_name="+" + ) + + # users that have been granted view access to the data + users = models.ManyToManyField(User, blank=True, related_name="+") + + # is the data public? + is_public = models.BooleanField( + default=False, help_text="opt in to make your data public" + ) + + class Meta: + abstract = True + + +class DataFile(Data): + """Database model for file contents.""" + + # file name + file_name = models.CharField( + max_length=200, default=None, blank=True, null=True, help_text="File name" + ) + + # imported data + # user can either import a file path or actual file + file = models.FileField( + blank=False, + default=None, + help_text="This is a file", + upload_to="uploaded_files", + storage=FileSystemStorage(), + ) + + +class DataSet(Data): + """Database model for a set of data and associated metadata.""" + + # dataset name + name = models.CharField(max_length=200) + + # associated files + files = models.ManyToManyField(DataFile) + + # session the dataset is a part of, if any + session = models.ForeignKey( + "Session", + on_delete=models.CASCADE, + related_name="datasets", + blank=True, + null=True, + ) + + # TODO: update based on SasData class in data.py + # type of dataset + # dataset_type = models.JSONField() + + +class Quantity(models.Model): + """Database model for data quantities such as the ordinate and abscissae.""" + + # data value + value = models.JSONField() + + # variance of the data + variance = models.JSONField() + + # units + units = models.CharField(max_length=200) + + # hash value + hash = models.IntegerField() + + # label, e.g. Q or I(Q) + label = models.CharField(max_length=50) + + # data set the quantity is a part of + dataset = models.ForeignKey( + DataSet, on_delete=models.CASCADE, related_name="data_contents" + ) + + +class ReferenceQuantity(models.Model): + """ + Database models for quantities referenced by variables in an OperationTree. + + Corresponds to the references dictionary in the QuantityHistory class in + sasdata/quantity.py. ReferenceQuantities should be essentially the same as + Quantities but with no operations performed on them and therefore no + OperationTree. + """ + + # data value + value = models.JSONField() + + # variance of the data + variance = models.JSONField() + + # units + units = models.CharField(max_length=200) + + # hash value + hash = models.IntegerField() + + # Quantity whose OperationTree this is a reference for + derived_quantity = models.ForeignKey( + Quantity, + related_name="references", + on_delete=models.CASCADE, + ) + + +# TODO: update based on changes in sasdata/metadata.py +class MetaData(models.Model): + """Database model for scattering metadata""" + + # title + title = models.CharField(max_length=500, default="Title") + + # run + run = models.JSONField(default=empty_list) + + # definition + definition = models.TextField(blank=True, null=True) + + # instrument + instrument = models.JSONField(blank=True, null=True) + + # process + process = models.JSONField(default=empty_list) + + # sample + sample = models.JSONField(blank=True, null=True) + + # associated dataset + dataset = models.OneToOneField( + DataSet, on_delete=models.CASCADE, related_name="metadata" + ) + + +class OperationTree(models.Model): + """Database model for tree of operations performed on a DataSet.""" + + # possible operations + OPERATION_CHOICES = { + "zero": "0 [Add.Id.]", + "one": "1 [Mul.Id.]", + "constant": "Constant", + "variable": "Variable", + "neg": "Neg", + "reciprocal": "Inv", + "add": "Add", + "sub": "Sub", + "mul": "Mul", + "div": "Div", + "pow": "Pow", + "transpose": "Transpose", + "dot": "Dot", + "matmul": "MatMul", + "tensor_product": "TensorProduct", + } + + # operation + operation = models.CharField(max_length=20, choices=OPERATION_CHOICES) + + # parameters + parameters = models.JSONField(default=empty_dict) + + # label (a or b) if the operation is a parameter of a child operation + # maintains ordering of binary operation parameters + label = models.CharField(max_length=10, blank=True, null=True) + + # operation this operation is a parameter for, if any + child_operation = models.ForeignKey( + "self", + on_delete=models.CASCADE, + related_name="parent_operations", + blank=True, + null=True, + ) + + # quantity the operation produces + # only set for base of tree (the quantity's most recent operation) + quantity = models.OneToOneField( + Quantity, + on_delete=models.CASCADE, + blank=True, + null=True, + related_name="operation_tree", + ) + + +class Session(Data): + """Database model for a project save state.""" + + # title + title = models.CharField(max_length=200) + + +class PublishedState(models.Model): + """Database model for a project published state.""" + + # published + published = models.BooleanField(default=False) + + # TODO: update doi as needed when DOI generation is implemented + # doi + doi = models.URLField() + + # session + session = models.OneToOneField( + Session, on_delete=models.CASCADE, related_name="published_state" + ) diff --git a/sasdata/fair_database/data/serializers.py b/sasdata/fair_database/data/serializers.py new file mode 100644 index 000000000..dfb7ece52 --- /dev/null +++ b/sasdata/fair_database/data/serializers.py @@ -0,0 +1,529 @@ +from data import models +from django.core.exceptions import ObjectDoesNotExist +from fair_database import permissions +from rest_framework import serializers + +# TODO: more custom validation, particularly for specific nested dictionary structures +# TODO: custom update methods for nested structures + + +# Determine if an operation does not have parent operations +def constant_or_variable(operation: str): + return operation in ["zero", "one", "constant", "variable"] + + +# Determine if an operation has two parent operations +def binary(operation: str): + return operation in ["add", "sub", "mul", "div", "dot", "matmul", "tensor_product"] + + +class DataFileSerializer(serializers.ModelSerializer): + """Serialization, deserialization, and validation for the DataFile model.""" + + class Meta: + model = models.DataFile + fields = "__all__" + + # TODO: check partial updates + # Check that private data has an owner + def validate(self, data): + if not self.context["is_public"] and not data["current_user"]: + raise serializers.ValidationError("private data must have an owner") + return data + + +class AccessManagementSerializer(serializers.Serializer): + """ + Serialization, deserialization, and validation for granting and revoking + access to instances of any exposed model. + """ + + # The username of a user + username = serializers.CharField(max_length=200, required=False) + # Whether that user has access + access = serializers.BooleanField() + + +class MetaDataSerializer(serializers.ModelSerializer): + """Serialization, deserialization, and validation for the MetaData model.""" + + # associated dataset + dataset = serializers.PrimaryKeyRelatedField( + queryset=models.DataSet, required=False, allow_null=True + ) + + class Meta: + model = models.MetaData + fields = "__all__" + + # Serialize an entry in MetaData + def to_representation(self, instance): + data = super().to_representation(instance) + if "dataset" in data: + data.pop("dataset") + return data + + +class OperationTreeSerializer(serializers.ModelSerializer): + """Serialization, deserialization, and validation for the OperationTree model.""" + + # associated quantity, for root operation + quantity = serializers.PrimaryKeyRelatedField( + queryset=models.Quantity, required=False, allow_null=True + ) + # operation this operation is a parameter for, for non-root operations + child_operation = serializers.PrimaryKeyRelatedField( + queryset=models.OperationTree, required=False, allow_null=True + ) + # parameter label, for non-root operations + label = serializers.CharField(max_length=10, required=False) + + class Meta: + model = models.OperationTree + fields = ["operation", "parameters", "quantity", "label", "child_operation"] + + # Validate parent operations + def validate_parameters(self, value): + if "a" in value: + serializer = OperationTreeSerializer(data=value["a"]) + serializer.is_valid(raise_exception=True) + if "b" in value: + serializer = OperationTreeSerializer(data=value["b"]) + serializer.is_valid(raise_exception=True) + return value + + # Check that the operation has the correct parameters + def validate(self, data): + expected_parameters = { + "zero": [], + "one": [], + "constant": ["value"], + "variable": ["hash_value", "name"], + "neg": ["a"], + "reciprocal": ["a"], + "add": ["a", "b"], + "sub": ["a", "b"], + "mul": ["a", "b"], + "div": ["a", "b"], + "pow": ["a", "power"], + "transpose": ["a", "axes"], + "dot": ["a", "b"], + "matmul": ["a", "b"], + "tensor_product": ["a", "b", "a_index", "b_index"], + } + + for parameter in expected_parameters[data["operation"]]: + if parameter not in data["parameters"]: + raise serializers.ValidationError( + data["operation"] + " requires parameter " + parameter + ) + + return data + + # Serialize an OperationTree instance + def to_representation(self, instance): + data = {"operation": instance.operation, "parameters": instance.parameters} + for parent_operation in instance.parent_operations.all(): + data["parameters"][parent_operation.label] = self.to_representation( + parent_operation + ) + return data + + # Create an OperationTree instance + def create(self, validated_data): + parent_operation1 = None + parent_operation2 = None + if not constant_or_variable(validated_data["operation"]): + parent_operation1 = validated_data["parameters"].pop("a") + parent_operation1["label"] = "a" + if binary(validated_data["operation"]): + parent_operation2 = validated_data["parameters"].pop("b") + parent_operation2["label"] = "b" + operation_tree = models.OperationTree.objects.create(**validated_data) + if parent_operation1: + parent_operation1["child_operation"] = operation_tree + OperationTreeSerializer.create( + OperationTreeSerializer(), validated_data=parent_operation1 + ) + if parent_operation2: + parent_operation2["child_operation"] = operation_tree + OperationTreeSerializer.create( + OperationTreeSerializer(), validated_data=parent_operation2 + ) + return operation_tree + + +class ReferenceQuantitySerializer(serializers.ModelSerializer): + """Serialization, deserialization, and validation for the ReferenceQuantity model.""" + + # quantity whose operation tree this is a reference for + derived_quantity = serializers.PrimaryKeyRelatedField( + queryset=models.Quantity, required=False + ) + + class Meta: + model = models.ReferenceQuantity + fields = ["value", "variance", "units", "hash", "derived_quantity"] + + # serialize a ReferenceQuantity instance + def to_representation(self, instance): + data = super().to_representation(instance) + if "derived_quantity" in data: + data.pop("derived_quantity") + return data + + # create a ReferenceQuantity instance + def create(self, validated_data): + if "label" in validated_data: + validated_data.pop("label") + if "history" in validated_data: + validated_data.pop("history") + return models.ReferenceQuantity.objects.create(**validated_data) + + +class QuantitySerializer(serializers.ModelSerializer): + """Serialization, deserialization, and validation for the Quantity model.""" + + # associated operation tree + operation_tree = OperationTreeSerializer(read_only=False, required=False) + # references for the operation tree + references = ReferenceQuantitySerializer(many=True, read_only=False, required=False) + # quantity label + label = serializers.CharField(max_length=20) + # dataset this is a part of + dataset = serializers.PrimaryKeyRelatedField( + queryset=models.DataSet, required=False, allow_null=True + ) + # serialized JSON form of operation tree and references + history = serializers.JSONField(required=False, allow_null=True) + + class Meta: + model = models.Quantity + fields = [ + "value", + "variance", + "units", + "hash", + "operation_tree", + "references", + "label", + "dataset", + "history", + ] + + # validate references + def validate_history(self, value): + if "references" in value: + for ref in value["references"]: + serializer = ReferenceQuantitySerializer(data=ref) + serializer.is_valid(raise_exception=True) + + # TODO: should variable-only history be assumed to refer to the same Quantity and ignored? + # Extract operation tree from history + def to_internal_value(self, data): + if "history" in data: + data_copy = data.copy() + if "operation_tree" in data["history"]: + operations = data["history"]["operation_tree"] + if ( + "operation" in operations + and not operations["operation"] == "variable" + ): + data_copy["operation_tree"] = operations + return_data = super().to_internal_value(data_copy) + return_data["history"] = data["history"] + return return_data + else: + return super().to_internal_value(data_copy) + return super().to_internal_value(data) + + # Serialize a Quantity instance + def to_representation(self, instance): + data = super().to_representation(instance) + if "dataset" in data: + data.pop("dataset") + if "derived_quantity" in data: + data.pop("derived_quantity") + data["history"] = {} + data["history"]["operation_tree"] = data.pop("operation_tree") + data["history"]["references"] = data.pop("references") + return data + + # Create a Quantity instance + def create(self, validated_data): + operations_tree = None + references = None + if "operation_tree" in validated_data: + operations_tree = validated_data.pop("operation_tree") + if "history" in validated_data: + history = validated_data.pop("history") + if history and "references" in history: + references = history.pop("references") + quantity = models.Quantity.objects.create(**validated_data) + if operations_tree: + operations_tree["quantity"] = quantity + OperationTreeSerializer.create( + OperationTreeSerializer(), validated_data=operations_tree + ) + if references: + for ref in references: + ref["derived_quantity"] = quantity + ReferenceQuantitySerializer.create( + ReferenceQuantitySerializer(), validated_data=ref + ) + return quantity + + +class DataSetSerializer(serializers.ModelSerializer): + """Serialization, deserialization, and validation for the DataSet model.""" + + # associated metadata + metadata = MetaDataSerializer(read_only=False) + # associated files + files = serializers.PrimaryKeyRelatedField( + required=False, many=True, allow_null=True, queryset=models.DataFile.objects + ) + # quantities that make up the dataset + data_contents = QuantitySerializer(many=True, read_only=False) + # session the dataset is a part of, if any + session = serializers.PrimaryKeyRelatedField( + queryset=models.Session, required=False, allow_null=True + ) + # TODO: handle files better + + class Meta: + model = models.DataSet + fields = [ + "id", + "name", + "files", + "metadata", + "data_contents", + "is_public", + "current_user", + "users", + "session", + ] + + # Serialize a DataSet instance + def to_representation(self, instance): + data = super().to_representation(instance) + if "request" in self.context: + files = [ + file.id + for file in instance.files.all() + if ( + file.is_public + or permissions.has_access(self.context["request"], file) + ) + ] + data["files"] = files + return data + + # Check that files exist and user has access to them + def validate_files(self, value): + for file in value: + if not file.is_public and not permissions.has_access( + self.context["request"], file + ): + raise serializers.ValidationError( + "You do not have access to file " + str(file.id) + ) + return value + + # Check that private data has an owner + def validate(self, data): + if ( + not self.context["request"].user.is_authenticated + and "is_public" in data + and not data["is_public"] + ): + raise serializers.ValidationError("private data must have an owner") + if "current_user" in data and ( + data["current_user"] == "" or data["current_user"] is None + ): + if "is_public" in data: + if not data["is_public"]: + raise serializers.ValidationError("private data must have an owner") + else: + if not self.instance.is_public: + raise serializers.ValidationError("private data must have an owner") + return data + + # Create a DataSet instance + def create(self, validated_data): + files = [] + if self.context["request"].user.is_authenticated: + validated_data["current_user"] = self.context["request"].user + metadata_raw = validated_data.pop("metadata") + data_contents = validated_data.pop("data_contents") + if "files" in validated_data: + files = validated_data.pop("files") + dataset = models.DataSet.objects.create(**validated_data) + dataset.files.set(files) + metadata_raw["dataset"] = dataset + MetaDataSerializer.create(MetaDataSerializer(), validated_data=metadata_raw) + for d in data_contents: + d["dataset"] = dataset + QuantitySerializer.create(QuantitySerializer(), validated_data=d) + return dataset + + # TODO: account for updating other attributes + # Update a DataSet instance + def update(self, instance, validated_data): + if "metadata" in validated_data: + metadata_raw = validated_data.pop("metadata") + new_metadata = MetaDataSerializer.update( + MetaDataSerializer(), instance.metadata, validated_data=metadata_raw + ) + instance.metadata = new_metadata + instance.save() + return super().update(instance, validated_data) + + +class PublishedStateSerializer(serializers.ModelSerializer): + """Serialization, deserialization, and validation for the PublishedState model.""" + + # associated session + session = serializers.PrimaryKeyRelatedField( + queryset=models.Session.objects, required=False, allow_null=True + ) + + class Meta: + model = models.PublishedState + fields = "__all__" + + # check that session does not already have a published state + def validate_session(self, value): + try: + published = value.published_state + if published is not None: + raise serializers.ValidationError( + "Only one published state per session" + ) + except models.Session.published_state.RelatedObjectDoesNotExist: + return value + + # set a placeholder DOI + def to_internal_value(self, data): + data_copy = data.copy() + data_copy["doi"] = "http://127.0.0.1:8000/v1/data/session/" + return super().to_internal_value(data_copy) + + # create a PublishedState instance + def create(self, validated_data): + # TODO: generate DOI + validated_data["doi"] = ( + "http://127.0.0.1:8000/v1/data/session/" + + str(validated_data["session"].id) + + "/" + ) + return models.PublishedState.objects.create(**validated_data) + + +class PublishedStateUpdateSerializer(serializers.ModelSerializer): + """Serialization for PublishedState updates. Restricts changes to published field.""" + + class Meta: + model = models.PublishedState + fields = ["published"] + + +class SessionSerializer(serializers.ModelSerializer): + """Serialization, deserialization, and validation for the Session model.""" + + # datasets that make up the session + datasets = DataSetSerializer(read_only=False, many=True) + # associated published state, if any + published_state = PublishedStateSerializer(read_only=False, required=False) + + class Meta: + model = models.Session + fields = [ + "id", + "title", + "published_state", + "datasets", + "current_user", + "is_public", + "users", + ] + + # disallow private unowned sessions + def validate(self, data): + if ( + not self.context["request"].user.is_authenticated + and "is_public" in data + and not data["is_public"] + ): + raise serializers.ValidationError("private sessions must have an owner") + if "current_user" in data and data["current_user"] == "": + if "is_public" in data: + if not "is_public": + raise serializers.ValidationError( + "private sessions must have an owner" + ) + else: + if not self.instance.is_public: + raise serializers.ValidationError( + "private sessions must have an owner" + ) + return data + + # propagate is_public to datasets + def to_internal_value(self, data): + data_copy = data.copy() + if "is_public" in data: + if "datasets" in data: + for dataset in data_copy["datasets"]: + dataset["is_public"] = data["is_public"] + return super().to_internal_value(data_copy) + + # serialize a session instance + def to_representation(self, instance): + session = super().to_representation(instance) + for dataset in session["datasets"]: + dataset.pop("session") + return session + + # Create a Session instance + def create(self, validated_data): + published_state = None + if self.context["request"].user.is_authenticated: + validated_data["current_user"] = self.context["request"].user + if "published_state" in validated_data: + published_state = validated_data.pop("published_state") + datasets = validated_data.pop("datasets") + session = models.Session.objects.create(**validated_data) + if published_state: + published_state["session"] = session + PublishedStateSerializer.create( + PublishedStateSerializer(), validated_data=published_state + ) + for dataset in datasets: + dataset["session"] = session + DataSetSerializer.create( + DataSetSerializer(context=self.context), validated_data=dataset + ) + return session + + # update a session instance + def update(self, instance, validated_data): + if "is_public" in validated_data: + for dataset in instance.datasets.all(): + dataset.is_public = validated_data["is_public"] + dataset.save() + if "published_state" in validated_data: + pb_raw = validated_data.pop("published_state") + try: + PublishedStateUpdateSerializer.update( + PublishedStateUpdateSerializer(), + instance.published_state, + validated_data=pb_raw, + ) + except ObjectDoesNotExist: + pb_raw["session"] = instance + PublishedStateSerializer.create( + PublishedStateSerializer(), validated_data=pb_raw + ) + return super().update(instance, validated_data) diff --git a/sasdata/fair_database/data/test/__init__.py b/sasdata/fair_database/data/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sasdata/fair_database/data/test/test_datafile.py b/sasdata/fair_database/data/test/test_datafile.py new file mode 100644 index 000000000..ac8dbb480 --- /dev/null +++ b/sasdata/fair_database/data/test/test_datafile.py @@ -0,0 +1,443 @@ +import os +import shutil + +from data.models import DataFile +from django.conf import settings +from django.contrib.auth.models import User +from django.db.models import Max +from django.test import TestCase +from rest_framework import status +from rest_framework.test import APIClient, APITestCase + + +# path to a file in example_data/1d_data +def find(filename): + return os.path.join( + os.path.dirname(__file__), "../../../example_data/1d_data", filename + ) + + +class TestLists(TestCase): + """Test get methods for DataFile.""" + + @classmethod + def setUpTestData(cls): + cls.public_test_data = DataFile.objects.create( + id=1, file_name="cyl_400_40.txt", is_public=True + ) + cls.public_test_data.file.save( + "cyl_400_40.txt", open(find("cyl_400_40.txt"), "rb") + ) + cls.user = User.objects.create_user( + username="testUser", password="secret", id=2 + ) + cls.private_test_data = DataFile.objects.create( + id=3, current_user=cls.user, file_name="cyl_400_20.txt", is_public=False + ) + cls.private_test_data.file.save( + "cyl_400_20.txt", open(find("cyl_400_20.txt"), "rb") + ) + cls.client_authenticated = APIClient() + cls.client_authenticated.force_authenticate(user=cls.user) + + # Test list public data + def test_does_list_public(self): + request = self.client_authenticated.get("/v1/data/file/") + self.assertEqual( + request.data, + {"public_data_ids": {1: "cyl_400_40.txt", 3: "cyl_400_20.txt"}}, + ) + + # Test list a user's private data + def test_does_list_user(self): + request = self.client_authenticated.get( + "/v1/data/file/", data={"username": "testUser"}, user=self.user + ) + self.assertEqual(request.data, {"user_data_ids": {3: "cyl_400_20.txt"}}) + + # Test list another user's public data + def test_list_other_user(self): + client_unauthenticated = APIClient() + request = client_unauthenticated.get( + "/v1/data/file/", data={"username": "testUser"}, user=self.user + ) + self.assertEqual(request.data, {"user_data_ids": {}}) + + # Test list a nonexistent user's data + def test_list_nonexistent_user(self): + request = self.client_authenticated.get( + "/v1/data/file/", data={"username": "fakeUser"} + ) + self.assertEqual(request.status_code, status.HTTP_404_NOT_FOUND) + + # Test loading a public data file + def test_does_load_data_info_public(self): + request = self.client_authenticated.get("/v1/data/file/1/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + + # Test loading private data with authorization + def test_does_load_data_info_private(self): + request = self.client_authenticated.get("/v1/data/file/3/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + + # Test loading data that does not exist + def test_load_data_info_nonexistent(self): + request = self.client_authenticated.get("/v1/data/file/5/") + self.assertEqual(request.status_code, status.HTTP_404_NOT_FOUND) + + @classmethod + def tearDownClass(cls): + cls.public_test_data.delete() + cls.private_test_data.delete() + cls.user.delete() + shutil.rmtree(settings.MEDIA_ROOT) + + +class TestingDatabase(APITestCase): + """Test non-get methods for DataFile.""" + + @classmethod + def setUpTestData(cls): + cls.user = User.objects.create_user( + username="testUser", password="secret", id=1 + ) + cls.data = DataFile.objects.create( + id=1, current_user=cls.user, file_name="cyl_400_20.txt", is_public=False + ) + cls.data.file.save("cyl_400_20.txt", open(find("cyl_400_20.txt"), "rb")) + cls.client_authenticated = APIClient() + cls.client_authenticated.force_authenticate(user=cls.user) + cls.client_unauthenticated = APIClient() + + # Test data upload creates data in database + def test_is_data_being_created(self): + file = open(find("cyl_400_40.txt"), "rb") + data = {"is_public": False, "file": file} + request = self.client_authenticated.post("/v1/data/file/", data=data) + max_id = DataFile.objects.aggregate(Max("id"))["id__max"] + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + request.data, + { + "current_user": "testUser", + "authenticated": True, + "file_id": max_id, + "file_alternative_name": "cyl_400_40.txt", + "is_public": False, + }, + ) + DataFile.objects.get(id=max_id).delete() + + # Test data upload w/out authenticated user + def test_is_data_being_created_no_user(self): + file = open(find("cyl_testdata.txt"), "rb") + data = {"is_public": True, "file": file} + request = self.client_unauthenticated.post("/v1/data/file/", data=data) + max_id = DataFile.objects.aggregate(Max("id"))["id__max"] + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + request.data, + { + "current_user": "", + "authenticated": False, + "file_id": max_id, + "file_alternative_name": "cyl_testdata.txt", + "is_public": True, + }, + ) + DataFile.objects.get(id=max_id).delete() + + # Test whether a user can overwrite data by specifying an in-use id + def test_no_data_overwrite(self): + file = open(find("apoferritin.txt")) + data = {"is_public": True, "file": file, id: 1} + request = self.client_authenticated.post("/v1/data/file/", data=data) + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual(DataFile.objects.get(id=1).file_name, "cyl_400_20.txt") + max_id = DataFile.objects.aggregate(Max("id"))["id__max"] + self.assertEqual( + request.data, + { + "current_user": "testUser", + "authenticated": True, + "file_id": max_id, + "file_alternative_name": "apoferritin.txt", + "is_public": True, + }, + ) + DataFile.objects.get(id=max_id).delete() + + # Test updating file + def test_does_file_upload_update(self): + file = open(find("cyl_testdata1.txt")) + data = {"file": file, "is_public": False} + request = self.client_authenticated.put("/v1/data/file/1/", data=data) + self.assertEqual( + request.data, + { + "current_user": "testUser", + "authenticated": True, + "file_id": 1, + "file_alternative_name": "cyl_testdata1.txt", + "is_public": False, + }, + ) + self.data.file.save("cyl_400_20.txt", open(find("cyl_400_20.txt"), "rb")) + self.data.file_name = "cyl_400_20.txt" + + # Test updating a public file + def test_public_file_upload_update(self): + data_object = DataFile.objects.create( + id=3, current_user=self.user, file_name="cyl_testdata2.txt", is_public=True + ) + data_object.file.save( + "cyl_testdata2.txt", open(find("cyl_testdata2.txt"), "rb") + ) + file = open(find("conalbumin.txt")) + data = {"file": file, "is_public": True} + request = self.client_authenticated.put("/v1/data/file/3/", data=data) + self.assertEqual( + request.data, + { + "current_user": "testUser", + "authenticated": True, + "file_id": 3, + "file_alternative_name": "conalbumin.txt", + "is_public": True, + }, + ) + data_object.delete() + + # Test file upload update fails when unauthorized + def test_unauthorized_file_upload_update(self): + file = open(find("cyl_400_40.txt")) + data = {"file": file, "is_public": False} + request = self.client_unauthenticated.put("/v1/data/file/1/", data=data) + self.assertEqual(request.status_code, status.HTTP_401_UNAUTHORIZED) + + # Test update nonexistent file fails + def test_file_upload_update_not_found(self): + file = open(find("cyl_400_40.txt")) + data = {"file": file, "is_public": False} + request = self.client_unauthenticated.put("/v1/data/file/5/", data=data) + self.assertEqual(request.status_code, status.HTTP_404_NOT_FOUND) + + # Test file download + def test_does_download(self): + request = self.client_authenticated.get( + "/v1/data/file/1/", data={"download": True} + ) + file_contents = b"".join(request.streaming_content) + test_file = open(find("cyl_400_20.txt"), "rb") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(file_contents, test_file.read()) + + # Test file download fails when unauthorized + def test_unauthorized_download(self): + request2 = self.client_unauthenticated.get( + "/v1/data/file/1/", data={"download": True} + ) + self.assertEqual(request2.status_code, status.HTTP_401_UNAUTHORIZED) + + # Test download nonexistent file + def test_download_nonexistent(self): + request = self.client_authenticated.get( + "/v1/data/file/5/", data={"download": True} + ) + self.assertEqual(request.status_code, status.HTTP_404_NOT_FOUND) + + # Test deleting a file + def test_delete(self): + DataFile.objects.create( + id=6, current_user=self.user, file_name="test.txt", is_public=False + ) + request = self.client_authenticated.delete("/v1/data/file/6/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertFalse(DataFile.objects.filter(pk=6).exists()) + + # Test deleting a file fails when unauthorized + def test_delete_unauthorized(self): + request = self.client_unauthenticated.delete("/v1/data/file/1/") + self.assertEqual(request.status_code, status.HTTP_401_UNAUTHORIZED) + + @classmethod + def tearDownClass(cls): + cls.user.delete() + cls.data.delete() + shutil.rmtree(settings.MEDIA_ROOT) + + +class TestAccessManagement(TestCase): + """Test viewing and managing access for a file.""" + + @classmethod + def setUpTestData(cls): + cls.user1 = User.objects.create_user(username="testUser", password="secret") + cls.user2 = User.objects.create_user(username="testUser2", password="secret2") + cls.private_test_data = DataFile.objects.create( + id=1, current_user=cls.user1, file_name="cyl_400_40.txt", is_public=False + ) + cls.private_test_data.file.save( + "cyl_400_40.txt", open(find("cyl_400_40.txt"), "rb") + ) + cls.shared_test_data = DataFile.objects.create( + id=2, current_user=cls.user1, file_name="cyl_400_20.txt", is_public=False + ) + cls.shared_test_data.file.save( + "cyl_400_20.txt", open(find("cyl_400_20.txt"), "rb") + ) + cls.shared_test_data.users.add(cls.user2) + cls.client_owner = APIClient() + cls.client_owner.force_authenticate(cls.user1) + cls.client_other = APIClient() + cls.client_other.force_authenticate(cls.user2) + + # test viewing no one with access + def test_view_no_access(self): + request = self.client_owner.get("/v1/data/file/1/users/") + data = { + "file": 1, + "file_name": "cyl_400_40.txt", + "is_public": False, + "users": [], + } + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(request.data, data) + + # test viewing list of users with access + def test_view_access(self): + request = self.client_owner.get("/v1/data/file/2/users/") + data = { + "file": 2, + "file_name": "cyl_400_20.txt", + "is_public": False, + "users": ["testUser2"], + } + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(request.data, data) + + # test granting another user access to private data + def test_grant_access(self): + data = {"username": "testUser2", "access": True} + request1 = self.client_owner.put("/v1/data/file/1/users/", data=data) + request2 = self.client_other.get("/v1/data/file/1/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertEqual( + request1.data, + { + "username": "testUser2", + "file": 1, + "file_name": "cyl_400_40.txt", + "access": True, + }, + ) + + # test removing another user's access to private data + def test_remove_access(self): + data = {"username": "testUser2", "access": False} + request1 = self.client_other.get("/v1/data/file/2/") + request2 = self.client_owner.put("/v1/data/file/2/users/", data=data) + request3 = self.client_other.get("/v1/data/file/2/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertEqual(request3.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual( + request2.data, + { + "username": "testUser2", + "file": 2, + "file_name": "cyl_400_20.txt", + "access": False, + }, + ) + + # test removing access from a user that already lacks access + def test_remove_no_access(self): + data = {"username": "testUser2", "access": False} + request1 = self.client_other.get("/v1/data/file/1/") + request2 = self.client_owner.put("/v1/data/file/1/users/", data=data) + request3 = self.client_other.get("/v1/data/file/1/") + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertEqual(request3.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual( + request2.data, + { + "username": "testUser2", + "file": 1, + "file_name": "cyl_400_40.txt", + "access": False, + }, + ) + + # test owner's access cannot be removed + def test_cant_revoke_own_access(self): + data = {"username": "testUser", "access": False} + request1 = self.client_owner.put("/v1/data/file/1/users/", data=data) + request2 = self.client_owner.get("/v1/data/file/1/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertEqual( + request1.data, + { + "username": "testUser", + "file": 1, + "file_name": "cyl_400_40.txt", + "access": True, + }, + ) + + # test giving access to a user that already has access + def test_grant_existing_access(self): + data = {"username": "testUser2", "access": True} + request1 = self.client_other.get("/v1/data/file/2/") + request2 = self.client_owner.put("/v1/data/file/2/users/", data=data) + request3 = self.client_other.get("/v1/data/file/2/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertEqual(request3.status_code, status.HTTP_200_OK) + self.assertEqual( + request2.data, + { + "username": "testUser2", + "file": 2, + "file_name": "cyl_400_20.txt", + "access": True, + }, + ) + + # test that access is read-only for the file + def test_no_edit_access(self): + data = {"is_public": True} + request = self.client_other.put("/v1/data/file/2/", data=data) + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + self.assertFalse(self.shared_test_data.is_public) + + # test that only the owner can view who has access + def test_only_view_access_to_owned_file(self): + request1 = self.client_other.get("/v1/data/file/1/users/") + request2 = self.client_other.get("/v1/data/file/2/users/") + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + + # test that only the owner can change access + def test_only_edit_access_to_owned_file(self): + data1 = {"username": "testUser2", "access": True} + data2 = {"username": "testUser1", "access": False} + request1 = self.client_other.put("/v1/data/file/1/users/", data=data1) + request2 = self.client_other.put("/v1/data/file/2/users/", data=data2) + request3 = self.client_other.get("/v1/data/file/1/") + request4 = self.client_owner.get("/v1/data/file/2/") + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request3.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request4.status_code, status.HTTP_200_OK) + + @classmethod + def tearDownClass(cls): + cls.user1.delete() + cls.user2.delete() + cls.private_test_data.delete() + cls.shared_test_data.delete() + shutil.rmtree(settings.MEDIA_ROOT) diff --git a/sasdata/fair_database/data/test/test_dataset.py b/sasdata/fair_database/data/test/test_dataset.py new file mode 100644 index 000000000..aecfb7247 --- /dev/null +++ b/sasdata/fair_database/data/test/test_dataset.py @@ -0,0 +1,736 @@ +import os +import shutil + +from data.models import DataFile, DataSet, MetaData, OperationTree, Quantity +from django.conf import settings +from django.contrib.auth.models import User +from django.db.models import Max +from rest_framework import status +from rest_framework.test import APIClient, APITestCase + + +# path to a file in example_data/1d_data +def find(filename): + return os.path.join( + os.path.dirname(__file__), "../../../example_data/1d_data", filename + ) + + +class TestDataSet(APITestCase): + """Test HTTP methods of DataSetView.""" + + @classmethod + def setUpTestData(cls): + cls.empty_metadata = { + "title": "New Metadata", + "run": ["X"], + "description": "test", + "instrument": {}, + "process": {}, + "sample": {}, + } + cls.empty_data = [ + { + "value": 0, + "variance": 0, + "units": "no", + "hash": 0, + "label": "test", + "history": {"operation_tree": {}, "references": []}, + } + ] + cls.user1 = User.objects.create_user( + id=1, username="testUser1", password="secret" + ) + cls.user2 = User.objects.create_user( + id=2, username="testUser2", password="secret" + ) + cls.user3 = User.objects.create_user( + id=3, username="testUser3", password="secret" + ) + cls.public_dataset = DataSet.objects.create( + id=1, + current_user=cls.user1, + is_public=True, + name="Dataset 1", + ) + cls.private_dataset = DataSet.objects.create( + id=2, current_user=cls.user1, name="Dataset 2" + ) + cls.unowned_dataset = DataSet.objects.create( + id=3, is_public=True, name="Dataset 3" + ) + cls.private_dataset.users.add(cls.user3) + cls.auth_client1 = APIClient() + cls.auth_client2 = APIClient() + cls.auth_client3 = APIClient() + cls.auth_client1.force_authenticate(cls.user1) + cls.auth_client2.force_authenticate(cls.user2) + cls.auth_client3.force_authenticate(cls.user3) + + # Test a user can list their own private data + def test_list_private(self): + request = self.auth_client1.get("/v1/data/set/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + {"dataset_ids": {1: "Dataset 1", 2: "Dataset 2", 3: "Dataset 3"}}, + ) + + # Test a user can see others' public but not private data in list + def test_list_public(self): + request = self.auth_client2.get("/v1/data/set/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, {"dataset_ids": {1: "Dataset 1", 3: "Dataset 3"}} + ) + + # Test a user can see private data they have been granted access to + def test_list_granted_access(self): + request = self.auth_client3.get("/v1/data/set/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + {"dataset_ids": {1: "Dataset 1", 2: "Dataset 2", 3: "Dataset 3"}}, + ) + + # Test an unauthenticated user can list public data + def test_list_unauthenticated(self): + request = self.client.get("/v1/data/set/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, {"dataset_ids": {1: "Dataset 1", 3: "Dataset 3"}} + ) + + # Test a user can see all data listed by their username + def test_list_username(self): + request = self.auth_client1.get("/v1/data/set/", data={"username": "testUser1"}) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, {"dataset_ids": {1: "Dataset 1", 2: "Dataset 2"}} + ) + + # Test a user can list public data by another user's username + def test_list_username_2(self): + request = self.auth_client1.get("/v1/data/set/", {"username": "testUser2"}) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(request.data, {"dataset_ids": {}}) + + # Test an unauthenticated user can list public data by a username + def test_list_username_unauthenticated(self): + request = self.client.get("/v1/data/set/", {"username": "testUser1"}) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(request.data, {"dataset_ids": {1: "Dataset 1"}}) + + # Test listing by a username that doesn't exist + def test_list_wrong_username(self): + request = self.auth_client1.get("/v1/data/set/", {"username": "fakeUser1"}) + self.assertEqual(request.status_code, status.HTTP_404_NOT_FOUND) + + # TODO: test listing by other parameters if functionality is added for that + + # Test creating a dataset with associated metadata + def test_dataset_created(self): + dataset = { + "name": "New Dataset", + "metadata": self.empty_metadata, + "data_contents": self.empty_data, + } + request = self.auth_client1.post("/v1/data/set/", data=dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + new_metadata = new_dataset.metadata + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + request.data, + { + "dataset_id": max_id, + "name": "New Dataset", + "authenticated": True, + "current_user": "testUser1", + "is_public": False, + }, + ) + self.assertEqual(new_dataset.name, "New Dataset") + self.assertEqual(new_metadata.title, "New Metadata") + self.assertEqual(new_dataset.current_user.username, "testUser1") + new_dataset.delete() + new_metadata.delete() + + # Test creating a dataset while unauthenticated + def test_dataset_created_unauthenticated(self): + dataset = { + "name": "New Dataset", + "metadata": self.empty_metadata, + "is_public": True, + "data_contents": self.empty_data, + } + request = self.client.post("/v1/data/set/", data=dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + new_metadata = new_dataset.metadata + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + request.data, + { + "dataset_id": max_id, + "name": "New Dataset", + "authenticated": False, + "current_user": "", + "is_public": True, + }, + ) + self.assertEqual(new_dataset.name, "New Dataset") + self.assertIsNone(new_dataset.current_user) + new_dataset.delete() + new_metadata.delete() + + # Test creating a database with associated files + def test_dataset_created_with_files(self): + file = DataFile.objects.create( + id=1, file_name="cyl_testdata.txt", is_public=True + ) + file.file.save("cyl_testdata.txt", open(find("cyl_testdata.txt"))) + dataset = { + "name": "Dataset with file", + "metadata": self.empty_metadata, + "is_public": True, + "data_contents": self.empty_data, + "files": [1], + } + request = self.client.post("/v1/data/set/", data=dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + request.data, + { + "dataset_id": max_id, + "name": "Dataset with file", + "authenticated": False, + "current_user": "", + "is_public": True, + }, + ) + self.assertTrue(file in new_dataset.files.all()) + new_dataset.delete() + file.delete() + + # Test that a dataset cannot be associated with inaccessible files + def test_no_dataset_with_private_files(self): + file = DataFile.objects.create( + id=1, file_name="cyl_testdata.txt", is_public=False, current_user=self.user2 + ) + file.file.save("cyl_testdata.txt", open(find("cyl_testdata.txt"))) + dataset = { + "name": "Dataset with file", + "metadata": self.empty_metadata, + "is_public": True, + "data_contents": self.empty_data, + "files": [1], + } + request = self.client.post("/v1/data/set/", data=dataset, format="json") + file.delete() + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + + # Test that a dataset cannot be associated with nonexistent files + def test_no_dataset_with_nonexistent_files(self): + dataset = { + "name": "Dataset with file", + "metadata": self.empty_metadata, + "is_public": True, + "data_contents": self.empty_data, + "files": [2], + } + request = self.client.post("/v1/data/set/", data=dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + + # Test that a dataset cannot be created without metadata + def test_metadata_required(self): + dataset = { + "name": "No metadata", + "is_public": True, + "data_contents": self.empty_data, + } + request = self.auth_client1.post("/v1/data/set/", data=dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + + # Test that a private dataset cannot be created without an owner + def test_no_private_unowned_dataset(self): + dataset = { + "name": "Disallowed Dataset", + "metadata": self.empty_metadata, + "is_public": False, + "data_contents": self.empty_data, + } + request = self.client.post("/v1/data/set/", data=dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + + # Test whether a user can overwrite data by specifying an in-use id + def test_no_data_overwrite(self): + dataset = { + "id": 2, + "name": "Overwrite Dataset", + "metadata": self.empty_metadata, + "is_public": True, + "data_contents": self.empty_data, + } + request = self.auth_client2.post("/v1/data/set/", data=dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual(DataSet.objects.get(id=2).name, "Dataset 2") + self.assertEqual( + request.data, + { + "dataset_id": max_id, + "name": "Overwrite Dataset", + "authenticated": True, + "current_user": "testUser2", + "is_public": True, + }, + ) + DataSet.objects.get(id=max_id).delete() + + @classmethod + def tearDownClass(cls): + cls.public_dataset.delete() + cls.private_dataset.delete() + cls.unowned_dataset.delete() + cls.user1.delete() + cls.user2.delete() + cls.user3.delete() + shutil.rmtree(settings.MEDIA_ROOT) + + +class TestSingleDataSet(APITestCase): + """Tests for HTTP methods of SingleDataSetView.""" + + @classmethod + def setUpTestData(cls): + cls.user1 = User.objects.create_user( + id=1, username="testUser1", password="secret" + ) + cls.user2 = User.objects.create_user( + id=2, username="testUser2", password="secret" + ) + cls.user3 = User.objects.create_user( + id=3, username="testUser3", password="secret" + ) + cls.public_dataset = DataSet.objects.create( + id=1, + current_user=cls.user1, + is_public=True, + name="Dataset 1", + ) + cls.private_dataset = DataSet.objects.create( + id=2, current_user=cls.user1, name="Dataset 2" + ) + cls.unowned_dataset = DataSet.objects.create( + id=3, is_public=True, name="Dataset 3" + ) + cls.metadata = MetaData.objects.create( + id=1, + title="Metadata", + run=0, + definition="test", + instrument="none", + process="none", + sample="none", + dataset=cls.public_dataset, + ) + cls.file = DataFile.objects.create( + id=1, file_name="cyl_testdata.txt", is_public=False, current_user=cls.user1 + ) + cls.file.file.save("cyl_testdata.txt", open(find("cyl_testdata.txt"))) + cls.private_dataset.users.add(cls.user3) + cls.public_dataset.files.add(cls.file) + cls.auth_client1 = APIClient() + cls.auth_client2 = APIClient() + cls.auth_client3 = APIClient() + cls.auth_client1.force_authenticate(cls.user1) + cls.auth_client2.force_authenticate(cls.user2) + cls.auth_client3.force_authenticate(cls.user3) + + # TODO: change load return data + # Test successfully accessing a private dataset + def test_load_private_dataset(self): + request1 = self.auth_client1.get("/v1/data/set/2/") + request2 = self.auth_client3.get("/v1/data/set/2/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertEqual( + request1.data, + { + "id": 2, + "current_user": "testUser1", + "users": [3], + "is_public": False, + "name": "Dataset 2", + "files": [], + "metadata": None, + "data_contents": [], + "session": None, + }, + ) + + # Test successfully accessing a public dataset + def test_load_public_dataset(self): + request1 = self.client.get("/v1/data/set/1/") + request2 = self.auth_client2.get("/v1/data/set/1/") + request3 = self.auth_client1.get("/v1/data/set/1/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertEqual(request3.status_code, status.HTTP_200_OK) + self.assertDictEqual( + request1.data, + { + "id": 1, + "current_user": "testUser1", + "users": [], + "is_public": True, + "name": "Dataset 1", + "files": [], + "metadata": { + "id": 1, + "title": "Metadata", + "run": 0, + "definition": "test", + "instrument": "none", + "process": "none", + "sample": "none", + }, + "data_contents": [], + "session": None, + }, + ) + self.assertEqual(request1.data, request2.data) + self.assertEqual( + request3.data, + { + "id": 1, + "current_user": "testUser1", + "users": [], + "is_public": True, + "name": "Dataset 1", + "files": [1], + "metadata": { + "id": 1, + "title": "Metadata", + "run": 0, + "definition": "test", + "instrument": "none", + "process": "none", + "sample": "none", + }, + "data_contents": [], + "session": None, + }, + ) + + # Test successfully accessing an unowned public dataset + def test_load_unowned_dataset(self): + request1 = self.auth_client1.get("/v1/data/set/3/") + request2 = self.client.get("/v1/data/set/3/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertDictEqual( + request1.data, + { + "id": 3, + "current_user": None, + "users": [], + "is_public": True, + "name": "Dataset 3", + "files": [], + "metadata": None, + "data_contents": [], + "session": None, + }, + ) + + # Test unsuccessfully accessing a private dataset + def test_load_private_dataset_unauthorized(self): + request1 = self.auth_client2.get("/v1/data/set/2/") + request2 = self.client.get("/v1/data/set/2/") + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_401_UNAUTHORIZED) + + # Test only owner can change a private dataset + def test_update_private_dataset(self): + request1 = self.auth_client1.put("/v1/data/set/2/", data={"is_public": True}) + request2 = self.auth_client3.put("/v1/data/set/2/", data={"is_public": False}) + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual( + request1.data, {"data_id": 2, "name": "Dataset 2", "is_public": True} + ) + self.assertTrue(DataSet.objects.get(id=2).is_public) + self.private_dataset.save() + self.assertFalse(DataSet.objects.get(id=2).is_public) + + # Test changing a public dataset + def test_update_public_dataset(self): + request1 = self.auth_client1.put( + "/v1/data/set/1/", data={"name": "Different name"} + ) + request2 = self.auth_client2.put("/v1/data/set/1/", data={"is_public": False}) + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual( + request1.data, {"data_id": 1, "name": "Different name", "is_public": True} + ) + self.assertEqual(DataSet.objects.get(id=1).name, "Different name") + self.public_dataset.save() + + # TODO: test invalid updates if and when those are figured out + + # Test changing an unowned dataset + def test_update_unowned_dataset(self): + request1 = self.auth_client1.put("/v1/data/set/3/", data={"current_user": 1}) + request2 = self.client.put("/v1/data/set/3/", data={"name": "Different name"}) + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_401_UNAUTHORIZED) + + # Test updating metadata + def test_update_dataset_metadata(self): + new_metadata = { + "title": "Updated Metadata", + "run": ["X"], + "definition": "update test", + "instrument": "none", + "process": "none", + "sample": "none", + } + request = self.auth_client1.put( + "/v1/data/set/1/", data={"metadata": new_metadata}, format="json" + ) + dataset = DataSet.objects.get(id=1) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(dataset.metadata.title, "Updated Metadata") + self.assertEqual(dataset.metadata.id, 1) + self.assertEqual(len(MetaData.objects.all()), 1) + dataset.metadata.delete() + self.metadata = MetaData.objects.create( + id=1, + title="Metadata", + run=0, + definition="test", + instrument="none", + process="none", + sample="none", + dataset=self.public_dataset, + ) + + # Test partially updating metadata + def test_update_dataset_partial_metadata(self): + request = self.auth_client1.put( + "/v1/data/set/1/", + data={"metadata": {"title": "Different Title"}}, + format="json", + ) + dataset = DataSet.objects.get(id=1) + metadata = dataset.metadata + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(metadata.title, "Different Title") + self.assertEqual(metadata.definition, "test") + self.assertEqual(metadata.id, 1) + metadata.title = "Metadata" + metadata.save() + + # Test updating a dataset's files + def test_update_dataset_files(self): + request = self.auth_client1.put("/v1/data/set/2/", data={"files": [1]}) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(len(DataSet.objects.get(id=2).files.all()), 1) + self.private_dataset.files.remove(self.file) + + # Test replacing a dataset's files + def test_update_dataset_replace_files(self): + file = DataFile.objects.create( + id=2, file_name="cyl_testdata1.txt", is_public=True, current_user=self.user1 + ) + file.file.save("cyl_testdata1.txt", open(find("cyl_testdata1.txt"))) + request = self.auth_client1.put("/v1/data/set/1/", data={"files": [2]}) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(len(DataSet.objects.get(id=1).files.all()), 1) + self.assertTrue(file in DataSet.objects.get(id=1).files.all()) + self.public_dataset.files.add(self.file) + self.public_dataset.files.remove(file) + + # Test updating a dataset to have no files + def test_update_dataset_clear_files(self): + request = self.auth_client1.put("/v1/data/set/1/", data={"files": [""]}) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(len(DataSet.objects.get(id=1).files.all()), 0) + self.public_dataset.files.add(self.file) + + # Test that a dataset cannot be updated to be private and unowned + def test_update_dataset_no_private_unowned(self): + request1 = self.auth_client1.put("/v1/data/set/2/", data={"current_user": ""}) + request2 = self.auth_client1.put( + "/v1/data/set/1/", data={"current_user": "", "is_public": False} + ) + public_dataset = DataSet.objects.get(id=1) + self.assertEqual(request1.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(request2.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(DataSet.objects.get(id=2).current_user, self.user1) + self.assertEqual(public_dataset.current_user, self.user1) + self.assertTrue(public_dataset.is_public) + + # Test deleting a dataset + def test_delete_dataset(self): + quantity = Quantity.objects.create( + id=1, + value=0, + variance=0, + units="none", + hash=0, + label="test", + dataset=self.private_dataset, + ) + neg = OperationTree.objects.create(id=1, operation="neg", quantity=quantity) + OperationTree.objects.create( + id=2, operation="zero", parameters={}, child_operation=neg + ) + request = self.auth_client1.delete("/v1/data/set/2/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(request.data, {"success": True}) + self.assertRaises(DataSet.DoesNotExist, DataSet.objects.get, id=2) + self.assertRaises(Quantity.DoesNotExist, Quantity.objects.get, id=1) + self.assertRaises(OperationTree.DoesNotExist, OperationTree.objects.get, id=1) + self.assertRaises(OperationTree.DoesNotExist, OperationTree.objects.get, id=2) + self.private_dataset = DataSet.objects.create( + id=2, current_user=self.user1, name="Dataset 2" + ) + + # Test cannot delete a public dataset + def test_delete_public_dataset(self): + request = self.auth_client1.delete("/v1/data/set/1/") + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + + # Test cannot delete an unowned dataset + def test_delete_unowned_dataset(self): + request = self.auth_client1.delete("/v1/data/set/3/") + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + + # Test cannot delete another user's dataset + def test_delete_dataset_unauthorized(self): + request1 = self.auth_client2.delete("/v1/data/set/1/") + request2 = self.auth_client3.delete("/v1/data/set/2/") + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + + @classmethod + def tearDownClass(cls): + cls.public_dataset.delete() + cls.private_dataset.delete() + cls.unowned_dataset.delete() + cls.user1.delete() + cls.user2.delete() + cls.user3.delete() + cls.file.delete() + shutil.rmtree(settings.MEDIA_ROOT) + + +class TestDataSetAccessManagement(APITestCase): + """Tests for HTTP methods of DataSetUsersView.""" + + @classmethod + def setUpTestData(cls): + cls.user1 = User.objects.create_user(username="testUser1", password="secret") + cls.user2 = User.objects.create_user(username="testUser2", password="secret") + cls.private_dataset = DataSet.objects.create( + id=1, current_user=cls.user1, name="Dataset 1" + ) + cls.shared_dataset = DataSet.objects.create( + id=2, current_user=cls.user1, name="Dataset 2" + ) + cls.shared_dataset.users.add(cls.user2) + cls.client_owner = APIClient() + cls.client_other = APIClient() + cls.client_owner.force_authenticate(cls.user1) + cls.client_other.force_authenticate(cls.user2) + + # Test listing no users with access + def test_list_access_private(self): + request1 = self.client_owner.get("/v1/data/set/1/users/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual( + request1.data, + {"data_id": 1, "name": "Dataset 1", "is_public": False, "users": []}, + ) + + # Test listing users with access + def test_list_access_shared(self): + request1 = self.client_owner.get("/v1/data/set/2/users/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual( + request1.data, + { + "data_id": 2, + "name": "Dataset 2", + "is_public": False, + "users": ["testUser2"], + }, + ) + + # Test only owner can view access + def test_list_access_unauthorized(self): + request = self.client_other.get("/v1/data/set/2/users/") + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + + # Test granting access to a dataset + def test_grant_access(self): + request1 = self.client_owner.put( + "/v1/data/set/1/users/", data={"username": "testUser2", "access": True} + ) + request2 = self.client_other.get("/v1/data/set/1/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertIn( # codespell:ignore + self.user2, DataSet.objects.get(id=1).users.all() + ) + self.assertEqual( + request1.data, + { + "username": "testUser2", + "data_id": 1, + "name": "Dataset 1", + "access": True, + }, + ) + self.private_dataset.users.remove(self.user2) + + # Test revoking access to a dataset + def test_revoke_access(self): + request1 = self.client_owner.put( + "/v1/data/set/2/users/", data={"username": "testUser2", "access": False} + ) + request2 = self.client_other.get("/v1/data/set/2/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + self.assertNotIn(self.user2, DataSet.objects.get(id=2).users.all()) + self.assertEqual( + request1.data, + { + "username": "testUser2", + "data_id": 2, + "name": "Dataset 2", + "access": False, + }, + ) + self.shared_dataset.users.add(self.user2) + + # Test only the owner can change access + def test_revoke_access_unauthorized(self): + request1 = self.client_other.put( + "/v1/data/set/2/users/", data={"username": "testUser2", "access": False} + ) + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + + @classmethod + def tearDownClass(cls): + cls.private_dataset.delete() + cls.shared_dataset.delete() + cls.user1.delete() + cls.user2.delete() diff --git a/sasdata/fair_database/data/test/test_operation_tree.py b/sasdata/fair_database/data/test/test_operation_tree.py new file mode 100644 index 000000000..90a26d81b --- /dev/null +++ b/sasdata/fair_database/data/test/test_operation_tree.py @@ -0,0 +1,798 @@ +from data.models import DataSet, MetaData, OperationTree, Quantity, ReferenceQuantity +from django.contrib.auth.models import User +from django.db.models import Max +from rest_framework import status +from rest_framework.test import APIClient, APITestCase + + +class TestCreateOperationTree(APITestCase): + """Tests for creating datasets with operation trees.""" + + @classmethod + def setUpTestData(cls): + cls.dataset = { + "name": "Test Dataset", + "metadata": { + "title": "test metadata", + "run": 1, + "definition": "test", + "instrument": {"source": {}, "collimation": {}, "detectors": {}}, + }, + "data_contents": [ + { + "label": "test", + "value": {"array_contents": [0, 0, 0, 0], "shape": (2, 2)}, + "variance": {"array_contents": [0, 0, 0, 0], "shape": (2, 2)}, + "units": "none", + "hash": 0, + } + ], + "is_public": True, + } + cls.user = User.objects.create_user( + id=1, username="testUser", password="sasview!" + ) + cls.client = APIClient() + cls.client.force_authenticate(cls.user) + + @staticmethod + def get_operation_tree(quantity): + return quantity.operation_tree + + # Test creating quantity with no operations performed (variable-only history) + def test_operation_tree_created_variable(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "variable", + "parameters": {"hash_value": 0, "name": "test"}, + }, + "references": [ + { + "label": "test", + "value": {"array_contents": [0, 0, 0, 0], "shape": (2, 2)}, + "variance": {"array_contents": [0, 0, 0, 0], "shape": (2, 2)}, + "units": "none", + "hash": 0, + "history": {}, + } + ], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + new_quantity = new_dataset.data_contents.get(hash=0) + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertRaises( + Quantity.operation_tree.RelatedObjectDoesNotExist, + self.get_operation_tree, + quantity=new_quantity, + ) + self.assertEqual(len(new_quantity.references.all()), 0) + + # Test creating quantity with unary operation + def test_operation_tree_created_unary(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "reciprocal", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + } + }, + }, + "references": [ + {"value": 5, "variance": 0, "units": "none", "hash": 111, "history": {}} + ], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + new_quantity = new_dataset.data_contents.get(hash=0) + reciprocal = new_quantity.operation_tree + variable = reciprocal.parent_operations.all().get(label="a") + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + new_quantity.value, {"array_contents": [0, 0, 0, 0], "shape": [2, 2]} + ) + self.assertEqual(reciprocal.operation, "reciprocal") + self.assertEqual(variable.operation, "variable") + self.assertEqual(len(reciprocal.parent_operations.all()), 1) + self.assertEqual(reciprocal.parameters, {}) + self.assertEqual(len(ReferenceQuantity.objects.all()), 1) + self.assertEqual(len(new_quantity.references.all()), 1) + self.assertEqual(new_quantity.references.get(hash=111).value, 5) + + # Test creating quantity with binary operation + def test_operation_tree_created_binary(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "add", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + "b": {"operation": "constant", "parameters": {"value": 5}}, + }, + }, + "references": [ + {"value": 5, "variance": 0, "units": "none", "hash": 111, "history": {}} + ], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + new_quantity = new_dataset.data_contents.get(hash=0) + add = new_quantity.operation_tree + variable = add.parent_operations.get(label="a") + constant = add.parent_operations.get(label="b") + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual(add.operation, "add") + self.assertEqual(add.parameters, {}) + self.assertEqual(variable.operation, "variable") + self.assertEqual(variable.parameters, {"hash_value": 111, "name": "x"}) + self.assertEqual(constant.operation, "constant") + self.assertEqual(constant.parameters, {"value": 5}) + self.assertEqual(len(add.parent_operations.all()), 2) + self.assertEqual(len(new_quantity.references.all()), 1) + self.assertEqual(new_quantity.references.get(hash=111).value, 5) + + # Test creating quantity with exponent + def test_operation_tree_created_pow(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "pow", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + "power": 2, + }, + }, + "references": [ + {"value": 5, "variance": 0, "units": "none", "hash": 111, "history": {}} + ], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + new_quantity = new_dataset.data_contents.get(hash=0) + pow = new_quantity.operation_tree + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual(pow.operation, "pow") + self.assertEqual(pow.parameters, {"power": 2}) + self.assertEqual(len(new_quantity.references.all()), 1) + self.assertEqual(new_quantity.references.get(hash=111).value, 5) + + # Test creating a transposed quantity + def test_operation_tree_created_transpose(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "transpose", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + "axes": [1, 0], + }, + }, + "references": [ + {"value": 5, "variance": 0, "units": "none", "hash": 111, "history": {}} + ], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + new_quantity = new_dataset.data_contents.get(hash=0) + transpose = new_quantity.operation_tree + variable = transpose.parent_operations.get() + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual(transpose.operation, "transpose") + self.assertEqual(transpose.parameters, {"axes": [1, 0]}) + self.assertEqual(variable.operation, "variable") + self.assertEqual(variable.parameters, {"hash_value": 111, "name": "x"}) + self.assertEqual(len(new_quantity.references.all()), 1) + self.assertEqual(new_quantity.references.get(hash=111).value, 5) + + # Test creating a quantity with multiple operations + def test_operation_tree_created_nested(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "neg", + "parameters": { + "a": { + "operation": "mul", + "parameters": { + "a": { + "operation": "constant", + "parameters": {"value": {"type": "int", "value": 7}}, + }, + "b": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + }, + }, + }, + }, + "references": [ + {"value": 5, "variance": 0, "units": "none", "hash": 111, "history": {}} + ], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + new_quantity = new_dataset.data_contents.get(hash=0) + negate = new_quantity.operation_tree + multiply = negate.parent_operations.get() + constant = multiply.parent_operations.get(label="a") + variable = multiply.parent_operations.get(label="b") + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual(negate.operation, "neg") + self.assertEqual(negate.parameters, {}) + self.assertEqual(multiply.operation, "mul") + self.assertEqual(multiply.parameters, {}) + self.assertEqual(constant.operation, "constant") + self.assertEqual(constant.parameters, {"value": {"type": "int", "value": 7}}) + self.assertEqual(variable.operation, "variable") + self.assertEqual(variable.parameters, {"hash_value": 111, "name": "x"}) + self.assertEqual(len(new_quantity.references.all()), 1) + self.assertEqual(new_quantity.references.get(hash=111).value, 5) + + # Test creating a quantity with tensordot + def test_operation_tree_created_tensor(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "tensor_product", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + "b": {"operation": "constant", "parameters": {"value": 5}}, + "a_index": 1, + "b_index": 1, + }, + }, + "references": [ + {"value": 5, "variance": 0, "units": "none", "hash": 111, "history": {}} + ], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + new_quantity = new_dataset.data_contents.get(hash=0) + tensor = new_quantity.operation_tree + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual(tensor.operation, "tensor_product") + self.assertEqual(tensor.parameters, {"a_index": 1, "b_index": 1}) + self.assertEqual(len(new_quantity.references.all()), 1) + self.assertEqual(new_quantity.references.get(hash=111).value, 5) + + # Test creating a quantity with no history + def test_operation_tree_created_no_history(self): + if "history" in self.dataset["data_contents"][0]: + self.dataset["data_contents"][0].pop("history") + request = self.client.post( + "/v1/data/set/", data=self.dataset, format="json" + ) + max_id = DataSet.objects.aggregate(Max("id"))["id__max"] + new_dataset = DataSet.objects.get(id=max_id) + new_quantity = new_dataset.data_contents.get(hash=0) + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertIsNone(new_quantity.operation_tree) + self.assertEqual(len(new_quantity.references.all()), 0) + + def tearDown(self): + DataSet.objects.all().delete() + MetaData.objects.all().delete() + Quantity.objects.all().delete() + OperationTree.objects.all().delete() + + @classmethod + def tearDownClass(cls): + cls.user.delete() + + +class TestCreateInvalidOperationTree(APITestCase): + """Tests for creating datasets with invalid operation trees.""" + + @classmethod + def setUpTestData(cls): + cls.dataset = { + "name": "Test Dataset", + "metadata": { + "title": "test metadata", + "run": 1, + "definition": "test", + "instrument": {"source": {}, "collimation": {}, "detectors": {}}, + }, + "data_contents": [ + { + "label": "test", + "value": {"array_contents": [0, 0, 0, 0], "shape": (2, 2)}, + "variance": {"array_contents": [0, 0, 0, 0], "shape": (2, 2)}, + "units": "none", + "hash": 0, + } + ], + "is_public": True, + } + cls.user = User.objects.create_user( + id=1, username="testUser", password="sasview!" + ) + cls.client = APIClient() + cls.client.force_authenticate(cls.user) + + # Test creating a quantity with an invalid operation + def test_create_operation_tree_invalid(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": {"operation": "fix", "parameters": {}}, + "references": [], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(DataSet.objects.all()), 0) + self.assertEqual(len(Quantity.objects.all()), 0) + self.assertEqual(len(OperationTree.objects.all()), 0) + + # Test creating a quantity with a nested invalid operation + def test_create_operation_tree_invalid_nested(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "reciprocal", + "parameters": { + "a": { + "operation": "fix", + "parameters": {}, + } + }, + }, + "references": [], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(DataSet.objects.all()), 0) + self.assertEqual(len(Quantity.objects.all()), 0) + self.assertEqual(len(OperationTree.objects.all()), 0) + + # Test creating a unary operation with a missing parameter fails + def test_create_missing_parameter_unary(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": {"operation": "neg", "parameters": {}}, + "references": {}, + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(DataSet.objects.all()), 0) + self.assertEqual(len(Quantity.objects.all()), 0) + self.assertEqual(len(OperationTree.objects.all()), 0) + + # Test creating a binary operation with a missing parameter fails + def test_create_missing_parameter_binary(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "add", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + } + }, + }, + "references": [], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(DataSet.objects.all()), 0) + self.assertEqual(len(Quantity.objects.all()), 0) + self.assertEqual(len(OperationTree.objects.all()), 0) + + # TODO: should variable-only history be ignored? + # Test creating a variable with a missing parameter fails + def test_create_missing_parameter_variable(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "neg", + "parameters": { + "a": {"operation": "variable", "parameters": {"name": "x"}} + }, + }, + "references": [], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(DataSet.objects.all()), 0) + self.assertEqual(len(Quantity.objects.all()), 0) + self.assertEqual(len(OperationTree.objects.all()), 0) + + # Test creating a constant with a missing parameter fails + def test_create_missing_parameter_constant(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "neg", + "parameters": {"a": {"operation": "constant", "parameters": {}}}, + }, + "references": [], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(DataSet.objects.all()), 0) + self.assertEqual(len(Quantity.objects.all()), 0) + self.assertEqual(len(OperationTree.objects.all()), 0) + + # Test creating an exponent with a missing parameter fails + def test_create_missing_parameter_pow(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "pow", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + }, + }, + "references": [], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(DataSet.objects.all()), 0) + self.assertEqual(len(Quantity.objects.all()), 0) + self.assertEqual(len(OperationTree.objects.all()), 0) + + # Test creating a transpose with a missing parameter fails + def test_create_missing_parameter_transpose(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "transpose", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + }, + }, + "references": [], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(DataSet.objects.all()), 0) + self.assertEqual(len(Quantity.objects.all()), 0) + self.assertEqual(len(OperationTree.objects.all()), 0) + + # Test creating a tensor with a missing parameter fails + def test_create_missing_parameter_tensor(self): + self.dataset["data_contents"][0]["history"] = { + "operation_tree": { + "operation": "tensor_product", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + "b": {"operation": "constant", "parameters": {"value": 5}}, + "b_index": 1, + }, + }, + "references": [], + } + request = self.client.post("/v1/data/set/", data=self.dataset, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(DataSet.objects.all()), 0) + self.assertEqual(len(Quantity.objects.all()), 0) + self.assertEqual(len(OperationTree.objects.all()), 0) + + # TODO: Test variables have corresponding reference quantities + + @classmethod + def tearDownClass(cls): + cls.user.delete() + + +class TestGetOperationTree(APITestCase): + """Tests for retrieving datasets with operation trees.""" + + @classmethod + def setUpTestData(cls): + cls.user = User.objects.create_user( + id=1, username="testUser", password="sasview!" + ) + cls.dataset = DataSet.objects.create( + id=1, + current_user=cls.user, + name="Test Dataset", + is_public=True, + ) + cls.quantity = Quantity.objects.create( + id=1, + value=0, + variance=0, + label="test", + units="none", + hash=1, + dataset=cls.dataset, + ) + cls.variable = OperationTree.objects.create( + id=1, operation="variable", parameters={"hash_value": 111, "name": "x"} + ) + cls.constant = OperationTree.objects.create( + id=2, operation="constant", parameters={"value": 1} + ) + cls.ref_quantity = ReferenceQuantity.objects.create( + id=1, + value=5, + variance=0, + units="none", + hash=111, + derived_quantity=cls.quantity, + ) + cls.client = APIClient() + cls.client.force_authenticate(cls.user) + + # Test accessing a quantity with no operations performed + def test_get_operation_tree_none(self): + self.ref_quantity.delete() + request = self.client.get("/v1/data/set/1/") + self.ref_quantity = ReferenceQuantity.objects.create( + id=1, + value=5, + variance=0, + units="none", + hash=111, + derived_quantity=self.quantity, + ) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data["data_contents"][0], + { + "label": "test", + "value": 0, + "variance": 0, + "units": "none", + "hash": 1, + "history": { + "operation_tree": None, + "references": [], + }, + }, + ) + + # Test accessing quantity with unary operation + def test_get_operation_tree_unary(self): + inv = OperationTree.objects.create( + id=3, + operation="reciprocal", + quantity=self.quantity, + ) + self.variable.label = "a" + self.variable.child_operation = inv + self.variable.save() + request = self.client.get("/v1/data/set/1/") + self.variable.child_operation = None + self.variable.save() + inv.delete() + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data["data_contents"][0], + { + "label": "test", + "value": 0, + "variance": 0, + "units": "none", + "hash": 1, + "history": { + "operation_tree": { + "operation": "reciprocal", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + } + }, + }, + "references": [ + { + "value": 5, + "variance": 0, + "units": "none", + "hash": 111, + } + ], + }, + }, + ) + + # Test accessing quantity with binary operation + def test_get_operation_tree_binary(self): + add = OperationTree.objects.create( + id=3, + operation="add", + quantity=self.quantity, + ) + self.variable.label = "a" + self.variable.child_operation = add + self.variable.save() + self.constant.label = "b" + self.constant.child_operation = add + self.constant.save() + request = self.client.get("/v1/data/set/1/") + self.variable.child_operation = None + self.constant.child_operation = None + self.variable.save() + self.constant.save() + add.delete() + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data["data_contents"][0]["history"]["operation_tree"], + { + "operation": "add", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + "b": { + "operation": "constant", + "parameters": {"value": 1}, + }, + }, + }, + ) + + # Test accessing a quantity with exponent + def test_get_operation_tree_pow(self): + power = OperationTree.objects.create( + id=3, + operation="pow", + parameters={"power": 2}, + quantity=self.quantity, + ) + self.variable.label = "a" + self.variable.child_operation = power + self.variable.save() + request = self.client.get("/v1/data/set/1/") + self.variable.child_operation = None + self.variable.save() + power.delete() + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data["data_contents"][0]["history"]["operation_tree"], + { + "operation": "pow", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + "power": 2, + }, + }, + ) + + # Test accessing a quantity with multiple operations + def test_get_operation_tree_nested(self): + neg = OperationTree.objects.create( + id=4, operation="neg", quantity=self.quantity + ) + multiply = OperationTree.objects.create( + id=3, operation="mul", child_operation=neg, label="a" + ) + self.constant.label = "a" + self.constant.child_operation = multiply + self.constant.save() + self.variable.label = "b" + self.variable.child_operation = multiply + self.variable.save() + request = self.client.get("/v1/data/set/1/") + self.constant.child_operation = None + self.variable.child_operation = None + self.constant.save() + self.variable.save() + neg.delete() + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data["data_contents"][0]["history"]["operation_tree"], + { + "operation": "neg", + "parameters": { + "a": { + "operation": "mul", + "parameters": { + "a": { + "operation": "constant", + "parameters": {"value": 1}, + }, + "b": { + "operation": "variable", + "parameters": { + "hash_value": 111, + "name": "x", + }, + }, + }, + } + }, + }, + ) + + # Test accessing a transposed quantity + def test_get_operation_tree_transpose(self): + trans = OperationTree.objects.create( + id=3, + operation="transpose", + parameters={"axes": (1, 0)}, + quantity=self.quantity, + ) + self.variable.label = "a" + self.variable.child_operation = trans + self.variable.save() + request = self.client.get("/v1/data/set/1/") + self.variable.child_operation = None + self.variable.save() + trans.delete() + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data["data_contents"][0]["history"]["operation_tree"], + { + "operation": "transpose", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + "axes": [1, 0], + }, + }, + ) + + # Test accessing a quantity with tensordot + def test_get_operation_tree_tensordot(self): + tensor = OperationTree.objects.create( + id=3, + operation="tensor_product", + parameters={"a_index": 1, "b_index": 1}, + quantity=self.quantity, + ) + self.variable.label = "a" + self.variable.child_operation = tensor + self.variable.save() + self.constant.label = "b" + self.constant.child_operation = tensor + self.constant.save() + request = self.client.get("/v1/data/set/1/") + self.variable.child_operation = None + self.constant.child_operation = None + self.variable.save() + self.constant.save() + tensor.delete() + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data["data_contents"][0]["history"]["operation_tree"], + { + "operation": "tensor_product", + "parameters": { + "a": { + "operation": "variable", + "parameters": {"hash_value": 111, "name": "x"}, + }, + "b": { + "operation": "constant", + "parameters": {"value": 1}, + }, + "a_index": 1, + "b_index": 1, + }, + }, + ) + + @classmethod + def tearDownClass(cls): + cls.user.delete() + cls.quantity.delete() + cls.dataset.delete() + cls.variable.delete() + cls.constant.delete() diff --git a/sasdata/fair_database/data/test/test_published_state.py b/sasdata/fair_database/data/test/test_published_state.py new file mode 100644 index 000000000..20072f3b1 --- /dev/null +++ b/sasdata/fair_database/data/test/test_published_state.py @@ -0,0 +1,582 @@ +from data.models import PublishedState, Session +from django.contrib.auth.models import User +from django.db.models import Max +from rest_framework import status +from rest_framework.test import APIClient, APITestCase + + +# TODO: account for non-placeholder doi +# Get the placeholder DOI for a session based on id +def doi_generator(id: int): + return "http://127.0.0.1:8000/v1/data/session/" + str(id) + "/" + + +class TestPublishedState(APITestCase): + """Test HTTP methods of PublishedStateView.""" + + @classmethod + def setUpTestData(cls): + cls.user1 = User.objects.create_user( + id=1, username="testUser1", password="secret" + ) + cls.user2 = User.objects.create_user( + id=2, username="testUser2", password="secret" + ) + cls.public_session = Session.objects.create( + id=1, current_user=cls.user1, title="Public Session", is_public=True + ) + cls.private_session = Session.objects.create( + id=2, current_user=cls.user1, title="Private Session", is_public=False + ) + cls.unowned_session = Session.objects.create( + id=3, title="Unowned Session", is_public=True + ) + cls.unpublished_session = Session.objects.create( + id=4, current_user=cls.user1, title="Publishable Session", is_public=True + ) + cls.public_ps = PublishedState.objects.create( + id=1, + doi=doi_generator(1), + published=True, + session=cls.public_session, + ) + cls.private_ps = PublishedState.objects.create( + id=2, + doi=doi_generator(2), + published=False, + session=cls.private_session, + ) + cls.unowned_ps = PublishedState.objects.create( + id=3, + doi=doi_generator(3), + published=True, + session=cls.unowned_session, + ) + cls.auth_client1 = APIClient() + cls.auth_client2 = APIClient() + cls.auth_client1.force_authenticate(cls.user1) + cls.auth_client2.force_authenticate(cls.user2) + + # Test listing published states including those of owned private sessions + def test_list_published_states_private(self): + request = self.auth_client1.get("/v1/data/published/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "published_state_ids": { + 1: { + "title": "Public Session", + "published": True, + "doi": doi_generator(1), + }, + 2: { + "title": "Private Session", + "published": False, + "doi": doi_generator(2), + }, + 3: { + "title": "Unowned Session", + "published": True, + "doi": doi_generator(3), + }, + } + }, + ) + + # Test listing published states of public sessions + def test_list_published_states_public(self): + request = self.auth_client2.get("/v1/data/published/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "published_state_ids": { + 1: { + "title": "Public Session", + "published": True, + "doi": doi_generator(1), + }, + 3: { + "title": "Unowned Session", + "published": True, + "doi": doi_generator(3), + }, + } + }, + ) + + # Test listing published states including sessions with access granted + def test_list_published_states_shared(self): + self.private_session.users.add(self.user2) + request = self.auth_client2.get("/v1/data/published/") + self.private_session.users.remove(self.user2) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "published_state_ids": { + 1: { + "title": "Public Session", + "published": True, + "doi": doi_generator(1), + }, + 2: { + "title": "Private Session", + "published": False, + "doi": doi_generator(2), + }, + 3: { + "title": "Unowned Session", + "published": True, + "doi": doi_generator(3), + }, + } + }, + ) + + # Test listing published states while unauthenticated + def test_list_published_states_unauthenticated(self): + request = self.client.get("/v1/data/published/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "published_state_ids": { + 1: { + "title": "Public Session", + "published": True, + "doi": doi_generator(1), + }, + 3: { + "title": "Unowned Session", + "published": True, + "doi": doi_generator(3), + }, + } + }, + ) + + # Test listing a user's own published states + def test_list_user_published_states_private(self): + request = self.auth_client1.get( + "/v1/data/published/", data={"username": "testUser1"} + ) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "published_state_ids": { + 1: { + "title": "Public Session", + "published": True, + "doi": doi_generator(1), + }, + 2: { + "title": "Private Session", + "published": False, + "doi": doi_generator(2), + }, + } + }, + ) + + # Test listing another user's published states + def test_list_user_published_states_public(self): + request = self.auth_client2.get( + "/v1/data/published/", data={"username": "testUser1"} + ) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "published_state_ids": { + 1: { + "title": "Public Session", + "published": True, + "doi": doi_generator(1), + } + } + }, + ) + + # Test listing another user's published states with access granted + def test_list_user_published_states_shared(self): + self.private_session.users.add(self.user2) + request = self.auth_client2.get( + "/v1/data/published/", data={"username": "testUser1"} + ) + self.private_session.users.remove(self.user2) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "published_state_ids": { + 1: { + "title": "Public Session", + "published": True, + "doi": doi_generator(1), + }, + 2: { + "title": "Private Session", + "published": False, + "doi": doi_generator(2), + }, + } + }, + ) + + # Test listing a user's published states while unauthenticated + def test_list_user_published_states_unauthenticated(self): + request = self.client.get("/v1/data/published/", data={"username": "testUser1"}) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "published_state_ids": { + 1: { + "title": "Public Session", + "published": True, + "doi": doi_generator(1), + } + } + }, + ) + + # Test creating a published state for a private session + def test_published_state_created_private(self): + self.unpublished_session.is_public = False + self.unpublished_session.save() + published_state = {"published": True, "session": 4} + request = self.auth_client1.post("/v1/data/published/", data=published_state) + max_id = PublishedState.objects.aggregate(Max("id"))["id__max"] + new_ps = PublishedState.objects.get(id=max_id) + self.publishable_session = Session.objects.get(id=4) + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + request.data, + { + "published_state_id": max_id, + "session_id": 4, + "title": "Publishable Session", + "doi": doi_generator(4), + "published": True, + "current_user": "testUser1", + "is_public": False, + }, + ) + self.assertEqual(self.publishable_session.published_state, new_ps) + self.assertEqual(new_ps.session, self.publishable_session) + new_ps.delete() + self.unpublished_session.is_public = True + self.unpublished_session.save() + + # Test creating a published state for a public session + def test_published_state_created_public(self): + published_state = {"published": False, "session": 4} + request = self.auth_client1.post("/v1/data/published/", data=published_state) + max_id = PublishedState.objects.aggregate(Max("id"))["id__max"] + new_ps = PublishedState.objects.get(id=max_id) + self.publishable_session = Session.objects.get(id=4) + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + request.data, + { + "published_state_id": max_id, + "session_id": 4, + "title": "Publishable Session", + "doi": doi_generator(4), + "published": False, + "current_user": "testUser1", + "is_public": True, + }, + ) + self.assertEqual(self.publishable_session.published_state, new_ps) + self.assertEqual(new_ps.session, self.publishable_session) + new_ps.delete() + + # Test that you can't create a published state for an unowned session + def test_published_state_created_unowned(self): + self.unpublished_session.current_user = None + self.unpublished_session.save() + published_state = {"published": True, "session": 4} + request = self.auth_client1.post("/v1/data/published/", data=published_state) + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(len(PublishedState.objects.all()), 3) + self.unpublished_session.current_user = self.user1 + self.unpublished_session.save() + + # Test that an unauthenticated user cannot create a published state + def test_published_state_created_unauthenticated(self): + published_state = {"published": True, "session": 4} + request = self.client.post("/v1/data/published/", data=published_state) + self.assertEqual(request.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(len(PublishedState.objects.all()), 3) + + # Test that a user cannot create a published state for a session they don't own + def test_published_state_created_unauthorized(self): + published_state = {"published": True, "session": 4} + request = self.auth_client2.post("/v1/data/published/", data=published_state) + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(len(PublishedState.objects.all()), 3) + + # Test that only one published state can be created per session + def test_no_duplicate_published_states(self): + published_state = {"published": True, "session": 1} + request = self.auth_client1.post("/v1/data/published/", data=published_state) + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + + @classmethod + def tearDownClass(cls): + cls.public_session.delete() + cls.private_session.delete() + cls.unowned_session.delete() + cls.user1.delete() + cls.user2.delete() + + +class TestSinglePublishedState(APITestCase): + """Test HTTP methods of SinglePublishedStateView.""" + + @classmethod + def setUpTestData(cls): + cls.user1 = User.objects.create_user( + id=1, username="testUser1", password="secret" + ) + cls.user2 = User.objects.create_user( + id=2, username="testUser2", password="secret" + ) + cls.public_session = Session.objects.create( + id=1, current_user=cls.user1, title="Public Session", is_public=True + ) + cls.private_session = Session.objects.create( + id=2, current_user=cls.user1, title="Private Session", is_public=False + ) + cls.unowned_session = Session.objects.create( + id=3, title="Unowned Session", is_public=True + ) + cls.public_ps = PublishedState.objects.create( + id=1, + doi=doi_generator(1), + published=True, + session=cls.public_session, + ) + cls.private_ps = PublishedState.objects.create( + id=2, + doi=doi_generator(2), + published=False, + session=cls.private_session, + ) + cls.unowned_ps = PublishedState.objects.create( + id=3, + doi=doi_generator(3), + published=True, + session=cls.unowned_session, + ) + cls.auth_client1 = APIClient() + cls.auth_client2 = APIClient() + cls.auth_client1.force_authenticate(cls.user1) + cls.auth_client2.force_authenticate(cls.user2) + + # Test viewing a published state of a public session + def test_get_public_published_state(self): + request1 = self.auth_client2.get("/v1/data/published/1/") + request2 = self.client.get("/v1/data/published/1/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertEqual( + request1.data, + { + "id": 1, + "doi": doi_generator(1), + "published": True, + "session": 1, + "title": "Public Session", + "current_user": "testUser1", + "is_public": True, + }, + ) + self.assertEqual(request1.data, request2.data) + + # Test viewing a published state of a private session + def test_get_private_published_state(self): + request = self.auth_client1.get("/v1/data/published/2/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "id": 2, + "doi": doi_generator(2), + "published": False, + "session": 2, + "title": "Private Session", + "current_user": "testUser1", + "is_public": False, + }, + ) + + # Test viewing a published state of an unowned session + def test_get_unowned_published_state(self): + request = self.auth_client1.get("/v1/data/published/3/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "id": 3, + "doi": doi_generator(3), + "published": True, + "session": 3, + "title": "Unowned Session", + "current_user": "", + "is_public": True, + }, + ) + + # Test viewing a published state of a session with access granted + def test_get_shared_published_state(self): + self.private_session.users.add(self.user2) + request = self.auth_client2.get("/v1/data/published/2/") + self.private_session.users.remove(self.user2) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "id": 2, + "doi": doi_generator(2), + "published": False, + "session": 2, + "title": "Private Session", + "current_user": "testUser1", + "is_public": False, + }, + ) + + # Test a user can't view a published state of a private session they don't own + def test_get_private_published_state_unauthorized(self): + request1 = self.client.get("/v1/data/published/2/") + request2 = self.auth_client2.get("/v1/data/published/2/") + self.assertEqual(request1.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + + # Test updating a published state of a public session + def test_update_public_published_state(self): + request = self.auth_client1.put( + "/v1/data/published/1/", data={"published": False} + ) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "published_state_id": 1, + "session_id": 1, + "title": "Public Session", + "published": False, + "is_public": True, + }, + ) + self.assertFalse(PublishedState.objects.get(id=1).published) + self.public_ps.save() + + # Test updating a published state of a private session + def test_update_private_published_state(self): + request = self.auth_client1.put( + "/v1/data/published/2/", data={"published": True} + ) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "published_state_id": 2, + "session_id": 2, + "title": "Private Session", + "published": True, + "is_public": False, + }, + ) + self.assertTrue(PublishedState.objects.get(id=2).published) + self.private_ps.save() + + # Test a user can't update the published state of an unowned session + def test_update_unowned_published_state(self): + request1 = self.auth_client1.put( + "/v1/data/published/3/", data={"published": False} + ) + request2 = self.client.put("/v1/data/published/3/", data={"published": False}) + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertTrue(PublishedState.objects.get(id=3).published) + + # Test a user can't update a public published state unauthorized + def test_update_public_published_state_unauthorized(self): + request1 = self.auth_client2.put( + "/v1/data/published/1/", data={"published": False} + ) + self.public_session.users.add(self.user2) + request2 = self.auth_client2.put( + "/v1/data/published/1/", data={"published": False} + ) + self.public_session.users.remove(self.user2) + request3 = self.client.put("/v1/data/published/1/", data={"published": False}) + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request3.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertTrue(PublishedState.objects.get(id=1).published) + + # Test a user can't update a private published state unauthorized + def test_update_private_published_state_unauthorized(self): + request1 = self.auth_client2.put( + "/v1/data/published/2/", data={"published": True} + ) + self.public_session.users.add(self.user2) + request2 = self.auth_client2.put( + "/v1/data/published/2/", data={"published": True} + ) + self.public_session.users.remove(self.user2) + request3 = self.client.put("/v1/data/published/2/", data={"published": True}) + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request3.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertFalse(PublishedState.objects.get(id=2).published) + + # Test deleting a published state of a private session + def test_delete_private_published_state(self): + request = self.auth_client1.delete("/v1/data/published/2/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(len(PublishedState.objects.all()), 2) + self.assertEqual(len(Session.objects.all()), 3) + self.assertRaises(PublishedState.DoesNotExist, PublishedState.objects.get, id=2) + self.private_ps = PublishedState.objects.create( + id=2, + doi=doi_generator(2), + published=False, + session=self.private_session, + ) + + # Test a user can't delete a private published state unauthorized + def test_delete_private_published_state_unauthorized(self): + request1 = self.auth_client2.delete("/v1/data/published/2/") + self.private_session.users.add(self.user2) + request2 = self.auth_client2.delete("/v1/data/published/2/") + self.private_session.users.remove(self.user2) + request3 = self.client.delete("/v1/data/published/2/") + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request3.status_code, status.HTTP_401_UNAUTHORIZED) + + # Test a user can't delete a published state of a public + def test_cant_delete_public_published_state(self): + request = self.auth_client1.delete("/v1/data/published/1/") + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + + # Test a user can't delete an unowned published state + def test_delete_unowned_published_state(self): + request = self.auth_client1.delete("/v1/data/published/3/") + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + + @classmethod + def tearDownClass(cls): + cls.public_session.delete() + cls.private_session.delete() + cls.unowned_session.delete() + cls.user1.delete() + cls.user2.delete() diff --git a/sasdata/fair_database/data/test/test_session.py b/sasdata/fair_database/data/test/test_session.py new file mode 100644 index 000000000..fc185f8fd --- /dev/null +++ b/sasdata/fair_database/data/test/test_session.py @@ -0,0 +1,700 @@ +from data.models import DataSet, PublishedState, Session +from django.contrib.auth.models import User +from django.db.models import Max +from rest_framework import status +from rest_framework.test import APIClient, APITestCase + + +class TestSession(APITestCase): + """Test HTTP methods of SessionView.""" + + @classmethod + def setUpTestData(cls): + cls.user1 = User.objects.create_user( + id=1, username="testUser1", password="secret" + ) + cls.user2 = User.objects.create_user( + id=2, username="testUser2", password="secret" + ) + cls.public_session = Session.objects.create( + id=1, current_user=cls.user1, title="Public Session", is_public=True + ) + cls.private_session = Session.objects.create( + id=2, current_user=cls.user1, title="Private Session", is_public=False + ) + cls.unowned_session = Session.objects.create( + id=3, title="Unowned Session", is_public=True + ) + cls.auth_client1 = APIClient() + cls.auth_client2 = APIClient() + cls.auth_client1.force_authenticate(cls.user1) + cls.auth_client2.force_authenticate(cls.user2) + + # Test listing sessions + def test_list_private(self): + request = self.auth_client1.get("/v1/data/session/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "session_ids": { + 1: "Public Session", + 2: "Private Session", + 3: "Unowned Session", + } + }, + ) + + # Test listing public sessions + def test_list_public(self): + request = self.auth_client2.get("/v1/data/session/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, {"session_ids": {1: "Public Session", 3: "Unowned Session"}} + ) + + # Test listing sessions while unauthenticated + def test_list_unauthenticated(self): + request = self.client.get("/v1/data/session/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, {"session_ids": {1: "Public Session", 3: "Unowned Session"}} + ) + + # Test listing a session with access granted + def test_list_granted_access(self): + self.private_session.users.add(self.user2) + request = self.auth_client2.get("/v1/data/session/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "session_ids": { + 1: "Public Session", + 2: "Private Session", + 3: "Unowned Session", + } + }, + ) + self.private_session.users.remove(self.user2) + + # Test listing by username + def test_list_username(self): + request = self.auth_client1.get( + "/v1/data/session/", data={"username": "testUser1"} + ) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, {"session_ids": {1: "Public Session", 2: "Private Session"}} + ) + + # Test listing by another user's username + def test_list_other_username(self): + request = self.auth_client2.get( + "/v1/data/session/", data={"username": "testUser1"} + ) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual(request.data, {"session_ids": {1: "Public Session"}}) + + # Test creating a public session + def test_session_created(self): + session = { + "title": "New session", + "datasets": [ + { + "name": "New dataset", + "metadata": { + "title": "New metadata", + "run": 0, + "description": "test", + "instrument": {}, + "process": {}, + "sample": {}, + }, + "data_contents": [], + } + ], + "is_public": True, + "published_state": {"published": False}, + } + request = self.auth_client1.post( + "/v1/data/session/", data=session, format="json" + ) + max_id = Session.objects.aggregate(Max("id"))["id__max"] + new_session = Session.objects.get(id=max_id) + new_dataset = new_session.datasets.get() + new_metadata = new_dataset.metadata + new_published_state = new_session.published_state + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + request.data, + { + "session_id": max_id, + "title": "New session", + "authenticated": True, + "current_user": "testUser1", + "is_public": True, + }, + ) + self.assertEqual(new_session.title, "New session") + self.assertEqual(new_dataset.name, "New dataset") + self.assertEqual(new_metadata.title, "New metadata") + self.assertEqual(new_session.current_user, self.user1) + self.assertEqual(new_dataset.current_user, self.user1) + self.assertTrue(all([new_session.is_public, new_dataset.is_public])) + self.assertFalse(new_published_state.published) + new_session.delete() + + # Test creating a private session + def test_session_created_private(self): + session = { + "title": "New session", + "datasets": [ + { + "name": "New dataset", + "metadata": { + "title": "New metadata", + "run": 0, + "description": "test", + "instrument": {}, + "process": {}, + "sample": {}, + }, + "data_contents": [], + } + ], + "is_public": False, + } + request = self.auth_client1.post( + "/v1/data/session/", data=session, format="json" + ) + max_id = Session.objects.aggregate(Max("id"))["id__max"] + new_session = Session.objects.get(id=max_id) + new_dataset = new_session.datasets.get() + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + request.data, + { + "session_id": max_id, + "title": "New session", + "authenticated": True, + "current_user": "testUser1", + "is_public": False, + }, + ) + self.assertEqual(new_session.current_user, self.user1) + self.assertEqual(new_dataset.current_user, self.user1) + self.assertFalse(any([new_session.is_public, new_dataset.is_public])) + new_session.delete() + + # Test creating a session while unauthenticated + def test_session_created_unauthenticated(self): + session = { + "title": "New session", + "datasets": [ + { + "name": "New dataset", + "metadata": { + "title": "New metadata", + "run": 0, + "description": "test", + "instrument": {}, + "process": {}, + "sample": {}, + }, + "data_contents": [], + } + ], + "is_public": True, + } + request = self.client.post("/v1/data/session/", data=session, format="json") + max_id = Session.objects.aggregate(Max("id"))["id__max"] + new_session = Session.objects.get(id=max_id) + new_dataset = new_session.datasets.get() + self.assertEqual(request.status_code, status.HTTP_201_CREATED) + self.assertEqual( + request.data, + { + "session_id": max_id, + "title": "New session", + "authenticated": False, + "current_user": "", + "is_public": True, + }, + ) + self.assertIsNone(new_session.current_user) + self.assertIsNone(new_dataset.current_user) + self.assertTrue(all([new_session.is_public, new_dataset.is_public])) + new_session.delete() + + # Test that a private session must have an owner + def test_no_private_unowned_session(self): + session = {"title": "New session", "datasets": [], "is_public": False} + request = self.client.post("/v1/data/session/", data=session, format="json") + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + + # Test post fails with dataset validation issue + def test_no_session_invalid_dataset(self): + session = { + "title": "New session", + "datasets": [ + { + "metadata": { + "title": "New metadata", + "run": 0, + "description": "test", + "instrument": {}, + "process": {}, + "sample": {}, + }, + "data_contents": [], + } + ], + "is_public": True, + } + request = self.auth_client1.post( + "/v1/data/session/", data=session, format="json" + ) + self.assertEqual(request.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(Session.objects.all()), 3) + self.assertEqual(len(DataSet.objects.all()), 0) + + @classmethod + def tearDownClass(cls): + cls.public_session.delete() + cls.private_session.delete() + cls.unowned_session.delete() + cls.user1.delete() + cls.user2.delete() + + +class TestSingleSession(APITestCase): + """Test HTTP methods of SingleSessionView.""" + + @classmethod + def setUpTestData(cls): + cls.user1 = User.objects.create_user( + id=1, username="testUser1", password="secret" + ) + cls.user2 = User.objects.create_user( + id=2, username="testUser2", password="secret" + ) + cls.public_session = Session.objects.create( + id=1, current_user=cls.user1, title="Public Session", is_public=True + ) + cls.private_session = Session.objects.create( + id=2, current_user=cls.user1, title="Private Session", is_public=False + ) + cls.unowned_session = Session.objects.create( + id=3, title="Unowned Session", is_public=True + ) + cls.public_dataset = DataSet.objects.create( + id=1, + current_user=cls.user1, + is_public=True, + name="Public Dataset", + session=cls.public_session, + ) + cls.private_dataset = DataSet.objects.create( + id=2, + current_user=cls.user1, + name="Private Dataset", + session=cls.private_session, + ) + cls.unowned_dataset = DataSet.objects.create( + id=3, is_public=True, name="Unowned Dataset", session=cls.unowned_session + ) + cls.private_published_state = PublishedState.objects.create( + id=2, + session=cls.private_session, + published=False, + doi="http://localhost:8000/v1/data/session/2/", + ) + cls.auth_client1 = APIClient() + cls.auth_client2 = APIClient() + cls.auth_client1.force_authenticate(cls.user1) + cls.auth_client2.force_authenticate(cls.user2) + + # Test loading another user's public session + def test_get_public_session(self): + request = self.auth_client2.get("/v1/data/session/1/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "id": 1, + "current_user": "testUser1", + "users": [], + "is_public": True, + "title": "Public Session", + "datasets": [ + { + "id": 1, + "current_user": 1, + "users": [], + "is_public": True, + "name": "Public Dataset", + "files": [], + "metadata": None, + "data_contents": [], + } + ], + "published_state": None, + }, + ) + + # Test loading a private session as the owner + def test_get_private_session(self): + request = self.auth_client1.get("/v1/data/session/2/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "id": 2, + "current_user": "testUser1", + "users": [], + "is_public": False, + "title": "Private Session", + "published_state": { + "id": 2, + "published": False, + "doi": "http://localhost:8000/v1/data/session/2/", + "session": 2, + }, + "datasets": [ + { + "id": 2, + "current_user": 1, + "users": [], + "is_public": False, + "name": "Private Dataset", + "files": [], + "metadata": None, + "data_contents": [], + } + ], + }, + ) + + # Test loading a private session as a user with granted access + def test_get_private_session_access_granted(self): + self.private_session.users.add(self.user2) + request = self.auth_client2.get("/v1/data/session/2/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.private_session.users.remove(self.user2) + + # Test loading an unowned session + def test_get_unowned_session(self): + request = self.auth_client1.get("/v1/data/session/3/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "id": 3, + "current_user": None, + "users": [], + "is_public": True, + "title": "Unowned Session", + "published_state": None, + "datasets": [ + { + "id": 3, + "current_user": None, + "users": [], + "is_public": True, + "name": "Unowned Dataset", + "files": [], + "metadata": None, + "data_contents": [], + } + ], + }, + ) + + # Test loading another user's private session + def test_get_private_session_unauthorized(self): + request1 = self.auth_client2.get("/v1/data/session/2/") + request2 = self.client.get("/v1/data/session/2/") + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_401_UNAUTHORIZED) + + # Test updating a public session + def test_update_public_session(self): + request = self.auth_client1.put( + "/v1/data/session/1/", data={"is_public": False} + ) + session = Session.objects.get(id=1) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + {"session_id": 1, "title": "Public Session", "is_public": False}, + ) + self.assertFalse(session.is_public) + session.is_public = False + session.save() + + # Test creating a published state by updating a session + def test_update_session_new_published_state(self): + request = self.auth_client1.put( + "/v1/data/session/1/", + data={"published_state": {"published": False}}, + format="json", + ) + new_published_state = Session.objects.get(id=1).published_state + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertFalse(new_published_state.published) + new_published_state.delete() + + # Test that another user's public session cannot be updated + def test_update_public_session_unauthorized(self): + request1 = self.auth_client2.put( + "/v1/data/session/1/", data={"is_public": False} + ) + request2 = self.client.put("/v1/data/session/1/", data={"is_public": False}) + session = Session.objects.get(id=1) + session.users.add(self.user2) + request3 = self.auth_client2.put( + "/v1/data/session/1/", data={"is_public": False} + ) + session.users.remove(self.user2) + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(request3.status_code, status.HTTP_403_FORBIDDEN) + self.assertTrue(Session.objects.get(id=1).is_public) + + # Test updating a private session + def test_update_private_session(self): + request1 = self.auth_client1.put( + "/v1/data/session/2/", data={"is_public": True} + ) + session = Session.objects.get(id=2) + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual( + request1.data, + {"session_id": 2, "title": "Private Session", "is_public": True}, + ) + self.assertTrue(session.is_public) + self.assertTrue(session.datasets.get().is_public) + session.is_public = False + session.save() + + # Test updating a published state through its session + def test_update_session_published_state(self): + request = self.auth_client1.put( + "/v1/data/session/2/", + data={"published_state": {"published": True}}, + format="json", + ) + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertTrue(PublishedState.objects.get(id=2).published) + self.private_published_state.save() + + # Test that another user's private session cannot be updated + def test_update_private_session_unauthorized(self): + request1 = self.auth_client2.put( + "/v1/data/session/2/", data={"is_public": True} + ) + request2 = self.client.put("/v1/data/session/2/", data={"is_public": True}) + session = Session.objects.get(id=2) + session.users.add(self.user2) + request3 = self.auth_client2.put( + "/v1/data/session/2/", data={"is_public": True} + ) + session.users.remove(self.user2) + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(request3.status_code, status.HTTP_403_FORBIDDEN) + self.assertFalse(Session.objects.get(id=2).is_public) + + # Test that an unowned session cannot be updated + def test_update_unowned_session(self): + request = self.auth_client1.put( + "/v1/data/session/3/", data={"is_public": False} + ) + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + self.assertTrue(Session.objects.get(id=3).is_public) + + # Test deleting a private session + def test_delete_private_session(self): + request = self.auth_client1.delete("/v1/data/session/2/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertRaises(Session.DoesNotExist, Session.objects.get, id=2) + self.assertRaises(DataSet.DoesNotExist, DataSet.objects.get, id=2) + self.assertRaises(PublishedState.DoesNotExist, PublishedState.objects.get, id=2) + self.private_session = Session.objects.create( + id=2, current_user=self.user1, title="Private Session", is_public=False + ) + self.private_dataset = DataSet.objects.create( + id=2, + current_user=self.user1, + name="Private Dataset", + session=self.private_session, + ) + self.private_published_state = PublishedState.objects.create( + id=2, + session=self.private_session, + published=False, + doi="http://localhost:8000/v1/data/session/2/", + ) + + # Test that another user's private session cannot be deleted + def test_delete_private_session_unauthorized(self): + request1 = self.auth_client2.delete("/v1/data/session/2/") + request2 = self.client.delete("/v1/data/session/2/") + self.private_session.users.add(self.user2) + request3 = self.auth_client2.delete("/v1/data/session/2/") + self.private_session.users.remove(self.user2) + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(request3.status_code, status.HTTP_403_FORBIDDEN) + + # Test that a public session cannot be deleted + def test_delete_public_session(self): + request = self.auth_client1.delete("/v1/data/session/1/") + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + + # Test that an unowned session cannot be deleted + def test_delete_unowned_session(self): + request = self.auth_client1.delete("/v1/data/session/3/") + self.assertEqual(request.status_code, status.HTTP_403_FORBIDDEN) + + @classmethod + def tearDownClass(cls): + cls.public_session.delete() + cls.private_session.delete() + cls.unowned_session.delete() + cls.user1.delete() + cls.user2.delete() + + +class TestSessionAccessManagement(APITestCase): + """Test HTTP methods of SessionUsersView.""" + + @classmethod + def setUpTestData(cls): + cls.user1 = User.objects.create_user(username="testUser1", password="secret") + cls.user2 = User.objects.create_user(username="testUser2", password="secret") + cls.private_session = Session.objects.create( + id=1, current_user=cls.user1, title="Private Session", is_public=False + ) + cls.shared_session = Session.objects.create( + id=2, current_user=cls.user1, title="Shared Session", is_public=False + ) + cls.private_dataset = DataSet.objects.create( + id=1, + current_user=cls.user1, + name="Private Dataset", + session=cls.private_session, + ) + cls.shared_dataset = DataSet.objects.create( + id=2, + current_user=cls.user1, + name="Shared Dataset", + session=cls.shared_session, + ) + cls.shared_session.users.add(cls.user2) + cls.shared_dataset.users.add(cls.user2) + cls.client_owner = APIClient() + cls.client_other = APIClient() + cls.client_owner.force_authenticate(cls.user1) + cls.client_other.force_authenticate(cls.user2) + + # Test listing access to an unshared session + def test_list_access_private(self): + request = self.client_owner.get("/v1/data/session/1/users/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "session_id": 1, + "title": "Private Session", + "is_public": False, + "users": [], + }, + ) + + # Test listing access to a shared session + def test_list_access_shared(self): + request = self.client_owner.get("/v1/data/session/2/users/") + self.assertEqual(request.status_code, status.HTTP_200_OK) + self.assertEqual( + request.data, + { + "session_id": 2, + "title": "Shared Session", + "is_public": False, + "users": ["testUser2"], + }, + ) + + # Test that only the owner can view access + def test_list_access_unauthorized(self): + request1 = self.client_other.get("/v1/data/session/1/users/") + request2 = self.client_other.get("/v1/data/session/2/users/") + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + + # Test granting access to a session + def test_grant_access(self): + request1 = self.client_owner.put( + "/v1/data/session/1/users/", {"username": "testUser2", "access": True} + ) + request2 = self.client_other.get("/v1/data/session/1/") + request3 = self.client_other.get("/v1/data/set/1/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual( + request1.data, + { + "username": "testUser2", + "session_id": 1, + "title": "Private Session", + "access": True, + }, + ) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertEqual(request3.status_code, status.HTTP_200_OK) + self.assertIn(self.user2, self.private_session.users.all()) # codespell:ignore + self.assertIn(self.user2, self.private_dataset.users.all()) # codespell:ignore + self.private_session.users.remove(self.user2) + self.private_dataset.users.remove(self.user2) + + # Test revoking access to a session + def test_revoke_access(self): + request1 = self.client_owner.put( + "/v1/data/session/2/users/", {"username": "testUser2", "access": False} + ) + request2 = self.client_other.get("/v1/data/session/2/") + request3 = self.client_other.get("/v1/data/session/2/") + self.assertEqual(request1.status_code, status.HTTP_200_OK) + self.assertEqual(request2.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request3.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual( + request1.data, + { + "username": "testUser2", + "session_id": 2, + "title": "Shared Session", + "access": False, + }, + ) + self.assertNotIn(self.user2, self.shared_session.users.all()) + self.assertNotIn(self.user2, self.shared_dataset.users.all()) + self.shared_session.users.add(self.user2) + self.shared_dataset.users.add(self.user2) + + # Test that only the owner can change access + def test_revoke_access_unauthorized(self): + request1 = self.client_other.put( + "/v1/data/session/2/users/", {"username": "testUser2", "access": False} + ) + request2 = self.client_other.get("/v1/data/session/2/") + self.assertEqual(request1.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(request2.status_code, status.HTTP_200_OK) + self.assertIn(self.user2, self.shared_session.users.all()) # codespell:ignore + + @classmethod + def tearDownClass(cls): + cls.private_session.delete() + cls.shared_session.delete() + cls.user1.delete() + cls.user2.delete() diff --git a/sasdata/fair_database/data/urls.py b/sasdata/fair_database/data/urls.py new file mode 100644 index 000000000..0e94f60c7 --- /dev/null +++ b/sasdata/fair_database/data/urls.py @@ -0,0 +1,49 @@ +from django.urls import path + +from . import views + +urlpatterns = [ + path("file/", views.DataFileView.as_view(), name="view and create files"), + path( + "file//", + views.SingleDataFileView.as_view(), + name="view, download, modify, delete files", + ), + path( + "file//users/", + views.DataFileUsersView.as_view(), + name="manage access to files", + ), + path("set/", views.DataSetView.as_view(), name="view and create datasets"), + path( + "set//", + views.SingleDataSetView.as_view(), + name="load, modify, delete datasets", + ), + path( + "set//users/", + views.DataSetUsersView.as_view(), + name="manage access to datasets", + ), + path("session/", views.SessionView.as_view(), name="view and create sessions"), + path( + "session//", + views.SingleSessionView.as_view(), + name="load, modify, delete sessions", + ), + path( + "session//users/", + views.SessionUsersView.as_view(), + name="manage access to sessions", + ), + path( + "published/", + views.PublishedStateView.as_view(), + name="view and create published states", + ), + path( + "published//", + views.SinglePublishedStateView.as_view(), + name="load, modify, delete published states", + ), +] diff --git a/sasdata/fair_database/data/views.py b/sasdata/fair_database/data/views.py new file mode 100644 index 000000000..fc1c547c1 --- /dev/null +++ b/sasdata/fair_database/data/views.py @@ -0,0 +1,707 @@ +import json +import os + +from data.forms import DataFileForm +from data.models import DataFile, DataSet, PublishedState, Session +from data.serializers import ( + AccessManagementSerializer, + DataFileSerializer, + DataSetSerializer, + PublishedStateSerializer, + PublishedStateUpdateSerializer, + SessionSerializer, +) +from django.contrib.auth.models import User +from django.http import ( + FileResponse, + Http404, + HttpResponse, + HttpResponseBadRequest, + HttpResponseForbidden, +) +from django.shortcuts import get_object_or_404 +from drf_spectacular.utils import extend_schema +from fair_database import permissions +from fair_database.permissions import DataPermission +from rest_framework import status +from rest_framework.response import Response +from rest_framework.views import APIView + +from sasdata.dataloader.loader import Loader + + +class DataFileView(APIView): + """ + View associated with the DataFile model. + + Functionality for viewing a list of files and uploading a new file. + """ + + # List of datafiles + @extend_schema( + description="Retrieve a list of accessible data files by id and filename." + ) + def get(self, request, version=None): + if "username" in request.GET: + search_user = get_object_or_404(User, username=request.GET["username"]) + data_list = {"user_data_ids": {}} + private_data = DataFile.objects.filter(current_user=search_user) + for x in private_data: + if permissions.check_permissions(request, x): + data_list["user_data_ids"][x.id] = x.file_name + else: + public_data = DataFile.objects.all() + data_list = {"public_data_ids": {}} + for x in public_data: + if permissions.check_permissions(request, x): + data_list["public_data_ids"][x.id] = x.file_name + return Response(data_list) + + # Create a datafile + @extend_schema(description="Upload a data file.") + def post(self, request, version=None): + form = DataFileForm(request.data, request.FILES) + if form.is_valid(): + form.save() + db = DataFile.objects.get(pk=form.instance.pk) + serializer = DataFileSerializer( + db, + data={ + "file_name": os.path.basename(form.instance.file.path), + "current_user": None, + "users": [], + }, + context={"is_public": db.is_public}, + ) + if request.user.is_authenticated: + serializer.initial_data["current_user"] = request.user.id + + if serializer.is_valid(raise_exception=True): + serializer.save() + return_data = { + "current_user": request.user.username, + "authenticated": request.user.is_authenticated, + "file_id": db.id, + "file_alternative_name": serializer.data["file_name"], + "is_public": serializer.data["is_public"], + } + return Response(return_data, status=status.HTTP_201_CREATED) + + # Create a datafile + @extend_schema(description="Upload a data file.") + def put(self, request, version=None): + return self.post(request, version) + + +class SingleDataFileView(APIView): + """ + View associated with a single DataFile. + + Functionality for viewing, modifying, or deleting a DataFile. + """ + + # Load the contents of a datafile or download the file to a device + @extend_schema( + description="Retrieve the contents of a data file or download a file." + ) + def get(self, request, data_id, version=None): + data = get_object_or_404(DataFile, id=data_id) + if "download" in request.GET and request.GET["download"]: + if not permissions.check_permissions(request, data): + if not request.user.is_authenticated: + return HttpResponse("Must be authenticated to download", status=401) + return HttpResponseForbidden("data is private") + try: + file = open(data.file.path, "rb") + except Exception as e: + return HttpResponseBadRequest(str(e)) + if file is None: + raise Http404("File not found.") + return FileResponse(file, as_attachment=True) + else: + loader = Loader() + if not permissions.check_permissions(request, data): + if not request.user.is_authenticated: + return HttpResponse("Must be authenticated to view", status=401) + return HttpResponseForbidden( + "Data is either not public or wrong auth token" + ) + data_list = loader.load(data.file.path) + contents = [str(data) for data in data_list] + return_data = {data.file_name: contents} + return Response(return_data) + + # Modify a datafile + @extend_schema(description="Make changes to a data file that you own.") + def put(self, request, data_id, version=None): + db = get_object_or_404(DataFile, id=data_id) + if not permissions.check_permissions(request, db): + if not request.user.is_authenticated: + return HttpResponse("must be authenticated to modify", status=401) + return HttpResponseForbidden("must be the data owner to modify") + form = DataFileForm(request.data, request.FILES, instance=db) + if form.is_valid(): + form.save() + serializer = DataFileSerializer( + db, + data={ + "file_name": os.path.basename(form.instance.file.path), + "current_user": request.user.id, + }, + context={"is_public": db.is_public}, + partial=True, + ) + if serializer.is_valid(raise_exception=True): + serializer.save() + return_data = { + "current_user": request.user.username, + "authenticated": request.user.is_authenticated, + "file_id": db.id, + "file_alternative_name": serializer.data["file_name"], + "is_public": serializer.data["is_public"], + } + return Response(return_data) + + # Delete a datafile + @extend_schema(description="Delete a data file that you own.") + def delete(self, request, data_id, version=None): + db = get_object_or_404(DataFile, id=data_id) + if not permissions.is_owner(request, db): + if not request.user.is_authenticated: + return HttpResponse("Must be authenticated to delete", status=401) + return HttpResponseForbidden("Must be the data owner to delete") + db.delete() + return Response(data={"success": True}) + + +class DataFileUsersView(APIView): + """ + View for the users that have access to a datafile. + + Functionality for accessing a list of users with access and granting or + revoking access. + """ + + # View users with access to a datafile + @extend_schema( + description="Retrieve a list of users that have been granted access to" + " a data file and the file's publicity status." + ) + def get(self, request, data_id, version=None): + db = get_object_or_404(DataFile, id=data_id) + if not permissions.is_owner(request, db): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to manage access", status=401 + ) + return HttpResponseForbidden("Must be the data owner to manage access") + response_data = { + "file": db.pk, + "file_name": db.file_name, + "is_public": db.is_public, + "users": [user.username for user in db.users.all()], + } + return Response(response_data) + + # Grant or revoke access to a datafile + @extend_schema(description="Grant or revoke a user's access to a data file.") + def put(self, request, data_id, version=None): + db = get_object_or_404(DataFile, id=data_id) + if not permissions.is_owner(request, db): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to manage access", status=401 + ) + return HttpResponseForbidden("Must be the data owner to manage access") + serializer = AccessManagementSerializer(data=request.data) + serializer.is_valid() + user = get_object_or_404(User, username=serializer.data["username"]) + if serializer.data["access"]: + db.users.add(user) + else: + db.users.remove(user) + response_data = { + "username": user.username, + "file": db.pk, + "file_name": db.file_name, + "access": (serializer.data["access"] or user == db.current_user), + } + return Response(response_data) + + +class DataSetView(APIView): + """ + View associated with the DataSet model. + + Functionality for viewing a list of datasets and creating a dataset. + """ + + permission_classes = [DataPermission] + + # get a list of accessible datasets + @extend_schema(description="Retrieve a list of accessible datasets by id and name.") + def get(self, request, version=None): + data_list = {"dataset_ids": {}} + data = DataSet.objects.all() + if "username" in request.GET: + user = get_object_or_404(User, username=request.GET["username"]) + data = DataSet.objects.filter(current_user=user) + for dataset in data: + if permissions.check_permissions(request, dataset): + data_list["dataset_ids"][dataset.id] = dataset.name + return Response(data=data_list) + + # TODO: enable uploading files as part of dataset creation, not just associating dataset with existing files + # create a dataset + @extend_schema(description="Upload a dataset.") + def post(self, request, version=None): + # TODO: revisit request data format + if isinstance(request.data, str): + serializer = DataSetSerializer( + data=json.loads(request.data), context={"request": request} + ) + else: + serializer = DataSetSerializer( + data=request.data, context={"request": request} + ) + if serializer.is_valid(raise_exception=True): + serializer.save() + db = serializer.instance + response = { + "dataset_id": db.id, + "name": db.name, + "authenticated": request.user.is_authenticated, + "current_user": request.user.username, + "is_public": db.is_public, + } + return Response(data=response, status=status.HTTP_201_CREATED) + + # create a dataset + @extend_schema(description="Upload a dataset.") + def put(self, request, version=None): + return self.post(request, version) + + +class SingleDataSetView(APIView): + """ + View associated with single datasets. + + Functionality for accessing a dataset in a format intended to be loaded + into SasView, modifying a dataset, or deleting a dataset. + """ + + permission_classes = [DataPermission] + + # get a specific dataset + @extend_schema(description="Retrieve a dataset.") + def get(self, request, data_id, version=None): + db = get_object_or_404(DataSet, id=data_id) + if not permissions.check_permissions(request, db): + if not request.user.is_authenticated: + return HttpResponse("Must be authenticated to view dataset", status=401) + return HttpResponseForbidden( + "You do not have permission to view this dataset." + ) + serializer = DataSetSerializer(db, context={"request": request}) + response_data = serializer.data + if db.current_user: + response_data["current_user"] = db.current_user.username + return Response(response_data) + + # edit a specific dataset + @extend_schema(description="Make changes to a dataset that you own.") + def put(self, request, data_id, version=None): + db = get_object_or_404(DataSet, id=data_id) + if not permissions.check_permissions(request, db): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to modify dataset", status=401 + ) + return HttpResponseForbidden("Cannot modify a dataset you do not own") + serializer = DataSetSerializer( + db, request.data, context={"request": request}, partial=True + ) + clear_files = "files" in request.data and not request.data["files"] + if clear_files: + data_copy = request.data.copy() + data_copy.pop("files") + serializer = DataSetSerializer( + db, data_copy, context={"request": request}, partial=True + ) + if serializer.is_valid(raise_exception=True): + serializer.save() + if clear_files: + db.files.clear() + db.save() + data = {"data_id": db.id, "name": db.name, "is_public": db.is_public} + return Response(data) + + # delete a dataset + @extend_schema(description="Delete a dataset that you own.") + def delete(self, request, data_id, version=None): + db = get_object_or_404(DataSet, id=data_id) + if not permissions.check_permissions(request, db): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to delete a dataset", status=401 + ) + return HttpResponseForbidden("Not authorized to delete") + db.delete() + return Response({"success": True}) + + +class DataSetUsersView(APIView): + """ + View for the users that have access to a dataset. + + Functionality for accessing a list of users with access and granting or + revoking access. + """ + + permission_classes = [DataPermission] + + # get a list of users with access to dataset data_id + @extend_schema( + description="Retrieve a list of users that have been granted access to" + " a dataset and the dataset's publicity status." + ) + def get(self, request, data_id, version=None): + db = get_object_or_404(DataSet, id=data_id) + if not permissions.is_owner(request, db): + if not request.user.is_authenticated: + return HttpResponse("Must be authenticated to view access", status=401) + return HttpResponseForbidden("Must be the dataset owner to view access") + response_data = { + "data_id": db.id, + "name": db.name, + "is_public": db.is_public, + "users": [user.username for user in db.users.all()], + } + return Response(response_data) + + # grant or revoke a user's access to dataset data_id + @extend_schema(description="Grant or revoke a user's access to a dataset.") + def put(self, request, data_id, version=None): + db = get_object_or_404(DataSet, id=data_id) + if not permissions.is_owner(request, db): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to manage access", status=401 + ) + return HttpResponseForbidden("Must be the dataset owner to manage access") + serializer = AccessManagementSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + user = get_object_or_404(User, username=serializer.data["username"]) + if serializer.data["access"]: + db.users.add(user) + else: + db.users.remove(user) + response_data = { + "username": user.username, + "data_id": db.id, + "name": db.name, + "access": serializer.data["access"], + } + return Response(response_data) + + +class SessionView(APIView): + """ + View associated with the Session model. + + Functionality for viewing a list of sessions and for creating a session. + """ + + # View a list of accessible sessions + @extend_schema( + description="Retrieve a list of accessible sessions by name and title." + ) + def get(self, request, version=None): + session_list = {"session_ids": {}} + sessions = Session.objects.all() + if "username" in request.GET: + user = get_object_or_404(User, username=request.GET["username"]) + sessions = Session.objects.filter(current_user=user) + for session in sessions: + if permissions.check_permissions(request, session): + session_list["session_ids"][session.id] = session.title + return Response(data=session_list) + + # Create a session + # TODO: revisit response data + @extend_schema(description="Upload a session.") + def post(self, request, version=None): + if isinstance(request.data, str): + serializer = SessionSerializer( + data=json.loads(request.data), context={"request": request} + ) + else: + serializer = SessionSerializer( + data=request.data, context={"request": request} + ) + if serializer.is_valid(raise_exception=True): + serializer.save() + db = serializer.instance + response = { + "session_id": db.id, + "title": db.title, + "authenticated": request.user.is_authenticated, + "current_user": request.user.username, + "is_public": db.is_public, + } + return Response(data=response, status=status.HTTP_201_CREATED) + + # Create a session + @extend_schema(description="Upload a session.") + def put(self, request, version=None): + return self.post(request, version) + + +class SingleSessionView(APIView): + """ + View associated with single sessions. + + Functionality for viewing, modifying, and deleting individual sessions. + """ + + # get a specific session + @extend_schema(description="Retrieve a session.") + def get(self, request, data_id, version=None): + db = get_object_or_404(Session, id=data_id) + if not permissions.check_permissions(request, db): + if not request.user.is_authenticated: + return HttpResponse("Must be authenticated to view session", status=401) + return HttpResponseForbidden( + "You do not have permission to view this session." + ) + serializer = SessionSerializer(db) + response_data = serializer.data + if db.current_user: + response_data["current_user"] = db.current_user.username + return Response(response_data) + + # modify a session + @extend_schema(description="Make changes to a session that you own.") + def put(self, request, data_id, version=None): + db = get_object_or_404(Session, id=data_id) + if not permissions.check_permissions(request, db): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to modify session", status=401 + ) + return HttpResponseForbidden("Cannot modify a session you do not own") + serializer = SessionSerializer( + db, request.data, context={"request": request}, partial=True + ) + if serializer.is_valid(raise_exception=True): + serializer.save() + data = {"session_id": db.id, "title": db.title, "is_public": db.is_public} + return Response(data) + + # delete a session + @extend_schema(description="Delete a session that you own.") + def delete(self, request, data_id, version=None): + db = get_object_or_404(Session, id=data_id) + if not permissions.check_permissions(request, db): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to delete a session", status=401 + ) + return HttpResponseForbidden("Not authorized to delete") + db.delete() + return Response({"success": True}) + + +class SessionUsersView(APIView): + """ + View for the users that have access to a session. + + Functionality for accessing a list of users with access and granting or + revoking access. + """ + + # view the users that have access to a specific session + @extend_schema( + description="Retrieve a list of users that have been granted access to" + " a session and the session's publicity status." + ) + def get(self, request, data_id, version=None): + db = get_object_or_404(Session, id=data_id) + if not permissions.is_owner(request, db): + if not request.user.is_authenticated: + return HttpResponse("Must be authenticated to view access", status=401) + return HttpResponseForbidden("Must be the session owner to view access") + response_data = { + "session_id": db.id, + "title": db.title, + "is_public": db.is_public, + "users": [user.username for user in db.users.all()], + } + return Response(response_data) + + # grant or revoke access to a session + @extend_schema(description="Grant or revoke a user's access to a data file.") + def put(self, request, data_id, version=None): + db = get_object_or_404(Session, id=data_id) + if not permissions.is_owner(request, db): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to manage access", status=401 + ) + return HttpResponseForbidden("Must be the dataset owner to manage access") + serializer = AccessManagementSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + user = get_object_or_404(User, username=serializer.data["username"]) + if serializer.data["access"]: + db.users.add(user) + for dataset in db.datasets.all(): + dataset.users.add(user) + else: + db.users.remove(user) + for dataset in db.datasets.all(): + dataset.users.remove(user) + response_data = { + "username": user.username, + "session_id": db.id, + "title": db.title, + "access": serializer.data["access"], + } + return Response(response_data) + + +class PublishedStateView(APIView): + """ + View associated with the PublishedState model. + + Functionality for viewing a list of session published states and for + creating a published state. + """ + + # View a list of accessible sessions' published states + @extend_schema( + description="Retrieve a list of published states of accessible sessions." + ) + def get(self, request, version=None): + ps_list = {"published_state_ids": {}} + published_states = PublishedState.objects.all() + if "username" in request.GET: + user = get_object_or_404(User, username=request.GET["username"]) + published_states = PublishedState.objects.filter(session__current_user=user) + for ps in published_states: + if permissions.check_permissions(request, ps.session): + ps_list["published_state_ids"][ps.id] = { + "title": ps.session.title, + "published": ps.published, + "doi": ps.doi, + } + return Response(data=ps_list) + + # Create a published state for an existing session + @extend_schema(description="Create a published state for an existing session.") + def post(self, request, version=None): + if isinstance(request.data, str): + serializer = PublishedStateSerializer( + data=json.loads(request.data), context={"request": request} + ) + else: + serializer = PublishedStateSerializer( + data=request.data, context={"request": request} + ) + if serializer.is_valid(raise_exception=True): + if not permissions.is_owner(request, serializer.validated_data["session"]): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to create a published state for a session", + status=401, + ) + return HttpResponseForbidden( + "Must be the session owner to create a published state for a session" + ) + serializer.save() + db = serializer.instance + response = { + "published_state_id": db.id, + "session_id": db.session.id, + "title": db.session.title, + "doi": db.doi, + "published": db.published, + "current_user": request.user.username, + "is_public": db.session.is_public, + } + return Response(data=response, status=status.HTTP_201_CREATED) + + # Create a published state for an existing session + @extend_schema(description="Create a published state for an existing session.") + def put(self, request, version=None): + return self.post(request, version) + + +class SinglePublishedStateView(APIView): + """ + View associated with specific session published states. + + Functionality for viewing, modifying, and deleting individual published states. + """ + + # View a specific published state + @extend_schema(description="Retrieve a published state.") + def get(self, request, ps_id, version=None): + db = get_object_or_404(PublishedState, id=ps_id) + if not permissions.check_permissions(request, db.session): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to view published state", status=401 + ) + return HttpResponseForbidden( + "You do not have permission to view this published state." + ) + serializer = PublishedStateSerializer(db) + response_data = serializer.data + response_data["title"] = db.session.title + if db.session.current_user: + response_data["current_user"] = db.session.current_user.username + else: + response_data["current_user"] = "" + response_data["is_public"] = db.session.is_public + return Response(response_data) + + # Modify a published state + @extend_schema( + description="Make changes to the published state of a session that you own." + ) + def put(self, request, ps_id, version=None): + db = get_object_or_404(PublishedState, id=ps_id) + if not permissions.check_permissions(request, db.session): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to modify published state", status=401 + ) + return HttpResponseForbidden( + "Cannot modify a published state you do not own" + ) + serializer = PublishedStateUpdateSerializer( + db, request.data, context={"request": request}, partial=True + ) + if serializer.is_valid(raise_exception=True): + serializer.save() + data = { + "published_state_id": db.id, + "session_id": db.session.id, + "title": db.session.title, + "published": db.published, + "is_public": db.session.is_public, + } + return Response(data) + + # Delete a published state + @extend_schema(description="Delete the published state of a session that you own.") + def delete(self, request, ps_id, version=None): + db = get_object_or_404(PublishedState, id=ps_id) + if not permissions.check_permissions(request, db.session): + if not request.user.is_authenticated: + return HttpResponse( + "Must be authenticated to delete a published state", status=401 + ) + return HttpResponseForbidden("Not authorized to delete") + db.delete() + return Response({"success": True}) diff --git a/sasdata/fair_database/documentation.yaml b/sasdata/fair_database/documentation.yaml new file mode 100644 index 000000000..22a0488fb --- /dev/null +++ b/sasdata/fair_database/documentation.yaml @@ -0,0 +1,1172 @@ +openapi: 3.0.3 +info: + title: SasView Database + version: 0.1.0 + description: A database following the FAIR data principles for SasView, a small + angle scattering analysis application. +paths: + /{version}/data/file/: + get: + operationId: data_file_list + description: Retrieve a list of accessible data files by id and filename. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: "#components/schemas/DataFileList" + post: + operationId: data_file_create + description: Upload a data file. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + requestBody: + required: true + content: + multipart/form-data: + schema: + $ref: "#components/schemas/DataFileCreate" + responses: + '201': + description: CREATED + content: + application/json: + schema: + $ref: "#components/schemas/DataFileCreated" + put: + operationId: data_file_create_2 + description: Upload a data file. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + requestBody: + required: true + content: + multipart/form-data: + schema: + $ref: "components/schemas/DataFileCreate" + responses: + '201': + description: CREATED + content: + application/json: + schema: + $ref: "#components/schemas/DataFileCreated" + /{version}/data/file/{data_id}/: + get: + operationId: data_file_retrieve + description: Retrieve the contents of a data file or download a file. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + requestBlock: + schema: + $ref: "#components/schemas/DataFileGet" + responses: + '200': + description: OK + content: + application/json: + schema: + oneOf: + - $ref: "#components/schemas/DataFile" + - $ref: "components/schemas/DataFileDownload" + put: + operationId: data_file_update_2 + description: Make changes to a data file that you own. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + requestBody: + required: true + content: + multipart/form-data: + schema: + $ref: "components/schemas/DataFileCreate" + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: "#components/schemas/DataFileCreated" + delete: + operationId: data_file_destroy + description: Delete a data file that you own. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: "#components/schemas/Delete" + /{version}/data/file/{data_id}/users/: + get: + operationId: data_file_users_retrieve + description: Retrieve a list of users that have been granted access to a data + file and the file's publicity status. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + content: + application/json: + schema: + allOf: + - $ref: "#components/schemas/UsersList" + - type: object + properties: + file: + type: integer + file_name: + type: string + put: + operationId: data_file_users_update + description: Grant or revoke a user's access to a data file. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + requestBody: + required: true + content: + application/json: + schema: + $ref: "#components/schemas/ManageAccess" + responses: + '200': + description: OK + content: + application/json: + schema: + allOf: + - $ref: "#components/schemas/ManageAccess" + - type: object + properties: + file: + type: integer + file_name: + type: string + /{version}/data/published/: + get: + operationId: data_published_retrieve + description: Retrieve a list of published states of accessible sessions. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + post: + operationId: data_published_create + description: Create a published state for an existing session. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '201': + description: CREATED + put: + operationId: data_published_update + description: Create a published state for an existing session. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '201': + description: CREATED + /{version}/data/published/{ps_id}/: + get: + operationId: data_published_retrieve_2 + description: Retrieve a published state. + parameters: + - in: path + name: ps_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + put: + operationId: data_published_update_2 + description: Make changes to the published state of a session that you own. + parameters: + - in: path + name: ps_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + delete: + operationId: data_published_destroy + description: Delete the published state of a session that you own. + parameters: + - in: path + name: ps_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + /{version}/data/session/: + get: + operationId: data_session_retrieve + description: Retrieve a list of accessible sessions by name and title. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + post: + operationId: data_session_create + description: Upload a session. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '201': + description: CREATED + put: + operationId: data_session_update + description: Upload a session. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '201': + description: CREATED + /{version}/data/session/{data_id}/: + get: + operationId: data_session_retrieve_2 + description: Retrieve a session. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + put: + operationId: data_session_update_2 + description: Make changes to a session that you own. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + delete: + operationId: data_session_destroy + description: Delete a session that you own. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + /{version}/data/session/{data_id}/users/: + get: + operationId: data_session_users_retrieve + description: Retrieve a list of users that have been granted access to a session + and the session's publicity status. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + put: + operationId: data_session_users_update + description: Grant or revoke a user's access to a data file. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + description: OK + /{version}/data/set/: + get: + operationId: data_set_retrieve + description: Retrieve a list of accessible datasets by id and name. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '200': + description: OK + post: + operationId: data_set_create + description: Upload a dataset. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '201': + description: CREATED + put: + operationId: data_set_update + description: Upload a dataset. + parameters: + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '201': + description: CREATED + /{version}/data/set/{data_id}/: + get: + operationId: data_set_retrieve_2 + description: Retrieve a dataset. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '200': + description: OK + put: + operationId: data_set_update_2 + description: Make changes to a dataset that you own. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '200': + description: OK + delete: + operationId: data_set_destroy + description: Delete a dataset that you own. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '200': + description: OK + /{version}/data/set/{data_id}/users/: + get: + operationId: data_set_users_retrieve + description: Retrieve a list of users that have been granted access to a dataset + and the dataset's publicity status. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '200': + description: OK + put: + operationId: data_set_users_update + description: Grant or revoke a user's access to a dataset. + parameters: + - in: path + name: data_id + schema: + type: integer + required: true + - in: path + name: version + schema: + type: string + pattern: ^(v1)$ + required: true + tags: + - data + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '200': + description: OK + /auth/login/: + post: + operationId: auth_login_create + description: |- + Check the credentials and return the REST Token + if the credentials are valid and authenticated. + Calls Django Auth login method to register User ID + in Django session framework + + Accept the following POST parameters: username, password + Return the REST Framework Token Object's key. + tags: + - auth + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/Login' + application/x-www-form-urlencoded: + schema: + $ref: '#/components/schemas/Login' + multipart/form-data: + schema: + $ref: '#/components/schemas/Login' + required: true + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/Login' + description: '' + /auth/logout/: + post: + operationId: auth_logout_create + description: |- + Calls Django logout method and delete the Token object + assigned to the current User object. + + Accepts/Returns nothing. + tags: + - auth + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/RestAuthDetail' + description: '' + /auth/password/change/: + post: + operationId: auth_password_change_create + description: |- + Calls Django Auth SetPasswordForm save method. + + Accepts the following POST parameters: new_password1, new_password2 + Returns the success/fail message. + tags: + - auth + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/PasswordChange' + application/x-www-form-urlencoded: + schema: + $ref: '#/components/schemas/PasswordChange' + multipart/form-data: + schema: + $ref: '#/components/schemas/PasswordChange' + required: true + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/RestAuthDetail' + description: '' + /auth/register/: + post: + operationId: auth_register_create + description: |- + Registers a new user. + + Accepts the following POST parameters: username, email, password1, password2. + tags: + - auth + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/Register' + application/x-www-form-urlencoded: + schema: + $ref: '#/components/schemas/Register' + multipart/form-data: + schema: + $ref: '#/components/schemas/Register' + required: true + security: + - knoxApiToken: [] + - cookieAuth: [] + - {} + responses: + '201': + content: + application/json: + schema: + $ref: '#/components/schemas/Register' + description: '' + /auth/user/: + get: + operationId: auth_user_retrieve + description: |- + Reads and updates UserModel fields + Accepts GET, PUT, PATCH methods. + + Default accepted fields: username, first_name, last_name + Default display fields: pk, username, email, first_name, last_name + Read-only fields: pk, email + + Returns UserModel fields. + tags: + - auth + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/UserDetails' + description: '' + put: + operationId: auth_user_update + description: |- + Reads and updates UserModel fields + Accepts GET, PUT, PATCH methods. + + Default accepted fields: username, first_name, last_name + Default display fields: pk, username, email, first_name, last_name + Read-only fields: pk, email + + Returns UserModel fields. + tags: + - auth + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/UserDetails' + application/x-www-form-urlencoded: + schema: + $ref: '#/components/schemas/UserDetails' + multipart/form-data: + schema: + $ref: '#/components/schemas/UserDetails' + required: true + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/UserDetails' + description: '' + patch: + operationId: auth_user_partial_update + description: |- + Reads and updates UserModel fields + Accepts GET, PUT, PATCH methods. + + Default accepted fields: username, first_name, last_name + Default display fields: pk, username, email, first_name, last_name + Read-only fields: pk, email + + Returns UserModel fields. + tags: + - auth + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/PatchedUserDetails' + application/x-www-form-urlencoded: + schema: + $ref: '#/components/schemas/PatchedUserDetails' + multipart/form-data: + schema: + $ref: '#/components/schemas/PatchedUserDetails' + security: + - knoxApiToken: [] + - cookieAuth: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/UserDetails' + description: '' +components: + schemas: + Delete: + type: object + properties: + success: + type: boolean + UsersList: + type: object + properties: + is_public: + type: boolean + users: + type: array + items: + type: string + ManageAccess: + type: object + properties: + username: + type: string + access: + type: boolean + DataFileList: + type: object + properties: + data_ids: + type: object + additionalProperties: + filename: + type: string + DataFileCreate: + type: object + properties: + filename: + type: string + file: + type: string + format: binary + DataFileCreated: + type: object + properties: + current_user: + type: string + authenticated: + type: boolean + file_id: + type: integer + file_alternative_name: + type: string + is_public: + type: boolean + DataFileGet: + type: object + properties: + download: + type: boolean + DataFile: + type: object + properties: + filename: + type: object + additionalProperties: + type: array + items: + type: string + DataFileDownload: + type: string + format: binary + Login: + type: object + properties: + username: + type: string + email: + type: string + format: email + password: + type: string + required: + - password + DataSetList: + type: object + properties: + dataset_ids: + type: object + additionalProperties: + name: + type: string + DataSetCreate: + type: object + properties: + name: + type: string + is_public: + type: boolean + data_contents: + type: array + items: + type: object + properties: + label: + type: string + value: + type: object + additionalProperties: true + PasswordChange: + type: object + properties: + new_password1: + type: string + maxLength: 128 + new_password2: + type: string + maxLength: 128 + required: + - new_password1 + - new_password2 + PatchedUserDetails: + type: object + description: User model w/o password + properties: + pk: + type: integer + readOnly: true + title: ID + username: + type: string + description: Required. 150 characters or fewer. Letters, digits and @/./+/-/_ + only. + pattern: ^[\w.@+-]+$ + maxLength: 150 + email: + type: string + format: email + readOnly: true + title: Email address + first_name: + type: string + maxLength: 150 + last_name: + type: string + maxLength: 150 + Register: + type: object + properties: + username: + type: string + maxLength: 150 + minLength: 1 + email: + type: string + format: email + password1: + type: string + writeOnly: true + password2: + type: string + writeOnly: true + required: + - password1 + - password2 + - username + RestAuthDetail: + type: object + properties: + detail: + type: string + readOnly: true + required: + - detail + UserDetails: + type: object + description: User model w/o password + properties: + pk: + type: integer + readOnly: true + title: ID + username: + type: string + description: Required. 150 characters or fewer. Letters, digits and @/./+/-/_ + only. + pattern: ^[\w.@+-]+$ + maxLength: 150 + email: + type: string + format: email + readOnly: true + title: Email address + first_name: + type: string + maxLength: 150 + last_name: + type: string + maxLength: 150 + required: + - email + - pk + - username + securitySchemes: + cookieAuth: + type: apiKey + in: cookie + name: sessionid + knoxApiToken: + type: apiKey + in: header + name: Authorization + description: Token-based authentication with required prefix "Token" diff --git a/sasdata/fair_database/fair_database/__init__.py b/sasdata/fair_database/fair_database/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sasdata/fair_database/fair_database/asgi.py b/sasdata/fair_database/fair_database/asgi.py new file mode 100644 index 000000000..a10c9b212 --- /dev/null +++ b/sasdata/fair_database/fair_database/asgi.py @@ -0,0 +1,16 @@ +""" +ASGI config for fair_database project. + +It exposes the ASGI callable as a module-level variable named ``application``. + +For more information on this file, see +https://docs.djangoproject.com/en/5.1/howto/deployment/asgi/ +""" + +import os + +from django.core.asgi import get_asgi_application + +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "fair_database.settings") + +application = get_asgi_application() diff --git a/sasdata/fair_database/fair_database/create_example_session.py b/sasdata/fair_database/fair_database/create_example_session.py new file mode 100644 index 000000000..6c10d9905 --- /dev/null +++ b/sasdata/fair_database/fair_database/create_example_session.py @@ -0,0 +1,97 @@ +import requests + +session = { + "title": "Example Session", + "datasets": [ + { + "name": "Dataset 1", + "metadata": { + "title": "Metadata 1", + "run": 1, + "description": "test", + "instrument": {}, + "process": {}, + "sample": {}, + }, + "data_contents": [ + { + "value": 0, + "variance": 0, + "units": "no", + "hash": 0, + "label": "Quantity 1", + "history": {"operation_tree": {}, "references": []}, + } + ], + }, + { + "name": "Dataset 2", + "metadata": { + "title": "Metadata 2", + "run": 2, + "description": "test", + "instrument": {}, + "process": {}, + "sample": {}, + }, + "data_contents": [ + { + "label": "Quantity 2", + "value": {"array_contents": [0, 0, 0, 0], "shape": (2, 2)}, + "variance": {"array_contents": [0, 0, 0, 0], "shape": (2, 2)}, + "units": "none", + "hash": 0, + "history": { + "operation_tree": { + "operation": "neg", + "parameters": { + "a": { + "operation": "mul", + "parameters": { + "a": { + "operation": "constant", + "parameters": { + "value": {"type": "int", "value": 7} + }, + }, + "b": { + "operation": "variable", + "parameters": { + "hash_value": 111, + "name": "x", + }, + }, + }, + }, + }, + }, + "references": [ + { + "value": 5, + "variance": 0, + "units": "none", + "hash": 111, + "history": {}, + } + ], + }, + } + ], + }, + ], + "is_public": False, +} + +url = "http://127.0.0.1:8000/v1/data/session/" +login_data = {"email": "test@test.org", "username": "testUser", "password": "sasview!"} +response = requests.post("http://127.0.0.1:8000/auth/login/", data=login_data) +if response.status_code != 200: + register_data = { + "email": "test@test.org", + "username": "testUser", + "password1": "sasview!", + "password2": "sasview!", + } + response = requests.post("http://127.0.0.1:8000/auth/register/", data=register_data) +token = response.json()["token"] +requests.request("POST", url, json=session, headers={"Authorization": "Token " + token}) diff --git a/sasdata/fair_database/fair_database/permissions.py b/sasdata/fair_database/fair_database/permissions.py new file mode 100644 index 000000000..74be5f33a --- /dev/null +++ b/sasdata/fair_database/fair_database/permissions.py @@ -0,0 +1,29 @@ +from rest_framework.permissions import BasePermission + + +# check if a request is made by an object's owner +def is_owner(request, obj): + return request.user.is_authenticated and request.user == obj.current_user + + +# check if a request is made by a user with read access +def has_access(request, obj): + return is_owner(request, obj) or ( + request.user.is_authenticated and request.user in obj.users.all() + ) + + +class DataPermission(BasePermission): + # check if a request has the correct permissions for a specific object + def has_object_permission(self, request, view, obj): + if request.method == "GET": + return obj.is_public or has_access(request, obj) + elif request.method == "DELETE": + return not obj.is_public and is_owner(request, obj) + else: + return is_owner(request, obj) + + +# check if a request has the correct permissions for a specific object +def check_permissions(request, obj): + return DataPermission().has_object_permission(request, None, obj) diff --git a/sasdata/fair_database/fair_database/settings.py b/sasdata/fair_database/fair_database/settings.py new file mode 100644 index 000000000..f3ec69ed9 --- /dev/null +++ b/sasdata/fair_database/fair_database/settings.py @@ -0,0 +1,202 @@ +""" +Django settings for fair_database project. + +Generated by 'django-admin startproject' using Django 5.1.5. + +For more information on this file, see +https://docs.djangoproject.com/en/5.1/topics/settings/ + +For the full list of settings and their values, see +https://docs.djangoproject.com/en/5.1/ref/settings/ +""" + +import os +from pathlib import Path + +# Build paths inside the project like this: BASE_DIR / 'subdir'. +BASE_DIR = Path(__file__).resolve().parent.parent + + +# Quick-start development settings - unsuitable for production +# See https://docs.djangoproject.com/en/5.1/howto/deployment/checklist/ + +# SECURITY WARNING: keep the secret key used in production secret! +SECRET_KEY = "django-insecure--f-t5!pdhq&4)^&xenr^k0e8n%-h06jx9d0&2kft(!+1$xzig)" + +# SECURITY WARNING: don't run with debug turned on in production! +DEBUG = True + +ALLOWED_HOSTS = [] + + +# Application definition + +INSTALLED_APPS = [ + "data.apps.DataConfig", + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", + "django.contrib.sites", + "rest_framework", + "rest_framework.authtoken", + "allauth", + "allauth.account", + "allauth.socialaccount", + "allauth.socialaccount.providers.orcid", + "dj_rest_auth", + "dj_rest_auth.registration", + "knox", + "user_app.apps.UserAppConfig", + "drf_spectacular", +] + +SITE_ID = 1 + +MIDDLEWARE = [ + "django.middleware.security.SecurityMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.common.CommonMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", + "django.middleware.clickjacking.XFrameOptionsMiddleware", + "allauth.account.middleware.AccountMiddleware", +] + +ROOT_URLCONF = "fair_database.urls" + +TEMPLATES = [ + { + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", + ], + }, + }, +] + +WSGI_APPLICATION = "fair_database.wsgi.application" + +# Authentication +AUTHENTICATION_BACKENDS = ( + "django.contrib.auth.backends.ModelBackend", + "allauth.account.auth_backends.AuthenticationBackend", +) + +REST_FRAMEWORK = { + "DEFAULT_AUTHENTICATION_CLASSES": [ + "knox.auth.TokenAuthentication", + "rest_framework.authentication.SessionAuthentication", + ], + "DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema", +} + +REST_AUTH = { + "TOKEN_SERIALIZER": "user_app.serializers.KnoxSerializer", + "USER_DETAILS_SERIALIZER": "dj_rest_auth.serializers.UserDetailsSerializer", + "TOKEN_MODEL": "knox.models.AuthToken", + "TOKEN_CREATOR": "user_app.util.create_knox_token", +} + +SPECTACULAR_SETTINGS = { + "TITLE": "SasView Database", + "DESCRIPTION": "A database following the FAIR data principles for SasView," + " a small angle scattering analysis application.", + "VERSION": "0.1.0", + "SERVE_INCLUDE_SCHEMA": False, +} + +# allauth settings +HEADLESS_ONLY = True +ACCOUNT_EMAIL_VERIFICATION = "none" + +# to enable ORCID, register for credentials through ORCID and fill out client_id and secret +# https://info.orcid.org/documentation/integration-guide/ +# https://docs.allauth.org/en/latest/socialaccount/index.html +SOCIALACCOUNT_PROVIDERS = { + "orcid": { + "APPS": [ + { + "client_id": "", + "secret": "", + "key": "", + } + ], + "SCOPE": [ + "profile", + "email", + ], + "AUTH_PARAMETERS": {"access_type": "online"}, + # Base domain of the API. Default value: 'orcid.org', for the production API + "BASE_DOMAIN": "sandbox.orcid.org", # for the sandbox API + # Member API or Public API? Default: False (for the public API) + "MEMBER_API": False, + } +} + +# Database +# https://docs.djangoproject.com/en/5.1/ref/settings/#databases + +DATABASES = { + "default": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": BASE_DIR / "db.sqlite3", + } +} + + +# Password validation +# https://docs.djangoproject.com/en/5.1/ref/settings/#auth-password-validators + +AUTH_PASSWORD_VALIDATORS = [ + { + "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator", + }, + { + "NAME": "django.contrib.auth.password_validation.MinimumLengthValidator", + }, + { + "NAME": "django.contrib.auth.password_validation.CommonPasswordValidator", + }, + { + "NAME": "django.contrib.auth.password_validation.NumericPasswordValidator", + }, +] + + +# Internationalization +# https://docs.djangoproject.com/en/5.1/topics/i18n/ + +LANGUAGE_CODE = "en-us" + +TIME_ZONE = "UTC" + +USE_I18N = True + +USE_TZ = True + + +# Static files (CSS, JavaScript, Images) +# https://docs.djangoproject.com/en/4.2/howto/static-files/ + + +STATIC_ROOT = os.path.join(BASE_DIR, "static") +STATIC_URL = "/static/" + +# instead of doing this, create a create a new media_root +MEDIA_ROOT = os.path.join(BASE_DIR, "media") +MEDIA_URL = "/media/" + +# Default primary key field type +# https://docs.djangoproject.com/en/5.1/ref/settings/#default-auto-field + +DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" diff --git a/sasdata/fair_database/fair_database/test_permissions.py b/sasdata/fair_database/fair_database/test_permissions.py new file mode 100644 index 000000000..ffb55dbd7 --- /dev/null +++ b/sasdata/fair_database/fair_database/test_permissions.py @@ -0,0 +1,292 @@ +import os +import shutil + +from data.models import DataFile +from django.conf import settings +from django.contrib.auth.models import User +from rest_framework import status +from rest_framework.test import APITestCase + + +def find(filename): + return os.path.join( + os.path.dirname(__file__), "../../example_data/1d_data", filename + ) + + +def auth_header(response): + return {"Authorization": "Token " + response.data["token"]} + + +class DataListPermissionsTests(APITestCase): + """Test permissions of data views using user_app for authentication.""" + + @classmethod + def setUpTestData(cls): + cls.user = User.objects.create_user( + username="testUser", password="secret", id=1, email="email@domain.com" + ) + cls.user2 = User.objects.create_user( + username="testUser2", password="secret", id=2, email="email2@domain.com" + ) + cls.unowned_test_data = DataFile.objects.create( + id=1, file_name="cyl_400_40.txt", is_public=True + ) + cls.unowned_test_data.file.save( + "cyl_400_40.txt", open(find("cyl_400_40.txt"), "rb") + ) + cls.private_test_data = DataFile.objects.create( + id=2, current_user=cls.user, file_name="cyl_400_20.txt", is_public=False + ) + cls.private_test_data.file.save( + "cyl_400_20.txt", open(find("cyl_400_20.txt"), "rb") + ) + cls.public_test_data = DataFile.objects.create( + id=3, current_user=cls.user, file_name="cyl_testdata.txt", is_public=True + ) + cls.public_test_data.file.save( + "cyl_testdata.txt", open(find("cyl_testdata.txt"), "rb") + ) + cls.login_data_1 = { + "username": "testUser", + "password": "secret", + "email": "email@domain.com", + } + cls.login_data_2 = { + "username": "testUser2", + "password": "secret", + "email": "email2@domain.com", + } + + # Authenticated user can view list of data + def test_list_authenticated(self): + token = self.client.post("/auth/login/", data=self.login_data_1) + response = self.client.get("/v1/data/file/", headers=auth_header(token)) + response2 = self.client.get( + "/v1/data/file/", data={"username": "testUser"}, headers=auth_header(token) + ) + self.assertEqual( + response.data, + { + "public_data_ids": { + 1: "cyl_400_40.txt", + 2: "cyl_400_20.txt", + 3: "cyl_testdata.txt", + } + }, + ) + self.assertEqual( + response2.data, + {"user_data_ids": {2: "cyl_400_20.txt", 3: "cyl_testdata.txt"}}, + ) + + # Authenticated user cannot view other users' private data on list + def test_list_authenticated_2(self): + token = self.client.post("/auth/login/", data=self.login_data_2) + response = self.client.get("/v1/data/file/", headers=auth_header(token)) + response2 = self.client.get( + "/v1/data/file/", data={"username": "testUser"}, headers=auth_header(token) + ) + response3 = self.client.get( + "/v1/data/file/", data={"username": "testUser2"}, headers=auth_header(token) + ) + self.assertEqual( + response.data, + {"public_data_ids": {1: "cyl_400_40.txt", 3: "cyl_testdata.txt"}}, + ) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + self.assertEqual(response2.data, {"user_data_ids": {3: "cyl_testdata.txt"}}) + self.assertEqual(response3.data, {"user_data_ids": {}}) + + # Unauthenticated user can view list of public data + def test_list_unauthenticated(self): + response = self.client.get("/v1/data/file/") + response2 = self.client.get("/v1/data/file/", data={"username": "testUser"}) + self.assertEqual( + response.data, + {"public_data_ids": {1: "cyl_400_40.txt", 3: "cyl_testdata.txt"}}, + ) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + self.assertEqual(response2.data, {"user_data_ids": {3: "cyl_testdata.txt"}}) + + # Authenticated user can load public data and owned private data + def test_load_authenticated(self): + token = self.client.post("/auth/login/", data=self.login_data_1) + response = self.client.get("/v1/data/file/1/", headers=auth_header(token)) + response2 = self.client.get("/v1/data/file/2/", headers=auth_header(token)) + response3 = self.client.get("/v1/data/file/3/", headers=auth_header(token)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + self.assertEqual(response3.status_code, status.HTTP_200_OK) + + # Authenticated user cannot load others' private data + def test_load_unauthorized(self): + token = self.client.post("/auth/login/", data=self.login_data_2) + response = self.client.get("/v1/data/file/2/", headers=auth_header(token)) + response2 = self.client.get("/v1/data/file/3/", headers=auth_header(token)) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + + # Unauthenticated user can load public data only + def test_load_unauthenticated(self): + response = self.client.get("/v1/data/file/1/") + response2 = self.client.get("/v1/data/file/2/") + response3 = self.client.get("/v1/data/file/3/") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response2.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(response3.status_code, status.HTTP_200_OK) + + # Authenticated user can upload data + def test_upload_authenticated(self): + token = self.client.post("/auth/login/", data=self.login_data_1) + file = open(find("cyl_testdata1.txt"), "rb") + data = {"file": file, "is_public": False} + response = self.client.post( + "/v1/data/file/", data=data, headers=auth_header(token) + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual( + response.data, + { + "current_user": "testUser", + "authenticated": True, + "file_id": 4, + "file_alternative_name": "cyl_testdata1.txt", + "is_public": False, + }, + ) + DataFile.objects.get(id=4).delete() + + # Unauthenticated user can upload public data only + def test_upload_unauthenticated(self): + file = open(find("cyl_testdata2.txt"), "rb") + file2 = open(find("cyl_testdata2.txt"), "rb") + data = {"file": file, "is_public": True} + data2 = {"file": file2, "is_public": False} + response = self.client.post("/v1/data/file/", data=data) + response2 = self.client.post("/v1/data/file/", data=data2) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual( + response.data, + { + "current_user": "", + "authenticated": False, + "file_id": 4, + "file_alternative_name": "cyl_testdata2.txt", + "is_public": True, + }, + ) + self.assertEqual(response2.status_code, status.HTTP_400_BAD_REQUEST) + + # Authenticated user can update own data + def test_upload_put_authenticated(self): + token = self.client.post("/auth/login/", data=self.login_data_1) + data = {"is_public": False} + response = self.client.put( + "/v1/data/file/2/", data=data, headers=auth_header(token) + ) + response2 = self.client.put( + "/v1/data/file/3/", data=data, headers=auth_header(token) + ) + self.assertEqual( + response.data, + { + "current_user": "testUser", + "authenticated": True, + "file_id": 2, + "file_alternative_name": "cyl_400_20.txt", + "is_public": False, + }, + ) + self.assertEqual( + response2.data, + { + "current_user": "testUser", + "authenticated": True, + "file_id": 3, + "file_alternative_name": "cyl_testdata.txt", + "is_public": False, + }, + ) + DataFile.objects.get(id=3).is_public = True + + # Authenticated user cannot update unowned data + def test_upload_put_unauthorized(self): + token = self.client.post("/auth/login/", data=self.login_data_2) + file = open(find("cyl_400_40.txt")) + data = {"file": file, "is_public": False} + response = self.client.put( + "/v1/data/file/1/", data=data, headers=auth_header(token) + ) + response2 = self.client.put( + "/v1/data/file/2/", data=data, headers=auth_header(token) + ) + response3 = self.client.put( + "/v1/data/file/3/", data=data, headers=auth_header(token) + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response2.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response3.status_code, status.HTTP_403_FORBIDDEN) + + # Unauthenticated user cannot update data + def test_upload_put_unauthenticated(self): + file = open(find("cyl_400_40.txt")) + data = {"file": file, "is_public": False} + response = self.client.put("/v1/data/file/1/", data=data) + response2 = self.client.put("/v1/data/file/2/", data=data) + response3 = self.client.put("/v1/data/file/3/", data=data) + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(response2.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(response3.status_code, status.HTTP_401_UNAUTHORIZED) + + # Authenticated user can download public and own data + def test_download_authenticated(self): + token = self.client.post("/auth/login/", data=self.login_data_1) + response = self.client.get( + "/v1/data/file/1/", data={"download": True}, headers=auth_header(token) + ) + response2 = self.client.get( + "/v1/data/file/2/", data={"download": True}, headers=auth_header(token) + ) + response3 = self.client.get( + "/v1/data/file/3/", data={"download": True}, headers=auth_header(token) + ) + b"".join(response.streaming_content) + b"".join(response2.streaming_content) + b"".join(response3.streaming_content) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + self.assertEqual(response3.status_code, status.HTTP_200_OK) + + # Authenticated user cannot download others' data + def test_download_unauthorized(self): + token = self.client.post("/auth/login/", data=self.login_data_2) + response = self.client.get( + "/v1/data/file/2/", data={"download": True}, headers=auth_header(token) + ) + response2 = self.client.get( + "/v1/data/file/3/", data={"download": True}, headers=auth_header(token) + ) + b"".join(response2.streaming_content) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + + # Unauthenticated user cannot download private data + def test_download_unauthenticated(self): + response = self.client.get("/v1/data/file/1/", data={"download": True}) + response2 = self.client.get("/v1/data/file/2/", data={"download": True}) + response3 = self.client.get("/v1/data/file/3/", data={"download": True}) + b"".join(response.streaming_content) + b"".join(response3.streaming_content) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response2.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(response3.status_code, status.HTTP_200_OK) + + @classmethod + def tearDownClass(cls): + cls.user.delete() + cls.user2.delete() + cls.public_test_data.delete() + cls.private_test_data.delete() + cls.unowned_test_data.delete() + shutil.rmtree(settings.MEDIA_ROOT) diff --git a/sasdata/fair_database/fair_database/upload_example_data.py b/sasdata/fair_database/fair_database/upload_example_data.py new file mode 100644 index 000000000..79de203d4 --- /dev/null +++ b/sasdata/fair_database/fair_database/upload_example_data.py @@ -0,0 +1,46 @@ +import logging +import os +from glob import glob + +import requests + +EXAMPLE_DATA_DIR = os.environ.get("EXAMPLE_DATA_DIR", "../../example_data") + + +def parse_1D(): + dir_1d = os.path.join(EXAMPLE_DATA_DIR, "1d_data") + if not os.path.isdir(dir_1d): + logging.error(f"1D Data directory not found at: {dir_1d}") + return + for file_path in glob(os.path.join(dir_1d, "*")): + upload_file(file_path) + + +def parse_2D(): + dir_2d = os.path.join(EXAMPLE_DATA_DIR, "2d_data") + if not os.path.isdir(dir_2d): + logging.error(f"2D Data directory not found at: {dir_2d}") + return + for file_path in glob(os.path.join(dir_2d, "*")): + upload_file(file_path) + + +def parse_sesans(): + sesans_dir = os.path.join(EXAMPLE_DATA_DIR, "sesans_data") + if not os.path.isdir(sesans_dir): + logging.error(f"Sesans Data directory not found at: {sesans_dir}") + return + for file_path in glob(os.path.join(sesans_dir, "*")): + upload_file(file_path) + + +def upload_file(file_path): + url = "http://localhost:8000/v1/data/file/" + file = open(file_path, "rb") + requests.request("POST", url, data={"is_public": True}, files={"file": file}) + + +if __name__ == "__main__": + parse_1D() + parse_2D() + parse_sesans() diff --git a/sasdata/fair_database/fair_database/urls.py b/sasdata/fair_database/fair_database/urls.py new file mode 100644 index 000000000..56c88ce21 --- /dev/null +++ b/sasdata/fair_database/fair_database/urls.py @@ -0,0 +1,42 @@ +""" +URL configuration for fair_database project. + +The `urlpatterns` list routes URLs to views. For more information please see: + https://docs.djangoproject.com/en/5.1/topics/http/urls/ +Examples: +Function views + 1. Add an import: from my_app import views + 2. Add a URL to urlpatterns: path('', views.home, name='home') +Class-based views + 1. Add an import: from other_app.views import Home + 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') +Including another URLconf + 1. Import the include() function: from django.urls import include, path + 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) +""" + +from django.contrib import admin +from django.urls import include, path, re_path +from drf_spectacular.views import ( + SpectacularAPIView, + SpectacularRedocView, + SpectacularSwaggerView, +) + +urlpatterns = [ + re_path(r"^(?P(v1))/data/", include("data.urls")), + path("admin/", admin.site.urls), + path("accounts/", include("allauth.urls")), # needed for social auth + path("auth/", include("user_app.urls")), + path("api/schema/", SpectacularAPIView.as_view(), name="schema"), + path( + "api/schema/swagger-ui/", + SpectacularSwaggerView.as_view(url_name="schema"), + name="swagger-ui", + ), + path( + "api/schema/redoc/", + SpectacularRedocView.as_view(url_name="schema"), + name="redoc", + ), +] diff --git a/sasdata/fair_database/fair_database/wsgi.py b/sasdata/fair_database/fair_database/wsgi.py new file mode 100644 index 000000000..5dfc4819c --- /dev/null +++ b/sasdata/fair_database/fair_database/wsgi.py @@ -0,0 +1,16 @@ +""" +WSGI config for fair_database project. + +It exposes the WSGI callable as a module-level variable named ``application``. + +For more information on this file, see +https://docs.djangoproject.com/en/5.1/howto/deployment/wsgi/ +""" + +import os + +from django.core.wsgi import get_wsgi_application + +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "fair_database.settings") + +application = get_wsgi_application() diff --git a/sasdata/fair_database/manage.py b/sasdata/fair_database/manage.py new file mode 100755 index 000000000..7d7e97246 --- /dev/null +++ b/sasdata/fair_database/manage.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +"""Django's command-line utility for administrative tasks.""" + +import os +import sys + + +def main(): + """Run administrative tasks.""" + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "fair_database.settings") + try: + from django.core.management import execute_from_command_line + except ImportError as exc: + raise ImportError( + "Couldn't import Django. Are you sure it's installed and " + "available on your PYTHONPATH environment variable? Did you " + "forget to activate a virtual environment?" + ) from exc + execute_from_command_line(sys.argv) + + +if __name__ == "__main__": + main() diff --git a/sasdata/fair_database/requirements.txt b/sasdata/fair_database/requirements.txt new file mode 100644 index 000000000..22b32934b --- /dev/null +++ b/sasdata/fair_database/requirements.txt @@ -0,0 +1,8 @@ +#this requirements extends the base sasview requirements files +#to get both you will need to run this after base requirements files +django +djangorestframework +dj-rest-auth +django-allauth +django-rest-knox +drf-spectacular diff --git a/sasdata/fair_database/user_app/__init__.py b/sasdata/fair_database/user_app/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sasdata/fair_database/user_app/admin.py b/sasdata/fair_database/user_app/admin.py new file mode 100644 index 000000000..846f6b406 --- /dev/null +++ b/sasdata/fair_database/user_app/admin.py @@ -0,0 +1 @@ +# Register your models here. diff --git a/sasdata/fair_database/user_app/apps.py b/sasdata/fair_database/user_app/apps.py new file mode 100644 index 000000000..83a29decf --- /dev/null +++ b/sasdata/fair_database/user_app/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class UserAppConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "user_app" diff --git a/sasdata/fair_database/user_app/migrations/__init__.py b/sasdata/fair_database/user_app/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sasdata/fair_database/user_app/models.py b/sasdata/fair_database/user_app/models.py new file mode 100644 index 000000000..6b2021999 --- /dev/null +++ b/sasdata/fair_database/user_app/models.py @@ -0,0 +1 @@ +# Create your models here. diff --git a/sasdata/fair_database/user_app/serializers.py b/sasdata/fair_database/user_app/serializers.py new file mode 100644 index 000000000..4993a7ab3 --- /dev/null +++ b/sasdata/fair_database/user_app/serializers.py @@ -0,0 +1,14 @@ +from dj_rest_auth.serializers import UserDetailsSerializer +from rest_framework import serializers + + +class KnoxSerializer(serializers.Serializer): + """ + Serializer for Knox authentication. + """ + + token = serializers.SerializerMethodField() + user = UserDetailsSerializer() + + def get_token(self, obj): + return obj["token"][1] diff --git a/sasdata/fair_database/user_app/tests.py b/sasdata/fair_database/user_app/tests.py new file mode 100644 index 000000000..8943ab40e --- /dev/null +++ b/sasdata/fair_database/user_app/tests.py @@ -0,0 +1,169 @@ +from django.contrib.auth.models import User +from django.test import TestCase +from rest_framework import status +from rest_framework.test import APIClient + + +# Create your tests here. +class AuthTests(TestCase): + """Tests for authentication endpoints.""" + + @classmethod + def setUpTestData(cls): + cls.client1 = APIClient() + cls.client2 = APIClient() + cls.register_data = { + "email": "email@domain.org", + "username": "testUser", + "password1": "sasview!", + "password2": "sasview!", + } + cls.login_data = { + "username": "testUser", + "email": "email@domain.org", + "password": "sasview!", + } + cls.login_data_2 = { + "username": "testUser2", + "email": "email2@domain.org", + "password": "sasview!", + } + cls.user = User.objects.create_user( + id=1, username="testUser2", password="sasview!", email="email2@domain.org" + ) + cls.client_authenticated = APIClient() + cls.client_authenticated.force_authenticate(user=cls.user) + + # Create an authentication header for a given token + def auth_header(self, response): + return {"Authorization": "Token " + response.data["token"]} + + # Test if registration successfully creates a new user and logs in + def test_register(self): + response = self.client1.post("/auth/register/", data=self.register_data) + user = User.objects.get(username="testUser") + response2 = self.client1.get("/auth/user/", headers=self.auth_header(response)) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(user.email, self.register_data["email"]) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + user.delete() + + # Test if login successful + def test_login(self): + response = self.client1.post("/auth/login/", data=self.login_data_2) + response2 = self.client1.get("/auth/user/", headers=self.auth_header(response)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + + # Test simultaneous login by multiple clients + def test_multiple_login(self): + response = self.client1.post("/auth/login/", data=self.login_data_2) + response2 = self.client2.post("/auth/login/", data=self.login_data_2) + response3 = self.client1.get("/auth/user/", headers=self.auth_header(response)) + response4 = self.client2.get("/auth/user/", headers=self.auth_header(response2)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + self.assertEqual(response3.status_code, status.HTTP_200_OK) + self.assertEqual(response4.status_code, status.HTTP_200_OK) + self.assertNotEqual(response.content, response2.content) + + # Test get user information + def test_user_get(self): + response = self.client_authenticated.get("/auth/user/") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual( + response.content, + b'{"pk":1,"username":"testUser2","email":"email2@domain.org","first_name":"","last_name":""}', + ) + + # Test changing username + def test_user_put_username(self): + data = {"username": "newName"} + response = self.client_authenticated.put("/auth/user/", data=data) + self.user.username = "testUser2" + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual( + response.content, + b'{"pk":1,"username":"newName","email":"email2@domain.org","first_name":"","last_name":""}', + ) + + # Test changing username and first and last name + def test_user_put_name(self): + data = {"username": "newName", "first_name": "Clark", "last_name": "Kent"} + response = self.client_authenticated.put("/auth/user/", data=data) + self.user.first_name = "" + self.user.last_name = "" + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual( + response.content, + b'{"pk":1,"username":"newName","email":"email2@domain.org","first_name":"Clark","last_name":"Kent"}', + ) + + # Test user info inaccessible when unauthenticated + def test_user_unauthenticated(self): + response = self.client1.get("/auth/user/") + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual( + response.content, + b'{"detail":"Authentication credentials were not provided."}', + ) + + # Test logout is successful after login + def test_login_logout(self): + self.client1.post("/auth/login/", data=self.login_data_2) + response = self.client1.post("/auth/logout/") + response2 = self.client1.get("/auth/user/") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.content, b'{"detail":"Successfully logged out."}') + self.assertEqual(response2.status_code, status.HTTP_401_UNAUTHORIZED) + + # Test logout is successful after registration + def test_register_logout(self): + self.client1.post("/auth/register/", data=self.register_data) + response = self.client1.post("/auth/logout/") + response2 = self.client1.get("/auth/user/") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.content, b'{"detail":"Successfully logged out."}') + self.assertEqual(response2.status_code, status.HTTP_401_UNAUTHORIZED) + User.objects.get(username="testUser").delete() + + # Test multiple logins for the same account log out independently + def test_multiple_logout(self): + self.client1.post("/auth/login/", data=self.login_data_2) + token = self.client2.post("/auth/login/", data=self.login_data_2) + response = self.client1.post("/auth/logout/") + response2 = self.client2.get("/auth/user/", headers=self.auth_header(token)) + response3 = self.client2.post("/auth/logout/") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + self.assertEqual(response3.status_code, status.HTTP_200_OK) + + # Test login is successful after registering then logging out + def test_register_login(self): + register_response = self.client1.post( + "/auth/register/", data=self.register_data + ) + logout_response = self.client1.post("/auth/logout/") + login_response = self.client1.post("/auth/login/", data=self.login_data) + self.assertEqual(register_response.status_code, status.HTTP_201_CREATED) + self.assertEqual(logout_response.status_code, status.HTTP_200_OK) + self.assertEqual(login_response.status_code, status.HTTP_200_OK) + User.objects.get(username="testUser").delete() + + # Test password is successfully changed + def test_password_change(self): + data = { + "new_password1": "sasview?", + "new_password2": "sasview?", + "old_password": "sasview!", + } + self.login_data_2["password"] = "sasview?" + response = self.client_authenticated.post("/auth/password/change/", data=data) + login_response = self.client1.post("/auth/login/", data=self.login_data_2) + self.login_data_2["password"] = "sasview!" + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(login_response.status_code, status.HTTP_200_OK) + + @classmethod + def tearDownClass(cls): + cls.user.delete() diff --git a/sasdata/fair_database/user_app/urls.py b/sasdata/fair_database/user_app/urls.py new file mode 100644 index 000000000..e393cb4b6 --- /dev/null +++ b/sasdata/fair_database/user_app/urls.py @@ -0,0 +1,15 @@ +from dj_rest_auth.views import LogoutView, PasswordChangeView, UserDetailsView +from django.urls import path + +from .views import KnoxLoginView, KnoxRegisterView + +"""Urls for authentication. Orcid login not functional. See settings.py for ORCID activation.""" + +urlpatterns = [ + path("register/", KnoxRegisterView.as_view(), name="register"), + path("login/", KnoxLoginView.as_view(), name="login"), + path("logout/", LogoutView.as_view(), name="logout"), + path("user/", UserDetailsView.as_view(), name="view user information"), + path("password/change/", PasswordChangeView.as_view(), name="change password"), + # path("login/orcid/", OrcidLoginView.as_view(), name="orcid login"), +] diff --git a/sasdata/fair_database/user_app/util.py b/sasdata/fair_database/user_app/util.py new file mode 100644 index 000000000..dc7b35026 --- /dev/null +++ b/sasdata/fair_database/user_app/util.py @@ -0,0 +1,7 @@ +from knox.models import AuthToken + + +# create an authentication token +def create_knox_token(token_model, user, serializer): + token = AuthToken.objects.create(user=user) + return token diff --git a/sasdata/fair_database/user_app/views.py b/sasdata/fair_database/user_app/views.py new file mode 100644 index 000000000..3a033add4 --- /dev/null +++ b/sasdata/fair_database/user_app/views.py @@ -0,0 +1,39 @@ +from allauth.account import app_settings as allauth_settings +from allauth.account.utils import complete_signup +from allauth.socialaccount.providers.orcid.views import OrcidOAuth2Adapter +from dj_rest_auth.registration.views import RegisterView, SocialLoginView +from dj_rest_auth.views import LoginView +from rest_framework.response import Response +from user_app.serializers import KnoxSerializer +from user_app.util import create_knox_token + +# Login using knox tokens rather than django-rest-framework tokens. + + +class KnoxLoginView(LoginView): + def get_response(self): + serializer_class = self.get_response_serializer() + + data = {"user": self.user, "token": self.token} + serializer = serializer_class(instance=data, context={"request": self.request}) + + return Response(serializer.data, status=200) + + +# Registration using knox tokens rather than django-rest-framework tokens. +class KnoxRegisterView(RegisterView): + def get_response_data(self, user): + return KnoxSerializer({"user": user, "token": self.token}).data + + def perform_create(self, serializer): + user = serializer.save(self.request) + self.token = create_knox_token(None, user, None) + complete_signup( + self.request._request, user, allauth_settings.EMAIL_VERIFICATION, None + ) + return user + + +# For ORCID login +class OrcidLoginView(SocialLoginView): + adapter_class = OrcidOAuth2Adapter diff --git a/sasdata/guess.py b/sasdata/guess.py index 9bc0a7ddd..93b0b3dd8 100644 --- a/sasdata/guess.py +++ b/sasdata/guess.py @@ -17,16 +17,8 @@ def guess_columns(col_count: int, dataset_type: DatasetType) -> list[str]: # Ideally we want an exact match but if the ordering is bigger than the col # count then we can accept that as well. for order_list in dataset_type.expected_orders: - if ( - len(order_list) >= col_count - or order_list == dataset_type.expected_orders[-1] - ): - return_value = order_list[:] - # If we have any extra columns than expected, then we'll just ignore them. - excess = col_count - len(order_list) - for _ in range(excess): - return_value.append("") - return return_value + if len(order_list) >= col_count: + return order_list return dataset_type.expected_orders[-1] diff --git a/sasdata/metadata.py b/sasdata/metadata.py index d53c3102c..da0ed9c7e 100644 --- a/sasdata/metadata.py +++ b/sasdata/metadata.py @@ -1,248 +1,159 @@ -""" -Contains classes describing the metadata for a scattering run - -The metadata is structures around the CANSas format version 1.1, found at -https://www.cansas.org/formats/canSAS1d/1.1/doc/specification.html -Metadata from other file formats should be massaged to fit into the data classes presented here. -Any useful metadata which cannot be included in these classes represent a bug in the CANSas format. - -""" - -import base64 -import json -import re -from dataclasses import dataclass, field, fields, is_dataclass -from typing import Any - -import h5py import numpy as np -from numpy import ndarray - -from sasdata.quantities.quantity import Quantity -from sasdata.quantities.unit_parser import parse_unit -from sasdata.quantities.units import NamedUnit - - -def from_json_quantity(obj: dict) -> Quantity | None: - if obj is None: - return None - - return Quantity(obj["value"], parse_unit(obj["units"])) - - -@dataclass(kw_only=True) -class Vec3: - """A three-vector of measured quantities""" - - x: Quantity[float] | None - y: Quantity[float] | None - z: Quantity[float] | None - - @staticmethod - def from_json(obj: dict) -> Quantity | None: - if obj is None: - return None - return Vec3( - x=from_json_quantity(obj["x"]), - y=from_json_quantity(obj["y"]), - z=from_json_quantity(obj["z"]), - ) +from numpy.typing import ArrayLike - def as_h5(self, f: h5py.Group): - """Export data onto an HDF5 group""" - if self.x: - self.x.as_h5(f, "x") - if self.y: - self.y.as_h5(f, "y") - if self.z: - self.z.as_h5(f, "z") +import sasdata.quantities.units as units +from sasdata.quantities.absolute_temperature import AbsoluteTemperatureAccessor +from sasdata.quantities.accessors import ( + AccessorTarget, + AngleAccessor, + FloatAccessor, + LengthAccessor, + QuantityAccessor, + StringAccessor, +) -@dataclass(kw_only=True) -class Rot3: - """A measured rotation in 3-space""" - - roll: Quantity[float] | None - pitch: Quantity[float] | None - yaw: Quantity[float] | None - - @staticmethod - def from_json(obj: dict) -> Quantity | None: - if obj is None: - return None - return Rot3( - roll=from_json_quantity(obj["roll"]), - pitch=from_json_quantity(obj["pitch"]), - yaw=from_json_quantity(obj["yaw"]), - ) - - def as_h5(self, f: h5py.Group): - """Export data onto an HDF5 group""" - if self.roll: - self.roll.as_h5(f, "roll") - if self.pitch: - self.pitch.as_h5(f, "pitch") - if self.yaw: - self.yaw.as_h5(f, "yaw") - - -@dataclass(kw_only=True) class Detector: """ Detector information """ - name: str | None - distance: Quantity[float] | None - offset: Vec3 | None - orientation: Rot3 | None - beam_center: Vec3 | None - pixel_size: Vec3 | None - slit_length: Quantity[float] | None + def __init__(self, target_object: AccessorTarget): + + # Name of the instrument [string] + self.name = StringAccessor(target_object, "name") + + # Sample to detector distance [float] [mm] + self.distance = LengthAccessor[float](target_object, + "distance", + "distance.units", + default_unit=units.millimeters) + + # Offset of this detector position in X, Y, + # (and Z if necessary) [Vector] [mm] + self.offset = LengthAccessor[ArrayLike](target_object, + "offset", + "offset.units", + default_unit=units.millimeters) + + self.orientation = AngleAccessor[ArrayLike](target_object, + "orientation", + "orientation.units", + default_unit=units.degrees) + + self.beam_center = LengthAccessor[ArrayLike](target_object, + "beam_center", + "beam_center.units", + default_unit=units.millimeters) + + # Pixel size in X, Y, (and Z if necessary) [Vector] [mm] + self.pixel_size = LengthAccessor[ArrayLike](target_object, + "pixel_size", + "pixel_size.units", + default_unit=units.millimeters) + + # Slit length of the instrument for this detector.[float] [mm] + self.slit_length = LengthAccessor[float](target_object, + "slit_length", + "slit_length.units", + default_unit=units.millimeters) def summary(self): - return ( - f"Detector:\n" - f" Name: {self.name}\n" - f" Distance: {self.distance}\n" - f" Offset: {self.offset}\n" - f" Orientation: {self.orientation}\n" - f" Beam center: {self.beam_center}\n" - f" Pixel size: {self.pixel_size}\n" - f" Slit length: {self.slit_length}\n" - ) + return (f"Detector:\n" + f" Name: {self.name.value}\n" + f" Distance: {self.distance.value}\n" + f" Offset: {self.offset.value}\n" + f" Orientation: {self.orientation.value}\n" + f" Beam center: {self.beam_center.value}\n" + f" Pixel size: {self.pixel_size.value}\n" + f" Slit length: {self.slit_length.value}\n") - @staticmethod - def from_json(obj): - return Detector( - name=obj["name"], - distance=from_json_quantity(obj["distance"]), - offset=Vec3.from_json(obj["offset"]), - orientation=Rot3.from_json(obj["orientation"]), - beam_center=Vec3.from_json(obj["beam_center"]), - pixel_size=Vec3.from_json(obj["pixel_size"]), - slit_length=from_json_quantity(obj["slit_length"]), - ) +class Aperture: - def as_h5(self, group: h5py.Group): - """Export data onto an HDF5 group""" - if self.name is not None: - group.create_dataset("name", data=[self.name]) - if self.distance: - self.distance.as_h5(group, "SDD") - if self.offset: - self.offset.as_h5(group.create_group("offset")) - if self.orientation: - self.orientation.as_h5(group.create_group("orientation")) - if self.beam_center: - self.beam_center.as_h5(group.create_group("beam_center")) - if self.pixel_size: - self.pixel_size.as_h5(group.create_group("pixel_size")) - if self.slit_length: - self.slit_length.as_h5(group, "slit_length") + def __init__(self, target_object: AccessorTarget): -@dataclass(kw_only=True) -class Aperture: - distance: Quantity[float] | None - size: Vec3 | None - size_name: str | None - name: str | None - type_: str | None + # Name + self.name = StringAccessor(target_object, "name") - def summary(self): - return ( - f" Aperture:\n" - f" Name: {self.name}\n" - f" Aperture size: {self.size}\n" - f" Aperture distance: {self.distance}\n" - ) + # Type + self.type = StringAccessor(target_object, "type") - @staticmethod - def from_json(obj): - return Aperture( - distance=from_json_quantity(obj["distance"]), - size=Vec3.from_json(obj["size"]), - size_name=obj["size_name"], - name=obj["name"], - type_=obj["type"], - ) + # Size name - TODO: What is the name of a size + self.size_name = StringAccessor(target_object, "size_name") + # Aperture size [Vector] # TODO: Wat!?! + self.size = QuantityAccessor[ArrayLike](target_object, + "size", + "size.units", + default_unit=units.millimeters) - def as_h5(self, group: h5py.Group): - """Export data onto an HDF5 group""" - if self.distance is not None: - self.distance.as_h5(group, "distance") - if self.name is not None: - group.attrs["name"] = self.name - if self.type_ is not None: - group.attrs["type"] = self.type_ - if self.size: - size_group = group.create_group("size") - self.size.as_h5(size_group) - if self.size_name is not None: - size_group.attrs["name"] = self.size_name + # Aperture distance [float] + self.distance = LengthAccessor[float](target_object, + "distance", + "distance.units", + default_unit=units.millimeters) + def summary(self): + return (f"Aperture:\n" + f" Name: {self.name.value}\n" + f" Aperture size: {self.size.value}\n" + f" Aperture distance: {self.distance.value}\n") -@dataclass(kw_only=True) class Collimation: """ Class to hold collimation information """ - length: Quantity[float] | None - apertures: list[Aperture] - - def summary(self): - return f"Collimation:\n Length: {self.length}\n" + "".join([a.summary() for a in self.apertures]) + def __init__(self, target_object: AccessorTarget): - @staticmethod - def from_json(obj): - return Collimation( - length=from_json_quantity(obj["length"]) if obj["length"] else None, - apertures=list(map(Aperture.from_json, obj["apertures"])), - ) + # Name + self.name = StringAccessor(target_object, "name") + # Length [float] [mm] + self.length = LengthAccessor[float](target_object, + "length", + "length.units", + default_unit=units.millimeters) - def as_h5(self, group: h5py.Group): - """Export data onto an HDF5 group""" - if self.length: - self.length.as_h5(group, "length") - for idx, a in enumerate(self.apertures): - a.as_h5(group.create_group(f"sasaperture{idx:02d}")) + # Todo - how do we handle this + # self.collimator = Collimation(target_object) -@dataclass(kw_only=True) -class BeamSize: - name: str | None - size: Vec3 | None + def summary(self): - @staticmethod - def from_json(obj): - return BeamSize(name=obj["name"], size=Vec3.from_json(obj["size"])) + #TODO collimation stuff + return ( + f"Collimation:\n" + f" Length: {self.length.value}\n") - def as_h5(self, group: h5py.Group): - """Export data onto an HDF5 group""" - if self.name: - group.attrs["name"] = self.name - if self.size: - self.size.as_h5(group) +@dataclass +class BeamSize: + name: Optional[str] + x: Optional[Quantity[float]] + y: Optional[Quantity[float]] + z: Optional[Quantity[float]] -@dataclass(kw_only=True) +@dataclass class Source: - radiation: str | None - beam_shape: str | None - beam_size: BeamSize | None - wavelength: Quantity[float] | None - wavelength_min: Quantity[float] | None - wavelength_max: Quantity[float] | None - wavelength_spread: Quantity[float] | None + radiation: str + beam_shape: str + beam_size: Optional[BeamSize] + wavelength : Quantity[float] + wavelength_min : Quantity[float] + wavelength_max : Quantity[float] + wavelength_spread : Quantity[float] def summary(self) -> str: + if self.radiation is None and self.type.value and self.probe_particle.value: + radiation = f"{self.type.value} {self.probe_particle.value}" + else: + radiation = f"{self.radiation}" + return ( f"Source:\n" - f" Radiation: {self.radiation}\n" + f" Radiation: {radiation}\n" f" Shape: {self.beam_shape}\n" f" Wavelength: {self.wavelength}\n" f" Min. Wavelength: {self.wavelength_min}\n" @@ -251,599 +162,177 @@ def summary(self) -> str: f" Beam Size: {self.beam_size}\n" ) - @staticmethod - def from_json(obj): - return Source( - radiation=obj["radiation"], - beam_shape=obj["beam_shape"], - beam_size=BeamSize.from_json(obj["beam_size"]) if obj["beam_size"] else None, - wavelength=obj["wavelength"], - wavelength_min=obj["wavelength_min"], - wavelength_max=obj["wavelength_max"], - wavelength_spread=obj["wavelength_spread"], - ) - - def as_h5(self, group: h5py.Group): - """Export data onto an HDF5 group""" - if self.radiation: - group.create_dataset("radiation", data=[self.radiation]) - if self.beam_shape: - group.create_dataset("beam_shape", data=[self.beam_shape]) - if self.beam_size: - self.beam_size.as_h5(group.create_group("beam_size")) - if self.wavelength: - self.wavelength.as_h5(group, "wavelength") - if self.wavelength_min: - self.wavelength_min.as_h5(group, "wavelength_min") - if self.wavelength_max: - self.wavelength_max.as_h5(group, "wavelength_max") - if self.wavelength_spread: - self.wavelength_spread.as_h5(group, "wavelength_spread") - - +""" +Definitions of radiation types +""" +NEUTRON = 'neutron' +XRAY = 'x-ray' +MUON = 'muon' +ELECTRON = 'electron' -@dataclass(kw_only=True) class Sample: """ Class to hold the sample description """ + def __init__(self, target_object: AccessorTarget): - name: str | None - sample_id: str | None - thickness: Quantity[float] | None - transmission: float | None - temperature: Quantity[float] | None - position: Vec3 | None - orientation: Rot3 | None - details: list[str] + # Short name for sample + self.name = StringAccessor(target_object, "name") + # ID - def summary(self) -> str: - return ( - f"Sample:\n" - f" ID: {self.sample_id}\n" - f" Transmission: {self.transmission}\n" - f" Thickness: {self.thickness}\n" - f" Temperature: {self.temperature}\n" - f" Position: {self.position}\n" - f" Orientation: {self.orientation}\n" - ) + self.sample_id = StringAccessor(target_object, "id") - @staticmethod - def from_json(obj): - return Sample( - name=obj["name"], - sample_id=obj["sample_id"], - thickness=obj["thickness"], - transmission=obj["transmission"], - temperature=obj["temperature"], - position=obj["position"], - orientation=obj["orientation"], - details=obj["details"], - ) + # Thickness [float] [mm] + self.thickness = LengthAccessor(target_object, + "thickness", + "thickness.units", + default_unit=units.millimeters) - def as_h5(self, f: h5py.Group): - """Export data onto an HDF5 group""" - if self.name is not None: - f.attrs["name"] = self.name - if self.sample_id is not None: - f.create_dataset("ID", data=[self.sample_id]) - if self.thickness: - self.thickness.as_h5(f, "thickness") - if self.transmission is not None: - f.create_dataset("transmission", data=[self.transmission]) - if self.temperature: - self.temperature.as_h5(f, "temperature") - if self.position: - self.position.as_h5(f.create_group("position")) - if self.orientation: - self.orientation.as_h5(f.create_group("orientation")) - if self.details: - f.create_dataset("details", data=self.details) + # Transmission [float] [fraction] + self.transmission = FloatAccessor(target_object,"transmission") + + # Temperature [float] [No Default] + self.temperature = AbsoluteTemperatureAccessor(target_object, + "temperature", + "temperature.unit", + default_unit=units.kelvin) + # Position [Vector] [mm] + self.position = LengthAccessor[ArrayLike](target_object, + "position", + "position.unit", + default_unit=units.millimeters) + + # Orientation [Vector] [degrees] + self.orientation = AngleAccessor[ArrayLike](target_object, + "orientation", + "orientation.unit", + default_unit=units.degrees) + + # Details + self.details = StringAccessor(target_object, "details") + + + # SESANS zacceptance + zacceptance = (0,"") + yacceptance = (0,"") + + def summary(self) -> str: + return (f"Sample:\n" + f" ID: {self.sample_id.value}\n" + f" Transmission: {self.transmission.value}\n" + f" Thickness: {self.thickness.value}\n" + f" Temperature: {self.temperature.value}\n" + f" Position: {self.position.value}\n" + f" Orientation: {self.orientation.value}\n") + # + # _str += " Details:\n" + # for item in self.details: + # _str += " %s\n" % item + # + # return _str -@dataclass(kw_only=True) class Process: """ Class that holds information about the processes performed on the data. """ + def __init__(self, target_object: AccessorTarget): + self.name = StringAccessor(target_object, "name") + self.date = StringAccessor(target_object, "date") + self.description = StringAccessor(target_object, "description") + + #TODO: It seems like these might be lists of strings, this should be checked - name: str | None - date: str | None - description: str | None - terms: dict[str, str | Quantity[float]] - notes: list[str] + self.term = StringAccessor(target_object, "term") + self.notes = StringAccessor(target_object, "notes") def single_line_desc(self): """ - Return a single line string representing the process + Return a single line string representing the process """ return f"{self.name.value} {self.date.value} {self.description.value}" def summary(self): - if self.terms: - termInfo = " Terms:\n" + "\n".join([f" {k}: {v}" for k, v in self.terms.items()]) + "\n" - else: - termInfo = "" - - if self.notes: - noteInfo = " Notes:\n" + "\n".join([f" {note}" for note in self.notes]) + "\n" - else: - noteInfo = "" - - return ( - f"Process:\n" - f" Name: {self.name}\n" - f" Date: {self.date}\n" - f" Description: {self.description}\n" - f"{termInfo}" - f"{noteInfo}" - ) - - @staticmethod - def from_json(obj): - return Process( - name=obj["name"], - date=obj["date"], - description=obj["description"], - terms=obj["terms"], - notes=obj["notes"], - ) - - def as_h5(self, group: h5py.Group): - """Export data onto an HDF5 group""" - if self.name is not None: - group.create_dataset("name", data=[self.name]) - if self.date is not None: - group.create_dataset("date", data=[self.date]) - if self.description is not None: - group.create_dataset("description", data=[self.description]) - if self.terms: - for idx, (term, value) in enumerate(self.terms.items()): - node = group.create_group(f"term{idx:02d}") - node.attrs["name"] = term - if type(value) is Quantity: - node.attrs["value"] = value.value - node.attrs["unit"] = value.units.symbol - else: - node.attrs["value"] = value - for idx, note in enumerate(self.notes): - group.create_dataset(f"note{idx:02d}", data=[note]) + return (f"Process:\n" + f" Name: {self.name.value}\n" + f" Date: {self.date.value}\n" + f" Description: {self.description.value}\n" + f" Term: {self.term.value}\n" + f" Notes: {self.notes.value}\n" + ) @dataclass class Instrument: - collimations: list[Collimation] - source: Source | None - detector: list[Detector] + collimations : list[Collimation] + source : Source + detector : list[Detector] def summary(self): return ( - "\n".join([c.summary() for c in self.collimations]) - + "".join([d.summary() for d in self.detector]) - + (self.source.summary() if self.source is not None else "") - ) + self.aperture.summary() + + self.collimation.summary() + + self.detector.summary() + + self.source.summary()) - @staticmethod - def from_json(obj): - return Instrument( - collimations=list(map(Collimation.from_json, obj["collimations"])), - source=Source.from_json(obj["source"]), - detector=list(map(Detector.from_json, obj["detector"])), - ) +def decode_string(data): + """ This is some crazy stuff""" - def as_h5(self, group: h5py.Group): - """Export data onto an HDF5 group""" - if self.source: - self.source.as_h5(group.create_group("sassource")) - for idx, c in enumerate(self.collimations): - c.as_h5(group.create_group(f"sascollimation{idx:02d}")) - for idx, d in enumerate(self.detector): - d.as_h5(group.create_group(f"sasdetector{idx:02d}")) + if isinstance(data, str): + return data + elif isinstance(data, np.ndarray): -@dataclass(kw_only=True) -class MetaNode: - name: str - attrs: dict[str, str] - contents: str | Quantity | ndarray | list["MetaNode"] - - def to_string(self, header=""): - """Convert node to pretty printer string""" - if self.attrs: - attributes = f"\n{header} Attributes:\n" + "\n".join( - [f"{header} {k}: {v}" for k, v in self.attrs.items()] - ) - else: - attributes = "" - if self.contents: - if type(self.contents) is str: - children = f"\n{header} {self.contents}" - else: - children = "".join([n.to_string(header + " ") for n in self.contents]) - else: - children = "" - - return f"\n{header}{self.name}:{attributes}{children}" - - def filter(self, name: str) -> list[ndarray | Quantity | str]: - match self.contents: - case str() | ndarray() | Quantity(): - if name == self.name: - return [self.contents] - case list(): - return [y for x in self.contents for y in x.filter(name)] - case _: - raise RuntimeError(f"Cannot filter contents of type {type(self.contents)}: {self.contents}") - return [] - - def __eq__(self, other) -> bool: - """Custom equality overload needed since numpy arrays don't - play nicely with equality""" - match self.contents: - case ndarray(): - if not np.all(self.contents == other.contents): - return False - case Quantity(): - result = self.contents == other.contents - if type(result) is ndarray and not np.all(result): - return False - if type(result) is bool and not result: - return False - case _: - if self.contents != other.contents: - return False - for k, v in self.attrs.items(): - if k not in other.attrs: - return False - if type(v) is np.ndarray and np.any(v != other.attrs[k]): - return False - if type(v) is not np.ndarray and v != other.attrs[k]: - return False - return self.name == other.name - - @staticmethod - def from_json(obj): - def from_content(con): - match con: - case list(): - return list(map(MetaNode.from_json, con)) - case { - "type": "ndarray", - "dtype": dtype, - "encoding": "base64", - "contents": contents, - "shape": shape, - }: - return np.frombuffer(base64.b64decode(contents), dtype=dtype).reshape(shape) - case {"value": value, "units": units}: - return from_json_quantity({"value": from_content(value), "units": from_content(units)}) - case _: - return con - - return MetaNode( - name=obj["name"], - attrs={k: from_content(v) for k, v in obj["attrs"].items()}, - contents=from_content(obj["contents"]), - ) + if data.dtype == object: + data = data.reshape(-1) + data = data[0] -@dataclass(kw_only=True, eq=True) -class Metadata: - title: str | None - run: list[str] - definition: str | None - process: list[Process] - sample: Sample | None - instrument: Instrument | None - raw: MetaNode | None - - def summary(self): - run_string = str(self.run[0] if len(self.run) == 1 else self.run) - return ( - f" {self.title}, Run: {run_string}\n" - + " " - + "=" * len(str(self.title)) - + "=======" - + "=" * len(run_string) - + "\n\n" - + f"Definition: {self.title}\n" - + "".join([p.summary() for p in self.process]) - + (self.sample.summary() if self.sample else "") - + (self.instrument.summary() if self.instrument else "") - ) - - @staticmethod - def from_json(obj): - return Metadata( - title=obj["title"] if obj["title"] else None, - run=obj["run"], - definition=obj["definition"] if obj["definition"] else None, - process=[Process.from_json(p) for p in obj["process"]], - sample=Sample.from_json(obj["sample"]) if obj["sample"] else None, - instrument=Instrument.from_json(obj["instrument"]) if obj["instrument"] else None, - raw=MetaNode.from_json(obj["raw"]), - ) + if isinstance(data, bytes): + return data.decode("utf-8") - @property - def id_header(self): - """Generate a header for used in the unique_id for datasets""" - title = "" - if self.title is not None: - title = self.title - return f"{title}:{",".join(self.run)}" - - def as_h5(self, f: h5py.Group): - """Export data onto an HDF5 group""" - for idx, run in enumerate(self.run): - f.create_dataset(f"run{idx:02d}", data=[run]) - if self.title is not None: - f.create_dataset("title", data=[self.title]) - if self.definition is not None: - f.create_dataset("definition", data=[self.definition]) - if self.process: - for idx, process in enumerate(self.process): - name = f"sasprocess{idx:02d}" - process.as_h5(f.create_group(name)) - if self.sample: - self.sample.as_h5(f.create_group("sassample")) - if self.instrument: - self.instrument.as_h5(f.create_group("sasinstrument")) - # self.raw.as_h5(meta) if self.raw else None - - -class MetadataEncoder(json.JSONEncoder): - def default(self, obj): - match obj: - case None: - return None - case bytes(): - return obj.decode("utf-8") - case NamedUnit(): - return obj.name - case Quantity(): - return {"value": obj.value, "units": obj.units.ascii_symbol} - case ndarray(): - return { - "type": "ndarray", - "encoding": "base64", - "contents": base64.b64encode(obj.tobytes()).decode("utf-8"), - "dtype": obj.dtype.str, - "shape": obj.shape, - } - case Vec3(): - return { - "x": obj.x, - "y": obj.y, - "z": obj.z, - } - case Rot3(): - return { - "roll": obj.roll, - "pitch": obj.pitch, - "yaw": obj.yaw, - } - case Sample(): - return { - "name": obj.name, - "sample_id": obj.sample_id, - "thickness": obj.thickness, - "transmission": obj.transmission, - "temperature": obj.temperature, - "position": obj.position, - "orientation": obj.orientation, - "details": obj.details, - } - case Process(): - return { - "name": obj.name, - "date": obj.date, - "description": obj.description, - "terms": {k: obj.terms[k] for k in obj.terms}, - "notes": obj.notes, - } - case Aperture(): - return { - "distance": obj.distance, - "size": obj.size, - "size_name": obj.size_name, - "name": obj.name, - "type": obj.type_, - } - case Collimation(): - return { - "length": obj.length, - "apertures": [a for a in obj.apertures], - } - case BeamSize(): - return {"name": obj.name, "size": obj.size} - case Source(): - return { - "radiation": obj.radiation, - "beam_shape": obj.beam_shape, - "beam_size": obj.beam_size, - "wavelength": obj.wavelength, - "wavelength_min": obj.wavelength_min, - "wavelength_max": obj.wavelength_max, - "wavelength_spread": obj.wavelength_spread, - } - case Detector(): - return { - "name": obj.name, - "distance": obj.distance, - "offset": obj.offset, - "orientation": obj.orientation, - "beam_center": obj.beam_center, - "pixel_size": obj.pixel_size, - "slit_length": obj.slit_length, - } - case Instrument(): - return { - "collimations": [c for c in obj.collimations], - "source": obj.source, - "detector": [d for d in obj.detector], - } - case MetaNode(): - return {"name": obj.name, "attrs": obj.attrs, "contents": obj.contents} - case Metadata(): - return { - "title": obj.title, - "run": obj.run, - "definition": obj.definition, - "process": [p for p in obj.process], - "sample": obj.sample, - "instrument": obj.instrument, - "raw": obj.raw, - } - case _: - return super().default(obj) - - -def access_meta(obj: dataclass, key: str) -> Any | None: - """Use a string accessor to locate a key from within the data - object. - - The basic grammar of these accessors explicitly match the python - syntax for accessing the data. For example, to access the `name` - field within the object `person`, you would call - `access_meta(person, ".name")`. Similarly, lists and dicts are - access with square brackets. - - > assert access_meta(person, '.name') == person.name - > assert access_meta(person, '.phone.home') == person.phone.home - > assert access_meta(person, '.addresses[0].postal_code') == person.address[0].postal_code - > assert access_meta(person, '.children["Taylor"]') == person.children["Taylor"] - - Obviously, when the accessor is know ahead of time, `access_meta` - provides no benefit over directly retrieving the data. However, - when a data structure is loaded at runtime (e.g. the metadata of a - neutron scattering file), then it isn't possible to know in - advance the location of the specific value that the user desires. - `access_meta` allows the user to provide the location at runtime. - - This function returns `None` when the key is not a valid address - for any data within the structure. Since the leaf could be any - type that is not a list, dict, or dataclass, the return type of - the function is `Any | None`. - - The list of locations within a structure is given by the - `meta_tags` function. + return str(data) - """ - result = obj - while key != "": - match key: - case accessor if accessor.startswith("."): - for fld in fields(result): - field_string = f".{fld.name}" - if accessor.startswith(field_string): - key = accessor[len(field_string) :] - result = getattr(result, fld.name) - break - case index if (type(result) is list) and (matches := re.match(r"\[(\d+?)\](.*)", index)): - result = result[int(matches[1])] - key = matches[2] - case name if (type(result) is dict) and (matches := re.match(r'\["(.+)"\](.*)', name)): - result = result[matches[1]] - key = matches[2] - case _: - return None - return result - - -def meta_tags(obj: dataclass) -> list[str]: - """Find all leaf accessors from a data object. - - The function treats the passed in object as a tree. Lists, dicts, - and dataclasses are all treated as branches on the tree and any - other type is treated as a leaf. The function then returns a list - of strings, where each string is a "path" from the root of the - tree to one leaf. The structure of the path is designed to mimic - the python code to access that specific leaf value. - - These accessors allow us to treat accessing entries within a - structure as first class values. This list can then be presented - to the user to allow them to select specific information within - the larger structure. This is particularly important when plotting - against a specific date value within the structure. - - Example: - - >@dataclass - class Thermometer: - temperature: float - units: str - params: list - > item = Example() - > item.temperature = 273 - > item.units = "K" - > item.old_values = [{'date': '2025-08-12', 'temperature': 300'}] - > assert meta_tags(item) = ['.temperature', '.units', '.old_values[0]["date"]', '.old_values[0]["temperature"]'] - - The actual value of the leaf object specified by a path can be - retrieved with the `access_meta` function. - - """ - result = [] - items = [("", obj)] - while items: - path, item = items.pop() - match item: - case list(xs): - for idx, x in enumerate(xs): - items.append((f"{path}[{idx}]", x)) - case dict(xs): - for k, v in xs.items(): - items.append((f'{path}["{k}"]', v)) - case n if is_dataclass(n): - for fld in fields(item): - items.append((f"{path}.{fld.name}", getattr(item, fld.name))) - case _: - result.append(path) - return result + else: + return data.tobytes().decode("utf-8") + else: + return str(data) @dataclass(kw_only=True) -class TagCollection: - """The collected tags and their variability.""" - - singular: set[str] = field(default_factory=set) - variable: set[str] = field(default_factory=set) - - -def collect_tags(objs: list[dataclass]) -> TagCollection: - """Identify uniform and varying data within a groups of data objects - - The resulting TagCollection contains every accessor string that is - valid for every object in the `objs` list. For example, if - `obj.name` is a string for every `obj` in `objs`, then the string - ".name" will be present in one of the two sets in the tags - collection. - - To be more specific, if `obj.name` exists and has the same value - for every `obj` in `objs`, the string ".name" will be included in - the `singular` set. If there are at least two distinct values for - `obj.name`, then ".name" will be in the `variable` set. - - """ - if not objs: - return ([], []) - first = objs.pop() - terms = set(meta_tags(first)) - for obj in objs: - terms = terms.intersection(set(meta_tags(obj))) - - objs.append(first) - - result = TagCollection() - - for term in terms: - values = set([access_meta(obj, term) for obj in objs]) - if len(values) == 1: - result.singular.add(term) - else: - result.variable.add(term) +class Metadata: + def __init__(self, target: AccessorTarget): + self._target = target + + self.instrument = Instrument(target.with_path_prefix("sasinstrument|instrument")) + self.process = Process(target.with_path_prefix("sasprocess|process")) + self.sample = Sample(target.with_path_prefix("sassample|sample")) + self.transmission_spectrum = TransmissionSpectrum(target.with_path_prefix("sastransmission_spectrum|transmission_spectrum")) + + self._title = StringAccessor(target, "title") + self._run = StringAccessor(target, "run") + self._definition = StringAccessor(target, "definition") + + self.title: str = decode_string(self._title.value) + self.run: str = decode_string(self._run.value) + self.definition: str = decode_string(self._definition.value) + title: Optional[str] + run: list[str] + definition: Optional[str] + process: list[str] + sample: Optional[Sample] + instrument: Optional[Instrument] - return result + def summary(self): + return ( + f" {self.title}, Run: {self.run}\n" + + " " + "="*len(self.title) + + "=======" + + "="*len(self.run) + "\n\n" + + f"Definition: {self.title}\n" + + self.process.summary() + + self.sample.summary() + + (self.instrument.summary() if self.instrument else "")) diff --git a/sasdata/model_requirements.py b/sasdata/model_requirements.py index 3773c86ee..262f662e8 100644 --- a/sasdata/model_requirements.py +++ b/sasdata/model_requirements.py @@ -1,581 +1,23 @@ -from abc import ABC, abstractmethod -from functools import singledispatch -from typing import Self +from dataclasses import dataclass import numpy as np -from scipy.special import erf, j0 +from transforms.operation import Operation -from sasdata import dataset_types -from sasdata.data import SasData -from sasdata.quantities import units -from sasdata.quantities.quantity import Operation, Quantity +from sasdata.metadata import Metadata -class ModellingRequirements(ABC): - """Requirements that need to be passed to any modelling step""" - +@dataclass +class ModellingRequirements: + """ Requirements that need to be passed to any modelling step """ dimensionality: int operation: Operation - def __add__(self, other: Self) -> Self: - return self.compose(other) - - @singledispatch - def compose(self, other: Self) -> Self: - # Compose uses the reversed order - return compose(other, self) - - @abstractmethod - def preprocess_q(self, data: np.ndarray, full_data: SasData) -> np.ndarray: - """Transform the Q values before processing in the model""" - pass - - @abstractmethod - def postprocess_iq(self, data: np.ndarray, full_data: SasData) -> np.ndarray: - """Transform the I(Q) values after running the model""" + def from_qi_transformation(self, data: np.ndarray, metadata: Metadata) -> np.ndarray: + """ Transformation for going from qi to this data""" pass -class ComposeRequirements(ModellingRequirements): - """Composition of two models""" - - first: ModellingRequirements - second: ModellingRequirements - - def __init__(self, fst, snd): - self.first = fst - self.second = snd - - def preprocess_q(self, data: np.ndarray, full_data: SasData) -> np.ndarray: - """Perform both transformations in order""" - return self.second.preprocess_q( - self.first.preprocess_q(data, full_data), full_data - ) - - def postprocess_iq(self, data: np.ndarray, full_data: SasData) -> np.ndarray: - """Perform both transformations in order""" - return self.second.postprocess_iq( - self.first.postprocess_iq(data, full_data), full_data - ) - - -class SesansModel(ModellingRequirements): - """Perform Hankel transform for SESANS""" - - def preprocess_q( - self, spin_echo_length: np.ndarray, full_data: SasData - ) -> np.ndarray: - """Calculate the q values needed to perform the Hankel transform - - Note: this is undefined for the case when spin_echo_lengths contains - exactly one element and that values is zero. - - """ - if len(spin_echo_length) == 1: - q_min, q_max = ( - 0.01 * 2 * np.pi / spin_echo_length[-1], - 10 * 2 * np.pi / spin_echo_length[0], - ) - else: - # TODO: Why does q_min depend on the number of correlation lengths? - # TODO: Why does q_max depend on the correlation step size? - q_min = 0.1 * 2 * np.pi / (np.size(spin_echo_length) * spin_echo_length[-1]) - q_max = 2 * np.pi / (spin_echo_length[1] - spin_echo_length[0]) - - # TODO: Possibly make this adjustable - log_spacing = 1.0003 - self.q = np.exp(np.arange(np.log(q_min), np.log(q_max), np.log(log_spacing))) - - dq = np.diff(self.q) - dq = np.insert(dq, 0, dq[0]) - - self.H0 = dq / (2 * np.pi) * self.q - - self.H = np.outer(self.q, spin_echo_length) - j0(self.H, out=self.H) - self.H *= (dq * self.q / (2 * np.pi)).reshape((-1, 1)) - - reptheta = np.outer( - self.q, - full_data._data_contents["Wavelength"].in_units_of(units.angstroms) - / (2 * np.pi), - ) - # Note: Using inplace update with reptheta => arcsin(reptheta). - # When q L / 2 pi > 1 that means wavelength is too large to - # reach that q value at any angle. These should produce theta = NaN - # without any warnings. - # - # Reverse the condition to protect against NaN. We can't use - # theta > zaccept since all comparisons with NaN return False. - zaccept = [ - x.terms["zmax"] for x in full_data.metadata.process if "zmax" in x.terms - ][0] - with np.errstate(invalid="ignore"): - mask = ~(np.arcsin(reptheta) <= zaccept.in_units_of(units.radians)) - self.H[mask] = 0 - - return self.q - - def postprocess_iq(self, data: np.ndarray, full_data: SasData) -> np.ndarray: - """ - Apply the SESANS transform to the computed I(q) - """ - G0 = self.H0 @ data - G = self.H.T @ data - P = G - G0 - - return P - - -MINIMUM_RESOLUTION = 1e-8 -MINIMUM_ABSOLUTE_Q = 0.02 # relative to the minimum q in the data -# According to (Barker & Pedersen 1995 JAC), 2.5 sigma is a good limit. -# According to simulations with github.com:scattering/sansresolution.git -# it is better to use asymmetric bounds (2.5, 3.0) -PINHOLE_N_SIGMA = (2.5, 3.0) - - -class PinholeModel(ModellingRequirements): - """Perform a pin hole smearing""" - - def __init__(self, q_width: np.ndarray, nsigma: (float, float) = PINHOLE_N_SIGMA): - self.q_width = q_width - self.nsigma_low, self.nsigma_high = nsigma - - def preprocess_q(self, q: np.ndarray, full_data: SasData) -> np.ndarray: - """Perform smearing transform""" - self.q = q - q_min = np.min(self.q - self.nsigma_low * self.q_width) - q_max = np.max(self.q + self.nsigma_high * self.q_width) - - self.q_calc = linear_extrapolation(self.q, q_min, q_max) - - # Protect against models which are not defined for very low q. Limit - # the smallest q value evaluated (in absolute) to 0.02*min - cutoff = MINIMUM_ABSOLUTE_Q * np.min(self.q) - self.q_calc = self.q_calc[abs(self.q_calc) >= cutoff] - - # Build weight matrix from calculated q values - self.weight_matrix = pinhole_resolution( - self.q_calc, - self.q, - np.maximum(self.q_width, MINIMUM_RESOLUTION), - nsigma=(self.nsigma_low, self.nsigma_high), - ) - - return np.abs(self.q_calc) - - def postprocess_iq(self, data: np.ndarray, full_data: SasData) -> np.ndarray: - """Perform smearing transform""" - return self.weight_matrix.T @ data - - -class SlitModel(ModellingRequirements): - """Perform a slit smearing""" - - def __init__( - self, - q_length: np.ndarray, - q_width: np.ndarray, - nsigma: (float, float) = PINHOLE_N_SIGMA, - ): - self.q_length = q_length - self.q_width = q_width - self.nsigma_low, self.nsigma_high = nsigma - - def preprocess_q(self, q: np.ndarray, full_data: SasData) -> np.ndarray: - """Perform smearing transform""" - self.q = q - q_min = np.min(self.q - self.nsigma_low * self.q_width) - q_max = np.max(self.q + self.nsigma_high * self.q_width) - - self.q_calc = slit_extend_q(self.q, self.q_width, self.q_length) - - # Protect against models which are not defined for very low q. Limit - # the smallest q value evaluated (in absolute) to 0.02*min - cutoff = MINIMUM_ABSOLUTE_Q * np.min(self.q) - self.q_calc = self.q_calc[abs(self.q_calc) >= cutoff] - - # Build weight matrix from calculated q values - self.weight_matrix = slit_resolution( - self.q_calc, self.q, self.q_length, self.q_width - ) - - return np.abs(self.q_calc) - - def postprocess_iq(self, data: np.ndarray, full_data: SasData) -> np.ndarray: - """Perform smearing transform""" - return self.weight_matrix.T @ data - - -class NullModel(ModellingRequirements): - """A model that does nothing""" - - def compose(self, other: ModellingRequirements) -> ModellingRequirements: - return other - - def preprocess_q( - self, data: Quantity[np.ndarray], _full_data: SasData - ) -> np.ndarray: - """Do nothing""" - return data - - def postprocess_iq(self, data: np.ndarray, _full_data: SasData) -> np.ndarray: - """Do nothing""" - return data - - -def guess_requirements(data: SasData) -> ModellingRequirements: - """Use names of axes and units to guess what kind of processing needs to be done""" - if data.dataset_type == dataset_types.sesans: - return SesansModel() - pass - - -@singledispatch -def compose( - second: ModellingRequirements, first: ModellingRequirements -) -> ModellingRequirements: - """Compose to models together - - This function uses a reverse order so that it can perform dispatch on - the *second* term, since the classes already had a chance to dispatch - on the first parameter - - """ - return ComposeRequirements(first, second) - - -@compose.register -def _(second: NullModel, first: ModellingRequirements) -> ModellingRequirements: - """Null model is the identity element of composition""" - return first - - -@compose.register -def _(second: SesansModel, first: ModellingRequirements) -> ModellingRequirements: - match first: - case PinholeModel() | SlitModel(): - # To the first approximation, there is no slit smearing in SESANS data - return second - case _: - return ComposeRequirements(first, second) - - -def linear_extrapolation(q, q_min, q_max): - """ - Extrapolate *q* out to [*q_min*, *q_max*] using the step size in *q* as - a guide. Extrapolation below uses about the same size as the first - interval. Extrapolation above uses about the same size as the final - interval. - - Note that extrapolated values may be negative. - """ - q = np.sort(q) - delta_low = q[1] - q[0] if len(q) > 1 else 0 - n_low = int(np.ceil((q[0] - q_min) / delta_low)) if delta_low > 0 else 15 - q_low = np.linspace(q_min, q[0], n_low + 1)[:-1] if q_min + 2 * MINIMUM_RESOLUTION >= q[0] else [] - - delta_high = q[-1] - q[-2] if len(q) > 1 else 0 - n_high = int(np.ceil((q_max - q[-1]) / delta_high)) if delta_high > 0 else 15 - q_high = np.linspace(q[-1], q_max, n_high + 1)[1:] if q_max - 2 * MINIMUM_RESOLUTION <= q[-1] else [] - return np.concatenate([q_low, q, q_high]) - - -def pinhole_resolution( - q_calc: np.ndarray, q: np.ndarray, q_width: np.ndarray, nsigma=PINHOLE_N_SIGMA -) -> np.ndarray: - r""" - Compute the convolution matrix *W* for pinhole resolution 1-D data. - - Each row *W[i]* determines the normalized weight that the corresponding - points *q_calc* contribute to the resolution smeared point *q[i]*. Given - *W*, the resolution smearing can be computed using *dot(W,q)*. - - Note that resolution is limited to $\pm 2.5 \sigma$.[1] The true resolution - function is a broadened triangle, and does not extend over the entire - range $(-\infty, +\infty)$. It is important to impose this limitation - since some models fall so steeply that the weighted value in gaussian - tails would otherwise dominate the integral. - - *q_calc* must be increasing. *q_width* must be greater than zero. - - [1] Barker, J. G., and J. S. Pedersen. 1995. Instrumental Smearing Effects - in Radially Symmetric Small-Angle Neutron Scattering by Numerical and - Analytical Methods. Journal of Applied Crystallography 28 (2): 105--14. - https://doi.org/10.1107/S0021889894010095. - """ - # The current algorithm is a midpoint rectangle rule. In the test case, - # neither trapezoid nor Simpson's rule improved the accuracy. - edges = bin_edges(q_calc) - # edges[edges < 0.0] = 0.0 # clip edges below zero - cdf = erf((edges[:, None] - q[None, :]) / (np.sqrt(2.0) * q_width)[None, :]) - weights = cdf[1:] - cdf[:-1] - # Limit q range to (-2.5,+3) sigma - try: - nsigma_low, nsigma_high = nsigma - except TypeError: - nsigma_low = nsigma_high = nsigma - qhigh = q + nsigma_high * q_width - qlow = q - nsigma_low * q_width # linear limits - ##qlow = q*q/qhigh # log limits - weights[q_calc[:, None] < qlow[None, :]] = 0.0 - weights[q_calc[:, None] > qhigh[None, :]] = 0.0 - weights /= np.sum(weights, axis=0)[None, :] - return weights - - -def bin_edges(x): - """ - Determine bin edges from bin centers, assuming that edges are centered - between the bins. - - Note: this uses the arithmetic mean, which may not be appropriate for - log-scaled data. - """ - if len(x) < 2 or (np.diff(x) < 0).any(): - raise ValueError("Expected bins to be an increasing set") - edges = np.hstack( - [ - x[0] - 0.5 * (x[1] - x[0]), # first point minus half first interval - 0.5 * (x[1:] + x[:-1]), # mid points of all central intervals - x[-1] + 0.5 * (x[-1] - x[-2]), # last point plus half last interval - ] - ) - return edges - - -def slit_resolution(q_calc, q, width, length, n_length=30): - r""" - Build a weight matrix to compute *I_s(q)* from *I(q_calc)*, given - $q_\perp$ = *width* (in the high-resolution axis) and $q_\parallel$ - = *length* (in the low resolution axis). *n_length* is the number - of steps to use in the integration over $q_\parallel$ when both - $q_\perp$ and $q_\parallel$ are non-zero. - - Each $q$ can have an independent width and length value even though - current instruments use the same slit setting for all measured points. - - If slit length is large relative to width, use: - - .. math:: - - I_s(q_i) = \frac{1}{\Delta q_\perp} - \int_0^{\Delta q_\perp} - I\left(\sqrt{q_i^2 + q_\perp^2}\right) \,dq_\perp - - If slit width is large relative to length, use: - - .. math:: - - I_s(q_i) = \frac{1}{2 \Delta q_\parallel} - \int_{-\Delta q_\parallel}^{\Delta q_\parallel} - I\left(|q_i + q_\parallel|\right) \,dq_\parallel - - For a mixture of slit width and length use: - - .. math:: - - I_s(q_i) = \frac{1}{2 \Delta q_\parallel \Delta q_\perp} - \int_{-\Delta q_\parallel}^{\Delta q_\parallel} - \int_0^{\Delta q_\perp} - I\left(\sqrt{(q_i + q_\parallel)^2 + q_\perp^2}\right) - \,dq_\perp dq_\parallel - - **Definition** - - We are using the mid-point integration rule to assign weights to each - element of a weight matrix $W$ so that - - .. math:: - - I_s(q) = W\,I(q_\text{calc}) - - If *q_calc* is at the mid-point, we can infer the bin edges from the - pairwise averages of *q_calc*, adding the missing edges before - *q_calc[0]* and after *q_calc[-1]*. - - For $q_\parallel = 0$, the smeared value can be computed numerically - using the $u$ substitution - - .. math:: - - u_j = \sqrt{q_j^2 - q^2} - - This gives - - .. math:: - - I_s(q) \approx \sum_j I(u_j) \Delta u_j - - where $I(u_j)$ is the value at the mid-point, and $\Delta u_j$ is the - difference between consecutive edges which have been first converted - to $u$. Only $u_j \in [0, \Delta q_\perp]$ are used, which corresponds - to $q_j \in \left[q, \sqrt{q^2 + \Delta q_\perp}\right]$, so - - .. math:: - - W_{ij} = \frac{1}{\Delta q_\perp} \Delta u_j - = \frac{1}{\Delta q_\perp} \left( - \sqrt{q_{j+1}^2 - q_i^2} - \sqrt{q_j^2 - q_i^2} \right) - \ \text{if}\ q_j \in \left[q_i, \sqrt{q_i^2 + q_\perp^2}\right] - - where $I_s(q_i)$ is the theory function being computed and $q_j$ are the - mid-points between the calculated values in *q_calc*. We tweak the - edges of the initial and final intervals so that they lie on integration - limits. - - (To be precise, the transformed midpoint $u(q_j)$ is not necessarily the - midpoint of the edges $u((q_{j-1}+q_j)/2)$ and $u((q_j + q_{j+1})/2)$, - but it is at least in the interval, so the approximation is going to be - a little better than the left or right Riemann sum, and should be - good enough for our purposes.) - - For $q_\perp = 0$, the $u$ substitution is simpler: - - .. math:: - - u_j = \left|q_j - q\right| - - so - - .. math:: - - W_{ij} = \frac{1}{2 \Delta q_\parallel} \Delta u_j - = \frac{1}{2 \Delta q_\parallel} (q_{j+1} - q_j) - \ \text{if}\ q_j \in - \left[q-\Delta q_\parallel, q+\Delta q_\parallel\right] - - However, we need to support cases were $u_j < 0$, which means using - $2 (q_{j+1} - q_j)$ when $q_j \in \left[0, q_\parallel-q_i\right]$. - This is not an issue for $q_i > q_\parallel$. - - For both $q_\perp > 0$ and $q_\parallel > 0$ we perform a 2 dimensional - integration with - - .. math:: - - u_{jk} = \sqrt{q_j^2 - (q + (k\Delta q_\parallel/L))^2} - \ \text{for}\ k = -L \ldots L - - for $L$ = *n_length*. This gives - - .. math:: - - W_{ij} = \frac{1}{2 \Delta q_\perp q_\parallel} - \sum_{k=-L}^L \Delta u_{jk} - \left(\frac{\Delta q_\parallel}{2 L + 1}\right) - - - """ - - # The current algorithm is a midpoint rectangle rule. - q_edges = bin_edges(q_calc) # Note: requires q > 0 - weights = np.zeros((len(q), len(q_calc)), "d") - - for i, (qi, w, l) in enumerate(zip(q, width, length)): - if w == 0.0 and l == 0.0: - # Perfect resolution, so return the theory value directly. - # Note: assumes that q is a subset of q_calc. If qi need not be - # in q_calc, then we can do a weighted interpolation by looking - # up qi in q_calc, then weighting the result by the relative - # distance to the neighbouring points. - weights[i, :] = q_calc == qi - elif l == 0: - weights[i, :] = _q_perp_weights(q_edges, qi, w) - elif w == 0 and qi >= l: - in_x = 1.0 * ((q_calc >= qi - l) & (q_calc <= qi + l)) - weights[i, :] = in_x * np.diff(q_edges) / (2 * l) - elif w == 0: - in_x = 1.0 * ((q_calc >= qi - l) & (q_calc <= qi + l)) - abs_x = 1.0 * (q_calc < abs(qi - l)) - weights[i, :] = (in_x + abs_x) * np.diff(q_edges) / (2 * l) - else: - weights[i, :] = _q_perp_weights( - q_edges, qi + np.arange(-n_length, n_length + 1) * l / n_length, w - ) - weights[i, :] /= 2 * n_length + 1 - - return weights.T - - -def _q_perp_weights(q_edges, qi, w): - q_edges = np.reshape(q_edges, (1, -1)) - qi = np.reshape(qi, (-1, 1)) - # Convert bin edges from q to u - u_limit = np.sqrt(qi**2 + w**2) - u_edges = q_edges**2 - qi**2 - u_edges[q_edges < abs(qi)] = 0.0 - u_edges[q_edges > u_limit] = np.repeat( - u_limit**2 - qi**2, u_edges.shape[1], axis=1 - )[q_edges > u_limit] - return (np.diff(np.sqrt(u_edges), axis=1) / w).sum(axis=0) - - -def slit_extend_q(q, width, length): - """ - Given *q*, *width* and *length*, find a set of sampling points *q_calc* so - that each point I(q) has sufficient support from the underlying - function. - """ - q_min, q_max = np.min(q - length), np.max(np.sqrt((q + length) ** 2 + width**2)) - - return geometric_extrapolation(q, q_min, q_max) - - -def geometric_extrapolation(q, q_min, q_max, points_per_decade=None): - r""" - Extrapolate *q* to [*q_min*, *q_max*] using geometric steps, with the - average geometric step size in *q* as the step size. - - if *q_min* is zero or less then *q[0]/10* is used instead. - - *points_per_decade* sets the ratio between consecutive steps such - that there will be $n$ points used for every factor of 10 increase - in *q*. - - If *points_per_decade* is not given, it will be estimated as follows. - Starting at $q_1$ and stepping geometrically by $\Delta q$ to $q_n$ - in $n$ points gives a geometric average of: - - .. math:: - - \log \Delta q = (\log q_n - \log q_1) / (n - 1) - - From this we can compute the number of steps required to extend $q$ - from $q_n$ to $q_\text{max}$ by $\Delta q$ as: - - .. math:: - - n_\text{extend} = (\log q_\text{max} - \log q_n) / \log \Delta q - - Substituting: - .. math:: - n_\text{extend} = (n-1) (\log q_\text{max} - \log q_n) - / (\log q_n - \log q_1) - """ - DEFAULT_POINTS_PER_DECADE = 10 - q = np.sort(q) - data_min, data_max = q[0], q[-1] - if points_per_decade is None: - if data_max > data_min: - log_delta_q = (len(q) - 1) / (np.log(data_max) - np.log(data_min)) - else: - log_delta_q = np.log(10.0) / DEFAULT_POINTS_PER_DECADE - else: - log_delta_q = np.log(10.0) / points_per_decade - if q_min <= 0: - q_min = data_min * MINIMUM_ABSOLUTE_Q - if q_min < data_min: - n_low = int(np.ceil(log_delta_q * (np.log(data_min) - np.log(q_min)))) - q_low = np.logspace(np.log10(q_min), np.log10(data_min), n_low + 1)[:-1] - else: - q_low = [] - if q_max > data_max: - n_high = int(np.ceil(log_delta_q * (np.log(q_max) - np.log(data_max)))) - q_high = np.logspace(np.log10(data_max), np.log10(q_max), n_high + 1)[1:] - else: - q_high = [] - return np.concatenate([q_low, q, q_high]) +def guess_requirements(abscissae, ordinate) -> ModellingRequirements: + """ Use names of axes and units to guess what kind of processing needs to be done """ diff --git a/sasdata/postprocess.py b/sasdata/postprocess.py index 1d7866098..82ce61420 100644 --- a/sasdata/postprocess.py +++ b/sasdata/postprocess.py @@ -4,11 +4,6 @@ """ -import numpy as np - -from sasdata.data import SasData - - def fix_mantid_units_error(data: SasData) -> SasData: pass @@ -19,39 +14,3 @@ def apply_fixes(data: SasData, mantid_unit_error=True): data = fix_mantid_units_error(data) return data - - -def deduce_qz(data: SasData): - """Calculates and appends Qz to SasData if Qx, Qy, and wavelength are all present""" - # if Qz is not already in the dataset, but Qx and Qy are - if 'Qz' not in data._data_contents and 'Qx' in data._data_contents and 'Qy' in data._data_contents: - # we start by making the approximation that qz=0 - data._data_contents['Qz'] = 0*data._data_contents['Qx'] - - # now check if metadata has wavelength information - wavelength = getattr( - getattr( - getattr( - getattr(data, "metadata", None), - "instrument", - None - ), - "source", - None - ), - "wavelength", - None - ) - - if wavelength is not None: - # we can deduce the value of qz from qx and qy - # if we have the wavelength - qx = data._data_contents['Qx'] - qy = data._data_contents['Qy'] - - # this is how you convert qx, qy, and wavelength to qz - k0 = 2*np.pi/wavelength - qz = k0-(k0**2-qx**2-qy**2)**(0.5) - - data._data_contents['Qz'] = qz - diff --git a/sasdata/quantities/_autogen_warning.py b/sasdata/quantities/_autogen_warning.py index 9a8c9372e..fc8be67dc 100644 --- a/sasdata/quantities/_autogen_warning.py +++ b/sasdata/quantities/_autogen_warning.py @@ -5,75 +5,75 @@ Do not edit by hand, instead edit the files that build it (%s) - - -DDDDDDDDDDDDD NNNNNNNN NNNNNNNN tttt -D::::::::::::DDD N:::::::N N::::::N ttt:::t -D:::::::::::::::DD N::::::::N N::::::N t:::::t -DDD:::::DDDDD:::::D N:::::::::N N::::::N t:::::t - D:::::D D:::::D ooooooooooo N::::::::::N N::::::N ooooooooooo ttttttt:::::ttttttt - D:::::D D:::::D oo:::::::::::oo N:::::::::::N N::::::N oo:::::::::::oo t:::::::::::::::::t - D:::::D D:::::Do:::::::::::::::o N:::::::N::::N N::::::No:::::::::::::::ot:::::::::::::::::t - D:::::D D:::::Do:::::ooooo:::::o N::::::N N::::N N::::::No:::::ooooo:::::otttttt:::::::tttttt - D:::::D D:::::Do::::o o::::o N::::::N N::::N:::::::No::::o o::::o t:::::t - D:::::D D:::::Do::::o o::::o N::::::N N:::::::::::No::::o o::::o t:::::t - D:::::D D:::::Do::::o o::::o N::::::N N::::::::::No::::o o::::o t:::::t - D:::::D D:::::D o::::o o::::o N::::::N N:::::::::No::::o o::::o t:::::t tttttt -DDD:::::DDDDD:::::D o:::::ooooo:::::o N::::::N N::::::::No:::::ooooo:::::o t::::::tttt:::::t -D:::::::::::::::DD o:::::::::::::::o N::::::N N:::::::No:::::::::::::::o tt::::::::::::::t -D::::::::::::DDD oo:::::::::::oo N::::::N N::::::N oo:::::::::::oo tt:::::::::::tt -DDDDDDDDDDDDD ooooooooooo NNNNNNNN NNNNNNN ooooooooooo ttttttttttt - - - - - - - - - dddddddd -EEEEEEEEEEEEEEEEEEEEEE d::::::d iiii tttt BBBBBBBBBBBBBBBBB -E::::::::::::::::::::E d::::::d i::::i ttt:::t B::::::::::::::::B -E::::::::::::::::::::E d::::::d iiii t:::::t B::::::BBBBBB:::::B -EE::::::EEEEEEEEE::::E d:::::d t:::::t BB:::::B B:::::B + + +DDDDDDDDDDDDD NNNNNNNN NNNNNNNN tttt +D::::::::::::DDD N:::::::N N::::::N ttt:::t +D:::::::::::::::DD N::::::::N N::::::N t:::::t +DDD:::::DDDDD:::::D N:::::::::N N::::::N t:::::t + D:::::D D:::::D ooooooooooo N::::::::::N N::::::N ooooooooooo ttttttt:::::ttttttt + D:::::D D:::::D oo:::::::::::oo N:::::::::::N N::::::N oo:::::::::::oo t:::::::::::::::::t + D:::::D D:::::Do:::::::::::::::o N:::::::N::::N N::::::No:::::::::::::::ot:::::::::::::::::t + D:::::D D:::::Do:::::ooooo:::::o N::::::N N::::N N::::::No:::::ooooo:::::otttttt:::::::tttttt + D:::::D D:::::Do::::o o::::o N::::::N N::::N:::::::No::::o o::::o t:::::t + D:::::D D:::::Do::::o o::::o N::::::N N:::::::::::No::::o o::::o t:::::t + D:::::D D:::::Do::::o o::::o N::::::N N::::::::::No::::o o::::o t:::::t + D:::::D D:::::D o::::o o::::o N::::::N N:::::::::No::::o o::::o t:::::t tttttt +DDD:::::DDDDD:::::D o:::::ooooo:::::o N::::::N N::::::::No:::::ooooo:::::o t::::::tttt:::::t +D:::::::::::::::DD o:::::::::::::::o N::::::N N:::::::No:::::::::::::::o tt::::::::::::::t +D::::::::::::DDD oo:::::::::::oo N::::::N N::::::N oo:::::::::::oo tt:::::::::::tt +DDDDDDDDDDDDD ooooooooooo NNNNNNNN NNNNNNN ooooooooooo ttttttttttt + + + + + + + + + dddddddd +EEEEEEEEEEEEEEEEEEEEEE d::::::d iiii tttt BBBBBBBBBBBBBBBBB +E::::::::::::::::::::E d::::::d i::::i ttt:::t B::::::::::::::::B +E::::::::::::::::::::E d::::::d iiii t:::::t B::::::BBBBBB:::::B +EE::::::EEEEEEEEE::::E d:::::d t:::::t BB:::::B B:::::B E:::::E EEEEEE ddddddddd:::::d iiiiiiittttttt:::::ttttttt B::::B B:::::Byyyyyyy yyyyyyy - E:::::E dd::::::::::::::d i:::::it:::::::::::::::::t B::::B B:::::B y:::::y y:::::y - E::::::EEEEEEEEEE d::::::::::::::::d i::::it:::::::::::::::::t B::::BBBBBB:::::B y:::::y y:::::y - E:::::::::::::::E d:::::::ddddd:::::d i::::itttttt:::::::tttttt B:::::::::::::BB y:::::y y:::::y - E:::::::::::::::E d::::::d d:::::d i::::i t:::::t B::::BBBBBB:::::B y:::::y y:::::y - E::::::EEEEEEEEEE d:::::d d:::::d i::::i t:::::t B::::B B:::::B y:::::y y:::::y - E:::::E d:::::d d:::::d i::::i t:::::t B::::B B:::::B y:::::y:::::y - E:::::E EEEEEEd:::::d d:::::d i::::i t:::::t tttttt B::::B B:::::B y:::::::::y -EE::::::EEEEEEEE:::::Ed::::::ddddd::::::ddi::::::i t::::::tttt:::::t BB:::::BBBBBB::::::B y:::::::y -E::::::::::::::::::::E d:::::::::::::::::di::::::i tt::::::::::::::t B:::::::::::::::::B y:::::y -E::::::::::::::::::::E d:::::::::ddd::::di::::::i tt:::::::::::tt B::::::::::::::::B y:::::y -EEEEEEEEEEEEEEEEEEEEEE ddddddddd dddddiiiiiiii ttttttttttt BBBBBBBBBBBBBBBBB y:::::y - y:::::y - y:::::y - y:::::y - y:::::y - yyyyyyy - - - - dddddddd -HHHHHHHHH HHHHHHHHH d::::::d -H:::::::H H:::::::H d::::::d -H:::::::H H:::::::H d::::::d -HH::::::H H::::::HH d:::::d - H:::::H H:::::H aaaaaaaaaaaaa nnnn nnnnnnnn ddddddddd:::::d - H:::::H H:::::H a::::::::::::a n:::nn::::::::nn dd::::::::::::::d - H::::::HHHHH::::::H aaaaaaaaa:::::an::::::::::::::nn d::::::::::::::::d - H:::::::::::::::::H a::::ann:::::::::::::::nd:::::::ddddd:::::d - H:::::::::::::::::H aaaaaaa:::::a n:::::nnnn:::::nd::::::d d:::::d - H::::::HHHHH::::::H aa::::::::::::a n::::n n::::nd:::::d d:::::d - H:::::H H:::::H a::::aaaa::::::a n::::n n::::nd:::::d d:::::d - H:::::H H:::::H a::::a a:::::a n::::n n::::nd:::::d d:::::d -HH::::::H H::::::HHa::::a a:::::a n::::n n::::nd::::::ddddd::::::dd -H:::::::H H:::::::Ha:::::aaaa::::::a n::::n n::::n d:::::::::::::::::d -H:::::::H H:::::::H a::::::::::aa:::a n::::n n::::n d:::::::::ddd::::d -HHHHHHHHH HHHHHHHHH aaaaaaaaaa aaaa nnnnnn nnnnnn ddddddddd ddddd - + E:::::E dd::::::::::::::d i:::::it:::::::::::::::::t B::::B B:::::B y:::::y y:::::y + E::::::EEEEEEEEEE d::::::::::::::::d i::::it:::::::::::::::::t B::::BBBBBB:::::B y:::::y y:::::y + E:::::::::::::::E d:::::::ddddd:::::d i::::itttttt:::::::tttttt B:::::::::::::BB y:::::y y:::::y + E:::::::::::::::E d::::::d d:::::d i::::i t:::::t B::::BBBBBB:::::B y:::::y y:::::y + E::::::EEEEEEEEEE d:::::d d:::::d i::::i t:::::t B::::B B:::::B y:::::y y:::::y + E:::::E d:::::d d:::::d i::::i t:::::t B::::B B:::::B y:::::y:::::y + E:::::E EEEEEEd:::::d d:::::d i::::i t:::::t tttttt B::::B B:::::B y:::::::::y +EE::::::EEEEEEEE:::::Ed::::::ddddd::::::ddi::::::i t::::::tttt:::::t BB:::::BBBBBB::::::B y:::::::y +E::::::::::::::::::::E d:::::::::::::::::di::::::i tt::::::::::::::t B:::::::::::::::::B y:::::y +E::::::::::::::::::::E d:::::::::ddd::::di::::::i tt:::::::::::tt B::::::::::::::::B y:::::y +EEEEEEEEEEEEEEEEEEEEEE ddddddddd dddddiiiiiiii ttttttttttt BBBBBBBBBBBBBBBBB y:::::y + y:::::y + y:::::y + y:::::y + y:::::y + yyyyyyy + + + + dddddddd +HHHHHHHHH HHHHHHHHH d::::::d +H:::::::H H:::::::H d::::::d +H:::::::H H:::::::H d::::::d +HH::::::H H::::::HH d:::::d + H:::::H H:::::H aaaaaaaaaaaaa nnnn nnnnnnnn ddddddddd:::::d + H:::::H H:::::H a::::::::::::a n:::nn::::::::nn dd::::::::::::::d + H::::::HHHHH::::::H aaaaaaaaa:::::an::::::::::::::nn d::::::::::::::::d + H:::::::::::::::::H a::::ann:::::::::::::::nd:::::::ddddd:::::d + H:::::::::::::::::H aaaaaaa:::::a n:::::nnnn:::::nd::::::d d:::::d + H::::::HHHHH::::::H aa::::::::::::a n::::n n::::nd:::::d d:::::d + H:::::H H:::::H a::::aaaa::::::a n::::n n::::nd:::::d d:::::d + H:::::H H:::::H a::::a a:::::a n::::n n::::nd:::::d d:::::d +HH::::::H H::::::HHa::::a a:::::a n::::n n::::nd::::::ddddd::::::dd +H:::::::H H:::::::Ha:::::aaaa::::::a n::::n n::::n d:::::::::::::::::d +H:::::::H H:::::::H a::::::::::aa:::a n::::n n::::n d:::::::::ddd::::d +HHHHHHHHH HHHHHHHHH aaaaaaaaaa aaaa nnnnnn nnnnnn ddddddddd ddddd + """ diff --git a/sasdata/quantities/_build_tables.py b/sasdata/quantities/_build_tables.py index ded60c5d3..5af6a92bb 100644 --- a/sasdata/quantities/_build_tables.py +++ b/sasdata/quantities/_build_tables.py @@ -61,15 +61,12 @@ non_si_dimensioned_units: list[tuple[str, str | None, str, str, float, int, int, int, int, int, int, int, list]] = [ UnitData("Ang", "Å", r"\AA", "angstrom", "angstroms", 1e-10, 1, 0, 0, 0, 0, 0, 0, []), - UnitData("micron", None, None, "micron", "microns", 1e-6, 1, 0, 0, 0, 0, 0, 0, []), UnitData("min", None, None, "minute", "minutes", 60, 0, 1, 0, 0, 0, 0, 0, []), - UnitData("rpm", None, None, "revolutions per minute", "revolutions per minute", 1/60, 0, -1, 0, 0, 0, 0, 0, []), - UnitData("h", None, None, "hour", "hours", 3600, 0, 1, 0, 0, 0, 0, 0, []), - UnitData("d", None, None, "day", "days", 3600*24, 0, 1, 0, 0, 0, 0, 0, []), - UnitData("y", None, None, "year", "years", 3600*24*365.2425, 0, 1, 0, 0, 0, 0, 0, []), + UnitData("h", None, None, "hour", "hours", 360, 0, 1, 0, 0, 0, 0, 0, []), + UnitData("d", None, None, "day", "days", 360*24, 0, 1, 0, 0, 0, 0, 0, []), + UnitData("y", None, None, "year", "years", 360*24*365.2425, 0, 1, 0, 0, 0, 0, 0, []), UnitData("deg", None, None, "degree", "degrees", 180/np.pi, 0, 0, 0, 0, 0, 0, 1, []), UnitData("rad", None, None, "radian", "radians", 1, 0, 0, 0, 0, 0, 0, 1, []), - UnitData("rot", None, None, "rotation", "rotations", 2*np.pi, 0, 0, 0, 0, 0, 0, 1, []), UnitData("sr", None, None, "stradian", "stradians", 1, 0, 0, 0, 0, 0, 0, 2, []), UnitData("l", None, None, "litre", "litres", 1e-3, 3, 0, 0, 0, 0, 0, 0, []), UnitData("eV", None, None, "electronvolt", "electronvolts", 1.602176634e-19, 2, -2, 1, 0, 0, 0, 0, all_magnitudes), @@ -112,7 +109,7 @@ "Ang": ["A", "Å"], "au": ["amu"], "percent": ["%"], - "deg": ["degr", "Deg", "degree", "degrees", "Degrees"], + "deg": ["degr", "Deg", "degrees", "Degrees"], "none": ["Counts", "counts", "cnts", "Cnts", "a.u.", "fraction", "Fraction"], "K": ["C"] # Ugh, cansas } @@ -144,7 +141,7 @@ def format_name(name: str): # when called from units.py. This condition patches the # line when the copy is made. if line.startswith("from unicode_superscript"): - fid.write(line.replace("from unicode_superscript", "\nfrom sasdata.quantities.unicode_superscript")) + fid.write(line.replace("from unicode_superscript", "from sasdata.quantities.unicode_superscript")) else: fid.write(line) @@ -433,22 +430,15 @@ def format_name(name: str): f"\n") fid.write("\n") - with open("si.py", 'w') as fid: - si_unit_names = [values.plural for values in base_si_units + derived_si_units if values.plural != "grams"] + ["kilograms"] - si_unit_names.sort() - fid.write('"""'+(warning_text%"_build_tables.py")+'"""\n\n') - fid.write("from sasdata.quantities.units import (\n") + si_unit_names = [values.plural for values in base_si_units + derived_si_units if values.plural != "grams"] + ["kilograms"] for name in si_unit_names: - fid.write(f" {name},\n") - fid.write(")\n") + fid.write(f"from sasdata.quantities.units import {name}\n") fid.write("\nall_si = [\n") - for name in si_unit_names: fid.write(f" {name},\n") - fid.write("]\n") diff --git a/sasdata/quantities/_units_base.py b/sasdata/quantities/_units_base.py index 5543f1d4e..d8ec39b45 100644 --- a/sasdata/quantities/_units_base.py +++ b/sasdata/quantities/_units_base.py @@ -214,20 +214,17 @@ def _components(self, tokens: Sequence["UnitToken"]): pass def __mul__(self: Self, other: "Unit"): - if isinstance(other, Unit): - return Unit(self.scale * other.scale, self.dimensions * other.dimensions) - elif isinstance(other, (int, float)): - return Unit(other * self.scale, self.dimensions) - return NotImplemented + if not isinstance(other, Unit): + return NotImplemented + + return Unit(self.scale * other.scale, self.dimensions * other.dimensions) def __truediv__(self: Self, other: "Unit"): - if isinstance(other, Unit): - return Unit(self.scale / other.scale, self.dimensions / other.dimensions) - elif isinstance(other, (int, float)): - return Unit(self.scale / other, self.dimensions) - else: + if not isinstance(other, Unit): return NotImplemented + return Unit(self.scale / other.scale, self.dimensions / other.dimensions) + def __rtruediv__(self: Self, other: "Unit"): if isinstance(other, Unit): return Unit(other.scale / self.scale, other.dimensions / self.dimensions) @@ -295,27 +292,6 @@ def __init__(self, def __repr__(self): return self.name - def __eq__(self, other): - """Match other units exactly or match strings against ANY of our names""" - match other: - case str(): - return self.name == other or self.name == f"{other}s" or self.ascii_symbol == other or self.symbol == other - case NamedUnit(): - return self.name == other.name \ - and self.ascii_symbol == other.ascii_symbol and self.symbol == other.symbol - case Unit(): - return self.equivalent(other) and np.abs(np.log(self.scale/other.scale)) < 1e-5 - case _: - return False - - - def startswith(self, prefix: str) -> bool: - """Check if any representation of the unit begins with the prefix string""" - prefix = prefix.lower() - return (self.name is not None and self.name.lower().startswith(prefix)) \ - or (self.ascii_symbol is not None and self.ascii_symbol.lower().startswith(prefix)) \ - or (self.symbol is not None and self.symbol.lower().startswith(prefix)) - # # Parsing plan: # Require unknown amounts of units to be explicitly positive or negative? diff --git a/sasdata/quantities/absolute_temperature.py b/sasdata/quantities/absolute_temperature.py index 045c7de97..108f30760 100644 --- a/sasdata/quantities/absolute_temperature.py +++ b/sasdata/quantities/absolute_temperature.py @@ -1,14 +1,14 @@ -from typing import TypeVar - -from sasdata.quantities.accessors import TemperatureAccessor -from sasdata.quantities.quantity import Quantity - -DataType = TypeVar("DataType") -class AbsoluteTemperatureAccessor(TemperatureAccessor[DataType]): - """ Parsing for absolute temperatures """ - @property - def value(self) -> Quantity[DataType] | None: - if self._numerical_part() is None: - return None - else: - return Quantity.parse(self._numerical_part(), self._unit_part(), absolute_temperature=True) +from typing import TypeVar + +from sasdata.quantities.accessors import TemperatureAccessor +from sasdata.quantities.quantity import Quantity + +DataType = TypeVar("DataType") +class AbsoluteTemperatureAccessor(TemperatureAccessor[DataType]): + """ Parsing for absolute temperatures """ + @property + def value(self) -> Quantity[DataType] | None: + if self._numerical_part() is None: + return None + else: + return Quantity.parse(self._numerical_part(), self._unit_part(), absolute_temperature=True) diff --git a/sasdata/quantities/accessors.py b/sasdata/quantities/accessors.py index d23268544..a7d23fd27 100644 --- a/sasdata/quantities/accessors.py +++ b/sasdata/quantities/accessors.py @@ -5,75 +5,75 @@ Do not edit by hand, instead edit the files that build it (_build_tables.py, _accessor_base.py) - - -DDDDDDDDDDDDD NNNNNNNN NNNNNNNN tttt -D::::::::::::DDD N:::::::N N::::::N ttt:::t -D:::::::::::::::DD N::::::::N N::::::N t:::::t -DDD:::::DDDDD:::::D N:::::::::N N::::::N t:::::t - D:::::D D:::::D ooooooooooo N::::::::::N N::::::N ooooooooooo ttttttt:::::ttttttt - D:::::D D:::::D oo:::::::::::oo N:::::::::::N N::::::N oo:::::::::::oo t:::::::::::::::::t - D:::::D D:::::Do:::::::::::::::o N:::::::N::::N N::::::No:::::::::::::::ot:::::::::::::::::t - D:::::D D:::::Do:::::ooooo:::::o N::::::N N::::N N::::::No:::::ooooo:::::otttttt:::::::tttttt - D:::::D D:::::Do::::o o::::o N::::::N N::::N:::::::No::::o o::::o t:::::t - D:::::D D:::::Do::::o o::::o N::::::N N:::::::::::No::::o o::::o t:::::t - D:::::D D:::::Do::::o o::::o N::::::N N::::::::::No::::o o::::o t:::::t - D:::::D D:::::D o::::o o::::o N::::::N N:::::::::No::::o o::::o t:::::t tttttt -DDD:::::DDDDD:::::D o:::::ooooo:::::o N::::::N N::::::::No:::::ooooo:::::o t::::::tttt:::::t -D:::::::::::::::DD o:::::::::::::::o N::::::N N:::::::No:::::::::::::::o tt::::::::::::::t -D::::::::::::DDD oo:::::::::::oo N::::::N N::::::N oo:::::::::::oo tt:::::::::::tt -DDDDDDDDDDDDD ooooooooooo NNNNNNNN NNNNNNN ooooooooooo ttttttttttt - - - - - - - - - dddddddd -EEEEEEEEEEEEEEEEEEEEEE d::::::d iiii tttt BBBBBBBBBBBBBBBBB -E::::::::::::::::::::E d::::::d i::::i ttt:::t B::::::::::::::::B -E::::::::::::::::::::E d::::::d iiii t:::::t B::::::BBBBBB:::::B -EE::::::EEEEEEEEE::::E d:::::d t:::::t BB:::::B B:::::B + + +DDDDDDDDDDDDD NNNNNNNN NNNNNNNN tttt +D::::::::::::DDD N:::::::N N::::::N ttt:::t +D:::::::::::::::DD N::::::::N N::::::N t:::::t +DDD:::::DDDDD:::::D N:::::::::N N::::::N t:::::t + D:::::D D:::::D ooooooooooo N::::::::::N N::::::N ooooooooooo ttttttt:::::ttttttt + D:::::D D:::::D oo:::::::::::oo N:::::::::::N N::::::N oo:::::::::::oo t:::::::::::::::::t + D:::::D D:::::Do:::::::::::::::o N:::::::N::::N N::::::No:::::::::::::::ot:::::::::::::::::t + D:::::D D:::::Do:::::ooooo:::::o N::::::N N::::N N::::::No:::::ooooo:::::otttttt:::::::tttttt + D:::::D D:::::Do::::o o::::o N::::::N N::::N:::::::No::::o o::::o t:::::t + D:::::D D:::::Do::::o o::::o N::::::N N:::::::::::No::::o o::::o t:::::t + D:::::D D:::::Do::::o o::::o N::::::N N::::::::::No::::o o::::o t:::::t + D:::::D D:::::D o::::o o::::o N::::::N N:::::::::No::::o o::::o t:::::t tttttt +DDD:::::DDDDD:::::D o:::::ooooo:::::o N::::::N N::::::::No:::::ooooo:::::o t::::::tttt:::::t +D:::::::::::::::DD o:::::::::::::::o N::::::N N:::::::No:::::::::::::::o tt::::::::::::::t +D::::::::::::DDD oo:::::::::::oo N::::::N N::::::N oo:::::::::::oo tt:::::::::::tt +DDDDDDDDDDDDD ooooooooooo NNNNNNNN NNNNNNN ooooooooooo ttttttttttt + + + + + + + + + dddddddd +EEEEEEEEEEEEEEEEEEEEEE d::::::d iiii tttt BBBBBBBBBBBBBBBBB +E::::::::::::::::::::E d::::::d i::::i ttt:::t B::::::::::::::::B +E::::::::::::::::::::E d::::::d iiii t:::::t B::::::BBBBBB:::::B +EE::::::EEEEEEEEE::::E d:::::d t:::::t BB:::::B B:::::B E:::::E EEEEEE ddddddddd:::::d iiiiiiittttttt:::::ttttttt B::::B B:::::Byyyyyyy yyyyyyy - E:::::E dd::::::::::::::d i:::::it:::::::::::::::::t B::::B B:::::B y:::::y y:::::y - E::::::EEEEEEEEEE d::::::::::::::::d i::::it:::::::::::::::::t B::::BBBBBB:::::B y:::::y y:::::y - E:::::::::::::::E d:::::::ddddd:::::d i::::itttttt:::::::tttttt B:::::::::::::BB y:::::y y:::::y - E:::::::::::::::E d::::::d d:::::d i::::i t:::::t B::::BBBBBB:::::B y:::::y y:::::y - E::::::EEEEEEEEEE d:::::d d:::::d i::::i t:::::t B::::B B:::::B y:::::y y:::::y - E:::::E d:::::d d:::::d i::::i t:::::t B::::B B:::::B y:::::y:::::y - E:::::E EEEEEEd:::::d d:::::d i::::i t:::::t tttttt B::::B B:::::B y:::::::::y -EE::::::EEEEEEEE:::::Ed::::::ddddd::::::ddi::::::i t::::::tttt:::::t BB:::::BBBBBB::::::B y:::::::y -E::::::::::::::::::::E d:::::::::::::::::di::::::i tt::::::::::::::t B:::::::::::::::::B y:::::y -E::::::::::::::::::::E d:::::::::ddd::::di::::::i tt:::::::::::tt B::::::::::::::::B y:::::y -EEEEEEEEEEEEEEEEEEEEEE ddddddddd dddddiiiiiiii ttttttttttt BBBBBBBBBBBBBBBBB y:::::y - y:::::y - y:::::y - y:::::y - y:::::y - yyyyyyy - - - - dddddddd -HHHHHHHHH HHHHHHHHH d::::::d -H:::::::H H:::::::H d::::::d -H:::::::H H:::::::H d::::::d -HH::::::H H::::::HH d:::::d - H:::::H H:::::H aaaaaaaaaaaaa nnnn nnnnnnnn ddddddddd:::::d - H:::::H H:::::H a::::::::::::a n:::nn::::::::nn dd::::::::::::::d - H::::::HHHHH::::::H aaaaaaaaa:::::an::::::::::::::nn d::::::::::::::::d - H:::::::::::::::::H a::::ann:::::::::::::::nd:::::::ddddd:::::d - H:::::::::::::::::H aaaaaaa:::::a n:::::nnnn:::::nd::::::d d:::::d - H::::::HHHHH::::::H aa::::::::::::a n::::n n::::nd:::::d d:::::d - H:::::H H:::::H a::::aaaa::::::a n::::n n::::nd:::::d d:::::d - H:::::H H:::::H a::::a a:::::a n::::n n::::nd:::::d d:::::d -HH::::::H H::::::HHa::::a a:::::a n::::n n::::nd::::::ddddd::::::dd -H:::::::H H:::::::Ha:::::aaaa::::::a n::::n n::::n d:::::::::::::::::d -H:::::::H H:::::::H a::::::::::aa:::a n::::n n::::n d:::::::::ddd::::d -HHHHHHHHH HHHHHHHHH aaaaaaaaaa aaaa nnnnnn nnnnnn ddddddddd ddddd - + E:::::E dd::::::::::::::d i:::::it:::::::::::::::::t B::::B B:::::B y:::::y y:::::y + E::::::EEEEEEEEEE d::::::::::::::::d i::::it:::::::::::::::::t B::::BBBBBB:::::B y:::::y y:::::y + E:::::::::::::::E d:::::::ddddd:::::d i::::itttttt:::::::tttttt B:::::::::::::BB y:::::y y:::::y + E:::::::::::::::E d::::::d d:::::d i::::i t:::::t B::::BBBBBB:::::B y:::::y y:::::y + E::::::EEEEEEEEEE d:::::d d:::::d i::::i t:::::t B::::B B:::::B y:::::y y:::::y + E:::::E d:::::d d:::::d i::::i t:::::t B::::B B:::::B y:::::y:::::y + E:::::E EEEEEEd:::::d d:::::d i::::i t:::::t tttttt B::::B B:::::B y:::::::::y +EE::::::EEEEEEEE:::::Ed::::::ddddd::::::ddi::::::i t::::::tttt:::::t BB:::::BBBBBB::::::B y:::::::y +E::::::::::::::::::::E d:::::::::::::::::di::::::i tt::::::::::::::t B:::::::::::::::::B y:::::y +E::::::::::::::::::::E d:::::::::ddd::::di::::::i tt:::::::::::tt B::::::::::::::::B y:::::y +EEEEEEEEEEEEEEEEEEEEEE ddddddddd dddddiiiiiiii ttttttttttt BBBBBBBBBBBBBBBBB y:::::y + y:::::y + y:::::y + y:::::y + y:::::y + yyyyyyy + + + + dddddddd +HHHHHHHHH HHHHHHHHH d::::::d +H:::::::H H:::::::H d::::::d +H:::::::H H:::::::H d::::::d +HH::::::H H::::::HH d:::::d + H:::::H H:::::H aaaaaaaaaaaaa nnnn nnnnnnnn ddddddddd:::::d + H:::::H H:::::H a::::::::::::a n:::nn::::::::nn dd::::::::::::::d + H::::::HHHHH::::::H aaaaaaaaa:::::an::::::::::::::nn d::::::::::::::::d + H:::::::::::::::::H a::::ann:::::::::::::::nd:::::::ddddd:::::d + H:::::::::::::::::H aaaaaaa:::::a n:::::nnnn:::::nd::::::d d:::::d + H::::::HHHHH::::::H aa::::::::::::a n::::n n::::nd:::::d d:::::d + H:::::H H:::::H a::::aaaa::::::a n::::n n::::nd:::::d d:::::d + H:::::H H:::::H a::::a a:::::a n::::n n::::nd:::::d d:::::d +HH::::::H H::::::HHa::::a a:::::a n::::n n::::nd::::::ddddd::::::dd +H:::::::H H:::::::Ha:::::aaaa::::::a n::::n n::::n d:::::::::::::::::d +H:::::::H H:::::::H a::::::::::aa:::a n::::n n::::n d:::::::::ddd::::d +HHHHHHHHH HHHHHHHHH aaaaaaaaaa aaaa nnnnnn nnnnnn ddddddddd ddddd + """ @@ -362,14 +362,6 @@ def angstroms(self) -> T: else: return quantity.in_units_of(units.angstroms) - @property - def microns(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns) - @property def miles(self) -> T: quantity = self.quantity @@ -535,14 +527,6 @@ def square_angstroms(self) -> T: else: return quantity.in_units_of(units.square_angstroms) - @property - def square_microns(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.square_microns) - @property def square_miles(self) -> T: quantity = self.quantity @@ -716,14 +700,6 @@ def cubic_angstroms(self) -> T: else: return quantity.in_units_of(units.cubic_angstroms) - @property - def cubic_microns(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.cubic_microns) - @property def cubic_miles(self) -> T: quantity = self.quantity @@ -889,14 +865,6 @@ def per_angstrom(self) -> T: else: return quantity.in_units_of(units.per_angstrom) - @property - def per_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.per_micron) - @property def per_mile(self) -> T: quantity = self.quantity @@ -1062,14 +1030,6 @@ def per_square_angstrom(self) -> T: else: return quantity.in_units_of(units.per_square_angstrom) - @property - def per_square_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.per_square_micron) - @property def per_square_mile(self) -> T: quantity = self.quantity @@ -1235,14 +1195,6 @@ def per_cubic_angstrom(self) -> T: else: return quantity.in_units_of(units.per_cubic_angstrom) - @property - def per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.per_cubic_micron) - @property def per_cubic_mile(self) -> T: quantity = self.quantity @@ -1477,14 +1429,6 @@ def attohertz(self) -> T: else: return quantity.in_units_of(units.attohertz) - @property - def revolutions_per_minute(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.revolutions_per_minute) - class SpeedAccessor[T](QuantityAccessor[T]): @@ -2898,94 +2842,6 @@ def angstroms_per_year(self) -> T: else: return quantity.in_units_of(units.angstroms_per_year) - @property - def microns_per_second(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_second) - - @property - def microns_per_millisecond(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_millisecond) - - @property - def microns_per_microsecond(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_microsecond) - - @property - def microns_per_nanosecond(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_nanosecond) - - @property - def microns_per_picosecond(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_picosecond) - - @property - def microns_per_femtosecond(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_femtosecond) - - @property - def microns_per_attosecond(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_attosecond) - - @property - def microns_per_minute(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_minute) - - @property - def microns_per_hour(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_hour) - - @property - def microns_per_day(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_day) - - @property - def microns_per_year(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_year) - @property def miles_per_second(self) -> T: quantity = self.quantity @@ -4751,94 +4607,6 @@ def angstroms_per_square_year(self) -> T: else: return quantity.in_units_of(units.angstroms_per_square_year) - @property - def microns_per_square_second(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_square_second) - - @property - def microns_per_square_millisecond(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_square_millisecond) - - @property - def microns_per_square_microsecond(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_square_microsecond) - - @property - def microns_per_square_nanosecond(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_square_nanosecond) - - @property - def microns_per_square_picosecond(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_square_picosecond) - - @property - def microns_per_square_femtosecond(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_square_femtosecond) - - @property - def microns_per_square_attosecond(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_square_attosecond) - - @property - def microns_per_square_minute(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_square_minute) - - @property - def microns_per_square_hour(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_square_hour) - - @property - def microns_per_square_day(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_square_day) - - @property - def microns_per_square_year(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.microns_per_square_year) - @property def miles_per_square_second(self) -> T: quantity = self.quantity @@ -7244,134 +7012,6 @@ def ounces_per_cubic_angstrom(self) -> T: else: return quantity.in_units_of(units.ounces_per_cubic_angstrom) - @property - def grams_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.grams_per_cubic_micron) - - @property - def exagrams_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.exagrams_per_cubic_micron) - - @property - def petagrams_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.petagrams_per_cubic_micron) - - @property - def teragrams_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.teragrams_per_cubic_micron) - - @property - def gigagrams_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.gigagrams_per_cubic_micron) - - @property - def megagrams_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.megagrams_per_cubic_micron) - - @property - def kilograms_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.kilograms_per_cubic_micron) - - @property - def milligrams_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.milligrams_per_cubic_micron) - - @property - def micrograms_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.micrograms_per_cubic_micron) - - @property - def nanograms_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.nanograms_per_cubic_micron) - - @property - def picograms_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.picograms_per_cubic_micron) - - @property - def femtograms_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.femtograms_per_cubic_micron) - - @property - def attograms_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.attograms_per_cubic_micron) - - @property - def atomic_mass_units_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.atomic_mass_units_per_cubic_micron) - - @property - def pounds_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.pounds_per_cubic_micron) - - @property - def ounces_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.ounces_per_cubic_micron) - @property def grams_per_cubic_mile(self) -> T: quantity = self.quantity @@ -10454,62 +10094,6 @@ def attomoles_per_cubic_angstrom(self) -> T: else: return quantity.in_units_of(units.attomoles_per_cubic_angstrom) - @property - def moles_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.moles_per_cubic_micron) - - @property - def millimoles_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.millimoles_per_cubic_micron) - - @property - def micromoles_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.micromoles_per_cubic_micron) - - @property - def nanomoles_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.nanomoles_per_cubic_micron) - - @property - def picomoles_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.picomoles_per_cubic_micron) - - @property - def femtomoles_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.femtomoles_per_cubic_micron) - - @property - def attomoles_per_cubic_micron(self) -> T: - quantity = self.quantity - if quantity is None: - return None - else: - return quantity.in_units_of(units.attomoles_per_cubic_micron) - @property def moles_per_cubic_mile(self) -> T: quantity = self.quantity diff --git a/sasdata/quantities/numerical_encoding.py b/sasdata/quantities/numerical_encoding.py index 6e2e53265..64264aab7 100644 --- a/sasdata/quantities/numerical_encoding.py +++ b/sasdata/quantities/numerical_encoding.py @@ -1,3 +1,5 @@ + + import base64 import struct diff --git a/sasdata/quantities/quantity.py b/sasdata/quantities/quantity.py index 08e0be6fe..5ac546024 100644 --- a/sasdata/quantities/quantity.py +++ b/sasdata/quantities/quantity.py @@ -1,50 +1,50 @@ + + + import hashlib import json -import math from typing import Any, Self, TypeVar, Union -import h5py import numpy as np from numpy._typing import ArrayLike from sasdata.quantities import units from sasdata.quantities.numerical_encoding import numerical_decode, numerical_encode -from sasdata.quantities.unit_parser import parse_unit from sasdata.quantities.units import NamedUnit, Unit T = TypeVar("T") -################### Quantity based operations, need to be here to avoid cyclic dependencies ##################### + +################### Quantity based operations, need to be here to avoid cyclic dependencies ##################### + def transpose(a: Union["Quantity[ArrayLike]", ArrayLike], axes: tuple | None = None): - """Transpose an array or an array based quantity, can also do reordering of axes""" + """ Transpose an array or an array based quantity, can also do reordering of axes""" if isinstance(a, Quantity): + if axes is None: - return DerivedQuantity( - value=np.transpose(a.value, axes=axes), - units=a.units, - history=QuantityHistory.apply_operation(Transpose, a.history), - ) + return DerivedQuantity(value=np.transpose(a.value, axes=axes), + units=a.units, + history=QuantityHistory.apply_operation(Transpose, a.history)) else: - return DerivedQuantity( - value=np.transpose(a.value, axes=axes), - units=a.units, - history=QuantityHistory.apply_operation(Transpose, a.history, axes=axes), - ) + return DerivedQuantity(value=np.transpose(a.value, axes=axes), + units=a.units, + history=QuantityHistory.apply_operation(Transpose, a.history, axes=axes)) else: return np.transpose(a, axes=axes) def dot(a: Union["Quantity[ArrayLike]", ArrayLike], b: Union["Quantity[ArrayLike]", ArrayLike]): - """Dot product of two arrays or two array based quantities""" + """ Dot product of two arrays or two array based quantities """ a_is_quantity = isinstance(a, Quantity) b_is_quantity = isinstance(b, Quantity) if a_is_quantity or b_is_quantity: + # If its only one of them that is a quantity, convert the other one if not a_is_quantity: @@ -56,20 +56,13 @@ def dot(a: Union["Quantity[ArrayLike]", ArrayLike], b: Union["Quantity[ArrayLike return DerivedQuantity( value=np.dot(a.value, b.value), units=a.units * b.units, - history=QuantityHistory.apply_operation(Dot, a.history, b.history), - ) + history=QuantityHistory.apply_operation(Dot, a.history, b.history)) else: return np.dot(a, b) - -def tensordot( - a: Union["Quantity[ArrayLike]", ArrayLike] | ArrayLike, - b: Union["Quantity[ArrayLike]", ArrayLike], - a_index: int, - b_index: int, -): - """Tensor dot product - equivalent to contracting two tensors, such as +def tensordot(a: Union["Quantity[ArrayLike]", ArrayLike] | ArrayLike, b: Union["Quantity[ArrayLike]", ArrayLike], a_index: int, b_index: int): + """ Tensor dot product - equivalent to contracting two tensors, such as A_{i0, i1, i2, i3...} and B_{j0, j1, j2...} @@ -85,6 +78,7 @@ def tensordot( b_is_quantity = isinstance(b, Quantity) if a_is_quantity or b_is_quantity: + # If its only one of them that is a quantity, convert the other one if not a_is_quantity: @@ -96,8 +90,12 @@ def tensordot( return DerivedQuantity( value=np.tensordot(a.value, b.value, axes=(a_index, b_index)), units=a.units * b.units, - history=QuantityHistory.apply_operation(TensorDot, a.history, b.history, a_index=a_index, b_index=b_index), - ) + history=QuantityHistory.apply_operation( + TensorDot, + a.history, + b.history, + a_index=a_index, + b_index=b_index)) else: return np.tensordot(a, b, axes=(a_index, b_index)) @@ -105,9 +103,8 @@ def tensordot( ################### Operation Definitions ####################################### - def hash_and_name(hash_or_name: int | str): - """Infer the name of a variable from a hash, or the hash from the name + """ Infer the name of a variable from a hash, or the hash from the name Note: hash_and_name(hash_and_name(number)[1]) is not the identity however: hash_and_name(hash_and_name(number)) is @@ -131,35 +128,34 @@ def hash_and_name(hash_or_name: int | str): else: raise TypeError("Variable name_or_hash_value must be either str or int") - class Operation: - serialisation_name = "unknown" - def summary(self, indent_amount: int = 0, indent: str = " "): - """Summary of the operation tree""" + serialisation_name = "unknown" + def summary(self, indent_amount: int = 0, indent: str=" "): + """ Summary of the operation tree""" - s = f"{indent_amount * indent}{self.__class__.__name__}(\n" + s = f"{indent_amount*indent}{self._summary_open()}(\n" for chunk in self._summary_components(): - s += chunk.summary(indent_amount + 1, indent) + "\n" + s += chunk.summary(indent_amount+1, indent) + "\n" - s += f"{indent_amount * indent})" + s += f"{indent_amount*indent})" return s + def _summary_open(self): + """ First line of summary """ def _summary_components(self) -> list["Operation"]: return [] - def evaluate(self, variables: dict[int, T]) -> T: - """Evaluate this operation""" - pass + + """ Evaluate this operation """ def _derivative(self, hash_value: int) -> "Operation": - """Get the derivative of this operation""" - pass + """ Get the derivative of this operation """ def _clean(self): - """Clean up this operation - i.e. remove silly things like 1*x""" + """ Clean up this operation - i.e. remove silly things like 1*x """ return self def derivative(self, variable: Union[str, int, "Variable"], simplify=True): @@ -181,7 +177,8 @@ def derivative(self, variable: Union[str, int, "Variable"], simplify=True): # print(derivative.summary()) # Inefficient way of doing repeated simplification, but it will work - for i in range(100): # set max iterations + for i in range(100): # set max iterations + derivative = derivative._clean() # # print("-------------------") @@ -206,14 +203,16 @@ def deserialise(data: str) -> "Operation": @staticmethod def deserialise_json(json_data: dict) -> "Operation": + operation = json_data["operation"] parameters = json_data["parameters"] - class_ = _serialisation_lookup[operation] + cls = _serialisation_lookup[operation] try: - return class_._deserialise(parameters) + return cls._deserialise(parameters) + except NotImplementedError: - raise NotImplementedError(f"No method to deserialise {operation} with {parameters} (class={class_})") + raise NotImplementedError(f"No method to deserialise {operation} with {parameters} (cls={cls})") @staticmethod def _deserialise(parameters: dict) -> "Operation": @@ -223,26 +222,25 @@ def serialise(self) -> str: return json.dumps(self._serialise_json()) def _serialise_json(self) -> dict[str, Any]: - return {"operation": self.serialisation_name, "parameters": self._serialise_parameters()} + return {"operation": self.serialisation_name, + "parameters": self._serialise_parameters()} def _serialise_parameters(self) -> dict[str, Any]: - raise NotImplementedError("_serialise_parameters not implemented for this class") + raise NotImplementedError("_serialise_parameters not implemented") def __eq__(self, other: "Operation"): return NotImplemented - class ConstantBase(Operation): pass - class AdditiveIdentity(ConstantBase): - serialisation_name = "zero" + serialisation_name = "zero" def evaluate(self, variables: dict[int, T]) -> T: return 0 - def _derivative(self, hash_value: int) -> "Operation": + def _derivative(self, hash_value: int) -> Operation: return AdditiveIdentity() @staticmethod @@ -252,8 +250,8 @@ def _deserialise(parameters: dict) -> "Operation": def _serialise_parameters(self) -> dict[str, Any]: return {} - def summary(self, indent_amount: int = 0, indent=" "): - return f"{indent_amount * indent}0 [Add.Id.]" + def summary(self, indent_amount: int=0, indent=" "): + return f"{indent_amount*indent}0 [Add.Id.]" def __eq__(self, other): if isinstance(other, AdditiveIdentity): @@ -265,7 +263,9 @@ def __eq__(self, other): return False + class MultiplicativeIdentity(ConstantBase): + serialisation_name = "one" def evaluate(self, variables: dict[int, T]) -> T: @@ -278,11 +278,13 @@ def _derivative(self, hash_value: int): def _deserialise(parameters: dict) -> "Operation": return MultiplicativeIdentity() + def _serialise_parameters(self) -> dict[str, Any]: return {} - def summary(self, indent_amount: int = 0, indent=" "): - return f"{indent_amount * indent}1 [Mul.Id.]" + + def summary(self, indent_amount: int=0, indent=" "): + return f"{indent_amount*indent}1 [Mul.Id.]" def __eq__(self, other): if isinstance(other, MultiplicativeIdentity): @@ -295,11 +297,14 @@ def __eq__(self, other): class Constant(ConstantBase): - serialisation_name = "constant" + serialisation_name = "constant" def __init__(self, value): self.value = value + def summary(self, indent_amount: int = 0, indent: str=" "): + return repr(self.value) + def evaluate(self, variables: dict[int, T]) -> T: return self.value @@ -307,6 +312,7 @@ def _derivative(self, hash_value: int): return AdditiveIdentity() def _clean(self): + if self.value == 0: return AdditiveIdentity() @@ -321,11 +327,12 @@ def _deserialise(parameters: dict) -> "Operation": value = numerical_decode(parameters["value"]) return Constant(value) + def _serialise_parameters(self) -> dict[str, Any]: return {"value": numerical_encode(self.value)} - def summary(self, indent_amount: int = 0, indent=" "): - return f"{indent_amount * indent}{self.value}" + def summary(self, indent_amount: int=0, indent=" "): + return f"{indent_amount*indent}{self.value}" def __eq__(self, other): if isinstance(other, AdditiveIdentity): @@ -335,14 +342,15 @@ def __eq__(self, other): return self.value == 1 elif isinstance(other, Constant): - return other.value == self.value + if other.value == self.value: + return True return False class Variable(Operation): - serialisation_name = "variable" + serialisation_name = "variable" def __init__(self, name_or_hash_value: int | str | tuple[int, str]): self.hash_value, self.name = hash_and_name(name_or_hash_value) @@ -366,40 +374,35 @@ def _deserialise(parameters: dict) -> "Operation": return Variable((hash_value, name)) def _serialise_parameters(self) -> dict[str, Any]: - return {"hash_value": self.hash_value, "name": self.name} + return {"hash_value": self.hash_value, + "name": self.name} - def summary(self, indent_amount: int = 0, indent: str = " "): - return f"{indent_amount * indent}{self.name}" + def summary(self, indent_amount: int = 0, indent: str=" "): + return f"{indent_amount*indent}{self.name}" def __eq__(self, other): if isinstance(other, Variable): return self.hash_value == other.hash_value - return False + return False class UnaryOperation(Operation): + def __init__(self, a: Operation): self.a = a def _serialise_parameters(self) -> dict[str, Any]: return {"a": self.a._serialise_json()} - @classmethod - def _deserialise(cls, parameters: dict) -> "UnaryOperation": - return cls(Operation.deserialise_json(parameters["a"])) - def _summary_components(self) -> list["Operation"]: return [self.a] - def __eq__(self, other): - if isinstance(other, self.__class__): - return self.a == other.a - return False + class Neg(UnaryOperation): - serialisation_name = "neg" + serialisation_name = "neg" def evaluate(self, variables: dict[int, T]) -> T: return -self.a.evaluate(variables) @@ -407,6 +410,7 @@ def _derivative(self, hash_value: int): return Neg(self.a._derivative(hash_value)) def _clean(self): + clean_a = self.a._clean() if isinstance(clean_a, Neg): @@ -419,12 +423,25 @@ def _clean(self): else: return Neg(clean_a) + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Neg(Operation.deserialise_json(parameters["a"])) + + + def _summary_open(self): + return "Neg" + + def __eq__(self, other): + if isinstance(other, Neg): + return other.a == self.a + class Inv(UnaryOperation): + serialisation_name = "reciprocal" def evaluate(self, variables: dict[int, T]) -> T: - return 1.0 / self.a.evaluate(variables) + return 1/self.a.evaluate(variables) def _derivative(self, hash_value: int) -> Operation: return Neg(Div(self.a._derivative(hash_value), Mul(self.a, self.a))) @@ -433,7 +450,7 @@ def _clean(self): clean_a = self.a._clean() if isinstance(clean_a, Inv): - # Removes double inversions + # Removes double negations return clean_a.a elif isinstance(clean_a, Neg): @@ -442,217 +459,23 @@ def _clean(self): return Neg(Inv(clean_a.a)) elif isinstance(clean_a, Constant): - return Constant(1.0 / clean_a.value)._clean() + return Constant(1/clean_a.value)._clean() else: return Inv(clean_a) -class Ln(UnaryOperation): - serialisation_name = "ln" - - def evaluate(self, variables: dict[int, T]) -> T: - return math.log(self.a.evaluate(variables)) - - def _derivative(self, hash_value: int) -> Operation: - return Div(self.a._derivative(hash_value), self.a) - - def _clean(self): - clean_a = self.a._clean() - - if isinstance(clean_a, Exp): - # Convert ln(exp(x)) to x - return clean_a.a - - elif isinstance(clean_a, MultiplicativeIdentity): - # Convert ln(1) to 0 - return AdditiveIdentity() - - elif clean_a == math.e: - # Convert ln(e) to 1 - return MultiplicativeIdentity() - - else: - return Ln(clean_a) - - -class Exp(UnaryOperation): - serialisation_name = "exp" - - def evaluate(self, variables: dict[int, T]) -> T: - return math.exp(self.a.evaluate(variables)) - - def _derivative(self, hash_value: int) -> Operation: - return Mul(self.a._derivative(hash_value), Exp(self.a)) - - def _clean(self): - clean_a = self.a._clean() - - if isinstance(clean_a, Ln): - # Convert exp(ln(x)) to x - return clean_a.a - - elif isinstance(clean_a, MultiplicativeIdentity): - # Convert e**1 to e - return math.e - - elif isinstance(clean_a, AdditiveIdentity): - # Convert e**0 to 1 - return 1 - - else: - return Exp(clean_a) - - -class Sin(UnaryOperation): - serialisation_name = "sin" - - def evaluate(self, variables: dict[int, T]) -> T: - return np.sin(self.a.evaluate(variables)) - - def _derivative(self, hash_value: int) -> Operation: - return Mul(self.a._derivative(hash_value), Cos(self.a)) - - def _clean(self): - clean_a = self.a._clean() - - if isinstance(clean_a, ArcSin): - return clean_a.a - - elif isinstance(clean_a, AdditiveIdentity): - # Convert sin(0) to 0 - return AdditiveIdentity() - - else: - return Sin(clean_a) - - -class ArcSin(UnaryOperation): - serialisation_name = "arcsin" - - def evaluate(self, variables: dict[int, T]) -> T: - return np.arcsin(self.a.evaluate(variables)) - - def _derivative(self, hash_value: int) -> Operation: - return Div(self.a._derivative(hash_value), Sqrt(Sub(MultiplicativeIdentity(), Mul(self.a, self.a)))) - - def _clean(self): - clean_a = self.a._clean() - - if isinstance(clean_a, Sin): - return clean_a.a - - elif isinstance(clean_a, AdditiveIdentity): - # Convert arcsin(0) to 0 - return AdditiveIdentity() - - elif isinstance(clean_a, MultiplicativeIdentity): - # Convert arcsin(1) to pi/2 - return Constant(0.5 * math.pi) - - else: - return ArcSin(clean_a) - - -class Cos(UnaryOperation): - serialisation_name = "cos" - - def evaluate(self, variables: dict[int, T]) -> T: - return np.cos(self.a.evaluate(variables)) - - def _derivative(self, hash_value: int) -> Operation: - return Mul(self.a._derivative(hash_value), Neg(Sin(self.a))) - - def _clean(self): - clean_a = self.a._clean() - - if isinstance(clean_a, ArcCos): - return clean_a.a - - elif isinstance(clean_a, AdditiveIdentity): - # Convert cos(0) to 1 - return MultiplicativeIdentity() - - else: - return Cos(clean_a) - - -class ArcCos(UnaryOperation): - serialisation_name = "arccos" - - def evaluate(self, variables: dict[int, T]) -> T: - return np.arccos(self.a.evaluate(variables)) - - def _derivative(self, hash_value: int) -> Operation: - return Neg(Div(self.a._derivative(hash_value), Sqrt(Sub(MultiplicativeIdentity(), Mul(self.a, self.a))))) - - def _clean(self): - clean_a = self.a._clean() - - if isinstance(clean_a, Cos): - return clean_a.a - - elif isinstance(clean_a, AdditiveIdentity): - # Convert arccos(0) to pi/2 - return Constant(0.5 * math.pi) - - elif isinstance(clean_a, MultiplicativeIdentity): - # Convert arccos(1) to 0 - return AdditiveIdentity() - - else: - return ArcCos(clean_a) - - -class Tan(UnaryOperation): - serialisation_name = "tan" - - def evaluate(self, variables: dict[int, T]) -> T: - return np.tan(self.a.evaluate(variables)) - - def _derivative(self, hash_value: int) -> Operation: - return Div(self.a._derivative(hash_value), Mul(Cos(self.a), Cos(self.a))) - - def _clean(self): - clean_a = self.a._clean() - - if isinstance(clean_a, ArcTan): - return clean_a.a - - elif isinstance(clean_a, AdditiveIdentity): - # Convert tan(0) to 0 - return AdditiveIdentity() - - else: - return Tan(clean_a) - - -class ArcTan(UnaryOperation): - serialisation_name = "arctan" - - def evaluate(self, variables: dict[int, T]) -> T: - return np.arctan(self.a.evaluate(variables)) - - def _derivative(self, hash_value: int) -> Operation: - return Div(self.a._derivative(hash_value), Add(MultiplicativeIdentity(), Mul(self.a, self.a))) - - def _clean(self): - clean_a = self.a._clean() - - if isinstance(clean_a, Tan): - return clean_a.a + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Inv(Operation.deserialise_json(parameters["a"])) - elif isinstance(clean_a, AdditiveIdentity): - # Convert arctan(0) to 0 - return AdditiveIdentity() + def _summary_open(self): + return "Inv" - elif isinstance(clean_a, MultiplicativeIdentity): - # Convert arctan(1) to pi/4 - return Constant(0.25 * math.pi) - - else: - return ArcTan(clean_a) + def __eq__(self, other): + if isinstance(other, Inv): + return other.a == self.a class BinaryOperation(Operation): def __init__(self, a: Operation, b: Operation): @@ -666,28 +489,30 @@ def _clean_ab(self, a, b): raise NotImplementedError("_clean_ab not implemented") def _serialise_parameters(self) -> dict[str, Any]: - return {"a": self.a._serialise_json(), "b": self.b._serialise_json()} - - @classmethod - def _deserialise(cls, parameters: dict) -> "BinaryOperation": - return cls(*BinaryOperation._deserialise_ab(parameters)) + return {"a": self.a._serialise_json(), + "b": self.b._serialise_json()} @staticmethod def _deserialise_ab(parameters) -> tuple[Operation, Operation]: - return (Operation.deserialise_json(parameters["a"]), Operation.deserialise_json(parameters["b"])) + return (Operation.deserialise_json(parameters["a"]), + Operation.deserialise_json(parameters["b"])) + def _summary_components(self) -> list["Operation"]: return [self.a, self.b] + def _self_cls(self) -> type: + """ Own class""" def __eq__(self, other): - if isinstance(other, self.__class__): - return self.a == other.a and self.b == other.b - return False - + if isinstance(other, self._self_cls()): + return other.a == self.a and self.b == other.b class Add(BinaryOperation): + serialisation_name = "add" + def _self_cls(self) -> type: + return Add def evaluate(self, variables: dict[int, T]) -> T: return self.a.evaluate(variables) + self.b.evaluate(variables) @@ -695,6 +520,7 @@ def _derivative(self, hash_value: int) -> Operation: return Add(self.a._derivative(hash_value), self.b._derivative(hash_value)) def _clean_ab(self, a, b): + if isinstance(a, AdditiveIdentity): # Convert 0 + b to b return b @@ -725,10 +551,20 @@ def _clean_ab(self, a, b): else: return Add(a, b) + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Add(*BinaryOperation._deserialise_ab(parameters)) + + def _summary_open(self): + return "Add" class Sub(BinaryOperation): + serialisation_name = "sub" + + def _self_cls(self) -> type: + return Sub def evaluate(self, variables: dict[int, T]) -> T: return self.a.evaluate(variables) - self.b.evaluate(variables) @@ -766,10 +602,21 @@ def _clean_ab(self, a, b): else: return Sub(a, b) + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Sub(*BinaryOperation._deserialise_ab(parameters)) + + + def _summary_open(self): + return "Sub" class Mul(BinaryOperation): + serialisation_name = "mul" + + def _self_cls(self) -> type: + return Mul def evaluate(self, variables: dict[int, T]) -> T: return self.a.evaluate(variables) * self.b.evaluate(variables) @@ -777,6 +624,7 @@ def _derivative(self, hash_value: int) -> Operation: return Add(Mul(self.a, self.b._derivative(hash_value)), Mul(self.a._derivative(hash_value), self.b)) def _clean_ab(self, a, b): + if isinstance(a, AdditiveIdentity) or isinstance(b, AdditiveIdentity): # Convert 0*b or a*0 to 0 return AdditiveIdentity() @@ -824,17 +672,28 @@ def _clean_ab(self, a, b): return Mul(a, b) + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Mul(*BinaryOperation._deserialise_ab(parameters)) + + + def _summary_open(self): + return "Mul" + class Div(BinaryOperation): + serialisation_name = "div" + + def _self_cls(self) -> type: + return Div + def evaluate(self, variables: dict[int, T]) -> T: return self.a.evaluate(variables) / self.b.evaluate(variables) def _derivative(self, hash_value: int) -> Operation: - return Div( - Sub(Mul(self.a.derivative(hash_value), self.b), Mul(self.a, self.b.derivative(hash_value))), - Mul(self.b, self.b), - ) + return Sub(Div(self.a.derivative(hash_value), self.b), + Div(Mul(self.a, self.b.derivative(hash_value)), Mul(self.b, self.b))) def _clean_ab(self, a, b): if isinstance(a, AdditiveIdentity): @@ -853,6 +712,7 @@ def _clean_ab(self, a, b): # Convert constants "a"/"b" to "a/b" return Constant(self.a.evaluate({}) / self.b.evaluate({}))._clean() + elif isinstance(a, Inv) and isinstance(b, Inv): return Div(b.a, a.a) @@ -878,56 +738,15 @@ def _clean_ab(self, a, b): return Div(a, b) -class Log(Operation): - serialisation_name = "log" - - def __init__(self, a: Operation, base: float): - self.a = a - self.base = base - - def evaluate(self, variables: dict[int, T]) -> T: - return math.log(self.a.evaluate(variables), self.base) - - def _derivative(self, hash_value: int) -> Operation: - return Div(self.a.derivative(hash_value), Mul(self.a, Ln(Constant(self.base)))) - - def _clean_ab(self) -> Operation: - a = self.a._clean() - - if isinstance(a, MultiplicativeIdentity): - # Convert log(1) to 0 - return AdditiveIdentity() - - elif a == self.base: - # Convert log(base) to 1 - return MultiplicativeIdentity() - - else: - return Log(a, self.base) - - def _serialise_parameters(self) -> dict[str, Any]: - return {"a": Operation._serialise_json(self.a), "base": self.base} - @staticmethod def _deserialise(parameters: dict) -> "Operation": - return Log(Operation.deserialise_json(parameters["a"]), parameters["base"]) - - def summary(self, indent_amount: int = 0, indent=" "): - return ( - f"{indent_amount * indent}Log(\n" - + self.a.summary(indent_amount + 1, indent) - + "\n" - + f"{(indent_amount + 1) * indent}{self.base}\n" - + f"{indent_amount * indent})" - ) - - def __eq__(self, other): - if isinstance(other, Log): - return self.a == other.a and self.base == other.base - return False + return Div(*BinaryOperation._deserialise_ab(parameters)) + def _summary_open(self): + return "Div" class Pow(Operation): + serialisation_name = "pow" def __init__(self, a: Operation, power: float): @@ -945,9 +764,9 @@ def _derivative(self, hash_value: int) -> Operation: return self.a._derivative(hash_value) else: - return Mul(Constant(self.power), Mul(Pow(self.a, self.power - 1), self.a._derivative(hash_value))) + return Mul(Constant(self.power), Mul(Pow(self.a, self.power-1), self.a._derivative(hash_value))) - def _clean(self): + def _clean(self) -> Operation: a = self.a._clean() if self.power == 1: @@ -962,35 +781,33 @@ def _clean(self): else: return Pow(a, self.power) + def _serialise_parameters(self) -> dict[str, Any]: - return {"a": Operation._serialise_json(self.a), "power": self.power} + return {"a": Operation._serialise_json(self.a), + "power": self.power} @staticmethod def _deserialise(parameters: dict) -> "Operation": return Pow(Operation.deserialise_json(parameters["a"]), parameters["power"]) - def summary(self, indent_amount: int = 0, indent=" "): - return ( - f"{indent_amount * indent}Pow(\n" - + self.a.summary(indent_amount + 1, indent) - + "\n" - + f"{(indent_amount + 1) * indent}{self.power}\n" - + f"{indent_amount * indent})" - ) + def summary(self, indent_amount: int=0, indent=" "): + return (f"{indent_amount*indent}Pow\n" + + self.a.summary(indent_amount+1, indent) + "\n" + + f"{(indent_amount+1)*indent}{self.power}\n" + + f"{indent_amount*indent})") def __eq__(self, other): if isinstance(other, Pow): return self.a == other.a and self.power == other.power - return False + # # Matrix operations # - class Transpose(Operation): - """Transpose operation - as per numpy""" + """ Transpose operation - as per numpy""" serialisation_name = "transpose" @@ -1002,50 +819,44 @@ def evaluate(self, variables: dict[int, T]) -> T: return np.transpose(self.a.evaluate(variables)) def _derivative(self, hash_value: int) -> Operation: - return Transpose(self.a.derivative(hash_value)) # TODO: Check! + return Transpose(self.a.derivative(hash_value)) # TODO: Check! def _clean(self): clean_a = self.a._clean() return Transpose(clean_a) + def _serialise_parameters(self) -> dict[str, Any]: if self.axes is None: - return {"a": self.a._serialise_json()} + return { "a": self.a._serialise_json() } else: - return {"a": self.a._serialise_json(), "axes": list(self.axes)} + return { + "a": self.a._serialise_json(), + "axes": list(self.axes) + } + @staticmethod def _deserialise(parameters: dict) -> "Operation": if "axes" in parameters: - return Transpose(a=Operation.deserialise_json(parameters["a"]), axes=tuple(parameters["axes"])) + return Transpose( + a=Operation.deserialise_json(parameters["a"]), + axes=tuple(parameters["axes"])) else: - return Transpose(a=Operation.deserialise_json(parameters["a"])) + return Transpose( + a=Operation.deserialise_json(parameters["a"])) - def summary(self, indent_amount: int = 0, indent=" "): - if self.axes is None: - return ( - f"{indent_amount * indent}Transpose(\n" - + self.a.summary(indent_amount + 1, indent) - + "\n" - + f"{indent_amount * indent})" - ) - else: - return ( - f"{indent_amount * indent}Transpose(\n" - + self.a.summary(indent_amount + 1, indent) - + "\n" - + f"{(indent_amount + 1) * indent}{self.axes}\n" - + f"{indent_amount * indent})" - ) + + def _summary_open(self): + return "Transpose" def __eq__(self, other): if isinstance(other, Transpose): return other.a == self.a - return False class Dot(BinaryOperation): - """Dot product - backed by numpy's dot method""" + """ Dot product - backed by numpy's dot method""" serialisation_name = "dot" @@ -1053,15 +864,27 @@ def evaluate(self, variables: dict[int, T]) -> T: return dot(self.a.evaluate(variables), self.b.evaluate(variables)) def _derivative(self, hash_value: int) -> Operation: - return Add(Dot(self.a, self.b._derivative(hash_value)), Dot(self.a._derivative(hash_value), self.b)) + return Add( + Dot(self.a, + self.b._derivative(hash_value)), + Dot(self.a._derivative(hash_value), + self.b)) def _clean_ab(self, a, b): - return Dot(a, b) # Do nothing for now + return Dot(a, b) # Do nothing for now + + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Dot(*BinaryOperation._deserialise_ab(parameters)) + + def _summary_open(self): + return "Dot" # TODO: Add to base operation class, and to quantities class MatMul(BinaryOperation): - """Matrix multiplication, using __matmul__ dunder""" + """ Matrix multiplication, using __matmul__ dunder""" serialisation_name = "matmul" @@ -1069,9 +892,14 @@ def evaluate(self, variables: dict[int, T]) -> T: return self.a.evaluate(variables) @ self.b.evaluate(variables) def _derivative(self, hash_value: int) -> Operation: - return Add(MatMul(self.a, self.b._derivative(hash_value)), MatMul(self.a._derivative(hash_value), self.b)) + return Add( + MatMul(self.a, + self.b._derivative(hash_value)), + MatMul(self.a._derivative(hash_value), + self.b)) def _clean_ab(self, a, b): + if isinstance(a, AdditiveIdentity) or isinstance(b, AdditiveIdentity): # Convert 0*b or a*0 to 0 return AdditiveIdentity() @@ -1089,6 +917,13 @@ def _clean_ab(self, a, b): return MatMul(a, b) + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return MatMul(*BinaryOperation._deserialise_ab(parameters)) + + def _summary_open(self): + return "MatMul" + class TensorDot(Operation): serialisation_name = "tensor_product" @@ -1101,59 +936,39 @@ def __init__(self, a: Operation, b: Operation, a_index: int, b_index: int): def evaluate(self, variables: dict[int, T]) -> T: return tensordot(self.a, self.b, self.a_index, self.b_index) + def _serialise_parameters(self) -> dict[str, Any]: return { "a": self.a._serialise_json(), "b": self.b._serialise_json(), "a_index": self.a_index, - "b_index": self.b_index, - } + "b_index": self.b_index } @staticmethod def _deserialise(parameters: dict) -> "Operation": - return TensorDot( - a=Operation.deserialise_json(parameters["a"]), - b=Operation.deserialise_json(parameters["b"]), - a_index=int(parameters["a_index"]), - b_index=int(parameters["b_index"]), - ) - - -_serialisable_classes = [ - AdditiveIdentity, - MultiplicativeIdentity, - Constant, - Variable, - Neg, - Inv, - Ln, - Exp, - Sin, - ArcSin, - Cos, - ArcCos, - Tan, - ArcTan, - Add, - Sub, - Mul, - Div, - Pow, - Log, - Transpose, - Dot, - MatMul, - TensorDot, -] - -_serialisation_lookup = {class_.serialisation_name: class_ for class_ in _serialisable_classes} + return TensorDot(a = Operation.deserialise_json(parameters["a"]), + b = Operation.deserialise_json(parameters["b"]), + a_index=int(parameters["a_index"]), + b_index=int(parameters["b_index"])) + + def _summary_open(self): + return "TensorProduct" + + +_serialisable_classes = [AdditiveIdentity, MultiplicativeIdentity, Constant, + Variable, + Neg, Inv, + Add, Sub, Mul, Div, Pow, + Transpose, Dot, MatMul, TensorDot] + +_serialisation_lookup = {cls.serialisation_name: cls for cls in _serialisable_classes} class UnitError(Exception): """Errors caused by unit specification not being correct""" - def hash_data_via_numpy(*data: ArrayLike): + md5_hash = hashlib.md5() for datum in data: @@ -1164,6 +979,7 @@ def hash_data_via_numpy(*data: ArrayLike): return int(md5_hash.hexdigest(), 16) + ##################################### # # # # @@ -1175,11 +991,12 @@ def hash_data_via_numpy(*data: ArrayLike): ##################################### + QuantityType = TypeVar("QuantityType") class QuantityHistory: - """Class that holds the information for keeping track of operations done on quantities""" + """ Class that holds the information for keeping track of operations done on quantities """ def __init__(self, operation_tree: Operation, references: dict[int, "Quantity"]): self.operation_tree = operation_tree @@ -1189,17 +1006,17 @@ def __init__(self, operation_tree: Operation, references: dict[int, "Quantity"]) self.si_reference_values = {key: self.references[key].in_si() for key in self.references} def jacobian(self) -> list[Operation]: - """Derivative of this quantity's operation history with respect to each of the references""" + """ Derivative of this quantity's operation history with respect to each of the references """ # Use the hash value to specify the variable of differentiation return [self.operation_tree.derivative(key) for key in self.reference_key_list] def _recalculate(self): - """Recalculate the value of this object - primary use case is for testing""" + """ Recalculate the value of this object - primary use case is for testing """ return self.operation_tree.evaluate(self.references) - def variance_propagate(self, quantity_units: Unit, covariances: dict[tuple[int, int] : "Quantity"] = {}): - """Do standard error propagation to calculate the uncertainties associated with this quantity + def variance_propagate(self, quantity_units: Unit, covariances: dict[tuple[int, int]: "Quantity"] = {}): + """ Do standard error propagation to calculate the uncertainties associated with this quantity :param quantity_units: units in which the output should be calculated :param covariances: off diagonal entries for the covariance matrix @@ -1223,16 +1040,15 @@ def variance_propagate(self, quantity_units: Unit, covariances: dict[tuple[int, return output + @staticmethod def variable(quantity: "Quantity"): - """Create a history that starts with the provided data""" + """ Create a history that starts with the provided data """ return QuantityHistory(Variable(quantity.hash_value), {quantity.hash_value: quantity}) @staticmethod - def apply_operation( - operation: type[Operation], *histories: "QuantityHistory", **extra_parameters - ) -> "QuantityHistory": - """Apply an operation to the history + def apply_operation(operation: type[Operation], *histories: "QuantityHistory", **extra_parameters) -> "QuantityHistory": + """ Apply an operation to the history This is slightly unsafe as it is possible to attempt to apply an n-ary operation to a number of trees other than n, but it is relatively concise. Because it is concise we'll go with this for now and see if it causes @@ -1248,8 +1064,8 @@ def apply_operation( references.update(history.references) return QuantityHistory( - operation(*[history.operation_tree for history in histories], **extra_parameters), references - ) + operation(*[history.operation_tree for history in histories], **extra_parameters), + references) def has_variance(self): for key in self.references: @@ -1259,9 +1075,10 @@ def has_variance(self): return False def summary(self): + variable_strings = [self.references[key].string_repr for key in self.references] - s = "Variables: " + ",".join(variable_strings) + s = "Variables: "+",".join(variable_strings) s += "\n" s += self.operation_tree.summary() @@ -1269,15 +1086,14 @@ def summary(self): class Quantity[QuantityType]: - def __init__( - self, - value: QuantityType, - units: Unit, - standard_error: QuantityType | None = None, - hash_seed="", - name="", - id_header="", - ): + + + def __init__(self, + value: QuantityType, + units: Unit, + standard_error: QuantityType | None = None, + hash_seed = ""): + self.value = value """ Numerical value of this data, in the specified units""" @@ -1296,29 +1112,22 @@ def __init__( if standard_error is None: self.hash_value = hash_data_via_numpy(hash_seed, value) else: - self._variance = standard_error**2 + self._variance = standard_error ** 2 self.hash_value = hash_data_via_numpy(hash_seed, value, standard_error) self.history = QuantityHistory.variable(self) - self._id_header = id_header - self.name = name - # TODO: Adding this method as a temporary measure but we need a single # method that does this. - def with_standard_error(self, standard_error: "Quantity"): + def with_standard_error(self, standard_error: Quantity): if standard_error.units.equivalent(self.units): - return Quantity( + return NamedQuantity( value=self.value, units=self.units, - standard_error=standard_error.in_units_of(self.units), - name=self.name, - id_header=self._id_header, - ) + standard_error=standard_error.in_units_of(self.units),) else: - raise UnitError( - f"Standard error units ({standard_error.units}) are not compatible with value units ({self.units})" - ) + raise UnitError(f"Standard error units ({standard_error.units}) " + f"are not compatible with value units ({self.units})") @property def has_variance(self): @@ -1326,37 +1135,17 @@ def has_variance(self): @property def variance(self) -> "Quantity": - """Get the variance of this object""" + """ Get the variance of this object""" if self._variance is None: - return Quantity(np.zeros_like(self.value), self.units**2, name=self.name, id_header=self._id_header) + return Quantity(np.zeros_like(self.value), self.units**2) else: return Quantity(self._variance, self.units**2) - def _base62_hash(self) -> str: - """Encode the hash_value in base62 for better readability""" - hashed = "" - current_hash = self.hash_value - while current_hash: - digit = current_hash % 62 - if digit < 10: - hashed = f"{digit}{hashed}" - elif digit < 36: - hashed = f"{chr(55 + digit)}{hashed}" - else: - hashed = f"{chr(61 + digit)}{hashed}" - current_hash = (current_hash - digit) // 62 - return hashed - - @property - def unique_id(self) -> str: - """Get a human readable unique id for a data set""" - return f"{self._id_header}:{self.name}:{self._base62_hash()}" - def standard_deviation(self) -> "Quantity": - return self.variance**0.5 + return self.variance ** 0.5 def in_units_of(self, units: Unit) -> QuantityType: - """Get this quantity in other units""" + """ Get this quantity in other units """ if self.units.equivalent(units): return (self.units.scale / units.scale) * self.value else: @@ -1364,16 +1153,13 @@ def in_units_of(self, units: Unit) -> QuantityType: def to_units_of(self, new_units: Unit) -> "Quantity[QuantityType]": new_value, new_error = self.in_units_of_with_standard_error(new_units) - return Quantity( - value=new_value, - units=new_units, - standard_error=new_error, - hash_seed=self._hash_seed, - id_header=self._id_header, - ) + return Quantity(value=new_value, + units=new_units, + standard_error=new_error, + hash_seed=self._hash_seed) def variance_in_units_of(self, units: Unit) -> QuantityType: - """Get the variance of quantity in other units""" + """ Get the variance of quantity in other units """ variance = self.variance if variance.units.equivalent(units): return (variance.units.scale / units.scale) * variance @@ -1389,6 +1175,7 @@ def in_units_of_with_standard_error(self, units): units_squared = units**2 if variance.units.equivalent(units_squared): + return self.in_units_of(units), np.sqrt(self.variance.in_units_of(units_squared)) else: raise UnitError(f"Target units ({units}) not compatible with existing units ({variance.units}).") @@ -1399,116 +1186,127 @@ def in_si_with_standard_error(self): else: return self.in_si(), None - def explicitly_formatted(self, unit_string: str) -> str: - """Returns quantity as a string with specific unit formatting - - Performs any necessary unit conversions, but maintains the exact unit - formatting provided by the user. This can be useful if you have a - power expressed in horsepower and you want it expressed as "745.7 N m/s" and not as "745.7 W".""" - unit = parse_unit(unit_string) - quantity = self.in_units_of(unit) - return f"{quantity} {unit_string}" - - def __eq__(self: Self, other: Self) -> bool | np.ndarray: - return self.value == other.in_units_of(self.units) - - def __mul__(self: Self, other: ArrayLike | Self) -> Self: + def __mul__(self: Self, other: ArrayLike | Self ) -> Self: if isinstance(other, Quantity): return DerivedQuantity( self.value * other.value, self.units * other.units, - history=QuantityHistory.apply_operation(Mul, self.history, other.history), - ) + history=QuantityHistory.apply_operation(Mul, self.history, other.history)) else: - return DerivedQuantity( - self.value * other, - self.units, - QuantityHistory(Mul(self.history.operation_tree, Constant(other)), self.history.references), - ) + return DerivedQuantity(self.value * other, self.units, + QuantityHistory( + Mul( + self.history.operation_tree, + Constant(other)), + self.history.references)) def __rmul__(self: Self, other: ArrayLike | Self): if isinstance(other, Quantity): return DerivedQuantity( - other.value * self.value, - other.units * self.units, - history=QuantityHistory.apply_operation(Mul, other.history, self.history), - ) + other.value * self.value, + other.units * self.units, + history=QuantityHistory.apply_operation( + Mul, + other.history, + self.history)) else: - return DerivedQuantity( - other * self.value, - self.units, - QuantityHistory(Mul(Constant(other), self.history.operation_tree), self.history.references), - ) + return DerivedQuantity(other * self.value, self.units, + QuantityHistory( + Mul( + Constant(other), + self.history.operation_tree), + self.history.references)) + def __matmul__(self, other: ArrayLike | Self): if isinstance(other, Quantity): return DerivedQuantity( self.value @ other.value, self.units * other.units, - history=QuantityHistory.apply_operation(MatMul, self.history, other.history), - ) + history=QuantityHistory.apply_operation( + MatMul, + self.history, + other.history)) else: return DerivedQuantity( - self.value @ other, - self.units, - QuantityHistory(MatMul(self.history.operation_tree, Constant(other)), self.history.references), - ) + self.value @ other, + self.units, + QuantityHistory( + MatMul( + self.history.operation_tree, + Constant(other)), + self.history.references)) def __rmatmul__(self, other: ArrayLike | Self): if isinstance(other, Quantity): return DerivedQuantity( - other.value @ self.value, - other.units * self.units, - history=QuantityHistory.apply_operation(MatMul, other.history, self.history), - ) + other.value @ self.value, + other.units * self.units, + history=QuantityHistory.apply_operation( + MatMul, + other.history, + self.history)) else: - return DerivedQuantity( - other @ self.value, - self.units, - QuantityHistory(MatMul(Constant(other), self.history.operation_tree), self.history.references), - ) + return DerivedQuantity(other @ self.value, self.units, + QuantityHistory( + MatMul( + Constant(other), + self.history.operation_tree), + self.history.references)) + def __truediv__(self: Self, other: float | Self) -> Self: if isinstance(other, Quantity): return DerivedQuantity( - self.value / other.value, - self.units / other.units, - history=QuantityHistory.apply_operation(Div, self.history, other.history), - ) + self.value / other.value, + self.units / other.units, + history=QuantityHistory.apply_operation( + Div, + self.history, + other.history)) else: - return DerivedQuantity( - self.value / other, - self.units, - QuantityHistory(Div(Constant(other), self.history.operation_tree), self.history.references), - ) + return DerivedQuantity(self.value / other, self.units, + QuantityHistory( + Div( + Constant(other), + self.history.operation_tree), + self.history.references)) def __rtruediv__(self: Self, other: float | Self) -> Self: if isinstance(other, Quantity): return DerivedQuantity( - other.value / self.value, - other.units / self.units, - history=QuantityHistory.apply_operation(Div, other.history, self.history), - ) + other.value / self.value, + other.units / self.units, + history=QuantityHistory.apply_operation( + Div, + other.history, + self.history + )) else: return DerivedQuantity( - other / self.value, - self.units**-1, - QuantityHistory(Div(Constant(other), self.history.operation_tree), self.history.references), - ) + other / self.value, + self.units ** -1, + QuantityHistory( + Div( + Constant(other), + self.history.operation_tree), + self.history.references)) def __add__(self: Self, other: Self | ArrayLike) -> Self: if isinstance(other, Quantity): if self.units.equivalent(other.units): return DerivedQuantity( - self.value + (other.value * other.units.scale) / self.units.scale, - self.units, - QuantityHistory.apply_operation(Add, self.history, other.history), - ) + self.value + (other.value * other.units.scale) / self.units.scale, + self.units, + QuantityHistory.apply_operation( + Add, + self.history, + other.history)) else: raise UnitError(f"Units do not have the same dimensionality: {self.units} vs {other.units}") @@ -1518,7 +1316,11 @@ def __add__(self: Self, other: Self | ArrayLike) -> Self: # Don't need __radd__ because only quantity/quantity operations should be allowed def __neg__(self): - return DerivedQuantity(-self.value, self.units, QuantityHistory.apply_operation(Neg, self.history)) + return DerivedQuantity(-self.value, self.units, + QuantityHistory.apply_operation( + Neg, + self.history + )) def __sub__(self: Self, other: Self | ArrayLike) -> Self: return self + (-other) @@ -1527,15 +1329,17 @@ def __rsub__(self: Self, other: Self | ArrayLike) -> Self: return (-self) + other def __pow__(self: Self, other: int | float): - return DerivedQuantity( - self.value**other, - self.units**other, - QuantityHistory(Pow(self.history.operation_tree, other), self.history.references), - ) + return DerivedQuantity(self.value ** other, + self.units ** other, + QuantityHistory( + Pow( + self.history.operation_tree, + other), + self.history.references)) @staticmethod def _array_repr_format(arr: np.ndarray): - """Format the array""" + """ Format the array """ order = len(arr.shape) reshaped = arr.reshape(-1) if len(reshaped) <= 2: @@ -1548,10 +1352,12 @@ def _array_repr_format(arr: np.ndarray): # else: # numbers = f"{reshaped[0]}, {reshaped[1]} ... {reshaped[-2]}, {reshaped[-1]}" - return "[" * order + numbers + "]" * order + return "["*order + numbers + "]"*order def __repr__(self): + if isinstance(self.units, NamedUnit): + value = self.value error = self.standard_deviation().in_units_of(self.units) unit_string = self.units.symbol @@ -1582,25 +1388,26 @@ def parse(number_or_string: str | ArrayLike, unit: str, absolute_temperature: Fa def string_repr(self): return str(self.hash_value) - def as_h5(self, group: h5py.Group, name: str): - """Add this data onto a group as a dataset under the given name""" - boxed = self.value if type(self.value) is np.ndarray else [self.value] - data = group.create_dataset(name, data=boxed) - data.attrs["units"] = self.units.ascii_symbol - class NamedQuantity[QuantityType](Quantity[QuantityType]): - def __init__( - self, name: str, value: QuantityType, units: Unit, standard_error: QuantityType | None = None, id_header="" - ): - super().__init__(value, units, standard_error=standard_error, hash_seed=name, name=name, id_header=id_header) + def __init__(self, + name: str, + value: QuantityType, + units: Unit, + standard_error: QuantityType | None = None): + + super().__init__(value, units, standard_error=standard_error, hash_seed=name) + self.name = name def __repr__(self): return f"[{self.name}] " + super().__repr__() def to_units_of(self, new_units: Unit) -> "NamedQuantity[QuantityType]": new_value, new_error = self.in_units_of_with_standard_error(new_units) - return NamedQuantity(value=new_value, units=new_units, standard_error=new_error, name=self.name) + return NamedQuantity(value=new_value, + units=new_units, + standard_error=new_error, + name=self.name) def with_standard_error(self, standard_error: Quantity): if standard_error.units.equivalent(self.units): @@ -1608,20 +1415,17 @@ def with_standard_error(self, standard_error: Quantity): value=self.value, units=self.units, standard_error=standard_error.in_units_of(self.units), - name=self.name, - id_header=self._id_header, - ) + name=self.name) else: - raise UnitError( - f"Standard error units ({standard_error.units}) are not compatible with value units ({self.units})" - ) + raise UnitError(f"Standard error units ({standard_error.units}) " + f"are not compatible with value units ({self.units})") + @property def string_repr(self): return self.name - class DerivedQuantity[QuantityType](Quantity[QuantityType]): def __init__(self, value: QuantityType, units: Unit, history: QuantityHistory): super().__init__(value, units, standard_error=None) @@ -1630,9 +1434,13 @@ def __init__(self, value: QuantityType, units: Unit, history: QuantityHistory): self._variance_cache = None self._has_variance = history.has_variance() + def to_units_of(self, new_units: Unit) -> "Quantity[QuantityType]": # TODO: Lots of tests needed for this - return DerivedQuantity(value=self.in_units_of(new_units), units=new_units, history=self.history) + return DerivedQuantity( + value=self.in_units_of(new_units), + units=new_units, + history=self.history) @property def has_variance(self): diff --git a/sasdata/quantities/si.py b/sasdata/quantities/si.py index 45961b0bf..6e2b77169 100644 --- a/sasdata/quantities/si.py +++ b/sasdata/quantities/si.py @@ -5,75 +5,75 @@ Do not edit by hand, instead edit the files that build it (_build_tables.py) - - -DDDDDDDDDDDDD NNNNNNNN NNNNNNNN tttt -D::::::::::::DDD N:::::::N N::::::N ttt:::t -D:::::::::::::::DD N::::::::N N::::::N t:::::t -DDD:::::DDDDD:::::D N:::::::::N N::::::N t:::::t - D:::::D D:::::D ooooooooooo N::::::::::N N::::::N ooooooooooo ttttttt:::::ttttttt - D:::::D D:::::D oo:::::::::::oo N:::::::::::N N::::::N oo:::::::::::oo t:::::::::::::::::t - D:::::D D:::::Do:::::::::::::::o N:::::::N::::N N::::::No:::::::::::::::ot:::::::::::::::::t - D:::::D D:::::Do:::::ooooo:::::o N::::::N N::::N N::::::No:::::ooooo:::::otttttt:::::::tttttt - D:::::D D:::::Do::::o o::::o N::::::N N::::N:::::::No::::o o::::o t:::::t - D:::::D D:::::Do::::o o::::o N::::::N N:::::::::::No::::o o::::o t:::::t - D:::::D D:::::Do::::o o::::o N::::::N N::::::::::No::::o o::::o t:::::t - D:::::D D:::::D o::::o o::::o N::::::N N:::::::::No::::o o::::o t:::::t tttttt -DDD:::::DDDDD:::::D o:::::ooooo:::::o N::::::N N::::::::No:::::ooooo:::::o t::::::tttt:::::t -D:::::::::::::::DD o:::::::::::::::o N::::::N N:::::::No:::::::::::::::o tt::::::::::::::t -D::::::::::::DDD oo:::::::::::oo N::::::N N::::::N oo:::::::::::oo tt:::::::::::tt -DDDDDDDDDDDDD ooooooooooo NNNNNNNN NNNNNNN ooooooooooo ttttttttttt - - - - - - - - - dddddddd -EEEEEEEEEEEEEEEEEEEEEE d::::::d iiii tttt BBBBBBBBBBBBBBBBB -E::::::::::::::::::::E d::::::d i::::i ttt:::t B::::::::::::::::B -E::::::::::::::::::::E d::::::d iiii t:::::t B::::::BBBBBB:::::B -EE::::::EEEEEEEEE::::E d:::::d t:::::t BB:::::B B:::::B + + +DDDDDDDDDDDDD NNNNNNNN NNNNNNNN tttt +D::::::::::::DDD N:::::::N N::::::N ttt:::t +D:::::::::::::::DD N::::::::N N::::::N t:::::t +DDD:::::DDDDD:::::D N:::::::::N N::::::N t:::::t + D:::::D D:::::D ooooooooooo N::::::::::N N::::::N ooooooooooo ttttttt:::::ttttttt + D:::::D D:::::D oo:::::::::::oo N:::::::::::N N::::::N oo:::::::::::oo t:::::::::::::::::t + D:::::D D:::::Do:::::::::::::::o N:::::::N::::N N::::::No:::::::::::::::ot:::::::::::::::::t + D:::::D D:::::Do:::::ooooo:::::o N::::::N N::::N N::::::No:::::ooooo:::::otttttt:::::::tttttt + D:::::D D:::::Do::::o o::::o N::::::N N::::N:::::::No::::o o::::o t:::::t + D:::::D D:::::Do::::o o::::o N::::::N N:::::::::::No::::o o::::o t:::::t + D:::::D D:::::Do::::o o::::o N::::::N N::::::::::No::::o o::::o t:::::t + D:::::D D:::::D o::::o o::::o N::::::N N:::::::::No::::o o::::o t:::::t tttttt +DDD:::::DDDDD:::::D o:::::ooooo:::::o N::::::N N::::::::No:::::ooooo:::::o t::::::tttt:::::t +D:::::::::::::::DD o:::::::::::::::o N::::::N N:::::::No:::::::::::::::o tt::::::::::::::t +D::::::::::::DDD oo:::::::::::oo N::::::N N::::::N oo:::::::::::oo tt:::::::::::tt +DDDDDDDDDDDDD ooooooooooo NNNNNNNN NNNNNNN ooooooooooo ttttttttttt + + + + + + + + + dddddddd +EEEEEEEEEEEEEEEEEEEEEE d::::::d iiii tttt BBBBBBBBBBBBBBBBB +E::::::::::::::::::::E d::::::d i::::i ttt:::t B::::::::::::::::B +E::::::::::::::::::::E d::::::d iiii t:::::t B::::::BBBBBB:::::B +EE::::::EEEEEEEEE::::E d:::::d t:::::t BB:::::B B:::::B E:::::E EEEEEE ddddddddd:::::d iiiiiiittttttt:::::ttttttt B::::B B:::::Byyyyyyy yyyyyyy - E:::::E dd::::::::::::::d i:::::it:::::::::::::::::t B::::B B:::::B y:::::y y:::::y - E::::::EEEEEEEEEE d::::::::::::::::d i::::it:::::::::::::::::t B::::BBBBBB:::::B y:::::y y:::::y - E:::::::::::::::E d:::::::ddddd:::::d i::::itttttt:::::::tttttt B:::::::::::::BB y:::::y y:::::y - E:::::::::::::::E d::::::d d:::::d i::::i t:::::t B::::BBBBBB:::::B y:::::y y:::::y - E::::::EEEEEEEEEE d:::::d d:::::d i::::i t:::::t B::::B B:::::B y:::::y y:::::y - E:::::E d:::::d d:::::d i::::i t:::::t B::::B B:::::B y:::::y:::::y - E:::::E EEEEEEd:::::d d:::::d i::::i t:::::t tttttt B::::B B:::::B y:::::::::y -EE::::::EEEEEEEE:::::Ed::::::ddddd::::::ddi::::::i t::::::tttt:::::t BB:::::BBBBBB::::::B y:::::::y -E::::::::::::::::::::E d:::::::::::::::::di::::::i tt::::::::::::::t B:::::::::::::::::B y:::::y -E::::::::::::::::::::E d:::::::::ddd::::di::::::i tt:::::::::::tt B::::::::::::::::B y:::::y -EEEEEEEEEEEEEEEEEEEEEE ddddddddd dddddiiiiiiii ttttttttttt BBBBBBBBBBBBBBBBB y:::::y - y:::::y - y:::::y - y:::::y - y:::::y - yyyyyyy - - - - dddddddd -HHHHHHHHH HHHHHHHHH d::::::d -H:::::::H H:::::::H d::::::d -H:::::::H H:::::::H d::::::d -HH::::::H H::::::HH d:::::d - H:::::H H:::::H aaaaaaaaaaaaa nnnn nnnnnnnn ddddddddd:::::d - H:::::H H:::::H a::::::::::::a n:::nn::::::::nn dd::::::::::::::d - H::::::HHHHH::::::H aaaaaaaaa:::::an::::::::::::::nn d::::::::::::::::d - H:::::::::::::::::H a::::ann:::::::::::::::nd:::::::ddddd:::::d - H:::::::::::::::::H aaaaaaa:::::a n:::::nnnn:::::nd::::::d d:::::d - H::::::HHHHH::::::H aa::::::::::::a n::::n n::::nd:::::d d:::::d - H:::::H H:::::H a::::aaaa::::::a n::::n n::::nd:::::d d:::::d - H:::::H H:::::H a::::a a:::::a n::::n n::::nd:::::d d:::::d -HH::::::H H::::::HHa::::a a:::::a n::::n n::::nd::::::ddddd::::::dd -H:::::::H H:::::::Ha:::::aaaa::::::a n::::n n::::n d:::::::::::::::::d -H:::::::H H:::::::H a::::::::::aa:::a n::::n n::::n d:::::::::ddd::::d -HHHHHHHHH HHHHHHHHH aaaaaaaaaa aaaa nnnnnn nnnnnn ddddddddd ddddd - + E:::::E dd::::::::::::::d i:::::it:::::::::::::::::t B::::B B:::::B y:::::y y:::::y + E::::::EEEEEEEEEE d::::::::::::::::d i::::it:::::::::::::::::t B::::BBBBBB:::::B y:::::y y:::::y + E:::::::::::::::E d:::::::ddddd:::::d i::::itttttt:::::::tttttt B:::::::::::::BB y:::::y y:::::y + E:::::::::::::::E d::::::d d:::::d i::::i t:::::t B::::BBBBBB:::::B y:::::y y:::::y + E::::::EEEEEEEEEE d:::::d d:::::d i::::i t:::::t B::::B B:::::B y:::::y y:::::y + E:::::E d:::::d d:::::d i::::i t:::::t B::::B B:::::B y:::::y:::::y + E:::::E EEEEEEd:::::d d:::::d i::::i t:::::t tttttt B::::B B:::::B y:::::::::y +EE::::::EEEEEEEE:::::Ed::::::ddddd::::::ddi::::::i t::::::tttt:::::t BB:::::BBBBBB::::::B y:::::::y +E::::::::::::::::::::E d:::::::::::::::::di::::::i tt::::::::::::::t B:::::::::::::::::B y:::::y +E::::::::::::::::::::E d:::::::::ddd::::di::::::i tt:::::::::::tt B::::::::::::::::B y:::::y +EEEEEEEEEEEEEEEEEEEEEE ddddddddd dddddiiiiiiii ttttttttttt BBBBBBBBBBBBBBBBB y:::::y + y:::::y + y:::::y + y:::::y + y:::::y + yyyyyyy + + + + dddddddd +HHHHHHHHH HHHHHHHHH d::::::d +H:::::::H H:::::::H d::::::d +H:::::::H H:::::::H d::::::d +HH::::::H H::::::HH d:::::d + H:::::H H:::::H aaaaaaaaaaaaa nnnn nnnnnnnn ddddddddd:::::d + H:::::H H:::::H a::::::::::::a n:::nn::::::::nn dd::::::::::::::d + H::::::HHHHH::::::H aaaaaaaaa:::::an::::::::::::::nn d::::::::::::::::d + H:::::::::::::::::H a::::ann:::::::::::::::nd:::::::ddddd:::::d + H:::::::::::::::::H aaaaaaa:::::a n:::::nnnn:::::nd::::::d d:::::d + H::::::HHHHH::::::H aa::::::::::::a n::::n n::::nd:::::d d:::::d + H:::::H H:::::H a::::aaaa::::::a n::::n n::::nd:::::d d:::::d + H:::::H H:::::H a::::a a:::::a n::::n n::::nd:::::d d:::::d +HH::::::H H::::::HHa::::a a:::::a n::::n n::::nd::::::ddddd::::::dd +H:::::::H H:::::::Ha:::::aaaa::::::a n::::n n::::n d:::::::::::::::::d +H:::::::H H:::::::H a::::::::::aa:::a n::::n n::::n d:::::::::ddd::::d +HHHHHHHHH HHHHHHHHH aaaaaaaaaa aaaa nnnnnn nnnnnn ddddddddd ddddd + """ @@ -100,22 +100,22 @@ ) all_si = [ + meters, + seconds, amperes, - coulombs, - farads, - henry, - hertz, - joules, kelvin, - kilograms, - meters, + hertz, newtons, - ohms, pascals, - seconds, - siemens, - tesla, - volts, + joules, watts, + coulombs, + volts, + ohms, + farads, + siemens, webers, + tesla, + henry, + kilograms, ] diff --git a/sasdata/quantities/test_numerical_encoding.py b/sasdata/quantities/test_numerical_encoding.py index b7fb7cfed..80cfbad9a 100644 --- a/sasdata/quantities/test_numerical_encoding.py +++ b/sasdata/quantities/test_numerical_encoding.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from sasdata.quantities.numerical_encoding import numerical_decode, numerical_encode +from sasdata.quantities.numerical_encoding import numerical_encode, numerical_decode @pytest.mark.parametrize("value", [-100.0, -10.0, -1.0, 0.0, 0.5, 1.0, 10.0, 100.0, 1e100]) @@ -63,4 +63,6 @@ def test_numpy_dtypes_encode_decode(dtype): ]) def test_coo_matrix_encode_decode(shape, n, m, dtype): - values = np.arange(10) + i_indices = + + values = np.arange(10) \ No newline at end of file diff --git a/sasdata/quantities/unit_formatting.py b/sasdata/quantities/unit_formatting.py index e63921329..904ce801f 100644 --- a/sasdata/quantities/unit_formatting.py +++ b/sasdata/quantities/unit_formatting.py @@ -1,4 +1,7 @@ + + + import numpy as np diff --git a/sasdata/quantities/unit_parser.py b/sasdata/quantities/unit_parser.py index a264489d6..26def035d 100644 --- a/sasdata/quantities/unit_parser.py +++ b/sasdata/quantities/unit_parser.py @@ -10,21 +10,16 @@ for group in all_units_groups: all_units.extend(group) - def split_unit_str(unit_str: str) -> list[str]: """Separate the letters from the numbers in unit_str""" - return findall(r"[A-Za-zΩ%Å]+|[-\d]+|/", unit_str) - + return findall(r'[A-Za-zΩ%Å]+|[-\d]+|/', unit_str) def validate_unit_str(unit_str: str) -> bool: """Validate whether unit_str is valid. This doesn't mean that the unit specified in unit_str exists but rather it only consists of letters, and numbers as a unit string should.""" - return fullmatch(r"[A-Za-zΩµ%Å^1-9⁻¹-⁹\-\+/\ \._]+", unit_str) is not None + return fullmatch(r'[A-Za-zΩ%Å^1-9\-\+/\ \.]+', unit_str) is not None - -def parse_single_unit( - unit_str: str, unit_group: UnitGroup | None = None, longest_unit: bool = True -) -> tuple[Unit | None, str]: +def parse_single_unit(unit_str: str, unit_group: UnitGroup | None = None, longest_unit: bool = True) -> tuple[Unit | None, str]: """Attempts to find a single unit for unit_str. Return this unit, and the remaining string in a tuple. If a unit cannot be parsed, the unit will be None, and the remaining string will be the entire unit_str. @@ -33,7 +28,7 @@ def parse_single_unit( If unit_group is set, it will only try to parse units within that group. This is useful for resolving ambiguities. """ - current_unit = "" + current_unit = '' string_pos = 0 if unit_group is None: lookup_dict = symbol_lookup @@ -41,45 +36,34 @@ def parse_single_unit( lookup_dict = dict([(name, unit) for name, unit in symbol_lookup.items() if unit in unit_group.units]) for next_char in unit_str: potential_unit_str = current_unit + next_char - potential_symbols = [ - symbol - for symbol, unit in lookup_dict.items() - if symbol.startswith(potential_unit_str) or unit.startswith(potential_unit_str) - ] + potential_symbols = [symbol for symbol in lookup_dict.keys() if symbol.startswith(potential_unit_str)] if len(potential_symbols) == 0: break string_pos += 1 current_unit = potential_unit_str - if not longest_unit and current_unit in lookup_dict: + if not longest_unit and current_unit in lookup_dict.keys(): break - if current_unit == "": + if current_unit == '': return None, unit_str - matching_types = [unit for symbol, unit in lookup_dict.items() if symbol == current_unit or unit == current_unit] - if not matching_types: - raise KeyError(f"No known type matching {current_unit}") - final_unit = matching_types[0] remaining_str = unit_str[string_pos::] - return final_unit, remaining_str - + return lookup_dict[current_unit], remaining_str -def parse_unit_strs(unit_str: str, current_units: list[Unit] | None = None, longest_unit: bool = True) -> list[Unit]: +def parse_unit_strs(unit_str: str, current_units: list[Unit] | None=None, longest_unit: bool = True) -> list[Unit]: """Recursively parse units from unit_str until no more characters are present.""" if current_units is None: current_units = [] - if unit_str == "": + if unit_str == '': return current_units parsed_unit, remaining_str = parse_single_unit(unit_str, longest_unit=longest_unit) if parsed_unit is not None: current_units += [parsed_unit] return parse_unit_strs(remaining_str, current_units, longest_unit) else: - raise ValueError(f"Could not interpret {remaining_str}") - + raise ValueError(f'Could not interpret {remaining_str}') # Its probably useful to work out the unit first, and then later work out if a named unit exists for it. Hence why there # are two functions. - def parse_unit_stack(unit_str: str, longest_unit: bool = True) -> list[Unit]: """Split unit_str into a stack of parsed units.""" unit_stack: list[Unit] = [] @@ -87,12 +71,12 @@ def parse_unit_stack(unit_str: str, longest_unit: bool = True) -> list[Unit]: inverse_next_unit = False for token in split_str: try: - if token == "/": + if token == '/': inverse_next_unit = True continue power = int(token) to_modify = unit_stack[-1] - modified = to_modify**power + modified = to_modify ** power # modified = unit_power(to_modify, power) unit_stack[-1] = modified except ValueError: @@ -109,34 +93,19 @@ def parse_unit_stack(unit_str: str, longest_unit: bool = True) -> list[Unit]: pass return unit_stack - -def known_mistake(unit_str: str) -> Unit | None: - """Take known broken units from historical files - and give them a reasonible parse""" - import sasdata.quantities.units as units - - mistakes = {"µm": units.micrometers, "per_centimeter": units.per_centimeter, "per_angstrom": units.per_angstrom} - if unit_str in mistakes: - return mistakes[unit_str] - return None - - def parse_unit(unit_str: str, longest_unit: bool = True) -> Unit: """Parse unit_str into a unit.""" - if result := known_mistake(unit_str): - return result try: if not validate_unit_str(unit_str): - raise ValueError(f"unit_str ({unit_str}) contains forbidden characters.") + raise ValueError('unit_str contains forbidden characters.') parsed_unit = Unit(1, Dimensions()) unit_stack = parse_unit_stack(unit_str, longest_unit) for unit in unit_stack: # parsed_unit = combine_units(parsed_unit, unit) parsed_unit *= unit return parsed_unit - except KeyError as ex: - raise ValueError(f"Unit string contains an unrecognised pattern: {unit_str}") from ex - + except KeyError: + raise ValueError('Unit string contains an unrecognised pattern.') def parse_unit_from_group(unit_str: str, from_group: UnitGroup) -> Unit | None: """Tries to use the given unit group to resolve ambiguities. Parse a unit twice with different options, and returns @@ -150,8 +119,7 @@ def parse_unit_from_group(unit_str: str, from_group: UnitGroup) -> Unit | None: else: return None - -def parse_named_unit(unit_string: str, rtol: float = 1e-14) -> NamedUnit: +def parse_named_unit(unit_string: str, rtol: float=1e-14) -> NamedUnit: """Parses unit into a named unit. Parses unit into a Unit if it is not already, and then finds an equivaelent named unit. Please note that this might not be the expected unit from the string itself. E.g. 'kgm/2' will become newtons. @@ -166,15 +134,14 @@ def parse_named_unit(unit_string: str, rtol: float = 1e-14) -> NamedUnit: else: return named_unit - -def find_named_unit(unit: Unit, rtol: float = 1e-14) -> NamedUnit | None: - """Find a named unit matching the one provided""" +def find_named_unit(unit: Unit, rtol: float=1e-14) -> NamedUnit | None: + """ Find a named unit matching the one provided """ dimension_hash = hash(unit.dimensions) if dimension_hash in unit_groups_by_dimension_hash: unit_group = unit_groups_by_dimension_hash[hash(unit.dimensions)] for named_unit in unit_group.units: - if abs(named_unit.scale - unit.scale) < rtol * named_unit.scale: + if abs(named_unit.scale - unit.scale) < rtol*named_unit.scale: return named_unit return None @@ -185,13 +152,14 @@ def parse_named_unit_from_group(unit_str: str, from_group: UnitGroup) -> NamedUn unit that is present in from_group is returned. This is useful in cases of ambiguities.""" parsed_unit = parse_unit_from_group(unit_str, from_group) if parsed_unit is None: - raise ValueError("That unit cannot be parsed from the specified group.") + raise ValueError('That unit cannot be parsed from the specified group.') return find_named_unit(parsed_unit) +def parse(string: str, + name_lookup: bool = True, + longest_unit: bool = True, + lookup_rtol: float = 1e-14): -def parse(string: str, name_lookup: bool = True, longest_unit: bool = True, lookup_rtol: float = 1e-14): - if type(string) is not str: - string = string.decode("utf-8") unit = parse_unit(string, longest_unit=longest_unit) if name_lookup: named = find_named_unit(unit, rtol=lookup_rtol) @@ -202,11 +170,11 @@ def parse(string: str, name_lookup: bool = True, longest_unit: bool = True, look if __name__ == "__main__": - to_parse = input("Enter a unit to parse: ") + to_parse = input('Enter a unit to parse: ') try: generic_unit = parse_unit(to_parse) - print(f"Generic Unit: {generic_unit}") + print(f'Generic Unit: {generic_unit}') named_unit = find_named_unit(generic_unit) - print(f"Named Unit: {named_unit}") + print(f'Named Unit: {named_unit}') except ValueError: - print("There is no named unit available.") + print('There is no named unit available.') diff --git a/sasdata/quantities/units.py b/sasdata/quantities/units.py index fe840ab85..7c2698bcd 100644 --- a/sasdata/quantities/units.py +++ b/sasdata/quantities/units.py @@ -5,75 +5,75 @@ Do not edit by hand, instead edit the files that build it (_build_tables.py, _units_base.py) - - -DDDDDDDDDDDDD NNNNNNNN NNNNNNNN tttt -D::::::::::::DDD N:::::::N N::::::N ttt:::t -D:::::::::::::::DD N::::::::N N::::::N t:::::t -DDD:::::DDDDD:::::D N:::::::::N N::::::N t:::::t - D:::::D D:::::D ooooooooooo N::::::::::N N::::::N ooooooooooo ttttttt:::::ttttttt - D:::::D D:::::D oo:::::::::::oo N:::::::::::N N::::::N oo:::::::::::oo t:::::::::::::::::t - D:::::D D:::::Do:::::::::::::::o N:::::::N::::N N::::::No:::::::::::::::ot:::::::::::::::::t - D:::::D D:::::Do:::::ooooo:::::o N::::::N N::::N N::::::No:::::ooooo:::::otttttt:::::::tttttt - D:::::D D:::::Do::::o o::::o N::::::N N::::N:::::::No::::o o::::o t:::::t - D:::::D D:::::Do::::o o::::o N::::::N N:::::::::::No::::o o::::o t:::::t - D:::::D D:::::Do::::o o::::o N::::::N N::::::::::No::::o o::::o t:::::t - D:::::D D:::::D o::::o o::::o N::::::N N:::::::::No::::o o::::o t:::::t tttttt -DDD:::::DDDDD:::::D o:::::ooooo:::::o N::::::N N::::::::No:::::ooooo:::::o t::::::tttt:::::t -D:::::::::::::::DD o:::::::::::::::o N::::::N N:::::::No:::::::::::::::o tt::::::::::::::t -D::::::::::::DDD oo:::::::::::oo N::::::N N::::::N oo:::::::::::oo tt:::::::::::tt -DDDDDDDDDDDDD ooooooooooo NNNNNNNN NNNNNNN ooooooooooo ttttttttttt - - - - - - - - - dddddddd -EEEEEEEEEEEEEEEEEEEEEE d::::::d iiii tttt BBBBBBBBBBBBBBBBB -E::::::::::::::::::::E d::::::d i::::i ttt:::t B::::::::::::::::B -E::::::::::::::::::::E d::::::d iiii t:::::t B::::::BBBBBB:::::B -EE::::::EEEEEEEEE::::E d:::::d t:::::t BB:::::B B:::::B + + +DDDDDDDDDDDDD NNNNNNNN NNNNNNNN tttt +D::::::::::::DDD N:::::::N N::::::N ttt:::t +D:::::::::::::::DD N::::::::N N::::::N t:::::t +DDD:::::DDDDD:::::D N:::::::::N N::::::N t:::::t + D:::::D D:::::D ooooooooooo N::::::::::N N::::::N ooooooooooo ttttttt:::::ttttttt + D:::::D D:::::D oo:::::::::::oo N:::::::::::N N::::::N oo:::::::::::oo t:::::::::::::::::t + D:::::D D:::::Do:::::::::::::::o N:::::::N::::N N::::::No:::::::::::::::ot:::::::::::::::::t + D:::::D D:::::Do:::::ooooo:::::o N::::::N N::::N N::::::No:::::ooooo:::::otttttt:::::::tttttt + D:::::D D:::::Do::::o o::::o N::::::N N::::N:::::::No::::o o::::o t:::::t + D:::::D D:::::Do::::o o::::o N::::::N N:::::::::::No::::o o::::o t:::::t + D:::::D D:::::Do::::o o::::o N::::::N N::::::::::No::::o o::::o t:::::t + D:::::D D:::::D o::::o o::::o N::::::N N:::::::::No::::o o::::o t:::::t tttttt +DDD:::::DDDDD:::::D o:::::ooooo:::::o N::::::N N::::::::No:::::ooooo:::::o t::::::tttt:::::t +D:::::::::::::::DD o:::::::::::::::o N::::::N N:::::::No:::::::::::::::o tt::::::::::::::t +D::::::::::::DDD oo:::::::::::oo N::::::N N::::::N oo:::::::::::oo tt:::::::::::tt +DDDDDDDDDDDDD ooooooooooo NNNNNNNN NNNNNNN ooooooooooo ttttttttttt + + + + + + + + + dddddddd +EEEEEEEEEEEEEEEEEEEEEE d::::::d iiii tttt BBBBBBBBBBBBBBBBB +E::::::::::::::::::::E d::::::d i::::i ttt:::t B::::::::::::::::B +E::::::::::::::::::::E d::::::d iiii t:::::t B::::::BBBBBB:::::B +EE::::::EEEEEEEEE::::E d:::::d t:::::t BB:::::B B:::::B E:::::E EEEEEE ddddddddd:::::d iiiiiiittttttt:::::ttttttt B::::B B:::::Byyyyyyy yyyyyyy - E:::::E dd::::::::::::::d i:::::it:::::::::::::::::t B::::B B:::::B y:::::y y:::::y - E::::::EEEEEEEEEE d::::::::::::::::d i::::it:::::::::::::::::t B::::BBBBBB:::::B y:::::y y:::::y - E:::::::::::::::E d:::::::ddddd:::::d i::::itttttt:::::::tttttt B:::::::::::::BB y:::::y y:::::y - E:::::::::::::::E d::::::d d:::::d i::::i t:::::t B::::BBBBBB:::::B y:::::y y:::::y - E::::::EEEEEEEEEE d:::::d d:::::d i::::i t:::::t B::::B B:::::B y:::::y y:::::y - E:::::E d:::::d d:::::d i::::i t:::::t B::::B B:::::B y:::::y:::::y - E:::::E EEEEEEd:::::d d:::::d i::::i t:::::t tttttt B::::B B:::::B y:::::::::y -EE::::::EEEEEEEE:::::Ed::::::ddddd::::::ddi::::::i t::::::tttt:::::t BB:::::BBBBBB::::::B y:::::::y -E::::::::::::::::::::E d:::::::::::::::::di::::::i tt::::::::::::::t B:::::::::::::::::B y:::::y -E::::::::::::::::::::E d:::::::::ddd::::di::::::i tt:::::::::::tt B::::::::::::::::B y:::::y -EEEEEEEEEEEEEEEEEEEEEE ddddddddd dddddiiiiiiii ttttttttttt BBBBBBBBBBBBBBBBB y:::::y - y:::::y - y:::::y - y:::::y - y:::::y - yyyyyyy - - - - dddddddd -HHHHHHHHH HHHHHHHHH d::::::d -H:::::::H H:::::::H d::::::d -H:::::::H H:::::::H d::::::d -HH::::::H H::::::HH d:::::d - H:::::H H:::::H aaaaaaaaaaaaa nnnn nnnnnnnn ddddddddd:::::d - H:::::H H:::::H a::::::::::::a n:::nn::::::::nn dd::::::::::::::d - H::::::HHHHH::::::H aaaaaaaaa:::::an::::::::::::::nn d::::::::::::::::d - H:::::::::::::::::H a::::ann:::::::::::::::nd:::::::ddddd:::::d - H:::::::::::::::::H aaaaaaa:::::a n:::::nnnn:::::nd::::::d d:::::d - H::::::HHHHH::::::H aa::::::::::::a n::::n n::::nd:::::d d:::::d - H:::::H H:::::H a::::aaaa::::::a n::::n n::::nd:::::d d:::::d - H:::::H H:::::H a::::a a:::::a n::::n n::::nd:::::d d:::::d -HH::::::H H::::::HHa::::a a:::::a n::::n n::::nd::::::ddddd::::::dd -H:::::::H H:::::::Ha:::::aaaa::::::a n::::n n::::n d:::::::::::::::::d -H:::::::H H:::::::H a::::::::::aa:::a n::::n n::::n d:::::::::ddd::::d -HHHHHHHHH HHHHHHHHH aaaaaaaaaa aaaa nnnnnn nnnnnn ddddddddd ddddd - + E:::::E dd::::::::::::::d i:::::it:::::::::::::::::t B::::B B:::::B y:::::y y:::::y + E::::::EEEEEEEEEE d::::::::::::::::d i::::it:::::::::::::::::t B::::BBBBBB:::::B y:::::y y:::::y + E:::::::::::::::E d:::::::ddddd:::::d i::::itttttt:::::::tttttt B:::::::::::::BB y:::::y y:::::y + E:::::::::::::::E d::::::d d:::::d i::::i t:::::t B::::BBBBBB:::::B y:::::y y:::::y + E::::::EEEEEEEEEE d:::::d d:::::d i::::i t:::::t B::::B B:::::B y:::::y y:::::y + E:::::E d:::::d d:::::d i::::i t:::::t B::::B B:::::B y:::::y:::::y + E:::::E EEEEEEd:::::d d:::::d i::::i t:::::t tttttt B::::B B:::::B y:::::::::y +EE::::::EEEEEEEE:::::Ed::::::ddddd::::::ddi::::::i t::::::tttt:::::t BB:::::BBBBBB::::::B y:::::::y +E::::::::::::::::::::E d:::::::::::::::::di::::::i tt::::::::::::::t B:::::::::::::::::B y:::::y +E::::::::::::::::::::E d:::::::::ddd::::di::::::i tt:::::::::::tt B::::::::::::::::B y:::::y +EEEEEEEEEEEEEEEEEEEEEE ddddddddd dddddiiiiiiii ttttttttttt BBBBBBBBBBBBBBBBB y:::::y + y:::::y + y:::::y + y:::::y + y:::::y + yyyyyyy + + + + dddddddd +HHHHHHHHH HHHHHHHHH d::::::d +H:::::::H H:::::::H d::::::d +H:::::::H H:::::::H d::::::d +HH::::::H H::::::HH d:::::d + H:::::H H:::::H aaaaaaaaaaaaa nnnn nnnnnnnn ddddddddd:::::d + H:::::H H:::::H a::::::::::::a n:::nn::::::::nn dd::::::::::::::d + H::::::HHHHH::::::H aaaaaaaaa:::::an::::::::::::::nn d::::::::::::::::d + H:::::::::::::::::H a::::ann:::::::::::::::nd:::::::ddddd:::::d + H:::::::::::::::::H aaaaaaa:::::a n:::::nnnn:::::nd::::::d d:::::d + H::::::HHHHH::::::H aa::::::::::::a n::::n n::::nd:::::d d:::::d + H:::::H H:::::H a::::aaaa::::::a n::::n n::::nd:::::d d:::::d + H:::::H H:::::H a::::a a:::::a n::::n n::::nd:::::d d:::::d +HH::::::H H::::::HHa::::a a:::::a n::::n n::::nd::::::ddddd::::::dd +H:::::::H H:::::::Ha:::::aaaa::::::a n::::n n::::n d:::::::::::::::::d +H:::::::H H:::::::H a::::::::::aa:::a n::::n n::::n d:::::::::ddd::::d +HHHHHHHHH HHHHHHHHH aaaaaaaaaa aaaa nnnnnn nnnnnn ddddddddd ddddd + """ @@ -299,20 +299,17 @@ def _components(self, tokens: Sequence["UnitToken"]): pass def __mul__(self: Self, other: "Unit"): - if isinstance(other, Unit): - return Unit(self.scale * other.scale, self.dimensions * other.dimensions) - elif isinstance(other, (int, float)): - return Unit(other * self.scale, self.dimensions) - return NotImplemented + if not isinstance(other, Unit): + return NotImplemented + + return Unit(self.scale * other.scale, self.dimensions * other.dimensions) def __truediv__(self: Self, other: "Unit"): - if isinstance(other, Unit): - return Unit(self.scale / other.scale, self.dimensions / other.dimensions) - elif isinstance(other, (int, float)): - return Unit(self.scale / other, self.dimensions) - else: + if not isinstance(other, Unit): return NotImplemented + return Unit(self.scale / other.scale, self.dimensions / other.dimensions) + def __rtruediv__(self: Self, other: "Unit"): if isinstance(other, Unit): return Unit(other.scale / self.scale, other.dimensions / self.dimensions) @@ -380,27 +377,6 @@ def __init__(self, def __repr__(self): return self.name - def __eq__(self, other): - """Match other units exactly or match strings against ANY of our names""" - match other: - case str(): - return self.name == other or self.name == f"{other}s" or self.ascii_symbol == other or self.symbol == other - case NamedUnit(): - return self.name == other.name \ - and self.ascii_symbol == other.ascii_symbol and self.symbol == other.symbol - case Unit(): - return self.equivalent(other) and np.abs(np.log(self.scale/other.scale)) < 1e-5 - case _: - return False - - - def startswith(self, prefix: str) -> bool: - """Check if any representation of the unit begins with the prefix string""" - prefix = prefix.lower() - return (self.name is not None and self.name.lower().startswith(prefix)) \ - or (self.ascii_symbol is not None and self.ascii_symbol.lower().startswith(prefix)) \ - or (self.symbol is not None and self.symbol.lower().startswith(prefix)) - # # Parsing plan: # Require unknown amounts of units to be explicitly positive or negative? @@ -688,15 +664,12 @@ def __init__(self, name: str, units: list[NamedUnit]): femtohenry = NamedUnit(1e-15, Dimensions(2, -2, 1, -2, 0, 0, 0),name='femtohenry',ascii_symbol='fH',symbol='fH') attohenry = NamedUnit(1e-18, Dimensions(2, -2, 1, -2, 0, 0, 0),name='attohenry',ascii_symbol='aH',symbol='aH') angstroms = NamedUnit(1e-10, Dimensions(1, 0, 0, 0, 0, 0, 0),name='angstroms',ascii_symbol='Ang',latex_symbol=r'\AA',symbol='Å') -microns = NamedUnit(1e-06, Dimensions(1, 0, 0, 0, 0, 0, 0),name='microns',ascii_symbol='micron',symbol='micron') minutes = NamedUnit(60, Dimensions(0, 1, 0, 0, 0, 0, 0),name='minutes',ascii_symbol='min',symbol='min') -revolutions_per_minute = NamedUnit(0.016666666666666666, Dimensions(0, -1, 0, 0, 0, 0, 0),name='revolutions_per_minute',ascii_symbol='rpm',symbol='rpm') -hours = NamedUnit(3600, Dimensions(0, 1, 0, 0, 0, 0, 0),name='hours',ascii_symbol='h',symbol='h') -days = NamedUnit(86400, Dimensions(0, 1, 0, 0, 0, 0, 0),name='days',ascii_symbol='d',symbol='d') -years = NamedUnit(31556952.0, Dimensions(0, 1, 0, 0, 0, 0, 0),name='years',ascii_symbol='y',symbol='y') +hours = NamedUnit(360, Dimensions(0, 1, 0, 0, 0, 0, 0),name='hours',ascii_symbol='h',symbol='h') +days = NamedUnit(8640, Dimensions(0, 1, 0, 0, 0, 0, 0),name='days',ascii_symbol='d',symbol='d') +years = NamedUnit(3155695.2, Dimensions(0, 1, 0, 0, 0, 0, 0),name='years',ascii_symbol='y',symbol='y') degrees = NamedUnit(57.29577951308232, Dimensions(0, 0, 0, 0, 0, 0, 1),name='degrees',ascii_symbol='deg',symbol='deg') radians = NamedUnit(1, Dimensions(0, 0, 0, 0, 0, 0, 1),name='radians',ascii_symbol='rad',symbol='rad') -rotations = NamedUnit(6.283185307179586, Dimensions(0, 0, 0, 0, 0, 0, 1),name='rotations',ascii_symbol='rot',symbol='rot') stradians = NamedUnit(1, Dimensions(0, 0, 0, 0, 0, 0, 2),name='stradians',ascii_symbol='sr',symbol='sr') litres = NamedUnit(0.001, Dimensions(3, 0, 0, 0, 0, 0, 0),name='litres',ascii_symbol='l',symbol='l') electronvolts = NamedUnit(1.602176634e-19, Dimensions(2, -2, 1, 0, 0, 0, 0),name='electronvolts',ascii_symbol='eV',symbol='eV') @@ -812,11 +785,6 @@ def __init__(self, name: str, units: list[NamedUnit]): per_angstrom = NamedUnit(10000000000.0, Dimensions(length=-1), name='per_angstrom', ascii_symbol='Ang^-1', symbol='Å⁻¹') per_square_angstrom = NamedUnit(1e+20, Dimensions(length=-2), name='per_square_angstrom', ascii_symbol='Ang^-2', symbol='Å⁻²') per_cubic_angstrom = NamedUnit(9.999999999999999e+29, Dimensions(length=-3), name='per_cubic_angstrom', ascii_symbol='Ang^-3', symbol='Å⁻³') -square_microns = NamedUnit(1e-12, Dimensions(length=2), name='square_microns', ascii_symbol='micron^2', symbol='micron²') -cubic_microns = NamedUnit(9.999999999999999e-19, Dimensions(length=3), name='cubic_microns', ascii_symbol='micron^3', symbol='micron³') -per_micron = NamedUnit(1000000.0, Dimensions(length=-1), name='per_micron', ascii_symbol='micron^-1', symbol='micron⁻¹') -per_square_micron = NamedUnit(1000000000000.0001, Dimensions(length=-2), name='per_square_micron', ascii_symbol='micron^-2', symbol='micron⁻²') -per_cubic_micron = NamedUnit(1.0000000000000001e+18, Dimensions(length=-3), name='per_cubic_micron', ascii_symbol='micron^-3', symbol='micron⁻³') square_miles = NamedUnit(2589988.110336, Dimensions(length=2), name='square_miles', ascii_symbol='miles^2', symbol='miles²') cubic_miles = NamedUnit(4168181825.44058, Dimensions(length=3), name='cubic_miles', ascii_symbol='miles^3', symbol='miles³') per_mile = NamedUnit(0.0006213711922373339, Dimensions(length=-1), name='per_mile', ascii_symbol='miles^-1', symbol='miles⁻¹') @@ -853,12 +821,12 @@ def __init__(self, name: str, units: list[NamedUnit]): meters_per_square_attosecond = NamedUnit(9.999999999999999e+35, Dimensions(length=1, time=-2), name='meters_per_square_attosecond', ascii_symbol='m/as^2', symbol='mas⁻²') meters_per_minute = NamedUnit(0.016666666666666666, Dimensions(length=1, time=-1), name='meters_per_minute', ascii_symbol='m/min', symbol='mmin⁻¹') meters_per_square_minute = NamedUnit(0.0002777777777777778, Dimensions(length=1, time=-2), name='meters_per_square_minute', ascii_symbol='m/min^2', symbol='mmin⁻²') -meters_per_hour = NamedUnit(0.0002777777777777778, Dimensions(length=1, time=-1), name='meters_per_hour', ascii_symbol='m/h', symbol='mh⁻¹') -meters_per_square_hour = NamedUnit(7.71604938271605e-08, Dimensions(length=1, time=-2), name='meters_per_square_hour', ascii_symbol='m/h^2', symbol='mh⁻²') -meters_per_day = NamedUnit(1.1574074074074073e-05, Dimensions(length=1, time=-1), name='meters_per_day', ascii_symbol='m/d', symbol='md⁻¹') -meters_per_square_day = NamedUnit(1.3395919067215363e-10, Dimensions(length=1, time=-2), name='meters_per_square_day', ascii_symbol='m/d^2', symbol='md⁻²') -meters_per_year = NamedUnit(3.168873850681143e-08, Dimensions(length=1, time=-1), name='meters_per_year', ascii_symbol='m/y', symbol='my⁻¹') -meters_per_square_year = NamedUnit(1.0041761481530735e-15, Dimensions(length=1, time=-2), name='meters_per_square_year', ascii_symbol='m/y^2', symbol='my⁻²') +meters_per_hour = NamedUnit(0.002777777777777778, Dimensions(length=1, time=-1), name='meters_per_hour', ascii_symbol='m/h', symbol='mh⁻¹') +meters_per_square_hour = NamedUnit(7.71604938271605e-06, Dimensions(length=1, time=-2), name='meters_per_square_hour', ascii_symbol='m/h^2', symbol='mh⁻²') +meters_per_day = NamedUnit(0.00011574074074074075, Dimensions(length=1, time=-1), name='meters_per_day', ascii_symbol='m/d', symbol='md⁻¹') +meters_per_square_day = NamedUnit(1.3395919067215363e-08, Dimensions(length=1, time=-2), name='meters_per_square_day', ascii_symbol='m/d^2', symbol='md⁻²') +meters_per_year = NamedUnit(3.168873850681143e-07, Dimensions(length=1, time=-1), name='meters_per_year', ascii_symbol='m/y', symbol='my⁻¹') +meters_per_square_year = NamedUnit(1.0041761481530735e-13, Dimensions(length=1, time=-2), name='meters_per_square_year', ascii_symbol='m/y^2', symbol='my⁻²') exameters_per_second = NamedUnit(1e+18, Dimensions(length=1, time=-1), name='exameters_per_second', ascii_symbol='Em/s', symbol='Ems⁻¹') exameters_per_square_second = NamedUnit(1e+18, Dimensions(length=1, time=-2), name='exameters_per_square_second', ascii_symbol='Em/s^2', symbol='Ems⁻²') exameters_per_millisecond = NamedUnit(1e+21, Dimensions(length=1, time=-1), name='exameters_per_millisecond', ascii_symbol='Em/ms', symbol='Emms⁻¹') @@ -875,12 +843,12 @@ def __init__(self, name: str, units: list[NamedUnit]): exameters_per_square_attosecond = NamedUnit(9.999999999999999e+53, Dimensions(length=1, time=-2), name='exameters_per_square_attosecond', ascii_symbol='Em/as^2', symbol='Emas⁻²') exameters_per_minute = NamedUnit(1.6666666666666666e+16, Dimensions(length=1, time=-1), name='exameters_per_minute', ascii_symbol='Em/min', symbol='Emmin⁻¹') exameters_per_square_minute = NamedUnit(277777777777777.78, Dimensions(length=1, time=-2), name='exameters_per_square_minute', ascii_symbol='Em/min^2', symbol='Emmin⁻²') -exameters_per_hour = NamedUnit(277777777777777.78, Dimensions(length=1, time=-1), name='exameters_per_hour', ascii_symbol='Em/h', symbol='Emh⁻¹') -exameters_per_square_hour = NamedUnit(77160493827.16049, Dimensions(length=1, time=-2), name='exameters_per_square_hour', ascii_symbol='Em/h^2', symbol='Emh⁻²') -exameters_per_day = NamedUnit(11574074074074.074, Dimensions(length=1, time=-1), name='exameters_per_day', ascii_symbol='Em/d', symbol='Emd⁻¹') -exameters_per_square_day = NamedUnit(133959190.67215364, Dimensions(length=1, time=-2), name='exameters_per_square_day', ascii_symbol='Em/d^2', symbol='Emd⁻²') -exameters_per_year = NamedUnit(31688738506.81143, Dimensions(length=1, time=-1), name='exameters_per_year', ascii_symbol='Em/y', symbol='Emy⁻¹') -exameters_per_square_year = NamedUnit(1004.1761481530735, Dimensions(length=1, time=-2), name='exameters_per_square_year', ascii_symbol='Em/y^2', symbol='Emy⁻²') +exameters_per_hour = NamedUnit(2777777777777778.0, Dimensions(length=1, time=-1), name='exameters_per_hour', ascii_symbol='Em/h', symbol='Emh⁻¹') +exameters_per_square_hour = NamedUnit(7716049382716.05, Dimensions(length=1, time=-2), name='exameters_per_square_hour', ascii_symbol='Em/h^2', symbol='Emh⁻²') +exameters_per_day = NamedUnit(115740740740740.73, Dimensions(length=1, time=-1), name='exameters_per_day', ascii_symbol='Em/d', symbol='Emd⁻¹') +exameters_per_square_day = NamedUnit(13395919067.215364, Dimensions(length=1, time=-2), name='exameters_per_square_day', ascii_symbol='Em/d^2', symbol='Emd⁻²') +exameters_per_year = NamedUnit(316887385068.1143, Dimensions(length=1, time=-1), name='exameters_per_year', ascii_symbol='Em/y', symbol='Emy⁻¹') +exameters_per_square_year = NamedUnit(100417.61481530734, Dimensions(length=1, time=-2), name='exameters_per_square_year', ascii_symbol='Em/y^2', symbol='Emy⁻²') petameters_per_second = NamedUnit(1000000000000000.0, Dimensions(length=1, time=-1), name='petameters_per_second', ascii_symbol='Pm/s', symbol='Pms⁻¹') petameters_per_square_second = NamedUnit(1000000000000000.0, Dimensions(length=1, time=-2), name='petameters_per_square_second', ascii_symbol='Pm/s^2', symbol='Pms⁻²') petameters_per_millisecond = NamedUnit(1e+18, Dimensions(length=1, time=-1), name='petameters_per_millisecond', ascii_symbol='Pm/ms', symbol='Pmms⁻¹') @@ -897,12 +865,12 @@ def __init__(self, name: str, units: list[NamedUnit]): petameters_per_square_attosecond = NamedUnit(9.999999999999998e+50, Dimensions(length=1, time=-2), name='petameters_per_square_attosecond', ascii_symbol='Pm/as^2', symbol='Pmas⁻²') petameters_per_minute = NamedUnit(16666666666666.666, Dimensions(length=1, time=-1), name='petameters_per_minute', ascii_symbol='Pm/min', symbol='Pmmin⁻¹') petameters_per_square_minute = NamedUnit(277777777777.7778, Dimensions(length=1, time=-2), name='petameters_per_square_minute', ascii_symbol='Pm/min^2', symbol='Pmmin⁻²') -petameters_per_hour = NamedUnit(277777777777.7778, Dimensions(length=1, time=-1), name='petameters_per_hour', ascii_symbol='Pm/h', symbol='Pmh⁻¹') -petameters_per_square_hour = NamedUnit(77160493.82716049, Dimensions(length=1, time=-2), name='petameters_per_square_hour', ascii_symbol='Pm/h^2', symbol='Pmh⁻²') -petameters_per_day = NamedUnit(11574074074.074074, Dimensions(length=1, time=-1), name='petameters_per_day', ascii_symbol='Pm/d', symbol='Pmd⁻¹') -petameters_per_square_day = NamedUnit(133959.19067215364, Dimensions(length=1, time=-2), name='petameters_per_square_day', ascii_symbol='Pm/d^2', symbol='Pmd⁻²') -petameters_per_year = NamedUnit(31688738.506811433, Dimensions(length=1, time=-1), name='petameters_per_year', ascii_symbol='Pm/y', symbol='Pmy⁻¹') -petameters_per_square_year = NamedUnit(1.0041761481530735, Dimensions(length=1, time=-2), name='petameters_per_square_year', ascii_symbol='Pm/y^2', symbol='Pmy⁻²') +petameters_per_hour = NamedUnit(2777777777777.778, Dimensions(length=1, time=-1), name='petameters_per_hour', ascii_symbol='Pm/h', symbol='Pmh⁻¹') +petameters_per_square_hour = NamedUnit(7716049382.716049, Dimensions(length=1, time=-2), name='petameters_per_square_hour', ascii_symbol='Pm/h^2', symbol='Pmh⁻²') +petameters_per_day = NamedUnit(115740740740.74074, Dimensions(length=1, time=-1), name='petameters_per_day', ascii_symbol='Pm/d', symbol='Pmd⁻¹') +petameters_per_square_day = NamedUnit(13395919.067215364, Dimensions(length=1, time=-2), name='petameters_per_square_day', ascii_symbol='Pm/d^2', symbol='Pmd⁻²') +petameters_per_year = NamedUnit(316887385.0681143, Dimensions(length=1, time=-1), name='petameters_per_year', ascii_symbol='Pm/y', symbol='Pmy⁻¹') +petameters_per_square_year = NamedUnit(100.41761481530735, Dimensions(length=1, time=-2), name='petameters_per_square_year', ascii_symbol='Pm/y^2', symbol='Pmy⁻²') terameters_per_second = NamedUnit(1000000000000.0, Dimensions(length=1, time=-1), name='terameters_per_second', ascii_symbol='Tm/s', symbol='Tms⁻¹') terameters_per_square_second = NamedUnit(1000000000000.0, Dimensions(length=1, time=-2), name='terameters_per_square_second', ascii_symbol='Tm/s^2', symbol='Tms⁻²') terameters_per_millisecond = NamedUnit(1000000000000000.0, Dimensions(length=1, time=-1), name='terameters_per_millisecond', ascii_symbol='Tm/ms', symbol='Tmms⁻¹') @@ -919,12 +887,12 @@ def __init__(self, name: str, units: list[NamedUnit]): terameters_per_square_attosecond = NamedUnit(9.999999999999999e+47, Dimensions(length=1, time=-2), name='terameters_per_square_attosecond', ascii_symbol='Tm/as^2', symbol='Tmas⁻²') terameters_per_minute = NamedUnit(16666666666.666666, Dimensions(length=1, time=-1), name='terameters_per_minute', ascii_symbol='Tm/min', symbol='Tmmin⁻¹') terameters_per_square_minute = NamedUnit(277777777.7777778, Dimensions(length=1, time=-2), name='terameters_per_square_minute', ascii_symbol='Tm/min^2', symbol='Tmmin⁻²') -terameters_per_hour = NamedUnit(277777777.7777778, Dimensions(length=1, time=-1), name='terameters_per_hour', ascii_symbol='Tm/h', symbol='Tmh⁻¹') -terameters_per_square_hour = NamedUnit(77160.49382716049, Dimensions(length=1, time=-2), name='terameters_per_square_hour', ascii_symbol='Tm/h^2', symbol='Tmh⁻²') -terameters_per_day = NamedUnit(11574074.074074075, Dimensions(length=1, time=-1), name='terameters_per_day', ascii_symbol='Tm/d', symbol='Tmd⁻¹') -terameters_per_square_day = NamedUnit(133.95919067215362, Dimensions(length=1, time=-2), name='terameters_per_square_day', ascii_symbol='Tm/d^2', symbol='Tmd⁻²') -terameters_per_year = NamedUnit(31688.73850681143, Dimensions(length=1, time=-1), name='terameters_per_year', ascii_symbol='Tm/y', symbol='Tmy⁻¹') -terameters_per_square_year = NamedUnit(0.0010041761481530736, Dimensions(length=1, time=-2), name='terameters_per_square_year', ascii_symbol='Tm/y^2', symbol='Tmy⁻²') +terameters_per_hour = NamedUnit(2777777777.7777777, Dimensions(length=1, time=-1), name='terameters_per_hour', ascii_symbol='Tm/h', symbol='Tmh⁻¹') +terameters_per_square_hour = NamedUnit(7716049.382716049, Dimensions(length=1, time=-2), name='terameters_per_square_hour', ascii_symbol='Tm/h^2', symbol='Tmh⁻²') +terameters_per_day = NamedUnit(115740740.74074075, Dimensions(length=1, time=-1), name='terameters_per_day', ascii_symbol='Tm/d', symbol='Tmd⁻¹') +terameters_per_square_day = NamedUnit(13395.919067215364, Dimensions(length=1, time=-2), name='terameters_per_square_day', ascii_symbol='Tm/d^2', symbol='Tmd⁻²') +terameters_per_year = NamedUnit(316887.38506811426, Dimensions(length=1, time=-1), name='terameters_per_year', ascii_symbol='Tm/y', symbol='Tmy⁻¹') +terameters_per_square_year = NamedUnit(0.10041761481530735, Dimensions(length=1, time=-2), name='terameters_per_square_year', ascii_symbol='Tm/y^2', symbol='Tmy⁻²') gigameters_per_second = NamedUnit(1000000000.0, Dimensions(length=1, time=-1), name='gigameters_per_second', ascii_symbol='Gm/s', symbol='Gms⁻¹') gigameters_per_square_second = NamedUnit(1000000000.0, Dimensions(length=1, time=-2), name='gigameters_per_square_second', ascii_symbol='Gm/s^2', symbol='Gms⁻²') gigameters_per_millisecond = NamedUnit(1000000000000.0, Dimensions(length=1, time=-1), name='gigameters_per_millisecond', ascii_symbol='Gm/ms', symbol='Gmms⁻¹') @@ -941,12 +909,12 @@ def __init__(self, name: str, units: list[NamedUnit]): gigameters_per_square_attosecond = NamedUnit(1e+45, Dimensions(length=1, time=-2), name='gigameters_per_square_attosecond', ascii_symbol='Gm/as^2', symbol='Gmas⁻²') gigameters_per_minute = NamedUnit(16666666.666666666, Dimensions(length=1, time=-1), name='gigameters_per_minute', ascii_symbol='Gm/min', symbol='Gmmin⁻¹') gigameters_per_square_minute = NamedUnit(277777.77777777775, Dimensions(length=1, time=-2), name='gigameters_per_square_minute', ascii_symbol='Gm/min^2', symbol='Gmmin⁻²') -gigameters_per_hour = NamedUnit(277777.77777777775, Dimensions(length=1, time=-1), name='gigameters_per_hour', ascii_symbol='Gm/h', symbol='Gmh⁻¹') -gigameters_per_square_hour = NamedUnit(77.1604938271605, Dimensions(length=1, time=-2), name='gigameters_per_square_hour', ascii_symbol='Gm/h^2', symbol='Gmh⁻²') -gigameters_per_day = NamedUnit(11574.074074074075, Dimensions(length=1, time=-1), name='gigameters_per_day', ascii_symbol='Gm/d', symbol='Gmd⁻¹') -gigameters_per_square_day = NamedUnit(0.13395919067215364, Dimensions(length=1, time=-2), name='gigameters_per_square_day', ascii_symbol='Gm/d^2', symbol='Gmd⁻²') -gigameters_per_year = NamedUnit(31.688738506811433, Dimensions(length=1, time=-1), name='gigameters_per_year', ascii_symbol='Gm/y', symbol='Gmy⁻¹') -gigameters_per_square_year = NamedUnit(1.0041761481530736e-06, Dimensions(length=1, time=-2), name='gigameters_per_square_year', ascii_symbol='Gm/y^2', symbol='Gmy⁻²') +gigameters_per_hour = NamedUnit(2777777.777777778, Dimensions(length=1, time=-1), name='gigameters_per_hour', ascii_symbol='Gm/h', symbol='Gmh⁻¹') +gigameters_per_square_hour = NamedUnit(7716.049382716049, Dimensions(length=1, time=-2), name='gigameters_per_square_hour', ascii_symbol='Gm/h^2', symbol='Gmh⁻²') +gigameters_per_day = NamedUnit(115740.74074074074, Dimensions(length=1, time=-1), name='gigameters_per_day', ascii_symbol='Gm/d', symbol='Gmd⁻¹') +gigameters_per_square_day = NamedUnit(13.395919067215363, Dimensions(length=1, time=-2), name='gigameters_per_square_day', ascii_symbol='Gm/d^2', symbol='Gmd⁻²') +gigameters_per_year = NamedUnit(316.88738506811427, Dimensions(length=1, time=-1), name='gigameters_per_year', ascii_symbol='Gm/y', symbol='Gmy⁻¹') +gigameters_per_square_year = NamedUnit(0.00010041761481530735, Dimensions(length=1, time=-2), name='gigameters_per_square_year', ascii_symbol='Gm/y^2', symbol='Gmy⁻²') megameters_per_second = NamedUnit(1000000.0, Dimensions(length=1, time=-1), name='megameters_per_second', ascii_symbol='Mm/s', symbol='Mms⁻¹') megameters_per_square_second = NamedUnit(1000000.0, Dimensions(length=1, time=-2), name='megameters_per_square_second', ascii_symbol='Mm/s^2', symbol='Mms⁻²') megameters_per_millisecond = NamedUnit(1000000000.0, Dimensions(length=1, time=-1), name='megameters_per_millisecond', ascii_symbol='Mm/ms', symbol='Mmms⁻¹') @@ -963,12 +931,12 @@ def __init__(self, name: str, units: list[NamedUnit]): megameters_per_square_attosecond = NamedUnit(9.999999999999999e+41, Dimensions(length=1, time=-2), name='megameters_per_square_attosecond', ascii_symbol='Mm/as^2', symbol='Mmas⁻²') megameters_per_minute = NamedUnit(16666.666666666668, Dimensions(length=1, time=-1), name='megameters_per_minute', ascii_symbol='Mm/min', symbol='Mmmin⁻¹') megameters_per_square_minute = NamedUnit(277.77777777777777, Dimensions(length=1, time=-2), name='megameters_per_square_minute', ascii_symbol='Mm/min^2', symbol='Mmmin⁻²') -megameters_per_hour = NamedUnit(277.77777777777777, Dimensions(length=1, time=-1), name='megameters_per_hour', ascii_symbol='Mm/h', symbol='Mmh⁻¹') -megameters_per_square_hour = NamedUnit(0.07716049382716049, Dimensions(length=1, time=-2), name='megameters_per_square_hour', ascii_symbol='Mm/h^2', symbol='Mmh⁻²') -megameters_per_day = NamedUnit(11.574074074074074, Dimensions(length=1, time=-1), name='megameters_per_day', ascii_symbol='Mm/d', symbol='Mmd⁻¹') -megameters_per_square_day = NamedUnit(0.00013395919067215364, Dimensions(length=1, time=-2), name='megameters_per_square_day', ascii_symbol='Mm/d^2', symbol='Mmd⁻²') -megameters_per_year = NamedUnit(0.031688738506811434, Dimensions(length=1, time=-1), name='megameters_per_year', ascii_symbol='Mm/y', symbol='Mmy⁻¹') -megameters_per_square_year = NamedUnit(1.0041761481530736e-09, Dimensions(length=1, time=-2), name='megameters_per_square_year', ascii_symbol='Mm/y^2', symbol='Mmy⁻²') +megameters_per_hour = NamedUnit(2777.777777777778, Dimensions(length=1, time=-1), name='megameters_per_hour', ascii_symbol='Mm/h', symbol='Mmh⁻¹') +megameters_per_square_hour = NamedUnit(7.716049382716049, Dimensions(length=1, time=-2), name='megameters_per_square_hour', ascii_symbol='Mm/h^2', symbol='Mmh⁻²') +megameters_per_day = NamedUnit(115.74074074074075, Dimensions(length=1, time=-1), name='megameters_per_day', ascii_symbol='Mm/d', symbol='Mmd⁻¹') +megameters_per_square_day = NamedUnit(0.013395919067215363, Dimensions(length=1, time=-2), name='megameters_per_square_day', ascii_symbol='Mm/d^2', symbol='Mmd⁻²') +megameters_per_year = NamedUnit(0.3168873850681143, Dimensions(length=1, time=-1), name='megameters_per_year', ascii_symbol='Mm/y', symbol='Mmy⁻¹') +megameters_per_square_year = NamedUnit(1.0041761481530735e-07, Dimensions(length=1, time=-2), name='megameters_per_square_year', ascii_symbol='Mm/y^2', symbol='Mmy⁻²') kilometers_per_second = NamedUnit(1000.0, Dimensions(length=1, time=-1), name='kilometers_per_second', ascii_symbol='km/s', symbol='kms⁻¹') kilometers_per_square_second = NamedUnit(1000.0, Dimensions(length=1, time=-2), name='kilometers_per_square_second', ascii_symbol='km/s^2', symbol='kms⁻²') kilometers_per_millisecond = NamedUnit(1000000.0, Dimensions(length=1, time=-1), name='kilometers_per_millisecond', ascii_symbol='km/ms', symbol='kmms⁻¹') @@ -985,12 +953,12 @@ def __init__(self, name: str, units: list[NamedUnit]): kilometers_per_square_attosecond = NamedUnit(1e+39, Dimensions(length=1, time=-2), name='kilometers_per_square_attosecond', ascii_symbol='km/as^2', symbol='kmas⁻²') kilometers_per_minute = NamedUnit(16.666666666666668, Dimensions(length=1, time=-1), name='kilometers_per_minute', ascii_symbol='km/min', symbol='kmmin⁻¹') kilometers_per_square_minute = NamedUnit(0.2777777777777778, Dimensions(length=1, time=-2), name='kilometers_per_square_minute', ascii_symbol='km/min^2', symbol='kmmin⁻²') -kilometers_per_hour = NamedUnit(0.2777777777777778, Dimensions(length=1, time=-1), name='kilometers_per_hour', ascii_symbol='km/h', symbol='kmh⁻¹') -kilometers_per_square_hour = NamedUnit(7.716049382716049e-05, Dimensions(length=1, time=-2), name='kilometers_per_square_hour', ascii_symbol='km/h^2', symbol='kmh⁻²') -kilometers_per_day = NamedUnit(0.011574074074074073, Dimensions(length=1, time=-1), name='kilometers_per_day', ascii_symbol='km/d', symbol='kmd⁻¹') -kilometers_per_square_day = NamedUnit(1.3395919067215364e-07, Dimensions(length=1, time=-2), name='kilometers_per_square_day', ascii_symbol='km/d^2', symbol='kmd⁻²') -kilometers_per_year = NamedUnit(3.168873850681143e-05, Dimensions(length=1, time=-1), name='kilometers_per_year', ascii_symbol='km/y', symbol='kmy⁻¹') -kilometers_per_square_year = NamedUnit(1.0041761481530736e-12, Dimensions(length=1, time=-2), name='kilometers_per_square_year', ascii_symbol='km/y^2', symbol='kmy⁻²') +kilometers_per_hour = NamedUnit(2.7777777777777777, Dimensions(length=1, time=-1), name='kilometers_per_hour', ascii_symbol='km/h', symbol='kmh⁻¹') +kilometers_per_square_hour = NamedUnit(0.007716049382716049, Dimensions(length=1, time=-2), name='kilometers_per_square_hour', ascii_symbol='km/h^2', symbol='kmh⁻²') +kilometers_per_day = NamedUnit(0.11574074074074074, Dimensions(length=1, time=-1), name='kilometers_per_day', ascii_symbol='km/d', symbol='kmd⁻¹') +kilometers_per_square_day = NamedUnit(1.3395919067215363e-05, Dimensions(length=1, time=-2), name='kilometers_per_square_day', ascii_symbol='km/d^2', symbol='kmd⁻²') +kilometers_per_year = NamedUnit(0.0003168873850681143, Dimensions(length=1, time=-1), name='kilometers_per_year', ascii_symbol='km/y', symbol='kmy⁻¹') +kilometers_per_square_year = NamedUnit(1.0041761481530735e-10, Dimensions(length=1, time=-2), name='kilometers_per_square_year', ascii_symbol='km/y^2', symbol='kmy⁻²') millimeters_per_second = NamedUnit(0.001, Dimensions(length=1, time=-1), name='millimeters_per_second', ascii_symbol='mm/s', symbol='mms⁻¹') millimeters_per_square_second = NamedUnit(0.001, Dimensions(length=1, time=-2), name='millimeters_per_square_second', ascii_symbol='mm/s^2', symbol='mms⁻²') millimeters_per_millisecond = NamedUnit(1.0, Dimensions(length=1, time=-1), name='millimeters_per_millisecond', ascii_symbol='mm/ms', symbol='mmms⁻¹') @@ -1007,12 +975,12 @@ def __init__(self, name: str, units: list[NamedUnit]): millimeters_per_square_attosecond = NamedUnit(1e+33, Dimensions(length=1, time=-2), name='millimeters_per_square_attosecond', ascii_symbol='mm/as^2', symbol='mmas⁻²') millimeters_per_minute = NamedUnit(1.6666666666666667e-05, Dimensions(length=1, time=-1), name='millimeters_per_minute', ascii_symbol='mm/min', symbol='mmmin⁻¹') millimeters_per_square_minute = NamedUnit(2.7777777777777776e-07, Dimensions(length=1, time=-2), name='millimeters_per_square_minute', ascii_symbol='mm/min^2', symbol='mmmin⁻²') -millimeters_per_hour = NamedUnit(2.7777777777777776e-07, Dimensions(length=1, time=-1), name='millimeters_per_hour', ascii_symbol='mm/h', symbol='mmh⁻¹') -millimeters_per_square_hour = NamedUnit(7.716049382716049e-11, Dimensions(length=1, time=-2), name='millimeters_per_square_hour', ascii_symbol='mm/h^2', symbol='mmh⁻²') -millimeters_per_day = NamedUnit(1.1574074074074074e-08, Dimensions(length=1, time=-1), name='millimeters_per_day', ascii_symbol='mm/d', symbol='mmd⁻¹') -millimeters_per_square_day = NamedUnit(1.3395919067215364e-13, Dimensions(length=1, time=-2), name='millimeters_per_square_day', ascii_symbol='mm/d^2', symbol='mmd⁻²') -millimeters_per_year = NamedUnit(3.168873850681143e-11, Dimensions(length=1, time=-1), name='millimeters_per_year', ascii_symbol='mm/y', symbol='mmy⁻¹') -millimeters_per_square_year = NamedUnit(1.0041761481530737e-18, Dimensions(length=1, time=-2), name='millimeters_per_square_year', ascii_symbol='mm/y^2', symbol='mmy⁻²') +millimeters_per_hour = NamedUnit(2.777777777777778e-06, Dimensions(length=1, time=-1), name='millimeters_per_hour', ascii_symbol='mm/h', symbol='mmh⁻¹') +millimeters_per_square_hour = NamedUnit(7.71604938271605e-09, Dimensions(length=1, time=-2), name='millimeters_per_square_hour', ascii_symbol='mm/h^2', symbol='mmh⁻²') +millimeters_per_day = NamedUnit(1.1574074074074074e-07, Dimensions(length=1, time=-1), name='millimeters_per_day', ascii_symbol='mm/d', symbol='mmd⁻¹') +millimeters_per_square_day = NamedUnit(1.3395919067215364e-11, Dimensions(length=1, time=-2), name='millimeters_per_square_day', ascii_symbol='mm/d^2', symbol='mmd⁻²') +millimeters_per_year = NamedUnit(3.168873850681143e-10, Dimensions(length=1, time=-1), name='millimeters_per_year', ascii_symbol='mm/y', symbol='mmy⁻¹') +millimeters_per_square_year = NamedUnit(1.0041761481530735e-16, Dimensions(length=1, time=-2), name='millimeters_per_square_year', ascii_symbol='mm/y^2', symbol='mmy⁻²') micrometers_per_second = NamedUnit(1e-06, Dimensions(length=1, time=-1), name='micrometers_per_second', ascii_symbol='um/s', symbol='µms⁻¹') micrometers_per_square_second = NamedUnit(1e-06, Dimensions(length=1, time=-2), name='micrometers_per_square_second', ascii_symbol='um/s^2', symbol='µms⁻²') micrometers_per_millisecond = NamedUnit(0.001, Dimensions(length=1, time=-1), name='micrometers_per_millisecond', ascii_symbol='um/ms', symbol='µmms⁻¹') @@ -1029,12 +997,12 @@ def __init__(self, name: str, units: list[NamedUnit]): micrometers_per_square_attosecond = NamedUnit(9.999999999999999e+29, Dimensions(length=1, time=-2), name='micrometers_per_square_attosecond', ascii_symbol='um/as^2', symbol='µmas⁻²') micrometers_per_minute = NamedUnit(1.6666666666666667e-08, Dimensions(length=1, time=-1), name='micrometers_per_minute', ascii_symbol='um/min', symbol='µmmin⁻¹') micrometers_per_square_minute = NamedUnit(2.7777777777777777e-10, Dimensions(length=1, time=-2), name='micrometers_per_square_minute', ascii_symbol='um/min^2', symbol='µmmin⁻²') -micrometers_per_hour = NamedUnit(2.7777777777777777e-10, Dimensions(length=1, time=-1), name='micrometers_per_hour', ascii_symbol='um/h', symbol='µmh⁻¹') -micrometers_per_square_hour = NamedUnit(7.71604938271605e-14, Dimensions(length=1, time=-2), name='micrometers_per_square_hour', ascii_symbol='um/h^2', symbol='µmh⁻²') -micrometers_per_day = NamedUnit(1.1574074074074074e-11, Dimensions(length=1, time=-1), name='micrometers_per_day', ascii_symbol='um/d', symbol='µmd⁻¹') -micrometers_per_square_day = NamedUnit(1.3395919067215363e-16, Dimensions(length=1, time=-2), name='micrometers_per_square_day', ascii_symbol='um/d^2', symbol='µmd⁻²') -micrometers_per_year = NamedUnit(3.168873850681143e-14, Dimensions(length=1, time=-1), name='micrometers_per_year', ascii_symbol='um/y', symbol='µmy⁻¹') -micrometers_per_square_year = NamedUnit(1.0041761481530736e-21, Dimensions(length=1, time=-2), name='micrometers_per_square_year', ascii_symbol='um/y^2', symbol='µmy⁻²') +micrometers_per_hour = NamedUnit(2.7777777777777776e-09, Dimensions(length=1, time=-1), name='micrometers_per_hour', ascii_symbol='um/h', symbol='µmh⁻¹') +micrometers_per_square_hour = NamedUnit(7.716049382716049e-12, Dimensions(length=1, time=-2), name='micrometers_per_square_hour', ascii_symbol='um/h^2', symbol='µmh⁻²') +micrometers_per_day = NamedUnit(1.1574074074074074e-10, Dimensions(length=1, time=-1), name='micrometers_per_day', ascii_symbol='um/d', symbol='µmd⁻¹') +micrometers_per_square_day = NamedUnit(1.3395919067215363e-14, Dimensions(length=1, time=-2), name='micrometers_per_square_day', ascii_symbol='um/d^2', symbol='µmd⁻²') +micrometers_per_year = NamedUnit(3.168873850681143e-13, Dimensions(length=1, time=-1), name='micrometers_per_year', ascii_symbol='um/y', symbol='µmy⁻¹') +micrometers_per_square_year = NamedUnit(1.0041761481530734e-19, Dimensions(length=1, time=-2), name='micrometers_per_square_year', ascii_symbol='um/y^2', symbol='µmy⁻²') nanometers_per_second = NamedUnit(1e-09, Dimensions(length=1, time=-1), name='nanometers_per_second', ascii_symbol='nm/s', symbol='nms⁻¹') nanometers_per_square_second = NamedUnit(1e-09, Dimensions(length=1, time=-2), name='nanometers_per_square_second', ascii_symbol='nm/s^2', symbol='nms⁻²') nanometers_per_millisecond = NamedUnit(1e-06, Dimensions(length=1, time=-1), name='nanometers_per_millisecond', ascii_symbol='nm/ms', symbol='nmms⁻¹') @@ -1051,12 +1019,12 @@ def __init__(self, name: str, units: list[NamedUnit]): nanometers_per_square_attosecond = NamedUnit(1e+27, Dimensions(length=1, time=-2), name='nanometers_per_square_attosecond', ascii_symbol='nm/as^2', symbol='nmas⁻²') nanometers_per_minute = NamedUnit(1.6666666666666667e-11, Dimensions(length=1, time=-1), name='nanometers_per_minute', ascii_symbol='nm/min', symbol='nmmin⁻¹') nanometers_per_square_minute = NamedUnit(2.777777777777778e-13, Dimensions(length=1, time=-2), name='nanometers_per_square_minute', ascii_symbol='nm/min^2', symbol='nmmin⁻²') -nanometers_per_hour = NamedUnit(2.777777777777778e-13, Dimensions(length=1, time=-1), name='nanometers_per_hour', ascii_symbol='nm/h', symbol='nmh⁻¹') -nanometers_per_square_hour = NamedUnit(7.71604938271605e-17, Dimensions(length=1, time=-2), name='nanometers_per_square_hour', ascii_symbol='nm/h^2', symbol='nmh⁻²') -nanometers_per_day = NamedUnit(1.1574074074074075e-14, Dimensions(length=1, time=-1), name='nanometers_per_day', ascii_symbol='nm/d', symbol='nmd⁻¹') -nanometers_per_square_day = NamedUnit(1.3395919067215365e-19, Dimensions(length=1, time=-2), name='nanometers_per_square_day', ascii_symbol='nm/d^2', symbol='nmd⁻²') -nanometers_per_year = NamedUnit(3.1688738506811435e-17, Dimensions(length=1, time=-1), name='nanometers_per_year', ascii_symbol='nm/y', symbol='nmy⁻¹') -nanometers_per_square_year = NamedUnit(1.0041761481530737e-24, Dimensions(length=1, time=-2), name='nanometers_per_square_year', ascii_symbol='nm/y^2', symbol='nmy⁻²') +nanometers_per_hour = NamedUnit(2.777777777777778e-12, Dimensions(length=1, time=-1), name='nanometers_per_hour', ascii_symbol='nm/h', symbol='nmh⁻¹') +nanometers_per_square_hour = NamedUnit(7.71604938271605e-15, Dimensions(length=1, time=-2), name='nanometers_per_square_hour', ascii_symbol='nm/h^2', symbol='nmh⁻²') +nanometers_per_day = NamedUnit(1.1574074074074076e-13, Dimensions(length=1, time=-1), name='nanometers_per_day', ascii_symbol='nm/d', symbol='nmd⁻¹') +nanometers_per_square_day = NamedUnit(1.3395919067215365e-17, Dimensions(length=1, time=-2), name='nanometers_per_square_day', ascii_symbol='nm/d^2', symbol='nmd⁻²') +nanometers_per_year = NamedUnit(3.1688738506811433e-16, Dimensions(length=1, time=-1), name='nanometers_per_year', ascii_symbol='nm/y', symbol='nmy⁻¹') +nanometers_per_square_year = NamedUnit(1.0041761481530736e-22, Dimensions(length=1, time=-2), name='nanometers_per_square_year', ascii_symbol='nm/y^2', symbol='nmy⁻²') picometers_per_second = NamedUnit(1e-12, Dimensions(length=1, time=-1), name='picometers_per_second', ascii_symbol='pm/s', symbol='pms⁻¹') picometers_per_square_second = NamedUnit(1e-12, Dimensions(length=1, time=-2), name='picometers_per_square_second', ascii_symbol='pm/s^2', symbol='pms⁻²') picometers_per_millisecond = NamedUnit(1e-09, Dimensions(length=1, time=-1), name='picometers_per_millisecond', ascii_symbol='pm/ms', symbol='pmms⁻¹') @@ -1073,12 +1041,12 @@ def __init__(self, name: str, units: list[NamedUnit]): picometers_per_square_attosecond = NamedUnit(9.999999999999998e+23, Dimensions(length=1, time=-2), name='picometers_per_square_attosecond', ascii_symbol='pm/as^2', symbol='pmas⁻²') picometers_per_minute = NamedUnit(1.6666666666666667e-14, Dimensions(length=1, time=-1), name='picometers_per_minute', ascii_symbol='pm/min', symbol='pmmin⁻¹') picometers_per_square_minute = NamedUnit(2.7777777777777775e-16, Dimensions(length=1, time=-2), name='picometers_per_square_minute', ascii_symbol='pm/min^2', symbol='pmmin⁻²') -picometers_per_hour = NamedUnit(2.7777777777777775e-16, Dimensions(length=1, time=-1), name='picometers_per_hour', ascii_symbol='pm/h', symbol='pmh⁻¹') -picometers_per_square_hour = NamedUnit(7.716049382716049e-20, Dimensions(length=1, time=-2), name='picometers_per_square_hour', ascii_symbol='pm/h^2', symbol='pmh⁻²') -picometers_per_day = NamedUnit(1.1574074074074074e-17, Dimensions(length=1, time=-1), name='picometers_per_day', ascii_symbol='pm/d', symbol='pmd⁻¹') -picometers_per_square_day = NamedUnit(1.3395919067215362e-22, Dimensions(length=1, time=-2), name='picometers_per_square_day', ascii_symbol='pm/d^2', symbol='pmd⁻²') -picometers_per_year = NamedUnit(3.168873850681143e-20, Dimensions(length=1, time=-1), name='picometers_per_year', ascii_symbol='pm/y', symbol='pmy⁻¹') -picometers_per_square_year = NamedUnit(1.0041761481530736e-27, Dimensions(length=1, time=-2), name='picometers_per_square_year', ascii_symbol='pm/y^2', symbol='pmy⁻²') +picometers_per_hour = NamedUnit(2.7777777777777776e-15, Dimensions(length=1, time=-1), name='picometers_per_hour', ascii_symbol='pm/h', symbol='pmh⁻¹') +picometers_per_square_hour = NamedUnit(7.716049382716049e-18, Dimensions(length=1, time=-2), name='picometers_per_square_hour', ascii_symbol='pm/h^2', symbol='pmh⁻²') +picometers_per_day = NamedUnit(1.1574074074074073e-16, Dimensions(length=1, time=-1), name='picometers_per_day', ascii_symbol='pm/d', symbol='pmd⁻¹') +picometers_per_square_day = NamedUnit(1.3395919067215364e-20, Dimensions(length=1, time=-2), name='picometers_per_square_day', ascii_symbol='pm/d^2', symbol='pmd⁻²') +picometers_per_year = NamedUnit(3.168873850681143e-19, Dimensions(length=1, time=-1), name='picometers_per_year', ascii_symbol='pm/y', symbol='pmy⁻¹') +picometers_per_square_year = NamedUnit(1.0041761481530734e-25, Dimensions(length=1, time=-2), name='picometers_per_square_year', ascii_symbol='pm/y^2', symbol='pmy⁻²') femtometers_per_second = NamedUnit(1e-15, Dimensions(length=1, time=-1), name='femtometers_per_second', ascii_symbol='fm/s', symbol='fms⁻¹') femtometers_per_square_second = NamedUnit(1e-15, Dimensions(length=1, time=-2), name='femtometers_per_square_second', ascii_symbol='fm/s^2', symbol='fms⁻²') femtometers_per_millisecond = NamedUnit(1e-12, Dimensions(length=1, time=-1), name='femtometers_per_millisecond', ascii_symbol='fm/ms', symbol='fmms⁻¹') @@ -1095,12 +1063,12 @@ def __init__(self, name: str, units: list[NamedUnit]): femtometers_per_square_attosecond = NamedUnit(1e+21, Dimensions(length=1, time=-2), name='femtometers_per_square_attosecond', ascii_symbol='fm/as^2', symbol='fmas⁻²') femtometers_per_minute = NamedUnit(1.6666666666666667e-17, Dimensions(length=1, time=-1), name='femtometers_per_minute', ascii_symbol='fm/min', symbol='fmmin⁻¹') femtometers_per_square_minute = NamedUnit(2.777777777777778e-19, Dimensions(length=1, time=-2), name='femtometers_per_square_minute', ascii_symbol='fm/min^2', symbol='fmmin⁻²') -femtometers_per_hour = NamedUnit(2.777777777777778e-19, Dimensions(length=1, time=-1), name='femtometers_per_hour', ascii_symbol='fm/h', symbol='fmh⁻¹') -femtometers_per_square_hour = NamedUnit(7.71604938271605e-23, Dimensions(length=1, time=-2), name='femtometers_per_square_hour', ascii_symbol='fm/h^2', symbol='fmh⁻²') -femtometers_per_day = NamedUnit(1.1574074074074075e-20, Dimensions(length=1, time=-1), name='femtometers_per_day', ascii_symbol='fm/d', symbol='fmd⁻¹') -femtometers_per_square_day = NamedUnit(1.3395919067215364e-25, Dimensions(length=1, time=-2), name='femtometers_per_square_day', ascii_symbol='fm/d^2', symbol='fmd⁻²') -femtometers_per_year = NamedUnit(3.1688738506811434e-23, Dimensions(length=1, time=-1), name='femtometers_per_year', ascii_symbol='fm/y', symbol='fmy⁻¹') -femtometers_per_square_year = NamedUnit(1.0041761481530736e-30, Dimensions(length=1, time=-2), name='femtometers_per_square_year', ascii_symbol='fm/y^2', symbol='fmy⁻²') +femtometers_per_hour = NamedUnit(2.777777777777778e-18, Dimensions(length=1, time=-1), name='femtometers_per_hour', ascii_symbol='fm/h', symbol='fmh⁻¹') +femtometers_per_square_hour = NamedUnit(7.71604938271605e-21, Dimensions(length=1, time=-2), name='femtometers_per_square_hour', ascii_symbol='fm/h^2', symbol='fmh⁻²') +femtometers_per_day = NamedUnit(1.1574074074074075e-19, Dimensions(length=1, time=-1), name='femtometers_per_day', ascii_symbol='fm/d', symbol='fmd⁻¹') +femtometers_per_square_day = NamedUnit(1.3395919067215363e-23, Dimensions(length=1, time=-2), name='femtometers_per_square_day', ascii_symbol='fm/d^2', symbol='fmd⁻²') +femtometers_per_year = NamedUnit(3.168873850681143e-22, Dimensions(length=1, time=-1), name='femtometers_per_year', ascii_symbol='fm/y', symbol='fmy⁻¹') +femtometers_per_square_year = NamedUnit(1.0041761481530735e-28, Dimensions(length=1, time=-2), name='femtometers_per_square_year', ascii_symbol='fm/y^2', symbol='fmy⁻²') attometers_per_second = NamedUnit(1e-18, Dimensions(length=1, time=-1), name='attometers_per_second', ascii_symbol='am/s', symbol='ams⁻¹') attometers_per_square_second = NamedUnit(1e-18, Dimensions(length=1, time=-2), name='attometers_per_square_second', ascii_symbol='am/s^2', symbol='ams⁻²') attometers_per_millisecond = NamedUnit(1e-15, Dimensions(length=1, time=-1), name='attometers_per_millisecond', ascii_symbol='am/ms', symbol='amms⁻¹') @@ -1117,12 +1085,12 @@ def __init__(self, name: str, units: list[NamedUnit]): attometers_per_square_attosecond = NamedUnit(1e+18, Dimensions(length=1, time=-2), name='attometers_per_square_attosecond', ascii_symbol='am/as^2', symbol='amas⁻²') attometers_per_minute = NamedUnit(1.6666666666666668e-20, Dimensions(length=1, time=-1), name='attometers_per_minute', ascii_symbol='am/min', symbol='ammin⁻¹') attometers_per_square_minute = NamedUnit(2.777777777777778e-22, Dimensions(length=1, time=-2), name='attometers_per_square_minute', ascii_symbol='am/min^2', symbol='ammin⁻²') -attometers_per_hour = NamedUnit(2.777777777777778e-22, Dimensions(length=1, time=-1), name='attometers_per_hour', ascii_symbol='am/h', symbol='amh⁻¹') -attometers_per_square_hour = NamedUnit(7.71604938271605e-26, Dimensions(length=1, time=-2), name='attometers_per_square_hour', ascii_symbol='am/h^2', symbol='amh⁻²') -attometers_per_day = NamedUnit(1.1574074074074075e-23, Dimensions(length=1, time=-1), name='attometers_per_day', ascii_symbol='am/d', symbol='amd⁻¹') -attometers_per_square_day = NamedUnit(1.3395919067215364e-28, Dimensions(length=1, time=-2), name='attometers_per_square_day', ascii_symbol='am/d^2', symbol='amd⁻²') -attometers_per_year = NamedUnit(3.1688738506811435e-26, Dimensions(length=1, time=-1), name='attometers_per_year', ascii_symbol='am/y', symbol='amy⁻¹') -attometers_per_square_year = NamedUnit(1.0041761481530737e-33, Dimensions(length=1, time=-2), name='attometers_per_square_year', ascii_symbol='am/y^2', symbol='amy⁻²') +attometers_per_hour = NamedUnit(2.7777777777777778e-21, Dimensions(length=1, time=-1), name='attometers_per_hour', ascii_symbol='am/h', symbol='amh⁻¹') +attometers_per_square_hour = NamedUnit(7.71604938271605e-24, Dimensions(length=1, time=-2), name='attometers_per_square_hour', ascii_symbol='am/h^2', symbol='amh⁻²') +attometers_per_day = NamedUnit(1.1574074074074074e-22, Dimensions(length=1, time=-1), name='attometers_per_day', ascii_symbol='am/d', symbol='amd⁻¹') +attometers_per_square_day = NamedUnit(1.3395919067215363e-26, Dimensions(length=1, time=-2), name='attometers_per_square_day', ascii_symbol='am/d^2', symbol='amd⁻²') +attometers_per_year = NamedUnit(3.1688738506811433e-25, Dimensions(length=1, time=-1), name='attometers_per_year', ascii_symbol='am/y', symbol='amy⁻¹') +attometers_per_square_year = NamedUnit(1.0041761481530734e-31, Dimensions(length=1, time=-2), name='attometers_per_square_year', ascii_symbol='am/y^2', symbol='amy⁻²') decimeters_per_second = NamedUnit(0.1, Dimensions(length=1, time=-1), name='decimeters_per_second', ascii_symbol='dm/s', symbol='dms⁻¹') decimeters_per_square_second = NamedUnit(0.1, Dimensions(length=1, time=-2), name='decimeters_per_square_second', ascii_symbol='dm/s^2', symbol='dms⁻²') decimeters_per_millisecond = NamedUnit(100.0, Dimensions(length=1, time=-1), name='decimeters_per_millisecond', ascii_symbol='dm/ms', symbol='dmms⁻¹') @@ -1139,12 +1107,12 @@ def __init__(self, name: str, units: list[NamedUnit]): decimeters_per_square_attosecond = NamedUnit(1e+35, Dimensions(length=1, time=-2), name='decimeters_per_square_attosecond', ascii_symbol='dm/as^2', symbol='dmas⁻²') decimeters_per_minute = NamedUnit(0.0016666666666666668, Dimensions(length=1, time=-1), name='decimeters_per_minute', ascii_symbol='dm/min', symbol='dmmin⁻¹') decimeters_per_square_minute = NamedUnit(2.777777777777778e-05, Dimensions(length=1, time=-2), name='decimeters_per_square_minute', ascii_symbol='dm/min^2', symbol='dmmin⁻²') -decimeters_per_hour = NamedUnit(2.777777777777778e-05, Dimensions(length=1, time=-1), name='decimeters_per_hour', ascii_symbol='dm/h', symbol='dmh⁻¹') -decimeters_per_square_hour = NamedUnit(7.71604938271605e-09, Dimensions(length=1, time=-2), name='decimeters_per_square_hour', ascii_symbol='dm/h^2', symbol='dmh⁻²') -decimeters_per_day = NamedUnit(1.1574074074074074e-06, Dimensions(length=1, time=-1), name='decimeters_per_day', ascii_symbol='dm/d', symbol='dmd⁻¹') -decimeters_per_square_day = NamedUnit(1.3395919067215364e-11, Dimensions(length=1, time=-2), name='decimeters_per_square_day', ascii_symbol='dm/d^2', symbol='dmd⁻²') -decimeters_per_year = NamedUnit(3.168873850681143e-09, Dimensions(length=1, time=-1), name='decimeters_per_year', ascii_symbol='dm/y', symbol='dmy⁻¹') -decimeters_per_square_year = NamedUnit(1.0041761481530736e-16, Dimensions(length=1, time=-2), name='decimeters_per_square_year', ascii_symbol='dm/y^2', symbol='dmy⁻²') +decimeters_per_hour = NamedUnit(0.0002777777777777778, Dimensions(length=1, time=-1), name='decimeters_per_hour', ascii_symbol='dm/h', symbol='dmh⁻¹') +decimeters_per_square_hour = NamedUnit(7.71604938271605e-07, Dimensions(length=1, time=-2), name='decimeters_per_square_hour', ascii_symbol='dm/h^2', symbol='dmh⁻²') +decimeters_per_day = NamedUnit(1.1574074074074075e-05, Dimensions(length=1, time=-1), name='decimeters_per_day', ascii_symbol='dm/d', symbol='dmd⁻¹') +decimeters_per_square_day = NamedUnit(1.3395919067215364e-09, Dimensions(length=1, time=-2), name='decimeters_per_square_day', ascii_symbol='dm/d^2', symbol='dmd⁻²') +decimeters_per_year = NamedUnit(3.168873850681143e-08, Dimensions(length=1, time=-1), name='decimeters_per_year', ascii_symbol='dm/y', symbol='dmy⁻¹') +decimeters_per_square_year = NamedUnit(1.0041761481530735e-14, Dimensions(length=1, time=-2), name='decimeters_per_square_year', ascii_symbol='dm/y^2', symbol='dmy⁻²') centimeters_per_second = NamedUnit(0.01, Dimensions(length=1, time=-1), name='centimeters_per_second', ascii_symbol='cm/s', symbol='cms⁻¹') centimeters_per_square_second = NamedUnit(0.01, Dimensions(length=1, time=-2), name='centimeters_per_square_second', ascii_symbol='cm/s^2', symbol='cms⁻²') centimeters_per_millisecond = NamedUnit(10.0, Dimensions(length=1, time=-1), name='centimeters_per_millisecond', ascii_symbol='cm/ms', symbol='cmms⁻¹') @@ -1161,12 +1129,12 @@ def __init__(self, name: str, units: list[NamedUnit]): centimeters_per_square_attosecond = NamedUnit(1e+34, Dimensions(length=1, time=-2), name='centimeters_per_square_attosecond', ascii_symbol='cm/as^2', symbol='cmas⁻²') centimeters_per_minute = NamedUnit(0.00016666666666666666, Dimensions(length=1, time=-1), name='centimeters_per_minute', ascii_symbol='cm/min', symbol='cmmin⁻¹') centimeters_per_square_minute = NamedUnit(2.777777777777778e-06, Dimensions(length=1, time=-2), name='centimeters_per_square_minute', ascii_symbol='cm/min^2', symbol='cmmin⁻²') -centimeters_per_hour = NamedUnit(2.777777777777778e-06, Dimensions(length=1, time=-1), name='centimeters_per_hour', ascii_symbol='cm/h', symbol='cmh⁻¹') -centimeters_per_square_hour = NamedUnit(7.71604938271605e-10, Dimensions(length=1, time=-2), name='centimeters_per_square_hour', ascii_symbol='cm/h^2', symbol='cmh⁻²') -centimeters_per_day = NamedUnit(1.1574074074074074e-07, Dimensions(length=1, time=-1), name='centimeters_per_day', ascii_symbol='cm/d', symbol='cmd⁻¹') -centimeters_per_square_day = NamedUnit(1.3395919067215364e-12, Dimensions(length=1, time=-2), name='centimeters_per_square_day', ascii_symbol='cm/d^2', symbol='cmd⁻²') -centimeters_per_year = NamedUnit(3.168873850681143e-10, Dimensions(length=1, time=-1), name='centimeters_per_year', ascii_symbol='cm/y', symbol='cmy⁻¹') -centimeters_per_square_year = NamedUnit(1.0041761481530737e-17, Dimensions(length=1, time=-2), name='centimeters_per_square_year', ascii_symbol='cm/y^2', symbol='cmy⁻²') +centimeters_per_hour = NamedUnit(2.777777777777778e-05, Dimensions(length=1, time=-1), name='centimeters_per_hour', ascii_symbol='cm/h', symbol='cmh⁻¹') +centimeters_per_square_hour = NamedUnit(7.71604938271605e-08, Dimensions(length=1, time=-2), name='centimeters_per_square_hour', ascii_symbol='cm/h^2', symbol='cmh⁻²') +centimeters_per_day = NamedUnit(1.1574074074074074e-06, Dimensions(length=1, time=-1), name='centimeters_per_day', ascii_symbol='cm/d', symbol='cmd⁻¹') +centimeters_per_square_day = NamedUnit(1.3395919067215363e-10, Dimensions(length=1, time=-2), name='centimeters_per_square_day', ascii_symbol='cm/d^2', symbol='cmd⁻²') +centimeters_per_year = NamedUnit(3.168873850681143e-09, Dimensions(length=1, time=-1), name='centimeters_per_year', ascii_symbol='cm/y', symbol='cmy⁻¹') +centimeters_per_square_year = NamedUnit(1.0041761481530735e-15, Dimensions(length=1, time=-2), name='centimeters_per_square_year', ascii_symbol='cm/y^2', symbol='cmy⁻²') angstroms_per_second = NamedUnit(1e-10, Dimensions(length=1, time=-1), name='angstroms_per_second', ascii_symbol='Ang/s', symbol='Ås⁻¹') angstroms_per_square_second = NamedUnit(1e-10, Dimensions(length=1, time=-2), name='angstroms_per_square_second', ascii_symbol='Ang/s^2', symbol='Ås⁻²') angstroms_per_millisecond = NamedUnit(1e-07, Dimensions(length=1, time=-1), name='angstroms_per_millisecond', ascii_symbol='Ang/ms', symbol='Åms⁻¹') @@ -1183,34 +1151,12 @@ def __init__(self, name: str, units: list[NamedUnit]): angstroms_per_square_attosecond = NamedUnit(9.999999999999999e+25, Dimensions(length=1, time=-2), name='angstroms_per_square_attosecond', ascii_symbol='Ang/as^2', symbol='Åas⁻²') angstroms_per_minute = NamedUnit(1.6666666666666668e-12, Dimensions(length=1, time=-1), name='angstroms_per_minute', ascii_symbol='Ang/min', symbol='Åmin⁻¹') angstroms_per_square_minute = NamedUnit(2.7777777777777778e-14, Dimensions(length=1, time=-2), name='angstroms_per_square_minute', ascii_symbol='Ang/min^2', symbol='Åmin⁻²') -angstroms_per_hour = NamedUnit(2.7777777777777778e-14, Dimensions(length=1, time=-1), name='angstroms_per_hour', ascii_symbol='Ang/h', symbol='Åh⁻¹') -angstroms_per_square_hour = NamedUnit(7.71604938271605e-18, Dimensions(length=1, time=-2), name='angstroms_per_square_hour', ascii_symbol='Ang/h^2', symbol='Åh⁻²') -angstroms_per_day = NamedUnit(1.1574074074074075e-15, Dimensions(length=1, time=-1), name='angstroms_per_day', ascii_symbol='Ang/d', symbol='Åd⁻¹') -angstroms_per_square_day = NamedUnit(1.3395919067215364e-20, Dimensions(length=1, time=-2), name='angstroms_per_square_day', ascii_symbol='Ang/d^2', symbol='Åd⁻²') -angstroms_per_year = NamedUnit(3.168873850681143e-18, Dimensions(length=1, time=-1), name='angstroms_per_year', ascii_symbol='Ang/y', symbol='Åy⁻¹') -angstroms_per_square_year = NamedUnit(1.0041761481530736e-25, Dimensions(length=1, time=-2), name='angstroms_per_square_year', ascii_symbol='Ang/y^2', symbol='Åy⁻²') -microns_per_second = NamedUnit(1e-06, Dimensions(length=1, time=-1), name='microns_per_second', ascii_symbol='micron/s', symbol='microns⁻¹') -microns_per_square_second = NamedUnit(1e-06, Dimensions(length=1, time=-2), name='microns_per_square_second', ascii_symbol='micron/s^2', symbol='microns⁻²') -microns_per_millisecond = NamedUnit(0.001, Dimensions(length=1, time=-1), name='microns_per_millisecond', ascii_symbol='micron/ms', symbol='micronms⁻¹') -microns_per_square_millisecond = NamedUnit(1.0, Dimensions(length=1, time=-2), name='microns_per_square_millisecond', ascii_symbol='micron/ms^2', symbol='micronms⁻²') -microns_per_microsecond = NamedUnit(1.0, Dimensions(length=1, time=-1), name='microns_per_microsecond', ascii_symbol='micron/us', symbol='micronµs⁻¹') -microns_per_square_microsecond = NamedUnit(1000000.0, Dimensions(length=1, time=-2), name='microns_per_square_microsecond', ascii_symbol='micron/us^2', symbol='micronµs⁻²') -microns_per_nanosecond = NamedUnit(999.9999999999999, Dimensions(length=1, time=-1), name='microns_per_nanosecond', ascii_symbol='micron/ns', symbol='micronns⁻¹') -microns_per_square_nanosecond = NamedUnit(999999999999.9999, Dimensions(length=1, time=-2), name='microns_per_square_nanosecond', ascii_symbol='micron/ns^2', symbol='micronns⁻²') -microns_per_picosecond = NamedUnit(1000000.0, Dimensions(length=1, time=-1), name='microns_per_picosecond', ascii_symbol='micron/ps', symbol='micronps⁻¹') -microns_per_square_picosecond = NamedUnit(1e+18, Dimensions(length=1, time=-2), name='microns_per_square_picosecond', ascii_symbol='micron/ps^2', symbol='micronps⁻²') -microns_per_femtosecond = NamedUnit(999999999.9999999, Dimensions(length=1, time=-1), name='microns_per_femtosecond', ascii_symbol='micron/fs', symbol='micronfs⁻¹') -microns_per_square_femtosecond = NamedUnit(9.999999999999998e+23, Dimensions(length=1, time=-2), name='microns_per_square_femtosecond', ascii_symbol='micron/fs^2', symbol='micronfs⁻²') -microns_per_attosecond = NamedUnit(999999999999.9999, Dimensions(length=1, time=-1), name='microns_per_attosecond', ascii_symbol='micron/as', symbol='micronas⁻¹') -microns_per_square_attosecond = NamedUnit(9.999999999999999e+29, Dimensions(length=1, time=-2), name='microns_per_square_attosecond', ascii_symbol='micron/as^2', symbol='micronas⁻²') -microns_per_minute = NamedUnit(1.6666666666666667e-08, Dimensions(length=1, time=-1), name='microns_per_minute', ascii_symbol='micron/min', symbol='micronmin⁻¹') -microns_per_square_minute = NamedUnit(2.7777777777777777e-10, Dimensions(length=1, time=-2), name='microns_per_square_minute', ascii_symbol='micron/min^2', symbol='micronmin⁻²') -microns_per_hour = NamedUnit(2.7777777777777777e-10, Dimensions(length=1, time=-1), name='microns_per_hour', ascii_symbol='micron/h', symbol='micronh⁻¹') -microns_per_square_hour = NamedUnit(7.71604938271605e-14, Dimensions(length=1, time=-2), name='microns_per_square_hour', ascii_symbol='micron/h^2', symbol='micronh⁻²') -microns_per_day = NamedUnit(1.1574074074074074e-11, Dimensions(length=1, time=-1), name='microns_per_day', ascii_symbol='micron/d', symbol='micrond⁻¹') -microns_per_square_day = NamedUnit(1.3395919067215363e-16, Dimensions(length=1, time=-2), name='microns_per_square_day', ascii_symbol='micron/d^2', symbol='micrond⁻²') -microns_per_year = NamedUnit(3.168873850681143e-14, Dimensions(length=1, time=-1), name='microns_per_year', ascii_symbol='micron/y', symbol='microny⁻¹') -microns_per_square_year = NamedUnit(1.0041761481530736e-21, Dimensions(length=1, time=-2), name='microns_per_square_year', ascii_symbol='micron/y^2', symbol='microny⁻²') +angstroms_per_hour = NamedUnit(2.777777777777778e-13, Dimensions(length=1, time=-1), name='angstroms_per_hour', ascii_symbol='Ang/h', symbol='Åh⁻¹') +angstroms_per_square_hour = NamedUnit(7.716049382716049e-16, Dimensions(length=1, time=-2), name='angstroms_per_square_hour', ascii_symbol='Ang/h^2', symbol='Åh⁻²') +angstroms_per_day = NamedUnit(1.1574074074074074e-14, Dimensions(length=1, time=-1), name='angstroms_per_day', ascii_symbol='Ang/d', symbol='Åd⁻¹') +angstroms_per_square_day = NamedUnit(1.3395919067215363e-18, Dimensions(length=1, time=-2), name='angstroms_per_square_day', ascii_symbol='Ang/d^2', symbol='Åd⁻²') +angstroms_per_year = NamedUnit(3.168873850681143e-17, Dimensions(length=1, time=-1), name='angstroms_per_year', ascii_symbol='Ang/y', symbol='Åy⁻¹') +angstroms_per_square_year = NamedUnit(1.0041761481530734e-23, Dimensions(length=1, time=-2), name='angstroms_per_square_year', ascii_symbol='Ang/y^2', symbol='Åy⁻²') miles_per_second = NamedUnit(1609.344, Dimensions(length=1, time=-1), name='miles_per_second', ascii_symbol='miles/s', symbol='miless⁻¹') miles_per_square_second = NamedUnit(1609.344, Dimensions(length=1, time=-2), name='miles_per_square_second', ascii_symbol='miles/s^2', symbol='miless⁻²') miles_per_millisecond = NamedUnit(1609344.0, Dimensions(length=1, time=-1), name='miles_per_millisecond', ascii_symbol='miles/ms', symbol='milesms⁻¹') @@ -1227,12 +1173,12 @@ def __init__(self, name: str, units: list[NamedUnit]): miles_per_square_attosecond = NamedUnit(1.609344e+39, Dimensions(length=1, time=-2), name='miles_per_square_attosecond', ascii_symbol='miles/as^2', symbol='milesas⁻²') miles_per_minute = NamedUnit(26.822400000000002, Dimensions(length=1, time=-1), name='miles_per_minute', ascii_symbol='miles/min', symbol='milesmin⁻¹') miles_per_square_minute = NamedUnit(0.44704, Dimensions(length=1, time=-2), name='miles_per_square_minute', ascii_symbol='miles/min^2', symbol='milesmin⁻²') -miles_per_hour = NamedUnit(0.44704, Dimensions(length=1, time=-1), name='miles_per_hour', ascii_symbol='miles/h', symbol='milesh⁻¹') -miles_per_square_hour = NamedUnit(0.00012417777777777778, Dimensions(length=1, time=-2), name='miles_per_square_hour', ascii_symbol='miles/h^2', symbol='milesh⁻²') -miles_per_day = NamedUnit(0.018626666666666666, Dimensions(length=1, time=-1), name='miles_per_day', ascii_symbol='miles/d', symbol='milesd⁻¹') -miles_per_square_day = NamedUnit(2.1558641975308643e-07, Dimensions(length=1, time=-2), name='miles_per_square_day', ascii_symbol='miles/d^2', symbol='milesd⁻²') -miles_per_year = NamedUnit(5.099808118350594e-05, Dimensions(length=1, time=-1), name='miles_per_year', ascii_symbol='miles/y', symbol='milesy⁻¹') -miles_per_square_year = NamedUnit(1.61606485897326e-12, Dimensions(length=1, time=-2), name='miles_per_square_year', ascii_symbol='miles/y^2', symbol='milesy⁻²') +miles_per_hour = NamedUnit(4.4704, Dimensions(length=1, time=-1), name='miles_per_hour', ascii_symbol='miles/h', symbol='milesh⁻¹') +miles_per_square_hour = NamedUnit(0.012417777777777778, Dimensions(length=1, time=-2), name='miles_per_square_hour', ascii_symbol='miles/h^2', symbol='milesh⁻²') +miles_per_day = NamedUnit(0.18626666666666666, Dimensions(length=1, time=-1), name='miles_per_day', ascii_symbol='miles/d', symbol='milesd⁻¹') +miles_per_square_day = NamedUnit(2.1558641975308643e-05, Dimensions(length=1, time=-2), name='miles_per_square_day', ascii_symbol='miles/d^2', symbol='milesd⁻²') +miles_per_year = NamedUnit(0.0005099808118350593, Dimensions(length=1, time=-1), name='miles_per_year', ascii_symbol='miles/y', symbol='milesy⁻¹') +miles_per_square_year = NamedUnit(1.61606485897326e-10, Dimensions(length=1, time=-2), name='miles_per_square_year', ascii_symbol='miles/y^2', symbol='milesy⁻²') yards_per_second = NamedUnit(0.9144000000000001, Dimensions(length=1, time=-1), name='yards_per_second', ascii_symbol='yrd/s', symbol='yrds⁻¹') yards_per_square_second = NamedUnit(0.9144000000000001, Dimensions(length=1, time=-2), name='yards_per_square_second', ascii_symbol='yrd/s^2', symbol='yrds⁻²') yards_per_millisecond = NamedUnit(914.4000000000001, Dimensions(length=1, time=-1), name='yards_per_millisecond', ascii_symbol='yrd/ms', symbol='yrdms⁻¹') @@ -1249,12 +1195,12 @@ def __init__(self, name: str, units: list[NamedUnit]): yards_per_square_attosecond = NamedUnit(9.144e+35, Dimensions(length=1, time=-2), name='yards_per_square_attosecond', ascii_symbol='yrd/as^2', symbol='yrdas⁻²') yards_per_minute = NamedUnit(0.015240000000000002, Dimensions(length=1, time=-1), name='yards_per_minute', ascii_symbol='yrd/min', symbol='yrdmin⁻¹') yards_per_square_minute = NamedUnit(0.00025400000000000005, Dimensions(length=1, time=-2), name='yards_per_square_minute', ascii_symbol='yrd/min^2', symbol='yrdmin⁻²') -yards_per_hour = NamedUnit(0.00025400000000000005, Dimensions(length=1, time=-1), name='yards_per_hour', ascii_symbol='yrd/h', symbol='yrdh⁻¹') -yards_per_square_hour = NamedUnit(7.055555555555557e-08, Dimensions(length=1, time=-2), name='yards_per_square_hour', ascii_symbol='yrd/h^2', symbol='yrdh⁻²') -yards_per_day = NamedUnit(1.0583333333333334e-05, Dimensions(length=1, time=-1), name='yards_per_day', ascii_symbol='yrd/d', symbol='yrdd⁻¹') -yards_per_square_day = NamedUnit(1.224922839506173e-10, Dimensions(length=1, time=-2), name='yards_per_square_day', ascii_symbol='yrd/d^2', symbol='yrdd⁻²') -yards_per_year = NamedUnit(2.8976182490628376e-08, Dimensions(length=1, time=-1), name='yards_per_year', ascii_symbol='yrd/y', symbol='yrdy⁻¹') -yards_per_square_year = NamedUnit(9.182186698711705e-16, Dimensions(length=1, time=-2), name='yards_per_square_year', ascii_symbol='yrd/y^2', symbol='yrdy⁻²') +yards_per_hour = NamedUnit(0.00254, Dimensions(length=1, time=-1), name='yards_per_hour', ascii_symbol='yrd/h', symbol='yrdh⁻¹') +yards_per_square_hour = NamedUnit(7.055555555555557e-06, Dimensions(length=1, time=-2), name='yards_per_square_hour', ascii_symbol='yrd/h^2', symbol='yrdh⁻²') +yards_per_day = NamedUnit(0.00010583333333333335, Dimensions(length=1, time=-1), name='yards_per_day', ascii_symbol='yrd/d', symbol='yrdd⁻¹') +yards_per_square_day = NamedUnit(1.224922839506173e-08, Dimensions(length=1, time=-2), name='yards_per_square_day', ascii_symbol='yrd/d^2', symbol='yrdd⁻²') +yards_per_year = NamedUnit(2.897618249062837e-07, Dimensions(length=1, time=-1), name='yards_per_year', ascii_symbol='yrd/y', symbol='yrdy⁻¹') +yards_per_square_year = NamedUnit(9.182186698711705e-14, Dimensions(length=1, time=-2), name='yards_per_square_year', ascii_symbol='yrd/y^2', symbol='yrdy⁻²') feet_per_second = NamedUnit(0.3048, Dimensions(length=1, time=-1), name='feet_per_second', ascii_symbol='ft/s', symbol='fts⁻¹') feet_per_square_second = NamedUnit(0.3048, Dimensions(length=1, time=-2), name='feet_per_square_second', ascii_symbol='ft/s^2', symbol='fts⁻²') feet_per_millisecond = NamedUnit(304.8, Dimensions(length=1, time=-1), name='feet_per_millisecond', ascii_symbol='ft/ms', symbol='ftms⁻¹') @@ -1271,12 +1217,12 @@ def __init__(self, name: str, units: list[NamedUnit]): feet_per_square_attosecond = NamedUnit(3.0479999999999997e+35, Dimensions(length=1, time=-2), name='feet_per_square_attosecond', ascii_symbol='ft/as^2', symbol='ftas⁻²') feet_per_minute = NamedUnit(0.00508, Dimensions(length=1, time=-1), name='feet_per_minute', ascii_symbol='ft/min', symbol='ftmin⁻¹') feet_per_square_minute = NamedUnit(8.466666666666667e-05, Dimensions(length=1, time=-2), name='feet_per_square_minute', ascii_symbol='ft/min^2', symbol='ftmin⁻²') -feet_per_hour = NamedUnit(8.466666666666667e-05, Dimensions(length=1, time=-1), name='feet_per_hour', ascii_symbol='ft/h', symbol='fth⁻¹') -feet_per_square_hour = NamedUnit(2.351851851851852e-08, Dimensions(length=1, time=-2), name='feet_per_square_hour', ascii_symbol='ft/h^2', symbol='fth⁻²') -feet_per_day = NamedUnit(3.527777777777778e-06, Dimensions(length=1, time=-1), name='feet_per_day', ascii_symbol='ft/d', symbol='ftd⁻¹') -feet_per_square_day = NamedUnit(4.083076131687243e-11, Dimensions(length=1, time=-2), name='feet_per_square_day', ascii_symbol='ft/d^2', symbol='ftd⁻²') -feet_per_year = NamedUnit(9.658727496876124e-09, Dimensions(length=1, time=-1), name='feet_per_year', ascii_symbol='ft/y', symbol='fty⁻¹') -feet_per_square_year = NamedUnit(3.060728899570568e-16, Dimensions(length=1, time=-2), name='feet_per_square_year', ascii_symbol='ft/y^2', symbol='fty⁻²') +feet_per_hour = NamedUnit(0.0008466666666666667, Dimensions(length=1, time=-1), name='feet_per_hour', ascii_symbol='ft/h', symbol='fth⁻¹') +feet_per_square_hour = NamedUnit(2.351851851851852e-06, Dimensions(length=1, time=-2), name='feet_per_square_hour', ascii_symbol='ft/h^2', symbol='fth⁻²') +feet_per_day = NamedUnit(3.527777777777778e-05, Dimensions(length=1, time=-1), name='feet_per_day', ascii_symbol='ft/d', symbol='ftd⁻¹') +feet_per_square_day = NamedUnit(4.083076131687243e-09, Dimensions(length=1, time=-2), name='feet_per_square_day', ascii_symbol='ft/d^2', symbol='ftd⁻²') +feet_per_year = NamedUnit(9.658727496876123e-08, Dimensions(length=1, time=-1), name='feet_per_year', ascii_symbol='ft/y', symbol='fty⁻¹') +feet_per_square_year = NamedUnit(3.060728899570568e-14, Dimensions(length=1, time=-2), name='feet_per_square_year', ascii_symbol='ft/y^2', symbol='fty⁻²') inches_per_second = NamedUnit(0.0254, Dimensions(length=1, time=-1), name='inches_per_second', ascii_symbol='in/s', symbol='ins⁻¹') inches_per_square_second = NamedUnit(0.0254, Dimensions(length=1, time=-2), name='inches_per_square_second', ascii_symbol='in/s^2', symbol='ins⁻²') inches_per_millisecond = NamedUnit(25.4, Dimensions(length=1, time=-1), name='inches_per_millisecond', ascii_symbol='in/ms', symbol='inms⁻¹') @@ -1293,12 +1239,12 @@ def __init__(self, name: str, units: list[NamedUnit]): inches_per_square_attosecond = NamedUnit(2.5399999999999998e+34, Dimensions(length=1, time=-2), name='inches_per_square_attosecond', ascii_symbol='in/as^2', symbol='inas⁻²') inches_per_minute = NamedUnit(0.00042333333333333334, Dimensions(length=1, time=-1), name='inches_per_minute', ascii_symbol='in/min', symbol='inmin⁻¹') inches_per_square_minute = NamedUnit(7.055555555555555e-06, Dimensions(length=1, time=-2), name='inches_per_square_minute', ascii_symbol='in/min^2', symbol='inmin⁻²') -inches_per_hour = NamedUnit(7.055555555555555e-06, Dimensions(length=1, time=-1), name='inches_per_hour', ascii_symbol='in/h', symbol='inh⁻¹') -inches_per_square_hour = NamedUnit(1.9598765432098764e-09, Dimensions(length=1, time=-2), name='inches_per_square_hour', ascii_symbol='in/h^2', symbol='inh⁻²') -inches_per_day = NamedUnit(2.939814814814815e-07, Dimensions(length=1, time=-1), name='inches_per_day', ascii_symbol='in/d', symbol='ind⁻¹') -inches_per_square_day = NamedUnit(3.402563443072702e-12, Dimensions(length=1, time=-2), name='inches_per_square_day', ascii_symbol='in/d^2', symbol='ind⁻²') -inches_per_year = NamedUnit(8.048939580730103e-10, Dimensions(length=1, time=-1), name='inches_per_year', ascii_symbol='in/y', symbol='iny⁻¹') -inches_per_square_year = NamedUnit(2.550607416308807e-17, Dimensions(length=1, time=-2), name='inches_per_square_year', ascii_symbol='in/y^2', symbol='iny⁻²') +inches_per_hour = NamedUnit(7.055555555555556e-05, Dimensions(length=1, time=-1), name='inches_per_hour', ascii_symbol='in/h', symbol='inh⁻¹') +inches_per_square_hour = NamedUnit(1.9598765432098765e-07, Dimensions(length=1, time=-2), name='inches_per_square_hour', ascii_symbol='in/h^2', symbol='inh⁻²') +inches_per_day = NamedUnit(2.9398148148148147e-06, Dimensions(length=1, time=-1), name='inches_per_day', ascii_symbol='in/d', symbol='ind⁻¹') +inches_per_square_day = NamedUnit(3.4025634430727023e-10, Dimensions(length=1, time=-2), name='inches_per_square_day', ascii_symbol='in/d^2', symbol='ind⁻²') +inches_per_year = NamedUnit(8.048939580730103e-09, Dimensions(length=1, time=-1), name='inches_per_year', ascii_symbol='in/y', symbol='iny⁻¹') +inches_per_square_year = NamedUnit(2.5506074163088065e-15, Dimensions(length=1, time=-2), name='inches_per_square_year', ascii_symbol='in/y^2', symbol='iny⁻²') grams_per_cubic_meter = NamedUnit(0.001, Dimensions(length=-3, mass=1), name='grams_per_cubic_meter', ascii_symbol='g m^-3', symbol='gm⁻³') exagrams_per_cubic_meter = NamedUnit(1000000000000000.0, Dimensions(length=-3, mass=1), name='exagrams_per_cubic_meter', ascii_symbol='Eg m^-3', symbol='Egm⁻³') petagrams_per_cubic_meter = NamedUnit(1000000000000.0, Dimensions(length=-3, mass=1), name='petagrams_per_cubic_meter', ascii_symbol='Pg m^-3', symbol='Pgm⁻³') @@ -1555,22 +1501,6 @@ def __init__(self, name: str, units: list[NamedUnit]): atomic_mass_units_per_cubic_angstrom = NamedUnit(1660.5389209999996, Dimensions(length=-3, mass=1), name='atomic_mass_units_per_cubic_angstrom', ascii_symbol='au Ang^-3', symbol='auÅ⁻³') pounds_per_cubic_angstrom = NamedUnit(4.5359237e+29, Dimensions(length=-3, mass=1), name='pounds_per_cubic_angstrom', ascii_symbol='lb Ang^-3', symbol='lbÅ⁻³') ounces_per_cubic_angstrom = NamedUnit(2.8349523125e+28, Dimensions(length=-3, mass=1), name='ounces_per_cubic_angstrom', ascii_symbol='oz Ang^-3', symbol='ozÅ⁻³') -grams_per_cubic_micron = NamedUnit(1000000000000000.1, Dimensions(length=-3, mass=1), name='grams_per_cubic_micron', ascii_symbol='g micron^-3', symbol='gmicron⁻³') -exagrams_per_cubic_micron = NamedUnit(1.0000000000000001e+33, Dimensions(length=-3, mass=1), name='exagrams_per_cubic_micron', ascii_symbol='Eg micron^-3', symbol='Egmicron⁻³') -petagrams_per_cubic_micron = NamedUnit(1.0000000000000002e+30, Dimensions(length=-3, mass=1), name='petagrams_per_cubic_micron', ascii_symbol='Pg micron^-3', symbol='Pgmicron⁻³') -teragrams_per_cubic_micron = NamedUnit(1.0000000000000002e+27, Dimensions(length=-3, mass=1), name='teragrams_per_cubic_micron', ascii_symbol='Tg micron^-3', symbol='Tgmicron⁻³') -gigagrams_per_cubic_micron = NamedUnit(1.0000000000000001e+24, Dimensions(length=-3, mass=1), name='gigagrams_per_cubic_micron', ascii_symbol='Gg micron^-3', symbol='Ggmicron⁻³') -megagrams_per_cubic_micron = NamedUnit(1.0000000000000001e+21, Dimensions(length=-3, mass=1), name='megagrams_per_cubic_micron', ascii_symbol='Mg micron^-3', symbol='Mgmicron⁻³') -kilograms_per_cubic_micron = NamedUnit(1.0000000000000001e+18, Dimensions(length=-3, mass=1), name='kilograms_per_cubic_micron', ascii_symbol='kg micron^-3', symbol='kgmicron⁻³') -milligrams_per_cubic_micron = NamedUnit(1000000000000.0001, Dimensions(length=-3, mass=1), name='milligrams_per_cubic_micron', ascii_symbol='mg micron^-3', symbol='mgmicron⁻³') -micrograms_per_cubic_micron = NamedUnit(1000000000.0000002, Dimensions(length=-3, mass=1), name='micrograms_per_cubic_micron', ascii_symbol='ug micron^-3', symbol='µgmicron⁻³') -nanograms_per_cubic_micron = NamedUnit(1000000.0000000003, Dimensions(length=-3, mass=1), name='nanograms_per_cubic_micron', ascii_symbol='ng micron^-3', symbol='ngmicron⁻³') -picograms_per_cubic_micron = NamedUnit(1000.0000000000002, Dimensions(length=-3, mass=1), name='picograms_per_cubic_micron', ascii_symbol='pg micron^-3', symbol='pgmicron⁻³') -femtograms_per_cubic_micron = NamedUnit(1.0000000000000002, Dimensions(length=-3, mass=1), name='femtograms_per_cubic_micron', ascii_symbol='fg micron^-3', symbol='fgmicron⁻³') -attograms_per_cubic_micron = NamedUnit(0.0010000000000000002, Dimensions(length=-3, mass=1), name='attograms_per_cubic_micron', ascii_symbol='ag micron^-3', symbol='agmicron⁻³') -atomic_mass_units_per_cubic_micron = NamedUnit(1.660538921e-09, Dimensions(length=-3, mass=1), name='atomic_mass_units_per_cubic_micron', ascii_symbol='au micron^-3', symbol='aumicron⁻³') -pounds_per_cubic_micron = NamedUnit(4.5359237000000006e+17, Dimensions(length=-3, mass=1), name='pounds_per_cubic_micron', ascii_symbol='lb micron^-3', symbol='lbmicron⁻³') -ounces_per_cubic_micron = NamedUnit(2.8349523125000004e+16, Dimensions(length=-3, mass=1), name='ounces_per_cubic_micron', ascii_symbol='oz micron^-3', symbol='ozmicron⁻³') grams_per_cubic_mile = NamedUnit(2.399127585789277e-13, Dimensions(length=-3, mass=1), name='grams_per_cubic_mile', ascii_symbol='g miles^-3', symbol='gmiles⁻³') exagrams_per_cubic_mile = NamedUnit(239912.7585789277, Dimensions(length=-3, mass=1), name='exagrams_per_cubic_mile', ascii_symbol='Eg miles^-3', symbol='Egmiles⁻³') petagrams_per_cubic_mile = NamedUnit(239.9127585789277, Dimensions(length=-3, mass=1), name='petagrams_per_cubic_mile', ascii_symbol='Pg miles^-3', symbol='Pgmiles⁻³') @@ -1747,13 +1677,6 @@ def __init__(self, name: str, units: list[NamedUnit]): picomoles_per_cubic_angstrom = NamedUnit(6.02214076e+41, Dimensions(length=-3, moles_hint=1), name='picomoles_per_cubic_angstrom', ascii_symbol='pmol Ang^-3', symbol='pmolÅ⁻³') femtomoles_per_cubic_angstrom = NamedUnit(6.022140759999999e+38, Dimensions(length=-3, moles_hint=1), name='femtomoles_per_cubic_angstrom', ascii_symbol='fmol Ang^-3', symbol='fmolÅ⁻³') attomoles_per_cubic_angstrom = NamedUnit(6.02214076e+35, Dimensions(length=-3, moles_hint=1), name='attomoles_per_cubic_angstrom', ascii_symbol='amol Ang^-3', symbol='amolÅ⁻³') -moles_per_cubic_micron = NamedUnit(6.022140760000001e+41, Dimensions(length=-3, moles_hint=1), name='moles_per_cubic_micron', ascii_symbol='mol micron^-3', symbol='molmicron⁻³') -millimoles_per_cubic_micron = NamedUnit(6.022140760000001e+38, Dimensions(length=-3, moles_hint=1), name='millimoles_per_cubic_micron', ascii_symbol='mmol micron^-3', symbol='mmolmicron⁻³') -micromoles_per_cubic_micron = NamedUnit(6.0221407600000004e+35, Dimensions(length=-3, moles_hint=1), name='micromoles_per_cubic_micron', ascii_symbol='umol micron^-3', symbol='µmolmicron⁻³') -nanomoles_per_cubic_micron = NamedUnit(6.022140760000001e+32, Dimensions(length=-3, moles_hint=1), name='nanomoles_per_cubic_micron', ascii_symbol='nmol micron^-3', symbol='nmolmicron⁻³') -picomoles_per_cubic_micron = NamedUnit(6.022140760000001e+29, Dimensions(length=-3, moles_hint=1), name='picomoles_per_cubic_micron', ascii_symbol='pmol micron^-3', symbol='pmolmicron⁻³') -femtomoles_per_cubic_micron = NamedUnit(6.022140760000001e+26, Dimensions(length=-3, moles_hint=1), name='femtomoles_per_cubic_micron', ascii_symbol='fmol micron^-3', symbol='fmolmicron⁻³') -attomoles_per_cubic_micron = NamedUnit(6.0221407600000005e+23, Dimensions(length=-3, moles_hint=1), name='attomoles_per_cubic_micron', ascii_symbol='amol micron^-3', symbol='amolmicron⁻³') moles_per_cubic_mile = NamedUnit(144478840228220.0, Dimensions(length=-3, moles_hint=1), name='moles_per_cubic_mile', ascii_symbol='mol miles^-3', symbol='molmiles⁻³') millimoles_per_cubic_mile = NamedUnit(144478840228.22003, Dimensions(length=-3, moles_hint=1), name='millimoles_per_cubic_mile', ascii_symbol='mmol miles^-3', symbol='mmolmiles⁻³') micromoles_per_cubic_mile = NamedUnit(144478840.22822002, Dimensions(length=-3, moles_hint=1), name='micromoles_per_cubic_mile', ascii_symbol='umol miles^-3', symbol='µmolmiles⁻³') @@ -2050,15 +1973,12 @@ def __init__(self, name: str, units: list[NamedUnit]): "aH": attohenry, "Ang": angstroms, "Å": angstroms, - "micron": microns, "min": minutes, - "rpm": revolutions_per_minute, "h": hours, "d": days, "y": years, "deg": degrees, "rad": radians, - "rot": rotations, "sr": stradians, "l": litres, "eV": electronvolts, @@ -2107,7 +2027,6 @@ def __init__(self, name: str, units: list[NamedUnit]): "amu": atomic_mass_units, "degr": degrees, "Deg": degrees, - "degree": degrees, "degrees": degrees, "Degrees": degrees, "Counts": none, @@ -2144,7 +2063,6 @@ def __init__(self, name: str, units: list[NamedUnit]): decimeters, centimeters, angstroms, - microns, miles, yards, feet, @@ -2170,7 +2088,6 @@ def __init__(self, name: str, units: list[NamedUnit]): square_decimeters, square_centimeters, square_angstroms, - square_microns, square_miles, square_yards, square_feet, @@ -2197,7 +2114,6 @@ def __init__(self, name: str, units: list[NamedUnit]): cubic_decimeters, cubic_centimeters, cubic_angstroms, - cubic_microns, cubic_miles, cubic_yards, cubic_feet, @@ -2223,7 +2139,6 @@ def __init__(self, name: str, units: list[NamedUnit]): per_decimeter, per_centimeter, per_angstrom, - per_micron, per_mile, per_yard, per_foot, @@ -2249,7 +2164,6 @@ def __init__(self, name: str, units: list[NamedUnit]): per_square_decimeter, per_square_centimeter, per_square_angstrom, - per_square_micron, per_square_mile, per_square_yard, per_square_foot, @@ -2275,7 +2189,6 @@ def __init__(self, name: str, units: list[NamedUnit]): per_cubic_decimeter, per_cubic_centimeter, per_cubic_angstrom, - per_cubic_micron, per_cubic_mile, per_cubic_yard, per_cubic_foot, @@ -2496,17 +2409,6 @@ def __init__(self, name: str, units: list[NamedUnit]): angstroms_per_hour, angstroms_per_day, angstroms_per_year, - microns_per_second, - microns_per_millisecond, - microns_per_microsecond, - microns_per_nanosecond, - microns_per_picosecond, - microns_per_femtosecond, - microns_per_attosecond, - microns_per_minute, - microns_per_hour, - microns_per_day, - microns_per_year, miles_per_second, miles_per_millisecond, miles_per_microsecond, @@ -2732,17 +2634,6 @@ def __init__(self, name: str, units: list[NamedUnit]): angstroms_per_square_hour, angstroms_per_square_day, angstroms_per_square_year, - microns_per_square_second, - microns_per_square_millisecond, - microns_per_square_microsecond, - microns_per_square_nanosecond, - microns_per_square_picosecond, - microns_per_square_femtosecond, - microns_per_square_attosecond, - microns_per_square_minute, - microns_per_square_hour, - microns_per_square_day, - microns_per_square_year, miles_per_square_second, miles_per_square_millisecond, miles_per_square_microsecond, @@ -3048,22 +2939,6 @@ def __init__(self, name: str, units: list[NamedUnit]): atomic_mass_units_per_cubic_angstrom, pounds_per_cubic_angstrom, ounces_per_cubic_angstrom, - grams_per_cubic_micron, - exagrams_per_cubic_micron, - petagrams_per_cubic_micron, - teragrams_per_cubic_micron, - gigagrams_per_cubic_micron, - megagrams_per_cubic_micron, - kilograms_per_cubic_micron, - milligrams_per_cubic_micron, - micrograms_per_cubic_micron, - nanograms_per_cubic_micron, - picograms_per_cubic_micron, - femtograms_per_cubic_micron, - attograms_per_cubic_micron, - atomic_mass_units_per_cubic_micron, - pounds_per_cubic_micron, - ounces_per_cubic_micron, grams_per_cubic_mile, exagrams_per_cubic_mile, petagrams_per_cubic_mile, @@ -3393,7 +3268,6 @@ def __init__(self, name: str, units: list[NamedUnit]): units = [ degrees, radians, - rotations, ]) solid_angle = UnitGroup( @@ -3529,13 +3403,6 @@ def __init__(self, name: str, units: list[NamedUnit]): picomoles_per_cubic_angstrom, femtomoles_per_cubic_angstrom, attomoles_per_cubic_angstrom, - moles_per_cubic_micron, - millimoles_per_cubic_micron, - micromoles_per_cubic_micron, - nanomoles_per_cubic_micron, - picomoles_per_cubic_micron, - femtomoles_per_cubic_micron, - attomoles_per_cubic_micron, moles_per_cubic_mile, millimoles_per_cubic_mile, micromoles_per_cubic_mile, diff --git a/sasdata/slicing/meshes/mesh.py b/sasdata/slicing/meshes/mesh.py index a3e8c0fa0..8cfd8a6a0 100644 --- a/sasdata/slicing/meshes/mesh.py +++ b/sasdata/slicing/meshes/mesh.py @@ -132,6 +132,8 @@ def locate_points(self, x: np.ndarray, y: np.ndarray): x = x.reshape(-1) y = y.reshape(-1) + xy = np.concatenate(([x], [y]), axis=1) + # The most simple implementation is not particularly fast, especially in python # # Less obvious, but hopefully faster strategy diff --git a/sasdata/slicing/meshes/meshmerge.py b/sasdata/slicing/meshes/meshmerge.py index 882699c0d..ae950803e 100644 --- a/sasdata/slicing/meshes/meshmerge.py +++ b/sasdata/slicing/meshes/meshmerge.py @@ -68,24 +68,13 @@ def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray] non_singular = np.linalg.det(deltas) != 0 - st = np.linalg.solve( - deltas[non_singular], - # Reshape is required because solve accepts matrices of shape - # (M) or (..., M, K) for the second parameter, but ours shape - # is (..., M). We add an extra dimension to force our matrix - # into the shape (..., M, 1), which meets the expectations. - # - # - # Due to the reshaping work mentioned above, the final result - # has an extra element of length 1. We then index this extra - # dimension to get back to the result we wanted. - np.expand_dims(start_point_diff[non_singular], axis=2))[:, :, 0] + st = np.linalg.solve(deltas[non_singular], start_point_diff[non_singular]) # Find the points where s and t are in (0, 1) intersection_inds = np.logical_and( - np.logical_and(0 < st[:, 0], st[:, 0] < 1), # noqa SIM300 - np.logical_and(0 < st[:, 1], st[:, 1] < 1)) # noqa SIM300 + np.logical_and(st[:, 0] > 0, st[:, 0] < 1), + np.logical_and(st[:, 1] > 0, st[:, 1] < 1)) start_points_for_intersections = p1[non_singular][intersection_inds, :] deltas_for_intersections = delta1[non_singular][intersection_inds, :] diff --git a/sasdata/slicing/transforms.py b/sasdata/slicing/transforms.py index 724c53ff4..03ee1968e 100644 --- a/sasdata/slicing/transforms.py +++ b/sasdata/slicing/transforms.py @@ -3,56 +3,55 @@ from matplotlib import cm from scipy.spatial import Voronoi -if __name__ == "__main__": - # Some test data +# Some test data - qx_base_values = np.linspace(-10, 10, 21) - qy_base_values = np.linspace(-10, 10, 21) +qx_base_values = np.linspace(-10, 10, 21) +qy_base_values = np.linspace(-10, 10, 21) - qx, qy = np.meshgrid(qx_base_values, qy_base_values) +qx, qy = np.meshgrid(qx_base_values, qy_base_values) - include = np.logical_not((np.abs(qx) < 2) & (np.abs(qy) < 2)) +include = np.logical_not((np.abs(qx) < 2) & (np.abs(qy) < 2)) - qx = qx[include] - qy = qy[include] +qx = qx[include] +qy = qy[include] - r = np.sqrt(qx**2 + qy**2) +r = np.sqrt(qx**2 + qy**2) - data = np.log((1+np.cos(3*r))*np.exp(-r*r)) +data = np.log((1+np.cos(3*r))*np.exp(-r*r)) - colormap = cm.get_cmap('winter', 256) +colormap = cm.get_cmap('winter', 256) - def get_data_mesh(x, y, data): +def get_data_mesh(x, y, data): - input_data = np.array((x, y)).T - voronoi = Voronoi(input_data) + input_data = np.array((x, y)).T + voronoi = Voronoi(input_data) - # plt.scatter(voronoi.vertices[:,0], voronoi.vertices[:,1]) - # plt.scatter(voronoi.points[:,0], voronoi.points[:,1]) + # plt.scatter(voronoi.vertices[:,0], voronoi.vertices[:,1]) + # plt.scatter(voronoi.points[:,0], voronoi.points[:,1]) - cmin = np.min(data) - cmax = np.max(data) + cmin = np.min(data) + cmax = np.max(data) - color_index_map = np.array(255 * (data - cmin) / (cmax - cmin), dtype=int) + color_index_map = np.array(255 * (data - cmin) / (cmax - cmin), dtype=int) - for point_index, points in enumerate(voronoi.points): + for point_index, points in enumerate(voronoi.points): - region_index = voronoi.point_region[point_index] - region = voronoi.regions[region_index] + region_index = voronoi.point_region[point_index] + region = voronoi.regions[region_index] - if len(region) > 0: + if len(region) > 0: - if -1 in region: + if -1 in region: - pass + pass - else: + else: - color = colormap(color_index_map[point_index]) + color = colormap(color_index_map[point_index]) - circly = region + [region[0]] - plt.fill(voronoi.vertices[circly, 0], voronoi.vertices[circly, 1], color=color, edgecolor="white") + circly = region + [region[0]] + plt.fill(voronoi.vertices[circly, 0], voronoi.vertices[circly, 1], color=color, edgecolor="white") - plt.show() + plt.show() - get_data_mesh(qx.reshape(-1), qy.reshape(-1), data) +get_data_mesh(qx.reshape(-1), qy.reshape(-1), data) diff --git a/sasdata/temp_ascii_reader.py b/sasdata/temp_ascii_reader.py index 96e8634bf..d1cb9ec7a 100644 --- a/sasdata/temp_ascii_reader.py +++ b/sasdata/temp_ascii_reader.py @@ -1,27 +1,8 @@ -import re -from dataclasses import dataclass, field, replace -from enum import Enum -from os import path +#!/usr/bin/env python -import numpy as np +from enum import Enum -from sasdata.ascii_reader_metadata import ( - AsciiMetadataCategory, - AsciiReaderMetadata, - bidirectional_pairings, - pairings, -) from sasdata.data import SasData -from sasdata.dataset_types import DatasetType, one_dim, unit_kinds -from sasdata.default_units import get_default_unit -from sasdata.guess import ( - guess_column_count, - guess_columns, - guess_dataset_type, - guess_starting_position, -) -from sasdata.metadata import Metadata, MetaNode -from sasdata.quantities.quantity import NamedQuantity, Quantity from sasdata.quantities.units import NamedUnit @@ -31,192 +12,5 @@ class AsciiSeparator(Enum): Tab = 2 -# TODO: Turn them all of for now so the caller can turn one of them on. But is this the desired behaviour? -def initialise_separator_dict(initial_value: bool = False) -> dict[str, bool]: - return {"Whitespace": initial_value, "Comma": initial_value, "Tab": initial_value} - - -@dataclass -class AsciiReaderParams: - """This object contains the parameters that are used to load a series of - ASCII files. These parameters can be generated by the ASCII Reader Dialog - when using SasView.""" - - # These will be the FULL file path. Will need to convert to basenames for some functions. - filenames: list[str] - # The unit object for the column should only be None if the column is ! - columns: list[tuple[str, NamedUnit | None]] - metadata: AsciiReaderMetadata = field(default_factory=AsciiReaderMetadata) - starting_line: int = 0 - excluded_lines: set[int] = field(default_factory=set) - separator_dict: dict[str, bool] = field(default_factory=initialise_separator_dict) - # Take a copy in case its mutated (which it shouldn't be) - dataset_type: DatasetType = field(default_factory=lambda: replace(one_dim)) - - def __post_init__(self): - self.initialise_metadata() - - def initialise_metadata(self): - for filename in self.filenames: - basename = path.basename(filename) - if basename not in self.metadata.filename_separator: - self.metadata.filename_separator[basename] = "_" - self.metadata.filename_specific_metadata[basename] = {} - - @property - def columns_included(self) -> list[tuple[str, NamedUnit]]: - return [ - column - for column in self.columns - if column[0] != "" and isinstance(column[1], NamedUnit) - ] - - -# TODO: Should I make this work on a list of filenames as well? -def guess_params_from_filename( - filename: str, dataset_type: DatasetType -) -> AsciiReaderParams: - # Lets just assume we want all of the seaprators on. This seems to work for most files. - separator_dict = initialise_separator_dict(True) - with open(filename) as file: - lines = file.readlines() - lines_split = [split_line(separator_dict, line) for line in lines] - startpos = guess_starting_position(lines_split) - colcount = guess_column_count(lines_split, startpos) - columns = [ - (x, get_default_unit(x, unit_kinds[x])) - for x in guess_columns(colcount, dataset_type) - if x in unit_kinds - ] - params = AsciiReaderParams( - [filename], - columns, - starting_line=startpos, - separator_dict=separator_dict, - dataset_type=guess_dataset_type(filename), - ) - return params - - -def split_line(separator_dict: dict[str, bool], line: str) -> list[str]: - """Split a line in a CSV file based on which seperators the user has - selected on the widget. - - """ - expr = "" - for seperator, isenabled in separator_dict.items(): - if isenabled: - if expr != r"": - expr += r"|" - match seperator: - case "Comma": - seperator_text = r"," - case "Whitespace": - seperator_text = r"\s+" - case "Tab": - seperator_text = r"\t" - expr += seperator_text - - return re.split(expr, line.strip()) - - -# TODO: Implement error handling. -def load_quantities(params: AsciiReaderParams, filename: str, metadata: Metadata) -> dict[str, Quantity]: - """Load a list of quantities from the filename based on the params.""" - with open(filename) as ascii_file: - lines = ascii_file.readlines() - arrays: list[np.ndarray] = [] - for _ in params.columns_included: - arrays.append(np.zeros(len(lines) - params.starting_line)) - for i, current_line in enumerate(lines): - if i < params.starting_line or current_line in params.excluded_lines: - continue - line_split = split_line(params.separator_dict, current_line) - try: - for j, token in enumerate(line_split): - # Sometimes in the split, there might be an extra column that doesn't need to be there (e.g. an empty - # string.) This won't convert to a float so we need to ignore it. - if j >= len(params.columns_included): - continue - # TODO: Data might not be floats. Maybe don't hard code this. - arrays[j][i - params.starting_line] = float(token) - except ValueError: - # If any of the lines contain non-numerical data, then this line can't be read in as a quantity so it - # should be ignored entirely. - print(f"Line {i + 1} skipped.") - continue - file_quantities = { - name: NamedQuantity(name, arrays[i], unit, id_header=metadata.id_header) - for i, (name, unit) in enumerate(params.columns_included) - } - return file_quantities - - -def import_metadata(metadata: dict[str, AsciiMetadataCategory[str]]) -> MetaNode: - root_contents = [] - for top_level_key, top_level_item in metadata.items(): - children = [] - for metadatum_name, metadatum in top_level_item.values.items(): - children.append(MetaNode(name=metadatum_name, attrs={}, contents=metadatum)) - if top_level_key == "other": - root_contents.extend(children) - else: - group = MetaNode(name=top_level_key, attrs={}, contents=children) - root_contents.append(group) - return MetaNode(name="root", attrs={}, contents=root_contents) - - -def merge_uncertainties(quantities: dict[str, Quantity]) -> dict[str, Quantity]: - """Data in the ASCII files will have the uncertainties in a separate column. - This function will merge columns of data with the columns containing their - uncertainties so that both are in one Quantity object.""" - new_quantities: dict[str, Quantity] = {} - error_quantity_names = pairings.values() - for name, quantity in quantities.items(): - if name in error_quantity_names: - continue - pairing = bidirectional_pairings.get(name, "") - error_quantity = None - for other_name, other_quantity in quantities.items(): - if other_name == pairing: - error_quantity = other_quantity - if error_quantity is not None: - to_add = quantity.with_standard_error(error_quantity) - else: - to_add = quantity - new_quantities[name] = to_add - return new_quantities - - -def load_data(params: AsciiReaderParams) -> list[SasData]: - """This loads a series of SasData objects based on the params. The amount of - SasData objects loaded will depend on how many filenames are present in the - list contained in the params.""" - loaded_data: list[SasData] = [] - for filename in params.filenames: - raw_metadata = import_metadata( - params.metadata.all_file_metadata(path.basename(filename)) - ) - metadata = Metadata( - title=None, - run=[], - definition=None, - sample=None, - instrument=None, - process=None, - raw=raw_metadata, - ) - quantities = load_quantities(params, filename, metadata) - data = SasData( - path.basename(filename), - merge_uncertainties(quantities), - params.dataset_type, - metadata, - ) - loaded_data.append(data) - return loaded_data - - -def load_data_default_params(filename: str) -> list[SasData]: - params = guess_params_from_filename(filename, guess_dataset_type(filename)) - return load_data(params) +def load_data(filename: str, starting_line: int, columns: list[tuple[str, NamedUnit]], separators: list[AsciiSeparator]) -> list[SasData]: + raise NotImplementedError() diff --git a/sasdata/temp_hdf5_reader.py b/sasdata/temp_hdf5_reader.py index e439486ad..435f597ad 100644 --- a/sasdata/temp_hdf5_reader.py +++ b/sasdata/temp_hdf5_reader.py @@ -1,5 +1,4 @@ import logging -from collections.abc import Callable import h5py import numpy as np @@ -9,31 +8,16 @@ from sasdata.data import SasData from sasdata.data_backing import Dataset as SASDataDataset from sasdata.data_backing import Group as SASDataGroup -from sasdata.dataset_types import one_dim, three_dim, two_dim -from sasdata.metadata import ( - Aperture, - BeamSize, - Collimation, - Detector, - Instrument, - Metadata, - MetaNode, - Process, - Rot3, - Sample, - Source, - Vec3, -) +from sasdata.metadata import Aperture, Collimation, Instrument, Source from sasdata.quantities import units -from sasdata.quantities.quantity import NamedQuantity, Quantity +from sasdata.quantities.quantity import NamedQuantity from sasdata.quantities.unit_parser import parse # test_file = "./example_data/1d_data/33837rear_1D_1.75_16.5_NXcanSAS_v3.h5" # test_file = "./example_data/1d_data/33837rear_1D_1.75_16.5_NXcanSAS.h5" -test_file = "./example_data/2d_data/BAM_2D.h5" -# test_file = "./example_data/2d_data/14250_2D_NoDetInfo_NXcanSAS_v3.h5" +# test_file = "./example_data/2d_data/BAM_2D.h5" +test_file = "./example_data/2d_data/14250_2D_NoDetInfo_NXcanSAS_v3.h5" # test_file = "./example_data/2d_data/33837rear_2D_1.75_16.5_NXcanSAS_v3.h5" -test_file = "./test/sasdataloader/data/nxcansas_1Dand2D_multisasdata.h5" logger = logging.getLogger(__name__) @@ -72,7 +56,7 @@ def recurse_hdf5(hdf5_entry): GET_UNITS_FROM_ELSEWHERE = units.meters -def connected_data(node: SASDataGroup, name_prefix="", metadata=None) -> dict[str, Quantity]: +def connected_data(node: SASDataGroup, name_prefix="") -> list[NamedQuantity]: """In the context of NeXus files, load a group of data entries that are organised together match up the units and errors with their values""" # Gather together data with its error terms @@ -84,12 +68,14 @@ def connected_data(node: SASDataGroup, name_prefix="", metadata=None) -> dict[st for name in node.children: child = node.children[name] - if "units" in child.attributes and child.attributes["units"]: + if "units" in child.attributes: units = parse(child.attributes["units"]) else: units = GET_UNITS_FROM_ELSEWHERE - quantity = NamedQuantity(name=child.name, value=child.data, units=units, id_header=metadata.id_header) + quantity = NamedQuantity( + name=name_prefix + child.name, value=child.data, units=units + ) # Turns out people can't be trusted to use the same keys here if "uncertainty" in child.attributes or "uncertainties" in child.attributes: @@ -102,146 +88,61 @@ def connected_data(node: SASDataGroup, name_prefix="", metadata=None) -> dict[st entries[name] = quantity - output : dict[str, Quantity] = {} + output = [] for name, entry in entries.items(): if name not in uncertainties: if name in uncertainty_map: uncertainty = entries[uncertainty_map[name]] new_entry = entry.with_standard_error(uncertainty) - output[name] = new_entry + output.append(new_entry) else: - output[name] = entry + output.append(entry) return output -### Begin metadata parsing code - -def get_canSAS_class(node : HDF5Group) -> str | None: - # Check if attribute exists - if "canSAS_class" in node.attrs: - cls = node.attrs["canSAS_class"] - return cls - elif "NX_class" in node.attrs: - cls = node.attrs["NX_class"] - cls = NX2SAS_class(cls) - # note that sastransmission groups have a - # NX_class of NXdata but a canSAS_class of SAStransmission_spectrum - # which is ambiguous because then how can one tell if it is a SASdata - # or a SAStransmission_spectrum object from the NX_class? - if node.name.lower().startswith("sastransmission"): - cls = 'SAStransmission_spectrum' - return cls - - return None - -def NX2SAS_class(cls : str) -> str | None: - # converts NX class names to canSAS class names - mapping = { - "NXentry": "SASentry", - "NXdata": "SASdata", - "NXdetector": "SASdetector", - "NXinstrument": "SASinstrument", - "NXnote": "SASnote", - "NXprocess": "SASprocess", - "NXcollection": "SASprocessnote", - "NXsample": "SASsample", - "NXsource": "SASsource", - "NXaperture": "SASaperture", - "NXcollimator": "SAScollimation", - - "SASentry": "SASentry", - "SASdata": "SASdata", - "SASdetector": "SASdetector", - "SASinstrument": "SASinstrument", - "SASnote": "SASnote", - "SASprocess": "SASprocess", - "SASprocessnote": "SASprocessnote", - "SAStransmission_spectrum": "SAStransmission_spectrum", - "SASsample": "SASsample", - "SASsource": "SASsource", - "SASaperture": "SASaperture", - "SAScollimation": "SAScollimation", - } - if isinstance(cls, bytes): - cls = cls.decode() - return mapping.get(cls, None) - -def find_canSAS_key(node: HDF5Group, canSAS_class: str): - matches = [] - - for key, item in node.items(): - if item.attrs.get("canSAS_class") == canSAS_class: - matches.append(key) - - return matches - -def parse_quantity(node : HDF5Group) -> Quantity[float]: - """Pull a single quantity with units out of an HDF5 node""" - magnitude = parse_float(node) - unit = node.attrs["units"] - return Quantity(magnitude, parse(unit)) - -def parse_string(node : HDF5Group) -> str: - """Access string data from a node""" - if node.shape == (): # scalar dataset - return node.asstr()[()] - else: # vector dataset - return node.asstr()[0] - -def parse_float(node: HDF5Group) -> float: - """Return the first element (or scalar) of a numeric dataset as float.""" - if node.shape == (): - return float(node[()].astype(str)) - else: - return float(node[0].astype(str)) - -def opt_parse[T](node: HDF5Group, key: str, subparser: Callable[[HDF5Group], T], ignore_case=False) -> T | None: - """Parse a subnode if it is present""" - if ignore_case: # ignore the case of the key - key = next((k for k in node.keys() if k.lower() == key.lower()), None) - if key in node: - return subparser(node[key]) - return None - -def attr_parse(node: HDF5Group, key: str) -> str | None: - """Parse an attribute if it is present""" - if key in node.attrs: - return node.attrs[key] - return None - - -def parse_aperture(node : HDF5Group) -> Aperture: - distance = opt_parse(node, "distance", parse_quantity) - name = attr_parse(node, "name") - size = opt_parse(node, "size", parse_vec3) - size_name = None - type_ = attr_parse(node, "type") - if size: - size_name = attr_parse(node["size"], "name") - else: - size_name = None - return Aperture(distance=distance, size=size, size_name=size_name, name=name, type_=type_) - -def parse_beam_size(node : HDF5Group) -> BeamSize: - name = attr_parse(node, "name") - size = parse_vec3(node) - return BeamSize(name=name, size=size) - -def parse_source(node : HDF5Group) -> Source: - radiation = opt_parse(node, "radiation", parse_string) - beam_shape = opt_parse(node, "beam_shape", parse_string) - beam_size = opt_parse(node, "beam_size", parse_beam_size) - wavelength = opt_parse(node, "wavelength", parse_quantity) - if wavelength is None: - wavelength = opt_parse(node, "incident_wavelength", parse_quantity) - wavelength_min = opt_parse(node, "wavelength_min", parse_quantity) - wavelength_max = opt_parse(node, "wavelength_max", parse_quantity) - wavelength_spread = opt_parse(node, "wavelength_spread", parse_quantity) - if wavelength_spread is None: - wavelength_spread = opt_parse(node, "incident_wavelength_spread", parse_quantity) + +def parse_apertures(node) -> list[Aperture]: + result = [] + aps = [a for a in node if "aperture" in a] + for ap in aps: + distance = None + size = None + if "distance" in node[ap]: + distance = node[ap]["distance"] + if "size" in node[ap]: + x = y = z = None + if "x" in node[ap]: + x = node[ap]["size"]["x"] + if "y" in node[ap]: + y = node[ap]["size"]["y"] + if "z" in node[ap]: + z = node[ap]["size"]["z"] + if x is not None or y is not None or z is not None: + size = (x, y, z) + result.append(Aperture(distance=distance, size=size, size_name=size_name, name=name, apType=apType)) + return result + + +def parse_source(node) -> Source: + beam_shape = None + beam_size = None + wavelength = None + wavelength_min = None + wavelength_max = None + wavelength_spread = None + if "beam_shape" in node: + beam_shape = node["beam_shape"] + if "wavelength" in node: + wavelength = node["wavelength"] + if "wavelength_min" in node: + wavelength = node["wavelength_min"] + if "wavelength_max" in node: + wavelength = node["wavelength_max"] + if "wavelength_spread" in node: + wavelength = node["wavelength_spread"] return Source( - radiation=radiation, + radiation=node["radiation"].asstr()[0], beam_shape=beam_shape, beam_size=beam_size, wavelength=wavelength, @@ -250,208 +151,75 @@ def parse_source(node : HDF5Group) -> Source: wavelength_spread=wavelength_spread, ) -def parse_vec3(node : HDF5Group) -> Vec3: - """Parse a measured 3-vector""" - x = opt_parse(node, "x", parse_quantity) - y = opt_parse(node, "y", parse_quantity) - z = opt_parse(node, "z", parse_quantity) - return Vec3(x=x, y=y, z=z) - -def parse_rot3(node : HDF5Group) -> Rot3: - """Parse a measured rotation""" - roll = opt_parse(node, "roll", parse_quantity) - pitch = opt_parse(node, "pitch", parse_quantity) - yaw = opt_parse(node, "yaw", parse_quantity) - return Rot3(roll=roll, pitch=pitch, yaw=yaw) - -def parse_detector(node : HDF5Group) -> Detector: - name = opt_parse(node, "name", parse_string) - distance = opt_parse(node, "SDD", parse_quantity) - offset = opt_parse(node, "offset", parse_vec3) - orientation = opt_parse(node, "orientation", parse_rot3) - beam_center = opt_parse(node, "beam_center", parse_vec3) - pixel_size = opt_parse(node, "pixel_size", parse_vec3) - slit_length = opt_parse(node, "slit_length", parse_quantity) - - return Detector(name=name, - distance=distance, - offset=offset, - orientation=orientation, - beam_center=beam_center, - pixel_size=pixel_size, - slit_length=slit_length) - - - -def parse_collimation(node : HDF5Group) -> Collimation: - length = opt_parse(node, "length", parse_quantity) - - keys = find_canSAS_key(node, "SASaperture") - keys = list(keys) if keys is not None else [] # list([1,2,3]) returns [1,2,3] and list("string") returns ["string"] - apertures = [parse_aperture(node[p]) for p in keys] # Empty list of keys will give an empty collimations list - - return Collimation(length=length, apertures=apertures) - - -def parse_instrument(node : HDF5Group) -> Instrument: - keys = find_canSAS_key(node, "SAScollimation") - keys = list(keys) if keys is not None else [] # list([1,2,3]) returns [1,2,3] and list("string") returns ["string"] - collimations = [parse_collimation(node[p]) for p in keys] # Empty list of keys will give an empty collimations list - keys = find_canSAS_key(node, "SASdetector") - keys = list(keys) if keys is not None else [] # list([1,2,3]) returns [1,2,3] and list("string") returns ["string"] - detector = [parse_detector(node[p]) for p in keys] # Empty list of keys will give an empty collimations list +def parse_collimation(node) -> Collimation: + if "length" in node: + length = node["length"] + else: + length = None + return Collimation(length=length, apertures=parse_apertures(node)) - keys = find_canSAS_key(node, "SASsource") - source = parse_source(node[keys[0]]) if keys is not None else None +def parse_instrument(raw, node) -> Instrument: return Instrument( - collimations=collimations, - detector=detector, - source=source, + collimations= [parse_collimation(node[x]) for x in node if "collimation" in x], + source=parse_source(node["sassource"]), ) -def parse_sample(node : HDF5Group) -> Sample: - name = attr_parse(node, "name") - sample_id = opt_parse(node, "ID", parse_string) - thickness = opt_parse(node, "thickness", parse_quantity) - transmission = opt_parse(node, "transmission", parse_float) - temperature = opt_parse(node, "temperature", parse_quantity) - position = opt_parse(node, "position", parse_vec3) - orientation = opt_parse(node, "orientation", parse_rot3) - details : list[str] = sum([list(node[d].asstr()[()]) for d in node if "details" in d], []) - return Sample(name=name, - sample_id=sample_id, - thickness=thickness, - transmission=transmission, - temperature=temperature, - position=position, - orientation=orientation, - details=details) - -def parse_term(node : HDF5Group) -> tuple[str, str | Quantity[float]] | None: - name = attr_parse(node, "name") - unit = attr_parse(node, "unit") - value = attr_parse(node, "value") - if name is None or value is None: - return None - if unit and unit.strip(): - return (name, Quantity(float(value), units.symbol_lookup[unit])) - return (name, value) - - -def parse_process(node : HDF5Group) -> Process: - name = opt_parse(node, "name", parse_string) - date = opt_parse(node, "date", parse_string) - description = opt_parse(node, "description", parse_string) - term_values = [parse_term(node[n]) for n in node if "term" in n] - terms = {tup[0]: tup[1] for tup in term_values if tup is not None} - notes = [parse_string(node[n]) for n in node if "note" in n] - return Process(name=name, date=date, description=description, terms=terms, notes=notes) - -def load_raw(node: HDF5Group | HDF5Dataset) -> MetaNode: - name = node.name.split("/")[-1] - match node: - case HDF5Group(): - attrib = {a: node.attrs[a] for a in node.attrs} - contents = [load_raw(node[v]) for v in node] - return MetaNode(name=name, attrs=attrib, contents=contents) - case HDF5Dataset(dtype=dt): - attrib = {a: node.attrs[a] for a in node.attrs} - if (str(dt).startswith("|S")): - if "units" in attrib: - contents = parse_string(node) - else: - contents = parse_string(node) - else: - if "units" in attrib and attrib["units"]: - data = node[()] if node.shape == () else node[:] - contents = Quantity(data, parse(attrib["units"]), id_header=node.name) - else: - contents = node[()] if node.shape == () else node[:] - return MetaNode(name=name, attrs=attrib, contents=contents) - case _: - raise RuntimeError(f"Cannot load raw data of type {type(node)}") - -def parse_metadata(node : HDF5Group) -> Metadata: - # parse the metadata groups - keys = find_canSAS_key(node, "SASinstrument") - keys = list(keys) if keys else [] # list([1,2,3]) returns [1,2,3] and list("string") returns ["string"] - instrument = parse_instrument(node[keys[0]]) if keys else None - - keys = find_canSAS_key(node, "SASsample") - keys = list(keys) if keys else [] # list([1,2,3]) returns [1,2,3] and list("string") returns ["string"] - sample = parse_sample(node[keys[0]]) if keys else None - - keys = find_canSAS_key(node, "SASprocess") - keys = list(keys) if keys else [] # list([1,2,3]) returns [1,2,3] and list("string") returns ["string"] - process = [parse_process(node[p]) for p in keys] # Empty list of keys will give an empty collimations list - - # parse the datasets - title = opt_parse(node, "title", parse_string) - run = [parse_string(node[r]) for r in node if "run" in r] - definition = opt_parse(node, "definition", parse_string) - - # load the entire node recursively into a raw object - raw = load_raw(node) - - return Metadata(process=process, - instrument=instrument, - sample=sample, - title=title, - run=run, - definition=definition, - raw=raw) +def load_data(filename) -> list[SasData]: + with h5py.File(filename, 'r') as f: -def load_data(filename: str) -> dict[str, SasData]: - with h5py.File(filename, "r") as f: - loaded_data: dict[str, SasData] = {} + loaded_data: list[SasData] = [] for root_key in f.keys(): + entry = f[root_key] - # if this is actually a SASentry - if not get_canSAS_class(entry) == 'SASentry': - continue - data_contents : dict[str, Quantity] = {} + data_contents = [] + raw_metadata = {} - entry_keys = entry.keys() + entry_keys = [key for key in entry.keys()] - if not [k for k in entry_keys if get_canSAS_class(entry[k])=='SASdata']: + if "sasdata" not in entry_keys and "data" not in entry_keys: logger.warning("No sasdata or data key") - logger.warning(f"Known keys: {[k for k in entry_keys]}") - - metadata = parse_metadata(f[root_key]) for key in entry_keys: component = entry[key] - if get_canSAS_class(entry[key])=='SASdata': + lower_key = key.lower() + if lower_key == "sasdata" or lower_key == "data": datum = recurse_hdf5(component) - data_contents = connected_data(datum, str(filename), metadata) + # TODO: Use named identifier + data_contents = connected_data(datum, "FILE_ID_HERE") - if "Qz" in data_contents: - dataset_type = three_dim - elif "Qy" in data_contents: - dataset_type = two_dim - else: - dataset_type = one_dim + else: + raw_metadata[key] = recurse_hdf5(component) + + instrument = opt_parse(f["sasentry01"], "sasinstrument", parse_instrument) + sample = opt_parse(f["sasentry01"], "sassample", parse_sample) + process = [parse_process(f["sasentry01"][p]) for p in f["sasentry01"] if "sasprocess" in p] + title = opt_parse(f["sasentry01"], "title", parse_string) + run = [parse_string(f["sasentry01"][r]) for r in f["sasentry01"] if "run" in r] + definition = opt_parse(f["sasentry01"], "definition", parse_string) - entry_key = entry.attrs["sasview_key"] if "sasview_key" in entry.attrs else root_key + metadata = Metadata(process=process, instrument=instrument, sample=sample, title=title, run=run, definition=definition) - loaded_data[entry_key] = SasData( + loaded_data.append( + SasData( name=root_key, - dataset_type=dataset_type, data_contents=data_contents, - metadata=metadata, + raw_metadata=SASDataGroup("root", raw_metadata), + instrument=instrument, verbose=False, ) + ) return loaded_data + if __name__ == "__main__": data = load_data(test_file) - for dataset in data.values(): - print(dataset.summary()) + for dataset in data: + print(dataset.summary(include_raw=False)) diff --git a/sasdata/temp_sesans_reader.py b/sasdata/temp_sesans_reader.py index 23d561593..b6cfb6861 100644 --- a/sasdata/temp_sesans_reader.py +++ b/sasdata/temp_sesans_reader.py @@ -3,67 +3,37 @@ """ import re -from collections import defaultdict from itertools import groupby -import numpy as np - from sasdata.data import SasData from sasdata.data_util.loader_exceptions import FileContentsException -from sasdata.dataset_types import sesans +from sasdata.dataset_types import one_dim from sasdata.metadata import ( Metadata, MetaNode, Process, Sample, ) -from sasdata.quantities import unit_parser, units +from sasdata.quantities import unit_parser from sasdata.quantities.quantity import Quantity def parse_version(lines: list[str]) -> tuple[str, list[str]]: + import re + header = lines[0] m = re.search(r"FileFormatVersion\s+(\S+)", header) if m is None: - raise FileContentsException( - "Sesans file does not contain File Format Version header" - ) + raise FileContentsException("Alleged Sesans file does not contain File Format Version header") return (m.group(0), lines[1:]) - -def parse_title(kvs: dict[str, str]) -> str: - """Get the title from the key value store""" - if "Title" in kvs: - return kvs["Title"] - elif "DataFileTitle" in kvs: - return kvs["DataFileTitle"] - for k, v in kvs.items(): - if "Title" in k: - return v - return "" - - -def parse_kvs_quantity(key: str, kvs: dict[str, str]) -> Quantity | None: - if key not in kvs or key + "_unit" not in kvs: - return None - return Quantity(value=float(kvs[key]), units=unit_parser.parse(kvs[key + "_unit"])) - - -def parse_sample(kvs: dict[str, str]) -> Sample: - """Get the sample info from the key value store""" - - thickness = parse_kvs_quantity("Thickness", kvs) - if thickness is None: - raise FileContentsException( - "SES format must include sample thickness to normalise calculations" - ) - - return Sample( - name=kvs.get("Sample"), +def parse_metadata(lines: list[str]) -> tuple[Metadata, list[str]]: + sample = Sample( + name=None, sample_id=None, - thickness=thickness, + thickness=None, transmission=None, temperature=None, position=None, @@ -75,7 +45,7 @@ def parse_sample(kvs: dict[str, str]) -> Sample: def parse_process(kvs: dict[str, str]) -> Process: ymax = parse_kvs_quantity("Theta_ymax", kvs) zmax = parse_kvs_quantity("Theta_zmax", kvs) - orientation = kvs.get("Orientation") + orientation = parse_kvs_text("Orientation", kvs) if ymax is None: raise FileContentsException("SES file must specify Theta_ymax") @@ -84,7 +54,7 @@ def parse_process(kvs: dict[str, str]) -> Process: if orientation is None: raise FileContentsException("SES file must include encoding orientation") - terms: dict[str, str | Quantity] = { + terms: dict[str, str | Quantity[float]] = { "ymax": ymax, "zmax": zmax, "orientation": orientation, @@ -136,7 +106,7 @@ def parse_metadata(lines: list[str]) -> tuple[Metadata, dict[str, str], list[str # Parse key value store kvs: dict[str, str] = {} for line in parts[0]: - m = re.search(r"(\S+)\s+(.+)\n", line) + m = re.search("(\\S+)\\s+(.+)\n", line) if not m: continue kvs[m.group(1)] = m.group(2) @@ -149,55 +119,23 @@ def parse_metadata(lines: list[str]) -> tuple[Metadata, dict[str, str], list[str title=parse_title(kvs), run=[], definition=None, - raw=parse_metanode(kvs), ), - kvs, - parts[2], + lines, ) -def parse_data(lines: list[str], kvs: dict[str, str]) -> dict[str, Quantity]: - +def parse_data(lines: list[str]) -> dict[str, Quantity]: data_contents: dict[str, Quantity] = {} - headers = lines[0].split() - points = defaultdict(list) - for line in lines[1:]: - values = line.split() - for idx, v in enumerate(values): - points[headers[idx]].append(float(v)) - - for h in points.keys(): - if h.endswith("_error") and h[:-6] in headers: - # This was an error line - continue - unit = units.none - if h + "_unit" in kvs: - unit = unit_parser.parse(kvs[h + "_unit"]) - - error = None - if h + "_error" in headers: - error = np.asarray(points[h + "_error"]) - - data_contents[h] = Quantity( - value=np.asarray(points[h]), - units=unit, - standard_error=error, - ) - - for required in ["SpinEchoLength", "Depolarisation", "Wavelength"]: - if required not in data_contents: - raise FileContentsException(f"SES file missing {required}") - return data_contents def parse_sesans(lines: list[str]) -> SasData: version, lines = parse_version(lines) - metadata, kvs, lines = parse_metadata(lines) - data_contents = parse_data(lines, kvs) + metadata, lines = parse_metadata(lines) + data_contents = parse_data(lines) return SasData( name="Sesans", - dataset_type=sesans, + dataset_type=one_dim, data_contents=data_contents, metadata=metadata, verbose=False, diff --git a/sasdata/temp_xml_reader.py b/sasdata/temp_xml_reader.py index 174fe7389..b6ea5f2cb 100644 --- a/sasdata/temp_xml_reader.py +++ b/sasdata/temp_xml_reader.py @@ -215,7 +215,7 @@ def parse_data(node: etree._Element, version: str, metadata: Metadata) -> dict[s struct = {} for value in idata.getchildren(): name = etree.QName(value).localname - if value.text is None or parse_string(value, version).strip() == "": + if value.text is None or value.text.strip() == "": continue if name not in us: unit = ( diff --git a/sasdata/transforms/rebinning.py b/sasdata/transforms/rebinning.py index dcb9a2ca6..e5f48b18c 100644 --- a/sasdata/transforms/rebinning.py +++ b/sasdata/transforms/rebinning.py @@ -123,69 +123,12 @@ def calculate_interpolation_matrix_1d(input_axis: Quantity[ArrayLike], case InterpolationOptions.CUBIC: # Cubic interpolation, much harder to implement because we can't just cheat and use numpy - - input_indices = np.arange(n_in, dtype=int) - output_indices = np.arange(n_out, dtype=int) - - # Find the location of the largest value in sorted_in that - # is less than every value of sorted_out - lower_bound = ( - np.sum(np.where(np.less.outer(sorted_in, sorted_out), 1, 0), axis=0) - 1 - ) - - # We're using the Finite Difference Cubic Hermite spline - # https://en.wikipedia.org/wiki/Cubic_Hermite_spline#Interpolation_on_an_arbitrary_interval - # https://en.wikipedia.org/wiki/Cubic_Hermite_spline#Finite_difference - - x1 = sorted_in[lower_bound] # xₖ on the wiki - x2 = sorted_in[lower_bound + 1] # xₖ₊₁ on the wiki - - x0 = sorted_in[lower_bound[lower_bound - 1 >= 0] - 1] # xpₖ₋₁ on the wiki - x0 = np.hstack([np.zeros(x1.size - x0.size), x0]) - - x3 = sorted_in[ - lower_bound[lower_bound + 2 < sorted_in.size] + 2 - ] # xₖ₊₂ on the wiki - x3 = np.hstack([x3, np.zeros(x2.size - x3.size)]) - - t = (sorted_out - x1) / (x2 - x1) # t on the wiki - - y0 = ( - -t * (x1 - x2) * (t**2 - 2 * t + 1) / (2 * x0 - 2 * x1) - ) # The coefficient to pₖ₋₁ on the wiki - y1 = ( - -t * (t**2 - 2 * t + 1) * (x0 - 2 * x1 + x2) - + (x0 - x1) * (3 * t**3 - 5 * t**2 + 2) - ) / (2 * (x0 - x1)) # The coefficient to pₖ - y2 = ( - t - * ( - -t * (t - 1) * (x1 - 2 * x2 + x3) - + (x2 - x3) * (-3 * t**2 + 4 * t + 1) - ) - / (2 * (x2 - x3)) - ) # The coefficient to pₗ₊₁ - y3 = t**2 * (t - 1) * (x1 - x2) / (2 * (x2 - x3)) # The coefficient to pₖ₊₂ - - conversion_matrix = np.zeros((n_in, n_out)) - - (row, column) = np.indices(conversion_matrix.shape) - - mask1 = row == lower_bound[column] - - conversion_matrix[np.roll(mask1, -1, axis=0)] = y0 - conversion_matrix[mask1] = y1 - conversion_matrix[np.roll(mask1, 1, axis=0)] = y2 - - # Special boundary condition for y3 - pick = np.roll(mask1, 2, axis=0) - pick[0:1, :] = 0 - if pick.any(): - conversion_matrix[pick] = y3 + raise NotImplementedError("Cubic interpolation not implemented yet") case _: raise InterpolationError(f"Unsupported interpolation order: {order}") + if mask is None: return conversion_matrix, None diff --git a/sasdata/trend.py b/sasdata/trend.py index 9b1a371a4..16d1c67f3 100644 --- a/sasdata/trend.py +++ b/sasdata/trend.py @@ -1,89 +1,25 @@ -from dataclasses import dataclass +#!/usr/bin/env python -import numpy as np +from dataclasses import dataclass from sasdata.data import SasData -from sasdata.data_backing import Dataset, Group -from sasdata.quantities.quantity import Quantity -from sasdata.transforms.rebinning import calculate_interpolation_matrix_1d # Axis strs refer to the name of their associated NamedQuantity. -# TODO: This probably shouldn't be here but will keep it here for now. -# TODO: Not sure how to type hint the return. -def get_metadatum_from_path(data: SasData, metadata_path: list[str]): - current_group = data._raw_metadata - for path_item in metadata_path: - current_item = current_group.children.get(path_item, None) - if current_item is None or (isinstance(current_item, Dataset) and path_item != metadata_path[-1]): - raise ValueError('Path does not lead to valid a metadatum.') - elif isinstance(current_item, Group): - current_group = current_item - else: - return current_item.data - raise ValueError('End of path without finding a dataset.') - - @dataclass class Trend: data: list[SasData] - # This is going to be a path to a specific metadatum. - # - # TODO: But what if the trend axis will be a particular NamedQuantity? Will probably need to think on this. - trend_axis: list[str] + trend_axis: str # Designed to take in a particular value of the trend axis, and return the SasData object that matches it. # TODO: Not exaclty sure what item's type will be. It could depend on where it is pointing to. def __getitem__(self, item) -> SasData: - for datum in self.data: - metadatum = get_metadatum_from_path(datum, self.trend_axis) - if metadatum == item: - return datum - raise KeyError() - @property - def trend_axes(self) -> list[float]: - return [get_metadatum_from_path(datum, self.trend_axis) for datum in self.data] + raise NotImplementedError() - # TODO: Assumes there are at least 2 items in data. Is this reasonable to assume? Should there be error handling for - # situations where this may not be the case? def all_axis_match(self, axis: str) -> bool: - reference_data = self.data[0] - data_axis = reference_data[axis] - for datum in self.data[1::]: - axis_datum = datum[axis] - # FIXME: Linter is complaining about typing. - if not np.all(np.isclose(axis_datum.value, data_axis.value)): - return False - return True - - # TODO: For now, return a new trend, but decide later. Shouldn't be too hard to change. - def interpolate(self, axis: str) -> "Trend": - new_data: list[SasData] = [] - reference_data = self.data[0] - # TODO: I don't like the repetition here. Can probably abstract a function for this ot make it clearer. - data_axis = reference_data[axis] - for i, datum in enumerate(self.data): - if i == 0: - # This is already the reference axis; no need to interpolate it. - continue - # TODO: Again, repetition - axis_datum = datum[axis] - # TODO: There are other options which may need to be filled (or become new params to this method) - mat, _ = calculate_interpolation_matrix_1d(axis_datum, data_axis) - new_quantities: dict[str, Quantity] = {} - for name, quantity in datum._data_contents.items(): - if name == axis: - new_quantities[name] = data_axis - continue - new_quantities[name] = quantity @ mat + raise NotImplementedError() - new_datum = SasData( - name=datum.name, - data_contents=new_quantities, - dataset_type=datum.dataset_type, - metadata=datum.metadata, - ) - new_data.append(new_datum) - new_trend = Trend(new_data, - self.trend_axis) - return new_trend + # TODO: Not sure if this should return a new trend, or just mutate the existing trend + # TODO: May be some details on the method as well. + def interpolate(self, axis: str) -> Self: + raise NotImplementedError() diff --git a/test/quantities/utest_math_operations.py b/test/quantities/utest_math_operations.py index 85e9ba1d6..67e10813c 100644 --- a/test/quantities/utest_math_operations.py +++ b/test/quantities/utest_math_operations.py @@ -30,7 +30,7 @@ def test_transpose_raw(order: list[int]): @pytest.mark.parametrize("order", order_list) -def test_transpose_raw_with_quantity(order: list[int]): +def test_transpose_raw(order: list[int]): """ Check that the transpose operation changes the order of indices correctly - uses sizes as way of tracking""" input_shape = tuple([i + 1 for i in range(len(order))]) expected_shape = tuple([i + 1 for i in order]) diff --git a/test/quantities/utest_operations.py b/test/quantities/utest_operations.py index 5b64eab0e..f4599f1b4 100644 --- a/test/quantities/utest_operations.py +++ b/test/quantities/utest_operations.py @@ -6,27 +6,15 @@ from sasdata.quantities.quantity import ( Add, AdditiveIdentity, - ArcCos, - ArcSin, - ArcTan, Constant, - Cos, Div, - Dot, - Exp, Inv, - Ln, - Log, - MatMul, Mul, MultiplicativeIdentity, Neg, Operation, Pow, - Sin, Sub, - Tan, - Transpose, Variable, ) diff --git a/test/quantities/utest_quantities.py b/test/quantities/utest_quantities.py index 7edfd14df..3ac18282d 100644 --- a/test/quantities/utest_quantities.py +++ b/test/quantities/utest_quantities.py @@ -83,7 +83,7 @@ def test_good_non_integer_unit_powers(unit_in, power, unit_out): def test_bad_non_integer_unit_powers(unit, power): """ Check that we get an error if we try and do something silly with powers""" with pytest.raises(units.DimensionError): - unit**power + x = unit**power @pytest.mark.parametrize("unit_1", si.all_si) @@ -131,11 +131,3 @@ def test_equality(): assert Quantity(1.0, units.angstroms) == Quantity(0.1, units.nanometers) assert Quantity(1.0, units.angstroms) != Quantity(0.1, units.angstroms) assert Quantity(1.0, units.angstroms) == Quantity(1.0e-10, units.meters) - -@pytest.mark.quantity -def test_explicit_format(): - value = Quantity(1.0, units.electronvolts) - assert value.explicitly_formatted("J") == "1.602176634e-19 J" - assert value.explicitly_formatted("N m") == "1.602176634e-19 N m" - assert value.explicitly_formatted("m N") == "1.602176634e-19 m N" - assert value.explicitly_formatted("m kilogram m / hour / year") == "1.8201532008477443e-08 m kilogram m / hour / year" diff --git a/test/quantities/utest_quantity_error.py b/test/quantities/utest_quantity_error.py index ce8be4cee..e8e55337d 100644 --- a/test/quantities/utest_quantity_error.py +++ b/test/quantities/utest_quantity_error.py @@ -20,7 +20,6 @@ def test_addition_propagation(x_err, y_err, x_units, y_units): _, err = (x + y).in_si_with_standard_error() assert err == pytest.approx(expected_err, abs=1e-8) - @pytest.mark.parametrize("x_val, y_val, x_units, y_units", [(1, 1, units.meters, units.meters), (1, 1, units.centimeters, units.centimeters), diff --git a/test/quantities/utest_units.py b/test/quantities/utest_units.py index 3bc775313..258143e2a 100644 --- a/test/quantities/utest_units.py +++ b/test/quantities/utest_units.py @@ -1,5 +1,3 @@ -import math - import sasdata.quantities.units as units from sasdata.quantities.units import Unit @@ -11,7 +9,7 @@ def __init__(self, test_name: str, *units): def run_test(self): for i, unit_1 in enumerate(self.units): - for unit_2 in self.units[i + 1 :]: + for unit_2 in self.units[i+1:]: assert unit_1.equivalent(unit_2), "Units should be equivalent" assert unit_1 == unit_2, "Units should be equal" @@ -23,22 +21,11 @@ def __init__(self, test_name: str, *units): def run_test(self): for i, unit_1 in enumerate(self.units): - for unit_2 in self.units[i + 1 :]: + for unit_2 in self.units[i+1:]: assert unit_1.equivalent(unit_2), "Units should be equivalent" assert unit_1 != unit_2, "Units should not be equal" -class DissimilarUnits: - def __init__(self, test_name: str, *units): - self.test_name = "Dissimilar: " + test_name - self.units: list[Unit] = list(units) - - def run_test(self): - for i, unit_1 in enumerate(self.units): - for unit_2 in self.units[i + 1 :]: - assert not unit_1.equivalent(unit_2), "Units should not be equivalent" - - tests = [ EqualUnits("Pressure", @@ -49,19 +36,7 @@ def run_test(self): EqualUnits("Resistance", units.ohms, units.volts / units.amperes, - 1e-3/units.millisiemens), - - EquivalentButUnequalUnits("Angular frequency", - units.rotations / units.minutes, - units.degrees * units.hertz), - - EqualUnits("Angular frequency", - (units.rotations/units.minutes ), - (units.radians*units.hertz) * 2 * math.pi/60.0), - - DissimilarUnits("Frequency and Angular frequency", - (units.rotations/units.minutes), - (units.hertz)), + 1e-3/units.millisiemens) ] diff --git a/test/sasdataloader/reference/14250.txt b/test/sasdataloader/reference/14250.txt new file mode 100644 index 000000000..6f11aba78 --- /dev/null +++ b/test/sasdataloader/reference/14250.txt @@ -0,0 +1,51 @@ +sasentry01 + [FILE_ID_HERE/sasentry01/sasdata/I] [[0.0 ... 0.0]] ± [[0.0 ... 0.0]] cm⁻¹ + [FILE_ID_HERE/sasentry01/sasdata/Qx] [[-0.11925 ... 0.11925000000000001]] Å⁻¹ + [FILE_ID_HERE/sasentry01/sasdata/Qy] [[-0.11925 ... 0.11925000000000001]] Å⁻¹ +Metadata: + + High C high V 63, 900oC, 10h 1.65T_SANS, Run: 14250 + =================================================== + +Definition: High C high V 63, 900oC, 10h 1.65T_SANS +Process: + Name: Mantid_generated_NXcanSAS + Date: 2016-12-06T17:15:48 + Description: None + Term: None + Notes: None +Sample: + ID: None + Transmission: None + Thickness: None + Temperature: None + Position: None + Orientation: None +Aperture: + Name: None + Aperture size: None + Aperture distance: None +Collimation: + Length: None +Detector: + Name: None + Distance: None + Offset: None + Orientation: None + Beam center: None + Pixel size: None + Slit length: None +Source: + Radiation: Spallation Neutron Source + Shape: None + Wavelength: None + Min. Wavelength: None + Max. Wavelength: None + Wavelength Spread: None + Beam Size: None +Transmission Spectrum: + Name: None + Timestamp: None + Wavelengths: None + Transmission: None + diff --git a/test/sasdataloader/reference/33837.txt b/test/sasdataloader/reference/33837.txt new file mode 100644 index 000000000..26f36b96b --- /dev/null +++ b/test/sasdataloader/reference/33837.txt @@ -0,0 +1,50 @@ +sasentry01 + [FILE_ID_HERE/sasentry01/sasdata/I] [5.416094671273121 ... 0.33697913143947616] ± [0.6152247543248875 ... 0.19365125082205084] m + [FILE_ID_HERE/sasentry01/sasdata/Q] [0.0041600000000000005 ... 0.6189241619415587] m +Metadata: + + MH4_5deg_16T_SLOW, Run: 33837 + ============================= + +Definition: MH4_5deg_16T_SLOW +Process: + Name: Mantid_generated_NXcanSAS + Date: 11-May-2016 12:20:43 + Description: None + Term: None + Notes: None +Sample: + ID: None + Transmission: None + Thickness: None + Temperature: None + Position: None + Orientation: None +Aperture: + Name: None + Aperture size: None + Aperture distance: None +Collimation: + Length: None +Detector: + Name: None + Distance: None + Offset: None + Orientation: None + Beam center: None + Pixel size: None + Slit length: None +Source: + Radiation: Spallation Neutron Source + Shape: None + Wavelength: None + Min. Wavelength: None + Max. Wavelength: None + Wavelength Spread: None + Beam Size: None +Transmission Spectrum: + Name: None + Timestamp: None + Wavelengths: None + Transmission: None + diff --git a/test/sasdataloader/reference/33837_v3.txt b/test/sasdataloader/reference/33837_v3.txt new file mode 100644 index 000000000..f4d205b2a --- /dev/null +++ b/test/sasdataloader/reference/33837_v3.txt @@ -0,0 +1,50 @@ +sasentry01 + [FILE_ID_HERE/sasentry01/sasdata/I] [5.416094671273121 ... 0.33697913143947616] ± [0.6152247543248875 ... 0.19365125082205084] none + [FILE_ID_HERE/sasentry01/sasdata/Q] [0.0041600000000000005 ... 0.6189241619415587] Å⁻¹ +Metadata: + + MH4_5deg_16T_SLOW, Run: 33837 + ============================= + +Definition: MH4_5deg_16T_SLOW +Process: + Name: Mantid_generated_NXcanSAS + Date: 2016-07-04T10:34:34 + Description: None + Term: None + Notes: None +Sample: + ID: None + Transmission: None + Thickness: None + Temperature: None + Position: None + Orientation: None +Aperture: + Name: None + Aperture size: None + Aperture distance: None +Collimation: + Length: None +Detector: + Name: None + Distance: None + Offset: None + Orientation: None + Beam center: None + Pixel size: None + Slit length: None +Source: + Radiation: Spallation Neutron Source + Shape: None + Wavelength: None + Min. Wavelength: None + Max. Wavelength: None + Wavelength Spread: None + Beam Size: None +Transmission Spectrum: + Name: None + Timestamp: None + Wavelengths: None + Transmission: None + diff --git a/test/sasdataloader/reference/BAM.txt b/test/sasdataloader/reference/BAM.txt new file mode 100644 index 000000000..bf6328b1c --- /dev/null +++ b/test/sasdataloader/reference/BAM.txt @@ -0,0 +1,51 @@ +sasentry01 + [FILE_ID_HERE/sasentry01/data/I] [[-4132.585671758142 ... -4139.954861346877]] ± [[1000000000.0 ... 1000000000.0]] m⁻¹ + [FILE_ID_HERE/sasentry01/data/Imask] [[True ... True]] m⁻¹ + [FILE_ID_HERE/sasentry01/data/Q] [[[-0.10733919639695527 ... 0.0]]] Å⁻¹ +Metadata: + + Qais oriented iron test 1, Run: 12345 + ===================================== + +Definition: Qais oriented iron test 1 +Process: + Name: None + Date: None + Description: None + Term: None + Notes: None +Sample: + ID: None + Transmission: None + Thickness: None + Temperature: None + Position: None + Orientation: None +Aperture: + Name: None + Aperture size: None + Aperture distance: None +Collimation: + Length: None +Detector: + Name: None + Distance: None + Offset: None + Orientation: None + Beam center: None + Pixel size: None + Slit length: None +Source: + Radiation: None + Shape: None + Wavelength: None + Min. Wavelength: None + Max. Wavelength: None + Wavelength Spread: None + Beam Size: None +Transmission Spectrum: + Name: None + Timestamp: None + Wavelengths: None + Transmission: None + diff --git a/test/sasdataloader/reference/ISIS_1_0.txt b/test/sasdataloader/reference/ISIS_1_0.txt index 57ff7ed28..14f071620 100644 --- a/test/sasdataloader/reference/ISIS_1_0.txt +++ b/test/sasdataloader/reference/ISIS_1_0.txt @@ -11,9 +11,7 @@ Process: Name: Mantid generated CanSAS1D XML Date: 02-Aug-2013 16:54:14 Description: None - Terms: - svn: 2.5.3 - user_file: K:/masks/MASKLOQ_MAN_132E_Lu_Banjo_12mm.txt + Term: {'svn': '2.5.3', 'user_file': 'K:/masks/MASKLOQ_MAN_132E_Lu_Banjo_12mm.txt'} Sample: ID: TK49 c10_SANS Transmission: None @@ -47,3 +45,4 @@ Source: Max. Wavelength: None Wavelength Spread: None Beam Size: None + diff --git a/test/sasdataloader/reference/ISIS_1_1.txt b/test/sasdataloader/reference/ISIS_1_1.txt index 48df868f5..731c1b6cf 100644 --- a/test/sasdataloader/reference/ISIS_1_1.txt +++ b/test/sasdataloader/reference/ISIS_1_1.txt @@ -11,9 +11,7 @@ Process: Name: Mantid generated CanSAS1D XML Date: 02-Aug-2013 16:53:56 Description: None - Terms: - svn: 2.5.3 - user_file: K:/masks/MASKLOQ_MAN_132E_Lu_Banjo_12mm.txt + Term: {'svn': '2.5.3', 'user_file': 'K:/masks/MASKLOQ_MAN_132E_Lu_Banjo_12mm.txt'} Sample: ID: TK49 c10_SANS Transmission: None @@ -47,3 +45,4 @@ Source: Max. Wavelength: None Wavelength Spread: None Beam Size: None + diff --git a/test/sasdataloader/reference/ISIS_1_1_doubletrans.txt b/test/sasdataloader/reference/ISIS_1_1_doubletrans.txt index 48df868f5..731c1b6cf 100644 --- a/test/sasdataloader/reference/ISIS_1_1_doubletrans.txt +++ b/test/sasdataloader/reference/ISIS_1_1_doubletrans.txt @@ -11,9 +11,7 @@ Process: Name: Mantid generated CanSAS1D XML Date: 02-Aug-2013 16:53:56 Description: None - Terms: - svn: 2.5.3 - user_file: K:/masks/MASKLOQ_MAN_132E_Lu_Banjo_12mm.txt + Term: {'svn': '2.5.3', 'user_file': 'K:/masks/MASKLOQ_MAN_132E_Lu_Banjo_12mm.txt'} Sample: ID: TK49 c10_SANS Transmission: None @@ -47,3 +45,4 @@ Source: Max. Wavelength: None Wavelength Spread: None Beam Size: None + diff --git a/test/sasdataloader/reference/ISIS_1_1_notrans.txt b/test/sasdataloader/reference/ISIS_1_1_notrans.txt index 48df868f5..731c1b6cf 100644 --- a/test/sasdataloader/reference/ISIS_1_1_notrans.txt +++ b/test/sasdataloader/reference/ISIS_1_1_notrans.txt @@ -11,9 +11,7 @@ Process: Name: Mantid generated CanSAS1D XML Date: 02-Aug-2013 16:53:56 Description: None - Terms: - svn: 2.5.3 - user_file: K:/masks/MASKLOQ_MAN_132E_Lu_Banjo_12mm.txt + Term: {'svn': '2.5.3', 'user_file': 'K:/masks/MASKLOQ_MAN_132E_Lu_Banjo_12mm.txt'} Sample: ID: TK49 c10_SANS Transmission: None @@ -47,3 +45,4 @@ Source: Max. Wavelength: None Wavelength Spread: None Beam Size: None + diff --git a/test/sasdataloader/reference/MAR07232_rest.txt b/test/sasdataloader/reference/MAR07232_rest.txt index e98b20bd2..8a9e16ce4 100644 --- a/test/sasdataloader/reference/MAR07232_rest.txt +++ b/test/sasdataloader/reference/MAR07232_rest.txt @@ -1,24 +1,31 @@ sasentry01 - Qy - Qx - I Metadata: MAR07232_rest_out.dat, Run: 2 ============================= Definition: MAR07232_rest_out.dat +Process: + Name: None + Date: None + Description: None + Term: None + Notes: None Sample: - ID: - Transmission: 0.84357 + ID: None + Transmission: [0.84357] Thickness: None Temperature: None Position: None Orientation: None +Aperture: + Name: None + Aperture size: None + Aperture distance: None Collimation: Length: None Detector: - Name: + Name: None Distance: None Offset: None Orientation: None @@ -33,3 +40,21 @@ Source: Max. Wavelength: None Wavelength Spread: None Beam Size: None +<<<<<<< HEAD +Transmission Spectrum: + Name: None + Timestamp: None + Wavelengths: None + Transmission: None + + +||||||| parent of 0b6ad967 (Metadata is a dataclass) +Transmission Spectrum: + Name: None + Timestamp: None + Wavelengths: None + Transmission: None + +======= + +>>>>>>> 0b6ad967 (Metadata is a dataclass) \ No newline at end of file diff --git a/test/sasdataloader/reference/TestExtensions.txt b/test/sasdataloader/reference/TestExtensions.txt index e833c9da7..d04c7ef45 100644 --- a/test/sasdataloader/reference/TestExtensions.txt +++ b/test/sasdataloader/reference/TestExtensions.txt @@ -1,4 +1,4 @@ -TK49 c10_SANS +None Q I Metadata: @@ -10,10 +10,8 @@ Definition: TK49 c10_SANS Process: Name: Mantid generated CanSAS1D XML Date: 02-Aug-2013 16:53:56 - Description: - Terms: - svn: 2.5.3 - user_file: K:/masks/MASKLOQ_MAN_132E_Lu_Banjo_12mm.txt + Description: None + Term: {'svn': '2.5.3', 'user_file': 'K:/masks/MASKLOQ_MAN_132E_Lu_Banjo_12mm.txt'} Sample: ID: TK49 c10_SANS Transmission: None @@ -47,3 +45,4 @@ Source: Max. Wavelength: None Wavelength Spread: None Beam Size: None + diff --git a/test/sasdataloader/reference/cansas1d.txt b/test/sasdataloader/reference/cansas1d.txt index e16366bba..7563b9b65 100644 --- a/test/sasdataloader/reference/cansas1d.txt +++ b/test/sasdataloader/reference/cansas1d.txt @@ -1,4 +1,4 @@ -Test title +None Q I Metadata: @@ -11,34 +11,12 @@ Process: Name: spol Date: 04-Sep-2007 18:35:02 Description: None - Terms: - radialstep: 10.0 mm - sector_width: 180.0 deg - sector_orient: 0.0 deg - MASK_file: USER:MASK.COM - Notes: - AvA1 0.0000E+00 AsA2 1.0000E+00 XvA3 1.0526E+03 XsA4 - 5.2200E-02 XfA5 0.0000E+00 - S... 13597 0 2.26E+02 2A 5mM 0%D2O Sbak 13594 0 1.13E+02 - H2O Buffer - V... 13552 3 1.00E+00 H2O5m + Term: {'radialstep': '10.000', 'sector_width': '180.0', 'sector_orient': '0.0', 'MASK_file': 'USER:MASK.COM'} Process: Name: NCNR-IGOR Date: 03-SEP-2006 11:42:47 - Description: - Terms: - average_type: Circular - SAM_file: SEP06064.SA3_AJJ_L205 - BKD_file: SEP06064.SA3_AJJ_L205 - EMP_file: SEP06064.SA3_AJJ_L205 - DIV_file: SEP06064.SA3_AJJ_L205 - MASK_file: SEP06064.SA3_AJJ_L205 - ABS:TSTAND: 1 - ABS:DSTAND: 1.0 mm - ABS:IZERO: 230.09 - ABS:XSECT: 1.0 mm - Notes: - No Information + Description: None + Term: {'average_type': 'Circular', 'SAM_file': 'SEP06064.SA3_AJJ_L205', 'BKD_file': 'SEP06064.SA3_AJJ_L205', 'EMP_file': 'SEP06064.SA3_AJJ_L205', 'DIV_file': 'SEP06064.SA3_AJJ_L205', 'MASK_file': 'SEP06064.SA3_AJJ_L205', 'ABS:TSTAND': '1', 'ABS:DSTAND': '1', 'ABS:IZERO': '230.09', 'ABS:XSECT': '1'} Sample: ID: SI600-new-long Transmission: 0.327 @@ -48,14 +26,6 @@ Sample: Orientation: Rot3(roll=22.5 deg, pitch=0.02 deg, yaw=None) Collimation: Length: 123.0 mm - Aperture: - Name: source - Aperture size: Vec3(x=50.0 mm, y=None, z=None) - Aperture distance: 11.0 m - Aperture: - Name: sample - Aperture size: Vec3(x=1.0 mm, y=None, z=None) - Aperture distance: None Detector: Name: fictional hybrid Distance: 4.15 m diff --git a/test/sasdataloader/reference/cansas1d_badunits.txt b/test/sasdataloader/reference/cansas1d_badunits.txt index 4c61f204e..3868aead0 100644 --- a/test/sasdataloader/reference/cansas1d_badunits.txt +++ b/test/sasdataloader/reference/cansas1d_badunits.txt @@ -1,4 +1,4 @@ -Test title +None Q I Metadata: @@ -11,32 +11,12 @@ Process: Name: spol Date: 04-Sep-2007 18:35:02 Description: None - Terms: - radialstep: 10.0 mm - sector_width: 4.1416 rad - sector_orient: 0.0 deg - MASK_file: USER:MASK.COM - Notes: - AvA1 0.0000E+00 AsA2 1.0000E+00 XvA3 1.0526E+03 XsA4 - 5.2200E-02 XfA5 0.0000E+00 - S... 13597 0 2.26E+02 2A 5mM 0%D2O Sbak 13594 0 1.13E+02 - H2O Buffer - V... 13552 3 1.00E+00 H2O5m + Term: {'radialstep': '10.000', 'sector_width': '4.1416', 'sector_orient': '0.0', 'MASK_file': 'USER:MASK.COM'} Process: Name: NCNR-IGOR Date: 03-SEP-2006 11:42:47 Description: - Terms: - average_type: Circular - SAM_file: SEP06064.SA3_AJJ_L205 - BKD_file: SEP06064.SA3_AJJ_L205 - EMP_file: SEP06064.SA3_AJJ_L205 - DIV_file: SEP06064.SA3_AJJ_L205 - MASK_file: SEP06064.SA3_AJJ_L205 - ABS:TSTAND: 1 - ABS:DSTAND: 1.0 mm - ABS:IZERO: 230.09 - ABS:XSECT: 1.0 mm + Term: {'average_type': 'Circular', 'SAM_file': 'SEP06064.SA3_AJJ_L205', 'BKD_file': 'SEP06064.SA3_AJJ_L205', 'EMP_file': 'SEP06064.SA3_AJJ_L205', 'DIV_file': 'SEP06064.SA3_AJJ_L205', 'MASK_file': 'SEP06064.SA3_AJJ_L205', 'ABS:TSTAND': '1', 'ABS:DSTAND': '1', 'ABS:IZERO': '230.09', 'ABS:XSECT': '1'} Sample: ID: SI600-new-long Transmission: 0.327 @@ -46,19 +26,11 @@ Sample: Orientation: Rot3(roll=0.39269908 rad, pitch=0.00034906585 rad, yaw=None) Collimation: Length: 0.123 m - Aperture: - Name: source - Aperture size: Vec3(x=50000.0 µm, y=None, z=None) - Aperture distance: 1100.0 cm - Aperture: - Name: sample - Aperture size: Vec3(x=0.1 cm, y=None, z=None) - Aperture distance: None Detector: Name: fictional hybrid Distance: 4.15 m - Offset: Vec3(x=1000000.0 nm, y=2000.0 µm, z=None) - Orientation: Rot3(roll=0.0174533 rad, pitch=0.0 rad, yaw=0.0 rad) + Offset: Vec3(x=1000000.0 nm, y=2000.0 none, z=None) + Orientation: Rot3(roll=0.0174533 none, pitch=0.0 rad, yaw=0.0 none) Beam center: Vec3(x=0.32264 m, y=0.32768 m, z=None) Pixel size: Vec3(x=0.5 cm, y=0.5 cm, z=None) Slit length: None @@ -69,4 +41,4 @@ Source: Min. Wavelength: 2.2 Å Max. Wavelength: 10.0 Å Wavelength Spread: 14.3 % - Beam Size: BeamSize(name='bm', size=Vec3(x=0.012 m, y=13000.0 µm, z=None)) + Beam Size: BeamSize(name='bm', size=Vec3(x=0.012 m, y=13000.0 none, z=None)) diff --git a/test/sasdataloader/reference/cansas1d_notitle.txt b/test/sasdataloader/reference/cansas1d_notitle.txt index ca8abac25..b21da686f 100644 --- a/test/sasdataloader/reference/cansas1d_notitle.txt +++ b/test/sasdataloader/reference/cansas1d_notitle.txt @@ -1,44 +1,22 @@ -SasData01 +None Q I Metadata: None, Run: 1234 - =============== + =========== Definition: None Process: Name: spol Date: 04-Sep-2007 18:35:02 Description: None - Terms: - radialstep: 10.0 mm - sector_width: 180.0 deg - sector_orient: 0.0 deg - MASK_file: USER:MASK.COM - Notes: - AvA1 0.0000E+00 AsA2 1.0000E+00 XvA3 1.0526E+03 XsA4 - 5.2200E-02 XfA5 0.0000E+00 - S... 13597 0 2.26E+02 2A 5mM 0%D2O Sbak 13594 0 1.13E+02 - H2O Buffer - V... 13552 3 1.00E+00 H2O5m + Term: {'radialstep': '10.000', 'sector_width': '180.0', 'sector_orient': '0.0', 'MASK_file': 'USER:MASK.COM'} Process: Name: NCNR-IGOR Date: 03-SEP-2006 11:42:47 Description: - Terms: - average_type: Circular - SAM_file: SEP06064.SA3_AJJ_L205 - BKD_file: SEP06064.SA3_AJJ_L205 - EMP_file: SEP06064.SA3_AJJ_L205 - DIV_file: SEP06064.SA3_AJJ_L205 - MASK_file: SEP06064.SA3_AJJ_L205 - ABS:TSTAND: 1 - ABS:DSTAND: 1.0 mm - ABS:IZERO: 230.09 - ABS:XSECT: 1.0 mm - Notes: - No Information + Term: {'average_type': 'Circular', 'SAM_file': 'SEP06064.SA3_AJJ_L205', 'BKD_file': 'SEP06064.SA3_AJJ_L205', 'EMP_file': 'SEP06064.SA3_AJJ_L205', 'DIV_file': 'SEP06064.SA3_AJJ_L205', 'MASK_file': 'SEP06064.SA3_AJJ_L205', 'ABS:TSTAND': '1', 'ABS:DSTAND': '1', 'ABS:IZERO': '230.09', 'ABS:XSECT': '1'} Sample: ID: SI600-new-long Transmission: 0.327 @@ -48,14 +26,6 @@ Sample: Orientation: Rot3(roll=22.5 deg, pitch=0.02 deg, yaw=None) Collimation: Length: 123.0 mm - Aperture: - Name: source - Aperture size: Vec3(x=50.0 mm, y=None, z=None) - Aperture distance: 11.0 m - Aperture: - Name: sample - Aperture size: Vec3(x=1.0 mm, y=None, z=None) - Aperture distance: None Detector: Name: fictional hybrid Distance: 4.15 m diff --git a/test/sasdataloader/reference/cansas1d_slit.txt b/test/sasdataloader/reference/cansas1d_slit.txt index cbc9bdcfe..7ef9093f8 100644 --- a/test/sasdataloader/reference/cansas1d_slit.txt +++ b/test/sasdataloader/reference/cansas1d_slit.txt @@ -1,8 +1,8 @@ -Test title - dQw +None + I dQl Q - I + dQw Metadata: Test title, Run: 1234 @@ -13,32 +13,12 @@ Process: Name: spol Date: 04-Sep-2007 18:35:02 Description: None - Terms: - radialstep: 10.0 mm - sector_width: 180.0 deg - sector_orient: 0.0 deg - MASK_file: USER:MASK.COM - Notes: - AvA1 0.0000E+00 AsA2 1.0000E+00 XvA3 1.0526E+03 XsA4 - 5.2200E-02 XfA5 0.0000E+00 - S... 13597 0 2.26E+02 2A 5mM 0%D2O Sbak 13594 0 1.13E+02 - H2O Buffer - V... 13552 3 1.00E+00 H2O5m + Term: {'radialstep': '10.000', 'sector_width': '180.0', 'sector_orient': '0.0', 'MASK_file': 'USER:MASK.COM'} Process: Name: NCNR-IGOR Date: 03-SEP-2006 11:42:47 Description: - Terms: - average_type: Circular - SAM_file: SEP06064.SA3_AJJ_L205 - BKD_file: SEP06064.SA3_AJJ_L205 - EMP_file: SEP06064.SA3_AJJ_L205 - DIV_file: SEP06064.SA3_AJJ_L205 - MASK_file: SEP06064.SA3_AJJ_L205 - ABS:TSTAND: 1 - ABS:DSTAND: 1.0 mm - ABS:IZERO: 230.09 - ABS:XSECT: 1.0 mm + Term: {'average_type': 'Circular', 'SAM_file': 'SEP06064.SA3_AJJ_L205', 'BKD_file': 'SEP06064.SA3_AJJ_L205', 'EMP_file': 'SEP06064.SA3_AJJ_L205', 'DIV_file': 'SEP06064.SA3_AJJ_L205', 'MASK_file': 'SEP06064.SA3_AJJ_L205', 'ABS:TSTAND': '1', 'ABS:DSTAND': '1', 'ABS:IZERO': '230.09', 'ABS:XSECT': '1'} Sample: ID: SI600-new-long Transmission: 0.327 @@ -48,14 +28,6 @@ Sample: Orientation: Rot3(roll=22.5 deg, pitch=0.02 deg, yaw=None) Collimation: Length: 123.0 mm - Aperture: - Name: source - Aperture size: Vec3(x=50.0 mm, y=None, z=None) - Aperture distance: 11.0 m - Aperture: - Name: sample - Aperture size: Vec3(x=1.0 mm, y=None, z=None) - Aperture distance: None Detector: Name: fictional hybrid Distance: 4.15 m diff --git a/test/sasdataloader/reference/cansas1d_units.txt b/test/sasdataloader/reference/cansas1d_units.txt index 68db7bc45..21845aa8c 100644 --- a/test/sasdataloader/reference/cansas1d_units.txt +++ b/test/sasdataloader/reference/cansas1d_units.txt @@ -1,4 +1,4 @@ -Test title +None Q I Metadata: @@ -11,32 +11,12 @@ Process: Name: spol Date: 04-Sep-2007 18:35:02 Description: None - Terms: - radialstep: 10.0 mm - sector_width: 4.1416 rad - sector_orient: 0.0 deg - MASK_file: USER:MASK.COM - Notes: - AvA1 0.0000E+00 AsA2 1.0000E+00 XvA3 1.0526E+03 XsA4 - 5.2200E-02 XfA5 0.0000E+00 - S... 13597 0 2.26E+02 2A 5mM 0%D2O Sbak 13594 0 1.13E+02 - H2O Buffer - V... 13552 3 1.00E+00 H2O5m + Term: {'radialstep': '10.000', 'sector_width': '4.1416', 'sector_orient': '0.0', 'MASK_file': 'USER:MASK.COM'} Process: Name: NCNR-IGOR Date: 03-SEP-2006 11:42:47 Description: - Terms: - average_type: Circular - SAM_file: SEP06064.SA3_AJJ_L205 - BKD_file: SEP06064.SA3_AJJ_L205 - EMP_file: SEP06064.SA3_AJJ_L205 - DIV_file: SEP06064.SA3_AJJ_L205 - MASK_file: SEP06064.SA3_AJJ_L205 - ABS:TSTAND: 1 - ABS:DSTAND: 1.0 mm - ABS:IZERO: 230.09 - ABS:XSECT: 1.0 mm + Term: {'average_type': 'Circular', 'SAM_file': 'SEP06064.SA3_AJJ_L205', 'BKD_file': 'SEP06064.SA3_AJJ_L205', 'EMP_file': 'SEP06064.SA3_AJJ_L205', 'DIV_file': 'SEP06064.SA3_AJJ_L205', 'MASK_file': 'SEP06064.SA3_AJJ_L205', 'ABS:TSTAND': '1', 'ABS:DSTAND': '1', 'ABS:IZERO': '230.09', 'ABS:XSECT': '1'} Sample: ID: SI600-new-long Transmission: 0.327 @@ -46,19 +26,11 @@ Sample: Orientation: Rot3(roll=0.39269908 rad, pitch=0.00034906585 rad, yaw=None) Collimation: Length: 0.123 m - Aperture: - Name: source - Aperture size: Vec3(x=50000.0 µm, y=None, z=None) - Aperture distance: 1100.0 cm - Aperture: - Name: sample - Aperture size: Vec3(x=0.1 cm, y=None, z=None) - Aperture distance: None Detector: Name: fictional hybrid Distance: 4.15 m - Offset: Vec3(x=1000000.0 nm, y=2000.0 µm, z=None) - Orientation: Rot3(roll=0.0174533 rad, pitch=0.0 rad, yaw=0.0 rad) + Offset: Vec3(x=1000000.0 nm, y=2000.0 none, z=None) + Orientation: Rot3(roll=0.0174533 none, pitch=0.0 rad, yaw=0.0 none) Beam center: Vec3(x=0.32264 m, y=0.32768 m, z=None) Pixel size: Vec3(x=0.5 cm, y=0.5 cm, z=None) Slit length: None @@ -69,4 +41,4 @@ Source: Min. Wavelength: 2.2 Å Max. Wavelength: 10.0 Å Wavelength Spread: 14.3 % - Beam Size: BeamSize(name='bm', size=Vec3(x=0.012 m, y=13000.0 µm, z=None)) + Beam Size: BeamSize(name='bm', size=Vec3(x=0.012 m, y=13000.0 none, z=None)) diff --git a/test/sasdataloader/reference/cansas_test.txt b/test/sasdataloader/reference/cansas_test.txt index 60333cc05..46f1e4e87 100644 --- a/test/sasdataloader/reference/cansas_test.txt +++ b/test/sasdataloader/reference/cansas_test.txt @@ -1,4 +1,4 @@ -ILL-D11 example1: 2A 5mM 0%D2O +None Q I Metadata: @@ -11,17 +11,7 @@ Process: Name: spol Date: 04-Sep-2007 18:12:27 Description: None - Terms: - radialstep: 10.0 mm - sector_width: 180.0 deg - sector_orient: 0.0 deg - Flux_monitor: 1.00 - Count_time: 900.0 s - Q_resolution: estimated - Notes: - AvA1 0.0000E+00 AsA2 1.0000E+00 XvA3 1.1111E+01 XsA4 5.2200E-02 XfA5 0.0000E+00 - S... 13586 0 5.63E+01 2A 5mM 0%D2O Sbak 13567 0 1.88E+01 H2O buffer - Sbak 13533 0 7.49E+00 H2O buffer V... 13552 1 1.00E+00 Normal 10m from + Term: {'radialstep': '10.000', 'sector_width': '180.0', 'sector_orient': '0.0', 'Flux_monitor': '1.00', 'Count_time': '900.000', 'Q_resolution': 'estimated'} Sample: ID: None Transmission: 0.0 diff --git a/test/sasdataloader/reference/cansas_test_modified.txt b/test/sasdataloader/reference/cansas_test_modified.txt index 9ccd51e73..56d65ae88 100644 --- a/test/sasdataloader/reference/cansas_test_modified.txt +++ b/test/sasdataloader/reference/cansas_test_modified.txt @@ -1,4 +1,4 @@ -ILL-D11 example1: 2A 5mM 0%D2O +None Q I Metadata: @@ -11,17 +11,7 @@ Process: Name: spol Date: 04-Sep-2007 18:12:27 Description: None - Terms: - radialstep: 10.0 mm - sector_width: 180.0 deg - sector_orient: 0.0 deg - Flux_monitor: 1.00 - Count_time: 900.0 s - Q_resolution: estimated - Notes: - AvA1 0.0000E+00 AsA2 1.0000E+00 XvA3 1.1111E+01 XsA4 5.2200E-02 XfA5 0.0000E+00 - S... 13586 0 5.63E+01 2A 5mM 0%D2O Sbak 13567 0 1.88E+01 H2O buffer - Sbak 13533 0 7.49E+00 H2O buffer V... 13552 1 1.00E+00 Normal 10m from + Term: {'radialstep': '10.000', 'sector_width': '180.0', 'sector_orient': '0.0', 'Flux_monitor': '1.00', 'Count_time': '900.000', 'Q_resolution': 'estimated'} Sample: ID: This is a test file Transmission: 0.0 diff --git a/test/sasdataloader/reference/nxcansas_1Dand2D_multisasdata.txt b/test/sasdataloader/reference/nxcansas_1Dand2D_multisasdata.txt index ad56786dd..5763cf71f 100644 --- a/test/sasdataloader/reference/nxcansas_1Dand2D_multisasdata.txt +++ b/test/sasdataloader/reference/nxcansas_1Dand2D_multisasdata.txt @@ -1,7 +1,4 @@ sasentry01 - Qy - Qx - I Metadata: MH4_5deg_16T_SLOW, Run: 33837 @@ -9,29 +6,27 @@ Metadata: Definition: MH4_5deg_16T_SLOW Process: - Name: Mantid generated CanSAS1D XML - Date: 11-May-2016 12:15:34 + Name: None + Date: None Description: None + Term: None + Notes: None Sample: - ID: + ID: None Transmission: None Thickness: None Temperature: None Position: None Orientation: None +Aperture: + Name: None + Aperture size: None + Aperture distance: None Collimation: Length: None Detector: - Name: front-detector - Distance: 2845.26 mm - Offset: None - Orientation: None - Beam center: None - Pixel size: None - Slit length: None -Detector: - Name: rear-detector - Distance: 4385.28 mm + Name: None + Distance: None Offset: None Orientation: None Beam center: None diff --git a/test/sasdataloader/reference/nxcansas_1Dand2D_multisasentry.txt b/test/sasdataloader/reference/nxcansas_1Dand2D_multisasentry.txt index 9353db922..769f5ee49 100644 --- a/test/sasdataloader/reference/nxcansas_1Dand2D_multisasentry.txt +++ b/test/sasdataloader/reference/nxcansas_1Dand2D_multisasentry.txt @@ -1,12 +1,56 @@ sasentry01 - dQ - Q - I Metadata: MH4_5deg_16T_SLOW, Run: 33837 ============================= +Definition: MH4_5deg_16T_SLOW +Process: + Name: None + Date: None + Description: None + Term: None + Notes: None +Sample: + ID: None + Transmission: None + Thickness: None + Temperature: None + Position: None + Orientation: None +Aperture: + Name: None + Aperture size: None + Aperture distance: None +Collimation: + Length: None +Detector: + Name: None + Distance: None + Offset: None + Orientation: None + Beam center: None + Pixel size: None + Slit length: None +Source: + Radiation: Spallation Neutron Source + Shape: None + Wavelength: None + Min. Wavelength: None + Max. Wavelength: None + Wavelength Spread: None + Beam Size: None +Transmission Spectrum: + Name: None + Timestamp: None + Wavelengths: None + Transmission: None + +<<<<<<< HEAD +<<<<<<< HEAD + MH4_5deg_16T_SLOW, Run: 33837 + ============================= + Definition: MH4_5deg_16T_SLOW Process: Name: Mantid generated CanSAS1D XML @@ -45,12 +89,8 @@ Source: Max. Wavelength: None Wavelength Spread: None Beam Size: None -sasentry02 - Qy - Qx - I -Metadata: +||||||| parent of cb5f3ffb (Add tests for dataset) MH4_5deg_16T_SLOW, Run: 33837 ============================= @@ -70,7 +110,7 @@ Collimation: Length: None Detector: Name: front-detector - Distance: 2845.26 mm + Distance: 2845.260009765625 mm Offset: None Orientation: None Beam center: None @@ -78,7 +118,56 @@ Detector: Slit length: None Detector: Name: rear-detector - Distance: 4385.28 mm + Distance: 4385.27978515625 mm + Offset: None + Orientation: None + Beam center: None + Pixel size: None + Slit length: None +Source: + Radiation: Spallation Neutron Source + Shape: None + Wavelength: None + Min. Wavelength: None + Max. Wavelength: None + Wavelength Spread: None + Beam Size: None +Transmission Spectrum: + Name: None + Timestamp: None + Wavelengths: None + Transmission: None + +||||||| original + MH4_5deg_16T_SLOW, Run: 33837 + ============================= + +Definition: MH4_5deg_16T_SLOW +Process: + Name: Mantid generated CanSAS1D XML + Date: 11-May-2016 12:15:34 + Description: None + Term: None +Sample: + ID: + Transmission: None + Thickness: None + Temperature: None + Position: None + Orientation: None +Collimation: + Length: None +Detector: + Name: front-detector + Distance: 2845.260009765625 mm + Offset: None + Orientation: None + Beam center: None + Pixel size: None + Slit length: None +Detector: + Name: rear-detector + Distance: 4385.27978515625 mm Offset: None Orientation: None Beam center: None @@ -92,3 +181,8 @@ Source: Max. Wavelength: None Wavelength Spread: None Beam Size: None +Transmission Spectrum: + Name: None + Timestamp: None + Wavelengths: None + Transmission: None diff --git a/test/sasdataloader/reference/simpleexamplefile.txt b/test/sasdataloader/reference/simpleexamplefile.txt index 58c9c66b6..a81384082 100644 --- a/test/sasdataloader/reference/simpleexamplefile.txt +++ b/test/sasdataloader/reference/simpleexamplefile.txt @@ -1,9 +1,48 @@ sasentry01 - Q - I Metadata: - None, Run: [] - ============= + None, Run: None + =============== Definition: None +Process: + Name: None + Date: None + Description: None + Term: None + Notes: None +Sample: + ID: None + Transmission: None + Thickness: None + Temperature: None + Position: None + Orientation: None +Aperture: + Name: None + Aperture size: None + Aperture distance: None +Collimation: + Length: None +Detector: + Name: None + Distance: None + Offset: None + Orientation: None + Beam center: None + Pixel size: None + Slit length: None +Source: + Radiation: None + Shape: None + Wavelength: None + Min. Wavelength: None + Max. Wavelength: None + Wavelength Spread: None + Beam Size: None +Transmission Spectrum: + Name: None + Timestamp: None + Wavelengths: None + Transmission: None + diff --git a/test/sasdataloader/reference/sphere2micron.txt b/test/sasdataloader/reference/sphere2micron.txt index 3df15c420..87f3b192d 100644 --- a/test/sasdataloader/reference/sphere2micron.txt +++ b/test/sasdataloader/reference/sphere2micron.txt @@ -1,12 +1,8 @@ Sesans - Wavelength - SpinEchoLength - Polarisation - Depolarisation Metadata: - Polystyrene of Markus Strobl, Full Sine, ++ only, Run: [] - ========================================================== + Title, Run: [] + ============ Definition: Polystyrene of Markus Strobl, Full Sine, ++ only Process: @@ -20,7 +16,7 @@ Process: Sample: ID: None Transmission: None - Thickness: 0.2 cm + Thickness: None Temperature: None Position: None Orientation: None diff --git a/test/sasdataloader/reference/sphere_isis.txt b/test/sasdataloader/reference/sphere_isis.txt index e92cfbf3a..53cd10147 100644 --- a/test/sasdataloader/reference/sphere_isis.txt +++ b/test/sasdataloader/reference/sphere_isis.txt @@ -1,11 +1,8 @@ Sesans - Wavelength - SpinEchoLength - Depolarisation Metadata: PMMA in Mixed Deuterated decalin, Run: [] - ========================================= + ======================================= Definition: PMMA in Mixed Deuterated decalin Process: diff --git a/test/sasdataloader/reference/valid_cansas_xml.txt b/test/sasdataloader/reference/valid_cansas_xml.txt index 3b3993aba..5e73076fa 100644 --- a/test/sasdataloader/reference/valid_cansas_xml.txt +++ b/test/sasdataloader/reference/valid_cansas_xml.txt @@ -11,9 +11,7 @@ Process: Name: Mantid generated CanSAS1D XML Date: 10-Oct-2013 16:00:29 Description: None - Terms: - svn: 2.6.20130902.1504 - user_file: K:/masks/MASKLOQ_MAN_133D_Xpress_8mm.txt + Term: {'svn': '2.6.20130902.1504', 'user_file': 'K:/masks/MASKLOQ_MAN_133D_Xpress_8mm.txt'} Sample: ID: LOQ_Standard_TK49_SANS Transmission: None @@ -47,3 +45,4 @@ Source: Max. Wavelength: None Wavelength Spread: None Beam Size: None + diff --git a/test/sasdataloader/reference/x25000_no_di.txt b/test/sasdataloader/reference/x25000_no_di.txt index b0cad2ee3..a36a709bb 100644 --- a/test/sasdataloader/reference/x25000_no_di.txt +++ b/test/sasdataloader/reference/x25000_no_di.txt @@ -1,25 +1,31 @@ sasentry01 - mask - Qy - Qx - I Metadata: , Run: ======= Definition: +Process: + Name: None + Date: None + Description: None + Term: None + Notes: None Sample: - ID: + ID: None Transmission: None Thickness: None Temperature: None Position: None Orientation: None +Aperture: + Name: None + Aperture size: None + Aperture distance: None Collimation: Length: None Detector: - Name: + Name: None Distance: None Offset: None Orientation: None diff --git a/test/sasdataloader/utest_new_sesans.py b/test/sasdataloader/utest_new_sesans.py index d91f8bedf..c17155094 100644 --- a/test/sasdataloader/utest_new_sesans.py +++ b/test/sasdataloader/utest_new_sesans.py @@ -4,20 +4,9 @@ import os -import numpy as np import pytest -from sasdata.model_requirements import ( - ComposeRequirements, - ModellingRequirements, - NullModel, - PinholeModel, - SesansModel, - SlitModel, - guess_requirements, -) -from sasdata.quantities import unit_parser, units -from sasdata.quantities.quantity import Quantity +from sasdata.temp_hdf5_reader import load_data from sasdata.temp_sesans_reader import load_data test_file_names = ["sphere2micron", "sphere_isis"] @@ -36,100 +25,3 @@ def test_load_file(f): with open(local_load(f"reference/{f}.txt")) as infile: expected = "".join(infile.readlines()) assert data.summary() == expected - -@pytest.mark.sesans -def test_sesans_modelling(): - data = load_data(local_load("sesans_data/sphere2micron.ses")) - req = guess_requirements(data) - assert type(req) is SesansModel - - def form_volume(x): - return np.pi * 4.0 / 3.0 * x**3 - - radius = 10000 - contrast = 5.4e-7 # Contrast is hard coded for best fit - form = contrast * form_volume(radius) - f2 = 1.0e-4*form*form - - xi_squared = smear(req, data._data_contents["SpinEchoLength"].value, full_data=data, y = data._data_contents["Depolarisation"].in_units_of(unit_parser.parse("A-2 cm-1")) / f2, radius=radius) - assert 1.0 < xi_squared < 1.5 - -@pytest.mark.sesans -def test_pinhole_zero(): - assert pinhole_smear(0) == 0 - -@pytest.mark.sesans -def test_pinhole_smear(): - smearing = [3**x for x in range(-1, 6)] - smears = [pinhole_smear(x) for x in smearing] - old = 0 - for factor, smear in zip(smearing, smears): - print(factor, smear) - assert old < smear < 1.2 - old = smear - assert pinhole_smear(3**6) > 1.2 - -@pytest.mark.sesans -def test_slit_zero(): - assert slit_smear(0, 0) == 0 - -@pytest.mark.sesans -def test_slit_smear(): - smears = [(length, width, slit_smear(3**length, 3**width)) for length in range(-1, 6) for width in range(-1, 6) if length + width == 6] - for length, width, smear in smears: - print(length, width, smear) - assert smear < 1.2 - assert slit_smear(3**6, 1) > 1.2 - assert slit_smear(1, 3**6) > 1.2 - - -def slit_smear(length: float, width): - data = Quantity(np.linspace(1e-4, 1e-1, 1000), units.per_angstrom) - diff = np.diff(data.value, prepend=0) - req = SlitModel(diff * length, diff * width) - return smear(req, data.value) - - -def pinhole_smear(smearing: float): - data = Quantity(np.linspace(1e-4, 1e-1, 1000), units.per_angstrom) - req = PinholeModel(np.diff(data.value, prepend=0) * smearing) - return smear(req, data.value) - - -def smear(req: ModellingRequirements, data: np.ndarray, y=None, full_data=None, radius=200): - def sphere(q): - def sas_3j1x_x(x): - return (np.sin(x) - x * np.cos(x))/x**3 - - return sas_3j1x_x(q * radius)**2 - - inner_q = req.preprocess_q(data, full_data) - result = req.postprocess_iq(sphere(inner_q), data) - - if y is None: - y = sphere(data) - - xi_squared = np.sum( ((y - result)/result )**2 ) / len(y) - return xi_squared - - - -@pytest.mark.sesans -def test_model_algebra(): - ses = SesansModel() - pin = PinholeModel(np.linspace(1e-3, 1, 1000)) - slit = SlitModel(np.linspace(1e-3, 1, 1000), np.linspace(1e-3, 1, 1000)) - null = NullModel() - - # Ignore slit smearing if we perform a sesans transform afterwards - assert type(pin + ses) is SesansModel - assert type(slit + ses) is SesansModel - # However, it is possible for the spin echo lengths to have some - # smearing between them. - assert type(ses + pin) is ComposeRequirements - assert type(null + ses) is SesansModel - assert type(null + slit) is SlitModel - assert type(null + pin) is PinholeModel - assert type(ses + null) is SesansModel - assert type(pin + null) is PinholeModel - assert type(slit + null) is SlitModel diff --git a/test/sasdataloader/utest_sasdataload.py b/test/sasdataloader/utest_sasdataload.py index 735904352..7952d6a28 100644 --- a/test/sasdataloader/utest_sasdataload.py +++ b/test/sasdataloader/utest_sasdataload.py @@ -2,455 +2,60 @@ Unit tests for the new recursive cansas reader """ -import io -import json import os -from dataclasses import dataclass, field -from typing import Any -import numpy as np import pytest -import sasdata.quantities.units as units -from sasdata.data import SasData, SasDataEncoder -from sasdata.dataset_types import one_dim -from sasdata.guess import guess_columns -from sasdata.quantities.quantity import Quantity -from sasdata.quantities.units import per_angstrom -from sasdata.temp_ascii_reader import ( - AsciiMetadataCategory, - AsciiReaderMetadata, - AsciiReaderParams, - load_data_default_params, -) -from sasdata.temp_ascii_reader import load_data as ascii_load_data from sasdata.temp_hdf5_reader import load_data as hdf_load_data -from sasdata.temp_sesans_reader import load_data as sesans_load_data from sasdata.temp_xml_reader import load_data as xml_load_data - -def local_load(path: str): - """Get local file path""" - base = os.path.join(os.path.dirname(__file__), path) - if os.path.exists(f"{base}.h5"): - return f"{base}.h5" - if os.path.exists(f"{base}.xml"): - return f"{base}.xml" - return f"{base}" - - -def local_reference_load(path: str): - return local_load(f"{os.path.join('reference', path)}") - - -def local_data_load(path: str): - return local_load(f"{os.path.join('data', path)}") - - -def local_json_load(path: str): - return local_load(f"{os.path.join('json', path)}") - - -def local_sesans_load(path: str): - return local_load(f"{os.path.join('sesans_data', path)}") - - -@pytest.mark.sasdata -def test_filter_data(): - data = xml_load_data(local_load("data/cansas1d_notitle")) - for k, v in data.items(): - assert v.metadata.raw.filter("transmission") == ["0.327"] - assert v.metadata.raw.filter("wavelength")[0] == Quantity(6.0, units.angstroms) - assert v.metadata.raw.filter("SDD")[0] == Quantity(4.15, units.meters) - data = hdf_load_data(local_load("data/nxcansas_1Dand2D_multisasentry")) - for k, v in data.items(): - assert v.metadata.raw.filter("radiation") == ["Spallation Neutron Source"] - assert v.metadata.raw.filter("SDD") == [ - Quantity(np.array([2845.26], dtype=np.float32), units.millimeters), - Quantity(np.array([4385.28], dtype=np.float32), units.millimeters), - ] - - -@dataclass(kw_only=True) -class BaseTestCase: - expected_values: dict[int, dict[str, float]] - expected_metadata: dict[str, Any] = field(default_factory=dict) - metadata_file: None | str = None - json_file: None | str = None - round_trip: bool = False - - -@dataclass(kw_only=True) -class AsciiTestCase(BaseTestCase): - # If this is a string of strings then the other params will be guessed. - reader_params: AsciiReaderParams | str - - -@dataclass(kw_only=True) -class BulkAsciiTestCase(AsciiTestCase): - reader_params: AsciiReaderParams - expected_values: dict[str, dict[int, dict[str, float]]] - expected_metadata: dict[str, dict[str, Any]] = field(default_factory=dict) - - -@dataclass(kw_only=True) -class XmlTestCase(BaseTestCase): - filename: str - entry: str = "sasentry01" - round_trip: bool = True - - -@dataclass(kw_only=True) -class Hdf5TestCase(BaseTestCase): - filename: str - entry: str = "sasentry01" - round_trip: bool = True - - -@dataclass(kw_only=True) -class SesansTestCase(BaseTestCase): - filename: str - - -test_cases = [ - pytest.param( - AsciiTestCase( - reader_params=local_data_load("ascii_test_1.txt"), - expected_values={ - 0: {"Q": 0.002618, "I": 0.02198, "dI": 0.002704}, - -1: {"Q": 0.0497, "I": 8.346, "dI": 0.191}, - }, - ), - marks=pytest.mark.xfail(reason="The ASCII reader cannot make the right guesses for this file."), - ), - AsciiTestCase( - reader_params=local_data_load("test_3_columns.txt"), - expected_values={ - 0: {"Q": 0, "I": 2.83954, "dI": 0.6}, - -1: {"Q": 1.22449, "I": 7.47487, "dI": 1.05918}, - }, - ), - pytest.param( - AsciiTestCase( - reader_params=local_data_load("detector_rectangular.DAT"), - expected_values={ - 0: { - "Qx": -0.009160664, - "Qy": -0.1683881, - "I": 16806.79, - "dI": 0.01366757, - }, - -1: { - "Qx": 0.2908819, - "Qy": 0.1634992, - "I": 8147.779, - "dI": 0.05458562, - }, - }, - ), - marks=pytest.mark.xfail( - reason="Guesses for 2D ASCII files are currently wrong, so the data loaded won't be correct." - ), - ), - BulkAsciiTestCase( - reader_params=AsciiReaderParams( - filenames=[ - local_data_load(filename) - for filename in [ - "1_33_1640_22.874115.csv", - "2_42_1640_23.456895.csv", - "3_61_1640_23.748285.csv", - "4_103_1640_24.039675.csv", - "5_312_1640_24.331065.csv", - "6_1270_1640_24.331065.csv", - ] - ], - columns=[(column, per_angstrom) for column in guess_columns(3, one_dim)], - separator_dict={"Comma": True}, - metadata=AsciiReaderMetadata( - master_metadata={ - "magnetic": AsciiMetadataCategory( - values={ - "counting_index": 0, - "applied_magnetic_field": 1, - "saturation_magnetization": 2, - "demagnetizing_field": 3, - } - ) - } - ), - ), - expected_values={}, - expected_metadata={ - "1_33_1640_22.874115.csv": { - "counting_index": ["1"], - "applied_magnetic_field": ["33"], - "saturation_magnetization": ["1640"], - "demagnetizing_field": ["22"], - }, - "6_1270_1640_24.331065.csv": { - "counting_index": ["6"], - "applied_magnetic_field": ["1270"], - "saturation_magnetization": ["1640"], - "demagnetizing_field": ["24"], - }, - }, - ), - XmlTestCase( - filename=local_data_load("ISIS_1_0.xml"), - entry="79680main_1D_2.2_10.0", - expected_values={ - 0: {"Q": 0.009, "I": 85.3333, "dI": 0.852491, "dQ": 0}, - -2: {"Q": 0.281, "I": 0.408902, "dQ": 0}, - -1: {"Q": 0.283, "I": 0, "dI": 0, "dQ": 0}, - }, - expected_metadata={ - # TODO: Add more. - "radiation": "neutron" - }, - ), - Hdf5TestCase( - filename=local_data_load("simpleexamplefile.h5"), - metadata_file=local_reference_load("simpleexamplefile.txt"), - expected_values={ - 0: {"Q": 0.5488135039273248, "I": 0.6778165367962301}, - -1: {"Q": 0.004695476192547066, "I": 0.4344166255581208}, - }, - ), - Hdf5TestCase( - filename=local_data_load("MAR07232_rest.h5"), - metadata_file=local_reference_load("MAR07232_rest.txt"), - expected_values={}, - ), - Hdf5TestCase( - filename=local_data_load("x25000_no_di.h5"), - expected_values={}, - ), - Hdf5TestCase( - filename=local_data_load("nxcansas_1Dand2D_multisasentry.h5"), - metadata_file=local_reference_load("nxcansas_1Dand2D_multisasentry.txt"), - expected_values={}, - ), - Hdf5TestCase( - filename=local_data_load("nxcansas_1Dand2D_multisasdata.h5"), - metadata_file=local_reference_load("nxcansas_1Dand2D_multisasdata.txt"), - expected_values={}, - ), - XmlTestCase( - filename=local_data_load("ISIS_1_0.xml"), - entry="79680main_1D_2.2_10.0", - metadata_file=local_reference_load("ISIS_1_0.txt"), - json_file=local_json_load("ISIS_1_0.json"), - expected_values={}, - ), - XmlTestCase( - filename=local_data_load("ISIS_1_1.xml"), - entry="79680main_1D_2.2_10.0", - metadata_file=local_reference_load("ISIS_1_1.txt"), - json_file=local_json_load("ISIS_1_1.json"), - expected_values={}, - ), - XmlTestCase( - filename=local_data_load("ISIS_1_1_doubletrans.xml"), - entry="79680main_1D_2.2_10.0", - metadata_file=local_reference_load("ISIS_1_1_doubletrans.txt"), - json_file=local_json_load("ISIS_1_1_doubletrans.json"), - expected_values={}, - ), - XmlTestCase( - filename=local_data_load("ISIS_1_1_notrans.xml"), - entry="79680main_1D_2.2_10.0", - metadata_file=local_reference_load("ISIS_1_1_notrans.txt"), - json_file=local_json_load("ISIS_1_1_notrans.json"), - expected_values={}, - ), - XmlTestCase( - filename=local_data_load("TestExtensions.xml"), - entry="TK49 c10_SANS", - metadata_file=local_reference_load("TestExtensions.txt"), - json_file=local_json_load("TestExtensions.json"), - expected_values={}, - ), - XmlTestCase( - filename=local_data_load("cansas1d.xml"), - entry="Test title", - metadata_file=local_reference_load("cansas1d.txt"), - json_file=local_json_load("cansas1d.json"), - expected_values={}, - ), - XmlTestCase( - filename=local_data_load("cansas1d_badunits.xml"), - entry="Test title", - metadata_file=local_reference_load("cansas1d_badunits.txt"), - json_file=local_json_load("cansas1d_badunits.json"), - expected_values={}, - ), - XmlTestCase( - filename=local_data_load("cansas1d_notitle.xml"), - entry="SasData01", - metadata_file=local_reference_load("cansas1d_notitle.txt"), - json_file=local_json_load("cansas1d_notitle.json"), - expected_values={}, - ), - XmlTestCase( - filename=local_data_load("cansas1d_slit.xml"), - entry="Test title", - metadata_file=local_reference_load("cansas1d_slit.txt"), - json_file=local_json_load("cansas1d_slit.json"), - expected_values={}, - ), - XmlTestCase( - filename=local_data_load("cansas1d_units.xml"), - entry="Test title", - metadata_file=local_reference_load("cansas1d_units.txt"), - json_file=local_json_load("cansas1d_units.json"), - expected_values={}, - ), - XmlTestCase( - filename=local_data_load("cansas_test.xml"), - entry="ILL-D11 example1: 2A 5mM 0%D2O", - metadata_file=local_reference_load("cansas_test.txt"), - json_file=local_json_load("cansas_test.json"), - expected_values={}, - ), - XmlTestCase( - filename=local_data_load("cansas_test_modified.xml"), - entry="ILL-D11 example1: 2A 5mM 0%D2O", - metadata_file=local_reference_load("cansas_test_modified.txt"), - json_file=local_json_load("cansas_test_modified.json"), - expected_values={}, - ), - XmlTestCase( - filename=local_data_load("valid_cansas_xml.xml"), - entry="80514main_1D_2.2_10.0", - metadata_file=local_reference_load("valid_cansas_xml.txt"), - json_file=local_json_load("valid_cansas_xml.json"), - expected_values={}, - ), - SesansTestCase( - filename=local_sesans_load("sphere2micron.ses"), - metadata_file=local_reference_load("sphere2micron.txt"), - expected_values={ - 0: {"SpinEchoLength": 391.56, "Depolarisation": 0.0041929}, - -1: {"SpinEchoLength": 46099, "Depolarisation": -0.19956}, - }, - ), +test_hdf_file_names = [ + # "simpleexamplefile", + "nxcansas_1Dand2D_multisasentry", + "nxcansas_1Dand2D_multisasdata", + "MAR07232_rest", + "x25000_no_di", ] +test_xml_file_names = [ + "ISIS_1_0", + "ISIS_1_1", + "ISIS_1_1_doubletrans", + "ISIS_1_1_notrans", + "TestExtensions", + "cansas1d", + "cansas1d_badunits", + "cansas1d_notitle", + "cansas1d_slit", + "cansas1d_units", + "cansas_test", + "cansas_test_modified", + "cansas_xml_multisasentry_multisasdata", + "valid_cansas_xml", +] -def join_actual_expected( - actual: list[SasData], expected: dict[str, dict[int, dict[str, float]]] -) -> list[tuple[SasData, dict[int, dict[str, float]]]]: - return_value = [] - for actual_datum in actual: - matching_expected_datum = expected.get(actual_datum.name) - if matching_expected_datum is None: - continue - return_value.append((actual_datum, matching_expected_datum)) - return return_value - - -def is_uncertainty(column: str) -> bool: - for uncertainty_str in ["I", "Q", "Qx", "Qy"]: - if column == "d" + uncertainty_str: - return True - return False - - -@pytest.mark.dataload -@pytest.mark.parametrize("test_case", test_cases) -def test_load_file(test_case: BaseTestCase): - match test_case: - case BulkAsciiTestCase(): - loaded_data = ascii_load_data(test_case.reader_params) - case AsciiTestCase(): - if isinstance(test_case.reader_params, str): - loaded_data = load_data_default_params(test_case.reader_params)[0] - elif isinstance(test_case.reader_params, AsciiReaderParams): - loaded_data = ascii_load_data(test_case.reader_params)[0] - else: - raise TypeError("Invalid type for reader_params.") - case Hdf5TestCase(): - combined_data = hdf_load_data(test_case.filename) - loaded_data = combined_data[test_case.entry] - # TODO: Support SESANS - case XmlTestCase(): - # Not bulk, so just assume we get one dataset. - combined_data = xml_load_data(test_case.filename) - loaded_data = combined_data[test_case.entry] - case SesansTestCase(): - loaded_data = sesans_load_data(test_case.filename) - combined_data = {"only": loaded_data} - case _: - raise ValueError("Invalid loader") - if isinstance(test_case, BulkAsciiTestCase): - loaded_expected_pairs = join_actual_expected(loaded_data, test_case.expected_values) - metadata_filenames = test_case.expected_metadata.keys() - else: - loaded_expected_pairs = [(loaded_data, test_case.expected_values)] - metadata_filenames = [loaded_data.name] - for loaded, expected in loaded_expected_pairs: - for index, values in expected.items(): - for column, expected_value in values.items(): - if is_uncertainty(column): - assert loaded._data_contents[column[1::]]._variance[index] == pytest.approx(expected_value**2) - else: - assert loaded._data_contents[column].value[index] == pytest.approx(expected_value) - - for filename in metadata_filenames: - current_metadata_dict = test_case.expected_metadata.get(filename) - current_datum = ( - next(filter(lambda d: d.name == filename, loaded_data)) if isinstance(loaded_data, list) else loaded_data - ) - if current_metadata_dict is None: - continue - for metadata_key, value in current_metadata_dict.items(): - assert current_datum.metadata.raw.filter(metadata_key) == value - if test_case.metadata_file is not None: - with open(test_case.metadata_file, encoding="utf-8") as infile: - expected = "".join(infile.readlines()) - keys = sorted([d for d in combined_data]) - assert "".join(combined_data[k].summary() for k in keys) == expected +def local_load(path: str): + """Get local file path""" + return os.path.join(os.path.dirname(__file__), path) - if test_case.json_file is not None: - # Test serialisation - with open(test_case.json_file, encoding="utf-8") as infile: - expected = json.loads("".join(infile.readlines())) - assert json.loads(SasDataEncoder().encode(combined_data)) == expected - # Test deserialisation - with open(test_case.json_file, encoding="utf-8") as infile: - raw = json.loads("".join(infile.readlines())) - parsed = {} - for k in raw: - parsed[k] = SasData.from_json(raw[k]) +@pytest.mark.sasdata +@pytest.mark.parametrize("f", test_hdf_file_names) +def test_hdf_load_file(f): + data = hdf_load_data(local_load(f"data/{f}.h5")) - for k in combined_data: - expect = combined_data[k] - pars = parsed[k] - assert pars.name == expect.name - # assert pars._data_contents == expect._data_contents - assert pars.dataset_type == expect.dataset_type - assert pars.mask == expect.mask - assert pars.model_requirements == expect.model_requirements + with open(local_load(f"reference/{f}.txt")) as infile: + expected = "".join(infile.readlines()) + keys = sorted([d for d in data]) + assert "".join(data[k].summary() for k in keys) == expected - if test_case.round_trip: - bio = io.BytesIO() - SasData.save_h5(combined_data, bio) - bio.seek(0) - result = hdf_load_data(bio) - bio.close() +@pytest.mark.current +@pytest.mark.parametrize("f", test_xml_file_names) +def test_xml_load_file(f): + data = xml_load_data(local_load(f"data/{f}.xml")) - for name, entry in result.items(): - assert combined_data[name].metadata.title == entry.metadata.title - assert combined_data[name].metadata.run == entry.metadata.run - assert combined_data[name].metadata.definition == entry.metadata.definition - assert combined_data[name].metadata.process == entry.metadata.process - assert combined_data[name].metadata.instrument == entry.metadata.instrument - assert combined_data[name].metadata.sample == entry.metadata.sample - assert combined_data[name].ordinate.units == entry.ordinate.units - assert np.all(combined_data[name].ordinate.value == entry.ordinate.value) - assert combined_data[name].abscissae.units == entry.abscissae.units - assert np.all(combined_data[name].abscissae.value == entry.abscissae.value) + with open(local_load(f"reference/{f}.txt")) as infile: + expected = "".join(infile.readlines()) + assert data[0].summary() == expected diff --git a/test/slicers/utest_point_assignment.py b/test/slicers/utest_point_assignment.py index c648edde3..a82f1c790 100644 --- a/test/slicers/utest_point_assignment.py +++ b/test/slicers/utest_point_assignment.py @@ -1,4 +1,5 @@ + def test_location_assignment(): pass diff --git a/test/utest_new_sasdata.py b/test/utest_new_sasdata.py index cc497d9ae..62e5e4ab7 100644 --- a/test/utest_new_sasdata.py +++ b/test/utest_new_sasdata.py @@ -2,11 +2,9 @@ from sasdata.data import SasData from sasdata.data_backing import Group -from sasdata.dataset_types import one_dim, three_dim, two_dim -from sasdata.metadata import Instrument, Metadata, Source -from sasdata.postprocess import deduce_qz +from sasdata.dataset_types import one_dim from sasdata.quantities.quantity import Quantity -from sasdata.quantities.units import angstroms, per_angstrom, per_centimeter +from sasdata.quantities.units import per_angstrom, per_centimeter def test_1d(): @@ -23,85 +21,5 @@ def test_1d(): data = SasData('TestData', data_contents, one_dim, Group('root', {}), True) - assert all(data.abscissae.value == np.array(q)) - assert all(data.ordinate.value == np.array(i)) - - -def test_2d(): - # This could be autogenerated but I am hard coding to reduce the logic in - # the test. - qx = [1, 1, 1, 2, 2, 2, 3, 3, 3] - qy = [1, 2, 3, 1, 2, 3, 1, 2, 3] - i = [1, 2, 3] - - qx_quantity = Quantity(np.array(qx), per_angstrom) - qy_quantity = Quantity(np.array(qy), per_angstrom) - i_quantity = Quantity(np.array(i), per_centimeter) - - data_contents = { - 'Qx': qx_quantity, - 'Qy': qy_quantity, - 'I': i_quantity - } - - data = SasData('TestData', data_contents, two_dim, Group('root', {}), True) - - assert all(data.ordinate.value == np.array(i)) - assert (data.abscissae.value == np.array([[1, 1], [1, 2], [1, 3], [2, 1], [2, 2], [2, 3], [3, 1], [3, 2], [3, 3]])).all().all() - -def test_3d(): - # test base 3D class - qx = [1, 1, 1, 2, 2, 2, 3, 3, 3] - qy = [1, 2, 3, 1, 2, 3, 1, 2, 3] - qz = [0, 1, 0, 1, 0, 1, 0, 1, 0] - i = [1, 2, 3] - - qx_quantity = Quantity(np.array(qx), per_angstrom) - qy_quantity = Quantity(np.array(qy), per_angstrom) - qz_quantity = Quantity(np.array(qz), per_angstrom) - i_quantity = Quantity(np.array(i), per_centimeter) - - data_contents = { - 'Qx': qx_quantity, - 'Qy': qy_quantity, - 'Qz': qz_quantity, - 'I': i_quantity - } - - data = SasData('TestData', data_contents, three_dim, Group('root', {}), True) - - assert (data._data_contents['Qx'].value == np.array(qx)).all() - - - - # test autogenerated qz from qx, qy, and wavelength - wavelength = Quantity(1., angstroms) - source = Source(radiation=None, - beam_shape=None, - beam_size=None, - wavelength=wavelength, - wavelength_max=None, - wavelength_min=None, - wavelength_spread=None) - instrument = Instrument(collimations=[], - source=source, - detector=[]) - metadata=Metadata(title=None, - run=[], - definition=None, - process=[], - sample=None, - instrument=instrument, - raw=None) - - data_contents = { - 'Qx': qx_quantity, - 'Qy': qy_quantity, - 'I': i_quantity - } - - data = SasData('TestData', data_contents, two_dim, metadata, True) - - deduce_qz(data) - - assert (data._data_contents['Qz'].value != (0*data._data_contents['Qx'].value)).all() + assert all(data.abscissae == np.array(q_quantity)) + assert all(data.ordinate == np.array(i_quantity)) diff --git a/test/utest_sasdata.py b/test/utest_sasdata.py index cebc748e2..391768382 100644 --- a/test/utest_sasdata.py +++ b/test/utest_sasdata.py @@ -9,6 +9,13 @@ logging.config.fileConfig(LOGGER_CONFIG_FILE) logger = logging.getLogger(__name__) +try: + pass +except: + logger.error("xmlrunner needs to be installed to run these tests") + logger.error("Try easy_install unittest-xml-reporting") + sys.exit(1) + # Check whether we have matplotlib installed HAS_MPL_WX = True try: diff --git a/test/utest_temp_ascii_reader.py b/test/utest_temp_ascii_reader.py index 0361ca342..8ae481ac8 100644 --- a/test/utest_temp_ascii_reader.py +++ b/test/utest_temp_ascii_reader.py @@ -1,3 +1,4 @@ +import pytest import os from typing import Literal @@ -20,24 +21,11 @@ # TODO: Look into parameterizing this, although its not trivial due to the setup, and tests being a bit different. -def find(filename: str, locations: Literal["sasdataloader", "mumag"]) -> str: - # This match statement is here in case we want to pull data out of other locations. - match locations: - case "sasdataloader": - return os.path.join( - os.path.dirname(__file__), "sasdataloader", "data", filename - ) - case "mumag": - return os.path.join( - os.path.dirname(__file__), - "mumag", - "Nanoperm_perpendicular_Honecker_et_al", - filename, - ) - +def find(filename: str) -> str: + return os.path.join(os.path.dirname(__file__), 'sasdataloader', 'data', filename) def test_ascii_1(): - filename = find("ascii_test_1.txt", "sasdataloader") + filename = 'ascii_test_1.txt' params = guess_params_from_filename(filename, one_dim) # Need to change the columns as they won't be right. # TODO: unitless @@ -63,6 +51,7 @@ def test_ascii_1(): case "dI": assert datum.value[0] == pytest.approx(0.002704) assert datum.value[-1] == pytest.approx(0.191) +<<<<<<< HEAD def test_ascii_2(): @@ -152,3 +141,23 @@ def test_mumag_metadata(): assert datum.metadata.raw.filter("applied_magnetic_field") == ["1270"] assert datum.metadata.raw.filter("saturation_magnetization") == ["1640"] assert datum.metadata.raw.filter("demagnetizing_field") == ["24"] +||||||| parent of f7982fb1 (Wrote a test for the ASCII reader.) + +def test_ascii_2(): + filename = find('test_3_columns.txt', 'sasdataloader') + params = guess_params_from_filename(filename, one_dim) + loaded_data = load_data(params)[0] + + for datum in loaded_data._data_contents: + match datum.name: + case 'Q': + assert datum.value[0] == pytest.approx(0) + assert datum.value[-1] == pytest.approx(1.22449) + case 'I': + assert datum.value[0] == pytest.approx(2.83954) + assert datum.value[-1] == pytest.approx(7.47487) + case 'dI': + assert datum.value[0] == pytest.approx(0.6) + assert datum.value[-1] == pytest.approx(1.05918) +======= +>>>>>>> f7982fb1 (Wrote a test for the ASCII reader.) diff --git a/test/utest_trend.py b/test/utest_trend.py index b079bf53c..cae81131a 100644 --- a/test/utest_trend.py +++ b/test/utest_trend.py @@ -4,34 +4,24 @@ import sasdata.temp_ascii_reader as ascii_reader from sasdata.ascii_reader_metadata import AsciiMetadataCategory -from sasdata.quantities.units import per_angstrom, per_nanometer +from sasdata.quantities.units import per_nanometer from sasdata.temp_ascii_reader import AsciiReaderParams from sasdata.trend import Trend -mumag_test_directories = [ +test_directories = [ 'FeNiB_perpendicular_Bersweiler_et_al', 'Nanoperm_perpendicular_Honecker_et_al', 'NdFeB_parallel_Bick_et_al' ] -custom_test_directory = 'custom_test' +@pytest.mark.parametrize('directory_name', test_directories) +def test_trend_build(directory_name: str): + """Try to build a trend object on the MuMag datasets, and see if all the Q items match (as they should).""" + load_from = path.join(path.dirname(__file__), directory_name) + files_to_load = listdir(load_from) -def get_files_to_load(directory_name: str) -> list[str]: - load_from = path.join(path.dirname(__file__), 'trend_test_data', directory_name) - base_filenames_to_load = listdir(load_from) - files_to_load = [path.join(load_from, basename) for basename in base_filenames_to_load] - return files_to_load - -@pytest.mark.parametrize('directory_name', mumag_test_directories) -def test_trend_build_interpolate(directory_name: str): - """Try to build a trend object on the MuMag datasets""" - files_to_load = get_files_to_load(directory_name) - params = AsciiReaderParams( - filenames=files_to_load, - columns=[('Q', per_nanometer), ('I', per_nanometer), ('dI', per_nanometer)], - ) - params.separator_dict['Whitespace'] = True - params.metadata.master_metadata['magnetic'] = AsciiMetadataCategory( + metadata = AsciiReaderMetadata() + metadata.master_metadata['magnetic'] = AsciiMetadataCategory( values={ 'counting_index': 0, 'applied_magnetic_field': 1, @@ -39,31 +29,18 @@ def test_trend_build_interpolate(directory_name: str): 'demagnetizing_field': 3 } ) - data = ascii_reader.load_data(params) - trend = Trend( - data=data, - trend_axis=['magnetic', 'applied_magnetic_field'] - ) - # Initially, the q axes in this date don't exactly match - to_interpolate_on = 'Q' - assert not trend.all_axis_match(to_interpolate_on) - interpolated_trend = trend.interpolate(to_interpolate_on) - assert interpolated_trend.all_axis_match(to_interpolate_on) -def test_trend_q_axis_match(): - files_to_load = get_files_to_load(custom_test_directory) params = AsciiReaderParams( filenames=files_to_load, - columns=[('Q', per_angstrom), ('I', per_angstrom)] - ) - params.metadata.master_metadata['magnetic'] = AsciiMetadataCategory( - values={ - 'counting_index': 0, - } + starting_line=0, + columns=[('Q', per_nanometer), ('I', per_nanometer), ('dI', per_nanometer)], + excluded_lines=set(), + separator_dict={'Whitespace': True, 'Comma': False, 'Tab': False}, + metadata=metadata, ) data = ascii_reader.load_data(params) trend = Trend( data=data, - trend_axis=['magnetic', 'counting_index'] + trend_axis=['magnetic', 'applied_magnetic_field'] ) assert trend.all_axis_match('Q') diff --git a/test/utest_unit_parser.py b/test/utest_unit_parser.py index e757a0220..7bdcf997f 100644 --- a/test/utest_unit_parser.py +++ b/test/utest_unit_parser.py @@ -1,5 +1,3 @@ -import re - import pytest from sasdata.quantities import units @@ -7,75 +5,57 @@ from sasdata.quantities.units import Unit named_units_for_testing = [ - ("m", units.meters), - ("A-1", units.per_angstrom), - ("1/A", units.per_angstrom), - ("1/angstroms", units.per_angstrom), - ("micrometer", units.micrometers), - ("micron", units.micrometers), - ("kmh-2", units.kilometers_per_square_hour), - ("km/h2", units.kilometers_per_square_hour), - ("kgm/s2", units.newtons), - ("m m", units.square_meters), - ("mm", units.millimeters), - ("A^-1", units.per_angstrom), - ("V/Amps", units.ohms), - ("Ω", units.ohms), - ("Å", units.angstroms), - ("%", units.percent), - ("per_centimeter", units.per_centimeter), + ('m', units.meters), + ('A-1', units.per_angstrom), + ('1/A', units.per_angstrom), + ('kmh-2', units.kilometers_per_square_hour), + ('km/h2', units.kilometers_per_square_hour), + ('kgm/s2', units.newtons), + ('m m', units.square_meters), + ('mm', units.millimeters), + ('A^-1', units.per_angstrom), + ('V/Amps', units.ohms), + ('Ω', units.ohms), + ('Å', units.angstroms), + ('%', units.percent) ] -latex_units_for_testing = [ - (r"\Omega", units.ohms), # Test omega is Ω - (r"\AA", units.angstroms), # Test angstrom is Å - (r"\%", units.percent), # Test percent is NOT a comment - (r"{\mu}A", units.microamperes), # Test µ with an ASCII unit - (r"{\mu}\Omega", units.microohms), # Test µ with LaTeX unit - (r"mm", units.millimeters), # Test that most units just use ASCII in LaTeX +unnamed_units_for_testing = [ + ('m13', units.meters**13), + ('kW/sr', units.kilowatts/units.stradians) ] -unnamed_units_for_testing = [("m13", units.meters**13), ("kW/sr", units.kilowatts / units.stradians)] - - @pytest.mark.parametrize("string, expected_units", named_units_for_testing) def test_name_parse(string: str, expected_units: Unit): - """Test basic parsing""" + """ Test basic parsing""" assert parse_named_unit(string) == expected_units - @pytest.mark.parametrize("string, expected_units", named_units_for_testing + unnamed_units_for_testing) def test_equivalent(string: str, expected_units: Unit): - """Check dimensions of parsed units""" + """ Check dimensions of parsed units""" assert parse_unit(string).equivalent(expected_units) @pytest.mark.parametrize("string, expected_units", named_units_for_testing + unnamed_units_for_testing) def test_scale_same(string: str, expected_units: Unit): - """Test basic parsing""" + """ Test basic parsing""" assert parse_unit(string).scale == pytest.approx(expected_units.scale, rel=1e-14) -@pytest.mark.parametrize("latex_string, units", latex_units_for_testing) -def test_latex_parse(latex_string: str, units: Unit): - """Test that proper LaTeX formats for units are being generated""" - assert units.latex_symbol == latex_string - - def test_parse_from_group(): - """Test group based disambiguation""" - parsed_metres_per_second = parse_named_unit_from_group("ms-1", units.speed) + """ Test group based disambiguation""" + parsed_metres_per_second = parse_named_unit_from_group('ms-1', units.speed) assert parsed_metres_per_second == units.meters_per_second def test_parse_errors(): # Fails because the unit is not in that specific group. - with pytest.raises(ValueError, match="That unit cannot be parsed from the specified group."): - parse_named_unit_from_group("km", units.speed) + with pytest.raises(ValueError, match='That unit cannot be parsed from the specified group.'): + parse_named_unit_from_group('km', units.speed) # Fails because part of the unit matches but there is an unknown unit '@' - with pytest.raises(ValueError, match=re.escape("unit_str (km@-1) contains forbidden characters.")): - parse_unit("km@-1") + with pytest.raises(ValueError, match='unit_str contains forbidden characters.'): + parse_unit('km@-1') # Fails because 'da' is not a unit. - with pytest.raises(ValueError, match="Unit string contains an unrecognised pattern."): - parse_unit("mmda2") + with pytest.raises(ValueError, match='Unit string contains an unrecognised pattern.'): + parse_unit('mmda2')