From af849d5c4d24c0f0d9257cafba012418cd1fa9eb Mon Sep 17 00:00:00 2001 From: Konstantin Pieper Date: Tue, 12 May 2026 18:35:05 -0400 Subject: [PATCH 01/22] better validation of inputs --- src/dial_dataclass/dial_dataclass.py | 67 +++++++++++++++++++++++++--- src/dial_service/dial_service.py | 3 +- src/dial_service/serverside_data.py | 9 +++- 3 files changed, 72 insertions(+), 7 deletions(-) diff --git a/src/dial_dataclass/dial_dataclass.py b/src/dial_dataclass/dial_dataclass.py index ea8cceb..79d6c5b 100644 --- a/src/dial_dataclass/dial_dataclass.py +++ b/src/dial_dataclass/dial_dataclass.py @@ -1,6 +1,11 @@ from typing import Annotated, Literal -from pydantic import BaseModel, Field, field_validator +from pydantic import ( + BaseModel, + Field, + field_validator, + model_validator, +) from .pydantic_helpers import ValidatedObjectId @@ -11,10 +16,39 @@ BackendType = Literal[_POSSIBLE_BACKENDS] +def validate_dims_and_length(data, x_name, y_name, dim_x=None): + """validate the lengths of datasets""" + + #print(data) + + dataset_x = data[x_name] + dataset_y = data[y_name] + + len_x = len(dataset_x) + len_y = len(dataset_y) + + if len_y != len_x: + msg = (f'Unequal number of points in {x_name} {len_x=}' + f' and {y_name} {len_y=}.') + raise ValueError(msg) + + if dim_x is None and len_x == 0: + msg = ('Can not infer dim_x from empty dataset.' + 'Set dim_x to the correct dimension.') + raise ValueError(msg) + + if dim_x is not None and len_x > 0 and len(dataset_x[0]) != dim_x: + msg = (f'Vectors in {x_name} must be of length {dim_x=}.' + 'Set dim_x to the correct dimension.') + raise ValueError(msg) + + class _DialWorkflowCreationParams(BaseModel): """This comprises the information needed to create a DIAL workflow. - This is a base class which should not be directly imported, clients should use "DialWorkflowCreationParamsClient" (in this file) and services should use "DialWorkflowCreationParamsService" (exported from the service) + This is a base class which should not be directly imported, clients should use + "DialWorkflowCreationParamsClient" (in this file) and services should use + "DialWorkflowCreationParamsService" (exported from the service) """ dataset_x: Annotated[ @@ -29,14 +63,17 @@ class _DialWorkflowCreationParams(BaseModel): dataset_y: Annotated[ list[float], Field( - description='The output values of the training data. Length should equal dataset_x', + description=('The output values of the training data.' + ' Length should equal dataset_x'), ), ] y_is_good: Annotated[ bool, Field( default=True, # <-- Set default here - description='If true, treat higher y values as better (e.g. y represents yield or profit). If false, opposite (e.g. y represents error or waste)', + description=('If true, treat higher y values as better' + ' (e.g. y represents yield or profit).' + ' If false, opposite (e.g. y represents error or waste)'), ), ] kernel: Literal['rbf', 'matern', 'linear'] @@ -55,7 +92,12 @@ class _DialWorkflowCreationParams(BaseModel): description='Specific RNG seed - use -1 to use system default', ), ] - dim_x: Annotated[int, Field(default=1)] + dim_x: Annotated[int | None, Field( + default=None, + description=('Provide the dimension of x explicitly,' + ' e.g. if the initial dataset is empty.' + ' If None, it will be inferred from dataset_x if possible.'), + )] preprocess_log: bool = Field(default=False) preprocess_standardize: bool = Field(default=False) @@ -93,6 +135,12 @@ def order_bounds(cls, bounds: list[list[float]]): row.sort() return bounds + @model_validator(mode='after') + def validate_dims_and_length(self, values): + validate_dims_and_length(vars(self), + 'dataset_x', 'dataset_y', dim_x=self.dim_x) + return self + # this class is specific to clients; they have no way of knowing which backends the Service supports, so we allow all of them class DialWorkflowCreationParamsClient(_DialWorkflowCreationParams): @@ -102,6 +150,7 @@ class DialWorkflowCreationParamsClient(_DialWorkflowCreationParams): class DialWorkflowDatasetUpdate(BaseModel): + """This class is used to send a single update to the dataset.""" workflow_id: ValidatedObjectId next_x: list[float] = Field( description='The next collection of X values you want to append to your overall data', @@ -125,6 +174,7 @@ class DialWorkflowDatasetUpdate(BaseModel): class DialWorkflowDatasetUpdates(BaseModel): + """This class is used to send multiple updates to the dataset.""" workflow_id: ValidatedObjectId next_x_list: list[list[float]] = Field(min_length=1) next_y_list: list[float] = Field(min_length=1) @@ -132,6 +182,11 @@ class DialWorkflowDatasetUpdates(BaseModel): backend_args: dict[str, float | int | bool | str | list[float] | tuple] | None = None extra_args: dict[str, float | int | bool | str | list[float] | tuple] | None = None + @model_validator(mode='after') + def validate_dims_and_length(self, values): + validate_dims_and_length(vars(self), + 'next_x_list', 'next_y_list') + return self class DialInputSingleConfidenceBound(BaseModel): workflow_id: ValidatedObjectId @@ -212,6 +267,7 @@ class DialInputSingleOtherStrategy(BaseModel): class DialInputMultipleOtherStrategy(BaseModel): + """TODO: document this""" workflow_id: ValidatedObjectId points: PositiveIntType strategy: Literal[ @@ -264,6 +320,7 @@ class DialInputPredictions(BaseModel): """This is the input dataclass for Dial for requesting a surrogate evaluation at a given number of points.""" workflow_id: ValidatedObjectId + points_to_predict: list[list[float]] extra_args: dict[str, float | int | bool | str | list[float] | tuple] | None = Field( default=None diff --git a/src/dial_service/dial_service.py b/src/dial_service/dial_service.py index af5ce68..f6d93d3 100644 --- a/src/dial_service/dial_service.py +++ b/src/dial_service/dial_service.py @@ -87,7 +87,8 @@ def update_workflow_with_data( ) -> ValidatedObjectId: """Updates the DB with the provided params. Success of operation is based off whether or not the INTERSECT response is an error.""" - # TODO - all exceptions should realistically provide error information to the client. INTERSECT-SDK v0.9 will introduce a specific exception we can throw which will allow us to do this. + # TODO - all exceptions should realistically provide error information to the client. + # INTERSECT-SDK v0.9 will introduce a specific exception we can throw which will allow us to do this. try: db_get_result = self.mongo_handler.get_workflow(update_params.workflow_id) except Exception: diff --git a/src/dial_service/serverside_data.py b/src/dial_service/serverside_data.py index 74405ec..d871587 100644 --- a/src/dial_service/serverside_data.py +++ b/src/dial_service/serverside_data.py @@ -14,6 +14,14 @@ # this is an extended version of ActiveLearningInputData. This allows us to add on properties and methods to this class without impacting the client side class ServersideInputBase: def __init__(self, data: DialWorkflowCreationParamsService): + self.dim_x = data.dim_x + + if self.dim_x is None: + if len(data.dataset_x) > 0: + self.dim_x = len(data.dataset_x[0]) + else: + raise ValueError('dim_x not set, and no dataset_x provided. Can not deduce dim_x.') + self.X_raw = np.array(data.dataset_x) self.Y_raw = np.array(data.dataset_y) # it seems like there should be a smarter way to do this, but stuff involving loops doesn't work with static autocompleters: @@ -28,7 +36,6 @@ def __init__(self, data: DialWorkflowCreationParamsService): self.backend_args = data.backend_args self.kernel_args = data.kernel_args self.extra_args = data.extra_args - self.dim_x = data.dim_x @cached_property def stddev(self) -> float: From f620bfd76ab58ca39e9f5be21ccb312c18075fcc Mon Sep 17 00:00:00 2001 From: Konstantin Pieper Date: Tue, 12 May 2026 20:04:25 -0400 Subject: [PATCH 02/22] draft: allow for multidimensional output. --- src/dial_dataclass/dial_dataclass.py | 125 ++++++++++++++++++++------- src/dial_service/serverside_data.py | 21 ++--- 2 files changed, 104 insertions(+), 42 deletions(-) diff --git a/src/dial_dataclass/dial_dataclass.py b/src/dial_dataclass/dial_dataclass.py index 79d6c5b..3cbca2b 100644 --- a/src/dial_dataclass/dial_dataclass.py +++ b/src/dial_dataclass/dial_dataclass.py @@ -1,4 +1,4 @@ -from typing import Annotated, Literal +from typing import Annotated, Literal, Optional from pydantic import ( BaseModel, @@ -16,11 +16,27 @@ BackendType = Literal[_POSSIBLE_BACKENDS] -def validate_dims_and_length(data, x_name, y_name, dim_x=None): +def _validate_dataset_lengths(dataset: list[any]) -> bool: + """validate the lengths of dataset entries""" + if len(dataset) > 1: + data = dataset[0] + if data is list: + target_length = len(dataset[0]) + for row in dataset[1:]: + if len(row) != target_length: + return False + return True + + +def _validate_dims_and_length(data: dict, + x_name: str, + y_name: str) -> tuple[int, int]: """validate the lengths of datasets""" - #print(data) + print(data) + dim_x = data.get("dim_x") + dim_y = data.get("dim_y") dataset_x = data[x_name] dataset_y = data[y_name] @@ -32,17 +48,33 @@ def validate_dims_and_length(data, x_name, y_name, dim_x=None): f' and {y_name} {len_y=}.') raise ValueError(msg) - if dim_x is None and len_x == 0: - msg = ('Can not infer dim_x from empty dataset.' - 'Set dim_x to the correct dimension.') - raise ValueError(msg) + def compute_dim(dim, dataset, name): + lenn = len(dataset) + if dim is None and lenn == 0: + msg = (f'Can not infer dim from empty dataset {name}.' + 'Set dim to the correct dimension.') + raise ValueError(msg) + else: + inferred_len = len(dataset[0]) if dataset[0] is list else 1 + if dim is not None and inferred_len != dim_x: + msg = (f'Vectors in {name} must be of length {dim=}.' + 'Set dim to the correct dimension.') + raise ValueError(msg) + else: + dim = inferred_len + return dim - if dim_x is not None and len_x > 0 and len(dataset_x[0]) != dim_x: - msg = (f'Vectors in {x_name} must be of length {dim_x=}.' - 'Set dim_x to the correct dimension.') - raise ValueError(msg) + dim_x = compute_dim(dim_x, dataset_x, x_name) + dim_y = compute_dim(dim_y, dataset_y, y_name) + + return dim_x, dim_y +def _validate_labels(data: dict, + x_name: str, + y_name: str) -> tuple[str, str]: + """validate the lengths of datasets""" + class _DialWorkflowCreationParams(BaseModel): """This comprises the information needed to create a DIAL workflow. @@ -61,12 +93,41 @@ class _DialWorkflowCreationParams(BaseModel): Field(description='The input vectors of the training data'), ] dataset_y: Annotated[ - list[float], + list[float | Annotated[ + # TODO: this could be the default + list[float], + Field(description='Field lengths of all subarrays should be equal'), + ] + ], Field( description=('The output values of the training data.' ' Length should equal dataset_x'), ), ] + x_labels: Annotated[ + str | list[str], + Field(default='x', + description='Labels for input variables x.') + ] + y_labels: Annotated[ + str | list[str], + Field(default='y', + description='Labels for output variables y.') + ] + dim_x: Annotated[int | None, Field( + default=None, + description=('Provide the dimension of entries in dataset_x explicitly,' + ' e.g. if the initial dataset is empty.' + ' If None, it will be inferred from dataset_x if possible.'), + )] + dim_y: Annotated[int | None, Field( + default=1, + description=('Provide the dimension of entries in dataset_y explicitly,' + ' e.g. if the initial dataset is empty.' + ' If None, it will be inferred from dataset_y if possible.'), + )] + # maximize_y: TODO, provide a language to determine which y to maximize + y_is_good: Annotated[ bool, Field( @@ -92,12 +153,6 @@ class _DialWorkflowCreationParams(BaseModel): description='Specific RNG seed - use -1 to use system default', ), ] - dim_x: Annotated[int | None, Field( - default=None, - description=('Provide the dimension of x explicitly,' - ' e.g. if the initial dataset is empty.' - ' If None, it will be inferred from dataset_x if possible.'), - )] preprocess_log: bool = Field(default=False) preprocess_standardize: bool = Field(default=False) @@ -115,17 +170,14 @@ class _DialWorkflowCreationParams(BaseModel): ) """Miscellaneous additional arguments.""" - @field_validator('dataset_x') + @field_validator('dataset_x', 'dataset_y') @classmethod - def ensure_consistent_dataset_x_lengths(cls, x): - if len(x) < 2: - return x - target_length = len(x[0]) - for row in x[1:]: - if len(row) != target_length: - msg = 'Unequal vector lengths in dataset_x' - raise ValueError(msg) - return x + def ensure_consistent_dataset_x_lengths(cls, dataset, ctx): + is_valid = _validate_dataset_lengths(dataset) + if not is_valid: + msg = f'Unequal vector lengths in {ctx.field_name}' + raise ValueError(msg) + return dataset # order rows as [low, high] - do NOT error out here, we can efficiently handle normalization @field_validator('bounds') @@ -137,8 +189,8 @@ def order_bounds(cls, bounds: list[list[float]]): @model_validator(mode='after') def validate_dims_and_length(self, values): - validate_dims_and_length(vars(self), - 'dataset_x', 'dataset_y', dim_x=self.dim_x) + _validate_dims_and_length(vars(self), + 'dataset_x', 'dataset_y') return self @@ -182,10 +234,19 @@ class DialWorkflowDatasetUpdates(BaseModel): backend_args: dict[str, float | int | bool | str | list[float] | tuple] | None = None extra_args: dict[str, float | int | bool | str | list[float] | tuple] | None = None + @field_validator('next_x_list', 'next_y_list') + @classmethod + def ensure_consistent_dataset_x_lengths(cls, dataset, ctx): + is_valid = _validate_dataset_lengths(dataset) + if not is_valid: + msg = f'Unequal vector lengths in {ctx.field_name}' + raise ValueError(msg) + return dataset + @model_validator(mode='after') def validate_dims_and_length(self, values): - validate_dims_and_length(vars(self), - 'next_x_list', 'next_y_list') + _validate_dims_and_length(vars(self), + 'next_x_list', 'next_y_list') return self class DialInputSingleConfidenceBound(BaseModel): diff --git a/src/dial_service/serverside_data.py b/src/dial_service/serverside_data.py index d871587..df35637 100644 --- a/src/dial_service/serverside_data.py +++ b/src/dial_service/serverside_data.py @@ -38,7 +38,17 @@ def __init__(self, data: DialWorkflowCreationParamsService): self.extra_args = data.extra_args @cached_property - def stddev(self) -> float: + def X_train(self) -> np.ndarray: + """ + Return X scaled to [0, 1] per dimension based on self.bounds. + + dataset_x: list[list[float]], shape (N, D) + bounds: list[[low, high], ...], shape (D, 2) + """ + return self._scale_X(self.X_raw) + + @cached_property + def Y_stddev(self) -> float: return np.std(self.Y_train) @cached_property @@ -72,15 +82,6 @@ def _scale_X(self, X: np.ndarray) -> np.ndarray: return (X - lows) / span - @cached_property - def X_train(self) -> np.ndarray: - """ - Return X scaled to [0, 1] per dimension based on self.bounds. - - dataset_x: list[list[float]], shape (N, D) - bounds: list[[low, high], ...], shape (D, 2) - """ - return self._scale_X(self.X_raw) # undoes the preprocessing. def inverse_transform(self, data: np.ndarray, is_stddev: bool = False): From 022542522d2db478d5cb8a150dbfd9a2eea0a6d6 Mon Sep 17 00:00:00 2001 From: Konstantin Pieper Date: Tue, 12 May 2026 20:05:34 -0400 Subject: [PATCH 03/22] add todo --- src/dial_service/backends/gpax_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dial_service/backends/gpax_backend.py b/src/dial_service/backends/gpax_backend.py index ee0b064..2cbfd74 100644 --- a/src/dial_service/backends/gpax_backend.py +++ b/src/dial_service/backends/gpax_backend.py @@ -46,7 +46,8 @@ def predict(model, data): # mean, y_var = model.predict(rng_key_predict, data.x_predict) # TODO check why model.predict().reshape() fails mean, y_var = model.predict(rng_key_predict, x.reshape(1, -1)) - return mean[0], data.stddev * y_var[0] + # TODO: why do we scale the variance by stddev, but not mean? + return mean[0], data.Y_stddev * y_var[0] @staticmethod def get_kernel(data): From db6f1ed650a4c143ddc5134b4019a0e9588203d3 Mon Sep 17 00:00:00 2001 From: Konstantin Pieper Date: Tue, 12 May 2026 23:30:26 -0400 Subject: [PATCH 04/22] more validation --- src/dial_dataclass/dial_dataclass.py | 69 +++++++++++++++++++--------- src/dial_service/core.py | 1 + src/dial_service/serverside_data.py | 10 ++-- 3 files changed, 52 insertions(+), 28 deletions(-) diff --git a/src/dial_dataclass/dial_dataclass.py b/src/dial_dataclass/dial_dataclass.py index 3cbca2b..96aea12 100644 --- a/src/dial_dataclass/dial_dataclass.py +++ b/src/dial_dataclass/dial_dataclass.py @@ -5,6 +5,7 @@ Field, field_validator, model_validator, + ValidationInfo, ) from .pydantic_helpers import ValidatedObjectId @@ -20,7 +21,7 @@ def _validate_dataset_lengths(dataset: list[any]) -> bool: """validate the lengths of dataset entries""" if len(dataset) > 1: data = dataset[0] - if data is list: + if isinstance(data, list): target_length = len(dataset[0]) for row in dataset[1:]: if len(row) != target_length: @@ -31,7 +32,7 @@ def _validate_dataset_lengths(dataset: list[any]) -> bool: def _validate_dims_and_length(data: dict, x_name: str, y_name: str) -> tuple[int, int]: - """validate the lengths of datasets""" + """validate the lengths of datasets, and compte dim_x and dim_y""" print(data) @@ -70,10 +71,29 @@ def compute_dim(dim, dataset, name): return dim_x, dim_y -def _validate_labels(data: dict, - x_name: str, - y_name: str) -> tuple[str, str]: - """validate the lengths of datasets""" +def _validate_labels(cls) -> tuple[str, str]: + """validate the lengths of labels""" + labels_x, labels_y = cls.labels_x, cls.labels_y + dim_x, dim_y = cls.dim_x, cls.dim_y + + def compute_labels(dim, labels): + if isinstance(labels, list): + if dim is not None and dim != len(labels): + msg = f'Labels {labels} ar not consistent with data dimension {dim=}' + raise ValueError(msg) + else: + if dim > 1: + # give each parameter a unique label by appending a number + labels = [f'{labels}_{i+1}' for i in range(dim)] + else: + labels = [labels] + return labels + + labels_x = compute_labels(dim_x, labels_x) + labels_y = compute_labels(dim_y, labels_y) + + return labels_x, labels_y + class _DialWorkflowCreationParams(BaseModel): """This comprises the information needed to create a DIAL workflow. @@ -104,23 +124,23 @@ class _DialWorkflowCreationParams(BaseModel): ' Length should equal dataset_x'), ), ] - x_labels: Annotated[ + labels_x: Annotated[ str | list[str], Field(default='x', description='Labels for input variables x.') ] - y_labels: Annotated[ + labels_y: Annotated[ str | list[str], Field(default='y', description='Labels for output variables y.') ] - dim_x: Annotated[int | None, Field( + dim_x: Annotated[PositiveIntType | None, Field( default=None, description=('Provide the dimension of entries in dataset_x explicitly,' ' e.g. if the initial dataset is empty.' ' If None, it will be inferred from dataset_x if possible.'), )] - dim_y: Annotated[int | None, Field( + dim_y: Annotated[PositiveIntType | None, Field( default=1, description=('Provide the dimension of entries in dataset_y explicitly,' ' e.g. if the initial dataset is empty.' @@ -140,7 +160,7 @@ class _DialWorkflowCreationParams(BaseModel): kernel: Literal['rbf', 'matern', 'linear'] bounds: list[ Annotated[ - Annotated[list[float], Field(min_length=2, max_length=2)], + list[float], Field(min_length=2, max_length=2), ] ] @@ -172,27 +192,34 @@ class _DialWorkflowCreationParams(BaseModel): @field_validator('dataset_x', 'dataset_y') @classmethod - def ensure_consistent_dataset_x_lengths(cls, dataset, ctx): + def ensure_consistent_dataset_lengths(cls, dataset, info: ValidationInfo): is_valid = _validate_dataset_lengths(dataset) if not is_valid: - msg = f'Unequal vector lengths in {ctx.field_name}' + msg = f'Unequal vector lengths in {info.field_name}' raise ValueError(msg) return dataset + @model_validator(mode='after') + def validate_dims_and_length(self, values): + # compute the dimensions and validate consistency + self.dim_x, self.dim_y = _validate_dims_and_length(vars(self), + 'dataset_x', 'dataset_y') + # compute or validate labels + self.labels_x, self.labels_y = _validate_labels(self) + return self + # order rows as [low, high] - do NOT error out here, we can efficiently handle normalization - @field_validator('bounds') + @field_validator('bounds', mode='after') @classmethod - def order_bounds(cls, bounds: list[list[float]]): + def order_bounds(cls, bounds: list[list[float]], info: ValidationInfo): + dim_x = info.data.get('dim_x') + if len(bounds) != dim_x: + msg = f'Bounds have incorrect length {len(bounds)} != {dim_x=}' + raise ValueError(msg) for row in bounds: row.sort() return bounds - @model_validator(mode='after') - def validate_dims_and_length(self, values): - _validate_dims_and_length(vars(self), - 'dataset_x', 'dataset_y') - return self - # this class is specific to clients; they have no way of knowing which backends the Service supports, so we allow all of them class DialWorkflowCreationParamsClient(_DialWorkflowCreationParams): diff --git a/src/dial_service/core.py b/src/dial_service/core.py index 6b6418f..175f878 100644 --- a/src/dial_service/core.py +++ b/src/dial_service/core.py @@ -87,6 +87,7 @@ def get_surrogate_values(data: ServersideInputPrediction, model: Any) -> list[li module = get_backend_module(backend) means, stddevs = module.predict(model, data) means = data.inverse_transform(means) + # TODO: adjust inverse transform. transformed_stddevs = data.inverse_transform(stddevs, is_stddev=True) return [means.tolist(), transformed_stddevs.tolist(), stddevs.tolist()] diff --git a/src/dial_service/serverside_data.py b/src/dial_service/serverside_data.py index df35637..493f327 100644 --- a/src/dial_service/serverside_data.py +++ b/src/dial_service/serverside_data.py @@ -15,13 +15,9 @@ class ServersideInputBase: def __init__(self, data: DialWorkflowCreationParamsService): self.dim_x = data.dim_x - - if self.dim_x is None: - if len(data.dataset_x) > 0: - self.dim_x = len(data.dataset_x[0]) - else: - raise ValueError('dim_x not set, and no dataset_x provided. Can not deduce dim_x.') - + self.dim_y = data.dim_y + self.labels_x = data.labels_x + self.labels_y = data.labels_y self.X_raw = np.array(data.dataset_x) self.Y_raw = np.array(data.dataset_y) # it seems like there should be a smarter way to do this, but stuff involving loops doesn't work with static autocompleters: From cb7c59f8bccbd502bccb77f0dc7af7e65a31a567 Mon Sep 17 00:00:00 2001 From: Konstantin Pieper Date: Wed, 27 May 2026 10:46:44 -0400 Subject: [PATCH 05/22] backward compatible extensions to dial_dataclass add statistics_y to encode the output and error model --- src/dial_dataclass/dial_dataclass.py | 96 ++++++++++++++++++++-------- 1 file changed, 69 insertions(+), 27 deletions(-) diff --git a/src/dial_dataclass/dial_dataclass.py b/src/dial_dataclass/dial_dataclass.py index 96aea12..fb59116 100644 --- a/src/dial_dataclass/dial_dataclass.py +++ b/src/dial_dataclass/dial_dataclass.py @@ -1,3 +1,4 @@ +from abc import ABC from typing import Annotated, Literal, Optional from pydantic import ( @@ -10,12 +11,40 @@ from .pydantic_helpers import ValidatedObjectId -PositiveIntType = Annotated[int, Field(ge=0)] - _POSSIBLE_BACKENDS = ('sklearn', 'gpax', 'sable') BackendType = Literal[_POSSIBLE_BACKENDS] +PositiveIntType = Annotated[int, Field(ge=0)] + +Label = Annotated[str, Field( + max_length=50, + description='Label for a dataset entry.' +)] +FloatOrLabel = Annotated[float | Label, Field( + description='A constant float, or the label of a dataset entry.', +)] + + +class Distribution(BaseModel, ABC): + """Base class for a statistical distribution.""" + loc: Annotated[Label, Field( + description='The location (mean, or center) of the distribution', + )] + scale: Annotated[FloatOrLabel, Field( + description='The scale or standard deviation of the distribution', + )] + + +class Delta(Distribution): + """The Delta distribution is deterministic and equal to its mean.""" + scale: float = Field(gt=0., lt=0., default=0., frozen=True) + + +class Normal(Distribution): + """The normal distribution is determined by loc (mean) and scale (standard deviation).""" + + def _validate_dataset_lengths(dataset: list[any]) -> bool: """validate the lengths of dataset entries""" @@ -32,12 +61,12 @@ def _validate_dataset_lengths(dataset: list[any]) -> bool: def _validate_dims_and_length(data: dict, x_name: str, y_name: str) -> tuple[int, int]: - """validate the lengths of datasets, and compte dim_x and dim_y""" + """validate the lengths of datasets, and compute dim_x and dim_y""" - print(data) + print('\n'.join(f"{k}: {v}" for k, v in data.items())) - dim_x = data.get("dim_x") - dim_y = data.get("dim_y") + dim_x = data.get('dim_x') + dim_y = data.get('dim_y') dataset_x = data[x_name] dataset_y = data[y_name] @@ -55,14 +84,15 @@ def compute_dim(dim, dataset, name): msg = (f'Can not infer dim from empty dataset {name}.' 'Set dim to the correct dimension.') raise ValueError(msg) - else: + + if lenn > 0: inferred_len = len(dataset[0]) if dataset[0] is list else 1 if dim is not None and inferred_len != dim_x: msg = (f'Vectors in {name} must be of length {dim=}.' 'Set dim to the correct dimension.') raise ValueError(msg) - else: - dim = inferred_len + dim = inferred_len + return dim dim_x = compute_dim(dim_x, dataset_x, x_name) @@ -81,12 +111,12 @@ def compute_labels(dim, labels): if dim is not None and dim != len(labels): msg = f'Labels {labels} ar not consistent with data dimension {dim=}' raise ValueError(msg) + elif dim > 1: + # give each parameter a unique label by appending a number + labels = [f'{labels}_{i+1}' for i in range(dim)] else: - if dim > 1: - # give each parameter a unique label by appending a number - labels = [f'{labels}_{i+1}' for i in range(dim)] - else: - labels = [labels] + # normalize to single element list + labels = [labels] return labels labels_x = compute_labels(dim_x, labels_x) @@ -125,28 +155,40 @@ class _DialWorkflowCreationParams(BaseModel): ), ] labels_x: Annotated[ - str | list[str], + Label | list[Label], Field(default='x', description='Labels for input variables x.') ] labels_y: Annotated[ - str | list[str], + Label | list[Label], Field(default='y', description='Labels for output variables y.') ] - dim_x: Annotated[PositiveIntType | None, Field( - default=None, - description=('Provide the dimension of entries in dataset_x explicitly,' - ' e.g. if the initial dataset is empty.' - ' If None, it will be inferred from dataset_x if possible.'), + dim_x: Annotated[ + PositiveIntType | None, + Field( + default=None, + description=('Provide the dimension of entries in dataset_x explicitly,' + ' e.g. if the initial dataset is empty.' + ' If None, it will be inferred from dataset_x if possible.'), )] - dim_y: Annotated[PositiveIntType | None, Field( - default=1, - description=('Provide the dimension of entries in dataset_y explicitly,' - ' e.g. if the initial dataset is empty.' - ' If None, it will be inferred from dataset_y if possible.'), + dim_y: Annotated[ + PositiveIntType | None, + Field( + default=1, + description=('Provide the dimension of entries in dataset_y explicitly,' + ' e.g. if the initial dataset is empty.' + ' If None, it will be inferred from dataset_y if possible.'), + )] + statistics_y: Annotated[ + Distribution, + Field( + default=Delta(loc='y'), + description=('Provide the statistical model underlying the y data: For example:' + " Delta(loc='y') means that the y data is without error," + " Normal(loc='y', scale=0.1) is a standard error with mean y and standard deviation 0.1" + " Normal(loc='y', scale='yerr') takes the std.dev. from the data column yerr.") )] - # maximize_y: TODO, provide a language to determine which y to maximize y_is_good: Annotated[ bool, From ef866f0b61dcb2cbc7fe40d317f2c23de5ff90ca Mon Sep 17 00:00:00 2001 From: Konstantin Pieper Date: Wed, 27 May 2026 13:56:54 -0400 Subject: [PATCH 06/22] fix dial_dataclass validation of bounds --- src/dial_dataclass/dial_dataclass.py | 37 +++++++++++++++++----------- tests/unit/test_internals.py | 1 + 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/dial_dataclass/dial_dataclass.py b/src/dial_dataclass/dial_dataclass.py index fb59116..527fe86 100644 --- a/src/dial_dataclass/dial_dataclass.py +++ b/src/dial_dataclass/dial_dataclass.py @@ -86,18 +86,29 @@ def compute_dim(dim, dataset, name): raise ValueError(msg) if lenn > 0: - inferred_len = len(dataset[0]) if dataset[0] is list else 1 - if dim is not None and inferred_len != dim_x: + print(dataset[0]) + inferred_dim = len(dataset[0]) if isinstance(dataset[0], list) else 1 + if dim is not None and inferred_dim != dim: msg = (f'Vectors in {name} must be of length {dim=}.' 'Set dim to the correct dimension.') raise ValueError(msg) - dim = inferred_len + dim = inferred_dim return dim + # compute dimensions and validate consistency dim_x = compute_dim(dim_x, dataset_x, x_name) dim_y = compute_dim(dim_y, dataset_y, y_name) + print(dim_x, dataset_x, x_name) + print(dim_y, dataset_y, y_name) + + # validate bounds, if they exist + bounds = data.get('bounds') + if bounds is not None and len(bounds) != dim_x: + msg = f'Bounds have incorrect length {len(bounds)} != {dim_x=}' + raise ValueError(msg) + return dim_x, dim_y @@ -241,6 +252,14 @@ def ensure_consistent_dataset_lengths(cls, dataset, info: ValidationInfo): raise ValueError(msg) return dataset + # order rows as [low, high] - do NOT error out here, we can efficiently handle normalization + @field_validator('bounds', mode='after') + @classmethod + def order_bounds(cls, bounds: list[list[float]], info: ValidationInfo): + for row in bounds: + row.sort() + return bounds + @model_validator(mode='after') def validate_dims_and_length(self, values): # compute the dimensions and validate consistency @@ -250,18 +269,6 @@ def validate_dims_and_length(self, values): self.labels_x, self.labels_y = _validate_labels(self) return self - # order rows as [low, high] - do NOT error out here, we can efficiently handle normalization - @field_validator('bounds', mode='after') - @classmethod - def order_bounds(cls, bounds: list[list[float]], info: ValidationInfo): - dim_x = info.data.get('dim_x') - if len(bounds) != dim_x: - msg = f'Bounds have incorrect length {len(bounds)} != {dim_x=}' - raise ValueError(msg) - for row in bounds: - row.sort() - return bounds - # this class is specific to clients; they have no way of knowing which backends the Service supports, so we allow all of them class DialWorkflowCreationParamsClient(_DialWorkflowCreationParams): diff --git a/tests/unit/test_internals.py b/tests/unit/test_internals.py index 75ee119..d4ee811 100644 --- a/tests/unit/test_internals.py +++ b/tests/unit/test_internals.py @@ -139,6 +139,7 @@ def multiple_2D(backend, strategy): workflow_state = DialWorkflowCreationParamsService( dataset_x=[], dataset_y=[], + dim_x=2, # provide dim_x for empty dataset y_is_good=False, kernel='rbf', length_per_dimension=False, From 78c76782e0cf8a2b9579718c64b8fa76bea073bd Mon Sep 17 00:00:00 2001 From: Konstantin Pieper Date: Fri, 29 May 2026 14:14:46 -0400 Subject: [PATCH 07/22] add autogenerated files to gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 390f1e3..a52249c 100644 --- a/.gitignore +++ b/.gitignore @@ -210,5 +210,6 @@ pyrightconfig.json # End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode .venv* - Taskfile.yml +.DS_Store +*~ From bf1d6e8f828691151a11048721fe6c1ed059a76f Mon Sep 17 00:00:00 2001 From: Konstantin Pieper Date: Fri, 29 May 2026 17:03:11 -0400 Subject: [PATCH 08/22] formatting changes from ruff --- src/dial_dataclass/dial_dataclass.py | 149 +++++++++++++++------------ src/dial_service/serverside_data.py | 1 - 2 files changed, 81 insertions(+), 69 deletions(-) diff --git a/src/dial_dataclass/dial_dataclass.py b/src/dial_dataclass/dial_dataclass.py index c382797..a93fcad 100644 --- a/src/dial_dataclass/dial_dataclass.py +++ b/src/dial_dataclass/dial_dataclass.py @@ -1,12 +1,12 @@ from abc import ABC -from typing import Annotated, Literal, Optional +from typing import Annotated, Literal from pydantic import ( BaseModel, Field, + ValidationInfo, field_validator, model_validator, - ValidationInfo, ) from .pydantic_helpers import ValidatedObjectId @@ -17,35 +17,42 @@ PositiveIntType = Annotated[int, Field(ge=0)] -Label = Annotated[str, Field( - max_length=50, - description='Label for a dataset entry.' -)] -FloatOrLabel = Annotated[float | Label, Field( - description='A constant float, or the label of a dataset entry.', -)] +Label = Annotated[str, Field(max_length=50, description='Label for a dataset entry.')] +FloatOrLabel = Annotated[ + float | Label, + Field( + description='A constant float, or the label of a dataset entry.', + ), +] class Distribution(BaseModel, ABC): """Base class for a statistical distribution.""" - loc: Annotated[Label, Field( - description='The location (mean, or center) of the distribution', - )] - scale: Annotated[FloatOrLabel, Field( - description='The scale or standard deviation of the distribution', - )] + + loc: Annotated[ + Label, + Field( + description='The location (mean, or center) of the distribution', + ), + ] + scale: Annotated[ + FloatOrLabel, + Field( + description='The scale or standard deviation of the distribution', + ), + ] class Delta(Distribution): """The Delta distribution is deterministic and equal to its mean.""" - scale: float = Field(gt=0., lt=0., default=0., frozen=True) + + scale: float = Field(gt=0.0, lt=0.0, default=0.0, frozen=True) class Normal(Distribution): """The normal distribution is determined by loc (mean) and scale (standard deviation).""" - def _validate_dataset_lengths(dataset: list[any]) -> bool: """validate the lengths of dataset entries""" if len(dataset) > 1: @@ -58,12 +65,10 @@ def _validate_dataset_lengths(dataset: list[any]) -> bool: return True -def _validate_dims_and_length(data: dict, - x_name: str, - y_name: str) -> tuple[int, int]: +def _validate_dims_and_length(data: dict, x_name: str, y_name: str) -> tuple[int, int]: """validate the lengths of datasets, and compute dim_x and dim_y""" - print('\n'.join(f"{k}: {v}" for k, v in data.items())) + # print('\n'.join(f'{k}: {v}' for k, v in data.items())) dim_x = data.get('dim_x') dim_y = data.get('dim_y') @@ -74,23 +79,21 @@ def _validate_dims_and_length(data: dict, len_y = len(dataset_y) if len_y != len_x: - msg = (f'Unequal number of points in {x_name} {len_x=}' - f' and {y_name} {len_y=}.') + msg = f'Unequal number of points in {x_name} {len_x=} and {y_name} {len_y=}.' raise ValueError(msg) def compute_dim(dim, dataset, name): lenn = len(dataset) if dim is None and lenn == 0: - msg = (f'Can not infer dim from empty dataset {name}.' - 'Set dim to the correct dimension.') + msg = f'Can not infer dim from empty dataset {name}.Set dim to the correct dimension.' raise ValueError(msg) if lenn > 0: - print(dataset[0]) inferred_dim = len(dataset[0]) if isinstance(dataset[0], list) else 1 if dim is not None and inferred_dim != dim: - msg = (f'Vectors in {name} must be of length {dim=}.' - 'Set dim to the correct dimension.') + msg = ( + f'Vectors in {name} must be of length {dim=}.Set dim to the correct dimension.' + ) raise ValueError(msg) dim = inferred_dim @@ -100,8 +103,8 @@ def compute_dim(dim, dataset, name): dim_x = compute_dim(dim_x, dataset_x, x_name) dim_y = compute_dim(dim_y, dataset_y, y_name) - print(dim_x, dataset_x, x_name) - print(dim_y, dataset_y, y_name) + # print(dim_x, dataset_x, x_name) + # print(dim_y, dataset_y, y_name) # validate bounds, if they exist bounds = data.get('bounds') @@ -124,7 +127,7 @@ def compute_labels(dim, labels): raise ValueError(msg) elif dim > 1: # give each parameter a unique label by appending a number - labels = [f'{labels}_{i+1}' for i in range(dim)] + labels = [f'{labels}_{i + 1}' for i in range(dim)] else: # normalize to single element list labels = [labels] @@ -154,60 +157,68 @@ class _DialWorkflowCreationParams(BaseModel): Field(description='The input vectors of the training data'), ] dataset_y: Annotated[ - list[float | Annotated[ - # TODO: this could be the default - list[float], - Field(description='Field lengths of all subarrays should be equal'), - ] - ], + list[ + float + | Annotated[ + # TODO: this could be the default + list[float], + Field(description='Field lengths of all subarrays should be equal'), + ] + ], Field( - description=('The output values of the training data.' - ' Length should equal dataset_x'), + description=('The output values of the training data. Length should equal dataset_x'), ), ] labels_x: Annotated[ - Label | list[Label], - Field(default='x', - description='Labels for input variables x.') + Label | list[Label], Field(default='x', description='Labels for input variables x.') ] labels_y: Annotated[ - Label | list[Label], - Field(default='y', - description='Labels for output variables y.') + Label | list[Label], Field(default='y', description='Labels for output variables y.') ] dim_x: Annotated[ PositiveIntType | None, Field( default=None, - description=('Provide the dimension of entries in dataset_x explicitly,' - ' e.g. if the initial dataset is empty.' - ' If None, it will be inferred from dataset_x if possible.'), - )] + description=( + 'Provide the dimension of entries in dataset_x explicitly,' + ' e.g. if the initial dataset is empty.' + ' If None, it will be inferred from dataset_x if possible.' + ), + ), + ] dim_y: Annotated[ PositiveIntType | None, Field( default=1, - description=('Provide the dimension of entries in dataset_y explicitly,' - ' e.g. if the initial dataset is empty.' - ' If None, it will be inferred from dataset_y if possible.'), - )] + description=( + 'Provide the dimension of entries in dataset_y explicitly,' + ' e.g. if the initial dataset is empty.' + ' If None, it will be inferred from dataset_y if possible.' + ), + ), + ] statistics_y: Annotated[ Distribution, Field( default=Delta(loc='y'), - description=('Provide the statistical model underlying the y data: For example:' - " Delta(loc='y') means that the y data is without error," - " Normal(loc='y', scale=0.1) is a standard error with mean y and standard deviation 0.1" - " Normal(loc='y', scale='yerr') takes the std.dev. from the data column yerr.") - )] + description=( + 'Provide the statistical model underlying the y data: For example:' + " Delta(loc='y') means that the y data is without error," + " Normal(loc='y', scale=0.1) is a standard error with mean y and standard deviation 0.1" + " Normal(loc='y', scale='yerr') takes the std.dev. from the data column yerr." + ), + ), + ] y_is_good: Annotated[ bool, Field( default=True, # <-- Set default here - description=('If true, treat higher y values as better' - ' (e.g. y represents yield or profit).' - ' If false, opposite (e.g. y represents error or waste)'), + description=( + 'If true, treat higher y values as better' + ' (e.g. y represents yield or profit).' + ' If false, opposite (e.g. y represents error or waste)' + ), ), ] kernel: Literal['rbf', 'matern', 'linear'] @@ -255,16 +266,15 @@ def ensure_consistent_dataset_lengths(cls, dataset, info: ValidationInfo): # order rows as [low, high] - do NOT error out here, we can efficiently handle normalization @field_validator('bounds', mode='after') @classmethod - def order_bounds(cls, bounds: list[list[float]], info: ValidationInfo): + def order_bounds(cls, bounds: list[list[float]]): for row in bounds: row.sort() return bounds @model_validator(mode='after') - def validate_dims_and_length(self, values): + def validate_dims_and_length(self): # compute the dimensions and validate consistency - self.dim_x, self.dim_y = _validate_dims_and_length(vars(self), - 'dataset_x', 'dataset_y') + self.dim_x, self.dim_y = _validate_dims_and_length(vars(self), 'dataset_x', 'dataset_y') # compute or validate labels self.labels_x, self.labels_y = _validate_labels(self) return self @@ -279,6 +289,7 @@ class DialWorkflowCreationParamsClient(_DialWorkflowCreationParams): class DialWorkflowDatasetUpdate(BaseModel): """This class is used to send a single update to the dataset.""" + workflow_id: ValidatedObjectId next_x: list[float] = Field( description='The next collection of X values you want to append to your overall data', @@ -303,6 +314,7 @@ class DialWorkflowDatasetUpdate(BaseModel): class DialWorkflowDatasetUpdates(BaseModel): """This class is used to send multiple updates to the dataset.""" + workflow_id: ValidatedObjectId next_x_list: list[list[float]] = Field(min_length=1) next_y_list: list[float] = Field(min_length=1) @@ -320,11 +332,11 @@ def ensure_consistent_dataset_x_lengths(cls, dataset, ctx): return dataset @model_validator(mode='after') - def validate_dims_and_length(self, values): - _validate_dims_and_length(vars(self), - 'next_x_list', 'next_y_list') + def validate_dims_and_length(self): + _validate_dims_and_length(vars(self), 'next_x_list', 'next_y_list') return self + class DialInputSingleConfidenceBound(BaseModel): workflow_id: ValidatedObjectId strategy: Literal['confidence_bound'] @@ -405,6 +417,7 @@ class DialInputSingleOtherStrategy(BaseModel): class DialInputMultipleOtherStrategy(BaseModel): """TODO: document this""" + workflow_id: ValidatedObjectId points: PositiveIntType strategy: Literal[ diff --git a/src/dial_service/serverside_data.py b/src/dial_service/serverside_data.py index 493f327..cad9209 100644 --- a/src/dial_service/serverside_data.py +++ b/src/dial_service/serverside_data.py @@ -78,7 +78,6 @@ def _scale_X(self, X: np.ndarray) -> np.ndarray: return (X - lows) / span - # undoes the preprocessing. def inverse_transform(self, data: np.ndarray, is_stddev: bool = False): if len(self.Y_raw) == 0: From 2a042d73f970a7806e3b563064cf8b0f90fce957 Mon Sep 17 00:00:00 2001 From: Konstantin Pieper Date: Fri, 29 May 2026 17:07:03 -0400 Subject: [PATCH 09/22] extract y and yerr from dataset_y unsing statistics_y - prepare the new way how to set yerr through statistics_y - unify and fix the output scaling --- src/dial_service/core.py | 5 +- src/dial_service/serverside_data.py | 161 +++++++++++++++++++++------- tests/unit/test_internals.py | 40 ++++--- 3 files changed, 150 insertions(+), 56 deletions(-) diff --git a/src/dial_service/core.py b/src/dial_service/core.py index 175f878..ce3712c 100644 --- a/src/dial_service/core.py +++ b/src/dial_service/core.py @@ -86,9 +86,8 @@ def get_surrogate_values(data: ServersideInputPrediction, model: Any) -> list[li backend = data.backend.lower() module = get_backend_module(backend) means, stddevs = module.predict(model, data) - means = data.inverse_transform(means) - # TODO: adjust inverse transform. - transformed_stddevs = data.inverse_transform(stddevs, is_stddev=True) + means, transformed_stddevs = data.inverse_transform_Y(means, stddevs) + # TODO: the third return argument is not needed. return [means.tolist(), transformed_stddevs.tolist(), stddevs.tolist()] diff --git a/src/dial_service/serverside_data.py b/src/dial_service/serverside_data.py index cad9209..4d139fc 100644 --- a/src/dial_service/serverside_data.py +++ b/src/dial_service/serverside_data.py @@ -18,8 +18,9 @@ def __init__(self, data: DialWorkflowCreationParamsService): self.dim_y = data.dim_y self.labels_x = data.labels_x self.labels_y = data.labels_y - self.X_raw = np.array(data.dataset_x) - self.Y_raw = np.array(data.dataset_y) + self.dataset_x = np.array(data.dataset_x) + self.dataset_y = np.array(data.dataset_y).reshape((-1, self.dim_y)) + self.statistics_y = data.statistics_y # it seems like there should be a smarter way to do this, but stuff involving loops doesn't work with static autocompleters: self.bounds = data.bounds self.y_is_good = data.y_is_good @@ -41,26 +42,9 @@ def X_train(self) -> np.ndarray: dataset_x: list[list[float]], shape (N, D) bounds: list[[low, high], ...], shape (D, 2) """ - return self._scale_X(self.X_raw) + return self.scale_X(self.dataset_x) - @cached_property - def Y_stddev(self) -> float: - return np.std(self.Y_train) - - @cached_property - def Y_best(self) -> float: - return self.Y_train.max() if self.y_is_good else self.Y_train.min() - - @cached_property - def Y_train(self) -> np.ndarray: - y = self.Y_raw - if self.preprocess_log: - y = np.log(y) - if self.preprocess_standardize: - y = (y - np.mean(y)) / np.std(y) - return y - - def _scale_X(self, X: np.ndarray) -> np.ndarray: + def scale_X(self, X: np.ndarray) -> np.ndarray: """ Scale X into [0, 1]^D using self.bounds. X: array of shape (N, D) @@ -78,23 +62,124 @@ def _scale_X(self, X: np.ndarray) -> np.ndarray: return (X - lows) / span - # undoes the preprocessing. - def inverse_transform(self, data: np.ndarray, is_stddev: bool = False): - if len(self.Y_raw) == 0: - return data + def _extract_y_train_from_dataset(self): + """ + Find output y and error values yerr in dataset_y, and save. + """ + if hasattr(self, 'y_train_raw') and hasattr(self, 'yerr_train_raw'): + # only compute this on first invocation + return + + y_label = self.statistics_y.loc + if not isinstance(y_label, str): + msg = 'statistics_y.loc must be a Label (str).' + raise TypeError(msg) + + # Use the label from self.statistics_y.loc to find the data column with the mean y data + # this may trigger a ValueError, if the label does not exist, but should be handled by dataclass validation + pos_y = self.labels_y.index(y_label) + self.y_train_raw = self.dataset_y[:, pos_y] + + yerr_label = self.statistics_y.scale + if isinstance(yerr_label, float): + self.yerr_train_raw = yerr_label + else: + # yerr_label is str + # this may trigger a ValueError, but should be handled by dataclass validation + pos_yerr = self.labels_y.index(yerr_label) + self.yerr_train_raw = self.dataset_y[:, pos_yerr] + + if np.any(self.yerr_train_raw < 0): + idxs = np.where(np.yerr_train_raw < 0) + msg = f'yerr values in statistics_y.scale must be non-negative, found {np.yerr_train_raw[idxs[0]]} at {idxs[0]}.' + raise ValueError(msg) - # not possible to un-log the standard deviations (-1 +- 1 in log space != .1 +- 10 in realspace) - if self.preprocess_log and is_stddev: - return np.repeat(-1, len(data)) + @cached_property + def Y_train(self) -> np.ndarray: + """ + Find output y and error values yerr in dataset_y, and apply transformation. + Return transformed y value. + """ + # ensure that self.y_train_raw, self.yerr_train_raw are populated + self._extract_y_train_from_dataset() + + y, _ = self.transform_Y(self.y_train_raw, self.yerr_train_raw) + + # return only y, to conform to interface + return y + + @cached_property + def Yerr_train(self) -> any: + """ + Find output y and error values in dataset y, and apply transformation. + Return transformed yerr value. + """ + # ensure that self.y_train_raw, self.yerr_train_raw are populated + self._extract_y_train_from_dataset() + + # recompute transformation, at some overhead (probably not worth to optimize) + _, yerr = self.transform_Y(self.y_train_raw, self.yerr_train_raw) + + # return only yerr, to conform to interface + return yerr + + def _transform_Y_params(self) -> tuple[float, float]: + """ + Return the appropriate mean and scaling of the raw y data for normalization + """ + # ensure that self.y_train_raw, self.yerr_train_raw are populated + self._extract_y_train_from_dataset() + + # find y_std from y_train_raw + y_train = self.y_train_raw + if len(y_train) > 0 and self.preprocess_standardize: + if self.preprocess_log: + y_train = np.log(y_train) + y_std = np.std(y_train) + y_mean = np.mean(y_train) + else: + y_std = 1.0 + y_mean = 0.0 + + return y_mean, y_std + + def transform_Y(self, y: np.ndarray, yerr: any) -> tuple[np.ndarray, any]: + """ + Transform y and yerr according to preprocess options + """ + if self.preprocess_log: + yerr = yerr / y + y = np.log(y) + + if self.preprocess_standardize: + y_mean, y_std = self._transform_Y_params() + yerr = yerr / y_std + y = (y - y_mean) / y_std + + return y, yerr + + def inverse_transform_Y(self, y: np.ndarray, yerr: any) -> tuple[np.ndarray, any]: + """ + Inverse transforms of y and yerr, in reverse order + """ if self.preprocess_standardize: - # the data that was used to calculate the standardization: - prestandardized_y = np.log(self.Y_raw) if self.preprocess_log else self.Y_raw - data = data * np.std(prestandardized_y) # not the same as *= (which is in-place) - if not is_stddev: - data = data + np.mean(prestandardized_y) + y_mean, y_std = self._transform_Y_params() + y = y_mean + y_std * y + yerr = y_std * yerr + if self.preprocess_log: - data = np.exp(data) - return data + y = np.exp(y) + yerr = y * yerr + + return y, yerr + + @cached_property + def Y_stddev(self) -> float: + return np.std(self.Y_train) + + @cached_property + def Y_best(self) -> float: + return self.Y_train.max() if self.y_is_good else self.Y_train.min() class ServersideInputSingle(ServersideInputBase): @@ -119,7 +204,7 @@ def set_x_predict(self, X_raw: np.ndarray) -> None: X_raw: shape (N, D) or (D,) for a single point. """ raw_vals = np.asarray(X_raw, dtype=float).reshape(-1, self.dim_x) - self.x_predict = self._scale_X(raw_vals) + self.x_predict = self.scale_X(raw_vals) class ServersideInputMultiple(ServersideInputBase): @@ -148,7 +233,7 @@ def set_x_predict(self, X_raw: np.ndarray) -> None: X_raw: shape (N, D) or (D,) for a single point. """ raw_vals = np.asarray(X_raw, dtype=float).reshape(-1, self.dim_x) - self.x_predict = self._scale_X(raw_vals) + self.x_predict = self.scale_X(raw_vals) class ServersideInputPrediction(ServersideInputBase): @@ -165,4 +250,4 @@ def set_x_predict(self, X_raw: np.ndarray) -> None: X_raw: shape (N, D) or (D,) for a single point. """ raw_vals = np.asarray(X_raw, dtype=float).reshape(-1, self.dim_x) - self.x_predict = self._scale_X(raw_vals) + self.x_predict = self.scale_X(raw_vals) diff --git a/tests/unit/test_internals.py b/tests/unit/test_internals.py index 7c0a1d1..33b4b3d 100644 --- a/tests/unit/test_internals.py +++ b/tests/unit/test_internals.py @@ -369,29 +369,39 @@ def test_surrogate(backend, expected_means, expected_stddevs, expected_raw_stdde ('backend'), [ ('sklearn'), - # ('gpax'), ], ) def test_inverse_transform(backend): data = prediction_1D(backend) - assert data.inverse_transform(np.array([-1, 0, 1])) == pytest.approx([-1, 0, 1]) - assert data.inverse_transform(np.array([-1, 0, 1]), True) == pytest.approx([-1, 0, 1]) + + test_y = np.array([-1, 0, 1]) + test_yerr = np.array([0.1, 1, 10]) + + def test_transform(inv_y, inv_yerr): + y, yerr = data.transform_Y(inv_y, inv_yerr) + assert y == pytest.approx(test_y) + assert yerr == pytest.approx(test_yerr) + + inv_y, inv_yerr = data.inverse_transform_Y(test_y, test_yerr) + assert inv_y == pytest.approx(test_y) + assert inv_yerr == pytest.approx(test_yerr) + test_transform(inv_y, inv_yerr) data.preprocess_log = True - assert data.inverse_transform(np.array([-1, 0, 1])) == pytest.approx( - [1 / E_CONSTANT, 1, E_CONSTANT] - ) - assert data.inverse_transform(np.array([-1, 0, 1]), True) == pytest.approx([-1, -1, -1]) + inv_y, inv_yerr = data.inverse_transform_Y(test_y, test_yerr) + assert inv_y == pytest.approx([1 / E_CONSTANT, 1, E_CONSTANT]) + assert inv_yerr == pytest.approx(inv_y * test_yerr) + test_transform(inv_y, inv_yerr) data.preprocess_log = False data.preprocess_standardize = True - assert data.inverse_transform(np.array([-1, 0, 1])) == pytest.approx([100, 150, 200]) - assert data.inverse_transform(np.array([-1, 0, 1]), True) == pytest.approx( - [-50, 0, 50] - ) # technically improper, as uncertainties can't be negative + inv_y, inv_yerr = data.inverse_transform_Y(test_y, test_yerr) + assert inv_y == pytest.approx([100, 150, 200]) + assert inv_yerr == pytest.approx(50 * test_yerr) + test_transform(inv_y, inv_yerr) data.preprocess_log = True - assert data.inverse_transform(np.array([-1, 0, 1])) == pytest.approx( - [100, 141.42135623730945, 200] - ) # TODO - assert data.inverse_transform(np.array([-1, 0, 1]), True) == pytest.approx([-1, -1, -1]) + inv_y, inv_yerr = data.inverse_transform_Y(test_y, test_yerr) + assert inv_y == pytest.approx([100, 141.42135623730945, 200]) + assert inv_yerr == pytest.approx(inv_y * 0.34657359027997243 * test_yerr) + test_transform(inv_y, inv_yerr) From b514dac0081d801018fffbacb76ffe40dce6bbbe Mon Sep 17 00:00:00 2001 From: Konstantin Pieper Date: Fri, 29 May 2026 17:35:26 -0400 Subject: [PATCH 10/22] remove leftover method Y_stddev - fix gpax stddev, untested --- src/dial_service/backends/gpax_backend.py | 15 ++++----------- src/dial_service/serverside_data.py | 4 ---- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/src/dial_service/backends/gpax_backend.py b/src/dial_service/backends/gpax_backend.py index 2cbfd74..0b45bba 100644 --- a/src/dial_service/backends/gpax_backend.py +++ b/src/dial_service/backends/gpax_backend.py @@ -14,9 +14,7 @@ class GpaxBackend(AbstractBackend[gpax.viGP, str, tuple[jnp.ndarray, jnp.ndarray @staticmethod def train_model(data): """Generate a trained model.""" - rng_key_train, rng_key_predict = gpax.utils.get_keys( - seed=data.seed if data.seed != -1 else None - ) + rng_key_train, _ = gpax.utils.get_keys(seed=data.seed if data.seed != -1 else None) gp_model = gpax.viGP(len(data.bounds), GpaxBackend.get_kernel(data), guide='delta') gp_model.fit( rng_key_train, @@ -32,22 +30,17 @@ def train_model(data): @staticmethod def initialize_model(data): """Generate an untrained model.""" - rng_key_train, rng_key_predict = gpax.utils.get_keys( - seed=data.seed if data.seed != -1 else None - ) return gpax.viGP(len(data.bounds), GpaxBackend.get_kernel(data), guide='delta') @staticmethod def predict(model, data): - rng_key_train, rng_key_predict = gpax.utils.get_keys( - seed=data.seed if data.seed != -1 else None - ) + _, rng_key_predict = gpax.utils.get_keys(seed=data.seed if data.seed != -1 else None) x = data.x_predict # mean, y_var = model.predict(rng_key_predict, data.x_predict) # TODO check why model.predict().reshape() fails mean, y_var = model.predict(rng_key_predict, x.reshape(1, -1)) - # TODO: why do we scale the variance by stddev, but not mean? - return mean[0], data.Y_stddev * y_var[0] + # return the square root of the variance (the standard deviation) + return mean[0], jnp.sqrt(y_var[0]) @staticmethod def get_kernel(data): diff --git a/src/dial_service/serverside_data.py b/src/dial_service/serverside_data.py index 4d139fc..69f721c 100644 --- a/src/dial_service/serverside_data.py +++ b/src/dial_service/serverside_data.py @@ -173,10 +173,6 @@ def inverse_transform_Y(self, y: np.ndarray, yerr: any) -> tuple[np.ndarray, any return y, yerr - @cached_property - def Y_stddev(self) -> float: - return np.std(self.Y_train) - @cached_property def Y_best(self) -> float: return self.Y_train.max() if self.y_is_good else self.Y_train.min() From 9ae0173adcbc208d46f6ce2d59fa53a938059b9a Mon Sep 17 00:00:00 2001 From: Konstantin Pieper Date: Fri, 29 May 2026 19:12:07 -0400 Subject: [PATCH 11/22] migrate sable_backend and client to the new (y,yerr) logic --- scripts/1d_sable_client.py | 6 ++++- src/dial_dataclass/__init__.py | 2 ++ src/dial_dataclass/dial_dataclass.py | 11 +++++--- src/dial_service/backends/sable_backend.py | 30 ++++++++++++++-------- 4 files changed, 34 insertions(+), 15 deletions(-) diff --git a/scripts/1d_sable_client.py b/scripts/1d_sable_client.py index 9660dbd..8b63d9e 100644 --- a/scripts/1d_sable_client.py +++ b/scripts/1d_sable_client.py @@ -23,6 +23,7 @@ DialInputSingleOtherStrategy, DialWorkflowCreationParamsClient, DialWorkflowDatasetUpdate, + Normal, ) mpl.use('agg') @@ -94,6 +95,8 @@ def __init__(self, service_destination: str): self.dataset_x = self.x_raw.reshape(-1, 1).tolist() self.dataset_y = self.y_raw.reshape(-1).tolist() + self.labels_y = ['y'] + self.statistics_y = Normal(loc='y', scale=self.noise_level) self.test_points = self.x_test.reshape(-1, 1).tolist() self.kernel = 'rbf' @@ -104,7 +107,6 @@ def __init__(self, service_destination: str): 'alpha': 0.05, 'p': 1.25, 'n_iter_irls': 100, - 'noise_level': self.noise_level, } self.strategy = 'upper_confidence_bound' self.strategy_args = {'exploit': 0.0, 'explore': 1.0} @@ -160,6 +162,8 @@ def callback_message(self, operation: str, **kwargs) -> IntersectClientCallback: next_payload = DialWorkflowCreationParamsClient( dataset_x=self.dataset_x, dataset_y=self.dataset_y, + labels_y=self.labels_y, + statistics_y=self.statistics_y, bounds=self.bounds.tolist(), kernel=self.kernel, kernel_args=self.kernel_args, diff --git a/src/dial_dataclass/__init__.py b/src/dial_dataclass/__init__.py index ecda389..99824eb 100644 --- a/src/dial_dataclass/__init__.py +++ b/src/dial_dataclass/__init__.py @@ -1,4 +1,5 @@ from .dial_dataclass import ( + Delta, DialInputMultiple, DialInputMultipleOtherStrategy, DialInputPredictions, @@ -8,6 +9,7 @@ DialWorkflowCreationParamsClient, DialWorkflowDatasetUpdate, DialWorkflowDatasetUpdates, + Normal, ) from .dial_dataclass_responses import ( DialDataResponse1D, diff --git a/src/dial_dataclass/dial_dataclass.py b/src/dial_dataclass/dial_dataclass.py index a93fcad..98bfdcb 100644 --- a/src/dial_dataclass/dial_dataclass.py +++ b/src/dial_dataclass/dial_dataclass.py @@ -26,7 +26,7 @@ ] -class Distribution(BaseModel, ABC): +class BaseDistribution(BaseModel, ABC): """Base class for a statistical distribution.""" loc: Annotated[ @@ -43,16 +43,19 @@ class Distribution(BaseModel, ABC): ] -class Delta(Distribution): - """The Delta distribution is deterministic and equal to its mean.""" +class Delta(BaseDistribution): + """The Delta distribution is deterministic and equal to its mean/loc.""" scale: float = Field(gt=0.0, lt=0.0, default=0.0, frozen=True) -class Normal(Distribution): +class Normal(BaseDistribution): """The normal distribution is determined by loc (mean) and scale (standard deviation).""" +Distribution = Annotated[Delta | Normal, Field(description='Union of all supported Distributions.')] + + def _validate_dataset_lengths(dataset: list[any]) -> bool: """validate the lengths of dataset entries""" if len(dataset) > 1: diff --git a/src/dial_service/backends/sable_backend.py b/src/dial_service/backends/sable_backend.py index b53e283..1a712c4 100644 --- a/src/dial_service/backends/sable_backend.py +++ b/src/dial_service/backends/sable_backend.py @@ -3,6 +3,8 @@ import numpy as np from sable import DiscretizedSurrogateModel, ScaledRBFModel +from dial_dataclass import Normal + from ..utilities import strategies from . import AbstractBackend @@ -30,17 +32,25 @@ def _get_model_kwargs(data) -> dict: def _get_observation_errors(data, n_observations: int) -> np.ndarray: backend_args = {} if data.backend_args is None else data.backend_args - y_err = backend_args.get('y_err', backend_args.get('noise_level', 1e-6)) - # TODO figure this out: We need a consistent way to configure alpha / noise_level / y_err - - y_err_arr = np.asarray(y_err, dtype=float).reshape(-1) - if y_err_arr.size == 1: - y_err_arr = np.full(n_observations, float(y_err_arr[0]), dtype=float) - if data.preprocess_standardize and len(data.Y_raw) > 0: - scale = np.std(np.asarray(data.Y_raw, dtype=float)) - if scale > 0: - y_err_arr = y_err_arr / scale + if isinstance(data.statistics_y, Normal): + y_err = data.Yerr_train + y_err_arr = np.asarray(y_err, dtype=float).reshape(-1) + if y_err_arr.size == 1: + y_err_arr = np.full(n_observations, float(y_err_arr[0]), dtype=float) + else: + # if y_err is not provided through the statistics, use the old fallback for compatibility + # TODO: remove if no longer needed + y_err = backend_args.get('y_err', backend_args.get('noise_level', 1e-6)) + + y_err_arr = np.asarray(y_err, dtype=float).reshape(-1) + if y_err_arr.size == 1: + y_err_arr = np.full(n_observations, float(y_err_arr[0]), dtype=float) + + if data.preprocess_standardize and len(data.Y_raw) > 0: + scale = np.std(np.asarray(data.Y_raw, dtype=float)) + if scale > 0: + y_err_arr = y_err_arr / scale return y_err_arr From 952061b1a9013485a44ad8357ae8455cba2f3947 Mon Sep 17 00:00:00 2001 From: Konstantin Pieper Date: Mon, 8 Jun 2026 18:06:22 -0400 Subject: [PATCH 12/22] make sk_learn backend respect output_statistics --- src/dial_service/backends/sklearn_backend.py | 100 +++++++++++-------- 1 file changed, 60 insertions(+), 40 deletions(-) diff --git a/src/dial_service/backends/sklearn_backend.py b/src/dial_service/backends/sklearn_backend.py index e70cde7..c4bab12 100644 --- a/src/dial_service/backends/sklearn_backend.py +++ b/src/dial_service/backends/sklearn_backend.py @@ -14,6 +14,8 @@ WhiteKernel, ) +from dial_dataclass import Normal + from ..utilities import strategies from . import AbstractBackend @@ -33,7 +35,8 @@ def _filter_kwargs_for(cls, params: dict) -> dict: """Keep only kwargs that `cls.__init__` actually accepts.""" sig = inspect.signature(cls.__init__) allowed = set(sig.parameters) - {'self', 'args', 'kwargs'} - return {k: v for k, v in params.items() if k in allowed} + params_filtered = {k: v for k, v in params.items() if k in allowed} + return params_filtered class SklearnBackend( @@ -47,68 +50,85 @@ def get_kernel(data): raise ValueError(msg) _params = {} if data.kernel_args is None else data.kernel_args - if 'length_scale' not in _params: - length_per_dimension = ( - data.extra_args.get('length_per_dimension') if data.extra_args else False - ) - # TODO check if necessary - # dim = data.X_train.shape[1] - # _params['length_scale'] = [1.0] * dim if length_per_dimension else 1.0 - _params['length_scale'] = [1.0] * data.dim_x if length_per_dimension else 1.0 + # TODO: test length_per_dimension and reenable + # if 'length_scale' not in _params: + # length_per_dimension = ( + # data.extra_args.get('length_per_dimension') if data.extra_args else False + # ) + # # TODO check if necessary + # # dim = data.X_train.shape[1] + # # _params['length_scale'] = [1.0] * dim if length_per_dimension else 1.0 + # _params['length_scale'] = [1.0] * data.dim_x if length_per_dimension else 1.0 base_kernel_cls = _KERNELS_SKLEARN[kernel_name] base_params = _filter_kwargs_for(base_kernel_cls, _params) - # Only do hyperparameter optimization if the user asks for it - # TODO make the default parameters for the kernels different from the sklearn defaults, but allow the user to customize it - const_params = {'constant_value_bounds': 'fixed', 'constant_value': 1.0} - const_params.update(_filter_kwargs_for(ConstantKernel, _params)) - white_params = {'noise_level_bounds': 'fixed', 'noise_level': 1e-6} - white_params.update(_filter_kwargs_for(WhiteKernel, _params)) - + # only do hyperparameter optimization if the user asks for it, use fixed defaults if base_kernel_cls == DotProduct: base_params = {'sigma_0': 1.0, 'sigma_0_bounds': 'fixed'} else: base_params = {'length_scale': 1.0, 'length_scale_bounds': 'fixed'} base_params.update(_filter_kwargs_for(base_kernel_cls, _params)) - - constant_kernel = ConstantKernel(**const_params) base_kernel = base_kernel_cls(**base_params) - white_kernel = WhiteKernel(**white_params) - return constant_kernel * base_kernel + white_kernel + # scale the prior variance by using a ConstantKernel + const_params = _filter_kwargs_for(ConstantKernel, _params) + if const_params: + # use a fixed value by default, unless bounds are explicitly provided + const_params = {'constant_value': 1.0, 'constant_value_bounds': 'fixed'} | const_params + constant_kernel = ConstantKernel(**const_params) + kernel = constant_kernel * base_kernel + else: + kernel = base_kernel + + # if requested, add a WhiteKernel with variance noise_level + white_params = _filter_kwargs_for(WhiteKernel, _params) + if white_params: + # use a fixed value by default, unless bounds are explicitly provided + white_params = {'noise_level': 1e-12, 'noise_level_bounds': 'fixed'} | white_params + white_kernel = WhiteKernel(**white_params) + kernel = kernel + white_kernel + + return kernel @staticmethod def train_model(data): """Create a model with training.""" - if data.backend_args is None: - _extra_args = {} - else: - _extra_args = data.backend_args.copy() # Ensure it's a dictionary - if 'alpha' in _extra_args and not isinstance(_extra_args['alpha'], np.ndarray): - # Process alpha as a numpy array - _extra_args['alpha'] = np.array(_extra_args['alpha']) - # print(_extra_args['alpha']) - model = GaussianProcessRegressor( - kernel=SklearnBackend.get_kernel(data), n_restarts_optimizer=1000, **_extra_args - ) + model = SklearnBackend.initialize_model(data) + + # print(f"pre-training kernel: {SklearnBackend.get_kernel(data)}") model.fit(data.X_train, data.Y_train) + # print(f"obtained trained kernel: {model.kernel_}") + return model @staticmethod def initialize_model(data): """Create a model without training.""" - if data.backend_args is None: - _extra_args = {} - else: - _extra_args = data.backend_args.copy() # Ensure it's a dictionary - if 'alpha' in _extra_args and not isinstance(_extra_args['alpha'], np.ndarray): + kernel = SklearnBackend.get_kernel(data) + + # if the output statistics are a normal distribution, configure the alpha argument for sklearn + _statistics_args = {} + if isinstance(data.statistics_y, Normal): + # set alpha to the variance associated to Y_err_train + y_variance_train = data.Yerr_train**2 + _statistics_args['alpha'] = y_variance_train + + _extra_args = {} + if data.backend_args is not None: + # copy backend_args + _extra_args = data.backend_args.copy() + + # backwards compatible way to set alpha + if 'alpha' in _extra_args: # Process alpha as a numpy array - _extra_args['alpha'] = np.array(_extra_args['alpha']) + _extra_args['alpha'] = np.asarray(_extra_args['alpha']) + + # update alpha from statistics_args, if present + # TODO: should raise a warning, if already present, or if WhiteKernel is present + _extra_args.update(_statistics_args) - return GaussianProcessRegressor( - kernel=SklearnBackend.get_kernel(data), n_restarts_optimizer=1000, **_extra_args - ) + return GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=1000, **_extra_args) @staticmethod def predict(model, data): From 8b21e7f9864f3c3870f6b3b1626de6cccc71f897 Mon Sep 17 00:00:00 2001 From: Konstantin Pieper Date: Mon, 8 Jun 2026 18:11:06 -0400 Subject: [PATCH 13/22] fixes for clients - manual_client: adapt to new payload structure - sinusoidal_growth_client: handle std.dev. vs variance and configure length_scale --- scripts/1d_sinusoidal_growth_client.py | 43 ++++++++++++++++---------- scripts/manual_client.py | 23 +++++++++----- 2 files changed, 42 insertions(+), 24 deletions(-) diff --git a/scripts/1d_sinusoidal_growth_client.py b/scripts/1d_sinusoidal_growth_client.py index 88bf688..12f20c8 100644 --- a/scripts/1d_sinusoidal_growth_client.py +++ b/scripts/1d_sinusoidal_growth_client.py @@ -52,11 +52,11 @@ def __init__(self): class ActiveLearningOrchestrator: def __init__(self, service_destination: str): - self.bounds = np.array([[-2, 2]]) + self.bounds = np.array([[-2.0, 2.0]]) self.num_dims = len(self.bounds) - self.x_raw = np.array([[1], [2.0]]) - self.x_test = np.array([[-1], [0.5]]) + self.x_raw = np.array([[1.0], [2.0]]) + self.x_test = np.array([[-1.0], [0.5]]) self.y_raw = sinusoidal_growth(self.x_raw) self.meshgrid_size = 100 @@ -72,8 +72,20 @@ def __init__(self, service_destination: str): self.dataset_y = self.y_raw.reshape(-1).tolist() self.test_points = self.x_test.reshape(-1, 1).tolist() + # configure kernel_hyperparameters self.kernel = 'rbf' - self.kernel_args = {'length_scale': 0.12, 'length_scale_bounds': (0.1, 1.0)} + self.kernel_args = { + 'length_scale': 0.1, + 'constant_value': 2.0, + } + self.optimize_lengthscale = False + if self.optimize_lengthscale: + self.kernel_args.update( + { + 'length_scale_bounds': (0.02, 0.2), + } + ) + self.backend = 'sklearn' self.backend_args = None self.strategy = 'upper_confidence_bound' @@ -81,9 +93,9 @@ def __init__(self, service_destination: str): self.niter = 0 self.max_iter = 20 self.at_grids = True - self.variance_grid = None + self.stddev_grid = None self.mean_grid = None - self.variance_test = None + self.stddev_test = None self.mean_test = None self.x_next = None @@ -142,7 +154,6 @@ def callback_message(self, operation: str, **kwargs) -> IntersectClientCallback: kernel_args=self.kernel_args, backend=self.backend, backend_args=self.backend_args, - extra_args={'length_per_dimension': True}, preprocess_standardize=True, y_is_good=True, seed=20, @@ -189,16 +200,16 @@ def callback_message(self, operation: str, **kwargs) -> IntersectClientCallback: def handle_surrogate_values(self, payload): response_data = payload['data'] if self.at_grids: - self.variance_grid = np.array(response_data[1]).reshape( + self.stddev_grid = np.array(response_data[1]).reshape( (self.meshgrid_size,) * self.num_dims ) self.mean_grid = np.array(response_data[0]).reshape( (self.meshgrid_size,) * self.num_dims ) else: - self.variance_test = np.array(response_data[1]) + self.stddev_test = np.array(response_data[1]) self.mean_test = np.array(response_data[0]) - print(f'Test Mean: {self.mean_test}, Variance: {self.variance_test}') + print(f'Test Mean: {self.mean_test}, Std.Dev.: {self.stddev_test}') # end of active learning loop after max_iter if self.niter > self.max_iter: @@ -225,12 +236,12 @@ def graph(self): fig, axs = plt.subplots(2, 1, figsize=(10, 8), sharex=True) - # First subplot: Mean and variance with training data + # First subplot: Mean and standard deviation with training data axs[0].plot(self.x_grid, self.mean_grid, label='Mean Prediction') axs[0].fill_between( self.x_grid[:, 0], - self.mean_grid + 2 * self.variance_grid, - self.mean_grid - 2 * self.variance_grid, + self.mean_grid + 2 * self.stddev_grid, + self.mean_grid - 2 * self.stddev_grid, alpha=0.5, label='Confidence Interval', ) @@ -249,12 +260,10 @@ def graph(self): # Second subplot: Acquisition function if self.strategy_args is not None: - if self.mean_grid is not None and self.variance_grid is not None: + if self.mean_grid is not None and self.stddev_grid is not None: exploit = self.strategy_args.get('exploit', 0.0) explore = self.strategy_args.get('explore', 1.0) - acquisition_values = exploit * self.mean_grid + explore * np.sqrt( - self.variance_grid - ) + acquisition_values = exploit * self.mean_grid + explore * self.stddev_grid else: acquisition_values = np.zeros_like(self.x_grid) diff --git a/scripts/manual_client.py b/scripts/manual_client.py index 8b436d9..3c55d3c 100644 --- a/scripts/manual_client.py +++ b/scripts/manual_client.py @@ -149,6 +149,7 @@ def assemble_message(self, operation: str, **kwargs: Any) -> IntersectClientCall payload = DialInputSingleOtherStrategy( workflow_id=self.workflow_id, strategy='expected_improvement', + bounds=BOUNDS, ) elif operation == 'get_surrogate_values': payload = DialInputPredictions( @@ -179,20 +180,28 @@ def __call__( print(payload, file=sys.stderr) print(file=sys.stderr) raise Exception # noqa: TRY002 (break INTERSECT loop) + if operation == 'dial.initialize_workflow': self.workflow_id: str = payload return self.assemble_message('get_surrogate_values') + if operation == 'dial.update_workflow_with_data': return self.assemble_message('get_surrogate_values') - if ( - operation == 'dial.get_surrogate_values' - ): # if we receive a grid of surrogate values, record it for graphing, then ask for the next recommended point - self.mean_grid = np.array(payload[0]).reshape(XX.shape) + + if operation == 'dial.get_surrogate_values': + # if we receive a grid of surrogate values, record it for graphing, then ask for the next recommended point + data = payload['data'] + mean_grid = data[0] + self.mean_grid = np.array(mean_grid).reshape(XX.shape) return self.assemble_message('get_next_point') + if operation == 'dial.get_next_point': - # if we receive an EI recommendation, record it, show the user the current graph, and ask the user for the results of their experiment: - self.graph(payload) - return self.add_data(payload) + # if we receive an EI recommendation, record it, show the user the current graph, + # and ask the user for the results of their experiment: + data = payload['data'] + self.graph(data) + update_message = self.add_data(data) + return update_message err_msg = f'Unknown operation received: {operation}' raise Exception(err_msg) # noqa: TRY002 (INTERSECT interaction mechanism) From 4a7724aa18bac27051d14ff4fdb982867cc28152 Mon Sep 17 00:00:00 2001 From: Konstantin Pieper Date: Mon, 8 Jun 2026 19:03:06 -0400 Subject: [PATCH 14/22] re-enable extra_args length_per_dimension --- scripts/2d_rosenbrock_client.py | 1 - scripts/manual_client.py | 4 ++- src/dial_service/backends/sklearn_backend.py | 31 ++++++++++---------- tests/unit/test_internals.py | 13 -------- 4 files changed, 18 insertions(+), 31 deletions(-) diff --git a/scripts/2d_rosenbrock_client.py b/scripts/2d_rosenbrock_client.py index 8abcb5b..1e9ff0e 100644 --- a/scripts/2d_rosenbrock_client.py +++ b/scripts/2d_rosenbrock_client.py @@ -103,7 +103,6 @@ def assemble_message(self, operation: str, **kwargs: Any) -> IntersectClientCall 'constant_value': CONSTANT_VALUE, 'constant_value_bounds': 'fixed', }, - length_per_dimension=False, # allow the matern to use separate length scales for the two parameters y_is_good=False, # we wish to minimize y (the error) backend='sklearn', # "sklearn" or "gpax" seed=-1, # Use seed = -1 for random results diff --git a/scripts/manual_client.py b/scripts/manual_client.py index 3c55d3c..37295ff 100644 --- a/scripts/manual_client.py +++ b/scripts/manual_client.py @@ -135,7 +135,9 @@ def assemble_message(self, operation: str, **kwargs: Any) -> IntersectClientCall dataset_y=INITIAL_DATASET_Y, bounds=BOUNDS, kernel='matern', - length_per_dimension=True, # allow the matern to use separate length scales for the two parameters + extra_args={ + 'length_per_dimension': True + }, # allow the matern to use separate length scales for the two parameters y_is_good=False, # we wish to minimize y (the error) backend='sklearn', # "sklearn" or "gpax" seed=-1, # Use seed = -1 for random results diff --git a/src/dial_service/backends/sklearn_backend.py b/src/dial_service/backends/sklearn_backend.py index c4bab12..32e605c 100644 --- a/src/dial_service/backends/sklearn_backend.py +++ b/src/dial_service/backends/sklearn_backend.py @@ -48,27 +48,26 @@ def get_kernel(data): if kernel_name not in _KERNELS_SKLEARN: msg = f'Unknown kernel {kernel_name}' raise ValueError(msg) - _params = {} if data.kernel_args is None else data.kernel_args - - # TODO: test length_per_dimension and reenable - # if 'length_scale' not in _params: - # length_per_dimension = ( - # data.extra_args.get('length_per_dimension') if data.extra_args else False - # ) - # # TODO check if necessary - # # dim = data.X_train.shape[1] - # # _params['length_scale'] = [1.0] * dim if length_per_dimension else 1.0 - # _params['length_scale'] = [1.0] * data.dim_x if length_per_dimension else 1.0 + + _params = {} if data.kernel_args is None else data.kernel_args.copy() + + # if length_scale is not provided, but extra_args['length_per_direction'], + # configure a default learnable dimension dependent length_scale + if 'length_scale' not in _params: + length_per_dimension = ( + data.extra_args.get('length_per_dimension') if data.extra_args else False + ) + _params['length_scale'] = [1.0] * data.dim_x if length_per_dimension else 1.0 + _params['length_scale_bounds'] = (1e-05, 100000.0) base_kernel_cls = _KERNELS_SKLEARN[kernel_name] base_params = _filter_kwargs_for(base_kernel_cls, _params) # only do hyperparameter optimization if the user asks for it, use fixed defaults if base_kernel_cls == DotProduct: - base_params = {'sigma_0': 1.0, 'sigma_0_bounds': 'fixed'} + base_params = {'sigma_0': 1.0, 'sigma_0_bounds': 'fixed'} | base_params else: - base_params = {'length_scale': 1.0, 'length_scale_bounds': 'fixed'} - base_params.update(_filter_kwargs_for(base_kernel_cls, _params)) + base_params = {'length_scale': 1.0, 'length_scale_bounds': 'fixed'} | base_params base_kernel = base_kernel_cls(**base_params) # scale the prior variance by using a ConstantKernel @@ -96,9 +95,9 @@ def train_model(data): """Create a model with training.""" model = SklearnBackend.initialize_model(data) - # print(f"pre-training kernel: {SklearnBackend.get_kernel(data)}") + # print(f'pre-training kernel: {SklearnBackend.get_kernel(data)}') model.fit(data.X_train, data.Y_train) - # print(f"obtained trained kernel: {model.kernel_}") + # print(f'obtained trained kernel: {model.kernel_}') return model diff --git a/tests/unit/test_internals.py b/tests/unit/test_internals.py index 5cc5032..7fd1839 100644 --- a/tests/unit/test_internals.py +++ b/tests/unit/test_internals.py @@ -35,8 +35,6 @@ def single_1D(backend, strategy, strategy_args): kernel_args={ 'length_scale': 0.5, 'length_scale_bounds': 'fixed', - 'noise_level': 0.0, - 'noise_level_bounds': 'fixed', 'constant_value': 1.0, 'constant_value_bounds': 'fixed', }, @@ -65,8 +63,6 @@ def single_1D_discrete_grid(backend, strategy, strategy_args, discrete_measureme kernel_args={ 'length_scale': 0.5, 'length_scale_bounds': 'fixed', - 'noise_level': 0.0, - 'noise_level_bounds': 'fixed', 'constant_value': 1.0, 'constant_value_bounds': 'fixed', }, @@ -119,8 +115,6 @@ def single_2D(backend, strategy, strategy_args, discrete_measurement_grid_size=N kernel_args={ 'length_scale': 0.15, 'length_scale_bounds': 'fixed', - 'noise_level': 0.0, - 'noise_level_bounds': 'fixed', 'constant_value': 1.0, 'constant_value_bounds': 'fixed', }, @@ -173,15 +167,12 @@ def single_3D(backend, strategy, strategy_args, discrete_measurement_grid_size=N kernel_args={ 'length_scale': 0.15, 'length_scale_bounds': 'fixed', - 'noise_level': 0.0, - 'noise_level_bounds': 'fixed', 'constant_value': 1.0, 'constant_value_bounds': 'fixed', }, backend=backend, preprocess_standardize=True, y_is_good=True, - extra_args={'length_per_dimension': True}, seed=42, ) params = DialInputSingleOtherStrategy( @@ -203,7 +194,6 @@ def multiple_2D(backend, strategy, discrete_measurement_grid_size=None): dim_x=2, # provide dim_x for empty dataset y_is_good=False, kernel='rbf', - length_per_dimension=False, bounds=[[0, 100], [-1, 1]], backend=backend, seed=42, @@ -228,15 +218,12 @@ def prediction_1D(backend): kernel_args={ 'length_scale': 0.5, 'length_scale_bounds': 'fixed', - 'noise_level': 0.0, - 'noise_level_bounds': 'fixed', 'constant_value': 50.0**2, 'constant_value_bounds': 'fixed', }, backend=backend, preprocess_standardize=False, y_is_good=True, - extra_args={'length_per_dimension': True}, seed=42, ) params = DialInputPredictions( From f9961e3394ce90e3f87b30d80e73f5d21a2644bb Mon Sep 17 00:00:00 2001 From: Konstantin Pieper Date: Mon, 15 Jun 2026 18:58:23 -0400 Subject: [PATCH 15/22] do not rely on type comparision, but use name instead - there appears to be an interaction between pydantic validation and type comparisons that makes them not be stable. Use a str with name instead. --- src/dial_dataclass/dial_dataclass.py | 4 ++++ src/dial_service/backends/sable_backend.py | 9 +-------- src/dial_service/backends/sklearn_backend.py | 10 ++++------ 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/src/dial_dataclass/dial_dataclass.py b/src/dial_dataclass/dial_dataclass.py index 98bfdcb..4786c1a 100644 --- a/src/dial_dataclass/dial_dataclass.py +++ b/src/dial_dataclass/dial_dataclass.py @@ -29,6 +29,7 @@ class BaseDistribution(BaseModel, ABC): """Base class for a statistical distribution.""" + name: str loc: Annotated[ Label, Field( @@ -46,12 +47,15 @@ class BaseDistribution(BaseModel, ABC): class Delta(BaseDistribution): """The Delta distribution is deterministic and equal to its mean/loc.""" + name: str = Field(default='Delta', frozen=True) scale: float = Field(gt=0.0, lt=0.0, default=0.0, frozen=True) class Normal(BaseDistribution): """The normal distribution is determined by loc (mean) and scale (standard deviation).""" + name: str = Field(default='Normal', frozen=True) + Distribution = Annotated[Delta | Normal, Field(description='Union of all supported Distributions.')] diff --git a/src/dial_service/backends/sable_backend.py b/src/dial_service/backends/sable_backend.py index 1a712c4..ba637b5 100644 --- a/src/dial_service/backends/sable_backend.py +++ b/src/dial_service/backends/sable_backend.py @@ -3,8 +3,6 @@ import numpy as np from sable import DiscretizedSurrogateModel, ScaledRBFModel -from dial_dataclass import Normal - from ..utilities import strategies from . import AbstractBackend @@ -33,7 +31,7 @@ def _get_model_kwargs(data) -> dict: def _get_observation_errors(data, n_observations: int) -> np.ndarray: backend_args = {} if data.backend_args is None else data.backend_args - if isinstance(data.statistics_y, Normal): + if data.statistics_y.name == 'Normal': y_err = data.Yerr_train y_err_arr = np.asarray(y_err, dtype=float).reshape(-1) if y_err_arr.size == 1: @@ -47,11 +45,6 @@ def _get_observation_errors(data, n_observations: int) -> np.ndarray: if y_err_arr.size == 1: y_err_arr = np.full(n_observations, float(y_err_arr[0]), dtype=float) - if data.preprocess_standardize and len(data.Y_raw) > 0: - scale = np.std(np.asarray(data.Y_raw, dtype=float)) - if scale > 0: - y_err_arr = y_err_arr / scale - return y_err_arr diff --git a/src/dial_service/backends/sklearn_backend.py b/src/dial_service/backends/sklearn_backend.py index 32e605c..2d649fe 100644 --- a/src/dial_service/backends/sklearn_backend.py +++ b/src/dial_service/backends/sklearn_backend.py @@ -14,8 +14,6 @@ WhiteKernel, ) -from dial_dataclass import Normal - from ..utilities import strategies from . import AbstractBackend @@ -108,7 +106,7 @@ def initialize_model(data): # if the output statistics are a normal distribution, configure the alpha argument for sklearn _statistics_args = {} - if isinstance(data.statistics_y, Normal): + if data.statistics_y.name == 'Normal': # set alpha to the variance associated to Y_err_train y_variance_train = data.Yerr_train**2 _statistics_args['alpha'] = y_variance_train @@ -123,9 +121,9 @@ def initialize_model(data): # Process alpha as a numpy array _extra_args['alpha'] = np.asarray(_extra_args['alpha']) - # update alpha from statistics_args, if present - # TODO: should raise a warning, if already present, or if WhiteKernel is present - _extra_args.update(_statistics_args) + # update alpha from statistics_args, if present + # TODO: should raise a warning, if already present, or if WhiteKernel is present + _extra_args.update(_statistics_args) return GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=1000, **_extra_args) From 169fe8641b6af7f68e47517a9826275ad33594a0 Mon Sep 17 00:00:00 2001 From: Konstantin Pieper Date: Mon, 15 Jun 2026 19:03:40 -0400 Subject: [PATCH 16/22] updates to sinosidal_growth test - test with different backend and tweak hyperparameters --- scripts/1d_sinusoidal_growth_client.py | 46 +++++++++++++++++--------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/scripts/1d_sinusoidal_growth_client.py b/scripts/1d_sinusoidal_growth_client.py index 12f20c8..db8da94 100644 --- a/scripts/1d_sinusoidal_growth_client.py +++ b/scripts/1d_sinusoidal_growth_client.py @@ -23,6 +23,7 @@ DialInputSingleOtherStrategy, DialWorkflowCreationParamsClient, DialWorkflowDatasetUpdate, + Normal, ) mpl.use('agg') @@ -72,22 +73,36 @@ def __init__(self, service_destination: str): self.dataset_y = self.y_raw.reshape(-1).tolist() self.test_points = self.x_test.reshape(-1, 1).tolist() - # configure kernel_hyperparameters - self.kernel = 'rbf' - self.kernel_args = { - 'length_scale': 0.1, - 'constant_value': 2.0, - } - self.optimize_lengthscale = False - if self.optimize_lengthscale: - self.kernel_args.update( - { - 'length_scale_bounds': (0.02, 0.2), - } - ) + # Assume that there is some small noise in the measurements to stabilize the fit + self.statistics_y = Normal(loc='y', scale=1e-6) self.backend = 'sklearn' - self.backend_args = None + + if self.backend == 'sklearn': + # configure kernel_hyperparameters + self.kernel = 'rbf' + self.kernel_args = { + 'length_scale': 0.1, + 'constant_value': 1.0, + } + self.optimize_lengthscale = False + if self.optimize_lengthscale: + self.kernel_args.update( + { + 'length_scale_bounds': (0.02, 0.2), + } + ) + self.backend_args = {} + elif self.backend == 'sable': + self.kernel = 'rbf' + self.kernel_args = {'x_range': [-2.0, 2.0], 'sigma_range': [1.0e-3, 1.0], 'gamma': 0.1} + self.backend_args = { + 'n_features': 10000, + 'alpha': 0.0005, + 'p': 1.25, + 'n_iter_irls': 20, + } + self.strategy = 'upper_confidence_bound' self.strategy_args = {'exploit': 0.4, 'explore': 1} self.niter = 0 @@ -176,6 +191,7 @@ def callback_message(self, operation: str, **kwargs) -> IntersectClientCallback: strategy=self.strategy, strategy_args=self.strategy_args, bounds=self.bounds.tolist(), + y_is_good=True, ) elif operation == 'dial.update_workflow_with_data': @@ -278,7 +294,7 @@ def graph(self): axs[1].grid(True) plt.tight_layout() - plt.savefig('graph.png') + plt.savefig('graph_sinusoidal.png') plt.close(fig) From 5f3f0dd8c10df0091431500a6a5212fff770eccb Mon Sep 17 00:00:00 2001 From: Lance-Drane Date: Tue, 23 Jun 2026 12:32:38 -0400 Subject: [PATCH 17/22] fix Distribution pydantic typing Signed-off-by: Lance-Drane --- pyproject.toml | 2 +- src/dial_dataclass/dial_dataclass.py | 14 ++++++++---- uv.lock | 33 +++++++++++++++------------- 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c74e6db..b36ae29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ license = { text = "BSD-3-Clause" } classifiers = ["Programming Language :: Python :: 3"] # TODO move some dependencies into optional dependencies dependencies = [ - "intersect_sdk>=0.9.0,<0.10.0", + "intersect_sdk>=0.9.3,<0.10.0", "numpy", "scikit-learn>=1.4.0,<2.0.0", # TODO consider making an optional dependency group "scipy>=1.12.0,<2.0.0", diff --git a/src/dial_dataclass/dial_dataclass.py b/src/dial_dataclass/dial_dataclass.py index 04a8417..485f752 100644 --- a/src/dial_dataclass/dial_dataclass.py +++ b/src/dial_dataclass/dial_dataclass.py @@ -47,17 +47,23 @@ class BaseDistribution(BaseModel, ABC): class Delta(BaseDistribution): """The Delta distribution is deterministic and equal to its mean/loc.""" - name: str = Field(default='Delta', frozen=True) - scale: float = Field(gt=0.0, lt=0.0, default=0.0, frozen=True) + name: Literal['Delta'] = Field(default='Delta', frozen=True) + scale: float = Field(ge=0.0, le=0.0, default=0.0, frozen=True) class Normal(BaseDistribution): """The normal distribution is determined by loc (mean) and scale (standard deviation).""" - name: str = Field(default='Normal', frozen=True) + name: Literal['Normal'] = Field(default='Normal', frozen=True) -Distribution = Annotated[Delta | Normal, Field(description='Union of all supported Distributions.')] +Distribution = Annotated[ + Delta | Normal, + Field( + description='Union of all supported Distributions.', + discriminator='name', + ), +] def _validate_dataset_lengths(dataset: list[Any]) -> bool: diff --git a/uv.lock b/uv.lock index bc5466a..0e595b0 100644 --- a/uv.lock +++ b/uv.lock @@ -697,7 +697,7 @@ dev = [ requires-dist = [ { name = "furo", marker = "extra == 'docs'", specifier = ">=2023.3.27" }, { name = "gpax", specifier = ">=0.1.8" }, - { name = "intersect-sdk", specifier = ">=0.9.0,<0.10.0" }, + { name = "intersect-sdk", specifier = ">=0.9.3,<0.10.0" }, { name = "numpy" }, { name = "numpyro", specifier = "<0.20.1" }, { name = "pymongo", specifier = ">=4.12.1" }, @@ -1102,20 +1102,32 @@ wheels = [ [[package]] name = "intersect-sdk" -version = "0.9.0" +version = "0.9.3" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "intersect-sdk-common" }, { name = "jsonschema", extra = ["format-nongpl"] }, + { name = "psutil" }, + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ad/4a/4197405ffa4d42913d3358cf34d1b0ec15a841c994c5931b3785a975a86a/intersect_sdk-0.9.3.tar.gz", hash = "sha256:b0f563ae161f7851c1281fd4023222ce6509af33fae371f776470d1767dcfb7c", size = 54468, upload-time = "2026-06-23T16:02:18.534Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f0/53/a2fb84ca204d3b794c6c9e93487d38169a075b9422ca27d5b95a794528c2/intersect_sdk-0.9.3-py3-none-any.whl", hash = "sha256:d02b87b07134de96e060054750cd82445334a573f5df34a451d4d786118c904b", size = 66750, upload-time = "2026-06-23T16:02:17.437Z" }, +] + +[[package]] +name = "intersect-sdk-common" +version = "0.9.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ { name = "minio" }, { name = "paho-mqtt" }, { name = "pika" }, - { name = "psutil" }, { name = "pydantic" }, - { name = "retrying" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/26/4e/8671bcef8b222626aad9a615328c8b8823d577919d26f78b095018c0d715/intersect_sdk-0.9.0.tar.gz", hash = "sha256:76413ea6307e58b1e664a391e10cc2e31bcb899c4fee4e5e6f8c640fcebbd244", size = 108556, upload-time = "2026-02-11T19:32:24.078Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/65/83884f05d52241d0fe3362baaca4b08bfd447b9360daeb52dcfa9bfa8b63/intersect_sdk_common-0.9.6.tar.gz", hash = "sha256:b46e2a155f9f73c2fa621c0a1c212f97dd90f53f856e61d30faa5dc7986f7ecd", size = 32223, upload-time = "2026-06-02T21:23:19.529Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/dd/b7/4ce6222b6bffcae9413115298caea7b37c5b3b97f68746a27b4bc06ef9ff/intersect_sdk-0.9.0-py3-none-any.whl", hash = "sha256:6aa3c05d053c54b0f7c314bc8c618952a1c5c5d91427cd86e6c5c4851a2a60f3", size = 100771, upload-time = "2026-02-11T19:32:22.506Z" }, + { url = "https://files.pythonhosted.org/packages/0a/2a/3085cce8374b733c0255cc1d35889d42d1452bf33e681a996754bec77ebb/intersect_sdk_common-0.9.6-py3-none-any.whl", hash = "sha256:8947769d54eab04c4da3adbb403f0c224c07faf76a45f98e372ae112130adaf8", size = 45082, upload-time = "2026-06-02T21:23:20.407Z" }, ] [[package]] @@ -2992,15 +3004,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/f4/c67b0b3f1b9245e8d266f0f112c500d50e5b4e83cb6f3b71b6528104182a/requests-2.34.2-py3-none-any.whl", hash = "sha256:2a0d60c172f83ac6ab31e4554906c0f3b3588d37b5cb939b1c061f4907e278e0", size = 73075, upload-time = "2026-05-14T19:25:26.443Z" }, ] -[[package]] -name = "retrying" -version = "1.4.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c8/5a/b17e1e257d3e6f2e7758930e1256832c9ddd576f8631781e6a072914befa/retrying-1.4.2.tar.gz", hash = "sha256:d102e75d53d8d30b88562d45361d6c6c934da06fab31bd81c0420acb97a8ba39", size = 11411, upload-time = "2025-08-03T03:35:25.189Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/67/f3/6cd296376653270ac1b423bb30bd70942d9916b6978c6f40472d6ac038e7/retrying-1.4.2-py3-none-any.whl", hash = "sha256:bbc004aeb542a74f3569aeddf42a2516efefcdaff90df0eb38fbfbf19f179f59", size = 10859, upload-time = "2025-08-03T03:35:23.829Z" }, -] - [[package]] name = "rfc3339-validator" version = "0.1.4" From 4d010202b9401644bd8332ef405a9df9c5e152c1 Mon Sep 17 00:00:00 2001 From: Konstantin Pieper Date: Wed, 24 Jun 2026 13:05:35 -0400 Subject: [PATCH 18/22] remove transformed_stddevs and use the regular stddevs - clean up return values, now that the distinction between stddevs, and transformed stddevs is no longer necesssary --- scripts/1d_sable_client.py | 8 +++----- scripts/1d_sinusoidal_growth_client.py | 8 +++----- scripts/2d_rosenbrock_client.py | 2 +- src/dial_dataclass/dial_dataclass_responses.py | 10 ++++------ src/dial_service/core.py | 14 +++++++------- src/dial_service/dial_service.py | 18 +++++++++--------- tests/benchmarks/test_strainmap.py | 2 +- tests/unit/test_internals.py | 9 +++------ 8 files changed, 31 insertions(+), 40 deletions(-) diff --git a/scripts/1d_sable_client.py b/scripts/1d_sable_client.py index 1f4bb4d..727365b 100644 --- a/scripts/1d_sable_client.py +++ b/scripts/1d_sable_client.py @@ -220,14 +220,12 @@ def callback_message(self, operation: str, **kwargs) -> IntersectClientCallback: def handle_surrogate_values(self, payload): means = payload['values'] - transformed_stddevs = payload['transformed_stddevs'] + stddevs = payload['stddevs'] if self.at_grids: - self.stddev_grid = np.array(transformed_stddevs).reshape( - (self.meshgrid_size,) * self.num_dims - ) + self.stddev_grid = np.array(stddevs).reshape((self.meshgrid_size,) * self.num_dims) self.mean_grid = np.array(means).reshape((self.meshgrid_size,) * self.num_dims) else: - self.stddev_test = np.array(transformed_stddevs) + self.stddev_test = np.array(stddevs) self.mean_test = np.array(means) print( f'Values at testing points {self.x_test.reshape(-1)}: Mean: {self.mean_test}, Stddev: {self.stddev_test}' diff --git a/scripts/1d_sinusoidal_growth_client.py b/scripts/1d_sinusoidal_growth_client.py index 8ba9294..f141d0f 100644 --- a/scripts/1d_sinusoidal_growth_client.py +++ b/scripts/1d_sinusoidal_growth_client.py @@ -225,14 +225,12 @@ def callback_message(self, operation: str, **kwargs) -> IntersectClientCallback: def handle_surrogate_values(self, payload): means = payload['values'] - transformed_stddevs = payload['transformed_stddevs'] + stddevs = payload['stddevs'] if self.at_grids: - self.stddev_grid = np.array(transformed_stddevs).reshape( - (self.meshgrid_size,) * self.num_dims - ) + self.stddev_grid = np.array(stddevs).reshape((self.meshgrid_size,) * self.num_dims) self.mean_grid = np.array(means).reshape((self.meshgrid_size,) * self.num_dims) else: - self.stddev_test = np.array(transformed_stddevs) + self.stddev_test = np.array(stddevs) self.mean_test = np.array(means) print(f'Test Mean: {self.mean_test}, Std.Dev.: {self.stddev_test}') diff --git a/scripts/2d_rosenbrock_client.py b/scripts/2d_rosenbrock_client.py index f0a4d46..815e2e9 100644 --- a/scripts/2d_rosenbrock_client.py +++ b/scripts/2d_rosenbrock_client.py @@ -266,7 +266,7 @@ def graph(self, x_EI: list[float], final: bool = False): marker='o', s=300, ) - plt.savefig('graph.png') + plt.savefig('graph_rosenbrock.png') else: fig, ax = plt.subplots(figsize=(8, 6)) message = ( diff --git a/src/dial_dataclass/dial_dataclass_responses.py b/src/dial_dataclass/dial_dataclass_responses.py index b7cc4f8..cfcd806 100644 --- a/src/dial_dataclass/dial_dataclass_responses.py +++ b/src/dial_dataclass/dial_dataclass_responses.py @@ -43,10 +43,8 @@ class DialSurrogateValuesResponse(BaseModel): values: list[float] """The computed values (for example, from Gaussian backends, the means) from calling get_surrogate_values()""" - transformed_stddevs: list[float] - """The computed uncertainties from calling get_surrogate_values(), with an inverse transform. If inverse-transforming is not possible (due to log-preprocessing), this will be all -1""" - stddevs: list[float] # TODO will probably remove in future - """The computed raw uncertainties from calling get_surrogate_values(), without an inverse transform""" + stddevs: list[float] + """The computed uncertainties from calling get_surrogate_values(), with an inverse transform.""" dim_x: int """Number of dimensions of the associated data, derived from workflow""" bounds: list[list[float]] @@ -57,5 +55,5 @@ class DialSurrogateValuesResponse(BaseModel): """The same workflow ID that was used to get the data, to facilitate possible load balancing.""" dataset_x_size: int """Current length of dataset_x""" - transformed_stddevs_avg: float - """the average of the transformed stddevs being returned""" + stddevs_avg: float + """The average of the transformed stddevs being returned""" diff --git a/src/dial_service/core.py b/src/dial_service/core.py index 80f08df..d7b616d 100644 --- a/src/dial_service/core.py +++ b/src/dial_service/core.py @@ -102,25 +102,25 @@ def get_surrogate_values( """ Get surrogate model predictions for given input points. - Model parameter should be a pretrained model, you can usually call core.train_model with the same data parameter if you don't yet have a model. + Model parameter should be a pretrained model, + you can usually call core.train_model with the same data parameter if you don't yet have a model. Args: client_data (DialInputPredictions): Input data containing prediction points and model parameters. Returns: - tuple[list[float], list[float], list[float], float]: A tuple containing means, transformed standard deviations, raw standard deviations, and a float value. + tuple[list[float], list[float], float]: A tuple containing means, standard deviations, + standard deviations, and average standard deviation. """ backend = data.backend.lower() module = get_backend_module(backend) means, stddevs = module.predict(model, data) - means, transformed_stddevs = data.inverse_transform_Y(means, stddevs) - average = np.sqrt(np.mean(np.asarray(transformed_stddevs) ** 2)) - # TODO: the third return argument is not needed. + means, stddevs = data.inverse_transform_Y(means, stddevs) + average_stddev = np.sqrt(np.mean(np.asarray(stddevs) ** 2)) return ( means.tolist(), - transformed_stddevs.tolist(), stddevs.tolist(), - float(average), + float(average_stddev), ) diff --git a/src/dial_service/dial_service.py b/src/dial_service/dial_service.py index 3dc43bb..c8f89fe 100644 --- a/src/dial_service/dial_service.py +++ b/src/dial_service/dial_service.py @@ -308,10 +308,11 @@ def get_next_points(self, client_data: DialInputMultiple) -> DialDataResponse2D: def get_surrogate_values( self, client_data: DialInputPredictions ) -> DialSurrogateValuesResponse: - """Trains a model then returns 3 lists based on user-supplied points: - - Predicted values. These are inverse transformed (undoing the preprocessing to put them on the same scale as dataset_y) - - Inverse-transformed uncertainties. If inverse-transforming is not possible (due to log-preprocessing), this will be all -1 - - Uncertainties without inverse transformation + """Trains a model then returns two lists based on user-supplied points: + - Predicted values. + These are inverse transformed (undoing the preprocessing to put them on the same scale as dataset_y) + - Uncertainties. + These are inverse transformed standard errors, transformed according to the differential of the transform. Additional metadata is also returned in the response. """ @@ -339,17 +340,16 @@ def get_surrogate_values( validated_state.extra_args = client_data.extra_args data = ServersideInputPrediction(validated_state, client_data) - return_data = core.get_surrogate_values(data, model) + means, stddevs, average_stddev = core.get_surrogate_values(data, model) return DialSurrogateValuesResponse( - values=return_data[0], - transformed_stddevs=return_data[1], - stddevs=return_data[2], + values=means, + stddevs=stddevs, dim_x=validated_state.dim_x, points_to_predict=client_data.points_to_predict, bounds=validated_state.bounds, workflow_id=client_data.workflow_id, dataset_x_size=len(validated_state.dataset_x), - transformed_stddevs_avg=return_data[3], + stddevs_avg=average_stddev, ) except Exception as err: logger.exception( diff --git a/tests/benchmarks/test_strainmap.py b/tests/benchmarks/test_strainmap.py index 0b3cd4b..2bca51c 100644 --- a/tests/benchmarks/test_strainmap.py +++ b/tests/benchmarks/test_strainmap.py @@ -255,7 +255,7 @@ def run_simulation( points_to_predict=INITIAL_POINTS_TO_PREDICT, ), ) - surrogate_mean, surrogate_std, _, _ = dial_core.get_surrogate_values(data, model) + surrogate_mean, surrogate_std, _ = dial_core.get_surrogate_values(data, model) mean_grid = np.array(surrogate_mean).reshape((-1, 1)) # subtract the true values and save mean absolute error and standard deviation diff --git a/tests/unit/test_internals.py b/tests/unit/test_internals.py index 38a72c7..5d67c9b 100644 --- a/tests/unit/test_internals.py +++ b/tests/unit/test_internals.py @@ -605,7 +605,7 @@ def test_hypercube_multiple_points(backend): @pytest.mark.parametrize( - ('backend', 'expected_means', 'expected_stddevs', 'expected_raw_stddevs'), + ('backend', 'expected_means', 'expected_stddevs'), [ ( 'sklearn', @@ -617,7 +617,6 @@ def test_hypercube_multiple_points(backend): 199.99999999, ], [2.11126987e01, 2.96625069e01, 2.11126987e01], - [21.11269870647274, 29.662506906581378, 21.112698706472752], ), # ( # 'gpax', @@ -629,17 +628,15 @@ def test_hypercube_multiple_points(backend): # 82.26569221517353, # ], # [3335.7290084812175, 3327.202331393974, 3335.7290084812175], - # [3335.7290084812175, 3327.202331393974, 3335.7290084812175], # ), ], ) -def test_surrogate(backend, expected_means, expected_stddevs, expected_raw_stddevs): +def test_surrogate(backend, expected_means, expected_stddevs): data = prediction_1D(backend) model = core.train_model(data) - means, stddevs, raw_stddevs, _ = core.get_surrogate_values(data, model) + means, stddevs, _ = core.get_surrogate_values(data, model) assert means == pytest.approx(expected_means) assert stddevs[1:4] == pytest.approx(expected_stddevs) - assert raw_stddevs[1:4] == pytest.approx(expected_raw_stddevs) @pytest.mark.parametrize( From 55630190f14dc335c213678351be7532011fefc8 Mon Sep 17 00:00:00 2001 From: Konstantin Pieper Date: Wed, 24 Jun 2026 16:10:15 -0400 Subject: [PATCH 19/22] fix and document all currently supported strategies - fix sign mistake in EI and UncertaintyBound --- src/dial_dataclass/dial_dataclass.py | 11 +++++++++- src/dial_service/utilities/strategies.py | 28 +++++++++++++++++++----- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/dial_dataclass/dial_dataclass.py b/src/dial_dataclass/dial_dataclass.py index 485f752..feef85b 100644 --- a/src/dial_dataclass/dial_dataclass.py +++ b/src/dial_dataclass/dial_dataclass.py @@ -364,12 +364,21 @@ class DialInputSingleConfidenceBound(BaseModel): Field(min_length=2, max_length=2), ] ] + seed: Annotated[ + int, + Field( + default=-1, + ge=-1, + le=4294967295, + description='Specific RNG seed - use -1 to use system default', + ), + ] extra_args: dict[str, float | int | bool | str | list[float] | tuple] | None = Field( default=None ) """These extra arguments will be MERGED with the saved extra_args, with these arguments taking place over the saved values when applicable.""" optimization_points: PositiveIntType = Field(default=1000) - confidence_bound: float = Field(gt=0.5, lt=1) + confidence_bound: float = Field(gt=0.5, lt=1.0) discrete_measurements: bool = Field(default=False) discrete_measurement_grid_size: list[PositiveIntType] = Field(default=[20, 20]) point_index: Annotated[ diff --git a/src/dial_service/utilities/strategies.py b/src/dial_service/utilities/strategies.py index e0b42be..56899ac 100644 --- a/src/dial_service/utilities/strategies.py +++ b/src/dial_service/utilities/strategies.py @@ -16,10 +16,14 @@ def random_in_bounds(bounds: list[list[float]], rng: np.random.RandomState): def uncertainty_sampling(_mean, stddev, _data): + """Measure of uncertainty (stddev) for maximization""" return stddev def upper_confidence_bound(mean, stddev, data): + """Upper confidence bound for maximization + If y_is_good = False, multiply mean by -1. + """ _params = data.strategy_args y_is_good = data.y_is_good _direction = 1 if y_is_good else -1 @@ -31,6 +35,10 @@ def upper_confidence_bound(mean, stddev, data): def upper_confidence_bound_nomad(mean, stddev, data): + """Upper confidence bound (NOMAD specific version) for maximization + Masks the values around the last measurement point to force exploration. + If y_is_good = False, multiply mean by -1. + """ _params = data.strategy_args y_is_good = data.y_is_good _direction = 1 if y_is_good else -1 @@ -49,20 +57,30 @@ def upper_confidence_bound_nomad(mean, stddev, data): def expected_improvement(mean, stddev, data): + """Expected Improvement (EI) for maximization + If y_is_good = False, multiply mean and data value by -1. + """ _params = data.strategy_args y_is_good = data.y_is_good + _direction = 1 if y_is_good else -1 - if stddev < 1e-8: - return 0.0 - z = (mean - data.Y_best) / stddev * (1 if y_is_good else -1) - return -stddev * (z * norm.cdf(z) + norm.pdf(z)) + # guard against small or negative stddev + stddev = np.maximum(stddev, 1e-15) + + z = (mean - data.Y_best) / stddev * _direction + return stddev * (z * norm.cdf(z) + norm.pdf(z)) def confidence_bound(mean, stddev, data): + """Confidence bound for maximization + The same as upper_confidence_bound with exploit = 1., explore = norm.ppf(0.5 + data.confidence_bound / 2) + If y_is_good = False, multiply mean by -1. + """ y_is_good = data.y_is_good + _direction = 1 if y_is_good else -1 z_value = norm.ppf(0.5 + data.confidence_bound / 2) - return -z_value * stddev + mean * (-1 if y_is_good else 1) + return _direction * mean + z_value * stddev STRATEGIES = { From e0653f97e3b863f76276a347d0bfa239bcfaeade Mon Sep 17 00:00:00 2001 From: Konstantin Pieper Date: Wed, 24 Jun 2026 20:54:41 -0400 Subject: [PATCH 20/22] fix unit test for fixed EI criterion - increase number of greedy restarts for get_next_point, and add debug --- src/dial_dataclass/dial_dataclass.py | 1 - src/dial_service/utilities/strategies.py | 13 +++++++++---- tests/unit/test_internals.py | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/dial_dataclass/dial_dataclass.py b/src/dial_dataclass/dial_dataclass.py index feef85b..d67f2dd 100644 --- a/src/dial_dataclass/dial_dataclass.py +++ b/src/dial_dataclass/dial_dataclass.py @@ -168,7 +168,6 @@ class _DialWorkflowCreationParams(BaseModel): list[ float | Annotated[ - # TODO: this could be the default list[float], Field(description='Field lengths of all subarrays should be equal'), ] diff --git a/src/dial_service/utilities/strategies.py b/src/dial_service/utilities/strategies.py index 56899ac..51d155a 100644 --- a/src/dial_service/utilities/strategies.py +++ b/src/dial_service/utilities/strategies.py @@ -140,19 +140,18 @@ def to_minimize(_x: np.ndarray): if data.discrete_measurements: _measurement_grid = create_measurement_grid(data) - # TODO - commented line is known to fail for expected_improvement regarding discrete measurements - # response_surface = to_minimize(_measurement_grid) - response_surface = [to_minimize(point) for point in _measurement_grid] + response_surface = to_minimize(_measurement_grid) index = np.int64(np.argmin(response_surface)) selected_point = _measurement_grid[index] logger.debug('selected point with discrete measurements') logger.debug(selected_point) return selected_point - n_restarts = 10 + n_restarts = 25 init_array = np.array(hypercube(data.bounds, n_restarts, data.numpy_rng)) best_score = np.inf selected_point = None + # out_list = [] for x_init in init_array: res = minimize( to_minimize, @@ -164,6 +163,12 @@ def to_minimize(_x: np.ndarray): if res.fun < best_score: best_score = res.fun selected_point = res.x + # out_list.append((x_init.tolist(), res.x.tolist(), res.fun)) + + logger.debug('selected point with optimization') + logger.debug('score and point: %f, %s', best_score, str(selected_point)) + # print(f'optimized: {best_score}, {selected_point}:', '\n', + # '\n'.join([str(out) for out in out_list])) return selected_point.tolist() diff --git a/tests/unit/test_internals.py b/tests/unit/test_internals.py index 5d67c9b..a84d894 100644 --- a/tests/unit/test_internals.py +++ b/tests/unit/test_internals.py @@ -392,7 +392,7 @@ def test_uncertainty(backend, approx): @pytest.mark.parametrize( ('backend', 'approx'), [ - ('sklearn', [1.037454]), + ('sklearn', [1.790396262]), # ('gpax', [2.0]), ], ) From d36c7658f1af29324a4efd36cc153a0ca0e46c8f Mon Sep 17 00:00:00 2001 From: Konstantin Pieper Date: Wed, 24 Jun 2026 22:01:13 -0400 Subject: [PATCH 21/22] update clients to test more settings - test using different acquisition strategies for rosenbrock in different iterations --- scripts/1d_sinusoidal_growth_client.py | 19 ++++--- scripts/2d_rosenbrock_client.py | 74 ++++++++++++++++++++------ 2 files changed, 67 insertions(+), 26 deletions(-) diff --git a/scripts/1d_sinusoidal_growth_client.py b/scripts/1d_sinusoidal_growth_client.py index f141d0f..756f4e8 100644 --- a/scripts/1d_sinusoidal_growth_client.py +++ b/scripts/1d_sinusoidal_growth_client.py @@ -53,14 +53,14 @@ def __init__(self): class ActiveLearningOrchestrator: def __init__(self, service_destination: str): - self.bounds = np.array([[-2.0, 2.0]]) + self.bounds = [[-2.0, 2.0]] self.num_dims = len(self.bounds) self.x_raw = np.array([[1.0], [2.0]]) self.x_test = np.array([[-1.0], [0.5]]) self.y_raw = sinusoidal_growth(self.x_raw) - self.meshgrid_size = 100 + self.meshgrid_size = 200 self.grid_points = [ np.linspace(dim_bounds[0], dim_bounds[1], self.meshgrid_size) for dim_bounds in self.bounds @@ -77,10 +77,9 @@ def __init__(self, service_destination: str): self.statistics_y = Normal(loc='y', scale=1e-6) self.backend = 'sklearn' - if self.backend == 'sklearn': # configure kernel_hyperparameters - self.kernel = 'rbf' + self.kernel = 'matern' # 'rbf' or 'matern' self.kernel_args = { 'length_scale': 0.1, 'constant_value': 1.0, @@ -129,16 +128,17 @@ def __call__( _has_error: bool, payload: INTERSECT_RESPONSE_VALUE, ) -> IntersectClientCallback: - print( - f'Received message from {_source} with operation {operation} and payload {payload}', - file=sys.stderr, - ) if _has_error: print('============ERROR==============', file=sys.stderr) print(operation, file=sys.stderr) print(payload, file=sys.stderr) raise IntersectCallbackError(operation, payload) + # print( + # f'Received message from {_source} with operation {operation} and payload {payload}', + # file=sys.stderr, + # ) + if operation == 'dial.initialize_workflow': self.workflow_id = payload return self.callback_message('dial.get_surrogate_values') @@ -179,7 +179,6 @@ def callback_message(self, operation: str, **kwargs) -> IntersectClientCallback: backend_args=self.backend_args, preprocess_standardize=True, y_is_good=True, - seed=20, ) elif operation == 'dial.get_surrogate_values': @@ -198,7 +197,7 @@ def callback_message(self, operation: str, **kwargs) -> IntersectClientCallback: workflow_id=self.workflow_id, strategy=self.strategy, strategy_args=self.strategy_args, - bounds=self.bounds.tolist(), + bounds=self.bounds, y_is_good=True, ) diff --git a/scripts/2d_rosenbrock_client.py b/scripts/2d_rosenbrock_client.py index 815e2e9..9be7e92 100644 --- a/scripts/2d_rosenbrock_client.py +++ b/scripts/2d_rosenbrock_client.py @@ -3,6 +3,7 @@ import logging import os import sys +from collections import deque from pathlib import Path from typing import Any @@ -22,6 +23,7 @@ # from scipy.stats import qmc from dial_dataclass import ( DialInputPredictions, + DialInputSingleConfidenceBound, DialInputSingleOtherStrategy, DialWorkflowCreationParamsClient, DialWorkflowDatasetUpdate, @@ -65,11 +67,11 @@ def generate_dataset_x(num_dims): ) INITIAL_POINTS_TO_PREDICT = np.hstack([mg.reshape(-1, 1) for mg in INITIAL_MESHGRIDS]).tolist() -NUM_ITERATIONS = 35 +NUM_ITERATIONS = 200 # HYPERPARAMETERS -LENGTH_SCALE = 0.2 -NOISE_LEVEL = 10e-6 +LENGTH_SCALE = 0.1 +NOISE_LEVEL = 10e-8 CONSTANT_VALUE = 1.0 @@ -86,6 +88,23 @@ def __init__(self, service_destination: str, rosenbrock_destination: str): self.dataset_x = INITIAL_DATASET_X self.dataset_y: list[float] = [] + # we want to minimize + self.y_is_good = False + + # Initialize a deque of strategies with desired number of iterations + self.strategies = deque( + [ + (20, 'uncertainty', {}), + (40, 'upper_confidence_bound', {'exploit': 1.0, 'explore': 0.5}), + (40, 'confidence_bound', {'confidence_bound': 0.9}), + (200, 'expected_improvement', {}), + ] + ) + + # if we want to test discrete measurements + self.discrete_measurements = False + self.discrete_measurement_grid_size = [200, 200] + # create a message to send to the server def assemble_message(self, operation: str, **kwargs: Any) -> IntersectClientCallback: if operation == 'initialize_workflow': @@ -103,8 +122,9 @@ def assemble_message(self, operation: str, **kwargs: Any) -> IntersectClientCall 'constant_value': CONSTANT_VALUE, 'constant_value_bounds': 'fixed', }, - y_is_good=False, # we wish to minimize y (the error) - backend='sklearn', # "sklearn" or "gpax" + preprocess_standardize=True, + y_is_good=self.y_is_good, # we wish to minimize y (the error) + backend='sklearn', seed=-1, # Use seed = -1 for random results ) elif operation == 'update_workflow_with_data': @@ -113,11 +133,31 @@ def assemble_message(self, operation: str, **kwargs: Any) -> IntersectClientCall **kwargs, ) elif operation == 'get_next_point': - payload = DialInputSingleOtherStrategy( - workflow_id=self.workflow_id, - strategy='expected_improvement', - bounds=INITIAL_BOUNDS, - ) + # select strategy + Niter, strategy, strategy_args = self.strategies.popleft() + Niter -= 1 + if Niter > 0: + self.strategies.appendleft((Niter, strategy, strategy_args)) + if strategy == 'confidence_bound': + payload = DialInputSingleConfidenceBound( + workflow_id=self.workflow_id, + bounds=INITIAL_BOUNDS, + y_is_good=self.y_is_good, # we wish to minimize y (the error) + strategy='confidence_bound', + confidence_bound=strategy_args['confidence_bound'], + discrete_measurements=self.discrete_measurements, + discrete_measurement_grid_size=self.discrete_measurement_grid_size, + ) + else: + payload = DialInputSingleOtherStrategy( + workflow_id=self.workflow_id, + bounds=INITIAL_BOUNDS, + y_is_good=self.y_is_good, # we wish to minimize y (the error) + strategy=strategy, + strategy_args=strategy_args, + discrete_measurements=self.discrete_measurements, + discrete_measurement_grid_size=self.discrete_measurement_grid_size, + ) elif operation == 'get_surrogate_values': payload = DialInputPredictions( workflow_id=self.workflow_id, @@ -178,7 +218,9 @@ def __call__( if operation == 'Rosenbrock.rosenbrock': # this operation gets called periodically self.dataset_y.append(payload) - print(f'{payload:.3f}') + coord_str = ', '.join([f'{x:.2f}' for x in self.dataset_x[-1]]) + strategy = self.strategies[0][1] + print(f'got value {payload:.5f} at [{coord_str}] with strategy {strategy}') if len(self.dataset_x) == NUM_ITERATIONS: minpos = np.argmin(self.dataset_y) y_opt = self.dataset_y[minpos] @@ -186,7 +228,7 @@ def __call__( self.graph(optimal_coords, True) coord_str = ', '.join([f'{coord:.2f}' for coord in optimal_coords]) print( - f'Optimal simulated datapoint at ({coord_str}), y={y_opt:.3f}', + f'Optimal simulated datapoint at ({coord_str}), y={y_opt:.5f}', end='\n', flush=True, ) @@ -244,8 +286,8 @@ def graph(self, x_EI: list[float], final: bool = False): plt.ylabel('Simulation Parameter #2') # add black dots for data points and a red marker for the recommendation: X_train = np.array(self.dataset_x) - plt.scatter(X_train[:, 0], X_train[:, 1], color='black', marker='o') - plt.scatter(1.0, 1.0, s=300, color='None', edgecolors='black', marker='o') + plt.scatter(X_train[:, 0], X_train[:, 1], color='black', marker='.') + plt.scatter(1.0, 1.0, s=300, color='None', edgecolors='tab:orange', marker='o') minpos = np.argmin(self.dataset_y) optimal_coords = self.dataset_x[minpos] @@ -257,7 +299,7 @@ def graph(self, x_EI: list[float], final: bool = False): f'Best point estimate so far is x=({final_x}), y={self.dataset_y[minpos]:.3f}' ) else: - plt.scatter([x_EI[0]], [x_EI[1]], color='red', marker='o') + plt.scatter([x_EI[0]], [x_EI[1]], color='tab:red', marker='o') plt.scatter( [x_EI[0]], [x_EI[1]], @@ -266,7 +308,7 @@ def graph(self, x_EI: list[float], final: bool = False): marker='o', s=300, ) - plt.savefig('graph_rosenbrock.png') + plt.savefig('graph_rosenbrock.png', dpi=200) else: fig, ax = plt.subplots(figsize=(8, 6)) message = ( From 99e4e71d9d68c36a442343d881a10a21b8625bc4 Mon Sep 17 00:00:00 2001 From: Konstantin Pieper Date: Wed, 24 Jun 2026 22:45:40 -0400 Subject: [PATCH 22/22] make DialInput repeated arguments from DialWorkflow optional - if the repeat arguments bounds, y_is_good, extra_args are not None, they will replace the stored values in the initial workflow initialization --- src/dial_dataclass/dial_dataclass.py | 78 +++++++++++++++++++--------- src/dial_service/dial_service.py | 12 ----- src/dial_service/serverside_data.py | 42 ++++++++++++--- 3 files changed, 87 insertions(+), 45 deletions(-) diff --git a/src/dial_dataclass/dial_dataclass.py b/src/dial_dataclass/dial_dataclass.py index d67f2dd..8f45ebc 100644 --- a/src/dial_dataclass/dial_dataclass.py +++ b/src/dial_dataclass/dial_dataclass.py @@ -347,21 +347,31 @@ def validate_dims_and_length(self): class DialInputSingleConfidenceBound(BaseModel): + """This class is used to request a single next point using the confidence_bound strategy.""" + workflow_id: ValidatedObjectId strategy: Literal['confidence_bound'] strategy_args: dict[str, float | int | bool] | None = Field(default=None) y_is_good: Annotated[ - bool, + bool | None, Field( - default=True, # <-- Set default here - description='If true, treat higher y values as better (e.g. y represents yield or profit). If false, opposite (e.g. y represents error or waste)', + default=None, + description=( + 'If true, treat higher y values as better (e.g. y represents yield or profit).' + ' If false, opposite (e.g. y represents error or waste).' + ' If None, the value from DialWorkflowCreationParams is used.' + ), ), ] - bounds: list[ - Annotated[ - Annotated[list[float], Field(min_length=2, max_length=2)], - Field(min_length=2, max_length=2), - ] + bounds: Annotated[ + None + | list[ + Annotated[ + Annotated[list[float], Field(min_length=2, max_length=2)], + Field(min_length=2, max_length=2), + ] + ], + Field(default=None), ] seed: Annotated[ int, @@ -390,6 +400,8 @@ class DialInputSingleConfidenceBound(BaseModel): class DialInputSingleOtherStrategy(BaseModel): + """This class is used to request a single next point using a given strategy.""" + workflow_id: ValidatedObjectId strategy: Literal[ 'random', @@ -402,17 +414,25 @@ class DialInputSingleOtherStrategy(BaseModel): ] strategy_args: dict[str, float | int | bool] | None = Field(default=None) y_is_good: Annotated[ - bool, + bool | None, Field( - default=True, # <-- Set default here - description='If true, treat higher y values as better (e.g. y represents yield or profit). If false, opposite (e.g. y represents error or waste)', + default=None, + description=( + 'If true, treat higher y values as better (e.g. y represents yield or profit).' + ' If false, opposite (e.g. y represents error or waste).' + ' If None, the value from DialWorkflowCreationParams is used.' + ), ), ] - bounds: list[ - Annotated[ - Annotated[list[float], Field(min_length=2, max_length=2)], - Field(min_length=2, max_length=2), - ] + bounds: Annotated[ + None + | list[ + Annotated[ + Annotated[list[float], Field(min_length=2, max_length=2)], + Field(min_length=2, max_length=2), + ] + ], + Field(default=None), ] seed: Annotated[ int, @@ -449,7 +469,7 @@ class DialInputSingleOtherStrategy(BaseModel): class DialInputMultipleOtherStrategy(BaseModel): - """TODO: document this""" + """This class is used to request multiple next points (of number points) using a given strategy.""" workflow_id: ValidatedObjectId points: PositiveIntType @@ -464,17 +484,25 @@ class DialInputMultipleOtherStrategy(BaseModel): ] strategy_args: dict[str, float | int | bool] | None = Field(default=None) y_is_good: Annotated[ - bool, + bool | None, Field( - default=True, # <-- Set default here - description='If true, treat higher y values as better (e.g. y represents yield or profit). If false, opposite (e.g. y represents error or waste)', + default=None, + description=( + 'If true, treat higher y values as better (e.g. y represents yield or profit).' + ' If false, opposite (e.g. y represents error or waste).' + ' If None, the value from DialWorkflowCreationParams is used.' + ), ), ] - bounds: list[ - Annotated[ - Annotated[list[float], Field(min_length=2, max_length=2)], - Field(min_length=2, max_length=2), - ] + bounds: Annotated[ + None + | list[ + Annotated[ + Annotated[list[float], Field(min_length=2, max_length=2)], + Field(min_length=2, max_length=2), + ] + ], + Field(default=None), ] seed: Annotated[ int, diff --git a/src/dial_service/dial_service.py b/src/dial_service/dial_service.py index c8f89fe..bc1a5be 100644 --- a/src/dial_service/dial_service.py +++ b/src/dial_service/dial_service.py @@ -238,13 +238,7 @@ def get_next_point(self, client_data: DialInputSingle) -> DialDataResponse1D: try: model = pickle.loads(workflow_state['model']) # noqa: S301 (XXX - this is technically trusted data as long as the DB hasn't been modified) validated_state = DialWorkflowCreationParamsService(**workflow_state) - if client_data.extra_args: - if validated_state.extra_args: - validated_state.extra_args.update(client_data.extra_args) - else: - validated_state.extra_args = client_data.extra_args data = ServersideInputSingle(validated_state, client_data) - return_data = core.get_next_point(data, model) return DialDataResponse1D( data=return_data, @@ -284,13 +278,7 @@ def get_next_points(self, client_data: DialInputMultiple) -> DialDataResponse2D: try: model = pickle.loads(workflow_state['model']) # noqa: S301 (XXX - this is technically trusted data as long as the DB hasn't been modified) validated_state = DialWorkflowCreationParamsService(**workflow_state) - if client_data.extra_args: - if validated_state.extra_args: - validated_state.extra_args.update(client_data.extra_args) - else: - validated_state.extra_args = client_data.extra_args data = ServersideInputMultiple(validated_state, client_data) - return_data = core.get_next_points(data, model) return DialDataResponse2D( data=return_data, diff --git a/src/dial_service/serverside_data.py b/src/dial_service/serverside_data.py index 4168482..0602da2 100644 --- a/src/dial_service/serverside_data.py +++ b/src/dial_service/serverside_data.py @@ -185,20 +185,35 @@ def __init__( params: DialInputSingle, ): super().__init__(workflow_state) + # set new inputs self.strategy = params.strategy self.strategy_args = params.strategy_args - self.y_is_good = params.y_is_good - self.bounds = params.bounds - self.numpy_rng = np.random.RandomState(None if params.seed == -1 else params.seed) - self.optimization_points = params.optimization_points self.confidence_bound = ( params.confidence_bound if params.strategy == 'confidence_bound' else 0.0 ) + # if params.strategy == 'confidence_bound': + # self.confidence_bound = params.confidence_bound + # elif self.strategy_args is not None and 'confidence_bounds' in self.strategy_args: + # self.confidence_bound = params.strategy_args['confidence_bound'] self.discrete_measurements = params.discrete_measurements self.discrete_measurement_grid_size = params.discrete_measurement_grid_size self.point_index = params.point_index + # update values from workflow initialization, if provided + if params.extra_args is not None: + if self.extra_args is not None: + self.extra_args.update(params.extra_args) + else: + self.extra_args = params.extra_args + if params.y_is_good is not None: + self.y_is_good = params.y_is_good + if params.bounds is not None: + self.bounds = params.bounds + + # always reinit rng, since initial rng is not updated in db! + self.numpy_rng = np.random.RandomState(None if params.seed == -1 else params.seed) + def set_x_predict(self, X_raw: np.ndarray) -> None: """ Store raw prediction points and their scaled version. @@ -214,15 +229,12 @@ def __init__( workflow_state: DialWorkflowCreationParamsService, params: DialInputMultiple, ): + # set new inputs super().__init__(workflow_state) self.strategy = params.strategy self.points = params.points self.strategy = params.strategy self.strategy_args = params.strategy_args - self.y_is_good = params.y_is_good - self.bounds = params.bounds - self.numpy_rng = np.random.RandomState(None if params.seed == -1 else params.seed) - self.optimization_points = params.optimization_points self.confidence_bound = ( params.confidence_bound if params.strategy == 'confidence_bound' else 0.0 @@ -230,6 +242,20 @@ def __init__( self.discrete_measurements = params.discrete_measurements self.discrete_measurement_grid_size = params.discrete_measurement_grid_size + # update values from workflow initialization, if provided + if params.extra_args is not None: + if self.extra_args is not None: + self.extra_args.update(params.extra_args) + else: + self.extra_args = params.extra_args + if params.y_is_good is not None: + self.y_is_good = params.y_is_good + if params.bounds is not None: + self.bounds = params.bounds + + # always reinit rng, since initial rng is not updated in db! + self.numpy_rng = np.random.RandomState(None if params.seed == -1 else params.seed) + def set_x_predict(self, X_raw: np.ndarray) -> None: """ Store raw prediction points and their scaled version.