diff --git a/llm_bench/load_test.py b/llm_bench/load_test.py index 0371d8a..34b9910 100644 --- a/llm_bench/load_test.py +++ b/llm_bench/load_test.py @@ -473,9 +473,17 @@ def notify_init(cls, environment, logging_params): if cls.logging_params is None: cls.logging_params = logging_params else: + # Multi-target runs intentionally use a different model per user; + # drop fields that are expected to differ across targets before comparing. + multi_target = bool(getattr(environment.parsed_options, "targets", None)) + if multi_target: + existing = {k: v for k, v in cls.logging_params.items() if k != "model"} + incoming = {k: v for k, v in logging_params.items() if k != "model"} + else: + existing, incoming = cls.logging_params, logging_params assert ( - cls.logging_params == logging_params - ), f"Inconsistent settings between workers: {cls.logging_params} != {logging_params}" + existing == incoming + ), f"Inconsistent settings between workers: {existing} != {incoming}" @classmethod def notify_first_request(cls): @@ -627,6 +635,30 @@ def _defer_run_time_to_after_spawn(environment, **_kwargs): logger.info(f"Will stop after {max_requests} requests complete") +@events.init.add_listener +def _scale_users_by_targets(environment, **_kwargs): + """Multiply -u/--users and --spawn-rate by the number of --targets entries. + + With this, -u 100 means "100 users per target": the user specifies per-target + load and the script scales the total locust user count up to match. spawn-rate + is multiplied proportionally so the ramp-up rate per target stays the same. + """ + targets = getattr(environment.parsed_options, "targets", None) or [] + n = len(targets) + if n <= 1: + return + num_users = getattr(environment.parsed_options, "num_users", None) + if num_users: + environment.parsed_options.num_users = num_users * n + logger.info( + f"Scaling --users by {n} targets: {num_users} -> {num_users * n} " + f"(stays {num_users} per target)" + ) + spawn_rate = getattr(environment.parsed_options, "spawn_rate", None) + if spawn_rate: + environment.parsed_options.spawn_rate = spawn_rate * n + + @dataclass class ChunkMetadata: text: str @@ -1164,6 +1196,58 @@ def _load_curl_like_data(text): class LLMUser(HttpUser): # no wait time, so every user creates a continuous load, sending requests as quickly as possible + _target_counter = 0 + _target_counter_lock = threading.Lock() + + @staticmethod + def _parse_targets(environment): + raw = environment.parsed_options.targets or [] + default_host = environment.host + default_model = environment.parsed_options.model + default_api_key = environment.parsed_options.api_key + def _label(url, model): + # Make the label distinguish targets even when they share a URL but + # differ by model (common with Fireworks deployment-pinned model + # strings of the form 'accounts/x/models/y#accounts/x/deployments/z'). + if model and "#" in model: + # Use the deployment id (right of '#') — short and human-recognizable. + return model.rsplit("/", 1)[-1] + if model: + return f"{url} {model}" + return url or "default" + + if not raw: + return [{ + "url": default_host, + "model": default_model, + "api_key": default_api_key, + "label": _label(default_host, default_model), + }] + parsed = [] + for spec in raw: + parts = spec.split("|") + url = parts[0] or default_host + model = (parts[1] if len(parts) > 1 and parts[1] else default_model) + api_key = (parts[2] if len(parts) > 2 and parts[2] else default_api_key) + parsed.append({ + "url": url, + "model": model, + "api_key": api_key, + "label": _label(url, model), + }) + return parsed + + def __init__(self, environment): + targets = self._parse_targets(environment) + with LLMUser._target_counter_lock: + idx = LLMUser._target_counter % len(targets) + LLMUser._target_counter += 1 + self._target = targets[idx] + # Override self.host before HttpUser.__init__ creates the HttpSession + if self._target["url"]: + self.host = self._target["url"] + super().__init__(environment) + def on_start(self): try: self._on_start() @@ -1173,7 +1257,7 @@ def on_start(self): sys.exit(1) def _guess_provider(self): - self.model = self.environment.parsed_options.model + self.model = self._target.get("model") or self.environment.parsed_options.model self.provider = self.environment.parsed_options.provider # guess based on URL if self.provider is None: @@ -1223,8 +1307,9 @@ def _guess_provider(self): def _on_start(self): self.client.headers["Content-Type"] = "application/json" - if self.environment.parsed_options.api_key: - self.client.headers["Authorization"] = "Bearer " + self.environment.parsed_options.api_key + api_key = self._target.get("api_key") or self.environment.parsed_options.api_key + if api_key: + self.client.headers["Authorization"] = "Bearer " + api_key if self.environment.parsed_options.header: for header in self.environment.parsed_options.header: key, val = header.split(":", 1) @@ -1444,6 +1529,7 @@ def _do_generate_text(self): stream=True, catch_response=True, timeout=60, + name=f"[{self._target['label']}] {self.provider_formatter.get_url()}", ) as response: combined_text = "" done = False @@ -1675,6 +1761,16 @@ def init_parser(parser): type=str, help="The model to use for generating text. If not specified we will pick the first model from the service as returned by /v1/models", ) + parser.add_argument( + "--targets", + action="append", + default=[], + help=( + "Target URL. Repeat for multiple targets. Format: 'url[|model][|api_key]' " + "(model and api_key optional, fall back to --model/--api-key). " + "With multiple targets, -u/--users is per-target (total = users * num_targets)." + ), + ) parser.add_argument( "--tokenizer", env_var="TOKENIZER",