diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 3c7c3f98..f7cd2394 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -400,8 +400,8 @@ def update_production_status( max_sample_idx = max(global_indices) if global_indices else -1 required_samples = max_sample_idx + 1 - # Ensure we have enough rows with self.data_status_lock: + # Ensure we have enough rows self.ensure_samples_capacity(required_samples) # Register new fields if needed @@ -415,10 +415,11 @@ def update_production_status( with self.data_status_lock: self.ensure_fields_capacity(required_fields) - # Update production status - if self.production_status is not None and global_indices and field_names: - field_indices = [self.field_name_mapping.get(field) for field in field_names] - self.production_status[torch.tensor(global_indices)[:, None], torch.tensor(field_indices)] = 1 + with self.data_status_lock: + # Update production status + if self.production_status is not None and global_indices and field_names: + field_indices = [self.field_name_mapping.get(field) for field in field_names] + self.production_status[torch.tensor(global_indices)[:, None], torch.tensor(field_indices)] = 1 # Update field metadata self._update_field_metadata(global_indices, dtypes, shapes, custom_backend_meta)