Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
af849d5
better validation of inputs
konstpieper May 12, 2026
f620bfd
draft: allow for multidimensional output.
konstpieper May 13, 2026
0225425
add todo
konstpieper May 13, 2026
db6f1ed
more validation
konstpieper May 13, 2026
a845af9
Merge remote-tracking branch 'origin/develop' into sable
konstpieper May 21, 2026
cb7c59f
backward compatible extensions to dial_dataclass
konstpieper May 27, 2026
bbc61c3
Merge remote-tracking branch 'origin/main' into sable
konstpieper May 27, 2026
ef866f0
fix dial_dataclass validation of bounds
konstpieper May 27, 2026
41cfa5d
Merge branch 'develop' into sable
konstpieper May 28, 2026
78c7678
add autogenerated files to gitignore
konstpieper May 29, 2026
bf1d6e8
formatting changes from ruff
konstpieper May 29, 2026
2a042d7
extract y and yerr from dataset_y unsing statistics_y
konstpieper May 29, 2026
b514dac
remove leftover method Y_stddev
konstpieper May 29, 2026
9ae0173
migrate sable_backend and client to the new (y,yerr) logic
konstpieper May 29, 2026
233081b
Merge remote-tracking branch 'origin/develop' into sable
konstpieper Jun 8, 2026
952061b
make sk_learn backend respect output_statistics
konstpieper Jun 8, 2026
8b21e7f
fixes for clients
konstpieper Jun 8, 2026
4a7724a
re-enable extra_args length_per_dimension
konstpieper Jun 8, 2026
f9961e3
do not rely on type comparision, but use name instead
konstpieper Jun 15, 2026
169fe86
updates to sinosidal_growth test
konstpieper Jun 15, 2026
0c0be40
Merge remote-tracking branch 'origin/develop' into sable
Lance-Drane Jun 23, 2026
5f3f0dd
fix Distribution pydantic typing
Lance-Drane Jun 23, 2026
4d01020
remove transformed_stddevs and use the regular stddevs
konstpieper Jun 24, 2026
5563019
fix and document all currently supported strategies
konstpieper Jun 24, 2026
e0653f9
fix unit test for fixed EI criterion
konstpieper Jun 25, 2026
d36c765
update clients to test more settings
konstpieper Jun 25, 2026
99e4e71
make DialInput repeated arguments from DialWorkflow optional
konstpieper Jun 25, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -212,5 +212,6 @@ pyrightconfig.json
# End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode

.venv*

Taskfile.yml
.DS_Store
*~
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ license = { text = "BSD-3-Clause" }
classifiers = ["Programming Language :: Python :: 3"]
# TODO move some dependencies into optional dependencies
dependencies = [
"intersect_sdk>=0.9.0,<0.10.0",
"intersect_sdk>=0.9.3,<0.10.0",
"numpy",
"scikit-learn>=1.4.0,<2.0.0", # TODO consider making an optional dependency group
"scipy>=1.12.0,<2.0.0",
Expand Down
14 changes: 8 additions & 6 deletions scripts/1d_sable_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
DialInputSingleOtherStrategy,
DialWorkflowCreationParamsClient,
DialWorkflowDatasetUpdate,
Normal,
)

mpl.use('agg')
Expand Down Expand Up @@ -94,6 +95,8 @@ def __init__(self, service_destination: str):

self.dataset_x = self.x_raw.reshape(-1, 1).tolist()
self.dataset_y = self.y_raw.reshape(-1).tolist()
self.labels_y = ['y']
self.statistics_y = Normal(loc='y', scale=self.noise_level)
self.test_points = self.x_test.reshape(-1, 1).tolist()

