diff --git a/tests/unit_tests/test_inference.py b/tests/unit_tests/test_inference.py index 8b3a4a64da4..3add3d1bc89 100644 --- a/tests/unit_tests/test_inference.py +++ b/tests/unit_tests/test_inference.py @@ -28,6 +28,7 @@ def gpt2_tiktoken_tokenizer(): @pytest.fixture(scope="module") def static_inference_engine(gpt2_tiktoken_tokenizer): + Utils.initialize_model_parallel() engine_wrapper = StaticInferenceEngineTestHarness() engine_wrapper.setup_engine(vocab_size=gpt2_tiktoken_tokenizer.vocab_size, legacy=True) @@ -63,8 +64,6 @@ def client(app): 'megatron.core.inference.text_generation_server.text_generation_server.send_do_generate' ) def test_generations_endpoint(mock_send_do_generate, client, gpt2_tiktoken_tokenizer): - Utils.initialize_distributed() - prompts = ["twinkle twinkle little star, how I wonder what you are"] request_data = {"prompts": prompts, "tokens_to_generate": 10, "logprobs": True} @@ -90,8 +89,6 @@ def test_generations_endpoint(mock_send_do_generate, client, gpt2_tiktoken_token "megatron.core.inference.text_generation_server.endpoints.completions.send_do_generate" ) def test_completions_endpoint(mock_send_do_generate, client, gpt2_tiktoken_tokenizer): - Utils.initialize_distributed() - twinkle = ("twinkle twinkle little star,", " how I wonder what you are") request_data = {"prompt": twinkle[0] + twinkle[1], "max_tokens": 0, "logprobs": 5, "echo": True}