Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions transfer_queue/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ backend:

# SimpleStorage: ZMQ-based in-memory storage for out-of-the-box usage
SimpleStorage:
# Maximum number of experience samples to hold across all storage units
total_storage_size: 100000
# Maximum number of experience samples to hold across all storage units.
# Set to null for unlimited capacity (no capacity check).
total_storage_size: null
# Number of distributed storage units.
# Recommended: >= 2 x number of nodes for load balancing.
num_data_storage_units: 2
Expand Down
9 changes: 7 additions & 2 deletions transfer_queue/storage/bootstrap/simple_storage_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,21 @@ def initialize_simple_storage(conf: DictConfig) -> dict[str, Any]:

simple_storage_handles = {}
num_data_storage_units = conf.backend.SimpleStorage.num_data_storage_units
total_storage_size = conf.backend.SimpleStorage.total_storage_size
total_storage_size = conf.backend.SimpleStorage.get("total_storage_size", None)
storage_placement_group = get_placement_group(num_data_storage_units, num_cpus_per_actor=1)

# Compute per-unit capacity: None means unlimited
storage_unit_size = (
math.ceil(total_storage_size / num_data_storage_units) if total_storage_size is not None else None
)

for storage_unit_rank in range(num_data_storage_units):
storage_node = SimpleStorageUnit.options( # type: ignore[attr-defined]
placement_group=storage_placement_group,
placement_group_bundle_index=storage_unit_rank,
name=f"TransferQueueStorageUnit#{storage_unit_rank}",
).remote(
storage_unit_size=math.ceil(total_storage_size / num_data_storage_units),
storage_unit_size=storage_unit_size,
)
simple_storage_handles[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node
logger.info(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.")
Expand Down
20 changes: 11 additions & 9 deletions transfer_queue/storage/simple_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ class StorageUnitData:
}
"""

def __init__(self, storage_size: int):
def __init__(self, storage_size: int | None = None):
# field_name -> {global_index: data} nested dict
self.field_data: dict[str, dict] = {}
# Capacity upper bound (not pre-allocated list length)
# Capacity upper bound (None means unlimited)
self.storage_size = storage_size
# Track active global_index keys for O(1) capacity checks
self._active_keys: set = set()
Expand Down Expand Up @@ -103,12 +103,13 @@ def put_data(self, field_data: dict[str, Any], global_indexes: list) -> None:
global_indexes: Global indexes to use as dict keys.
"""
# Capacity is enforced per unique sample key, not counted per-field
new_global_keys = [k for k in global_indexes if k not in self._active_keys]
if len(self._active_keys) + len(new_global_keys) > self.storage_size:
raise ValueError(
f"Storage capacity exceeded: {len(self._active_keys)} existing + "
f"{len(new_global_keys)} new > {self.storage_size}"
)
if self.storage_size is not None:
new_global_keys = [k for k in global_indexes if k not in self._active_keys]
if len(self._active_keys) + len(new_global_keys) > self.storage_size:
raise ValueError(
f"Storage capacity exceeded: {len(self._active_keys)} existing + "
f"{len(new_global_keys)} new > {self.storage_size}"
)
for f, values in field_data.items():
if len(values) != len(global_indexes):
raise ValueError(
Expand Down Expand Up @@ -152,11 +153,12 @@ class SimpleStorageUnit:
zmq_server_info: ZMQ connection information for clients.
"""

def __init__(self, storage_unit_size: int):
def __init__(self, storage_unit_size: int | None = None):
"""Initialize a SimpleStorageUnit with the specified size.

Args:
storage_unit_size: Maximum number of elements that can be stored in this storage unit.
If None, the storage unit has unlimited capacity.
"""
self.storage_unit_id = f"TQ_STORAGE_UNIT_{uuid4().hex[:8]}"
self.storage_unit_size = storage_unit_size
Expand Down
Loading