self.kernel = 'rbf'
Expand All @@ -104,7 +107,6 @@ def __init__(self, service_destination: str):
'alpha': 0.05,
'p': 1.25,
'n_iter_irls': 100,
'noise_level': self.noise_level,
}
self.strategy = 'upper_confidence_bound'
self.strategy_args = {'exploit': 0.0, 'explore': 1.0}
Expand Down Expand Up @@ -160,6 +162,8 @@ def callback_message(self, operation: str, **kwargs) -> IntersectClientCallback:
next_payload = DialWorkflowCreationParamsClient(
dataset_x=self.dataset_x,
dataset_y=self.dataset_y,
labels_y=self.labels_y,
statistics_y=self.statistics_y,
bounds=self.bounds.tolist(),
kernel=self.kernel,
kernel_args=self.kernel_args,
Expand Down Expand Up @@ -216,14 +220,12 @@ def callback_message(self, operation: str, **kwargs) -> IntersectClientCallback:

def handle_surrogate_values(self, payload):
means = payload['values']
transformed_stddevs = payload['transformed_stddevs']
stddevs = payload['stddevs']
if self.at_grids:
self.stddev_grid = np.array(transformed_stddevs).reshape(
(self.meshgrid_size,) * self.num_dims
)
self.stddev_grid = np.array(stddevs).reshape((self.meshgrid_size,) * self.num_dims)
self.mean_grid = np.array(means).reshape((self.meshgrid_size,) * self.num_dims)
else:
self.stddev_test = np.array(transformed_stddevs)
self.stddev_test = np.array(stddevs)
self.mean_test = np.array(means)
print(
f'Values at testing points {self.x_test.reshape(-1)}: Mean: {self.mean_test}, Stddev: {self.stddev_test}'
Expand Down
101 changes: 68 additions & 33 deletions scripts/1d_sinusoidal_growth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
DialInputSingleOtherStrategy,
DialWorkflowCreationParamsClient,
DialWorkflowDatasetUpdate,
Normal,
)

mpl.use('agg')
Expand Down Expand Up @@ -52,14 +53,14 @@ def __init__(self):

class ActiveLearningOrchestrator:
def __init__(self, service_destination: str):
self.bounds = np.array([[-2, 2]])
self.bounds = [[-2.0, 2.0]]
self.num_dims = len(self.bounds)

self.x_raw = np.array([[1], [2.0]])
self.x_test = np.array([[-1], [0.5]])
self.x_raw = np.array([[1.0], [2.0]])
self.x_test = np.array([[-1.0], [0.5]])
self.y_raw = sinusoidal_growth(self.x_raw)

self.meshgrid_size = 100
self.meshgrid_size = 200
self.grid_points = [
np.linspace(dim_bounds[0], dim_bounds[1], self.meshgrid_size)
for dim_bounds in self.bounds
Expand All @@ -72,18 +73,47 @@ def __init__(self, service_destination: str):
self.dataset_y = self.y_raw.reshape(-1).tolist()
self.test_points = self.x_test.reshape(-1, 1).tolist()

self.kernel = 'rbf'
self.kernel_args = {'length_scale': 0.12, 'length_scale_bounds': (0.1, 1.0)}
# Assume that there is some small noise in the measurements to stabilize the fit
self.statistics_y = Normal(loc='y', scale=1e-6)

self.backend = 'sklearn'
self.backend_args = None
if self.backend == 'sklearn':
# configure kernel_hyperparameters
self.kernel = 'matern' # 'rbf' or 'matern'
self.kernel_args = {
'length_scale': 0.1,
'constant_value': 1.0,
}
self.optimize_lengthscale = False
if self.optimize_lengthscale:
self.kernel_args.update(
{
'length_scale_bounds': (0.02, 0.2),
}
)
self.backend_args = {}
elif self.backend == 'sable':
self.kernel = 'rbf'
self.kernel_args = {
'x_range': [-2.0, 2.0],
'sigma_range': [1.0e-3, 1.0],
'gamma': 0.1,
}
self.backend_args = {
'n_features': 10000,
'alpha': 0.0005,
'p': 1.25,
'n_iter_irls': 20,
}

self.strategy = 'upper_confidence_bound'
self.strategy_args = {'exploit': 0.4, 'explore': 1}
self.niter = 0
self.max_iter = 20
self.at_grids = True
self.variance_grid = None
self.stddev_grid = None
self.mean_grid = None
self.variance_test = None
self.stddev_test = None
self.mean_test = None
self.x_next = None

Expand All @@ -92,18 +122,23 @@ def __init__(self, service_destination: str):
self.service_destination = service_destination

def __call__(
self, _source: str, operation: str, _has_error: bool, payload: INTERSECT_RESPONSE_VALUE
self,
_source: str,
operation: str,
_has_error: bool,
payload: INTERSECT_RESPONSE_VALUE,
) -> IntersectClientCallback:
print(
f'Received message from {_source} with operation {operation} and payload {payload}',
file=sys.stderr,
)
if _has_error:
print('============ERROR==============', file=sys.stderr)
print(operation, file=sys.stderr)
print(payload, file=sys.stderr)
raise IntersectCallbackError(operation, payload)

# print(
# f'Received message from {_source} with operation {operation} and payload {payload}',
# file=sys.stderr,
# )

if operation == 'dial.initialize_workflow':
self.workflow_id = payload
return self.callback_message('dial.get_surrogate_values')
Expand Down Expand Up @@ -142,10 +177,8 @@ def callback_message(self, operation: str, **kwargs) -> IntersectClientCallback:
kernel_args=self.kernel_args,
backend=self.backend,
backend_args=self.backend_args,
extra_args={'length_per_dimension': True},
preprocess_standardize=True,
y_is_good=True,
seed=20,
)

elif operation == 'dial.get_surrogate_values':
Expand All @@ -164,7 +197,8 @@ def callback_message(self, operation: str, **kwargs) -> IntersectClientCallback:
workflow_id=self.workflow_id,
strategy=self.strategy,
strategy_args=self.strategy_args,
bounds=self.bounds.tolist(),
bounds=self.bounds,
y_is_good=True,
)

