diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d1258617..8761604a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -86,6 +86,7 @@ jobs: AWS_ACCESS_KEY: ${{ secrets.AWS_ACCESS_KEY }} CLEAN_UP_TMP_PATH: 1 run: | + wget https://uit.stanford.edu/sites/default/files/2023/10/11/incommon-rsa-ca2.pem #downloading cert manually [for CORAL] sudo cp incommon-rsa-ca2.pem /usr/local/share/ca-certificates/incommon-rsa-server-ca-2.crt # [cert for CORAL] sudo update-ca-certificates # [cert for CORAL] diff --git a/dataset_configs/ipl/config.yaml b/dataset_configs/ipl/config.yaml new file mode 100644 index 00000000..5d69c742 --- /dev/null +++ b/dataset_configs/ipl/config.yaml @@ -0,0 +1,49 @@ +documentation: | + TopIPL + ###### + + This config is used to run the `TopIPL: Iterative Pseudo-Labeling for ASR `_ training algorithm using NeMo-Run. + + TopIPL is a **semi-supervised training method** for automatic speech recognition (ASR) that iteratively alternates between model training and pseudo-label generation for unlabeled data. It uses a **top-N checkpoint averaging strategy** to create a strong teacher model and maintains a **dynamic cache** of pseudo-labels throughout the process. + + The pipeline is implemented as a processor compatible with the `nemo_run` framework. It generates an output manifest containing updated labels based on pseudo-labeling iterations. + + This config performs the following steps: + + 1. Runs training and inference commands using NeMo-Run. + 2. Periodically stops training to generate pseudo-labels with a top-N checkpoint ensemble. + 3. Maintains a dynamic cache of pseudo-labels for unlabeled data. + 4. Produces a new output manifest after each iteration. + + **Required arguments** + + - **output_manifest_file**: path where the final manifest with pseudo-labels will be saved. + - **nemo_run_config**: YAML config file specifying the training, inference, and IPL parameters. + + **Training config requirements** + + Your training config must include the following setting to enable IPL: + + .. code-block:: yaml + + exp_manager: + create_ipl_epoch_stopper_callback: True + + If you're not using Lhotse, also include: + + .. code-block:: yaml + + ipl_epoch_stopper_callback_params: + stop_every_n_epochs: 2 + + ### Prerequisites + + - nemo_run + - ``pip install -r ipl.txt`` + +processors_to_run: all + +processors: + - _target_: sdp.processors.IPL.nemo_run_processor.NemoRunIPLProcessor + config_path: ./nemo_run_config.yaml + output_manifest_file: ??? diff --git a/dataset_configs/ipl/nemo_run_config.yaml b/dataset_configs/ipl/nemo_run_config.yaml new file mode 100644 index 00000000..df968da1 --- /dev/null +++ b/dataset_configs/ipl/nemo_run_config.yaml @@ -0,0 +1,80 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The script to be run. +script: # Script path to run relative to directory +script_config: # Training config file for the script. ipl_epoch_stopper_callback should be provided in the config +inference_config: # Inference config file of unlabeled data for transcribe_speech_parallel + +exp_name: null # populated by exp_manager.name if not provided +results_dir: # Where to store the results of the run + +# Path to the local NeMo repository. This is used to locate scripts and configs from NeMo. +# To set this up: +# 1. Clone the NeMo repository: +# git clone https://github.com/NVIDIA/NeMo.git /your/desired/path/to/nemo +# 2. Set the path here: +# Make sure this path is valid and NeMo is up to date if you're using its scripts. +nemo_directory: # Nemo directory path +do_average: # Boolean value indicating whether to do average of checkpoints for pseudo-label generation +p_cache: # Probability with which update pseudo-labeled set +num_ipl_epochs: # How many epochs do pseudo-labeling + +# Optional arguments +num_runs: +num_gpus: +num_tasks_per_node: +max_runtime: # Specify for clusters + +######################################################################################################################## + +executor: slurm # or local + +USER: + +# Fields for cluster run +ssh_tunnel: + host: + # ------------------------------- Fill this up! ------------------------------- + user: "${USER}" # your username; or resolved from ${USER} environment variable ; or can be null which resolved from ${USER} environment variable + job_dir: "" # Job directory to keep created files + identity: "" + # ----------------------------------------------------------------------------- + +account: +partition: +job_name_prefix: + +containers: + asr: # Container image + + +env_vars: + - 'TOKENIZERS_PARALLELISM=' + - 'AIS_ENDPOINT=' + - 'LHOTSE_AUDIO_DURATION_MISMATCH_TOLERANCE=' + - 'TORCH_CUDNN_V8_API_ENABLED=' + - 'PYTORCH_CUDA_ALLOC_CONF=' + - 'HYDRA_FULL_ERROR=1' + +required_env_vars: + - 'HF_TOKEN=' + - 'WANDB_KEY=' + +mounts: + # Replace with your own paths in your cluster config + - /path/to/mount:/where/to/mount/ + +timeouts: + partition_name: # Specify time diff --git a/docs/src/sdp/api.rst b/docs/src/sdp/api.rst index bfa2bc62..dcdd13bc 100644 --- a/docs/src/sdp/api.rst +++ b/docs/src/sdp/api.rst @@ -379,6 +379,14 @@ Miscellaneous .. autodata:: sdp.processors.tts.prepare_tts_segments.PrepareTTSSegmentsProcessor :annotation: +.. autodata:: sdp.processors.ipl.nemo_run_processor.NemoRunIPLProcessor + :annotation: + +.. autodata:: sdp.processors.ipl.ipl_processors.TrainingCommandGenerator + :annotation: + +.. autodata:: sdp.processors.ipl.ipl_processors.InferenceCommandGenerator + :annotation: .. _sdp-base-classes: diff --git a/docs/src/sdp/existing_configs.rst b/docs/src/sdp/existing_configs.rst index 5e7b7c97..d0a3e64e 100644 --- a/docs/src/sdp/existing_configs.rst +++ b/docs/src/sdp/existing_configs.rst @@ -407,3 +407,21 @@ HiFiTTS-2 config-docs/english/hifitts2/config_22khz config-docs/english/hifitts2/config_44khz config-docs/english/hifitts2/config_bandwidth + +NemoRunIPL +~~~~~~~~~~ + +**Supported configs**. + +* **IPL**: + `config `__ | + :doc:`documentation ` +* **NeMoRun**: + `config `__ | + :doc:`documentation ` + +.. toctree:: + :hidden: + + config-docs/ipl/config + config-docs/ipl/nemo_run_config \ No newline at end of file diff --git a/requirements/ipl.txt b/requirements/ipl.txt new file mode 100644 index 00000000..de76dca4 --- /dev/null +++ b/requirements/ipl.txt @@ -0,0 +1,11 @@ +nemo_run + +# Nemo repository path is also required, it is used to locate scripts and configs from NeMo. +# +# To set this up: +# 1. Clone the NeMo repository: +# git clone https://github.com/NVIDIA/NeMo.git /your/desired/path/to/nemo +# 2. Set the path in nemo_run_config.yaml: +# nemo_directory: /your/desired/path/to/nemo +# +# Make sure this path is valid and NeMo is up to date if you're using its scripts. diff --git a/sdp/processors/ipl/README.md b/sdp/processors/ipl/README.md new file mode 100644 index 00000000..e7d9872c --- /dev/null +++ b/sdp/processors/ipl/README.md @@ -0,0 +1,47 @@ +# 🧠 TopIPL: Iterative Pseudo-Labeling for ASR + +TopIPL is an **iterative pseudo-labeling algorithm** designed for training ASR models using both labeled and unlabeled data. It maintains a **dynamic pseudo-label cache** and leverages **top-N averaged checkpoints** as a teacher model to generate high-quality pseudo-labels across training iterations. + +## 📦 Contents + +- `NemoRunIPLProcessor` — Command generator and job submitter for IPL runs, compatible with local and cluster environments. +- `nemo_run_config.yaml` — Main configuration file. Users should define all required paths and parameters here. + +## 🚀 Getting Started + +TopIPL runs like any other processor in the `nemo_run` framework. To use it, you must pass: + +- `output_manifest_file`: Path where the resulting manifest will be saved. +- `nemo_run_config`: YAML file containing IPL setup, training/inference configs, and NeMo-Run settings. + +### 🔧 Training Config Requirements + +Your training config must: + +```yaml +exp_manager: + create_ipl_epoch_stopper_callback: True +``` +If you're not using Lhotse, also include: + +```yaml +ipl_epoch_stopper_callback_params: +stop_every_n_epochs: 2 + +``` + +### Prerequisites + +Before using TopIPL, make sure the following are set up: + +- Clone the NeMo repository: + ```bash + git clone https://github.com/NVIDIA/NeMo.git /your/desired/path/to/nemo + +- Set the path to NeMo in your `nemo_run_config.yaml`: `nemo_directory: /your/desired/path/to/nemo` +- `pip install -r requirements/ipl.txt` + +### Running the Code + +```bash +python main.py --config-path=/path/to/directory/config --config-name=config.yaml \ No newline at end of file diff --git a/sdp/processors/ipl/__init__.py b/sdp/processors/ipl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sdp/processors/ipl/ipl_processors.py b/sdp/processors/ipl/ipl_processors.py new file mode 100644 index 00000000..0a29c656 --- /dev/null +++ b/sdp/processors/ipl/ipl_processors.py @@ -0,0 +1,341 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard library imports +import os +import subprocess +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +# Third-party imports +from omegaconf import DictConfig, OmegaConf, open_dict +import logging +import json +# Local imports +from sdp.processors.base_processor import BaseProcessor + + +class TrainingCommandGenerator(BaseProcessor): + """ + A processor that generates training commands for NeMo models with support for both local and cluster configurations. + Handles manifest file updates and tarred audio filepath management for training datasets. + + Args: + training_config_local (str): Path to the local machine configuration file + training_config_cluster (str): Path to the cluster configuration file + training_script_path (str): Path to the training script relative to nemo_directory + nemo_directory (str): Base directory for NeMo framework + new_manifest_files (str, Optional): New manifest files to add to the training configuration + new_tarred_audio_filepaths (str, Optional): New tarred audio filepaths to add to the training configuration + **kwargs: Additional arguments passed to the parent BaseProcessor class + """ + + def __init__( + self, + training_config_local: str, # Local machine config path + training_config_cluster: str, # Cluster config path + training_script_path: str, # Path to training script + nemo_directory: str, # Base directory for NeMo + new_manifest_files: str = None, # New manifest files to add + new_tarred_audio_filepaths: str = None, # New tarred audio paths + **kwargs + ): + super().__init__(**kwargs) + + # Paths on the current machine + self.training_config_local = training_config_local + self.training_config_cluster = training_config_cluster + self.training_script_path = os.path.join(nemo_directory, training_script_path) + self.nemo_directory = nemo_directory + self.new_manifest_files = new_manifest_files + self.new_tarred_audio_filepaths = new_tarred_audio_filepaths + + def process( + self, + new_manifest_files=None, + new_tarred_audio_filepaths=None + ) -> str: + """ + Generates the training command based on the processor's configuration. + If new manifest files are provided, updates the training configuration accordingly. + + Returns: + str: The complete training command to be executed on the cluster + """ + if new_manifest_files is None: + cmd = self.get_execution_script( + cluster_script_path=self.training_script_path, + local_config=self.training_config_local, + cluster_config_path=self.training_config_cluster + ) + else: + updated_manifest_filepaths, updated_tarred_audio_filepaths = self.update_training_sets( + config=self.training_config_local, + updated_manifest_filepaths=new_manifest_files, + updated_tarred_audio_filepaths=new_tarred_audio_filepaths + ) + cmd = self.get_execution_script( + cluster_script_path=self.training_script_path, + local_config=self.training_config_local, + cluster_config_path=self.training_config_cluster, + updated_manifest_filepaths=updated_manifest_filepaths, + updated_tarred_filepaths=updated_tarred_audio_filepaths + ) + return cmd + + def get_execution_script( + self, + cluster_script_path: str, + local_config: DictConfig, + cluster_config_path: str, + updated_manifest_filepaths: Optional[str] = None, + updated_tarred_filepaths: Optional[str] = None + ) -> str: + """ + Create the command to run the script on the cluster. + + Args: + cluster_script_path (str): Path to the script to run on the cluster + local_config (DictConfig): Local configuration loaded from training_config_local + cluster_config_path (str): Path to the cluster configuration file + updated_manifest_filepaths (str, Optional): Path to the updated manifest file + updated_tarred_filepaths (str, Optional): Path to the updated tarred audio filepaths + + Returns: + str: Command to run the script on the cluster + """ + # Get the WANDB API key from the environment variables + wandb_key = os.environ.get("WANDB_API_KEY") or os.environ.get("WANDB") or os.environ.get("WANDB_KEY", "") + if not wandb_key: + logging.warning("WANDB key not found in environment variables. WANDB logging will not work.") + + # Check if WANDB logging is enabled in the exp_manager config + if local_config.get('exp_manager', {}).get('create_wandb_logger', False): + raise ValueError( + "WANDB key is required for logging but was not found in environment variables. " + "Please set WANDB_API_KEY to enable WANDB logging." + ) + + + config_path = os.path.dirname(cluster_config_path) + config_name = os.path.basename(cluster_config_path) + cmd = ( + "nvidia-smi && " + f"cd {os.path.dirname(cluster_script_path)} && " + f"python -u -B {os.path.basename(cluster_script_path)} " + f"--config-path {config_path} --config-name \"{config_name}\"" + ) + + # Add additional parameters if provided + if updated_manifest_filepaths: + cmd += f" model.train_ds.manifest_filepath={updated_manifest_filepaths}" + if updated_tarred_filepaths: + cmd += f" model.train_ds.tarred_audio_filepaths={updated_tarred_filepaths}" + output_data = {"training_command": cmd} + + # with open(self.output_manifest_file, 'w') as f: + # json.dump(output_data, f, indent=4) + return cmd + + def get_transcribed_names(self, manifest_filepaths: List[str], is_tarred: bool=False) -> List[List[str]]: + """ + Generates a list of modified file paths by prepending 'transcribed_' to the filenames. + The use case is for non AIStore datasets + + Args: + manifest_filepaths (list of str): A list of file paths to be modified. + + Returns: + list of list of str: A list where each element is a single-item list containing the updated file path. + Example: + >>> manifest_filepaths = [ + ... "/path/to/manifest_1.json", + ... "/path/to/manifest_2.json" + ... ] + >>> get_transcribed_names(manifest_filepaths) + [ + ["/path/to/prefix_transcribed_manifest_1.json"], + ["/path/to/prefix_transcribed_manifest_2.json"] + ] + """ + # For manifest_filepath, modify the filenames by prepending 'prefix_transcribed_' + transcribed_paths = [] + + for file_path in manifest_filepaths: + directory, filename = os.path.split(file_path) + + new_filename = ( + f"transcribed_{filename}" if is_tarred + else f"transcribed_manifest.json" + ) + transcribed_paths.append([os.path.join(directory, new_filename)]) + + return transcribed_paths + + def update_training_sets( + self, + config: DictConfig, + updated_manifest_filepaths: List[str], + updated_tarred_audio_filepaths: Optional[List[str]] = None + ) -> Tuple[str, str]: + """ + Updates the training dataset configuration by adding pseudo-labeled datasets + to the training paths based on the dataset type. + + Args: + config (DictConfig): Training config file to be updated + updated_manifest_filepaths (List[str]): List of updated manifest file paths to be included + updated_tarred_audio_filepaths (Optional[List[str]]): List of updated tarred audio filepaths to be included + + Returns: + Tuple[str, str]: A tuple containing: + - Updated manifest file paths as a string, formatted for Omegaconf + - Updated tarred audio file paths as a string, formatted for Omegaconf + """ + updated_manifest_filepaths = self.get_transcribed_names(updated_manifest_filepaths,is_tarred=config.model.train_ds.get("is_tarred", False)) + manifest_filepath = config.model.train_ds.manifest_filepath + if updated_tarred_audio_filepaths: + updated_tarred_audio_filepaths = [[path] for path in updated_tarred_audio_filepaths] + + # Updating the configuration based on dataset types + if config.model.train_ds.get("is_tarred", False): + tarred_audio_filepaths = config.model.train_ds.tarred_audio_filepaths + if isinstance(tarred_audio_filepaths, str): + updated_tarred_audio_filepaths.append([tarred_audio_filepaths]) + updated_manifest_filepaths.append([manifest_filepath]) + else: + updated_tarred_audio_filepaths += tarred_audio_filepaths + updated_manifest_filepaths += manifest_filepath + else: + if config.model.train_ds.get("use_lhotse", False): + if isinstance(manifest_filepath, str): + updated_manifest_filepaths.append([manifest_filepath]) + else: + updated_manifest_filepaths += manifest_filepath + else: + updated_manifest_filepaths = [item for sublist in updated_manifest_filepaths for item in sublist] + if isinstance(manifest_filepath, str): + updated_manifest_filepaths.append(manifest_filepath) + else: + updated_manifest_filepaths += manifest_filepath + + # Returning strings formatted for Omegaconf + return ( + str(updated_manifest_filepaths).replace(", ", ","), + str(updated_tarred_audio_filepaths).replace(", ", ",") if updated_tarred_audio_filepaths else None, + ) + + +class InferenceCommandGenerator(BaseProcessor): + """ + A processor that generates inference commands for pseudo-labeling. + + Args: + nemo_directory (str): Base directory for NeMo framework + inference_local_config (str): Path to the local configuration file + inference_config_paths (str): Path to the inference configuration files + manifests (str): Path to the manifest files + p_cache (float): What part of pseudo-labels to update + num_gpus (int): Number of GPUs to use + is_tarred (bool): Whether the audio is tarred + first_run (bool): Whether this is the first run of pseudo-labeling + **kwargs: Additional arguments passed to the parent BaseProcessor class + """ + + def __init__( + self, + nemo_directory: str, + inference_config_paths: str, + manifests: str, + p_cache: float, + num_gpus: int, + is_tarred: bool = False, + **kwargs + ): + super().__init__(**kwargs) + + # Paths on the current machine + self.inference_config_paths = inference_config_paths + self.nemo_directory = nemo_directory + self.inference_script_path = os.path.join(nemo_directory, "examples/asr/transcribe_speech_parallel.py") + self.manifests = manifests + self.p_cache = p_cache + self.num_gpus = num_gpus + self.is_tarred = is_tarred + + def process(self, first_run=False): + """ + Generate the pseudo-labeling command for the given configuration and training parameters. + + Args: + first_run (bool, Optional): Whether this is the first run of pseudo-labeling. + + Returns: + str: The constructed pseudo-labeling command. + """ + cmd = "" + prediction_directories_str = " ".join([os.path.dirname(path) for path in self.manifests]) + inference_config_paths_str = " ".join(self.inference_config_paths) + write_transcription_path = os.path.join(self.nemo_directory, "scripts/pseudo_labeling/write_transcribed_files.py") + update_inference_config_path = os.path.join(self.nemo_directory, "scripts/pseudo_labeling/update_inference_config.py") + + if first_run: + cmd += f" && {self.get_pl_inference_command(self.inference_config_paths, shuffle=False)}" + cmd += ( + f" && python {write_transcription_path} " + f"--prediction_filepaths {prediction_directories_str} --full_pass" + ) + if self.is_tarred: + cmd += " --is_tarred" + cmd += ( + f" && python {update_inference_config_path} " + f"--inference_configs {inference_config_paths_str} --p_cache {self.p_cache} --num_gpus {self.num_gpus}" + ) + else: + cmd += f" && {self.get_pl_inference_command(self.inference_config_paths, shuffle=True)}" + cmd += ( + f" && python {write_transcription_path} " + f"--prediction_filepaths {prediction_directories_str} " + ) + if self.is_tarred: + cmd += " --is_tarred" + + output_data = {"inference_command": cmd} + with open(self.output_manifest_file, 'w') as f: + json.dump(output_data, f, indent=4) + + return cmd + + + def get_pl_inference_command(self, inference_configs, shuffle=None): + """ + Generate a command to run PL inference with multiple configuration files. + Args: + inference_configs (list): List of configuration file paths. + shuffle (bool, Optional): Whether to enable shuffling in predict_ds. + + Returns: + str: Combined command string to execute PL inference. + """ + cmd_list = [] + for config in inference_configs: + config_path = os.path.dirname(config) + config_name = os.path.basename(config) + cmd = f"python {self.inference_script_path} --config-path {config_path} --config-name {config_name}" + if shuffle is not None: + cmd += f" predict_ds.shuffle={shuffle}" + cmd_list.append(cmd) + + return " && ".join(cmd_list) + diff --git a/sdp/processors/ipl/nemo_run_processor.py b/sdp/processors/ipl/nemo_run_processor.py new file mode 100644 index 00000000..529a128c --- /dev/null +++ b/sdp/processors/ipl/nemo_run_processor.py @@ -0,0 +1,334 @@ + +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sdp.processors.ipl.ipl_processors import TrainingCommandGenerator, InferenceCommandGenerator +from sdp.processors.base_processor import BaseProcessor +from omegaconf import OmegaConf, open_dict +import os +from pathlib import Path +import logging +import datetime +import nemo_run as run +from sdp.utils import nemo_run_utils + +class NemoRunIPLProcessor(BaseProcessor): + """ + A processor that handles Iterative Pseudo-Labeling (IPL) training workflow. + + Args: + config_path (str): Path to the YAML configuration file containing IPL settings + output_manifest_file (str): Path where the output manifest file will be written + input_manifest_file (str, Optional): Path to the input manifest file + """ + + def __init__( + self, + config_path: str, + **kwargs + ): + super().__init__(**kwargs) + self.config_path = config_path + + def process(self): + """ + Main processing method that implements the IPL workflow. + This method: + 1. Loads and validates configurations + 2. Sets up training and inference command generators + 3. Executes the IPL training pipeline + """ + # Load the cluster config from YAML + cluster_cfg = OmegaConf.load(self.config_path) + + # Process the required arguments from the cluster config + script_path = cluster_cfg.script + script_config_path = cluster_cfg.script_config + results_dir = cluster_cfg.results_dir + nemo_root = cluster_cfg.nemo_directory + inference_config = cluster_cfg.inference_config + do_average = cluster_cfg.get('do_average', False) + inference_config_path = Path(inference_config).absolute() + + inference_config = OmegaConf.load(inference_config_path) + + script_config_path = Path(script_config_path).absolute() + + # Gather all mounts from the cluster config + self.gather_mounts(cluster_cfg) + + # Add the results directory to the cluster config as a mount path + nemo_run_utils.add_mount_path(results_dir, '/results', cluster_cfg) + + # Create results and logdir + log_dir = cluster_cfg.get('log_dir', os.path.join(results_dir, 'logs')) + nemo_run_utils.create_remote_directory([results_dir, log_dir], cluster_cfg) + + # Load the script config + script_config = OmegaConf.load(script_config_path) + + # Validate IPL training configuration + if "ipl_training" not in script_config.model: + raise KeyError("Parameters for `IPL` training are not provided.") + # Check all paths in configs are properly mounted + + self.check_config_mount_paths(script_config, cluster_cfg) + # Resolve experiment name + exp_name = cluster_cfg.exp_name + if exp_name is None: + if 'exp_manager' in script_config and 'name' in script_config['exp_manager']: + exp_name = script_config['exp_manager']['name'] + else: + raise ValueError( + "Experiment name not provided in the run config file (`exp_name`) or the cluster config (inside exp_manager.name)" + ) + + # Begin NeMo Run setup + with run.Experiment(exp_name) as exp: + # Create the config file name + timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + config_name = f"{exp_name}_{timestamp}_config.yaml" + + # Copy the merged config file to remote location's /results/configs directory + config_dir = os.path.join(results_dir, 'configs') + train_config_cluster = nemo_run_utils.create_remote_config(script_config, config_name, config_dir, cluster_cfg) + + # Get run parameters from the config + num_runs = cluster_cfg.num_runs + num_gpus = cluster_cfg.get('num_gpus', script_config['trainer']['devices']) + if isinstance(num_gpus, list): + num_gpus = len(num_gpus) + if num_gpus == -1: + num_gpus = 1 if cluster_cfg['executor'] == 'local' else 8 + logging.warning(f"\n\nSetting num_gpus to {num_gpus} as it was set to -1\n\n") + num_nodes = cluster_cfg.get('num_nodes', script_config['trainer'].get('num_nodes', 1)) + + # Set up checkpoint paths + checkpoint_dir = os.path.join( + os.path.join(script_config.exp_manager.exp_dir, script_config.exp_manager.name), "checkpoints" + ) + checkpoint_name = os.path.join(checkpoint_dir, script_config.exp_manager.name + ".nemo") + + # Create remote inference config + if do_average: + avg_cmd, averaged_checkpoint = self.average_checkpoints(checkpoint_name, nemo_root) + else: + avg_cmd = None + averaged_checkpoint = checkpoint_name + inference_config_paths, manifests, tarr_paths = nemo_run_utils.create_remote_inference_config( + cluster_cfg, config_dir, inference_config, averaged_checkpoint + ) + self.check_config_mount_paths(inference_config, cluster_cfg) + # Configure command generators + train_command_generator_config = { + "nemo_directory": nemo_root, + "training_config_local": script_config, + "training_config_cluster": train_config_cluster, + "training_script_path": script_path, + "output_manifest_file": "./train_output_manifest_filepath.json", + } + inference_command_generator_config = { + "nemo_directory": nemo_root, + "inference_config_paths": inference_config_paths, + "manifests": manifests, + "p_cache": cluster_cfg.p_cache, + "num_gpus": num_nodes * num_gpus, + "is_tarred": getattr(script_config.model.train_ds, "is_tarred", False), + "output_manifest_file": "./inference_output_manifest_filepath.json", + } + + print(f"cluster_cf {cluster_cfg}") + # Generate the complete IPL command + cmd = self.get_pseudo_labeling_command( + train_command_generator_config, + inference_command_generator_config, + num_ipl_epochs=cluster_cfg['num_ipl_epochs'], + new_manifest_files=manifests, + new_tarr_files=tarr_paths, + first_run=True, + avg_cmd=avg_cmd + ) + + # Cast the cluster config to a dictionary for compatibility with NeMo Run + cluster_cfg = OmegaConf.to_object(cluster_cfg) + + # Schedule tasks + task = None + for run_id in range(num_runs): + if run_id == 0: + task = None + else: + cmd = self.get_pseudo_labeling_command( + train_command_generator_config, + inference_command_generator_config, + num_ipl_epochs=cluster_cfg['num_ipl_epochs'], + new_manifest_files=manifests, + new_tarr_files=tarr_paths, + first_run=False + ) + task = [task] + + task = nemo_run_utils.add_task( + exp, + cmd=cmd, + task_name=f"{exp_name}_job", + cluster_config=cluster_cfg, + container=cluster_cfg['containers']['asr'], + num_tasks=cluster_cfg.get('num_tasks', cluster_cfg.get('num_tasks_per_node', 1)), + num_gpus=num_gpus, + num_nodes=num_nodes, + log_dir=nemo_run_utils.get_mounted_filepath(cluster_cfg, log_dir), + partition=cluster_cfg.get('partition', None), + task_dependencies=task, + ) + + # Run the experiment + nemo_run_utils.run_exp(exp, cluster_cfg) + + def gather_mounts(self, cluster_cfg): + """ + Gather all mounts from the cluster config including ones which are disjoint from the cluster_cfg.mounts list. + + Args: + cluster_cfg: Cluster config dictionary + """ + mounts = cluster_cfg.get('mounts', []) + mounts = [os.path.expanduser(m) for m in mounts] + + keys = list(cluster_cfg.keys()) + with open_dict(cluster_cfg): + for k in keys: + if k.startswith("mount_"): + logging.info(f"Found additional mount flag in the cluster config `{k}`. Adding it to the mounts list.") + mounts.append(cluster_cfg[k]) + del cluster_cfg[k] + + cluster_cfg['mounts'] = mounts + logging.info(f"Final Mounts: {mounts}") + + def check_config_mount_paths(self, script_config, cluster_config): + """ + Check if all path-like strings in the script config are mounted paths in the cluster config. + + Args: + script_config: Script config dictionary + cluster_config: Cluster config dictionary + """ + def filepath_check(v, cluster_cfg): + if v.startswith(os.path.sep): + logging.info(f"Checking if {v} is a mounted path") + nemo_run_utils.check_if_mounted(cluster_cfg, v) + unmounted_path = nemo_run_utils.get_unmounted_filepath(cluster_cfg, v) + nemo_run_utils.check_remote_mount_directories(unmounted_path, cluster_cfg) + + def check_mounted_path(cfg, cluster_cfg): + if hasattr(cfg, 'items'): + for k, v in cfg.items(): + if hasattr(v, 'items'): + check_mounted_path(v, cluster_cfg) + elif isinstance(v, list): + for item in v: + if isinstance(item, str): + filepath_check(item, cluster_cfg) + elif isinstance(v, str): + filepath_check(v, cluster_cfg) + + check_mounted_path(script_config, cluster_config) + + def get_pseudo_labeling_command( + self, + train_command_config: dict, + inference_command_config: dict, + num_ipl_epochs: int, + new_manifest_files, + new_tarr_files, + first_run: bool = False, + avg_cmd: str = None + ) -> str: + """ + Generate the pseudo-labeling command for the given configuration and training parameters. + + Args: + train_command_config (dict): Config for TrainingCommandGenerator + inference_command_config (dict): Config for InferenceCommandGenerator + num_ipl_epochs (int): Number of epochs to train with pseudo-labels + new_manifest_files: List of manifest files to use + new_tarr_files: List of tarred audio files to use + first_run (bool): Whether this is the first run of pseudo-labeling + + Returns: + str: The constructed pseudo-labeling command + """ + train_proc = TrainingCommandGenerator(**train_command_config) + infer_proc = InferenceCommandGenerator(**inference_command_config) + + exec_cmd = self.get_export_variables_cmd(train_command_config["training_config_local"], train_command_config["nemo_directory"]) + exec_cmd += train_proc.process() + exec_cmd += " && sleep 10" + if avg_cmd: + exec_cmd += " && " + avg_cmd + + exec_cmd += " " + infer_proc.process(first_run=first_run) + + for _ in range(num_ipl_epochs): + exec_cmd += " && sleep 10" + exec_cmd += " && " + train_proc.process(new_manifest_files, new_tarr_files) + if avg_cmd: + exec_cmd += " && " + avg_cmd + exec_cmd += " " + infer_proc.process(first_run=False) + + return exec_cmd + + def get_export_variables_cmd(self, merged_cfg , nemo_root): + """Generate command to export required environment variables.""" + wandb_key = os.environ.get("WANDB_API_KEY") or os.environ.get("WANDB") or os.environ.get("WANDB_KEY", "") + if not wandb_key: + logging.warning("WANDB key not found in environment variables. WANDB logging will not work.") + + if merged_cfg.get('exp_manager', {}).get('create_wandb_logger', False): + raise ValueError( + "WANDB key is required for logging but was not found in environment variables. " + "Please set WANDB_API_KEY to enable WANDB logging." + ) + + cmd = ( + "nvidia-smi && " + f"export PYTHONPATH={nemo_root} && " + f"export HF_TOKEN={os.getenv('HF_TOKEN', '')} && " + f"export WANDB_API_KEY={wandb_key} && ") + + return cmd + + def average_checkpoints(self, checkpoint_path: str, nemo_root:str) -> str: + """ + Generates the command to average all checkpoints in the given directory and returns the path to the averaged checkpoint. + + Args: + checkpoint_path (str): Path to the directory containing checkpoints + + Returns: + tuple: (command to run, path to the averaged checkpoint file) + """ + # Get the directory containing the checkpoints + checkpoint_dir = os.path.dirname(checkpoint_path) + + # Construct the command for checkpoint averaging + cmd = f"python {nemo_root}/scripts/checkpoint_averaging/legacy/checkpoint_averaging.py {checkpoint_dir}" + + # The averaged checkpoint will have the same name but with '-averaged' suffix + checkpoint_name = os.path.basename(checkpoint_path) + base_name = os.path.splitext(checkpoint_name)[0] + averaged_checkpoint = os.path.join(checkpoint_dir, f"{base_name}-averaged.nemo") + + return cmd, averaged_checkpoint diff --git a/sdp/utils/ipl_utils.py b/sdp/utils/ipl_utils.py new file mode 100644 index 00000000..07d50c5d --- /dev/null +++ b/sdp/utils/ipl_utils.py @@ -0,0 +1,330 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import glob +import json +import os +from typing import List, Optional, Tuple, Union + +from omegaconf import OmegaConf + +def separate_multiple_transcriptions(inference_config: dict) -> Tuple[List[str], Optional[List[str]]]: + """ + Separates and returns the manifest and tarred audio file paths from the configuration. + This function makes it easier to run transcribe_speech_parallel for each bucket separately + Args: + inference_config (str): Path to the inference configuration file. + Returns: + Tuple[List[str], Optional[List[str]]]: A tuple containing: + - A list of manifest file paths. + - An Optional list of tarred audio file paths, or None if not applicable. + """ + + if hasattr(inference_config.predict_ds, "is_tarred") and inference_config.predict_ds.is_tarred: + tarred_audio_filepaths = inference_config.predict_ds.tarred_audio_filepaths + manifest_filepaths = inference_config.predict_ds.manifest_filepath + if type(tarred_audio_filepaths) != str and len(tarred_audio_filepaths) > 1: + manifests = [] + tarr_audio_files = [] + for manifest_filepath, tarred_audio_filepath in zip(manifest_filepaths, tarred_audio_filepaths): + manifests.append(manifest_filepath[0]) + tarr_audio_files.append(tarred_audio_filepath[0]) + return manifests, tarr_audio_files + else: + return [manifest_filepaths], [tarred_audio_filepaths] + else: + if isinstance(inference_config.predict_ds.manifest_filepath, str): + return [inference_config.predict_ds.manifest_filepath], None + else: + return inference_config.predict_ds.manifest_filepath, None + + +def create_transcribed_shard_manifests( + prediction_filepaths: List[str], +) -> List[str]: + """ + Creates transcribed shard manifest files by processing predictions and organizing them by shard ID. + This function reads a `predictions_all.json` file from each given directory, organizes the data by + shard IDs, and writes the entries to separate shard manifest files. For each shard, the `pred_text` + field is updated as the main transcription (`text`), and the original transcription (`text`) is + stored as `orig_text`. + Args: + prediction_filepaths (List[str]): A list of file paths to directories containing + `predictions_all.json` files with prediction data, including shard IDs. + Returns: + List[str]: A list of file paths to the combined manifest files (`transcribed_manifest__OP_0..CL_.json`) + created for each directory. + """ + all_manifest_filepaths = [] + for prediction_filepath in prediction_filepaths: + max_shard_id = 0 + shard_data = {} + full_path = os.path.join(prediction_filepath, "predictions_all.json") + with open(full_path, 'r') as f: + for line in f.readlines(): + data_entry = json.loads(line) + shard_id = data_entry.get("shard_id") + if max_shard_id < shard_id: + max_shard_id = shard_id + if shard_id not in shard_data: + shard_data[shard_id] = [] + shard_data[shard_id].append(data_entry) + for shard_id, entries in shard_data.items(): + output_filename = os.path.join(prediction_filepath, f"transcribed_manifest_{shard_id}.json") + with open(output_filename, 'w') as f: + for data_entry in entries: + if data_entry['audio_filepath'].endswith(".wav"): + if 'text' in data_entry: + data_entry['orig_text'] = data_entry.pop('text') + data_entry['text'] = data_entry.pop('pred_text') + json.dump(data_entry, f, ensure_ascii=False) + f.write("\n") + shard_manifest_filepath = os.path.join( + prediction_filepath, f"transcribed_manifest__OP_0..{max_shard_id}_CL_.json" + ) + all_manifest_filepaths.append(shard_manifest_filepath) + return all_manifest_filepaths + + +def create_transcribed_manifests( + prediction_filepaths: List[str], +) -> List[str]: + """ + Creates updated transcribed manifest files by processing predictions. + This function reads prediction files (`predictions_all.json`) from the provided directories, + updates the transcription data by renaming the `pred_text` field to `text`, and stores the + original `text` field as `orig_text`. The updated data is written to new transcribed manifest + files (`transcribed_manifest.json`) in each directory. + Args: + prediction_filepaths (List[str]): A list of file paths to directories containing + prediction files (`predictions_all.json`). + Returns: + List[str]: A list of file paths to the newly created transcribed manifest files + (`transcribed_manifest.json`). + """ + all_manifest_filepaths = [] + for prediction_filepath in prediction_filepaths: + prediction_name = os.path.join(prediction_filepath, "predictions_all.json") + transcripted_name = os.path.join(prediction_filepath, f"transcribed_manifest.json") + + # Open and read the original predictions_all.json file + with open(transcripted_name, 'w', encoding='utf-8') as f: + with open(prediction_name, 'r', encoding='utf-8') as pred_f: + + for line in pred_f.readlines(): + data_entry = json.loads(line) + if 'text' in data_entry: + data_entry['orig_text'] = data_entry.pop('text') + data_entry['text'] = data_entry.pop('pred_text') + json.dump(data_entry, f, ensure_ascii=False) + f.write("\n") + # Append the path of the new manifest file to the list + all_manifest_filepaths.append(transcripted_name) + + return all_manifest_filepaths + + +def write_sampled_shard_transcriptions(manifest_filepaths: List[str]) -> List[List[str]]: + """ + Updates transcriptions by merging predicted shard data and transcribed manifest data. + This function processes prediction and transcribed manifest files, merges them + by matching the shard_id and audio file paths. For each shard, the corresponding + data entries are written to a new file. + Args: + manifest_filepaths (List[str]): A list of file paths to directories containing + prediction and transcribed manifest files. + Returns: + List[List[str]]: A list of lists containing the file paths to the generated + transcribed shard manifest files. + """ + all_manifest_filepaths = [] + + # Process each prediction directory + for prediction_filepath in manifest_filepaths: + predicted_shard_data = {} + # Collect entries from prediction files based on shard id + prediction_path = os.path.join(prediction_filepath, "predictions_all.json") + with open(prediction_path, 'r') as f: + for line in f: + data_entry = json.loads(line) + shard_id = data_entry.get("shard_id") + audio_filepath = data_entry['audio_filepath'] + predicted_shard_data.setdefault(shard_id, {})[audio_filepath] = data_entry + max_shard_id = 0 + for full_path in glob.glob(os.path.join(prediction_filepath, f"transcribed_manifest_[0-9]*.json")): + all_data_entries = [] + with open(full_path, 'r') as f: + for line in f: + data_entry = json.loads(line) + shard_id = data_entry.get("shard_id") + max_shard_id = max(max_shard_id, shard_id) + all_data_entries.append(data_entry) + # Write the merged data to a new manifest file keeping new transcriptions + output_filename = os.path.join(prediction_filepath, f"transcribed_manifest_{shard_id}.json") + with open(output_filename, 'w') as f: + for data_entry in all_data_entries: + audio_filepath = data_entry['audio_filepath'] + # Escape duplicated audio files that end with *dup + if audio_filepath.endswith(".wav"): + if shard_id in predicted_shard_data and audio_filepath in predicted_shard_data[shard_id]: + predicted_data_entry = predicted_shard_data[shard_id][audio_filepath] + if 'text' in predicted_data_entry: + predicted_data_entry['orig_text'] = predicted_data_entry.pop('text') + if "pred_text" in predicted_data_entry: + predicted_data_entry['text'] = predicted_data_entry.pop('pred_text') + json.dump(predicted_data_entry, f, ensure_ascii=False) + else: + json.dump(data_entry, f, ensure_ascii=False) + f.write("\n") + + shard_manifest_filepath = os.path.join( + prediction_filepath, f"transcribed_manifest__OP_0..{max_shard_id}_CL_.json" + ) + all_manifest_filepaths.append([shard_manifest_filepath]) + + return all_manifest_filepaths + +def write_sampled_transcriptions(manifest_filepaths: List[str]) -> List[str]: + """ + Updates transcriptions by merging predicted data with transcribed manifest data. + This function processes prediction and transcribed manifest files within given directories. + It matches audio file paths to update transcriptions with predictions, ensuring each audio file + is properly transcribed. The updated data is written to the transcribed manifest file. + Args: + manifest_filepaths (List[str]): A list of file paths to directories containing + the prediction file (`predictions_all.json`) and the transcribed manifest file + (`transcribed_manifest.json`). + Returns: + List[str]: A list of file paths to the updated transcribed manifest files. + """ + + all_manifest_filepaths = [] + for prediction_filepath in manifest_filepaths: + predicted_data = {} + + prediction_path = os.path.join(prediction_filepath, "predictions_all.json") + with open(prediction_path, 'r') as f: + for line in f: + data_entry = json.loads(line) + path = data_entry['audio_filepath'] + + predicted_data[path] = data_entry + full_path = os.path.join(prediction_filepath, f"transcribed_manifest.json") + all_data_entries = [] + count = 0 + with open(full_path, 'r') as f: + for line in f: + count += 1 + data_entry = json.loads(line) + all_data_entries.append(data_entry) + + + output_filename = os.path.join(prediction_filepath, f"transcribed_manifest.json") + with open(output_filename, 'w') as f: + for data_entry in all_data_entries: + audio_filepath = data_entry['audio_filepath'] + if audio_filepath.endswith(".wav"): + if audio_filepath in predicted_data: + predicted_data_entry = predicted_data[audio_filepath] + if 'text' in predicted_data_entry: + predicted_data_entry['orig_text'] = predicted_data_entry.pop('text') + predicted_data_entry['text'] = predicted_data_entry.pop('pred_text') + json.dump(predicted_data_entry, f, ensure_ascii=False) + f.write("\n") + else: + json.dump(data_entry, f, ensure_ascii=False) + f.write("\n") + all_manifest_filepaths.append(output_filename) + return all_manifest_filepaths + + +def update_training_sets( + merged_config: OmegaConf, final_cache_manifests: list, tarred_audio_filepaths: Union[list, str] +) -> OmegaConf: + """ + Adds pseudo-labeled sets to the training datasets based on dataset type and + handles tarred audio files differently. The function updates the 'manifest_filepath' + and 'tarred_audio_filepaths' fields in the training dataset configuration. + Args: + merged_config: The configuration object containing the model and dataset settings. + final_cache_manifests: A list of paths to the manifest files for the pseudo-labeled data. + tarred_audio_filepaths: A string or list of tarred audio file paths to be added to the training set. + Returns: + merged_config: The updated configuration object with the new training datasets. + """ + + print() + print(f"update_training_sets") + print(f"") + if merged_config.model.train_ds.get("is_tarred", False): + if isinstance(tarred_audio_filepaths, str): + if isinstance(merged_config.model.train_ds['tarred_audio_filepaths'], str): + merged_config.model.train_ds['tarred_audio_filepaths'] = [ + [merged_config.model.train_ds['tarred_audio_filepaths']], + [tarred_audio_filepaths], + ] + else: + merged_config.model.train_ds.tarred_audio_filepaths.append(tarred_audio_filepaths) + else: + if isinstance(merged_config.model.train_ds.tarred_audio_filepaths, str): + merged_config.model.train_ds.tarred_audio_filepaths = [ + [merged_config.model.train_ds.tarred_audio_filepaths] + ] + merged_config.model.train_ds.tarred_audio_filepaths += tarred_audio_filepaths + + if isinstance(merged_config.model.train_ds.manifest_filepath, str): + merged_config.model.train_ds.manifest_filepath = [merged_config.model.train_ds.manifest_filepath] + + merged_config.model.train_ds.manifest_filepath += final_cache_manifests + + else: + print(f"is not tarred") + if isinstance(merged_config.model.train_ds.manifest_filepath, str): + print(f"is str") + merged_config.model.train_ds.manifest_filepath = [merged_config.model.train_ds.manifest_filepath] + + if merged_config.model.train_ds.get("use_lhotse", False): + print(f"is lhotse") + merged_config.model.train_ds.manifest_filepath = [merged_config.model.train_ds.manifest_filepath] + merged_config.model.train_ds.manifest_filepath.append(final_cache_manifests) + else: + print(f"not lhotse") + print(f"merged_config.model.train_ds.manifest_filepath {merged_config.model.train_ds.manifest_filepath}") + print(f"final_cache_manifests {final_cache_manifests}") + merged_config.model.train_ds.manifest_filepath += final_cache_manifests + + + return merged_config + + +def count_files_for_pseudo_labeling(manifest_filepath: str, is_tarred: bool) -> int: + """ + Counts the number of files for pseudo-labeling. + Args: + manifest_filepath (str): The path to the manifest file(s). + is_tarred (bool): Flag to determine whether to count files for multiple shard manifests. + Returns: + int: The total number of audio files given for pseudo labeling. + """ + if is_tarred: + dir_path, filename = os.path.split(manifest_filepath) + prefix = filename.split('_', 1)[0] + number_of_files = 0 + for full_path in glob.glob(os.path.join(dir_path, f"{prefix}_[0-9]*.json")): + with open(full_path, 'r') as f: + number_of_files += len(f.readlines()) + else: + with open(manifest_filepath, 'r') as f: + number_of_files = len(f.readlines()) + + return number_of_files diff --git a/sdp/utils/nemo_run_utils.py b/sdp/utils/nemo_run_utils.py new file mode 100644 index 00000000..5cbd8575 --- /dev/null +++ b/sdp/utils/nemo_run_utils.py @@ -0,0 +1,406 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from functools import lru_cache +from nemo_run.core.tunnel import LocalTunnel, SSHTunnel +from omegaconf import DictConfig, OmegaConf +from sdp.utils.skills_utils import ( + get_mounts_from_config, + check_if_mounted, + add_task, + run_exp, +) +import logging +import copy +from sdp.utils import ipl_utils +@lru_cache(maxsize=2) +def get_tunnel(**ssh_tunnel): + return SSHTunnel(**ssh_tunnel) + + + +def add_mount_path(mount_source: str, mount_dest: str, cluster_config): + """ + Add a mount path to the cluster config. + + Args: + mount_source: The source filepath on the local/remote machine. + mount_dest: The destination filepath on the remote/local machine. Must be an absolute path. + cluster_config: The cluster config dictionary. + """ + + # Check if the cluster config is provided + if cluster_config is None: + raise ValueError("Cluster config is not provided.") + + # Check if the mounts key is present in the cluster config + if 'mounts' in cluster_config: + # Resolve the environment variables for the mount source and mount destination + original_mounts = get_mounts_from_config(cluster_config) + + added_mount = False + for mount_path in original_mounts: + source, destination = mount_path.split(':') + + # Check if the mount path already exists in the cluster config + if source == mount_source and destination == mount_dest: + return + + # Add the mount path to the cluster config if it does not already exist + if not added_mount: + cluster_config['mounts'].append(f"{mount_source}:{mount_dest}") + logging.info(f"Added mount path: `{mount_source}:{mount_dest}`") + + else: + # Don't add a new mount path if the mounts key is not present in the cluster config + raise ValueError("No mounts found in cluster config, can only add to existing mount list.") + + +def create_remote_directory(directory: str | list, cluster_config: dict): + """ + Create a remote directory on the cluster using the cluster config. + + **Note**: The ssh tunnel config must be provided in the cluster config for remote directory creation. + + Args: + directory: The directory path to be created on the remote cluster. Can be a single directory path or a list + of directory paths. + cluster_config: The cluster config dictionary. + """ + + if cluster_config is None: + raise ValueError("Cluster config is not provided.") + + # Check if the directory is a string or a list + if isinstance(directory, str): + directory = [directory] + + # Check if the executor is local + if cluster_config.get('executor') == 'local': + tunnel = LocalTunnel(job_dir=directory[0]) # temp job dir, unused + for dir_path in directory: + tunnel.run(f'mkdir -p {dir_path}', hide=False, warn=True) + logging.info(f"Created directory: {dir_path} in local filesystem.") + + # Dont cleanup, cache the tunnel + # tunnel.cleanup() + + # Check if the executor is slurm + elif cluster_config.get('executor') == 'slurm': + # Check if the ssh tunnel config is provided in the cluster config + ssh_tunnel_config = cluster_config.get('ssh_tunnel', None) + if ssh_tunnel_config is None: + raise ValueError("`ssh_tunnel` sub-config is not provided in cluster_config.") + + # Check for pre-existing job_dir in the ssh_tunnel_config + if 'job_dir' not in ssh_tunnel_config: + ssh_tunnel_config['job_dir'] = directory[0] + + # Create the remote directory on the cluster + tunnel = get_tunnel(**cluster_config['ssh_tunnel']) + for dir_path in directory: + tunnel.run(f'mkdir -p {dir_path}', hide=False, warn=True) + logging.info(f"Created directory: {dir_path} on remote cluster.") + + # Dont cleanup, cache the tunnel + # tunnel.cleanup() + + else: + raise ValueError(f"Unsupported executor: {cluster_config.get('executor')}") + + +def create_remote_config(config: dict, config_name: str, config_directory: str, cluster_config: dict): + """ + Utility to write a remote config file on the cluster using the cluster config. + + Args: + config: The config dictionary to be written to the file. Can be OmegaConf DictConfig or a dictionary. + config_name: The name of the config file to be created. + config_directory: The directory path where the config file will be created on the remote machine. + Can be a single directory path or a list of directory paths to copy the config file to. + cluster_config: The cluster config dictionary. + """ + if cluster_config is None: + raise ValueError("Cluster config is not provided.") + + # Check if the config_name is a string and ends with .yaml + if not config_name.endswith('.yaml'): + config_name = f"{config_name}.yaml" + + # Check if the config_directory is a string or a list + if isinstance(config_directory, str): + config_directory = [config_directory] + + # Cast a normal dict to OmeagConf DictConfig + if isinstance(config, dict): + config = OmegaConf.create(config) + + # Check if the executor is local + if cluster_config.get('executor') == 'local': + tunnel = LocalTunnel(job_dir=config_directory[0]) + + # Create the config file on the local filesystem + for dir_path in config_directory: + config_filepath = os.path.join(dir_path, config_name) + tunnel.run(f'mkdir -p {dir_path}', hide=False, warn=True) + tunnel.run(f"touch {config_filepath}", hide=False, warn=True) + tunnel.run(f"echo '{OmegaConf.to_yaml(config)}' > {config_filepath}", hide=False, warn=True) + logging.info(f"Created config file: {dir_path} in local filesystem.") + + # Dont cleanup, cache the tunnel + # tunnel.cleanup() + + # Check if the executor is slurm + elif cluster_config.get('executor') == 'slurm': + # Check if the ssh tunnel config is provided in the cluster config + ssh_tunnel_config = cluster_config.get('ssh_tunnel', None) + if ssh_tunnel_config is None: + raise ValueError("`ssh_tunnel` sub-config is not provided in cluster_config.") + + # Check for pre-existing job_dir in the ssh_tunnel_config + if 'job_dir' not in ssh_tunnel_config: + ssh_tunnel_config['job_dir'] = config_directory[0] + + tunnel = get_tunnel(**cluster_config['ssh_tunnel']) + + # Create the config file on the remote cluster + for dir_path in config_directory: + config_filepath = os.path.join(dir_path, config_name) + tunnel.run(f'mkdir -p {dir_path}', hide=False, warn=True) + tunnel.run(f"touch {config_filepath}", hide=False, warn=True) + tunnel.run(f"echo '{OmegaConf.to_yaml(config)}' > {config_filepath}", hide=False, warn=True) + logging.info(f"Created config file: {dir_path} on remote cluster.") + + # Dont cleanup, cache the tunnel + # tunnel.cleanup() + + else: + raise ValueError(f"Unsupported executor: {cluster_config.get('executor')}") + return config_filepath + +def create_remote_inference_config(cluster_config, config_directory: str, inference_config, checkpoint_path): + """ + Utility to create and write remote inference configuration files for a cluster setup. + + Args: + cluster_config (dict): The cluster configuration dictionary containing details about the cluster setup, + including the executor type (`local` or `slurm`) and optional SSH tunnel configurations. + config_directory (str or list of str): The directory path(s) where the inference configuration file(s) + will be created on the remote machine. If a single path is provided, it will be converted into a list. + inference_config: The base inference configuration object, which will be modified for each bucket. + Should be compatible with OmegaConf. + checkpoint_path (str): The path to the model checkpoint, which will be included in the modified inference configuration. + + Returns: + tuple: A tuple containing: + - new_config_paths (list): A list of paths to the newly created inference configuration files. + - manifests (list): A list of manifest file paths, one for each bucket. + - tarr_audio_files (list or None): A list of tarred audio file paths, one for each bucket, or None if not applicable. + """ + if isinstance(config_directory, str): + config_directory = [config_directory] + + # separating each bucket for creating different inference config + manifests, tarr_audio_files = ipl_utils.separate_multiple_transcriptions(inference_config) + + new_config_paths = [] + for i in range(len(manifests)): + output_dir = os.path.dirname(manifests[i]) + modified_cfg = copy.deepcopy(inference_config) + # Updating inference config for exact bucket + OmegaConf.update(modified_cfg, "output_path", output_dir) + OmegaConf.update(modified_cfg, "predict_ds.manifest_filepath", manifests[i]) + if tarr_audio_files: + OmegaConf.update(modified_cfg, "predict_ds.tarred_audio_filepaths", tarr_audio_files[i]) + OmegaConf.update(modified_cfg, "model", checkpoint_path) + + if cluster_config.get('executor') == 'local': + for dir_path in config_directory: + inference_config_filepath = os.path.join(dir_path, f"modified_config_{i}.yaml") + new_config_paths.append(os.path.abspath(inference_config_filepath)) + tunnel = LocalTunnel(job_dir=config_directory[0]) + tunnel.run(f"touch {inference_config_filepath}", hide=False, warn=True) + tunnel.run( + f"echo '{OmegaConf.to_yaml(modified_cfg)}' > {inference_config_filepath}", hide=False, warn=True + ) + logging.info(f"Created config file: {dir_path} in local filesystem.") + elif cluster_config.get('executor') == 'slurm': + ssh_tunnel_config = cluster_config.get('ssh_tunnel', None) + if ssh_tunnel_config is None: + raise ValueError("`ssh_tunnel` sub-config is not provided in cluster_config.") + if 'job_dir' not in ssh_tunnel_config: + ssh_tunnel_config['job_dir'] = config_directory[0] + tunnel = get_tunnel(**cluster_config['ssh_tunnel']) + + for dir_path in config_directory: + # Creating config files also locally to be able to count + inference_config_filepath = os.path.join(dir_path, f"modified_config_{i}.yaml") + new_config_paths.append(inference_config_filepath) + tunnel.run(f"touch {inference_config_filepath}", hide=False, warn=True) + tunnel.run( + f"echo '{OmegaConf.to_yaml(modified_cfg)}' > {inference_config_filepath}", hide=False, warn=True + ) + + return new_config_paths, manifests, tarr_audio_files + + +def check_remote_mount_directories(directories: str | list, cluster_config: dict, exit_on_failure: bool = True): + """ + Check if files and directories at the source location exist for later mounting on the cluster. + + Args: + directories: The directory path to be checked on the local/remote machine. Can be a single directory + path or a list. Can be either a file or a directory. + cluster_config: The cluster config dictionary. + exit_on_failure: If True, will raise an exception if the directories do not exist at the source location. + """ + + # Check if the cluster config is provided + if cluster_config is None: + raise ValueError("Cluster config is not provided.") + + # Check if the directories is a string or a list + if isinstance(directories, str): + directories = [directories] + + # Check if the executor is local + if cluster_config.get('executor') == 'local': + tunnel = LocalTunnel(job_dir=None) + + # Check if the directories exist at the source location for mounting + missing_source_locations = [] + for directory in directories: + result = tunnel.run(f'test -e {directory} && echo "Directory Exists"', hide=True, warn=True) + + if "Directory Exists" not in result.stdout: + missing_source_locations.append(directory) + + # Dont cleanup, cache the tunnel + # tunnel.cleanup() + + # Raise an exception if the directories do not exist at the source location + if len(missing_source_locations) > 0 and exit_on_failure: + missing_source_locations = [ + f"{loc} DOES NOT exist at source destination" for loc in missing_source_locations + ] + missing_source_locations = "\n".join(missing_source_locations) + raise FileNotFoundError( + f"Some files or directories do not exist at the source location for mounting !!\n\n" + f"{missing_source_locations}" + ) + + # Check if the executor is slurm + elif cluster_config.get('executor') == 'slurm': + # Check if the ssh tunnel config is provided in the cluster config + ssh_tunnel_config = cluster_config.get('ssh_tunnel', None) + if ssh_tunnel_config is None: + raise ValueError("`ssh_tunnel` sub-config is not provided in cluster_config.") + + # Check for pre-existing job_dir in the ssh_tunnel_config + if 'job_dir' not in ssh_tunnel_config: + ssh_tunnel_config['job_dir'] = os.getcwd() + + tunnel = get_tunnel(**cluster_config['ssh_tunnel']) + missing_source_locations = [] + + # Check if the directories exist at the source location for mounting + for directory in directories: + result = tunnel.run(f'test -e {directory} && echo "Directory Exists"', hide=True, warn=True) + + if "Directory Exists" not in result.stdout: + missing_source_locations.append(directory) + + # Dont cleanup, cache the tunnel + # tunnel.cleanup() + + # Raise an exception if the directories do not exist at the source location + if len(missing_source_locations) > 0 and exit_on_failure: + missing_source_locations = [ + f"{loc} DOES NOT exist at source destination" for loc in missing_source_locations + ] + missing_source_locations = "\n".join(missing_source_locations) + raise FileNotFoundError( + f"Some files or directories do not exist at the source location for mounting !!\n\n" + f"{missing_source_locations}" + ) + + else: + raise ValueError(f"Unsupported executor: {cluster_config.get('executor')}") + + +def get_unmounted_filepath(cluster_config: dict, filepath: str): + """ + Resolve the mounted filepath using the cluster config to merge the mount source path to the filepath. + Raises an exception if the mount path is not found for the file path. + + Args: + cluster_config: The cluster config dictionary. + filepath: The filepath to be unmounted using the cluster config. + + Returns: + str: unmounted filepath + """ + # Find which mount path matches the filepaths prefix + mount_path = None + for mount in cluster_config['mounts']: + mount_source, mount_dest = mount.split(':') + if filepath.startswith(mount_dest): + mount_path = mount + break + + if mount_path is None: + raise ValueError( + f"Could not find a mount path for the file path `{filepath}`. Below paths are mounted: \n" + f"{cluster_config['mounts']}" + ) + + # replace the mount destination inside the filepath with the mount source + mount_source, mount_dest = mount_path.split(':') + filepath = mount_source + filepath[len(mount_dest) :] # replace the mount destination with the mount source + + return filepath + + +def get_mounted_filepath(cluster_config: dict, filepath: str): + """ + Resolve the mounted filepath using the cluster config to merge the mount destination path to the filepath. + Raises an exception if the mount path is not found for the file path. + + Args: + cluster_config: The cluster config dictionary. + filepath: The filepath to be mounted using the cluster config. + + Returns: + str: mounted filepath + """ + # Find which mount path matches the filepaths prefix + mount_path = None + for mount in cluster_config['mounts']: + mount_source, mount_dest = mount.split(':') + if filepath.startswith(mount_source): + mount_path = mount + break + + if mount_path is None: + raise ValueError( + f"Could not find a mount path for the file path `{filepath}`. Below paths are mounted: \n" + f"{cluster_config['mounts']}" + ) + + # replace the mount destination inside the filepath with the mount source + mount_source, mount_dest = mount_path.split(':') + filepath = mount_dest + filepath[len(mount_source) :] # replace the mount destination with the mount source + + return filepath diff --git a/sdp/utils/skills_utils.py b/sdp/utils/skills_utils.py new file mode 100644 index 00000000..ac536043 --- /dev/null +++ b/sdp/utils/skills_utils.py @@ -0,0 +1,1229 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +#This file is maintained in sync with `nemo_skills/pipeline/utils.py` +# and is intended to be copied as-is to ensure consistency across projects. + +import logging +import os +import shlex +import subprocess +import sys +import tarfile +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import datetime +from functools import lru_cache +from pathlib import Path +from typing import Optional + +import nemo_run as run +import yaml +from huggingface_hub import get_token +try: + from invoke import StreamWatcher +except ImportError: + StreamWatcher = object # fallback if invoke is not installed +from nemo_run.config import set_nemorun_home +from nemo_run.core.execution.docker import DockerExecutor +from nemo_run.core.execution.slurm import SlurmJobDetails, get_packaging_job_key +from nemo_run.core.tunnel import SSHTunnel +from omegaconf import DictConfig + +LOG = logging.getLogger(__file__) + + +# TODO: this file is way too big - we need to split it into pieces + +# keeping a global variable for first submitted experiment (per cluster) and reusing it by default +# we are using ssh tunnel as a proxy for cluster identity, since even if other parameters are different +# we can still reuse code as long as ssh matches +REUSE_CODE_EXP = {} + + +@dataclass +class RepoMetadata: + """Metadata for a repo that is used in the experiment.""" + + name: str + path: Path + + def __post_init__(self): + if isinstance(self.path, str): + self.path = Path(self.path) + + if not self.path.exists(): + raise ValueError(f"Repository path `{self.path}` does not exist.") + + +# Registry of external repos that should be packaged with the code in the experiment +EXTERNAL_REPOS = { + 'nemo_skills': RepoMetadata( + name='nemo_skills', path=Path(__file__).absolute().parents[1] + ), # path to nemo_skills repo +} + + + +def register_external_repo(metadata: RepoMetadata): + """Register an external repo to be packaged with the code in the experiment. + + Args: + metadata (RepoMetadata): Metadata for the external repo. + """ + if metadata.name in EXTERNAL_REPOS: + raise ValueError(f"External repo {metadata.name} is already registered.") + + EXTERNAL_REPOS[metadata.name] = metadata + + +def get_registered_external_repo(name: str) -> Optional[RepoMetadata]: + """Get the path to the registered external repo. + + Args: + name (str): Name of the external repo. + + Returns: + A path to the external repo if it is registered, otherwise None. + """ + if name not in EXTERNAL_REPOS: + return None + + return EXTERNAL_REPOS[name] + + +def check_if_mounted(cluster_config, path_to_check): + """Will check that path_to_check is referenced inside one of the mounts.""" + for mount in get_mounts_from_config(cluster_config) + ['/nemo_run/code:/nemo_run/code']: + if path_to_check.startswith(mount.split(":")[1]): + return + raise ValueError(f"The path '{path_to_check}' is not mounted. Check cluster config.") + + +def get_unmounted_path(cluster_config, path): + """Will return the path on the filesystem before it's mounted.""" + if path is None: + return None + for mount in get_mounts_from_config(cluster_config): + if path.startswith(mount.split(":")[1]): + return mount.split(":")[0] + path[len(mount.split(":")[1]) :] + raise ValueError(f"The path '{path}' is not mounted. Check cluster config.") + + +# caching the status assuming it doesn't change while experiment is being scheduled +# otherwise this results in too many ssh calls +@lru_cache +def get_exp_handles(expname: str, ignore_finished=True, ignore_exp_not_exists=True) -> list[str]: + """Will return the handles of the tasks in the experiment. + + If ignore_finished=True, will only return handles for the tasks + that are not yet finished. Useful for filtering handles to set dependencies on. + + If ignore_exp_not_exists=True, will not raise an error if the experiment does not exist. + + TODO: it's still possible that job submission fails if the tasks exist when this function + is called, but finish before nemo-run submits a new job (which might take minutes) + """ + from torchx.specs.api import AppState + + def _get_handles(exp): + handles = [] + for job in exp.jobs: + if not ignore_finished or ( + job.status(exp._runner) in [AppState.RUNNING, AppState.PENDING, AppState.SUBMITTED, AppState.UNKNOWN] + ): + handles.append(job.handle) + continue + return handles + + # if we are given an experiment object, we can directly get the handles + if isinstance(expname, run.Experiment): + return _get_handles(expname) + + try: + with run.Experiment.from_title(expname) as exp: + return _get_handles(exp) + except FileNotFoundError: + try: + with run.Experiment.from_id(expname) as exp: + return _get_handles(exp) + except AssertionError: + if ignore_exp_not_exists: + LOG.warning("Experiment %s not found!", expname) + return [] + raise ValueError(f"Experiment {expname} not found!") + + +def get_timeout(cluster_config, partition): + if 'timeouts' not in cluster_config: + timeout = "10000:00:00:00" + else: + timeout = cluster_config["timeouts"][partition or cluster_config["partition"]] + + # subtracting 15 minutes to account for the time it takes to save the model + # the format expected by nemo is days:hours:minutes:seconds + time_diff = datetime.strptime(timeout, "%H:%M:%S") - datetime.strptime("00:15:00", "%H:%M:%S") + timeout = ( + f'00:{time_diff.seconds // 3600:02d}:{(time_diff.seconds % 3600) // 60:02d}:{time_diff.seconds % 60:02d}' + ) + return timeout + + +def get_free_port(exclude: list[int] | None = None, strategy: int | str = 5000) -> int: + """Will return a free port on the host.""" + exclude = exclude or [] + if isinstance(strategy, int): + port = strategy + while port in exclude: + port += 1 + return port + elif strategy == "random": + import random + + port = random.randint(1024, 65535) + while port in exclude: + port = random.randint(1024, 65535) + return port + else: + raise ValueError(f"Strategy {strategy} not supported.") + + +def get_generation_command(server_address, generation_commands): + cmd = ( + f"export PYTHONPATH=$PYTHONPATH:/nemo_run/code && " + f"cd /nemo_run/code && " + # might be required if we are not hosting server ourselves + # this will try to handshake in a loop and unblock when the server responds + f"echo 'Waiting for the server to start at {server_address}' && " + f"while [ $(curl -X PUT {server_address} >/dev/null 2>&1; echo $?) -ne 0 ]; do sleep 3; done && " + # will run in a single task always (no need to check mpi env vars) + f"{generation_commands}" + ) + return cmd + + +def get_reward_server_command( + server_type: str, + num_gpus: int, + num_nodes: int, + model_path: str, + cluster_config: dict, + server_port: int, + server_args: str = "", +): + num_tasks = num_gpus + + # check if the model path is mounted if not vllm; + # vllm can also pass model name as "model_path" so we need special processing + if server_type != "vllm": + check_if_mounted(cluster_config, model_path) + + # the model path will be mounted, so generally it will start with / + elif server_type == "vllm" and model_path.startswith("/"): + check_if_mounted(cluster_config, model_path) + + if server_type == 'nemo': + nemo_aligner_reward_model_port = get_free_port(strategy="random", exclude=[server_port]) + server_start_cmd = ( + # Note: The order of the two commands is important as the reward model server + # needs to be the first command so it can get the HF_TOKEN from the environment + f"python -m nemo_skills.inference.server.serve_nemo_aligner_reward_model " + f" ++rm_model_file={model_path} " + f" trainer.devices={num_gpus} " + f" trainer.num_nodes={num_nodes} " + f" +model.tensor_model_parallel_size={num_gpus} " + f" +model.pipeline_model_parallel_size={num_nodes} " + # This port could be configurable, but is hard coded to reduce + # the divergence of the server command parameters from pipeline/generate.py + f" inference.port={nemo_aligner_reward_model_port} " + f" {server_args} & " + f"python -m nemo_skills.inference.server.serve_nemo_reward_model " + # These ports could be configurable, but is hard coded to reduce + # the divergence of the server command parameters from pipeline/generate.py + f" inference_port={server_port} " + f" triton_server_address=localhost:{nemo_aligner_reward_model_port} " + ) + + # somehow on slurm nemo needs multiple tasks, but locally only 1 + if cluster_config["executor"] == "local": + num_tasks = 1 + + elif server_type == "vllm": + if num_nodes > 1: + raise ValueError("VLLM server does not support multi-node execution") + + server_start_cmd = ( + f"python3 -m nemo_skills.inference.server.serve_vllm " + f" --model {model_path} " + f" --num_gpus {num_gpus} " + f" --port {server_port} " + f" {server_args} " + ) + num_tasks = 1 + else: + raise ValueError(f"Server type '{server_type}' not supported for reward model.") + + server_cmd = ( + f"nvidia-smi && " + f"cd /nemo_run/code && " + f"export PYTHONPATH=$PYTHONPATH:/nemo_run/code && " + f"{server_start_cmd} " + ) + return server_cmd, num_tasks + + +def get_ray_server_cmd(start_cmd): + ports = ( + "--node-manager-port=12345 " + "--object-manager-port=12346 " + "--dashboard-port=8265 " + "--dashboard-agent-grpc-port=12347 " + "--runtime-env-agent-port=12349 " + "--metrics-export-port=12350 " + "--min-worker-port=14349 " + "--max-worker-port=18349 " + ) + + ray_start_cmd = ( + "if [ \"${SLURM_PROCID:-0}\" = 0 ]; then " + " echo 'Starting head node' && " + " export RAY_raylet_start_wait_time_s=120 && " + " ray start " + " --head " + " --port=6379 " + f" {ports} && " + f" {start_cmd} ;" + "else " + " echo 'Starting worker node' && " + " export RAY_raylet_start_wait_time_s=120 && " + " echo \"Connecting to head node at $SLURM_MASTER_NODE\" && " + " ray start " + " --block " + " --address=$SLURM_MASTER_NODE:6379 " + f" {ports} ;" + "fi" + ) + return ray_start_cmd + + +def get_server_command( + server_type: str, + num_gpus: int, + num_nodes: int, + model_path: str, + cluster_config: dict, + server_port: int, + server_args: str = "", +): + num_tasks = num_gpus + + # check if the model path is mounted if not vllm; + # vllm can also pass model name as "model_path" so we need special processing + if server_type != "vllm": + check_if_mounted(cluster_config, model_path) + + # the model path will be mounted, so generally it will start with / + elif server_type == "vllm" and model_path.startswith("/"): + check_if_mounted(cluster_config, model_path) + + if server_type == 'nemo': + server_start_cmd = ( + f"python -m nemo_skills.inference.server.serve_nemo " + f" gpt_model_file={model_path} " + f" trainer.devices={num_gpus} " + f" trainer.num_nodes={num_nodes} " + f" tensor_model_parallel_size={num_gpus} " + f" pipeline_model_parallel_size={num_nodes} " + f" ++port={server_port} " + f" {server_args} " + ) + + # somehow on slurm nemo needs multiple tasks, but locally only 1 + if cluster_config["executor"] == "local": + num_tasks = 1 + elif server_type == 'vllm': + start_vllm_cmd = ( + f"python3 -m nemo_skills.inference.server.serve_vllm " + f" --model {model_path} " + f" --num_gpus {num_gpus} " + f" --port {server_port} " + f" {server_args} " + ) + server_start_cmd = get_ray_server_cmd(start_vllm_cmd) + num_tasks = 1 + elif server_type == 'sglang': + if num_nodes > 1: + multinode_args = f" --dist_init_addr $SLURM_MASTER_NODE --node_rank $SLURM_PROCID " + else: + multinode_args = "" + server_start_cmd = ( + f"python3 -m nemo_skills.inference.server.serve_sglang " + f" --model {model_path} " + f" --num_gpus {num_gpus} " + f" --num_nodes {num_nodes} " + f" --port {server_port} " + f" {multinode_args} " + f" {server_args} " + ) + num_tasks = 1 + else: + # need this flag for stable Nemotron-4-340B deployment + server_start_cmd = ( + f"FORCE_NCCL_ALL_REDUCE_STRATEGY=1 python -m nemo_skills.inference.server.serve_trt " + f" --model_path {model_path} " + f" --port {server_port} " + f" {server_args} " + ) + num_tasks = num_gpus + + server_cmd = ( + f"nvidia-smi && " + f"cd /nemo_run/code && " + f"export PYTHONPATH=$PYTHONPATH:/nemo_run/code && " + f"{server_start_cmd} " + ) + return server_cmd, num_tasks + + +def get_sandox_command(): + return "/entrypoint.sh && /start.sh" + + +@dataclass(kw_only=True) +class CustomJobDetails(SlurmJobDetails): + # we have 1 srun per sub-task (e.g. server/sandbox/main), but only a single sbatch + srun_prefix: str = "main" + sbatch_prefix: str = "" + + @property + def stdout(self) -> Path: + return Path(self.folder) / f"{self.sbatch_prefix}%j_sbatch.log" + + @property + def srun_stdout(self) -> Path: + return Path(self.folder) / f"{self.srun_prefix}%j_srun.log" + + @property + def stderr(self) -> Path: + return Path(self.folder) / f"{self.sbatch_prefix}%j_sbatch.log" + + @property + def srun_stderr(self) -> Path: + return Path(self.folder) / f"{self.srun_prefix}%j_srun.log" + + @property + def ls_term(self) -> str: + """This term will be used to fetch the logs. + + The command used to list the files is ls -1 {ls_term} 2> /dev/null + """ + assert self.folder + return os.path.join(self.folder, "*srun.log") + + +def read_config(config_file): + with open(config_file, "rt", encoding="utf-8") as fin: + cluster_config = yaml.safe_load(fin) + + return cluster_config + + +def get_cluster_config(cluster=None, config_dir=None): + """Trying to find an appropriate cluster config. + + Will search in the following order: + 1. config_dir parameter + 2. NEMO_SKILLS_CONFIG_DIR environment variable + 3. Current folder / cluster_configs + 4. This file folder / ../../cluster_configs + + If NEMO_SKILLS_CONFIG is provided and cluster is None, + it will be used as a full path to the config file + and NEMO_SKILLS_CONFIG_DIR will be ignored. + + If cluster is a python object (dict-like), then we simply + return the cluster config, under the assumption that the + config is prepared by the user. + """ + # if cluster is provided, we try to find it in one of the folders + if cluster is not None: + # check if cluster is a python object instead of a str path, pass through + if isinstance(cluster, (dict, DictConfig)): + return cluster + + # either using the provided config_dir or getting from env var + config_dir = config_dir or os.environ.get("NEMO_SKILLS_CONFIG_DIR") + if config_dir: + return read_config(Path(config_dir) / f"{cluster}.yaml") + + # if it's not defined we are trying to find locally + if (Path.cwd() / 'cluster_configs' / f"{cluster}.yaml").exists(): + return read_config(Path.cwd() / 'cluster_configs' / f"{cluster}.yaml") + + if (Path(__file__).parents[2] / 'cluster_configs' / f"{cluster}.yaml").exists(): + return read_config(Path(__file__).parents[2] / 'cluster_configs' / f"{cluster}.yaml") + + raise ValueError(f"Cluster config {cluster} not found in any of the supported folders.") + + config_file = os.environ.get("NEMO_SKILLS_CONFIG") + if not config_file: + raise ValueError("Either cluster or NEMO_SKILLS_CONFIG must be provided.") + + if not Path(config_file).exists(): + raise ValueError(f"Cluster config {config_file} not found.") + + cluster_config = read_config(config_file) + + if cluster_config['executor'] == 'slurm' and "ssh_tunnel" not in cluster_config: + if "job_dir" not in cluster_config: + raise ValueError("job_dir must be provided in the cluster config if ssh_tunnel is not provided.") + set_nemorun_home(cluster_config["job_dir"]) + + return cluster_config + + +@lru_cache +def _get_tunnel_cached( + job_dir: str, + host: str, + user: str, + identity: str | None = None, + shell: str | None = None, + pre_command: str | None = None, +): + return run.SSHTunnel( + host=host, + user=user, + identity=identity, + shell=shell, + pre_command=pre_command, + job_dir=job_dir, + ) + + +def tunnel_hash(tunnel): + return f"{tunnel.job_dir}:{tunnel.host}:{tunnel.user}:{tunnel.identity}:{tunnel.shell}:{tunnel.pre_command}" + + +def get_tunnel(cluster_config): + if "ssh_tunnel" not in cluster_config: + LOG.info("No ssh_tunnel configuration found, assuming we are running from the cluster already.") + return run.LocalTunnel(job_dir="") + return _get_tunnel_cached(**cluster_config["ssh_tunnel"]) + + +# Helper class and function to support streaming updates +class OutputWatcher(StreamWatcher): + """Class for streaming remote tar/compression process.""" + + def submit(self, stream): + print(stream, end='\r') + sys.stdout.flush() + return [] + + +def progress_callback(transferred: int, total: int) -> None: + """Display SFTP transfer progress.""" + percent = (transferred / total) * 100 + bar = '=' * int(percent / 2) + '>' + sys.stdout.write( + f'\rFile Transfer Progress: [{bar:<50}] {percent:.1f}% ' + f'({transferred/1024/1024:.1f}MB/{total/1024/1024:.1f}MB)' + ) + sys.stdout.flush() + + +def cluster_download( + tunnel: SSHTunnel, remote_dir: str, local_dir: str, remote_tar_dir: Optional[str] = None, verbose: bool = True +): + """ + Downloads a directory from a remote cluster by creating a tar archive and transferring it. + + Args: + tunnel: SSHTunnel connection + remote_dir: Path to the directory on remote server + local_dir: Local path to save the downloaded directory + remote_tar_dir: Optional directory for temporary tar file creation + verbose: Print download progress + """ + + remote_dir = remote_dir.rstrip('/') + remote_dir_parent, remote_dir_name = os.path.split(remote_dir) + + # Directory where the remote tarball is written + remote_tar_dir = remote_tar_dir if remote_tar_dir else remote_dir_parent + # Path of the remote tar file + remote_tar_filename = f"{remote_dir_name}.tar.gz" + + # Remote and local tar files + remote_tar = f"{os.path.join(remote_tar_dir, remote_tar_filename)}" + local_tar = os.path.join(local_dir, remote_tar_filename) + + # Get the directory size + result = tunnel.run(f'du -sb {remote_dir} | cut -f1') + total_size = int(result.stdout.strip()) + + # Check if result directory compression is streamable + streaming_possible = False + try: + # Check whether the command pv is present on the remote system or not. + # Certain systems may not have the `pv` command + result = tunnel.run('which pv', warn=True) + streaming_possible = result.exited == 0 + except Exception: + streaming_possible = False + + if streaming_possible and verbose: + # We can do streaming compression + # Command for streaming the compression progress + command = ( + f'cd {remote_dir_parent} && ' + f'tar --exclude="*.log" -cf - {remote_dir_name} | ' + f'pv -s {total_size} -p -t -e -b -F "Compressing Remote Directory: %b %t %p" | ' + f'gzip > {remote_tar}' + ) + # Run the remote compression command and stream the progress + result = tunnel.run(command, watchers=[OutputWatcher()], pty=True, hide=(not verbose)) + else: + command = f'cd {remote_dir_parent} && tar -czf {remote_tar} {remote_dir_name}' + result = tunnel.run(command, hide=(not verbose)) + + # Get SFTP client from tunnel's session's underlying client + sftp = tunnel.session.client.open_sftp() + + # Use SFTP's get with callback + sftp.get(remote_tar, local_tar, callback=progress_callback if verbose else None) + print(f"\nTransfer complete: {local_tar}") + + # Extract the tarball locally + os.makedirs(local_dir, exist_ok=True) + with tarfile.open(local_tar, "r:gz") as tar: + tar.extractall(path=local_dir) + + # Clean up the tarball from the remote server + tunnel.run(f'rm {remote_tar}', hide=True) + + # Clean up the local tarball + os.remove(local_tar) + + +def cluster_upload(tunnel: SSHTunnel, local_file: str, remote_dir: str, verbose: bool = True): + """ + Uploads a file to cluster. + TODO: extend to a folder. + + Args: + tunnel: SSHTunnel connection + local_file: Path to the local file to upload + remote_dir: Cluster path where to save the file + verbose: Print upload progress + """ + sftp = tunnel.session.client.open_sftp() + sftp.put(str(local_file), str(remote_dir), callback=progress_callback if verbose else None) + print(f"\nTransfer complete") + + +def get_git_repo_path(path: str | Path = None): + """Check if the path is a git repo. + + Args: + path: Path to the directory to check. If None, will check the current directory. + + Returns: + Path to the repo if it is a git repo, otherwise None. + """ + original_path = os.getcwd() + try: + if path: + os.chdir(path) + + repo_path = ( + subprocess.run( + ["git", "rev-parse", "--show-toplevel"], + capture_output=True, + check=True, + ) + .stdout.decode() + .strip() + ) + return Path(repo_path) + + except subprocess.CalledProcessError: + return None + + finally: + os.chdir(original_path) + + +def get_packager(extra_package_dirs: tuple[str] | None = None): + """Will check if we are running from a git repo and use git packager or default packager otherwise.""" + nemo_skills_dir = get_registered_external_repo('nemo_skills').path + + if extra_package_dirs: + include_patterns = [str(Path(d) / '*') for d in extra_package_dirs] + include_pattern_relative_paths = [str(Path(d).parent) for d in extra_package_dirs] + else: + include_patterns = [] + include_pattern_relative_paths = [] + + check_uncommited_changes = False + + # are we in a git repo? If yes, we are uploading the current code + repo_path = get_git_repo_path(path=None) # check if we are in a git repo in pwd + + if repo_path: + # Do we have nemo_skills package in this repo? If no, we need to pick it up from installed location + if not (Path(repo_path) / 'nemo_skills').is_dir(): + logging.warning( + "Not running from NeMo-Skills repo, trying to upload installed package. " + "Make sure there are no extra files in %s", + str(nemo_skills_dir / '*'), + ) + include_patterns.append(str(nemo_skills_dir / '*')) + else: + # picking up local dataset files if we are in the right repo + include_patterns.append(str(nemo_skills_dir / "dataset/**/*.jsonl")) + include_pattern_relative_paths.append(str(nemo_skills_dir.parent)) + + root_package = run.GitArchivePackager( + include_pattern=include_patterns, + include_pattern_relative_path=include_pattern_relative_paths, + check_uncommitted_changes=check_uncommited_changes, + ) + else: + logging.warning( + "Not running from a git repo, trying to upload installed package. Make sure there are no extra files in %s", + str(nemo_skills_dir / '*'), + ) + include_patterns.append(str(nemo_skills_dir / '*')) + include_pattern_relative_paths.append(str(nemo_skills_dir.parent)) + + root_package = run.PatternPackager( + include_pattern=include_patterns, + relative_path=include_pattern_relative_paths, + ) + + extra_repos = {} + if len(EXTERNAL_REPOS) > 1: + # Insert root package as the first package + extra_repos['nemo_run'] = root_package + + for repo_name, repo_meta in EXTERNAL_REPOS.items(): + if repo_name == 'nemo_skills': + continue + + repo_path = repo_meta.path + if get_git_repo_path(repo_path): + # Extra repos is a git repos, so we need to package only committed files + extra_repos[repo_name] = run.GitArchivePackager( + basepath=str(repo_path), check_uncommitted_changes=check_uncommited_changes + ) + else: + # Extra repos is not a git repo, so we need to package all files in the directory + repo_include_pattern = [str(Path(repo_path) / '*')] + repo_include_pattern_relative_path = [str(Path(repo_path).parent)] + extra_repos[repo_name] = run.PatternPackager( + include_pattern=repo_include_pattern, + relative_path=repo_include_pattern_relative_path, + ) + + # Return hybrid packager + return run.HybridPackager(sub_packagers=extra_repos, extract_at_root=True) + + return root_package + + +def get_env_variables(cluster_config): + """ + Will get the environment variables from the cluster config and the user environment. + + The following items in the cluster config are supported: + - `required_env_vars` - list of required environment variables + - `env_vars` - list of optional environment variables + + WANDB_API_KEY, NVIDIA_API_KEY, OPENAI_API_KEY, and HF_TOKEN are always added if they exist. + + Args: + cluster_config: cluster config dictionary + + Returns: + dict: dictionary of environment + """ + env_vars = {} + # Check for user requested env variables + required_env_vars = cluster_config.get("required_env_vars", []) + for env_var in required_env_vars: + if "=" in env_var: + if env_var.count("=") == 1: + env_var, value = env_var.split("=") + else: + raise ValueError(f"Invalid required environment variable format: {env_var}") + env_vars[env_var.strip()] = value.strip() + logging.info(f"Adding required environment variable {env_var}") + elif env_var in os.environ: + logging.info(f"Adding required environment variable {env_var} from environment") + env_vars[env_var] = os.environ[env_var] + else: + raise ValueError(f"Required environment variable {env_var} not found.") + + # It is fine to have these as always optional even if they are required for some configs + # Assume it is required, then this will override the value set above with the same + # value, assuming it has not been updated externally between these two calls + always_optional_env_vars = ["WANDB_API_KEY", "NVIDIA_API_KEY", "OPENAI_API_KEY", "HF_TOKEN"] + default_factories = { + "HF_TOKEN": lambda: str(get_token()), + } + # Add optional env variables + optional_env_vars = cluster_config.get("env_vars", []) + for env_var in optional_env_vars + always_optional_env_vars: + if "=" in env_var: + if env_var.count("=") == 1: + env_var, value = env_var.split("=") + else: + raise ValueError(f"Invalid optional environment variable format: {env_var}") + env_vars[env_var.strip()] = value.strip() + logging.info(f"Adding optional environment variable {env_var}") + elif env_var in os.environ: + logging.info(f"Adding optional environment variable {env_var} from environment") + env_vars[env_var] = os.environ[env_var] + elif env_var in default_factories: + env_vars[env_var] = default_factories[env_var]() + logging.info(f"Adding optional environment variable {env_var} from environment") + else: + logging.info(f"Optional environment variable {env_var} not found in user environment; skipping.") + + return env_vars + + +def get_mounts_from_config(cluster_config: dict): + """ + Determines if there are mount paths that are being passed via environment variables. + Selects the key in the cluster config called `mounts` which is a list of strings. + Each string is in the format of `:` where `env_var` + is the name of the environment variable. + + Args: + cluster_config (dict): cluster config dictionary + + Returns: + list: updated list of mounts + """ + mounts = cluster_config.get('mounts', []) + + # if there are env_mounts, we will add the mounts from the env_mounts + for mount_id in range(len(mounts)): + mount = mounts[mount_id] + + if ":" not in mount: + raise ValueError(f"Invalid mount format: {mount}. The mount path must be separated by a colon.") + + mount_source, mount_target = mount.split(":") + + if mount_source[0] == "{" and mount_source[-1] == "}": + # Resolve the environment variable for the mount source + mount_source = mount_source[1:-1] + + if mount_source not in os.environ: + raise ValueError( + f"Required environment variable {mount_source} not found in env variables passed in cluster configs." + ) + + mount_source = os.environ[mount_source] + + if mount_target[0] == "{" and mount_target[-1] == "}": + # Resolve the environment variable for the mount target + mount_target = mount_target[1:-1] + + if mount_target not in os.environ: + raise ValueError( + f"Required environment variable {mount_target} not found in env variables passed in cluster configs." + ) + + mount_target = os.environ[mount_target] + + # add the mount to the list of mounts + resolved_mount = f"{mount_source}:{mount_target}" + mounts[mount_id] = resolved_mount + + return mounts + + +def get_executor( + cluster_config, + container, + num_nodes, + tasks_per_node, + gpus_per_node, + job_name, + log_dir, + log_prefix: str = "main", + mounts=None, + partition=None, + time_min=None, + dependencies=None, + extra_package_dirs: tuple[str] | None = None, + heterogeneous=False, + het_group=None, + total_het_groups=None, + slurm_kwargs: dict | None = None, +): + env_vars = get_env_variables(cluster_config) + config_mounts = get_mounts_from_config(cluster_config) + + mounts = mounts or config_mounts + if extra_package_dirs is not None: + extra_package_dirs = tuple(extra_package_dirs) + packager = get_packager(extra_package_dirs=extra_package_dirs) + if cluster_config["executor"] == "local": + if num_nodes > 1: + raise ValueError("Local executor does not support multi-node execution") + + env_vars["PYTHONUNBUFFERED"] = "1" # this makes sure logs are streamed right away + return DockerExecutor( + container_image=container, + packager=packager, + ipc_mode="host", + volumes=mounts, + ntasks_per_node=1, + num_gpus=gpus_per_node, + network="host", + env_vars=env_vars, + additional_kwargs={"entrypoint": ""}, + ) + + if not heterogeneous: + env_vars["SLURM_MASTER_NODE"] = "$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n1)" + else: + # master node will be within the same group + env_vars["SLURM_MASTER_NODE"] = ( + f"$(scontrol show hostnames $SLURM_JOB_NODELIST_HET_GROUP_{het_group} | head -n1)" + ) + # in addition defining master nodes for all groups to allow communication + for group in range(total_het_groups): + env_vars[f"SLURM_MASTER_NODE_HET_GROUP_{group}"] = ( + f"$(scontrol show hostnames $SLURM_JOB_NODELIST_HET_GROUP_{group} | head -n1)" + ) + + partition = partition or cluster_config.get("partition") + if 'timeouts' not in cluster_config: + timeout = "10000:00:00:00" + else: + timeout = cluster_config["timeouts"][partition] + + additional_parameters = {'time_min': time_min} if time_min is not None else {} + if cluster_config.get('mail_type') is not None: + additional_parameters['mail_type'] = cluster_config['mail_type'] + if cluster_config.get('mail_user') is not None: + additional_parameters['mail_user'] = cluster_config['mail_user'] + srun_args = [ + "--no-container-mount-home", + "--overlap", + "--mpi=pmix", + '--wait=10', + # we need to be explicit about this in srun as commands might need to run in parallel + f"--ntasks-per-node={tasks_per_node}", + f"--nodes={num_nodes}", + # NeMo-run should take care of this, but we'll put it here temporarily + f"--container-env={','.join([k.strip() for k in env_vars.keys()])}", + ] + if not cluster_config.get("disable_gpus_per_node", False) and gpus_per_node is not None: + srun_args.append(f"--gpus-per-node={gpus_per_node}") + + dependency_type = cluster_config.get("dependency_type", "afterany") + + return run.SlurmExecutor( + account=cluster_config["account"], + partition=partition, + nodes=num_nodes, + ntasks_per_node=tasks_per_node, + tunnel=get_tunnel(cluster_config), + container_image=container, + container_mounts=mounts, + time=timeout, + additional_parameters=additional_parameters, + packager=packager, + gpus_per_node=gpus_per_node if not cluster_config.get("disable_gpus_per_node", False) else None, + srun_args=srun_args, + job_details=CustomJobDetails( + job_name=cluster_config.get("job_name_prefix", "") + job_name, + folder=get_unmounted_path(cluster_config, log_dir), + srun_prefix=log_prefix + '_' + job_name + '_', + sbatch_prefix=job_name + '_', + ), + wait_time_for_group_job=0.01, + monitor_group_job_wait_time=20, + dependencies=dependencies, + dependency_type=dependency_type, + heterogeneous=heterogeneous, + env_vars=env_vars, + **(slurm_kwargs or {}), + ) + + +@contextmanager +def temporary_env_update(cluster_config, updates): + original_env_vars = cluster_config.get("env_vars", []).copy() + updated_env_vars = original_env_vars.copy() + for key, value in updates.items(): + updated_env_vars.append(f"{key}={value}") + cluster_config["env_vars"] = updated_env_vars + try: + yield + finally: + cluster_config["env_vars"] = original_env_vars + + +# TODO: this function has become too cumbersome to use with all recently added support +# we should make it simpler by perhaps removing separate logic for server/sandbox +# and supporting them through a list of cmds directly +# should also make heterogenous logic very clear and more robust +# and all parameters that can be list should be list for consistency +def add_task( + exp, + cmd: str | list[str], + task_name, + cluster_config, + container: str | list[str], + num_tasks: int | list[int] = 1, + num_gpus=None, + num_nodes=1, + log_dir=None, + partition=None, + time_min=None, + with_sandbox=False, + sandbox_port: int | None = None, + server_config=None, + reuse_code_exp: str = None, + reuse_code: bool = True, + task_dependencies: list[str] = None, + run_after: str | list[str] | None = None, + get_server_command=get_server_command, + extra_package_dirs: list[str] | None = None, + slurm_kwargs: dict | None = None, + heterogeneous: bool = False, +): + """Wrapper for nemo-run exp.add to help setting up executors and dependencies. + + Note that there are two parameters that control dependencies. + - task_dependencies: list of tasks that this task depends on **within the same experiment** + - run_after: a string with experiment name or a list of experiment names that this task + should run after. Will schedule dependencies on all tasks inside `run_after` experiments. + It needs to already be launched and running. + + Example of how to set task_dependencies: + + with run.Experiment(expname) as exp: + task1 = add_task(exp, ...) + task2 = add_task(exp, ..., task_dependencies=[task1]) + + You can use `reuse_code_exp` to reuse the code from another experiment + (and thus avoid costly packaging/ssh uploading). You can provide either experiment + name or the experiment object itself. + + By default we will reuse the code of the first submitted experiment. + If you want to avoid this, set `reuse_code=False`. + """ + if run_after is not None and cluster_config["executor"] == "slurm": + if isinstance(run_after, (str, run.Experiment)): + run_after = [run_after] + dependencies = [] + for dep_expname in run_after: + exp_handles = get_exp_handles(dep_expname) + if len(exp_handles) == 0: + LOG.warning( + "No pending or running tasks found for experiment %s, cannot set dependencies.", dep_expname + ) + dependencies.extend(exp_handles) + if len(dependencies) == 0: + dependencies = None + else: + dependencies = None + + if num_gpus is None and cluster_config['executor'] == "slurm": + if not 'cpu' in (partition or cluster_config.get("partition", "")): + num_gpus = 1 + + if sandbox_port is None: + sandbox_port = get_free_port(strategy="random") + + het_group = 0 + het_group_indices = [] + total_het_groups = (server_config is not None) + bool(cmd) + with_sandbox + + commands = [] + executors = [] + # assuming server always has the largest resources request, so it needs to go first + if server_config is not None: + server_cmd, num_server_tasks = get_server_command(**server_config, cluster_config=cluster_config) + if 'container' not in server_config: + server_container = cluster_config["containers"][server_config['server_type']] + server_executor = get_executor( + cluster_config=cluster_config, + container=server_container, + num_nodes=server_config['num_nodes'], + tasks_per_node=num_server_tasks, + gpus_per_node=server_config['num_gpus'], + partition=partition, + time_min=time_min, + dependencies=dependencies, + job_name=task_name, + log_dir=log_dir, + log_prefix="server", + extra_package_dirs=extra_package_dirs, + slurm_kwargs=slurm_kwargs, + heterogeneous=heterogeneous, + het_group=het_group, + total_het_groups=total_het_groups, + ) + if cluster_config["executor"] == "local" and num_server_tasks > 1: + server_cmd = f"mpirun --allow-run-as-root -np {num_server_tasks} bash -c {shlex.quote(server_cmd)}" + commands.append(server_cmd) + executors.append(server_executor) + het_group_indices.append(het_group) + het_group += 1 + + # then goes the main task(s) unless it's empty + if cmd: + if isinstance(cmd, str): + cmd = [cmd] + if isinstance(container, str): + container = [container] + if isinstance(num_tasks, int): + num_tasks = [num_tasks] + if len(cmd) != len(container) or len(cmd) != len(num_tasks): + raise ValueError("Number of commands, containers and num_tasks must match.") + for cur_idx, (cur_cmd, cur_container, cur_tasks) in enumerate(zip(cmd, container, num_tasks)): + if cluster_config["executor"] == "local" and cur_tasks > 1: + cur_cmd = f"mpirun --allow-run-as-root -np {cur_tasks} bash -c {shlex.quote(cur_cmd)}" + with temporary_env_update(cluster_config, {"NEMO_SKILLS_SANDBOX_PORT": sandbox_port}): + commands.append(cur_cmd) + executors.append( + get_executor( + cluster_config=cluster_config, + container=cur_container, + num_nodes=num_nodes, + tasks_per_node=cur_tasks, + gpus_per_node=num_gpus, + partition=partition, + time_min=time_min, + dependencies=dependencies, + job_name=task_name, + log_dir=log_dir, + log_prefix="main" if len(cmd) == 1 else f"main_{cur_idx}", + extra_package_dirs=extra_package_dirs, + slurm_kwargs=slurm_kwargs, + heterogeneous=heterogeneous, + het_group=het_group, + total_het_groups=total_het_groups, + ) + ) + het_group_indices.append(het_group) + het_group += 1 + + # finally a sandbox if needed + if with_sandbox: + sandbox_env_updates = {"LISTEN_PORT": sandbox_port} + current_env_vars = cluster_config.get("env_vars", []).copy() + for override in current_env_vars: + if "PYTHONPATH" in override: + if override.startswith("PYTHONPATH="): + override = override[11:] + sandbox_env_updates["PYTHONPATH"] = override + ":/app" + + with temporary_env_update(cluster_config, sandbox_env_updates): + commands.append(get_sandox_command()) + sandbox_executor = get_executor( + cluster_config=cluster_config, + container=cluster_config["containers"]["sandbox"], + num_nodes=executors[0].nodes if cluster_config["executor"] == "slurm" else 1, + tasks_per_node=1, + gpus_per_node=num_gpus, + partition=partition, + time_min=time_min, + mounts=tuple(), # we don't want to mount anything + dependencies=dependencies, + job_name=task_name, + log_dir=log_dir, + log_prefix="sandbox", + extra_package_dirs=extra_package_dirs, + slurm_kwargs=slurm_kwargs, + heterogeneous=heterogeneous, + het_group=het_group, + total_het_groups=total_het_groups, + ) + executors.append(sandbox_executor) + het_group_indices.append(het_group) + het_group += 1 + + if cluster_config["executor"] != "local": + tunnel = get_tunnel(cluster_config) + if isinstance(tunnel, run.SSHTunnel) and reuse_code: + reuse_code_exp = reuse_code_exp or REUSE_CODE_EXP.get(tunnel_hash(tunnel)) + if reuse_code_exp is not None: + if isinstance(reuse_code_exp, str): + try: + reuse_code_exp = run.Experiment.from_id(reuse_code_exp) + except Exception: + LOG.debug(f"Failed to create experiment from id {reuse_code_exp}, trying to find it by title") + reuse_code_exp = run.Experiment.from_title(reuse_code_exp) + + LOG.info("Trying to reuse code from experiment %s", reuse_code_exp._title) + reuse_key = get_packaging_job_key(reuse_code_exp._id, "nemo-run") + if reuse_key in reuse_code_exp.tunnels[tunnel.key].packaging_jobs: + reuse_dir = reuse_code_exp.tunnels[tunnel.key].packaging_jobs[reuse_key].dst_path + + for executor in executors: + executor.packager.symlink_from_remote_dir = reuse_dir + LOG.info(f"Successfully reused code from {reuse_key}") + else: + LOG.warning("Relevant packaging job not found for experiment %s", reuse_code_exp._title) + # if current is not reused, we are refreshing the cache as there is a reason to believe it's outdated + elif isinstance(tunnel, run.SSHTunnel): + REUSE_CODE_EXP.pop(tunnel_hash(tunnel), None) + + if len(commands) == 1: + # to keep sbatch script simpler, we don't wrap in a list in this case + return exp.add( + run.Script(inline=commands[0]), + executor=executors[0], + name="nemo-run", + dependencies=task_dependencies, + ) + else: + if heterogeneous: + executors[0].het_group_indices = het_group_indices + return exp.add( + [run.Script(inline=command) for command in commands], + executor=executors, + name="nemo-run", + dependencies=task_dependencies, + ) + + +def run_exp(exp, cluster_config, sequential=None): + """If sequential is not specified, using True locally and False otherwise. + + If it is specified, it will be used as is. + """ + if cluster_config['executor'] == 'local': + exp.run(detach=False, tail_logs=True, sequential=True if sequential is None else sequential) + else: + exp.run(detach=True, sequential=False if sequential is None else sequential) + + # caching the experiment code for reuse + tunnel = get_tunnel(cluster_config) + if isinstance(tunnel, run.SSHTunnel): + ssh_hash = tunnel_hash(tunnel) + if ssh_hash not in REUSE_CODE_EXP: + REUSE_CODE_EXP[ssh_hash] = exp diff --git a/tests/test_cfg_runtime_tests.py b/tests/test_cfg_runtime_tests.py index cce1b820..eb3cb4ec 100644 --- a/tests/test_cfg_runtime_tests.py +++ b/tests/test_cfg_runtime_tests.py @@ -25,7 +25,8 @@ def get_test_cases(): """Returns paths to all configs that are checked in.""" for config_path in glob.glob(f"{DATASET_CONFIGS_ROOT}/**/*.yaml", recursive=True): - yield config_path + if not config_path.endswith("nemo_run_config.yaml"): + yield config_path @pytest.mark.parametrize("config_path", get_test_cases())