diff --git a/transfer_queue/config.yaml b/transfer_queue/config.yaml index 4759c36..8a6df3d 100644 --- a/transfer_queue/config.yaml +++ b/transfer_queue/config.yaml @@ -23,8 +23,9 @@ backend: # SimpleStorage: ZMQ-based in-memory storage for out-of-the-box usage SimpleStorage: - # Maximum number of experience samples to hold across all storage units - total_storage_size: 100000 + # Maximum number of experience samples to hold across all storage units. + # Set to null for unlimited capacity (no capacity check). + total_storage_size: null # Number of distributed storage units. # Recommended: >= 2 x number of nodes for load balancing. num_data_storage_units: 2 diff --git a/transfer_queue/storage/bootstrap/simple_storage_bootstrap.py b/transfer_queue/storage/bootstrap/simple_storage_bootstrap.py index 1ab2f6b..adbd539 100644 --- a/transfer_queue/storage/bootstrap/simple_storage_bootstrap.py +++ b/transfer_queue/storage/bootstrap/simple_storage_bootstrap.py @@ -33,16 +33,21 @@ def initialize_simple_storage(conf: DictConfig) -> dict[str, Any]: simple_storage_handles = {} num_data_storage_units = conf.backend.SimpleStorage.num_data_storage_units - total_storage_size = conf.backend.SimpleStorage.total_storage_size + total_storage_size = conf.backend.SimpleStorage.get("total_storage_size", None) storage_placement_group = get_placement_group(num_data_storage_units, num_cpus_per_actor=1) + # Compute per-unit capacity: None means unlimited + storage_unit_size = ( + math.ceil(total_storage_size / num_data_storage_units) if total_storage_size is not None else None + ) + for storage_unit_rank in range(num_data_storage_units): storage_node = SimpleStorageUnit.options( # type: ignore[attr-defined] placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank, name=f"TransferQueueStorageUnit#{storage_unit_rank}", ).remote( - storage_unit_size=math.ceil(total_storage_size / num_data_storage_units), + storage_unit_size=storage_unit_size, ) simple_storage_handles[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node logger.info(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.") diff --git a/transfer_queue/storage/simple_storage.py b/transfer_queue/storage/simple_storage.py index e70648e..6865a60 100644 --- a/transfer_queue/storage/simple_storage.py +++ b/transfer_queue/storage/simple_storage.py @@ -60,10 +60,10 @@ class StorageUnitData: } """ - def __init__(self, storage_size: int): + def __init__(self, storage_size: int | None = None): # field_name -> {global_index: data} nested dict self.field_data: dict[str, dict] = {} - # Capacity upper bound (not pre-allocated list length) + # Capacity upper bound (None means unlimited) self.storage_size = storage_size # Track active global_index keys for O(1) capacity checks self._active_keys: set = set() @@ -103,12 +103,13 @@ def put_data(self, field_data: dict[str, Any], global_indexes: list) -> None: global_indexes: Global indexes to use as dict keys. """ # Capacity is enforced per unique sample key, not counted per-field - new_global_keys = [k for k in global_indexes if k not in self._active_keys] - if len(self._active_keys) + len(new_global_keys) > self.storage_size: - raise ValueError( - f"Storage capacity exceeded: {len(self._active_keys)} existing + " - f"{len(new_global_keys)} new > {self.storage_size}" - ) + if self.storage_size is not None: + new_global_keys = [k for k in global_indexes if k not in self._active_keys] + if len(self._active_keys) + len(new_global_keys) > self.storage_size: + raise ValueError( + f"Storage capacity exceeded: {len(self._active_keys)} existing + " + f"{len(new_global_keys)} new > {self.storage_size}" + ) for f, values in field_data.items(): if len(values) != len(global_indexes): raise ValueError( @@ -152,11 +153,12 @@ class SimpleStorageUnit: zmq_server_info: ZMQ connection information for clients. """ - def __init__(self, storage_unit_size: int): + def __init__(self, storage_unit_size: int | None = None): """Initialize a SimpleStorageUnit with the specified size. Args: storage_unit_size: Maximum number of elements that can be stored in this storage unit. + If None, the storage unit has unlimited capacity. """ self.storage_unit_id = f"TQ_STORAGE_UNIT_{uuid4().hex[:8]}" self.storage_unit_size = storage_unit_size