elif operation == 'dial.update_workflow_with_data':
Expand All @@ -181,23 +215,23 @@ def callback_message(self, operation: str, **kwargs) -> IntersectClientCallback:
return IntersectClientCallback(
messages_to_send=[
IntersectDirectMessageParams(
destination=self.service_destination, operation=operation, payload=next_payload
destination=self.service_destination,
operation=operation,
payload=next_payload,
)
]
)

def handle_surrogate_values(self, payload):
means = payload['values']
transformed_stddevs = payload['transformed_stddevs']
stddevs = payload['stddevs']
if self.at_grids:
self.variance_grid = np.array(transformed_stddevs).reshape(
(self.meshgrid_size,) * self.num_dims
)
self.stddev_grid = np.array(stddevs).reshape((self.meshgrid_size,) * self.num_dims)
self.mean_grid = np.array(means).reshape((self.meshgrid_size,) * self.num_dims)
else:
self.variance_test = np.array(transformed_stddevs)
self.stddev_test = np.array(stddevs)
self.mean_test = np.array(means)
print(f'Test Mean: {self.mean_test}, Variance: {self.variance_test}')
print(f'Test Mean: {self.mean_test}, Std.Dev.: {self.stddev_test}')

# end of active learning loop after max_iter
if self.niter > self.max_iter:
Expand All @@ -224,12 +258,12 @@ def graph(self):

fig, axs = plt.subplots(2, 1, figsize=(10, 8), sharex=True)

# First subplot: Mean and variance with training data
# First subplot: Mean and standard deviation with training data
axs[0].plot(self.x_grid, self.mean_grid, label='Mean Prediction')
axs[0].fill_between(
self.x_grid[:, 0],
self.mean_grid + 2 * self.variance_grid,
self.mean_grid - 2 * self.variance_grid,
self.mean_grid + 2 * self.stddev_grid,
self.mean_grid - 2 * self.stddev_grid,
alpha=0.5,
label='Confidence Interval',
)
Expand All @@ -248,27 +282,28 @@ def graph(self):

# Second subplot: Acquisition function
if self.strategy_args is not None:
if self.mean_grid is not None and self.variance_grid is not None:
if self.mean_grid is not None and self.stddev_grid is not None:
exploit = self.strategy_args.get('exploit', 0.0)
explore = self.strategy_args.get('explore', 1.0)
acquisition_values = exploit * self.mean_grid + explore * np.sqrt(
self.variance_grid
)
acquisition_values = exploit * self.mean_grid + explore * self.stddev_grid
else:
acquisition_values = np.zeros_like(self.x_grid)

axs[1].plot(self.x_grid, acquisition_values)
if self.x_next is not None:
axs[1].axvline(
x=self.x_next[0], color='red', linestyle='--', label='Next Point (x_next)'
x=self.x_next[0],
color='red',
linestyle='--',
label='Next Point (x_next)',
)
axs[1].set_xlabel('Features, x')
axs[1].set_ylabel('Acquisition Value')
axs[1].legend()
axs[1].grid(True)

plt.tight_layout()
plt.savefig('graph.png')
plt.savefig('graph_sinusoidal.png')
plt.close(fig)


Expand Down
Loading
Loading