diff --git a/tests/test_mooncake_force_delete.py b/tests/test_mooncake_force_delete.py index 129a1de1..abb4907a 100644 --- a/tests/test_mooncake_force_delete.py +++ b/tests/test_mooncake_force_delete.py @@ -68,6 +68,56 @@ def test_enable_hard_pin_default_off(self): assert restored.enable_hard_pin is False +class TestMooncakeEnvDefaults: + def test_tcp_memcpy_default_is_applied_by_export_env(self): + config = MooncakeConfig(protocol="tcp") + + with patch.dict(os.environ, {}, clear=True): + config.export_env() + + assert os.environ["MC_STORE_MEMCPY"] == "0" + + def test_tcp_memcpy_default_preserves_user_override(self): + config = MooncakeConfig(protocol="tcp") + + with patch.dict(os.environ, {"MC_STORE_MEMCPY": "1"}, clear=True): + config.apply_env_defaults() + + assert os.environ["MC_STORE_MEMCPY"] == "1" + + def test_tcp_memcpy_default_not_applied_for_rdma(self): + config = MooncakeConfig(protocol="rdma") + + with patch.dict(os.environ, {}, clear=True): + config.apply_env_defaults() + + assert "MC_STORE_MEMCPY" not in os.environ + + def test_direct_store_setup_applies_tcp_memcpy_before_mooncake_client_setup(self): + config = MooncakeConfig(protocol="tcp", async_put_pool_size=0) + mock_raw_store = MagicMock() + mock_raw_store.setup.return_value = 0 + + class ConcreteStore(MooncakeHiddenStateStore): + pass + + def make_raw_store(): + assert os.environ["MC_STORE_MEMCPY"] == "0" + return mock_raw_store + + store = ConcreteStore(config) + with ( + patch.dict(os.environ, {}, clear=True), + patch("torchspec.transfer.mooncake.store.MooncakeDistributedStore", make_raw_store), + patch.object(ConcreteStore, "_verify_force_delete"), + patch.object(ConcreteStore, "_build_replicate_config"), + patch("torch.cuda.is_available", return_value=False), + ): + store.setup() + + mock_raw_store.setup.assert_called_once() + + # --------------------------------------------------------------------------- # Tests 2-3: _verify_force_delete # --------------------------------------------------------------------------- diff --git a/tests/test_placement_group.py b/tests/test_placement_group.py index 06a2ad3c..17c52ee6 100644 --- a/tests/test_placement_group.py +++ b/tests/test_placement_group.py @@ -1,13 +1,12 @@ -from argparse import Namespace import importlib.util import sys import types +from argparse import Namespace from pathlib import Path from unittest.mock import MagicMock, patch import pytest - repo_root = Path(__file__).resolve().parents[1] torchspec_pkg = sys.modules.get("torchspec") if torchspec_pkg is None and importlib.util.find_spec("torch") is None: @@ -67,8 +66,8 @@ def __init__(self, **kwargs): sys.modules["torchspec.ray.train_group"] = train_group_stub from torchspec.ray.placement_group import ( # noqa: E402 - _NodeConstraint, _build_custom_bundles, + _NodeConstraint, _sort_probed_bundle_infos, create_placement_groups, ) diff --git a/torchspec/config/mooncake_config.py b/torchspec/config/mooncake_config.py index 3b937838..9b1309f5 100644 --- a/torchspec/config/mooncake_config.py +++ b/torchspec/config/mooncake_config.py @@ -176,6 +176,7 @@ def export_env(self) -> None: os.environ["MOONCAKE_PROTOCOL"] = self.protocol os.environ["MOONCAKE_DEVICE_NAME"] = self.device_name os.environ["MOONCAKE_ENABLE_GPU_DIRECT"] = "1" if self.enable_gpu_direct else "0" + self.apply_env_defaults() if self.async_put_pool_size is not None: os.environ["MOONCAKE_ASYNC_PUT_POOL_SIZE"] = str(self.async_put_pool_size) os.environ["MOONCAKE_STORE_FULL_WAIT_SECONDS"] = str(self.store_full_wait_seconds) @@ -190,6 +191,14 @@ def export_env(self) -> None: os.environ["MOONCAKE_GET_RETRY_MAX_WAIT_SECONDS"] = str(self.get_retry_max_wait_seconds) os.environ["MOONCAKE_ENABLE_HARD_PIN"] = "1" if self.enable_hard_pin else "0" + def apply_env_defaults(self) -> None: + """Apply Mooncake process defaults that are needed before client setup.""" + # Fix: https://github.com/kvcache-ai/Mooncake/issues/1986 + if self.protocol.lower() == "tcp" and "MC_STORE_MEMCPY" not in os.environ: + # Mooncake's TCP-only memcpy fast path can segfault in same-host + # multi-process get paths. Preserve an explicit user override. + os.environ["MC_STORE_MEMCPY"] = "0" + @classmethod def from_env(cls) -> "MooncakeConfig": """Create config from environment variables.""" diff --git a/torchspec/ray/placement_group.py b/torchspec/ray/placement_group.py index 5a80d00a..0bff3791 100644 --- a/torchspec/ray/placement_group.py +++ b/torchspec/ray/placement_group.py @@ -34,7 +34,6 @@ from torchspec.ray.train_group import RayTrainGroup from torchspec.utils.logging import logger - # Ray exposes a tiny "node:" resource on each node. Requiring a fractional # amount pins a bundle to that node without consuming a full logical resource. _NODE_RESOURCE_EPSILON = 0.001 diff --git a/torchspec/transfer/mooncake/store.py b/torchspec/transfer/mooncake/store.py index 37219d98..363c2d32 100644 --- a/torchspec/transfer/mooncake/store.py +++ b/torchspec/transfer/mooncake/store.py @@ -73,6 +73,7 @@ def setup(self, device: torch.device | int | None = None) -> None: "Set mooncake.device_name to a specific RDMA device (e.g. 'mlx5_0')." ) + self.config.apply_env_defaults() self._store = MooncakeDistributedStore() logger.info( "Connecting to Mooncake master at %s, metadata server at %s",