diff --git a/backend/core/migrations/0058_productfile_storage_kind_metadata.py b/backend/core/migrations/0058_productfile_storage_kind_metadata.py new file mode 100644 index 0000000..9570154 --- /dev/null +++ b/backend/core/migrations/0058_productfile_storage_kind_metadata.py @@ -0,0 +1,31 @@ +# Generated by Codex on 2026-04-30 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("core", "0057_delete_groupmembership"), + ] + + operations = [ + migrations.AddField( + model_name="productfile", + name="metadata", + field=models.JSONField(blank=True, default=dict), + ), + migrations.AddField( + model_name="productfile", + name="storage_kind", + field=models.CharField( + choices=[ + ("file", "File"), + ("hats_collection", "HATS Collection"), + ], + default="file", + max_length=32, + verbose_name="Storage Kind", + ), + ), + ] diff --git a/backend/core/models/__init__.py b/backend/core/models/__init__.py index de13bb5..de8b87e 100644 --- a/backend/core/models/__init__.py +++ b/backend/core/models/__init__.py @@ -2,7 +2,7 @@ from core.models.product_type import ProductType from core.models.product import Product, ProductStatus from core.models.product_content import ProductContent -from core.models.product_file import ProductFile, FileRoles +from core.models.product_file import ProductFile, FileRoles, FileStorageKind from core.models.user_profile import Profile from core.models.pipeline import Pipeline from core.models.process import Process diff --git a/backend/core/models/product_file.py b/backend/core/models/product_file.py index b99639f..2b399e9 100644 --- a/backend/core/models/product_file.py +++ b/backend/core/models/product_file.py @@ -1,4 +1,6 @@ import os +import pathlib +import shutil from core.models import Product from django.db import models @@ -14,6 +16,11 @@ class FileRoles(models.IntegerChoices): AUXILIARY = 2, "Auxiliary" +class FileStorageKind(models.TextChoices): + FILE = "file", "File" + HATS_COLLECTION = "hats_collection", "HATS Collection" + + class ProductFile(models.Model): product = models.ForeignKey(Product, on_delete=models.CASCADE, related_name="files") @@ -37,12 +44,24 @@ class ProductFile(models.Model): extension = models.CharField( verbose_name="Extension", max_length=10, null=True, blank=True ) + storage_kind = models.CharField( + verbose_name="Storage Kind", + max_length=32, + choices=FileStorageKind.choices, + default=FileStorageKind.FILE, + ) + metadata = models.JSONField(default=dict, blank=True) created = models.DateTimeField(auto_now_add=True, blank=True) updated = models.DateTimeField(auto_now=True) def delete(self, *args, **kwargs): if self.file: - self.file.delete() + if self.storage_kind == FileStorageKind.HATS_COLLECTION: + path = pathlib.Path(self.file.path) + if path.exists(): + shutil.rmtree(path) + else: + self.file.delete() super().delete(*args, **kwargs) def can_delete(self, user) -> bool: diff --git a/backend/core/product_steps.py b/backend/core/product_steps.py index 4c98404..62127f4 100644 --- a/backend/core/product_steps.py +++ b/backend/core/product_steps.py @@ -3,8 +3,8 @@ import time from json import dumps, loads -from core.models import Product, ProductContent, ProductFile -from core.product_handle import NotTableError +from core.models import FileStorageKind, Product, ProductContent, ProductFile +from core.product_handle import NotTableError, ProductHandle from core.serializers import ProductSerializer from core.table_data_collector import MainTableDataCollector from django.conf import settings @@ -277,7 +277,9 @@ def registry(self): @classmethod def get_table_preview_path(cls, product): """Build the absolute path to the cached table preview JSON file.""" - return pathlib.Path(settings.MEDIA_ROOT, product.path, cls.TABLE_PREVIEW_FILENAME) + return pathlib.Path( + settings.MEDIA_ROOT, product.path, cls.TABLE_PREVIEW_FILENAME + ) @classmethod def get_table_preview_processing_path(cls, product): diff --git a/backend/core/services/hats_collection.py b/backend/core/services/hats_collection.py new file mode 100644 index 0000000..c104179 --- /dev/null +++ b/backend/core/services/hats_collection.py @@ -0,0 +1,283 @@ +import logging +import os +import pathlib +import shutil +import stat +import tarfile +import tempfile +import zipfile + +from django.conf import settings + +LOGGER = logging.getLogger("django") + +HATS_ARCHIVE_EXTENSIONS = (".tar", ".tar.gz", ".tgz", ".zip") +HATS_REJECTED_PRODUCT_TYPES = {"training_results"} +DEFAULT_MAX_MEMBERS = 200000 +DEFAULT_MAX_EXTRACTED_SIZE = 20 * 1024 * 1024 * 1024 + + +class HatsArchiveError(ValueError): + pass + + +class NotHatsCollectionError(HatsArchiveError): + pass + + +class UnsafeArchiveError(HatsArchiveError): + pass + + +def is_hats_archive_name(filename): + name = filename.lower() + return name.endswith(HATS_ARCHIVE_EXTENSIONS) + + +def product_type_accepts_hats(product_type_name): + return product_type_name not in HATS_REJECTED_PRODUCT_TYPES + + +def _is_relative_safe_path(path): + candidate = pathlib.PurePosixPath(path.replace("\\", "/")) + return not candidate.is_absolute() and ".." not in candidate.parts + + +def _safe_target(root, member_name): + target = pathlib.Path(root, member_name).resolve() + root = pathlib.Path(root).resolve() + try: + target.relative_to(root) + except ValueError as exc: + raise UnsafeArchiveError(f"Unsafe archive path: {member_name}") from exc + return target + + +def _check_member_count(count): + max_members = getattr(settings, "HATS_MAX_ARCHIVE_MEMBERS", DEFAULT_MAX_MEMBERS) + if count > max_members: + raise UnsafeArchiveError( + f"Archive has too many members ({count}; maximum is {max_members})." + ) + + +def _check_total_size(total_size): + max_size = getattr( + settings, "HATS_MAX_EXTRACTED_SIZE", DEFAULT_MAX_EXTRACTED_SIZE + ) + if total_size > max_size: + raise UnsafeArchiveError( + "Archive extracted size exceeds the configured limit " + f"({total_size} bytes; maximum is {max_size})." + ) + + +def _extract_tar(archive_path, destination): + total_size = 0 + with tarfile.open(archive_path) as tar: + members = tar.getmembers() + _check_member_count(len(members)) + + for member in members: + if not _is_relative_safe_path(member.name): + raise UnsafeArchiveError(f"Unsafe archive path: {member.name}") + if member.issym() or member.islnk() or member.isdev(): + raise UnsafeArchiveError( + f"Archive member is not a regular file or directory: {member.name}" + ) + if member.isfile(): + total_size += member.size + _check_total_size(total_size) + + for member in members: + target = _safe_target(destination, member.name) + if member.isdir(): + target.mkdir(parents=True, exist_ok=True) + continue + + if member.isfile(): + target.parent.mkdir(parents=True, exist_ok=True) + source = tar.extractfile(member) + if source is None: + raise UnsafeArchiveError( + f"Could not read archive member: {member.name}" + ) + with source, open(target, "wb") as output: + shutil.copyfileobj(source, output) + + +def _zip_member_is_symlink(info): + mode = info.external_attr >> 16 + return stat.S_ISLNK(mode) + + +def _extract_zip(archive_path, destination): + total_size = 0 + with zipfile.ZipFile(archive_path) as zip_archive: + members = zip_archive.infolist() + _check_member_count(len(members)) + + for member in members: + if not _is_relative_safe_path(member.filename): + raise UnsafeArchiveError(f"Unsafe archive path: {member.filename}") + if _zip_member_is_symlink(member): + raise UnsafeArchiveError( + f"Archive member is not a regular file or directory: {member.filename}" + ) + if not member.is_dir(): + total_size += member.file_size + _check_total_size(total_size) + + for member in members: + target = _safe_target(destination, member.filename) + if member.is_dir(): + target.mkdir(parents=True, exist_ok=True) + continue + + target.parent.mkdir(parents=True, exist_ok=True) + with zip_archive.open(member) as source, open(target, "wb") as output: + shutil.copyfileobj(source, output) + + +def _extract_archive(archive_path, destination): + name = archive_path.name.lower() + if name.endswith((".tar", ".tar.gz", ".tgz")): + _extract_tar(archive_path, destination) + return + + if name.endswith(".zip"): + _extract_zip(archive_path, destination) + return + + raise NotHatsCollectionError("Unsupported HATS archive extension.") + + +def _candidate_roots(extracted_dir): + roots = [] + for root, dirs, files in os.walk(extracted_dir): + file_names = set(files) + if "collection.properties" in file_names or "hats.properties" in file_names: + roots.append(pathlib.Path(root)) + return roots + + +def _read_properties_file(path): + data = {} + if not path.exists(): + return data + + with open(path, encoding="utf-8") as properties: + for raw_line in properties: + line = raw_line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, value = line.split("=", 1) + data[key.strip()] = value.strip().replace("\\:", ":") + return data + + +def _to_jsonable_dtypes(dtypes): + return {str(column): str(dtype) for column, dtype in dtypes.items()} + + +def _validate_with_lsdb(candidate): + try: + import lsdb + except ImportError as exc: + raise HatsArchiveError( + "LSDB is not installed in the backend environment; " + "HATS validation cannot be performed." + ) from exc + + try: + catalog = lsdb.open_catalog(candidate) + except Exception as exc: + raise NotHatsCollectionError(f"Archive is not a valid HATS collection: {exc}") + + columns = [str(column) for column in getattr(catalog, "all_columns", [])] + if not columns: + columns = [str(column) for column in getattr(catalog, "columns", [])] + + metadata = { + "columns": columns, + "dtypes": _to_jsonable_dtypes(getattr(catalog, "dtypes", {})), + "npartitions": getattr(catalog, "npartitions", None), + "name": getattr(catalog, "name", None), + "hats": _read_properties_file(pathlib.Path(candidate, "hats.properties")), + "collection": _read_properties_file( + pathlib.Path(candidate, "collection.properties") + ), + } + + n_rows = metadata["hats"].get("hats_nrows") + if n_rows is not None: + try: + metadata["n_rows"] = int(n_rows) + except ValueError: + metadata["n_rows"] = None + + return metadata + + +def _find_hats_root(extracted_dir): + validation_errors = [] + for candidate in _candidate_roots(extracted_dir): + try: + metadata = _validate_with_lsdb(candidate) + return candidate, metadata + except NotHatsCollectionError as exc: + validation_errors.append(str(exc)) + + message = "Archive does not contain a valid HATS collection." + if validation_errors: + message = f"{message} {'; '.join(validation_errors)}" + raise NotHatsCollectionError(message) + + +def _write_upload_to_temp(uploaded_file, tmpdir): + suffix = pathlib.Path(uploaded_file.name).suffix + name = uploaded_file.name.lower() + if name.endswith(".tar.gz"): + suffix = ".tar.gz" + elif name.endswith(".tgz"): + suffix = ".tgz" + + archive_path = pathlib.Path(tmpdir, f"upload{suffix}") + with open(archive_path, "wb") as output: + for chunk in uploaded_file.chunks(): + output.write(chunk) + return archive_path + + +def validate_and_store_hats_archive(uploaded_file, product, target_name="main"): + with tempfile.TemporaryDirectory() as tmpdir: + archive_path = _write_upload_to_temp(uploaded_file, tmpdir) + extracted_dir = pathlib.Path(tmpdir, "extracted") + extracted_dir.mkdir() + + _extract_archive(archive_path, extracted_dir) + hats_root, metadata = _find_hats_root(extracted_dir) + + product_root = pathlib.Path(settings.MEDIA_ROOT, product.path) + target = pathlib.Path(product_root, target_name) + if target.exists(): + raise HatsArchiveError( + f"A HATS collection already exists for this product at '{target_name}'." + ) + + target.parent.mkdir(parents=True, exist_ok=True) + shutil.move(str(hats_root), str(target)) + + relative_path = pathlib.Path(product.path, target_name) + return str(relative_path), metadata + + +def validate_hats_archive(uploaded_file): + with tempfile.TemporaryDirectory() as tmpdir: + archive_path = _write_upload_to_temp(uploaded_file, tmpdir) + extracted_dir = pathlib.Path(tmpdir, "extracted") + extracted_dir.mkdir() + + _extract_archive(archive_path, extracted_dir) + _, metadata = _find_hats_root(extracted_dir) + return metadata diff --git a/backend/core/test/test_product_file.py b/backend/core/test/test_product_file.py index fa7276c..148f655 100644 --- a/backend/core/test/test_product_file.py +++ b/backend/core/test/test_product_file.py @@ -1,9 +1,18 @@ import json import mimetypes import os +from io import BytesIO from pathlib import Path - -from core.models import Product, ProductFile, ProductType, Release, FileRoles +from unittest.mock import patch + +from core.models import ( + FileRoles, + FileStorageKind, + Product, + ProductFile, + ProductType, + Release, +) from core.serializers import ProductFileSerializer from core.test.util import sample_product_file from django.contrib.auth.models import User @@ -31,6 +40,7 @@ def setUp(self): # Get Product Types previous created by fixtures self.product_type = ProductType.objects.get(name="validation_results") + self.training_results = ProductType.objects.get(name="training_results") self.product_dict = { "product_type": self.product_type.pk, @@ -56,6 +66,18 @@ def create_product(self): product = Product.objects.get(id=data["id"]) return product + def create_training_results_product(self): + url = reverse("products-list") + response = self.client.post( + url, + { + **self.product_dict, + "product_type": self.training_results.pk, + }, + ) + data = json.loads(response.content) + return Product.objects.get(id=data["id"]) + def test_upload_main_file(self): filename = sample_product_file("csv") @@ -115,6 +137,61 @@ def test_list_product_file(self): record = ProductFile.objects.get(id=product_file["id"]) self.assertTrue(str(record).startswith(self.product.display_name)) + @patch("core.views.product_file.validate_and_store_hats_archive") + def test_upload_hats_main_file(self, validate_and_store): + validate_and_store.return_value = ( + f"{self.product.path}/main", + { + "columns": ["ra", "dec", "z"], + "dtypes": {"ra": "float64", "dec": "float64", "z": "float64"}, + "npartitions": 1, + "n_rows": 3, + }, + ) + archive = BytesIO(b"fake hats archive") + archive.name = "catalog.tar.gz" + + response = self.client.post( + self.url, + { + "product": self.product.pk, + "file": archive, + "role": 0, + "type": "application/gzip", + }, + format="multipart", + ) + + data = json.loads(response.content) + self.assertEqual(201, response.status_code) + product_file = ProductFile.objects.get(pk=data["id"]) + self.assertEqual(product_file.storage_kind, FileStorageKind.HATS_COLLECTION) + self.assertEqual(product_file.file.name, f"{self.product.path}/main") + self.assertEqual(product_file.metadata["columns"], ["ra", "dec", "z"]) + self.assertEqual(product_file.n_rows, 3) + + @patch("core.views.product_file.validate_hats_archive") + def test_training_results_rejects_hats_main_file(self, validate_hats): + product = self.create_training_results_product() + validate_hats.return_value = {"columns": ["ra", "dec"]} + archive = BytesIO(b"fake hats archive") + archive.name = "catalog.tar.gz" + + response = self.client.post( + self.url, + { + "product": product.pk, + "file": archive, + "role": 0, + "type": "application/gzip", + }, + format="multipart", + ) + + data = json.loads(response.content) + self.assertEqual(400, response.status_code) + self.assertIn("HATS collections are not accepted", data["error"]) + class ProductFileDetailAPIViewTestCase(APITestCase): @@ -220,6 +297,8 @@ def test_product_file_serialized_format(self): "size": self.product_file.file.size, "n_rows": None, "extension": os.path.splitext(self.product_file.file.name)[1], + "storage_kind": FileStorageKind.FILE, + "metadata": {}, "created": self.product_file.created.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), "updated": self.product_file.updated.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), "can_delete": True, diff --git a/backend/core/test/test_product_registry.py b/backend/core/test/test_product_registry.py index 9195fdd..7dfce03 100644 --- a/backend/core/test/test_product_registry.py +++ b/backend/core/test/test_product_registry.py @@ -3,7 +3,14 @@ from pathlib import Path from unittest import mock -from core.models import Product, ProductFile, ProductType, Release +from core.models import ( + FileStorageKind, + Product, + ProductContent, + ProductFile, + ProductType, + Release, +) from core.product_handle import ProductHandle from core.product_steps import RegistryProduct from core.test.util import sample_product_file @@ -156,7 +163,10 @@ def guarded_df_from_file(instance, filepath, **kwargs): return original_df_from_file(instance, filepath, **kwargs) with mock.patch.object( - ProductHandle, "df_from_file", autospec=True, side_effect=guarded_df_from_file + ProductHandle, + "df_from_file", + autospec=True, + side_effect=guarded_df_from_file, ): response = self.client.post(url) @@ -180,6 +190,29 @@ def test_registry_retry(self): response = self.client.post(url) self.assertEqual(response.status_code, 200) + def test_registry_hats_collection(self): + product = self.create_product(specz=True) + ProductFile.objects.create( + product=product, + file=f"{product.path}/main", + role=0, + name="catalog.tar.gz", + size=123, + extension=".tar.gz", + storage_kind=FileStorageKind.HATS_COLLECTION, + metadata={ + "columns": ["ra", "dec", "z"], + "n_rows": 10, + "npartitions": 1, + }, + ) + url = reverse("products-registry", kwargs={"pk": product.pk}) + + response = self.client.post(url) + + self.assertEqual(response.status_code, 200) + self.assertEqual(ProductContent.objects.filter(product=product).count(), 3) + def test_registry_generates_table_preview_file(self): product = self.create_product(specz=True) self.upload_main_file(product, extension="csv") @@ -197,7 +230,9 @@ def test_registry_generates_table_preview_file(self): self.assertIn("count", payload) self.assertIn("columns", payload) self.assertIn("results", payload) - self.assertLessEqual(len(payload["results"]), RegistryProduct.TABLE_PREVIEW_ROWS) + self.assertLessEqual( + len(payload["results"]), RegistryProduct.TABLE_PREVIEW_ROWS + ) def test_read_data_uses_cached_table_preview(self): product = self.create_product(specz=True) @@ -227,7 +262,9 @@ def test_read_data_starts_background_preview_when_missing(self): preview_path.unlink(missing_ok=True) processing_path.unlink(missing_ok=True) - with mock.patch("core.views.product.build_product_table_preview.delay") as delay_mock: + with mock.patch( + "core.views.product.build_product_table_preview.delay" + ) as delay_mock: response = self.client.get( reverse("products-read-data", kwargs={"pk": product.pk}), {"page": 1, "page_size": 10}, @@ -252,7 +289,9 @@ def test_read_data_does_not_restart_preview_when_already_processing(self): processing_path.unlink(missing_ok=True) RegistryProduct.start_table_preview_processing(product) - with mock.patch("core.views.product.build_product_table_preview.delay") as delay_mock: + with mock.patch( + "core.views.product.build_product_table_preview.delay" + ) as delay_mock: response = self.client.get( reverse("products-read-data", kwargs={"pk": product.pk}), {"page": 1, "page_size": 10}, diff --git a/backend/core/views/product.py b/backend/core/views/product.py index e1bc757..1c9c31b 100644 --- a/backend/core/views/product.py +++ b/backend/core/views/product.py @@ -7,11 +7,12 @@ import yaml from core.models import ( + FileStorageKind, Product, ProductContent, + ProductStatus, ProductDownloadArchive, ProductDownloadArchiveStatus, - ProductStatus, ) from core.permissions import AccessControlMixin, ProductAccessPermission from core.product_handle import FileHandle, NotTableError @@ -469,6 +470,20 @@ def download_main_file(self, request, **kwargs): product = self.get_object() main_file = product.files.get(role=0) main_file_path = Path(main_file.file.path) + if main_file.storage_kind == FileStorageKind.HATS_COLLECTION: + with tempfile.TemporaryDirectory() as tmpdirname: + zip_file = self.zip_directory( + main_file_path, product.internal_name, tmpdirname + ) + size = zip_file.stat().st_size + file_handle = open(zip_file, "rb") + response = FileResponse(file_handle, content_type="application/zip") + response["Content-Length"] = size + response["Content-Disposition"] = ( + f"attachment; filename={zip_file.name}" + ) + return response + product_path = pathlib.Path( settings.MEDIA_ROOT, product.path, main_file_path ) @@ -494,6 +509,12 @@ def read_data(self, request, **kwargs): page_size = int(request.GET.get("page_size", 100)) product = self.get_object() + + product_file = product.files.get(role=0) + if product_file.storage_kind == FileStorageKind.HATS_COLLECTION: + content = {"message": "Table preview is not available for HATS products."} + return Response(content, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + preview_path = RegistryProduct.get_table_preview_path(product) try: @@ -565,6 +586,22 @@ def main_file_info(self, request, **kwargs): product = self.get_object() product_file = product.files.get(role=0) main_file_path = Path(product_file.file.path) + if product_file.storage_kind == FileStorageKind.HATS_COLLECTION: + data = self.get_serializer(instance=product).data + data["main_file"] = { + "name": product_file.name, + "type": product_file.type, + "extension": product_file.extension, + "size": product_file.size, + "n_rows": product_file.n_rows, + "storage_kind": product_file.storage_kind, + "metadata": product_file.metadata, + } + product_contents = self.__get_product_contents(product) + if product_contents: + data["main_file"]["associated_columns"] = product_contents + return Response(data) + product_path = pathlib.Path( settings.MEDIA_ROOT, product.path, main_file_path ) @@ -638,6 +675,34 @@ def pending_publication(self, request, **kwargs): content = {"error": str(e)} return Response(content, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + def zip_product(self, internal_name, path, tmpdir): + + product_path = pathlib.Path(settings.MEDIA_ROOT, path) + return self.zip_directory(product_path, internal_name, tmpdir) + + def zip_directory(self, directory, internal_name, tmpdir): + + directory = pathlib.Path(directory) + thash = "".join(secrets.choice(secrets.token_hex(16)) for i in range(5)) + zip_name = f"{internal_name}_{thash}.zip" + zip_path = pathlib.Path(tmpdir, zip_name) + + with zipfile.ZipFile( + zip_path, + "w", + compression=zipfile.ZIP_DEFLATED, + compresslevel=9, + ) as ziphandle: + for root, dirs, files in os.walk(directory): + for file in files: + origin_file = os.path.join(root, file) + arcname = pathlib.Path(origin_file).relative_to(directory) + ziphandle.write(origin_file, arcname=arcname) + + ziphandle.close() + + return zip_path + def destroy(self, request, pk=None, *args, **kwargs): """Product can only be deleted by the OWNER or if the user has an admin profile. diff --git a/backend/core/views/product_file.py b/backend/core/views/product_file.py index 0b9f9e1..ca4851b 100644 --- a/backend/core/views/product_file.py +++ b/backend/core/views/product_file.py @@ -1,8 +1,17 @@ import logging import os -from core.models import Product, ProductFile +from core.models import FileRoles, FileStorageKind, Product, ProductFile from core.serializers import ProductFileSerializer +from core.services.hats_collection import ( + HatsArchiveError, + NotHatsCollectionError, + UnsafeArchiveError, + is_hats_archive_name, + product_type_accepts_hats, + validate_and_store_hats_archive, + validate_hats_archive, +) from rest_framework import exceptions, mixins, status, viewsets from rest_framework.response import Response @@ -48,7 +57,51 @@ def perform_create(self, serializer): file = self.request.data.get("file") size = file.size - filename, extension = os.path.splitext(file.name) + _, extension = os.path.splitext(file.name) + + if file.name.lower().endswith(".tar.gz"): + extension = ".tar.gz" + elif file.name.lower().endswith(".tgz"): + extension = ".tgz" + + product = serializer.validated_data["product"] + role = serializer.validated_data.get("role", FileRoles.MAIN) + + if role == FileRoles.MAIN and is_hats_archive_name(file.name): + if product_type_accepts_hats(product.product_type.name): + try: + relative_path, metadata = validate_and_store_hats_archive( + file, product + ) + return serializer.save( + file=relative_path, + size=size, + name=file.name, + extension=extension, + storage_kind=FileStorageKind.HATS_COLLECTION, + metadata=metadata, + n_rows=metadata.get("n_rows"), + ) + except NotHatsCollectionError: + file.seek(0) + except (HatsArchiveError, UnsafeArchiveError) as exc: + raise exceptions.ValidationError({"error": str(exc)}) from exc + else: + try: + validate_hats_archive(file) + except NotHatsCollectionError: + file.seek(0) + except (HatsArchiveError, UnsafeArchiveError) as exc: + raise exceptions.ValidationError({"error": str(exc)}) from exc + else: + raise exceptions.ValidationError( + { + "error": ( + "HATS collections are not accepted for " + f"{product.product_type.display_name} products." + ) + } + ) return serializer.save( size=size, diff --git a/backend/requirements.txt b/backend/requirements.txt index a53f2e4..b90b4e6 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,4 +1,4 @@ -astropy==6.1.3 +astropy==6.1.5 celery==5.4.0 coverage==7.6.1 Django==5.1.15 @@ -15,8 +15,9 @@ drf-social-oauth2==3.1.0 drf-spectacular==0.27.2 gunicorn==23.0.0 h5py==3.11.0 +lsdb==0.7.3 numpy==2.1.1 -pandas==2.2.2 +pandas==2.2.3 psycopg2-binary==2.9.9 pyarrow==17.0.0 pydantic==2.9.2 diff --git a/frontend/components/newProduct/Step2.js b/frontend/components/newProduct/Step2.js index d38bd31..1b58470 100644 --- a/frontend/components/newProduct/Step2.js +++ b/frontend/components/newProduct/Step2.js @@ -20,6 +20,7 @@ import { MAX_UPLOAD_SIZE, createProductFile, deleteProductFile, + getProduct, getProductFiles, registryProduct } from '../../services/product' @@ -36,6 +37,12 @@ export default function NewProductStep2({ productId, onNext, onPrev }) { const [isLoading, setLoading] = useState(false) const [progress, setProgress] = useState(null) const [formError, setFormError] = React.useState('') + const [productType, setProductType] = useState(null) + const acceptsHats = productType !== 'training_results' + + const mainFileAccept = acceptsHats + ? '.csv,.fits,.fit,.hf5,.hdf5,.h5,.pq,.parquet,.vo,.vot,.xml,.txt,.tar,.tar.gz,.tgz,.zip' + : undefined const loadFiles = React.useCallback(async () => { setFormError('') @@ -74,6 +81,16 @@ export default function NewProductStep2({ productId, onNext, onPrev }) { loadFiles() }, [loadFiles]) + useEffect(() => { + getProduct(productId) + .then(product => { + setProductType(product.product_type_internal_name) + }) + .catch(() => { + setProductType(null) + }) + }, [productId]) + const handleNext = () => { setFormError('') setLoading(true) @@ -233,6 +250,12 @@ export default function NewProductStep2({ productId, onNext, onPrev }) { Validation Results, the file format is free (in case of multiple files, please provide them compressed in a .zip or .tar file). + {acceptsHats && ( + + This product type also accepts HATS collections packaged as .tar, + .tar.gz, .tgz, or .zip files. + + )} The maximum upload size is {MAX_UPLOAD_SIZE}MB. If your dataset is larger, please contact the Photo-z Server team to request an exception. @@ -274,6 +297,7 @@ export default function NewProductStep2({ productId, onNext, onPrev }) { {mainFile === false && ( { handleUploadFile(file, 0) }}