Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
"openpyxl >=3.0.7, !=3.1.1",
"GDX2py >=2.2.0",
"ijson >=3.1.4",
"chardet >=4.0.0",
"chardet >=7",
"PyMySQL[rsa] >=1.0.2",
"psycopg2-binary",
"pyarrow >= 19.0",
Expand Down
5 changes: 5 additions & 0 deletions spinedb_api/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def __init__(self, msg, rank=None, key=None):
self.rank = rank
self.key = key

def __eq__(self, other):
if not isinstance(other, InvalidMappingComponent):
return NotImplemented
return self.msg == other.msg and self.rank == other.rank and self.key == other.key


class ReaderError(SpineDBAPIError):
"""Failure in import reader."""
Expand Down
197 changes: 114 additions & 83 deletions spinedb_api/import_mapping/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,32 @@
"""
from collections.abc import Callable
from copy import deepcopy
from operator import itemgetter
from itertools import dropwhile
from typing import Any, Optional
from ..exception import ParameterValueFormatError
from ..helpers import string_to_bool
from ..import_functions import UnparseCallable
from ..mapping import Position, is_pivoted
from ..parameter_value import (
Array,
IndexedValue,
Map,
TimePattern,
TimeSeriesVariableResolution,
convert_leaf_maps_to_specialized_containers,
from_database,
split_value_and_type,
)
from .import_mapping import ImportMapping, check_validity
from .import_mapping import (
ArrayValueRecord,
ImportMapping,
MapValueRecord,
SemiMappedData,
TimePatternValueRecord,
TimeSeriesValueRecord,
ValueRecord,
check_validity,
)
from .import_mapping_compat import import_mapping_from_dict

_NO_VALUE = object()
Expand Down Expand Up @@ -178,6 +189,7 @@ def get_mapped_data(
_make_entities(mapped_data)
_make_entity_metadata(mapped_data)
_make_entity_alternatives(mapped_data, errors)
_make_parameter_definitions(mapped_data, unparse_value)
_make_parameter_values(mapped_data, unparse_value)
_make_parameter_value_metadata(mapped_data)
return mapped_data, errors
Expand Down Expand Up @@ -295,23 +307,34 @@ def _unpivot_rows(
return unpivoted_rows, pivoted_pos, non_pivoted_pos, unpivoted_column_pos


def _make_entity_classes(mapped_data):
rows = mapped_data.get("entity_classes")
if rows is None:
def _make_entity_classes(mapped_data: dict) -> None:
try:
rows = mapped_data.pop("entity_classes")
except KeyError:
return
rows = [(class_name, tuple(dimension_names)) for class_name, dimension_names in rows.items()]
rows.sort(key=itemgetter(1))
mapped_data["entity_classes"] = final_rows = []
for class_name, dimension_names in rows:
row = (class_name, tuple(dimension_names)) if dimension_names else (class_name,)
final_rows.append(row)
final_rows = []
for name, record in rows.items():
item = [name, record.dimensions]
if record.description:
item.append(record.description)
final_rows.append(item)
if final_rows:
mapped_data["entity_classes"] = final_rows


def _make_entities(mapped_data):
rows = mapped_data.get("entities")
if rows is None:
try:
rows = mapped_data.pop("entities")
except KeyError:
return
mapped_data["entities"] = list(rows)
final_rows = []
for (class_name, name), record in rows.items():
item = [class_name, name if not record.elements else record.elements]
if record.description:
item.append(record.description)
final_rows.append(item)
if final_rows:
mapped_data["entities"] = final_rows


def _make_entity_alternatives(mapped_data, errors):
Expand All @@ -332,35 +355,59 @@ def _make_entity_alternatives(mapped_data, errors):
mapped_data["entity_alternatives"] = rows


def _make_parameter_definitions(mapped_data: SemiMappedData, unparse_value: UnparseCallable) -> None:
key = "parameter_definitions"
try:
rows = mapped_data.pop(key)
except KeyError:
return
final_rows = []
for (entity_class_name, parameter_name), record in rows.items():
definition_data = [entity_class_name, parameter_name]
default_value = record.default_value
if isinstance(default_value, ValueRecord):
if default_value.has_value():
default_value = unparse_value(_make_value(default_value))
else:
default_value = None
elif isinstance(default_value, str):
try:
default_value = from_database(*split_value_and_type(default_value))
except ParameterValueFormatError:
pass
reversed_extras = [record.description, record.value_list_name, default_value]
definition_data += reversed(list(dropwhile(lambda x: x is None, reversed_extras)))
final_rows.append(definition_data)
if final_rows:
mapped_data[key] = final_rows


def _make_parameter_values(mapped_data, unparse_value):
value_pos = 3
key = "parameter_values"
rows = mapped_data.get(key)
if rows is not None:
valued_rows = []
for row in rows:
raw_value = _make_value(row, value_pos)
if raw_value is _NO_VALUE:
continue
value = unparse_value(raw_value)
if value is not None:
row[value_pos] = value
valued_rows.append(row)
mapped_data[key] = valued_rows
value_pos = 0
key = "parameter_definitions"
rows = mapped_data.get(key)
if rows is not None:
full_rows = []
for entity_definition, extras in rows.items():
if extras:
value = unparse_value(_make_value(extras, value_pos))
if value is not None:
extras[value_pos] = value
full_rows.append(entity_definition + tuple(extras))
try:
rows = mapped_data.pop(key)
except KeyError:
return
final_rows = []
for (entity_class_name, entity_byname, parameter_name, alternative_name), value in rows.items():
if isinstance(value, ValueRecord):
if value.has_value():
value = unparse_value(_make_value(value))
else:
full_rows.append(entity_definition)
mapped_data[key] = full_rows
value = None
elif isinstance(value, str):
try:
value = from_database(*split_value_and_type(value))
except ParameterValueFormatError:
pass
if value is None:
continue
value_data = [entity_class_name, entity_byname, parameter_name, value]
if alternative_name is not None:
value_data.append(alternative_name)
final_rows.append(value_data)
if final_rows:
mapped_data[key] = final_rows


def _make_parameter_value_metadata(mapped_data):
Expand All @@ -377,42 +424,28 @@ def _make_entity_metadata(mapped_data):
mapped_data["entity_metadata"] = list(rows)


def _make_value(row, value_pos):
try:
value = row[value_pos]
except IndexError:
return None
if isinstance(value, dict):
if "data" not in value:
return _NO_VALUE
return _parameter_value_from_dict(value)
if isinstance(value, str):
try:
return from_database(*split_value_and_type(value))
except ParameterValueFormatError:
pass
return value


def _parameter_value_from_dict(d):
mapped_index_names = d.get("index_names", {0: ""})
index_names = (max(mapped_index_names) + 1) * [""]
for i, name in mapped_index_names.items():
index_names[i] = name
if d["type"] == "map":
map_ = _table_to_map(d["data"], compress=d.get("compress", False))
if index_names != [""]:
_apply_index_names(map_, index_names)
return map_
if d["type"] == "time_pattern":
return TimePattern(*zip(*d["data"]), index_name=index_names[0])
if d["type"] == "time_series":
options = d.get("options", {})
ignore_year = options.get("ignore_year", False)
repeat = options.get("repeat", False)
return TimeSeriesVariableResolution(*zip(*d["data"]), ignore_year, repeat, index_name=index_names[0])
if d["type"] == "array":
return Array(d["data"], index_name=index_names[0])
def _make_value(record: ValueRecord) -> IndexedValue:
match record:
case ArrayValueRecord():
index_name = record.index_names[0] if record.index_names else ""
return Array(record.values, index_name=index_name)
case TimePatternValueRecord():
index_name = record.index_names[0] if record.index_names else ""
indexes = [i[0] for i in record.indexes]
return TimePattern(indexes, record.values, index_name)
case TimeSeriesValueRecord():
index_name = record.index_names[0] if record.index_names else ""
indexes = [i[0] for i in record.indexes]
return TimeSeriesVariableResolution(indexes, record.values, record.ignore_year, record.repeat, index_name)
case MapValueRecord():
map_value = _table_to_map(
([*indexes, values] for indexes, values in zip(record.indexes, record.values)), record.compress
)
if record.index_names:
_apply_index_names(map_value, record.index_names)
return map_value
case _:
raise RuntimeError(f"logic error: unknown record type '{type(record).__name__}'")


def _table_to_map(table, compress=False):
Expand Down Expand Up @@ -456,24 +489,22 @@ def _apply_index_names(map_, index_names):
"""
name = index_names[0]
if name:
map_.index_name = index_names[0]
map_.index_name = name
if len(index_names) == 1:
return
for v in map_.values:
if isinstance(v, Map):
_apply_index_names(v, index_names[1:])


def _ensure_mapping_name_consistency(mappings, mapping_names):
def _ensure_mapping_name_consistency(mappings: list[ImportMapping], mapping_names: list[str]) -> None:
"""Makes sure that there are as many mapping names as actual mappings.

Args:
mappings (list(ImportMapping)): list of mappings
mapping_names (list(str)): list of mapping names
mappings: list of mappings
mapping_names: list of mapping names
"""
n_mappings = len(mappings)
n_mapping_names = len(mapping_names)
if n_mapping_names > n_mappings:
mapping_names = mapping_names[:n_mappings]
elif n_mapping_names < n_mappings:
if n_mapping_names < n_mappings:
mapping_names += [""] * (n_mappings - n_mapping_names)
Loading