diff --git a/.github/workflows/auth-ci.yml b/.github/workflows/auth-ci.yml index 61023148c..afbd9d1e8 100644 --- a/.github/workflows/auth-ci.yml +++ b/.github/workflows/auth-ci.yml @@ -26,14 +26,20 @@ jobs: with: python-version: '3.11' - - name: Install dependencies + - name: Install dependencies - Client working-directory: . run: | python -m pip install --upgrade pip pip install -e cli/ pip install -r cli/test-requirements.txt - pip install -r server/requirements.txt - pip install -r server/test-requirements.txt + + - name: Install Dependencies - Server + working-directory: ./server + run: | + python -m venv .venv_server + source .venv_server/bin/activate + pip install -r requirements.txt + pip install -r test-requirements.txt - name: Set server environment vars working-directory: ./server @@ -41,15 +47,21 @@ jobs: - name: Run postgresql server in background working-directory: ./server - run: sh run_dev_postgresql.sh && sleep 6 + run: | + python -m venv .venv_server + sh run_dev_postgresql.sh && sleep 6 - name: Run django server in background with generated certs working-directory: ./server - run: sh setup-dev-server.sh & sleep 6 + run: | + python -m venv .venv_server + sh setup-dev-server.sh & sleep 6 - name: Run server integration tests working-directory: ./server - run: python seed.py --cert cert.crt --auth online + run: | + python -m venv .venv_server + python seed.py --cert cert.crt --auth online - name: Run client integration tests working-directory: . diff --git a/.github/workflows/docker-ci.yml b/.github/workflows/docker-ci.yml index d90419f62..841208959 100644 --- a/.github/workflows/docker-ci.yml +++ b/.github/workflows/docker-ci.yml @@ -34,14 +34,21 @@ jobs: make sudo make install - - name: Install dependencies + - name: Install dependencies - Client working-directory: . run: | python -m pip install --upgrade pip pip install -e cli/ pip install -r cli/test-requirements.txt - pip install -r server/requirements.txt - pip install -r server/test-requirements.txt + + - name: Install Dependencies - Server + working-directory: ./server + run: | + python -m venv .venv_server + source .venv_server/bin/activate + pip install -r requirements.txt + pip install -r test-requirements.txt + - name: Set server environment vars working-directory: ./server @@ -49,7 +56,9 @@ jobs: - name: Generate SSL certificate working-directory: ./server - run: sh setup-dev-server.sh -c cert.crt -k cert.key -d 0 + run: | + python -m venv .venv_server + sh setup-dev-server.sh -c cert.crt -k cert.key -d 0 - name: Build container image working-directory: ./server diff --git a/.github/workflows/encrypted-containers-ci.yml b/.github/workflows/encrypted-containers-ci.yml index a9d32c4ed..23be0a735 100644 --- a/.github/workflows/encrypted-containers-ci.yml +++ b/.github/workflows/encrypted-containers-ci.yml @@ -34,14 +34,20 @@ jobs: make sudo make install - - name: Install dependencies + - name: Install dependencies - Client working-directory: . run: | python -m pip install --upgrade pip pip install -e cli/ pip install -r cli/test-requirements.txt - pip install -r server/requirements.txt - pip install -r server/test-requirements.txt + + - name: Install Dependencies - Server + working-directory: ./server + run: | + python -m venv .venv_server + source .venv_server/bin/activate + pip install -r requirements.txt + pip install -r test-requirements.txt - name: Set server environment vars working-directory: ./server @@ -49,15 +55,21 @@ jobs: - name: Run postgresql server in background working-directory: ./server - run: sh run_dev_postgresql.sh && sleep 6 + run: | + source .venv_server/bin/activate + sh run_dev_postgresql.sh && sleep 6 - name: Run django server in background with generated certs working-directory: ./server - run: sh setup-dev-server.sh & sleep 6 + run: | + source .venv_server/bin/activate + sh setup-dev-server.sh & sleep 6 - name: Run server integration tests working-directory: ./server - run: python seed.py --cert cert.crt + run: | + source .venv_server/bin/activate + python seed.py --cert cert.crt - name: Run client integration tests working-directory: . diff --git a/.github/workflows/local-ci.yml b/.github/workflows/local-ci.yml index e8f9e3901..9213b3179 100644 --- a/.github/workflows/local-ci.yml +++ b/.github/workflows/local-ci.yml @@ -34,14 +34,20 @@ jobs: make sudo make install - - name: Install dependencies + - name: Install dependencies - Client working-directory: . run: | python -m pip install --upgrade pip pip install -e cli/ pip install -r cli/test-requirements.txt - pip install -r server/requirements.txt - pip install -r server/test-requirements.txt + + - name: Install Dependencies - Server + working-directory: ./server + run: | + python -m venv .venv_server + source .venv_server/bin/activate + pip install -r requirements.txt + pip install -r test-requirements.txt - name: Set server environment vars working-directory: ./server @@ -49,17 +55,40 @@ jobs: - name: Run postgresql server in background working-directory: ./server - run: sh run_dev_postgresql.sh && sleep 6 + run: | + source .venv_server/bin/activate + sh run_dev_postgresql.sh && sleep 6 - name: Run django server in background with generated certs working-directory: ./server - run: sh setup-dev-server.sh & sleep 6 + run: | + source .venv_server/bin/activate + sh setup-dev-server.sh & sleep 6 - name: Run server integration tests working-directory: ./server - run: python seed.py --cert cert.crt + run: | + source .venv_server/bin/activate + python seed.py --cert cert.crt + + - name: Run chestxray demo including private model + working-directory: . + run: sh cli/cli_chestxray_tutorial_test.sh -f -p + + - name: Reset DB for Workflow test + working-directory: ./server + run: | + source .venv_server/bin/activate + sh reset_db.sh + sh reset_db_postgresql.sh + + - name: Seed DB for Workflow test + working-directory: ./server + run: | + source .venv_server/bin/activate + python seed.py --cert cert.crt -w - - name: Run chestxray demo + - name: Run chestxray demo with workflow (no private model) working-directory: . run: sh cli/cli_chestxray_tutorial_test.sh -f diff --git a/.github/workflows/train-ci.yml b/.github/workflows/train-ci.yml index 835f0e2f3..e793166e8 100644 --- a/.github/workflows/train-ci.yml +++ b/.github/workflows/train-ci.yml @@ -15,14 +15,20 @@ jobs: with: python-version: '3.11' - - name: Install dependencies + - name: Install dependencies - Client working-directory: . run: | python -m pip install --upgrade pip pip install -e cli/ pip install -r cli/test-requirements.txt - pip install -r server/requirements.txt - pip install -r server/test-requirements.txt + + - name: Install Dependencies - Server + working-directory: ./server + run: | + python -m venv .venv_server + source .venv_server/bin/activate + pip install -r requirements.txt + pip install -r test-requirements.txt - name: Set server environment vars working-directory: ./server @@ -30,15 +36,21 @@ jobs: - name: Run postgresql server in background working-directory: ./server - run: sh run_dev_postgresql.sh && sleep 6 + run: | + source .venv_server/bin/activate + sh run_dev_postgresql.sh && sleep 6 - name: Run django server in background with generated certs working-directory: ./server - run: sh setup-dev-server.sh & sleep 6 + run: | + source .venv_server/bin/activate + sh setup-dev-server.sh & sleep 6 - name: Run server integration tests working-directory: ./server - run: python seed.py --cert cert.crt + run: | + source .venv_server/bin/activate + python seed.py --cert cert.crt - name: Run client integration tests working-directory: . diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml index 1a215e1df..75f5816d1 100644 --- a/.github/workflows/unittests.yml +++ b/.github/workflows/unittests.yml @@ -16,14 +16,23 @@ jobs: uses: actions/setup-python@v2 with: python-version: 3.11 - - name: Install dependencies + + - name: Install dependencies - Client + working-directory: . run: | python -m pip install --upgrade pip pip install flake8 pytest if [ -f requirements.txt ]; then pip install -r requirements.txt; fi if [ -f cli/requirements.txt ]; then pip install -e cli; fi - pip install -r server/requirements.txt - pip install -r server/test-requirements.txt + + - name: Install Dependencies - Server + working-directory: ./server + run: | + python -m venv .venv_server + source .venv_server/bin/activate + pip install -r requirements.txt + pip install -r test-requirements.txt + - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names @@ -34,20 +43,29 @@ jobs: # Exclude examples folder as it doesn't contain code related to medperf tools # Exclude migrations folder as it contains autogenerated code # Ignore E231, as it is raising warnings with auto-generated code. - flake8 . --count --max-complexity=10 --max-line-length=127 --ignore F821,W503,E231 --statistics --exclude=examples/,"*/migrations/*",cli/medperf/templates/ + flake8 . --count --max-complexity=10 --max-line-length=127 --ignore F821,W503,E231 --statistics --exclude=examples/,"*/migrations/*",cli/medperf/templates/,server/.venv_server/ + - name: Test with pytest working-directory: ./cli/medperf/tests run: | pytest + - name: Set server environment vars working-directory: ./server run: cp .env.local.local-auth .env + - name: Run postgresql server in background working-directory: ./server run: sh run_dev_postgresql.sh && sleep 6 + - name: Run migrations working-directory: ./server - run: python manage.py migrate + run: | + source .venv_server/bin/activate + python manage.py migrate + - name: Run server unit tests working-directory: ./server - run: python manage.py test --parallel + run: | + source .venv_server/bin/activate + python manage.py test --parallel diff --git a/.github/workflows/webui-ci.yml b/.github/workflows/webui-ci.yml index 4f875360c..b963ef138 100644 --- a/.github/workflows/webui-ci.yml +++ b/.github/workflows/webui-ci.yml @@ -15,14 +15,20 @@ jobs: with: python-version: '3.11' - - name: Install dependencies + - name: Install dependencies - Client working-directory: . run: | python -m pip install --upgrade pip pip install -e cli/ pip install -r cli/test-requirements.txt - pip install -r server/requirements.txt - pip install -r server/test-requirements.txt + + - name: Install Dependencies - Server + working-directory: ./server + run: | + python -m venv .venv_server + source .venv_server/bin/activate + pip install -r requirements.txt + pip install -r test-requirements.txt - name: Set server environment vars working-directory: ./server @@ -30,15 +36,21 @@ jobs: - name: Run postgresql server in background working-directory: ./server - run: sh run_dev_postgresql.sh && sleep 6 + run: | + source .venv_server/bin/activate + sh run_dev_postgresql.sh && sleep 6 - name: Run django server in background with generated certs working-directory: ./server - run: sh setup-dev-server.sh & sleep 6 + run: | + source .venv_server/bin/activate + sh setup-dev-server.sh & sleep 6 - name: Reset and seed database with tutorial data working-directory: ./server - run: sh reset_db.sh && python seed.py --demo tutorial + run: | + source .venv_server/bin/activate + sh reset_db.sh && python seed.py --demo tutorial - name: Set up tutorial files working-directory: . diff --git a/cli/cli_chestxray_tutorial_test.sh b/cli/cli_chestxray_tutorial_test.sh index 659ede731..816c08eaf 100755 --- a/cli/cli_chestxray_tutorial_test.sh +++ b/cli/cli_chestxray_tutorial_test.sh @@ -11,6 +11,7 @@ echo "=====================================" echo "downloading files to $DIRECTORY" wget -P $DIRECTORY https://storage.googleapis.com/medperf-storage/chestxray_tutorial/sample_raw_data.tar.gz tar -xzvf $DIRECTORY/sample_raw_data.tar.gz -C $DIRECTORY + chmod a+w $DIRECTORY/sample_raw_data ########################################################## @@ -27,8 +28,7 @@ print_eval medperf profile create -n testdata checkFailed "testdata profile creation failed" print_eval medperf profile create -n noserver checkFailed "noserver profile creation failed" -print_eval medperf profile create -n testprivate -checkFailed "testprivate profile creation failed" + print_eval medperf profile set --server https://example.com checkFailed "setting mock server failed" @@ -47,15 +47,9 @@ checkFailed "testdata profile activation failed" print_eval medperf auth login -e $DATAOWNER checkFailed "testdata login failed" -print_eval medperf profile activate testprivate -checkFailed "testprivate profile activation failed" - -print_eval medperf auth login -e $PRIVATEMODELOWNER -checkFailed "testprivate login failed" - ########################################################## echo "=====================================" -echo ""Activate benchmarkowner profile"" +echo "Activate benchmarkowner profile" echo "=====================================" # Log in as the benchmark owner print_eval medperf profile activate testbenchmark @@ -66,7 +60,7 @@ echo "\n" ########################################################## echo "=====================================" -echo ""Change association approval policy to auto approve always"" +echo "Change association approval policy to auto approve always" echo "=====================================" # Log in as the benchmark owner print_eval medperf benchmark update_associations_policy -b 1 \ @@ -102,7 +96,7 @@ echo "\n" echo "=====================================" echo "Running data preparation step" echo "=====================================" -print_eval medperf dataset prepare -d $DSET_UID +print_eval medperf dataset prepare -d $DSET_UID -y checkFailed "Data preparation step failed" ########################################################## @@ -128,79 +122,6 @@ checkFailed "Data association step failed" echo "\n" -########################################################## -echo "=============================================" -echo "Getting a certificate" -echo "=============================================" -print_eval medperf certificate get_client_certificate -checkFailed "Failed to obtain Data Owner Certificate" -########################################################## - -echo "\n" - -########################################################## -echo "=============================================" -echo "Submitting the certificate" -echo "=============================================" -print_eval medperf certificate submit_client_certificate -y -checkFailed "Failed to submit Data Owner Certificate" -########################################################## - -echo "\n" - -########################################################## -echo "=====================================" -echo "Activate Model Owner Profile" -echo "=====================================" -print_eval medperf profile activate testprivate -checkFailed "testprivate profile activation failed" -########################################################## - -echo "\n" - -########################################################## -echo "=====================================" -echo "Submit a private model" -echo "=====================================" -print_eval medperf container submit --name privmodel \ - -m $CHESTXRAY_ENCRYPTED_MODEL -p $CHESTXRAY_ENCRYPTED_MODEL_PARAMS \ - -a $CHESTXRAY_ENCRYPTED_MODEL_ADD --decryption_key $PRIVATE_MODEL_LOCAL/key.bin --operational -checkFailed "private container submission failed" -PMODEL_UID=$(medperf container ls | grep privmodel | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) -########################################################## - -echo "\n" - -########################################################## -echo "=====================================" -echo "Running private model association" -echo "=====================================" -print_eval medperf container associate -m $PMODEL_UID -b 1 -y -checkFailed "private model association failed" -########################################################## - -echo "\n" - -########################################################## -echo "=====================================" -echo "Give Access to Private Model" -echo "=====================================" -print_eval medperf container grant_access --model-id $PMODEL_UID --benchmark-id 1 -y -checkFailed "Failed to Give Model Access to Data owner" -########################################################## - -echo "\n" - -########################################################## -echo "=====================================" -echo "Activate Data Owner profile" -echo "=====================================" -print_eval medperf profile activate testdata -checkFailed "testdata profile activation failed" -########################################################## - -echo "\n" - ########################################################## echo "=====================================" echo "Running benchmark execution step - Public" @@ -212,17 +133,6 @@ checkFailed "Benchmark execution step failed (public)" echo "\n" -########################################################## -echo "=====================================" -echo "Running benchmark execution step - Private" -echo "=====================================" -# Create results -print_eval medperf run -b 1 -d $DSET_UID -m $PMODEL_UID -y -checkFailed "Benchmark execution step failed (private)" -########################################################## - -echo "\n" - ########################################################## echo "=====================================" echo " Offline Compatibility Test - Public " @@ -253,29 +163,12 @@ checkFailed "offline compatibility test execution step failed - public model" echo "\n" -########################################################## -echo "=====================================" -echo " Offline Compatibility Test - Private " -echo "=====================================" -print_eval medperf test run --offline --no-cache \ - --demo_dataset_url https://storage.googleapis.com/medperf-storage/chestxray_tutorial/demo_data.tar.gz \ - --demo_dataset_hash "71faabd59139bee698010a0ae3a69e16d97bc4f2dde799d9e187b94ff9157c00" \ - -p $PREP_LOCAL/container_config.yaml \ - -m $PRIVATE_MODEL_LOCAL/container_config.yaml \ - -e $METRIC_LOCAL/container_config.yaml \ - -d $PRIVATE_MODEL_LOCAL/key.bin \ - --data_preparator_parameters $PREP_LOCAL/workspace/parameters.yaml \ - --model_parameters $MODEL_LOCAL/workspace/parameters.yaml \ - --evaluator_parameters $METRIC_LOCAL/workspace/parameters.yaml \ - --model_additional_files $MODEL_LOCAL/workspace/additional_files/ - -checkFailed "offline compatibility test execution step failed - private model" +if ${PRIVATE}; then + . "$(dirname $(realpath "$0"))/cli_chestxray_tutorial_test_private_model.sh" +fi print_eval rm $MODEL_LOCAL/workspace/additional_files/cnn_weights.tar.gz print_eval rm $MODEL_LOCAL/workspace/additional_files/cnn_weights.pth -########################################################## - -echo "\n" ########################################################## echo "=====================================" @@ -312,9 +205,6 @@ checkFailed "Profile deletion failed" print_eval medperf profile delete noserver checkFailed "Profile deletion failed" -print_eval medperf profile delete testprivate -checkFailed "Profile deletion failed" - if ${CLEANUP}; then clean fi diff --git a/cli/cli_chestxray_tutorial_test_private_model.sh b/cli/cli_chestxray_tutorial_test_private_model.sh new file mode 100644 index 000000000..bcbe62808 --- /dev/null +++ b/cli/cli_chestxray_tutorial_test_private_model.sh @@ -0,0 +1,153 @@ +# This script is meant to be called as part of cli_chestxray_tutorial_test.sh + + +########################################################## +echo "=====================================" +echo "Activate dataowner profile" +echo "=====================================" +print_eval medperf profile activate testdata +checkFailed "testdata profile activation failed" +########################################################## + +echo "\n" +########################################################## +echo "=============================================" +echo "Getting a certificate" +echo "=============================================" +print_eval medperf certificate get_client_certificate --overwrite +checkFailed "Failed to obtain Data Owner Certificate" +########################################################## + +echo "\n" + +########################################################## +echo "=============================================" +echo "Submitting the certificate" +echo "=============================================" +print_eval medperf certificate submit_client_certificate -y +checkFailed "Failed to submit Data Owner Certificate" +########################################################## + +echo "\n" + +########################################################## +echo "==========================================" +echo "Creating test profile for private model owner" +echo "==========================================" +print_eval medperf profile create -n testprivate +checkFailed "testprivate profile creation failed" +########################################################## + +echo "\n" + +########################################################## +echo "==========================================" +echo "Login Private Model Owner" +echo "==========================================" + +print_eval medperf profile activate testprivate +checkFailed "testprivate profile activation failed" + +print_eval medperf auth logout +checkFailed "logout failed" + +print_eval medperf auth login -e $PRIVATEMODELOWNER +checkFailed "testprivate login failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Submit a private model" +echo "=====================================" +print_eval medperf container submit --name privmodel \ +-m $CHESTXRAY_ENCRYPTED_MODEL -p $CHESTXRAY_ENCRYPTED_MODEL_PARAMS \ +-a $CHESTXRAY_ENCRYPTED_MODEL_ADD --decryption_key $PRIVATE_MODEL_LOCAL/key.bin --operational +checkFailed "private container submission failed" +PMODEL_UID=$(medperf container ls | grep privmodel | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Running private model association" +echo "=====================================" +print_eval medperf container associate -m $PMODEL_UID -b 1 -y +checkFailed "private model association failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Give Access to Private Model" +echo "=====================================" +print_eval medperf container grant_access --model-id $PMODEL_UID --benchmark-id 1 -y +checkFailed "Failed to Give Model Access to Data owner" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Activate dataowner profile" +echo "=====================================" +print_eval medperf profile activate testdata +checkFailed "testdata profile activation failed" + +print_eval medperf auth logout + +print_eval medperf auth login -e $DATAOWNER +checkFailed "testdata login failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Running benchmark execution step - Private" +echo "=====================================" +# Create results +print_eval medperf run -b 1 -d $DSET_UID -m $PMODEL_UID -y +checkFailed "Benchmark execution step failed (private)" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo " Offline Compatibility Test - Private " +echo "=====================================" + +## Change the server and logout just to make sure this command will work without connecting to a server +print_eval medperf profile activate noserver +checkFailed "noserver profile activation failed" + +print_eval medperf test run --offline --no-cache \ +--demo_dataset_url https://storage.googleapis.com/medperf-storage/chestxray_tutorial/demo_data.tar.gz \ +--demo_dataset_hash "71faabd59139bee698010a0ae3a69e16d97bc4f2dde799d9e187b94ff9157c00" \ +-p $PREP_LOCAL/container_config.yaml \ +-m $PRIVATE_MODEL_LOCAL/container_config.yaml \ +-e $METRIC_LOCAL/container_config.yaml \ +-d $PRIVATE_MODEL_LOCAL/key.bin \ +--data_preparator_parameters $PREP_LOCAL/workspace/parameters.yaml \ +--model_parameters $MODEL_LOCAL/workspace/parameters.yaml \ +--evaluator_parameters $METRIC_LOCAL/workspace/parameters.yaml \ +--model_additional_files $MODEL_LOCAL/workspace/additional_files/ + +checkFailed "offline compatibility test execution step failed - private model" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Delete test Private Model Owner profile" +echo "=====================================" +print_eval medperf profile delete testprivate +checkFailed "Profile deletion failed" +########################################################## + +echo "\n" \ No newline at end of file diff --git a/cli/medperf/commands/association/utils.py b/cli/medperf/commands/association/utils.py index a67ebadf4..9e3b0da6e 100644 --- a/cli/medperf/commands/association/utils.py +++ b/cli/medperf/commands/association/utils.py @@ -1,6 +1,6 @@ from medperf.exceptions import InvalidArgumentError, MedperfException from medperf import config -from pydantic.datetime_parse import parse_datetime +from medperf.utils import parse_datetime def validate_args( diff --git a/cli/medperf/commands/dataset/dataset.py b/cli/medperf/commands/dataset/dataset.py index 02bb9898e..7aa5e8576 100644 --- a/cli/medperf/commands/dataset/dataset.py +++ b/cli/medperf/commands/dataset/dataset.py @@ -33,8 +33,12 @@ def list( ), name: str = typer.Option(None, "--name", help="Filter by name"), owner: int = typer.Option(None, "--owner", help="Filter by owner"), - state: str = typer.Option(None, "--state", help="Filter by state (DEVELOPMENT/OPERATION)"), - is_valid: bool = typer.Option(None, "--valid/--invalid", help="Filter by valid status"), + state: str = typer.Option( + None, "--state", help="Filter by state (DEVELOPMENT/OPERATION)" + ), + is_valid: bool = typer.Option( + None, "--valid/--invalid", help="Filter by valid status" + ), ): """List datasets""" EntityList.run( @@ -118,10 +122,19 @@ def prepare( "-y", help="Skip report submission approval step (In this case, it is assumed to be approved)", ), + no_cache: bool = typer.Option( + False, + "-nc", + "--no-cache", + "--no_cache", + help="Start a clean run without previous cached results.", + ), ): """Runs the Data preparation step for a raw dataset""" ui = config.ui - DataPreparation.run(data_uid, approve_sending_reports=approval) + DataPreparation.run( + data_uid, approve_sending_reports=approval, use_cached_results=not no_cache + ) ui.print("✅ Done!") diff --git a/cli/medperf/commands/dataset/prepare.py b/cli/medperf/commands/dataset/prepare.py index 8129a3f98..67db296ee 100644 --- a/cli/medperf/commands/dataset/prepare.py +++ b/cli/medperf/commands/dataset/prepare.py @@ -1,10 +1,10 @@ +from __future__ import annotations import logging import os -import pandas as pd from medperf.entities.dataset import Dataset import medperf.config as config from medperf.entities.cube import Cube -from medperf.utils import approval_prompt, dict_pretty_print +from medperf.utils import approval_prompt, dict_pretty_print, remove_path from medperf.exceptions import ( CommunicationError, ExecutionError, @@ -18,7 +18,7 @@ class ReportHandler(FileSystemEventHandler): - def __init__(self, preparation_obj: "DataPreparation"): + def __init__(self, preparation_obj: DataPreparation): self.preparation = preparation_obj self.timer = None @@ -48,7 +48,7 @@ def on_modified(self, event): class ReportSender: - def __init__(self, preparation_obj: "DataPreparation"): + def __init__(self, preparation_obj: DataPreparation): self.preparation = preparation_obj def start(self): @@ -81,8 +81,14 @@ def run( dataset_id: int, approve_sending_reports: bool = False, data_preparation_cube: Cube = None, + use_cached_results: bool = True, ): - preparation = cls(dataset_id, approve_sending_reports, data_preparation_cube) + preparation = cls( + dataset_id, + approve_sending_reports, + data_preparation_cube, + use_cached_results, + ) preparation.get_dataset() preparation.validate() with preparation.ui.interactive(): @@ -97,9 +103,10 @@ def run( with preparation.ui.interactive(): preparation.run_prepare() - with preparation.ui.interactive(): - preparation.run_sanity_check() - preparation.run_statistics() + if not preparation.is_workflow: + with preparation.ui.interactive(): + preparation.run_sanity_check() + preparation.run_statistics() preparation.mark_dataset_as_ready() @@ -110,6 +117,7 @@ def __init__( dataset_id: int, approve_sending_reports: bool, data_preparation_cube: Cube = None, + use_cached_results: bool = True, ): self.comms = config.comms self.ui = config.ui @@ -127,6 +135,11 @@ def __init__( self.report_specified = None self.metadata_specified = None self._lock = Lock() + self.use_cached_results = use_cached_results + + @property + def is_workflow(self): + return self.cube.is_workflow def should_run_prepare(self): return not self.dataset.submitted_as_prepared and not self.dataset.is_ready() @@ -179,6 +192,9 @@ def run_prepare(self): report_sender = ReportSender(self) report_sender.start() + if not self.use_cached_results: + self._remove_old_results() + prepare_mounts = { "data_path": self.raw_data_path, "labels_path": self.raw_labels_path, @@ -191,6 +207,10 @@ def run_prepare(self): if self.report_specified: prepare_mounts["report_file"] = self.report_path + if self.cube.is_workflow: + prepare_mounts["statistics_file"] = self.out_statistics_path + prepare_mounts["dataset_path"] = self.dataset.path + self.ui.text = "Running preparation step..." try: with self.ui.interactive(): @@ -211,6 +231,23 @@ def run_prepare(self): self.ui.print("> Container execution complete") report_sender.stop("finished") + def _remove_old_results(self): + paths_to_remove = [ + self.out_datapath, + self.out_labelspath, + self.out_statistics_path, + self.metadata_path, + self.report_path, + ] + + if self.is_workflow: + airflow_home = os.path.join(self.dataset.path, "airflow_home") + paths_to_remove.append(airflow_home) + + for path_to_remove in paths_to_remove: + if os.path.exists(path_to_remove): + remove_path(path_to_remove) + def run_sanity_check(self): sanity_check_timeout = config.sanity_check_timeout out_datapath = self.out_datapath @@ -287,31 +324,20 @@ def mark_dataset_as_ready(self): self.dataset.mark_as_ready() def __generate_report_dict(self): - report_status_dict = {} - if os.path.exists(self.report_path): with open(self.report_path, "r") as f: report_dict = yaml.safe_load(f) + return report_dict.get("progress", {}) - # TODO: this specific logic with status is very tuned to the RANO. Hope we'd - # make it more general once - report = pd.DataFrame(report_dict) - if "status" in report.keys(): - report_status = report.status.value_counts() / len(report) - report_status_dict = report_status.round(3).to_dict() - report_status_dict = { - f"Stage {key}": str(val * 100) + "%" - for key, val in report_status_dict.items() - } - - return report_status_dict + # If no report has been generated yet, return a blank report + return {} def prompt_for_report_sending_approval(self): example = { "execution_status": "running", "progress": { - "Stage 1": "40%", - "Stage 3": "60%", + "Stage 1": "40.0", + "Stage 3": "60.0", }, } @@ -356,6 +382,7 @@ def _send_report(self, report_metadata): report_status_dict = {} if self.allow_sending_reports: report_status_dict = self.__generate_report_dict() + report = {"progress": report_status_dict, **report_metadata} if report == self.dataset.report: # Watchdog may trigger an event even if contents didn't change diff --git a/cli/medperf/commands/execution/utils.py b/cli/medperf/commands/execution/utils.py index b8406cef3..98fdd20e4 100644 --- a/cli/medperf/commands/execution/utils.py +++ b/cli/medperf/commands/execution/utils.py @@ -1,4 +1,4 @@ -from pydantic.datetime_parse import parse_datetime +from medperf.utils import parse_datetime from medperf.entities.execution import Execution from typing import List diff --git a/cli/medperf/commands/mlcube/submit.py b/cli/medperf/commands/mlcube/submit.py index f56cace25..453fad6df 100644 --- a/cli/medperf/commands/mlcube/submit.py +++ b/cli/medperf/commands/mlcube/submit.py @@ -28,6 +28,7 @@ def run( submit_info, container_config, parameters_config, decryption_key ) submission.read_config_files() + submission.validate_hash_format() submission.create_cube_object() with ui.interactive(): @@ -151,3 +152,17 @@ def store_decryption_key(self): return logging.debug(f"Decryption key provided: {self.decryption_key}") store_decryption_key(self.cube.id, self.decryption_key) + + def validate_hash_format(self): + """ + Changes a string hash (i.e sent from the command line) into a dict + Also removes a None hash (so it can be generated properly as an empty dict) + """ + tentative_hash = self.submit_info.pop("image_hash", None) + + if tentative_hash is None: + return + elif isinstance(tentative_hash, str): + tentative_hash = {"default": tentative_hash} + + self.submit_info["image_hash"] = tentative_hash diff --git a/cli/medperf/config.py b/cli/medperf/config.py index a28160b09..b60d03e65 100644 --- a/cli/medperf/config.py +++ b/cli/medperf/config.py @@ -63,6 +63,7 @@ config_storage = Path.home().resolve() / ".medperf_config" logs_storage = Path.home().resolve() / ".medperf_logs" +airflow_venv_dir = str(config_storage / ".airflow_venv") config_path = str(config_storage / "config.yaml") auth_jwks_file = str(config_storage / ".jwks") creds_folder = str(config_storage / ".tokens") @@ -269,6 +270,8 @@ webui_max_chunk_length = 20 # Max 20 events in a chunk webui_max_chunk_size = 64 * 1024 # Max 64 Bytes as chunk size +# Airflow-related config +airflow_postgres_image = "postgres:14.19" default_profile_name = "default" testauth_profile_name = "testauth" diff --git a/cli/medperf/containers/parsers/airflow_parser.py b/cli/medperf/containers/parsers/airflow_parser.py new file mode 100644 index 000000000..b626c91c4 --- /dev/null +++ b/cli/medperf/containers/parsers/airflow_parser.py @@ -0,0 +1,222 @@ +from typing import Dict, Union, Literal, Set +from medperf.exceptions import InvalidContainerSpec +from medperf.containers.parsers.parser import Parser +import os +import yaml +from medperf import config + + +class ContainerForAirflow: + + def __init__(self, image: str, platform: Literal["docker", "singularity"]): + self._image = image + if platform not in ["docker", "singularity"]: + raise InvalidContainerSpec(f"Container type {platform} is not supported!") + self._platform = platform + + @property + def image(self): + return self._image + + @property + def platform(self): + return self._platform + + def __eq__(self, other): + if not isinstance(other, ContainerForAirflow): + return False + return self._image == other._image and self._platform == other._platform + + def __hash__(self): + return hash((self._image, self._platform)) + + +class AirflowParser(Parser): + """ + MedPerf-facing side of the tool to parse YAML files for Airflow. + The actual DAG generation portion of the parser is implemented separately + in the `airflow_runner` directory. + """ + + def __init__( + self, + airflow_config: dict, + allowed_runners: list, + container_files_base_path: str, + ): + self.airflow_config = airflow_config + self.allowed_runners = allowed_runners + self.config_file_path = os.path.join( + container_files_base_path, config.cube_filename + ) + with open(self.config_file_path, "w") as f: + yaml.safe_dump(self.airflow_config, f, sort_keys=False) + + # The following variables are set when calling check_schema for the first time + self._steps = [] + self._has_metadata = None + self.pools = None + self.step_ids = [] + self.containers: Set[ContainerForAirflow] = ( + set() + ) # TODO currently assumes only images on some registry, does not support files + + def check_schema(self) -> str: + """ + This is still the preliminary version of the schema. Subject to change. + """ + if "steps" not in self.airflow_config: + raise InvalidContainerSpec("Airflow config should have a 'steps' field.") + + self._steps = self.airflow_config["steps"] + + final_step_candidates = [] + error_msgs = [] + tmp_pools = {} + + for i, step in enumerate(self._steps): + self.step_ids.append(step["id"]) + + self._check_mandatory_fields( + step=step, step_index=i, error_msg_list=error_msgs + ) + + is_last_step_candidate = self._check_last_step( + step=step, error_msg_list=error_msgs + ) + if is_last_step_candidate: + final_step_candidates.append(step["id"]) + + container_image = self._verify_container(step, error_msg_list=error_msgs) + if container_image is not None: + self.containers.add( # TODO add support for singularity containers + ContainerForAirflow(image=container_image, platform="docker") + ) + + if "limit" in step.keys(): + tmp_pools.update(self._create_pool_info(step)) + + if not final_step_candidates: + msg = "The provided YAML DAG file has no clear final step!" + error_msgs.append(msg) + + elif len(final_step_candidates) > 1: + final_step_candidates = [f"- {step}" for step in final_step_candidates] + msg = "The provided YAML DAG file has multiple potential last steps. Please verify the following steps:\n" + msg += "\n".join(final_step_candidates) + error_msgs.append(msg) + + if error_msgs: + full_msg = "\n\n".join(error_msgs) + raise InvalidContainerSpec(full_msg) + + self.pools = tmp_pools or None + + @staticmethod + def _check_mandatory_fields( + step: Dict[str, str], step_index: int, error_msg_list: list[str] + ) -> Union[str, None]: + mandatory_fields = {"id", "type"} + missing_fields = mandatory_fields.difference(set(step.keys())) + if missing_fields: + ordered_fields = sorted(list(missing_fields)) + if step.get("id"): + step_identifier = f"step {step['id']}" + else: + step_identifier = f"{step_index + 1}th step" + msg = ( + f"The {step_identifier} in the yaml file is missing the " + f"following mandatory fields: ', '.join({ordered_fields})" + ) + error_msg_list.append(msg) + + @staticmethod + def _check_last_step(step: Dict[str, str], error_msg_list: list[str]) -> bool: + is_marked_as_last = step.get("last", False) + inferred_as_last = step.get("next") is None + + if is_marked_as_last or inferred_as_last: + is_last_step_candidate = True + if step.get("partition", False): + error_msg = f"Step {step['id']} appears to be the final step, but is also part of a partition.\n" + error_msg += ( + "Please make sure the final step is not part of any partitions. " + ) + error_msg += ( + "You may need to add a final dummy step to join all results." + ) + error_msg_list.append(error_msg) + else: + is_last_step_candidate = False + + return is_last_step_candidate + + @staticmethod + def _create_pool_info(step: Dict[str, Union[str, int]]): + return { + f'{step["id"]}_pool': { + "slots": step["limit"], + "include_deferred": False, + "description": f"Pool to limit the execution of " + f'tasks with ID {step["id"]} to {step["limit"]} ' + "parallel executions", + } + } + + @staticmethod + def _verify_container(step: Dict[str, str], error_msg_list: list[str]): + if step.get("type") != "container": + return None + + if "command" not in step: + msg = f"Step {step['id']} is of type 'container' bu does not specify a 'command' field!" + error_msg_list.append(msg) + + try: + return step["image"] + except KeyError: + msg = f"Step {step['id']} is of type 'container' but does not specify a 'image' field!!" + error_msg_list.append(msg) + return None + + # TODO validate how to add these methods to this parser + def check_task_schema(self, task): + pass + + def get_setup_args(self): + pass + + def get_volumes(self, task: str, medperf_mounts: dict): + pass + + def get_run_args(self, task: str, medperf_mounts: dict): + pass + + def is_report_specified(self): + """Can always get report data from Airflow REST API""" + return True + + @property + def has_metadata(self): + if self._has_metadata is None: + self._has_metadata = any( + "metadata_path" in step["mounts"].get("output_volumes", {}) + for step in self._steps + ) + + return self._has_metadata + + def is_metadata_specified(self): + return self.has_metadata + + def is_container_encrypted(self): + return False + + def is_docker_archive(self): + return False + + def is_docker_image(self): + return False + + def is_singularity_file(self): + return False diff --git a/cli/medperf/containers/parsers/factory.py b/cli/medperf/containers/parsers/factory.py index d518e759d..3001d4604 100644 --- a/cli/medperf/containers/parsers/factory.py +++ b/cli/medperf/containers/parsers/factory.py @@ -3,6 +3,7 @@ from .simple_container import SimpleContainerParser from medperf.enums import ContainerTypes import logging +from .airflow_parser import AirflowParser DOCKER_TYPES = [ ContainerTypes.DOCKER_IMAGE.value, @@ -15,11 +16,25 @@ ] -def load_parser(container_config: dict) -> Parser: - if container_config is None: +def _is_airflow_yaml_file(airflow_config: dict): + return "container_type" not in airflow_config and "steps" in airflow_config + + +def load_parser(container_config: dict, container_files_base_path: str) -> Parser: + + if _is_airflow_yaml_file(container_config): + parser = AirflowParser( + airflow_config=container_config, + allowed_runners=["docker", "singularity"], + container_files_base_path=container_files_base_path, + ) + parser.check_schema() + return parser + + elif container_config is None: raise InvalidContainerSpec("Empty container configuration") - if "container_type" not in container_config: + elif "container_type" not in container_config: raise InvalidContainerSpec( "Container configuration should contain a 'container_type' field." ) diff --git a/cli/medperf/containers/runners/airflow_runner.py b/cli/medperf/containers/runners/airflow_runner.py new file mode 100644 index 000000000..dabd450f9 --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner.py @@ -0,0 +1,118 @@ +from .runner import Runner +from typing import Dict +from medperf.containers.runners.airflow_runner_utils.system_runner import ( + AirflowSystemRunner, +) +from pathlib import Path +from .airflow_runner_utils.dags import constants +from .airflow_runner_utils.plugins import auto_login +import os +from medperf.containers.parsers.airflow_parser import AirflowParser +from medperf.account_management import get_medperf_user_data +from .utils import get_expected_hash +from .singularity_utils import check_docker_image_by_name +import logging +from .docker_utils import download_docker_image +from medperf import config + + +class AirflowRunner(Runner): + + _DAGS_FOLDER = str(Path(constants.__file__).parent.resolve()) + _PLUGINS_FOLDER = str(Path(auto_login.__file__).parent.resolve()) + + def __init__( + self, + airflow_config_parser: AirflowParser, + workflow_name, + ): + self.parser = airflow_config_parser + self.workflow_name = workflow_name + + def download( + self, + expected_image_hash: Dict[str, str], + download_timeout: int = None, + get_hash_timeout: int = None, + ) -> Dict[str, str]: + # TODO add support for Docker Archives, Encrypted Containers, singularity files + if config.platform == "docker": + return self._download_containers_for_docker( + expected_image_hash, download_timeout, get_hash_timeout + ) + elif config.platform == "singularity": + return self._check_containers_for_singularity( + expected_image_hash, get_hash_timeout + ) + + def _download_containers_for_docker( + self, hashes_dict, download_timeout, get_hash_timeout + ): + for container in self.parser.containers: + expected_image_hash = get_expected_hash(hashes_dict, container.image) + computed_image_hash = download_docker_image( + docker_image=container.image, + expected_image_hash=expected_image_hash, + download_timeout=download_timeout, + get_hash_timeout=get_hash_timeout, + ) + hashes_dict[container.image] = computed_image_hash + + return hashes_dict + + def _check_containers_for_singularity(self, hashes_dict, get_hash_timeout): + """ + Note: currently assumes image always come from some Docker registry (i.e docker hub) + and then are converted into singularity during run + """ + for container in self.parser.containers: + expected_image_hash = get_expected_hash(hashes_dict, container.image) + computed_image_hash = check_docker_image_by_name( + docker_image=container.image, + expected_image_hash=expected_image_hash, + get_hash_timeout=get_hash_timeout, + ) + hashes_dict[container.image] = computed_image_hash + + return hashes_dict + + def run( + self, + task: str = None, # Not used + tmp_folder: str = None, # TODO implement + output_logs: str = None, + timeout: int = None, + medperf_mounts: dict[str, str] = {}, + medperf_env: dict[str, str] = {}, + ports: list = [], + disable_network: bool = True, + container_decryption_key_file: str = None, + ): + + email = get_medperf_user_data()["email"] + username = email.split("@", maxsplit=1)[0] + + dataset_dir = medperf_mounts.pop("dataset_path") + airflow_home = os.path.join(dataset_dir, "airflow_home") + additional_files_path = medperf_mounts["additional_files"] + + logging.debug( + f"Starting Airflow runner with the following airflow home directory: {airflow_home}." + ) + with AirflowSystemRunner( + airflow_home=airflow_home, + user=username, + email=email, + dags_folder=self._DAGS_FOLDER, + plugins_folder=self._PLUGINS_FOLDER, + additional_files_dir=additional_files_path, + mounts=medperf_mounts, + project_name=self.workflow_name, + yaml_parser=self.parser, + ) as system_runner: + system_runner.init_airflow() + system_runner.wait_for_dag() + + @property + def is_workflow(self): + return True diff --git a/cli/medperf/containers/runners/airflow_runner_utils/.gitignore b/cli/medperf/containers/runners/airflow_runner_utils/.gitignore new file mode 100644 index 000000000..9e776a5bc --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/.gitignore @@ -0,0 +1,2 @@ +airflow_home*/ +*workspace/ \ No newline at end of file diff --git a/cli/medperf/containers/runners/airflow_runner_utils/airflow_api_client.py b/cli/medperf/containers/runners/airflow_runner_utils/airflow_api_client.py new file mode 100644 index 000000000..3a2dea1a8 --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/airflow_api_client.py @@ -0,0 +1,172 @@ +from __future__ import annotations +from airflow.sdk.api.client import BearerAuth as AirflowBearerAuth +from typing import Optional +from methodtools import lru_cache +import time +import httpx +from pydantic import SecretStr +from typing import Union + + +""" +The AirflowAPIClient defined in this Module if structured similarly to the internal +Client class used by Airflow (i.e airflow.sdk.api.client.Client), but simplified +so that only methods necessary for the Airflow <> MedPerf integration are +implemented. +""" + + +async def get_token_async(base_url: str, username: str, password: SecretStr): + """ + This is used by the auto-login plugin. It needs an async request to not lock itself up. + """ + headers = {"Content-Type": "application/json"} + data = {"username": username, "password": password.get_secret_value()} + + auth_url = f"{base_url}/auth/token" + async with httpx.AsyncClient() as client: + response = await client.post(auth_url, headers=headers, json=data) + + if response.status_code != 201: + print("Failed to get token:", response.status_code, response.text) + jwt_token = response.json().get("access_token") + return jwt_token + + +class BearerAuth(AirflowBearerAuth): + def __init__( + self, + token: str, + expires_at: Optional[float] = None, + leeway_seconds: float = None, + ): + if expires_at is None: + twenty_four_hours = 60 * 60 * 24 # Default duration from airflow + leeway_seconds = leeway_seconds or 30 + now = time.time() + expires_at = now + twenty_four_hours + leeway_seconds + + self.expires_at = expires_at + super().__init__(token=token) + + def is_valid(self): + now = time.time() + return now < self.expires_at + + +class AirflowAPIClient(httpx.Client): + + def __init__( + self, + username: str, + password: Union[str, SecretStr], + api_url: str, + **kwargs, + ): + self.username = username + self.password = password + if isinstance(self.password, str): + self.password = SecretStr(self.password) + + token = self.get_token(api_url) + auth = BearerAuth(token) + event_hooks = {"request": [self._renew_token]} + super().__init__(base_url=api_url, auth=auth, event_hooks=event_hooks, **kwargs) + + def get_token(self, base_url=None): + if base_url is None: + base_url = str(self.base_url) + + base_for_auth = base_url.split("/api")[0] + headers = {"Content-Type": "application/json"} + data = {"username": self.username, "password": self.password.get_secret_value()} + + auth_url = f"{base_for_auth}/auth/token" + response = httpx.post(auth_url, headers=headers, json=data) + + if response.status_code != 201: + print("Failed to get token:", response.status_code, response.text) + jwt_token = response.json().get("access_token") + return jwt_token + + def _renew_token(self, request: httpx.Request): + if not self.auth.is_valid(): + new_token = self.get_token() + self.auth.token = new_token + request.headers["Authorization"] = "Bearer " + self.auth.token + + @lru_cache() + @property + def dags(self) -> DagOperations: + return DagOperations(self) + + @lru_cache() + @property + def dag_runs(self) -> DagRunOperations: + return DagRunOperations(self) + + @lru_cache() + @property + def task_instances(self) -> TaskInstanceOperations: + return TaskInstanceOperations(self) + + @lru_cache() + @property + def tasks(self) -> TaskOperations: + return TaskOperations(self) + + @lru_cache() + @property + def assets(self) -> AssetOperations: + return AssetOperations(self) + + +class BaseOperations: + __slots__ = ("client",) + + def __init__(self, client: AirflowAPIClient): + self.client = client + + +class DagOperations(BaseOperations): + + def get_all_dags(self): + return self.client.get("dags").json() + + +class DagRunOperations(BaseOperations): + + def get_most_recent_dag_run(self, dag_id: str): + params = {"dag_id": dag_id, "limit": 1, "order_by": "logical_date"} + return self.client.get(f"dags/{dag_id}/dagRuns", params=params).json() + + def get_dag_run_by_run_id(self, dag_id: str, dag_run_id: str): + return self.client.get(f"dags/{dag_id}/dagRuns/{dag_run_id}").json() + + def trigger_dag_run(self, dag_id: str): + json_data = {"logical_date": None} + response = self.client.post(f"dags/{dag_id}/dagRuns", json=json_data) + return response.json() + + +class TaskInstanceOperations(BaseOperations): + def get_task_instances_in_dag_run(self, dag_id: str, dag_run_id: str): + return self.client.get( + f"dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances" + ).json() + + +class TaskOperations(BaseOperations): + def get_tasks(self, dag_id: str): + return self.client.get(f"dags/{dag_id}/tasks").json() + + +class AssetOperations(BaseOperations): + + def get_asset_events(self): + return self.client.get("assets/events").json() + + def create_asset_event(self, asset_id): + json_data = {"asset_id": asset_id} + response = self.client.post("assets/events", json=json_data) + return response.json() diff --git a/cli/medperf/containers/runners/airflow_runner_utils/airflow_monitor.py b/cli/medperf/containers/runners/airflow_runner_utils/airflow_monitor.py new file mode 100644 index 000000000..c2bdc0011 --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/airflow_monitor.py @@ -0,0 +1,180 @@ +""" +This DAG is responsible for summarizing the status of the pipeline into a +yaml file that can be sent to the MedPerf stage. This summary can be used +by the Benchmark Comitte to track how Data Preparation is going at each +participant and assist users that appear to be struggling. +""" + +from __future__ import annotations +import os +import yaml +import pandas as pd +import asyncio +from typing import Any, Literal +from collections import defaultdict +from medperf.containers.runners.airflow_runner_utils.airflow_api_client import ( + AirflowAPIClient, +) +from medperf.containers.parsers.airflow_parser import AirflowParser +from airflow.utils.state import DagRunState + + +class ReportSummary: + + def __init__( + self, + report_file: str, + execution_status: Literal["running", "failure", "done"], + progress_dict: dict[str, Any] = None, + ): + self.report_file = report_file + self.execution_status = execution_status + self.progress_dict = progress_dict if progress_dict is not None else {} + + def to_dict(self): + report_dict = { + "execution_status": self.execution_status, + "progress": self.progress_dict, + } + return report_dict + + def write_yaml(self): + if self.report_file is None: + return + + report_dict = self.to_dict() + with open(self.report_file, "w") as f: + yaml.dump( + report_dict, + f, + sort_keys=False, + ) + + +class Summarizer: + + def __init__(self, yaml_parser: AirflowParser, report_file: os.PathLike): + self.step_ids = yaml_parser.step_ids + self.report_file = report_file + + @staticmethod + def _get_dag_id_to_dag_dict(client: AirflowAPIClient) -> dict[str, dict[str, Any]]: + all_dags = client.dags.get_all_dags()["dags"] + + all_dags = {dag["dag_id"]: dag for dag in all_dags} + + return all_dags + + @staticmethod + def _get_most_recent_dag_runs( + all_dags: dict[str, dict[str, Any]], client: AirflowAPIClient + ) -> dict[str, dict[str, Any] | None]: + most_recent_dag_runs = {} + + for dag_id in all_dags.keys(): + most_recent_run = client.dag_runs.get_most_recent_dag_run(dag_id=dag_id)[ + "dag_runs" + ] + if not most_recent_run: + most_recent_run = None + else: + most_recent_run = most_recent_run[0] + most_recent_dag_runs[dag_id] = most_recent_run + + return most_recent_dag_runs + + def _sort_column(self, col): + sorted_indices = [] + for task_id in col: + if task_id in self.step_ids: + sorted_indices.append(self.step_ids.index(task_id)) + else: + sorted_indices.append(0) + + return sorted_indices + + def _get_report_summary( + self, + all_dags: dict[str, dict[str, Any]], + most_recent_dag_runs: dict[str, dict[str, Any] | None], + ) -> ReportSummary: + + dag_info_dicts = [] + for dag_id, run_dict in most_recent_dag_runs.items(): + this_dag = all_dags[dag_id] + if run_dict is None: + run_state = None + else: + run_state = run_dict["state"] + + dag_step_tags = [ + tag["name"] for tag in this_dag["tags"] if tag["name"] in self.step_ids + ] + update_dict = { + "DAG ID": dag_id, + "DAG Display Name": this_dag["dag_display_name"], + "DAG Run State": run_state, + "DAG Step Tag": dag_step_tags, + } + + dag_info_dicts.append(update_dict) + + progress_df = pd.DataFrame(dag_info_dicts) + progress_df = progress_df.explode("DAG Step Tag") + progress_df = progress_df.sort_values( + by=["DAG Step Tag"], + key=self._sort_column, + ) + all_dag_tags = progress_df["DAG Step Tag"].unique() + summary_dict = defaultdict(lambda: dict()) + + for dag_tag in all_dag_tags: + + relevant_df = progress_df[progress_df["DAG Step Tag"] == dag_tag] + if relevant_df.empty: + continue + task_success_ratio = len( + relevant_df[relevant_df["DAG Run State"] == DagRunState.SUCCESS] + ) / len(relevant_df) + success_percentage = round(task_success_ratio * 100, 3) + summary_dict[dag_tag] = success_percentage + + summary_dict = dict(summary_dict) + + if all( + dag_run_state == DagRunState.SUCCESS + for dag_run_state in relevant_df["DAG Run State"] + ): + execution_status = "done" + else: + execution_status = "running" + + report_summary = ReportSummary( + report_file=self.report_file, + execution_status=execution_status, + progress_dict=summary_dict, + ) + return report_summary + + def summarize( + self, + airflow_client: AirflowAPIClient, + ): + if self.report_file is None: + return + + all_dags = self._get_dag_id_to_dag_dict(airflow_client) + most_recent_dag_runs = self._get_most_recent_dag_runs(all_dags, airflow_client) + report_summary = self._get_report_summary(all_dags, most_recent_dag_runs) + report_summary.write_yaml() + + async def summarize_every_x_seconds( + self, + interval_seconds: int, + airflow_client: AirflowAPIClient, + ): + attempts = 1 + while True: + attempts += 1 + self.summarize(airflow_client) + await asyncio.sleep(interval_seconds) diff --git a/cli/medperf/containers/runners/airflow_runner_utils/components/airflow_component.py b/cli/medperf/containers/runners/airflow_runner_utils/components/airflow_component.py new file mode 100644 index 000000000..4de52e019 --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/components/airflow_component.py @@ -0,0 +1,47 @@ +import os +from abc import abstractmethod +from typing import List, Literal +from .component import ComponentRunner + + +class AirflowComponentRunner(ComponentRunner): + + def __init__( + self, + python_executable: os.PathLike, + airflow_home: os.PathLike, + container_type: Literal["docker", "singularity"], + workflow_yaml_file: os.PathLike, + additional_files_dir: os.PathLike, + dags_folder: os.PathLike, + ): + self._python_exec = python_executable + self.process = None + user_dags_folder = os.path.join(dags_folder, "user") + self.airflow_home = airflow_home + self._env_vars = { + "AIRFLOW_HOME": airflow_home, + "PYTHONPATH": f"{dags_folder}:{user_dags_folder}:{additional_files_dir}", + "WORKFLOW_YAML_FILE": workflow_yaml_file, + "CONTAINER_TYPE": container_type, + } + + @property + def _run_env(self): + base_env = os.environ.copy() + base_env.update(**self._env_vars) + return base_env + + @property + @abstractmethod + def initialize_command(self) -> List[str]: + pass + + async def start_logic(self): + actual_command = [self._python_exec, "-m", *self.initialize_command] + + logfile_path = os.path.join(self.airflow_home, "logs", self.logfile) + self.run_command_with_logging( + command=actual_command, logfile_path=logfile_path, env=self._run_env + ) + await self.wait_for_start() diff --git a/cli/medperf/containers/runners/airflow_runner_utils/components/api_server.py b/cli/medperf/containers/runners/airflow_runner_utils/components/api_server.py new file mode 100644 index 000000000..2e1b56c88 --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/components/api_server.py @@ -0,0 +1,30 @@ +from __future__ import annotations +from medperf.containers.runners.airflow_runner_utils.components.airflow_component import ( + AirflowComponentRunner, +) +from medperf.containers.runners.airflow_runner_utils.components.utils import ( + validate_port, +) +from typing import Union +from .utils import run_healthcheck + + +class AirflowApiServer(AirflowComponentRunner): + def __init__(self, port: Union[str, int], **kwargs): + self.port = validate_port(port) + super().__init__(**kwargs) + + @property + def component_name(self): + return "Airflow API Server" + + @property + def logfile(self): + return "apiserver.log" + + @property + def initialize_command(self): + return ["airflow", "api-server", "--port", self.port] + + def has_successfully_started(self): + return run_healthcheck(f"http://localhost:{self.port}/api/v2/version") diff --git a/cli/medperf/containers/runners/airflow_runner_utils/components/component.py b/cli/medperf/containers/runners/airflow_runner_utils/components/component.py new file mode 100644 index 000000000..04eeb103f --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/components/component.py @@ -0,0 +1,98 @@ +import os +import subprocess +from abc import ABC, abstractmethod +from typing import List, Dict +import asyncio +import logging + + +class ComponentRunner(ABC): + + START_PERIOD = 5 + INTERVAL = 10 + TIMEOUT = 60 + MAX_TRIES = 5 + + def __init__(self): + self.process = None + + @property + @abstractmethod + def logfile(self) -> str: + pass + + @property + @abstractmethod + def has_successfully_started(self) -> bool: + pass + + @property + def component_name(self): + """Can be overriden to customize name""" + return self.__class__.__name__ + + @property + def starting_message(self): + return f"Starting component {self.component_name}" + + @property + def finished_message(self): + return f"Component {self.component_name} started succesfully" + + @property + def fail_message(self): + return f"Failed to start up component {self.component_name}" + + def run_command_with_logging( + self, command: List[str], logfile_path: os.PathLike, env: Dict = None + ): + with open(logfile_path, "a") as f: + self.process = subprocess.Popen(command, env=env, stdout=f, stderr=f) + + async def start(self): + logging.debug(self.starting_message) + await self.start_logic() + logging.debug(self.finished_message) + + @abstractmethod + async def start_logic(self): + pass + + async def wait_for_start(self): + try: + async with asyncio.timeout(self.TIMEOUT): + await self._sync_wait_for_start() + except asyncio.TimeoutError: + raise TimeoutError(self.fail_message) + + async def _sync_wait_for_start(self): + current_try = 1 + while current_try < self.MAX_TRIES: + if self.has_successfully_started(): + return + await asyncio.sleep(self.INTERVAL) + current_try += 1 + raise ValueError( + f"Component {self.component_name} exceeded maximum number of checks to start." + ) + + def terminate(self): + if self.process is not None: + self.process.terminate() + + def kill(self): + if self.process is not None: + self.process.kill() + + def __enter__(self): + self.start_logic() + return self + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is None: + self.terminate() + + else: + self.kill() + + return False # Propagate exception, if any diff --git a/cli/medperf/containers/runners/airflow_runner_utils/components/dag_processor.py b/cli/medperf/containers/runners/airflow_runner_utils/components/dag_processor.py new file mode 100644 index 000000000..59b9877dd --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/components/dag_processor.py @@ -0,0 +1,42 @@ +from medperf.containers.runners.airflow_runner_utils.components.airflow_component import ( + AirflowComponentRunner, +) +import subprocess +from .utils import build_mounts_dict + + +class AirflowDagProcessor(AirflowComponentRunner): + + def __init__(self, mounts: dict, *args, **kwargs): + super().__init__(*args, **kwargs) + formatted_mounts = build_mounts_dict(mounts) + self._env_vars.update(**formatted_mounts) + + @property + def logfile(self): + return "processor.log" + + @property + def initialize_command(self): + return ["airflow", "dag-processor"] + + @property + def component_name(self): + return "Airflow DAG Processor" + + def has_successfully_started(self): + has_dag_processor_jobs = subprocess.run( + [ + self._python_exec, + "-m", + "airflow", + "jobs", + "check", + "--job-type", + "DagProcessorJob", + "--local", + ], + capture_output=True, + env=self._run_env, + ) + return has_dag_processor_jobs.returncode == 0 diff --git a/cli/medperf/containers/runners/airflow_runner_utils/components/db_postgres_component.py b/cli/medperf/containers/runners/airflow_runner_utils/components/db_postgres_component.py new file mode 100644 index 000000000..f2251b81b --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/components/db_postgres_component.py @@ -0,0 +1,75 @@ +from .component import ComponentRunner +from .utils import validate_port, generate_random_password +from pydantic import SecretStr +from typing import Union + +import os +import random +import string +from abc import abstractmethod + + +class PostgresDatabaseComponent(ComponentRunner): + + START_PERIOD = 2 + INTERVAL = 5 + MAX_TRIES = 5 + TIMEOUT = 30 + + def __init__( + self, + project_name: str, + root_dir: os.PathLike, + postgres_user: str = "postgres", + postgres_password: SecretStr = None, + postgres_db: str = "postgres", + postgres_port: Union[str, int] = 5432, + hostname: str = "localhost", + ): + super().__init__() + self.user = postgres_user + self.db = postgres_db + self.password = postgres_password or generate_random_password() + self.hostname = hostname + self.port = validate_port(postgres_port) + self.data_dir = os.path.join(root_dir, "postgres_data") + self.logs_dir = os.path.join(root_dir, "logs") + self.container_name = self.generate_container_name(project_name) + self.container_id = None + + @property + def component_name(self): + return "PostgreSQL Database for Airflow" + + @property + def logfile(self): + return os.path.join(self.logs_dir, "postgres_db.log") + + @staticmethod + def generate_container_name(project_name): + all_characters = string.ascii_lowercase + string.digits + num_digits = 4 + random_part = "".join(random.choice(all_characters) for _ in range(num_digits)) + complete_name = f"postgres_{project_name}_{random_part}" + return complete_name + + @property + def connection_string(self): + return f"postgresql+psycopg2://{self.user}:{self.password.get_secret_value()}@{self.hostname}:{self.port}/{self.db}" + + @abstractmethod + async def start_logic(self): + """Logic to start the Postgres container goes here""" + pass + + @abstractmethod + def has_successfully_started(self): + """Logic to check if Postgres container has succesfully started goes here""" + + @abstractmethod + def terminate(self): + """Logic to gracefully terminate Postgres container goes here""" + + @abstractmethod + def kill(self): + """ "Logic to forcefully stop Postgres container goes here""" diff --git a/cli/medperf/containers/runners/airflow_runner_utils/components/db_postgres_docker.py b/cli/medperf/containers/runners/airflow_runner_utils/components/db_postgres_docker.py new file mode 100644 index 000000000..44e9cb917 --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/components/db_postgres_docker.py @@ -0,0 +1,70 @@ +from .db_postgres_component import PostgresDatabaseComponent +import subprocess +from medperf import config +import logging + + +class PostgresDBDocker(PostgresDatabaseComponent): + + async def start_logic(self): + self.process = subprocess.run( + [ + "docker", + "run", + "-d", + "--name", + self.container_name, + "-e", + f"POSTGRES_USER={self.user}", + "-e", + f"POSTGRES_PASSWORD={self.password.get_secret_value()}", + "-e", + f"POSTGRES_DB={self.db}", + "-v", + f"{self.data_dir}:/var/lib/postgresql/data:rw", + "-p", + f"{self.port}:5432", + f"{config.airflow_postgres_image}", + ], + capture_output=True, + text=True, + ) + self.container_id = self.process.stdout + await self.wait_for_start() + + def has_successfully_started(self): + postgres_status: subprocess.CompletedProcess = subprocess.run( + [ + "docker", + "exec", + self.container_id, + "pg_isready", + "-U", + self.user, + "-d", + self.db, + ], + capture_output=True, + text=True, + ) + has_started = postgres_status.returncode == 0 + + if not has_started: + logging.debug("Postgres DB not started yet") + logging.debug(f"stdout=\n{postgres_status.stdout}") + logging.debug(f"stderr=\n{postgres_status.stderr}") + + return postgres_status.returncode == 0 + + def terminate(self): + subprocess.run( + ["docker", "stop", self.container_id], capture_output=True, text=True + ) + subprocess.run( + ["docker", "rm", self.container_id], capture_output=True, text=True + ) + + def kill(self): + subprocess.run( + ["docker", "rm", "-f", self.container_id], capture_output=True, text=True + ) diff --git a/cli/medperf/containers/runners/airflow_runner_utils/components/db_postgres_singularity.py b/cli/medperf/containers/runners/airflow_runner_utils/components/db_postgres_singularity.py new file mode 100644 index 000000000..835b1486f --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/components/db_postgres_singularity.py @@ -0,0 +1,78 @@ +from .db_postgres_component import PostgresDatabaseComponent +import subprocess +from medperf import config +from medperf.containers.runners.singularity_utils import ( + get_singularity_executable_props, +) + + +class PostgresDBSingularity(PostgresDatabaseComponent): + """Note: currently untested!""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + executable, runtime, version = get_singularity_executable_props() + self.executable = executable + self.runtime = runtime + self.version = version + + @property + def instance_name(self): + return f"instance://{self.container_name}" + + async def start_logic(self): + """Notes: + - singularity instance start runs detached (similar to docker run -d) + """ + run_env = { + "SINGULARITYENV_POSTGRES_USER": self.user, + "SINGULARITYENV_POSTGRES_PASSWORD": self.password.get_secret_value(), + "SINGULARITYENV_POSTGRES_DB": self.db, + } + self.process = subprocess.run( + [ + self.executable, + "instance", + "start", + "-eC", + "--bind", + f"{self.data_dir}:/var/lib/postgresql/data:rw", + f"docker://{config.airflow_postgres_image}", + self.container_name, + ], + capture_output=True, + text=True, + env=run_env, + ) + await self.wait_for_start() + + def has_successfully_started(self): + postgres_status: subprocess.CompletedProcess = subprocess.run( + [ + self.executable, + "exec", + self.instance_name, + "pg_isready", + "-U", + self.user, + "-d", + self.db, + ], + capture_output=True, + text=True, + ) + return postgres_status.returncode == 0 + + def terminate(self): + subprocess.run( + [self.executable, "stop", self.instance_name], + capture_output=True, + text=True, + ) + + def kill(self): + subprocess.run( + [self.executable, "stop", "-F", self.instance_name], + capture_output=True, + text=True, + ) diff --git a/cli/medperf/containers/runners/airflow_runner_utils/components/scheduler.py b/cli/medperf/containers/runners/airflow_runner_utils/components/scheduler.py new file mode 100644 index 000000000..49c3eaeea --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/components/scheduler.py @@ -0,0 +1,26 @@ +from medperf.containers.runners.airflow_runner_utils.components.airflow_component import ( + AirflowComponentRunner, +) +from .utils import run_healthcheck, build_mounts_dict + + +class AirflowScheduler(AirflowComponentRunner): + def __init__(self, mounts: dict, *args, **kwargs): + super().__init__(*args, **kwargs) + formatted_mounts = build_mounts_dict(mounts) + self._env_vars.update(**formatted_mounts) + + @property + def logfile(self): + return "scheduler.log" + + @property + def component_name(self): + return "Airflow Scheduler" + + @property + def initialize_command(self): + return ["airflow", "scheduler"] + + def has_successfully_started(self): + return run_healthcheck("http://localhost:8974/health") diff --git a/cli/medperf/containers/runners/airflow_runner_utils/components/triggerer.py b/cli/medperf/containers/runners/airflow_runner_utils/components/triggerer.py new file mode 100644 index 000000000..986f87949 --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/components/triggerer.py @@ -0,0 +1,35 @@ +from medperf.containers.runners.airflow_runner_utils.components.airflow_component import ( + AirflowComponentRunner, +) +import subprocess + + +class AirflowTriggerer(AirflowComponentRunner): + @property + def logfile(self): + return "triggerer.log" + + @property + def component_name(self): + return "Airflow Triggerer" + + @property + def initialize_command(self): + return ["airflow", "triggerer"] + + def has_successfully_started(self): + has_triggerer_jobs = subprocess.run( + [ + self._python_exec, + "-m", + "airflow", + "jobs", + "check", + "--job-type", + "TriggererJob", + "--local", + ], + capture_output=True, + env=self._run_env, + ) + return has_triggerer_jobs.returncode == 0 diff --git a/cli/medperf/containers/runners/airflow_runner_utils/components/utils.py b/cli/medperf/containers/runners/airflow_runner_utils/components/utils.py new file mode 100644 index 000000000..08a2b5e9f --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/components/utils.py @@ -0,0 +1,36 @@ +from typing import Union +import requests +from http import HTTPStatus +from pydantic import SecretStr +import secrets + + +def validate_port(port: Union[int, str]) -> str: + try: + is_valid_port = 1 <= int(port) <= 65535 + except ValueError: + is_valid_port = False + + if not is_valid_port: + raise ValueError(f"Port value sent ({port}) is not a valid port!") + + return str(port) + + +def run_healthcheck(healthcheck_url: str) -> bool: + try: + response = requests.get(healthcheck_url) + return response.status_code == HTTPStatus.OK + except requests.exceptions.RequestException: + return False + + +def generate_random_password(nbytes: int = 16) -> SecretStr: + return SecretStr(secrets.token_urlsafe(nbytes)) + + +def build_mounts_dict(mounts: dict[str, str]): + formatted_dict = { + f"host_{mount_name}": host_path for mount_name, host_path in mounts.items() + } + return formatted_dict diff --git a/cli/medperf/containers/runners/airflow_runner_utils/dags/.airflowignore b/cli/medperf/containers/runners/airflow_runner_utils/dags/.airflowignore new file mode 100644 index 000000000..b4bb6e52d --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/dags/.airflowignore @@ -0,0 +1,8 @@ +dag_utils.py +operator_factory.py +dag_builder.py +constants.py +pipeline_state.py +operator_builders/ +yaml_parser/ +api_client/ \ No newline at end of file diff --git a/cli/medperf/containers/runners/airflow_runner_utils/dags/constants.py b/cli/medperf/containers/runners/airflow_runner_utils/dags/constants.py new file mode 100644 index 000000000..e21d55cdc --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/dags/constants.py @@ -0,0 +1,7 @@ +from datetime import datetime, timedelta +import os + +ALWAYS_CONDITION = "ALWAYS" +YESTERDAY = datetime.today() - timedelta(days=1) +FINAL_ASSET = "medperf_airflow_completed_asset" +WORKFLOW_YAML_FILE = os.getenv("WORKFLOW_YAML_FILE") diff --git a/cli/medperf/containers/runners/airflow_runner_utils/dags/dag_builder.py b/cli/medperf/containers/runners/airflow_runner_utils/dags/dag_builder.py new file mode 100644 index 000000000..821649080 --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/dags/dag_builder.py @@ -0,0 +1,76 @@ +from __future__ import annotations +from airflow.sdk import Asset, DAG +from typing import Any +from medperf.containers.runners.airflow_runner_utils.dags.constants import YESTERDAY +from medperf.containers.runners.airflow_runner_utils.dags.operator_factory import ( + operator_factory, +) + + +class DagBuilder: + + def __init__(self, expanded_step: dict[str, Any]): + raw_inlets = expanded_step.pop("inlets", []) + is_first = not bool(raw_inlets) # If a step has no inlets, it is the first step + self.builder_id = expanded_step["id"] + self.inlets = [Asset(raw_inlet) for raw_inlet in raw_inlets] + self.partition = expanded_step.get("partition", None) + self.operator_builders = operator_factory(is_first=is_first, **expanded_step) + self._operator_id_to_builder_obj = { + operator.operator_id: operator for operator in self.operator_builders + } + self._generated_operators = {} + + @property + def num_operators(self) -> int: + return len(self.operator_builders) + + def __str__(self): + return f"{self.__class__.__name__}(id={self.builder_id}, num_ops={self.num_operators})" + + def __repr__(self): + return str(self) + + @property + def display_name(self): + return self.operator_builders[0].display_name + + @property + def tags(self): + return self.operator_builders[0].tags + + def build_dag(self): + schedule = self.inlets or "@once" + + with DAG( + dag_id=self.builder_id, + dag_display_name=self.display_name, + catchup=False, + max_active_runs=1, + schedule=schedule, + start_date=YESTERDAY, + is_paused_upon_creation=False, + tags=self.tags, + auto_register=True, + ) as dag: + for operator_builder in self.operator_builders: + current_operator = self._get_generated_operator_by_id( + operator_builder.operator_id + ) + + for next_id in operator_builder.next_ids: + next_operator = self._get_generated_operator_by_id(next_id) + if next_operator is not None: + current_operator >> next_operator + return dag + + def _get_generated_operator_by_id( + self, + operator_id, + ): + if operator_id not in self._generated_operators: + builder_for_this_operator = self._operator_id_to_builder_obj[operator_id] + self._generated_operators[operator_id] = ( + builder_for_this_operator.get_airflow_operator() + ) + return self._generated_operators[operator_id] diff --git a/cli/medperf/containers/runners/airflow_runner_utils/dags/dag_from_yaml.py b/cli/medperf/containers/runners/airflow_runner_utils/dags/dag_from_yaml.py new file mode 100644 index 000000000..52f73d859 --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/dags/dag_from_yaml.py @@ -0,0 +1,8 @@ +from __future__ import annotations +from yaml_parser import YamlParser + +# We need to import DAG so that airflow recognizes the auto-generated DAGs +from airflow.sdk import DAG # noqa: F401 + +parser = YamlParser() +dags = parser.build_dags() diff --git a/cli/medperf/containers/runners/airflow_runner_utils/dags/dag_utils.py b/cli/medperf/containers/runners/airflow_runner_utils/dags/dag_utils.py new file mode 100644 index 000000000..6d5392ff5 --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/dags/dag_utils.py @@ -0,0 +1,17 @@ +import re + + +def import_external_python_function(function_path: str): + import importlib + + condition_module, condition_function = function_path.rsplit(".", maxsplit=1) + imported_module = importlib.import_module(condition_module) + function_obj = getattr(imported_module, condition_function) + + return function_obj + + +def create_legal_dag_id(subject_slash_timepoint, replace_char="_"): + legal_chars = "A-Za-z0-9_-" + legal_id = re.sub(rf"[^{legal_chars}]", replace_char, subject_slash_timepoint) + return legal_id diff --git a/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/__init__.py b/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/branch_from_sensor_operator_builder.py b/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/branch_from_sensor_operator_builder.py new file mode 100644 index 000000000..e5567069f --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/branch_from_sensor_operator_builder.py @@ -0,0 +1,40 @@ +from medperf.containers.runners.airflow_runner_utils.dags.operator_builders.python_sensor_builder import ( + PythonSensorBuilder, +) +from medperf.containers.runners.airflow_runner_utils.dags.operator_builders.operator_builder import ( + OperatorBuilder, +) +from airflow.decorators import task +from airflow.models.taskinstance import TaskInstance + + +class BranchFromSensorOperatorBuilder(OperatorBuilder): + """ + BranchOperators are used together with Sensors to to automatically create branching behavior. + Once any condition in the sensor is met, it pushes the ID of the corresponding task as an Airflow XCom. + This BranchOperator then reads this XCom and branches accordingly. + """ + + def __init__( + self, + previous_sensor: PythonSensorBuilder, + **kwargs, + ): + + self.sensor_task_id = previous_sensor.operator_id + super().__init__(**kwargs) + + def _define_base_operator(self): + + @task.branch( + task_id=self.operator_id, + task_display_name=self.display_name, + outlets=self.outlets, + ) + def branching(task_instance: TaskInstance): + """Read next task from the Sensor XCom (which detected any of the branching conditions) + and branch into that""" + xcom_data = task_instance.xcom_pull(task_ids=self.sensor_task_id) + return [xcom_data] # This corresponds to the ID of the next task + + return branching() diff --git a/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/container_operator_builder.py b/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/container_operator_builder.py new file mode 100644 index 000000000..cad4124e6 --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/container_operator_builder.py @@ -0,0 +1,87 @@ +from __future__ import annotations +from medperf.containers.runners.airflow_runner_utils.dags.operator_builders.operator_builder import ( + OperatorBuilder, +) +from abc import abstractmethod +import os +from medperf.exceptions import MedperfException + + +class MountInfo: + + def __init__(self, source: os.PathLike, target: os.PathLike, read_only: bool): + self.source = source + self.target = target + self.read_only = read_only + + def __eq__(self, other): + if not isinstance(other, MountInfo): + return False + return ( + self.source == other.source + and self.target == other.target + and self.read_only == other.read_only + ) + + def __hash__(self): + return hash((self.source, self.target, self.read_only)) + + +class ContainerOperatorBuilder(OperatorBuilder): + + def __init__( + self, + image: str, + command: str | list[str], + mounts: dict, + host_mounts: dict, + **kwargs, + ): + super().__init__(**kwargs) + self.image = image + if isinstance(command, str): + self.base_command = command.split(" ") + else: + self.base_command = command + + self.mounts = self.build_mounts(mounts, host_mounts) + + def build_mounts(self, mounts, host_mounts): + mount_infos = set() + for mount_type, mount_info in mounts.items(): + read_only = mount_type == "input_volumes" and self.is_first + for var_name, mount_details in mount_info.items(): + host_path = host_mounts[var_name] + + if host_path is None: + raise MedperfException( + f"Could not find definition for mount {var_name} in the Airflow environment!" + ) + mount_path = mount_details["mount_path"] + + mount_infos.add( + MountInfo(source=host_path, target=mount_path, read_only=read_only) + ) + + if mount_details.get("type") == "directory": + os.makedirs(host_path, exist_ok=True) + elif mount_details.get("type") == "file" and not os.path.exists( + host_path + ): + open(host_path, "x").close() + container_mounts = [ + self._build_mount_item(mount_info) for mount_info in mount_infos + ] + return container_mounts + + @abstractmethod + def _build_mount_item(self, mount_info: MountInfo): + """Logic for building mounts in Docker or Singularity. Implemented in subclasses""" + pass + + def _get_command(self): + command = [*self.base_command] + if self.partition: + command = [*command, "--partition", self.partition] + + return command diff --git a/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/docker_operator_buider.py b/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/docker_operator_buider.py new file mode 100644 index 000000000..fdb84a8d8 --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/docker_operator_buider.py @@ -0,0 +1,42 @@ +from medperf.containers.runners.airflow_runner_utils.dags.operator_builders.container_operator_builder import ( + ContainerOperatorBuilder, + MountInfo, +) +from airflow.providers.docker.operators.docker import DockerOperator +from docker.types import Mount +import os + + +class DockerOperatorBuilder(ContainerOperatorBuilder): + + def _build_mount_item(self, mount_info: MountInfo): + return Mount( + source=mount_info.source, + target=mount_info.target, + type="bind", + read_only=mount_info.read_only, + ) + + def _define_base_operator(self) -> DockerOperator: + + command = self._get_command() + + # TODO when adding device requests, it should be similar to what is defined below + # from docker.types import eviceRequest + # device_request = DeviceRequest(device_ids=["0", "2"], capabilities=[["gpu"]]) + return DockerOperator( + image=self.image, + command=command, + mounts=self.mounts, + task_id=self.operator_id, + task_display_name=self.display_name, + auto_remove="success", + mount_tmp_dir=False, + outlets=self.outlets, + user=f"{os.getuid()}:{os.getgid()}", + # TODO add medperf arguments: shm_size, user, network, ports, entrypoint, gpus + shm_size=None, + network_mode=None, + port_bindings=None, + device_requests=None, # gpus + ) diff --git a/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/empty_operator_builder.py b/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/empty_operator_builder.py new file mode 100644 index 000000000..835ee7293 --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/empty_operator_builder.py @@ -0,0 +1,15 @@ +from medperf.containers.runners.airflow_runner_utils.dags.operator_builders.operator_builder import ( + OperatorBuilder, +) +from airflow.providers.standard.operators.empty import EmptyOperator + + +class EmptyOperatorBuilder(OperatorBuilder): + + def _define_base_operator(self): + + return EmptyOperator( + task_id=self.operator_id, + task_display_name=self.display_name, + outlets=self.outlets, + ) diff --git a/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/manual_approval_buider.py b/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/manual_approval_buider.py new file mode 100644 index 000000000..77203a54a --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/manual_approval_buider.py @@ -0,0 +1,15 @@ +from medperf.containers.runners.airflow_runner_utils.dags.operator_builders.operator_builder import ( + OperatorBuilder, +) +from airflow.providers.standard.operators.hitl import ApprovalOperator + + +class ManualApprovalBuilder(OperatorBuilder): + def _define_base_operator(self): + + return ApprovalOperator( + task_id=self.operator_id, + subject="Manual Approval Task", + fail_on_reject=True, + body="Please confirm all generated results before approving the workflow.", + ) diff --git a/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/operator_builder.py b/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/operator_builder.py new file mode 100644 index 000000000..fcdac4d6a --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/operator_builder.py @@ -0,0 +1,199 @@ +from __future__ import annotations +from airflow.sdk import Asset, BaseOperator +from abc import ABC, abstractmethod +from medperf.containers.runners.airflow_runner_utils.dags.constants import ( + ALWAYS_CONDITION, + FINAL_ASSET, +) + + +class OperatorBuilder(ABC): + def __init__( + self, + operator_id: str, + raw_id: str, + next_ids: list[str] | str = None, + limit: int = None, + from_yaml: bool = True, + make_outlet: bool = True, + on_error: str = None, + is_first: bool = False, + **kwargs, + ): + # TODO add logic to import on_error as a callable + # Always call this init during subclass inits + self.operator_id = operator_id + self.raw_id = raw_id + self.display_name = self._convert_id_to_display_name(self.raw_id) + self.is_first = is_first + if make_outlet: + self.next_ids = [] + self.outlets = self._make_outlets(next_ids) + else: + self.next_ids = next_ids or [] + self.outlets = None + + self.from_yaml = from_yaml + if limit is None: + self.pool_info = None + else: + self.pool_info = f"{self.raw_id}_pool" + + self.partition = kwargs.get("partition") + self.tags = [self.raw_id, self.display_name] + if self.partition: + self.tags.append(self.partition) + self.display_name += f" - {self.partition}" + + def __str__(self): + return f"{self.__class__.__name__}(operator_id={self.operator_id})" + + def __repr__(self): + return str(self) + + def __hash__(self): + return hash(self.operator_id) + + def _make_outlets(self, next_ids): + if next_ids is not None: + if isinstance(next_ids, str): + next_ids = [next_ids] + outlets = [Asset(next_id) for next_id in next_ids] + else: + outlets = [Asset(FINAL_ASSET)] + return outlets + + @staticmethod + def _convert_id_to_display_name(original_id): + return original_id.replace("_", " ").title() + + def get_airflow_operator(self) -> BaseOperator: + base_operator = self._define_base_operator() + if self.pool_info is not None: + base_operator.pool = self.pool_info + + if self.outlets: + base_operator.outlets = self.outlets + + return base_operator + + @abstractmethod + def _define_base_operator(self) -> BaseOperator: + """ + Returns the initial definition of the operator object, without defining pools or outlets. + These, if defined, are patched later in get_airflow_operator. + """ + pass + + @classmethod + def build_operator_list(cls, is_first: bool, **kwargs) -> list[OperatorBuilder]: + """ + Helper method to build a list of required Operators for a DAG Builder. + Usually will return a list with a single element that is the desired operator + If conditional next_ids are sent from the YAML file, then this will return a list including + a Python Sensor Operator and a Python Branching Operator, which are both used to deal with branching + """ + operator_list = [] + kwargs["operator_id"] = kwargs.pop("id", None) + + id_info = kwargs.pop("next", []) + make_outlet_for_main_operator = True + if isinstance(id_info, dict): + # If we have a branching condition in YAML, we return three operators: + # OperatorFromYAML -> PythonSensorOperator -> PythonBranchOperator -> EmptyOperator -> NextOperatorFromYAML + # OperatorFromYAML runs as defind by the YAML File. + # A PythonSensorOperator then waits for any of the defind conditions to be True and + # forwards the True condition to the PythonBranchOperator, which then branches accordingly. + # The Sensor and Branch Operators are defined here, so we can adapt the input arguments of the first operator + # accordingly (ie make it go into sensor that goes into branch which then goes into other operators from the + # YAML file). Empty operators are used between the branch operator and next operator from YAML to simplify + # breaking DAG cycles, if any are present. If DAG cycles are not present, the Empty operators do not + # interfere with DAG execution. + from .branch_from_sensor_operator_builder import ( + BranchFromSensorOperatorBuilder, + ) + from .python_sensor_builder import PythonSensorBuilder + from .empty_operator_builder import EmptyOperatorBuilder + + make_outlet_for_main_operator = False + conditions_definitions = kwargs.pop( + "conditions_definitions", [] + ) # [{'id': 'condition_1', 'type': 'function', 'function_name': 'function_name'}...] + conditions_definitions = { + condition["id"]: { + key: value for key, value in condition.items() if key != "id" + } + for condition in conditions_definitions + } # {'condition_1: {'type': 'function', 'function_name': 'function_name'}, ...} + + branching_info: list[dict[str, str]] = id_info.pop("if") + partition = kwargs.get("partition") + operator_id = kwargs["operator_id"] + operator_raw_id = kwargs["raw_id"] + sensor_id = f"conditions_{operator_id}" + branching_id = f"branch_{operator_id}" + wait_time = id_info.pop("wait", None) + default_conditions = id_info.pop("else", []) + kwargs["next_ids"] = [sensor_id] + + conditions = branching_info + for default_condition in default_conditions: + if default_condition and default_condition != kwargs["operator_id"]: + conditions.append( + {"condition": ALWAYS_CONDITION, "target": default_condition} + ) + processed_conditions = [] + for condition in conditions: + temp_conditions = [ + { + "condition": condition["condition"], + "target": f"empty_{operator_id}_{target}", + } + for target in condition["target"] + ] + processed_conditions.extend(temp_conditions) + + empty_ids = [condition["target"] for condition in processed_conditions] + ids_after_empty = [condition["target"] for condition in conditions] + + sensor_operator = PythonSensorBuilder( + conditions=processed_conditions, + raw_id=f"conditions_{operator_raw_id}", + wait_time=wait_time, + operator_id=sensor_id, + next_ids=[branching_id], + conditions_definitions=conditions_definitions, + from_yaml=False, + make_outlet=False, + partition=partition, + ) + + empty_operators = [ + EmptyOperatorBuilder( + operator_id=empty_id, + raw_id=empty_id, + from_yaml=False, + next_ids=next_id, + partition=partition, + make_outlet=True, + ) + for empty_id, next_id in zip(empty_ids, ids_after_empty) + ] + + branch_operator = BranchFromSensorOperatorBuilder( + next_ids=[empty_id for empty_id in empty_ids], + previous_sensor=sensor_operator, + operator_id=branching_id, + raw_id=f"branch_{operator_raw_id}", + from_yaml=False, + make_outlet=False, + ) + operator_list.extend([sensor_operator, branch_operator, *empty_operators]) + else: + kwargs["next_ids"] = id_info + + this_operator = cls( + **kwargs, is_first=is_first, make_outlet=make_outlet_for_main_operator + ) + operator_list = [this_operator, *operator_list] + return operator_list diff --git a/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/python_sensor_builder.py b/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/python_sensor_builder.py new file mode 100644 index 000000000..41d5208c5 --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/python_sensor_builder.py @@ -0,0 +1,102 @@ +from __future__ import annotations +from medperf.containers.runners.airflow_runner_utils.dags.operator_builders.operator_builder import ( + OperatorBuilder, +) +from airflow.decorators import task +from airflow.sdk import PokeReturnValue +from medperf.containers.runners.airflow_runner_utils.dags.pipeline_state import ( + PipelineState, +) +from medperf.containers.runners.airflow_runner_utils.dags.constants import ( + ALWAYS_CONDITION, +) +from datetime import timedelta +from medperf.containers.runners.airflow_runner_utils.dags.dag_utils import ( + import_external_python_function, +) + +DEFAULT_WAIT_TIME = timedelta(seconds=60) + + +class Condition: + + def __init__( + self, + condition_id: str, + next_id: str, + conditions_definitions: dict[str, dict[str, str]], + ): + self.condition_id = condition_id + self.next_id = next_id + + if self.condition_id == ALWAYS_CONDITION: + self.type = ALWAYS_CONDITION + self.complete_function_name = None + + else: + this_definition = conditions_definitions[self.condition_id] + self.type = this_definition["type"] # Currently unused + self.complete_function_name = this_definition["function_name"] + + +def evaluate_external_condition(condition: Condition, pipeline_state: PipelineState): + if condition.condition_id == ALWAYS_CONDITION: + return True + + condition_function_obj = import_external_python_function( + condition.complete_function_name + ) + print(f"Checking condition {condition.condition_id}...") + condition_result = condition_function_obj(pipeline_state) + + if condition_result: + print(f"Condition {condition.condition_id} met!") + else: + print(f"Condition {condition.condition_id} not met.") + return condition_result + + +class PythonSensorBuilder(OperatorBuilder): + """ + Sensors are used together with BranchOperators to automatically create branching behavior. + Once any condition in the sensor is met, the ID of the corresponding task to that condition is pushed + as an Airflow XCom. The BranchOperator then reads this XCom and branches accordingly. + """ + + def __init__( + self, + conditions: list[dict[str, str]], + conditions_definitions: dict[str, dict[str, str]], + wait_time: float = 60, + **kwargs, + ): + super().__init__(**kwargs) + self.conditions = [ + Condition( + condition_id=condition["condition"], + next_id=condition["target"], + conditions_definitions=conditions_definitions, + ) + for condition in conditions + ] + self.wait_time = wait_time or DEFAULT_WAIT_TIME + + def _define_base_operator(self): + + @task.sensor( + poke_interval=self.wait_time, + mode="reschedule", + task_id=self.operator_id, + task_display_name=self.display_name, + outlets=self.outlets, + ) + def wait_for_conditions(**kwargs): + pipeline_state = PipelineState(running_subject=self.partition, **kwargs) + + for condition in self.conditions: + if evaluate_external_condition(condition, pipeline_state): + return PokeReturnValue(is_done=True, xcom_value=condition.next_id) + + return False + + return wait_for_conditions() diff --git a/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/singularity_operator_builder.py b/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/singularity_operator_builder.py new file mode 100644 index 000000000..a4936e335 --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_builders/singularity_operator_builder.py @@ -0,0 +1,30 @@ +from medperf.containers.runners.airflow_runner_utils.dags.operator_builders.container_operator_builder import ( + ContainerOperatorBuilder, + MountInfo, +) +from airflow.providers.singularity.operators.singularity import SingularityOperator + + +class SingularityOperatorBuilder(ContainerOperatorBuilder): + """ + Currently untested!! + """ + + def _build_mount_item( + self, host_path, mount_path, read_only, mount_info: MountInfo + ): + mount_suffix = "ro" if mount_info.read_only else "rw" + mount_str = f"{mount_info.source}:{mount_info.target}:{mount_suffix}" + return mount_str + + def _define_base_operator(self) -> SingularityOperator: + command = self._get_command() + return SingularityOperator( + image=self.image, + command=command, + volumes=self.mounts, + task_id=self.operator_id, + task_display_name=self.display_name, + auto_remove=True, + outlets=self.outlets, + ) diff --git a/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_factory.py b/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_factory.py new file mode 100644 index 000000000..6c4b05d96 --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/dags/operator_factory.py @@ -0,0 +1,31 @@ +from __future__ import annotations +from .operator_builders.singularity_operator_builder import SingularityOperatorBuilder +from .operator_builders.docker_operator_buider import DockerOperatorBuilder +from .operator_builders.empty_operator_builder import EmptyOperatorBuilder +from .operator_builders.manual_approval_buider import ManualApprovalBuilder +from .operator_builders.operator_builder import OperatorBuilder +import os + +container_builder = ( + SingularityOperatorBuilder + if os.getenv("CONTAINER_TYPE") == "singularity" + else DockerOperatorBuilder +) + +OPERATOR_MAPPING: dict[str, OperatorBuilder] = { + "container": container_builder, + "dummy": EmptyOperatorBuilder, + "manual_approval": ManualApprovalBuilder, +} + + +def operator_factory(type, is_first: bool, **kwargs) -> list[OperatorBuilder]: + + return_list = [] + try: + operator_obj = OPERATOR_MAPPING[type] + except KeyError: + raise TypeError(f"Tasks of type {type} are not supported!") + + return_list = operator_obj.build_operator_list(is_first=is_first, **kwargs) + return return_list diff --git a/cli/medperf/containers/runners/airflow_runner_utils/dags/pipeline_state.py b/cli/medperf/containers/runners/airflow_runner_utils/dags/pipeline_state.py new file mode 100644 index 000000000..d3846caf6 --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/dags/pipeline_state.py @@ -0,0 +1,13 @@ +import os + + +class PipelineState: + # TODO properly define + + def __init__(self, running_subject: str = None, **airflow_kwargs): + self.running_subject = running_subject + self.airflow_kwargs = airflow_kwargs + self.host_input_data_path = os.getenv("host_data_path") + self.host_output_data_path = os.getenv("host_output_path") + self.host_labels_path = os.getenv("host_labels_path") + self.host_output_labels_path = os.getenv("host_output_labels_path") diff --git a/cli/medperf/containers/runners/airflow_runner_utils/dags/yaml_parser.py b/cli/medperf/containers/runners/airflow_runner_utils/dags/yaml_parser.py new file mode 100644 index 000000000..251341805 --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/dags/yaml_parser.py @@ -0,0 +1,366 @@ +from medperf.containers.runners.airflow_runner_utils.dags.constants import ( + WORKFLOW_YAML_FILE, +) +from typing import Union, Any +from medperf.containers.runners.airflow_runner_utils.dags.dag_utils import ( + import_external_python_function, +) +from collections import defaultdict +from airflow.sdk import DAG +from medperf.containers.runners.airflow_runner_utils.dags.dag_utils import ( + create_legal_dag_id, +) +from medperf.containers.runners.airflow_runner_utils.dags.dag_builder import DagBuilder +from copy import deepcopy +from medperf.enums import ContainerConfigMountKeys +from medperf.exceptions import MedperfException +import os +import yaml +from medperf.containers.runners.airflow_runner_utils.dags.pipeline_state import ( + PipelineState, +) + +valid_mount_keys = [item.value for item in ContainerConfigMountKeys] + + +def get_dict_value_by_key_prefix( + key_prefix: str, dictionary: dict, key_suffix: str = None +): + """ + This assumes the values are effectively the same for the purposes this function is called. + If they are different, then calling by suffix does not make sense! + """ + if key_suffix is not None: + complete_key = create_legal_dag_id(f"{key_prefix}_{key_suffix}") + if complete_key in dictionary: + return dictionary[complete_key] + + relevant_dict = { + key: value for key, value in dictionary.items() if key.startswith(key_prefix) + } + return list(relevant_dict.values())[0] + + +class YamlParser: + + def __init__(self, yaml_file: str = None): + self.yaml_file = yaml_file or WORKFLOW_YAML_FILE + yaml_content = self.read_yaml_definition() + self.raw_steps = yaml_content["steps"] + self._raw_conditions = yaml_content.get("conditions", []) + self._raw_subject_definitions = yaml_content.get("partition_def", {}) + self.dag_builders = None + + def read_yaml_definition( + self, + ) -> dict[str, Union[list[dict[str, str]], dict[str, str]]]: + try: + with open(self.yaml_file, "r") as f: + raw_content = f.read() + yaml_info = yaml.safe_load(raw_content) + except Exception: + MedperfException(f"Unable to load Workflow YAML file {self.yaml_file}!") + + return yaml_info + + def read_subject_partitions(self): + if not self._raw_subject_definitions: + return [] + + partition_function_name = self._raw_subject_definitions["function_name"] + partition_function_obj = import_external_python_function( + partition_function_name + ) + subject_partition_list = partition_function_obj(PipelineState()) + return subject_partition_list + + def _get_next_id_from_expanded_step(self, raw_step): + next_field = raw_step.get("next") + if next_field is None: + return [] + elif isinstance(next_field, str): + return [next_field] + elif isinstance(next_field, list): + return next_field + else: + if_fields = next_field.get("if", []) + default_fields = next_field.get("else", None) + next_fields = [ + target for if_field in if_fields for target in if_field["target"] + ] + if default_fields: + next_fields.extend(default_fields) + return next_fields + + def _update_next_id_in_expanded_step( # noqa: C901 + self, current_step, id_to_partition_to_partition_id + ): + next_field = deepcopy(current_step.get("next")) + this_partition = current_step["partition"] + + def get_updated_ids(this_partition, partition_to_partition_id): + if this_partition is None: + # This step is not partitioned, but leads to a partition -> use all as next + updated_ids = list(partition_to_partition_id.values()) + else: + # This step is also partitioned. Pick corresponding partition + updated_ids = [partition_to_partition_id[this_partition]] + return updated_ids + + def update_next_ids(original_next_ids): + new_ids = [] + for next_id in original_next_ids: + partition_to_partition_id = id_to_partition_to_partition_id.get(next_id) + if not partition_to_partition_id: + new_ids.append(next_id) + continue + + updated_ids = get_updated_ids( + this_partition=this_partition, + partition_to_partition_id=partition_to_partition_id, + ) + new_ids.extend(updated_ids) + return new_ids + + if not next_field: + return + elif isinstance(next_field, str): + next_field = [next_field] + + if isinstance(next_field, list): + updated_next = update_next_ids(next_field) + next_field = updated_next + else: + if_fields = next_field.get("if", []) + for if_field in if_fields: + next_ids = if_field["target"] + + if isinstance(next_ids, str): + next_ids = [next_ids] + updated_ids = update_next_ids(next_ids) + if_field["target"] = updated_ids + + default_field = next_field.get("else") + if default_field: + if isinstance(default_field, str): + default_field = [default_field] + new_default = update_next_ids(default_field) + next_field["else"] = new_default + + return next_field + + def _verify_unique_id(self, potential_id, original_id, mapped_steps): + if potential_id in mapped_steps: + raise ValueError(f"ID {original_id} has been used more than one time!") + + def _create_expanded_steps( + self, raw_steps: list[dict[str, Any]], subject_partitions: list[str] + ): + step_id_to_expanded_step = {} + original_id_to_partition_to_partitioned_id = defaultdict(dict) + next_id_to_upstream_ids = defaultdict(set) + + self._expanded_steps_first_pass( + raw_steps=raw_steps, + subject_partitions=subject_partitions, + step_id_to_expanded_step=step_id_to_expanded_step, + original_id_to_partition_to_partitioned_id=original_id_to_partition_to_partitioned_id, + ) + + for step_id, step in step_id_to_expanded_step.items(): + new_next = self._update_next_id_in_expanded_step( + step, original_id_to_partition_to_partitioned_id + ) + step["next"] = new_next + + for step_id, step in step_id_to_expanded_step.items(): + next_ids = self._get_next_id_from_expanded_step(step) + for next_id in next_ids: + next_id_to_upstream_ids[next_id].add(step_id) + + self._make_inlets_for_expanded_steps( + step_id_to_expanded_step=step_id_to_expanded_step, + next_id_to_upstream_ids=next_id_to_upstream_ids, + ) + self._make_host_mounts(step_id_to_expanded_step) + expanded_steps = list(step_id_to_expanded_step.values()) + + return expanded_steps + + def _expanded_steps_first_pass( + self, + raw_steps, + subject_partitions, + step_id_to_expanded_step, + original_id_to_partition_to_partitioned_id, + ): + for step in raw_steps: + original_id = step["id"] + step["conditions_definitions"] = self._raw_conditions + if step.get("partition"): + for subject_partition in subject_partitions: + partitioned_step = {k: v for k, v in step.items()} + partitioned_id = create_legal_dag_id( + f"{original_id}_{subject_partition}" + ) + self._verify_unique_id( + potential_id=partitioned_id, + original_id=original_id, + mapped_steps=step_id_to_expanded_step, + ) + + partitioned_step["id"] = partitioned_id + partitioned_step["raw_id"] = original_id + partitioned_step["partition"] = subject_partition + step_id_to_expanded_step[partitioned_step["id"]] = partitioned_step + original_id_to_partition_to_partitioned_id[original_id][ + subject_partition + ] = partitioned_id + else: + step["partition"] = None + step_id = create_legal_dag_id(original_id) + step["id"] = step_id + step["raw_id"] = original_id + self._verify_unique_id( + potential_id=step_id, + original_id=original_id, + mapped_steps=step_id_to_expanded_step, + ) + step_id_to_expanded_step[step_id] = step + + def _make_inlets_for_expanded_steps( + self, step_id_to_expanded_step, next_id_to_upstream_ids + ): + for step_id, step in step_id_to_expanded_step.items(): + upstream_ids = list(next_id_to_upstream_ids[step_id]) + if upstream_ids: + this_step_inlets = [] + if step["partition"] is None: + for upstream_id in upstream_ids: + upstream_dict = step_id_to_expanded_step[upstream_id] + upstream_partition = upstream_dict["partition"] + if upstream_partition: + new_inlet = create_legal_dag_id( + f"{step_id}_{upstream_partition}" + ) + this_step_inlets.append(new_inlet) + self._update_next_with_new_partition( + upstream_dict, step_id, new_inlet + ) + if not this_step_inlets: + this_step_inlets = [step_id] + step["inlets"] = this_step_inlets + else: + step["inlets"] = [] + + def _make_host_mounts(self, step_id_to_expanded_step: dict): + + look_on_second_pass = set() + + for step_id, step in step_id_to_expanded_step.items(): + host_mounts = {} + self._host_mounts_first_pass( + step=step, + look_on_second_pass=look_on_second_pass, + host_mounts=host_mounts, + volume_key="input_volumes", + ) + self._host_mounts_first_pass( + step=step, + look_on_second_pass=look_on_second_pass, + host_mounts=host_mounts, + volume_key="output_volumes", + ) + step["host_mounts"] = host_mounts + + for step_id in look_on_second_pass: + step = step_id_to_expanded_step[step_id] + self._host_mounts_second_pass( + step=step, + step_id_to_expanded_step=step_id_to_expanded_step, + volume_key="input_volumes", + ) + + @staticmethod + def _host_mounts_first_pass( + step: dict, + look_on_second_pass: set, + host_mounts: dict, + volume_key: str, + ): + step_mounts = step.get("mounts") + if step_mounts is None: + return + + volumes = step_mounts.get(volume_key, {}) + for volume_name, volume_data in volumes.items(): + from_step = volume_data.get("from") + if from_step is not None: + look_on_second_pass.add(step["id"]) + elif volume_name in valid_mount_keys: + host_mounts[volume_name] = os.getenv(f"host_{volume_name}") + else: + raise MedperfException( + f'Invalid mount {volume_name} in step {step["id"]}' + ) + + @staticmethod + def _host_mounts_second_pass( + step: dict, + step_id_to_expanded_step: dict, + volume_key: str, + ): + step_mounts = step.get("mounts") + if step_mounts is None: + return + + volumes = step_mounts.get(volume_key, {}) + for volume_name, volume_data in volumes.items(): + input_step_info = volume_data.get("from") + + if input_step_info is None: + continue # Done in first pass + + input_step = get_dict_value_by_key_prefix( + key_prefix=input_step_info["step"], + dictionary=step_id_to_expanded_step, + key_suffix=step.get("partition"), + ) + output_key = input_step_info["mount"] + step["host_mounts"][volume_name] = input_step["host_mounts"][output_key] + + def _update_next_with_new_partition(self, original_dict, original_next, new_next): + next_field = original_dict["next"] + if isinstance(next_field, list): + next_field = [ + item if item != original_next else new_next for item in next_field + ] + else: + if_fields = next_field["if"] + if_fields = [ + item if item != original_next else new_next for item in if_fields + ] + default_fields = next_field["else"] + default_fields = [ + item if item != original_next else new_next for item in default_fields + ] + original_dict["next"] = next_field + + def map_dag_builders_from_yaml(self) -> list[DagBuilder]: + + subject_partitions = self.read_subject_partitions() + expanded_steps = self._create_expanded_steps( + self.raw_steps, subject_partitions=subject_partitions + ) + dag_builder_list = [ + DagBuilder(expanded_step=expanded_step) for expanded_step in expanded_steps + ] + + return dag_builder_list + + def build_dags(self) -> list[DAG]: + if self.dag_builders is None: + self.dag_builders = self.map_dag_builders_from_yaml() + + dags_list = [builder.build_dag() for builder in self.dag_builders] + return dags_list diff --git a/cli/medperf/containers/runners/airflow_runner_utils/plugins/auto_login.py b/cli/medperf/containers/runners/airflow_runner_utils/plugins/auto_login.py new file mode 100644 index 000000000..dbba37215 --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/plugins/auto_login.py @@ -0,0 +1,49 @@ +from airflow.plugins_manager import AirflowPlugin +from fastapi import FastAPI, Query, Request, status +from fastapi.responses import RedirectResponse +from airflow.api_fastapi.auth.managers.base_auth_manager import COOKIE_NAME_JWT_TOKEN +from airflow.configuration import conf +from medperf.containers.runners.airflow_runner_utils.airflow_api_client import ( + get_token_async, +) +from pydantic import SecretStr + +app = FastAPI() + + +@app.get("/auto_login", status_code=status.HTTP_302_FOUND) +async def root( + request: Request, + username: str = Query(...), + password: str = Query(...), +) -> RedirectResponse: + host = conf.get("api", "host") + port = conf.get("api", "port") + homepage_url = f"http://{host}:{port}" + + # Needs to be async to not lock itself up. + # Seems like the plugin also runs in the API server that takes the request. + token = await get_token_async( + base_url=homepage_url, username=username, password=SecretStr(password) + ) + response = RedirectResponse(url=homepage_url, status_code=status.HTTP_302_FOUND) + + response.set_cookie( + key=COOKIE_NAME_JWT_TOKEN, + value=token, + path="/", + ) + + return response + + +app_with_metadata = { + "app": app, + "url_prefix": "/medperf", + "name": "Auto Login for MedPerf", +} + + +class AutoLoginPlugin(AirflowPlugin): + name = "auto_login" + fastapi_apps = [app_with_metadata] diff --git a/cli/medperf/containers/runners/airflow_runner_utils/system_runner.py b/cli/medperf/containers/runners/airflow_runner_utils/system_runner.py new file mode 100644 index 000000000..08defc4da --- /dev/null +++ b/cli/medperf/containers/runners/airflow_runner_utils/system_runner.py @@ -0,0 +1,430 @@ +from __future__ import annotations +import subprocess +import os +from .airflow_api_client import AirflowAPIClient +from .components.api_server import AirflowApiServer +from .components.airflow_component import AirflowComponentRunner +from .components.dag_processor import AirflowDagProcessor +from .components.db_postgres_docker import PostgresDBDocker +from .components.db_postgres_singularity import PostgresDBSingularity +from .components.scheduler import AirflowScheduler +from .components.triggerer import AirflowTriggerer +from .components.utils import validate_port +from .airflow_monitor import Summarizer +from airflow.utils.state import DagRunState +import configparser +from typing import Union, List +import secrets +import urllib.parse +from pydantic import SecretStr +import json +import logging +import asyncio +from medperf.containers.runners.airflow_runner_utils.dags.constants import ( + FINAL_ASSET, +) +from medperf.containers.parsers.airflow_parser import AirflowParser +from medperf import config +import sys +from medperf.utils import parse_datetime + + +class AirflowSystemRunner: + def __init__( + self, + airflow_home: os.PathLike, + user: str, + email: str, + dags_folder: os.PathLike, + plugins_folder: os.PathLike, + mounts: dict[str, os.PathLike], + additional_files_dir: os.PathLike, + project_name: str, + yaml_parser: AirflowParser, + port: Union[str, int] = 8080, + postgres_port: Union[ + str, int + ] = 5423, # Change default postgres port to avoid conflicts + airflow_python_executable: os.PathLike = None, + **extra_airflow_configs: dict, + ): + self.airflow_home = airflow_home + self._python_exec = airflow_python_executable or sys.executable + # TODO windows + self.port = validate_port(port) + self.dags_folder = dags_folder + self.plugins_folder = plugins_folder + self.extra_configs = extra_airflow_configs + self.mounts = mounts + self.yaml_parser = yaml_parser + self.additional_files_dir = additional_files_dir + self.user = user + self.email = email + self._password = SecretStr(secrets.token_urlsafe(16)) + self.airflow_config_file = os.path.join(self.airflow_home, "airflow.cfg") + self.resuming_from_previous_execution = False + self.project_name = project_name + self._postgres_password = SecretStr(secrets.token_urlsafe(16)) + self._postgres_user = self._postgres_db = "airflow" + self.postgres_port = postgres_port + self.scheduler = self.api_server = self.dag_processor = self.db = ( + self.triggerer + ) = None + self.host = "localhost" + + @property + def _complete_link(self): + return f"http://{self.host}:{self.port}" + + @property + def _airflow_components(self) -> List[AirflowComponentRunner]: + return [ + self.scheduler, + self.api_server, + self.dag_processor, + self.db, + self.triggerer, + ] + + def _initial_setup(self): + logging.debug("Creating initial Airflow configuration") + config_create_process = subprocess.run( + [self._python_exec, "-m", "airflow", "config", "list"], + capture_output=True, + env=self._run_env, + ) + airflow_config = configparser.ConfigParser() + logging.debug( + f"Airflow process creation stdout:\n{config_create_process.stdout}" + ) + with open(self.airflow_config_file, "r") as f: + airflow_config.read_file(f) + airflow_config["core"].update( + { + "dags_folder": self.dags_folder, + "plugins_folder": self.plugins_folder, + "executor": "LocalExecutor", + "auth_manager": "airflow.providers.fab.auth_manager.fab_auth_manager.FabAuthManager", + "load_examples": "false", + } + ) + airflow_config["database"].update( + {"sql_alchemy_conn": self.db.connection_string} + ) + airflow_config["scheduler"].update({"enable_health_check": "true"}) + airflow_config["api"].update( + { + "host": self.host, + "port": self.port, + "instance_name": f"MedPerf Workflow - {self.project_name}", + } + ) + logging.debug(f"Saving Airflow configuration to {self.airflow_config_file}") + with open(self.airflow_config_file, "w") as f: + airflow_config.write(f) + + def init_airflow(self, force_venv_creation: bool = False): + os.makedirs(os.path.join(self.airflow_home, "logs"), exist_ok=True) + self._initialize_components() + + config.ui.print("Starting Airflow components") + asyncio.run(self.db.start()) + if not os.path.exists(self.airflow_config_file): + self._initial_setup() + else: + self.resuming_from_previous_execution = True # Has old config + + self._start_airflow_db() + self._create_admin_user() + self._create_pools() + asyncio.run(self._start_non_db_components()) + config.ui.print("Airflow components successfully started") + + @property + def _run_env(self): + sys_env = os.environ.copy() + sys_env["AIRFLOW_HOME"] = self.airflow_home + return sys_env + + def _initialize_components(self): + common_args = { + "python_executable": self._python_exec, + "airflow_home": self.airflow_home, + "container_type": config.platform, + "additional_files_dir": self.additional_files_dir, + "workflow_yaml_file": self.yaml_parser.config_file_path, + "dags_folder": self.dags_folder, + } + + if config.platform == "singularity": + self.db = PostgresDBSingularity( + project_name=self.project_name, + root_dir=self.airflow_home, + postgres_user="airflow", + postgres_db="airflow", + postgres_port=self.postgres_port, + hostname=self.host, + ) + else: # Default to docker + self.db = PostgresDBDocker( + project_name=self.project_name, + root_dir=self.airflow_home, + postgres_user="airflow", + postgres_db="airflow", + postgres_port=self.postgres_port, + hostname=self.host, + ) + self.api_server = AirflowApiServer(**common_args, port=self.port) + self.scheduler = AirflowScheduler(**common_args, mounts=self.mounts) + self.dag_processor = AirflowDagProcessor(**common_args, mounts=self.mounts) + self.triggerer = AirflowTriggerer(**common_args) + + def _start_airflow_db(self): + logging.debug("Migrating Airflow DB") + init_db_logs = os.path.join(self.airflow_home, "logs", "init_db.log") + with open(init_db_logs, "a") as f: + db_migrate = subprocess.run( + [self._python_exec, "-m", "airflow", "db", "migrate"], + stdout=f, + stderr=f, + env=self._run_env, + ) + + if db_migrate.returncode != 0: + raise ValueError("Failed Airflow Database migration") + + async def _start_non_db_components(self): + await asyncio.gather( + self.api_server.start(), + self.scheduler.start(), + self.dag_processor.start(), + self.triggerer.start(), + ) + + def _create_admin_user(self): + """First attempts to delete user by email, so that it can be generated again with a new password""" + admin_email = "admin@admin.com" + + logging.debug("Deleting existing admin user, if any") + delete_previous_user = subprocess.run( + [ + self._python_exec, + "-m", + "airflow", + "users", + "delete", + "--email", + admin_email, + ], + capture_output=True, + text=True, + env=self._run_env, + ) + logging.debug(f"User deletion stdout:\n{delete_previous_user.stdout}") + logging.debug("Creating new admin user") + create_new_user = subprocess.run( + [ + self._python_exec, + "-m", + "airflow", + "users", + "create", + "--username", + self.user, + "--role", + "Admin", + "--password", + self._password.get_secret_value(), + "--firstname", + "admin", + "--lastname", + "admin", + "--email", + self.email, + ], + capture_output=True, + text=True, + env=self._run_env, + ) + logging.debug(f"Admin creation stdout:\n{create_new_user.stdout}") + + def _create_pools(self): + logging.debug("Creating pools") + if self.yaml_parser.pools is None: + logging.debug("No pool definitions found") + return + + pools_json = os.path.join(self.airflow_home, "pools.json") + + logging.debug(f"Writing pools to {pools_json}") + with open(pools_json, "w") as f: + json.dump(self.yaml_parser.pools, f, indent=4) + + logging.debug("Running pool creation command in Airflow") + create_pools = subprocess.run( + [self._python_exec, "-m", "airflow", "pools", "import", pools_json], + capture_output=True, + text=True, + env=self._run_env, + ) + + if create_pools.returncode != 0: + raise ValueError( + f"Error when attempting to create pools:\n{create_pools.stderr}" + ) + + def wait_for_dag(self): + asyncio.run(self._async_wait_for_dag()) + + def _print_login_link(self, airflow_client: AirflowAPIClient): + + token = airflow_client.get_token() + # token = urllib.parse.urlencode({"token": token}) + token = urllib.parse.urlencode( + {"username": self.user, "password": self._password.get_secret_value()} + ) + full_link = SecretStr(f"{self._complete_link}/medperf/auto_login?{token}") + + msg = [ + f"MedPerf has started executing the Data Pipeline {self.project_name} via Airflow." + ] + msg.append("Execution will continue until the pipeline successfully completes.") + msg.append("Please use the following link to access the Airflow WebUI:\n") + msg.append(full_link.get_secret_value()) + msg.append( + "\nIf this process must be stopped prematurely, please use the Ctrl+C command!" + ) + + for line in msg: + config.ui.print(line) + + async def _async_wait_for_dag(self): + + wait_interval = 10 # seconds + + api_url = f"{self._complete_link}/api/v2" + with AirflowAPIClient( + username=self.user, password=self._password, api_url=api_url + ) as airflow_client: + + self._print_login_link(airflow_client) + self._check_resuming_from_previous_execution(airflow_client) + try: + summarizer = Summarizer( + yaml_parser=self.yaml_parser, + report_file=self.mounts.get("report_file"), + ) + summarizer_task = asyncio.create_task( + summarizer.summarize_every_x_seconds( + interval_seconds=60, airflow_client=airflow_client + ) + ) + while True: + await asyncio.sleep(wait_interval) + if self._check_completed_asset(airflow_client): + break + + except KeyboardInterrupt: + config.ui.print("Interrupting Airflow Execution. Please wait...") + raise + + finally: + summarizer_task.cancel() + summarizer.summarize(airflow_client) + + config.ui.print( + "Pipeline Execution finished. Airflow will now be closed and MedPerf will proceed." + ) + + def _check_resuming_from_previous_execution(self, airflow_client: AirflowAPIClient): + if not self.resuming_from_previous_execution: + # Nothing to do + return + + try: + asset_events = airflow_client.assets.get_asset_events()["asset_events"] + except json.JSONDecodeError: + config.ui.print( + "Could not verify outstanding tasks from previous execution. " + "Please use the Airflow WebUI on next start up to initially resume outstanding tasks." + ) + return + + asset_events = sorted( + asset_events, + key=lambda event: parse_datetime(event["timestamp"]), + reverse=True, + ) + restarted_tasks = [] + seen_asset_uris = set() + for event in asset_events: + if event["uri"] in seen_asset_uris: + continue + seen_asset_uris.add(event["uri"]) + + created_dagruns = event["created_dagruns"] + + try: + last_dagrun = created_dagruns[-1] + except IndexError: + # If no runs started, create fake state to ensure we restart it + last_dagrun = {"state": DagRunState.FAILED} + event["created_dagruns"] = [last_dagrun] + + if last_dagrun["state"] in [DagRunState.SUCCESS, DagRunState.RUNNING]: + continue # Successfully completed or running, nothing do to + + airflow_client.assets.create_asset_event(event["asset_id"]) + restarted_tasks.append(event["name"]) + + if restarted_tasks: + formatted_asset_names = [ + f"- {asset_name}" for asset_name in restarted_tasks + ] + formatted_asset_names_str = "\n".join(formatted_asset_names) + + config.ui.print( + f"Automatically restarting the following uncompleted tasks from last execution:\n{formatted_asset_names_str}" + ) + + def _check_completed_asset(self, airflow_client: AirflowAPIClient) -> bool: + """Checks if the final asset that marks pipeline completion has been updated""" + try: + asset_events = airflow_client.assets.get_asset_events()["asset_events"] + except json.JSONDecodeError: + config.ui.print( + "Error checking completion of the pipeline. Please use the Airflow WebUI to verify execution status." + ) + asset_events = [] + + if not asset_events: + return False + + final_asset = [event for event in asset_events if event["uri"] == FINAL_ASSET] + return bool(final_asset) + + def _stop_airflow(self): + logging.debug("Stopping Airflow execution") + for component in self._airflow_components: + if component is not None and component.process: + logging.debug(f"Stopping component {component.component_name}") + component.terminate() + + def _kill_airflow(self): + logging.debug("Forcefully terminating Airflow execution") + for component in self._airflow_components: + if component is not None: + logging.debug(f"Killing component {component.component_name}") + component.kill() + + def __enter__(self): + logging.debug("Entering Airflow context manager") + return self + + def __exit__(self, exc_type, exc_value, traceback): + logging.debug("Exiting Airflow context manager") + if exc_type is None: + self._stop_airflow() + + else: + self._kill_airflow() diff --git a/cli/medperf/containers/runners/docker_runner.py b/cli/medperf/containers/runners/docker_runner.py index dffe1c32c..84ff9ea82 100644 --- a/cli/medperf/containers/runners/docker_runner.py +++ b/cli/medperf/containers/runners/docker_runner.py @@ -1,4 +1,3 @@ -from medperf.comms.entity_resources import resources from medperf.exceptions import MedperfException from .utils import ( check_allowed_run_args, @@ -7,16 +6,16 @@ add_medperf_environment_variables, add_network_config, add_medperf_tmp_folder, - check_docker_image_hash, + download_image_file, ) from .runner import Runner import logging from .docker_utils import ( craft_docker_run_command, - get_docker_image_hash, get_repo_tags_from_archive, load_image, delete_images, + download_docker_image, ) from medperf.encryption import decrypt_gpg_file, check_gpg from medperf.utils import remove_path, run_command, tmp_path_for_file_decryption @@ -31,7 +30,7 @@ def __init__(self, container_config_parser): def download( self, - expected_image_hash, + expected_image_hash: str, download_timeout: int = None, get_hash_timeout: int = None, ): @@ -39,8 +38,6 @@ def download( logging.debug("Downloading Docker archive") return self._download_docker_archive( expected_image_hash, - download_timeout, - get_hash_timeout, ) else: logging.debug("Downloading Docker image") @@ -53,25 +50,24 @@ def _download_docker_image( expected_image_hash, download_timeout: int = None, get_hash_timeout: int = None, - ): + ) -> str: docker_image = self.parser.get_setup_args() - command = ["docker", "pull", docker_image] - logging.debug("Running pull command") - run_command(command, timeout=download_timeout) - computed_image_hash = get_docker_image_hash(docker_image, get_hash_timeout) - check_docker_image_hash(computed_image_hash, expected_image_hash) + computed_image_hash = download_docker_image( + docker_image=docker_image, + expected_image_hash=expected_image_hash, + download_timeout=download_timeout, + get_hash_timeout=get_hash_timeout, + ) return computed_image_hash def _download_docker_archive( self, expected_image_hash, - download_timeout: int = None, - get_hash_timeout: int = None, ): file_url = self.parser.get_setup_args() - image_path, computed_image_hash = resources.get_cube_image( - file_url, expected_image_hash - ) # Hash checking happens in resources + image_path, computed_image_hash = download_image_file( + image_url=file_url, expected_image_hash=expected_image_hash + ) self.image_archive_path = image_path return computed_image_hash diff --git a/cli/medperf/containers/runners/docker_utils.py b/cli/medperf/containers/runners/docker_utils.py index 38367e217..0fd758f9a 100644 --- a/cli/medperf/containers/runners/docker_utils.py +++ b/cli/medperf/containers/runners/docker_utils.py @@ -8,6 +8,7 @@ import json import tarfile import logging +from .utils import check_docker_image_hash def get_docker_image_hash(docker_image, timeout: int = None): @@ -177,3 +178,17 @@ def delete_images(images): run_command(delete_image_cmd) except ExecutionError: config.ui.print_warning("WARNING: Failed to delete docker images.") + + +def download_docker_image( + docker_image: str, + expected_image_hash: str, + download_timeout: int = None, + get_hash_timeout: int = None, +) -> str: + command = ["docker", "pull", docker_image] + logging.debug("Running pull command") + run_command(command, timeout=download_timeout) + computed_image_hash = get_docker_image_hash(docker_image, get_hash_timeout) + check_docker_image_hash(computed_image_hash, expected_image_hash) + return computed_image_hash diff --git a/cli/medperf/containers/runners/factory.py b/cli/medperf/containers/runners/factory.py index 559d0f1f2..bd72865e3 100644 --- a/cli/medperf/containers/runners/factory.py +++ b/cli/medperf/containers/runners/factory.py @@ -1,15 +1,19 @@ from .docker_runner import DockerRunner from .singularity_runner import SingularityRunner +from .airflow_runner import AirflowRunner from medperf import config from medperf.exceptions import InvalidArgumentError +from medperf.containers.parsers.airflow_parser import AirflowParser import logging -def load_runner(container_config_parser): +def load_runner(container_config_parser, container_name): if config.platform not in container_config_parser.allowed_runners: logging.debug(f"Allowed runners: {container_config_parser.allowed_runners}") raise InvalidArgumentError(f"Cannot run this container using {config.platform}") + if isinstance(container_config_parser, AirflowParser): + return AirflowRunner(container_config_parser, container_name) if config.platform == "docker": return DockerRunner(container_config_parser) if config.platform == "singularity": diff --git a/cli/medperf/containers/runners/runner.py b/cli/medperf/containers/runners/runner.py index 61ad69dfb..5d5fb3c8d 100644 --- a/cli/medperf/containers/runners/runner.py +++ b/cli/medperf/containers/runners/runner.py @@ -1,14 +1,15 @@ from abc import ABC, abstractmethod +from typing import Union, Dict class Runner(ABC): @abstractmethod def download( self, - expected_image_hash: str, + expected_image_hash: Union[str, Dict[str, str]], download_timeout: int = None, get_hash_timeout: int = None, - ): + ) -> Union[str, Dict[str, str]]: pass @abstractmethod @@ -25,3 +26,8 @@ def run( container_decryption_key_file: str = None, ): pass + + @property + def is_workflow(self): + """Can be overriden for workflow runners""" + return False diff --git a/cli/medperf/containers/runners/singularity_runner.py b/cli/medperf/containers/runners/singularity_runner.py index 2cb4b9ad3..09846f395 100644 --- a/cli/medperf/containers/runners/singularity_runner.py +++ b/cli/medperf/containers/runners/singularity_runner.py @@ -1,4 +1,3 @@ -from medperf.comms.entity_resources import resources from medperf.exceptions import InvalidArgumentError, MedperfException from medperf.utils import remove_path, run_command, tmp_path_for_file_decryption from .utils import ( @@ -8,14 +7,14 @@ add_medperf_environment_variables, add_network_config, add_medperf_tmp_folder, - check_docker_image_hash, + download_image_file, ) from .singularity_utils import ( cleanup_singularity_cache, - get_docker_image_hash_from_dockerhub, get_singularity_executable_props, craft_singularity_run_command, convert_docker_image_to_sif, + check_docker_image_by_name, ) from medperf.encryption import decrypt_gpg_file, check_gpg @@ -62,10 +61,10 @@ def _supports_nvccli(self): def download( self, - expected_image_hash, + expected_image_hash: str, download_timeout: int = None, get_hash_timeout: int = None, - ): + ) -> str: if self.parser.is_docker_archive() or self.parser.is_singularity_file(): logging.debug("Downloading image file") return self._download_image_file( @@ -83,28 +82,25 @@ def download( def _download_image_file( self, - expected_image_hash, + expected_image_hash: str, download_timeout: int = None, get_hash_timeout: int = None, - ): + ) -> str: image_file_url = self.parser.get_setup_args() - image_file_path, computed_image_hash = resources.get_cube_image( - image_file_url, expected_image_hash - ) # Hash checking happens in resources + image_file_path, computed_image_hash = download_image_file( + image_url=image_file_url, expected_image_hash=expected_image_hash + ) self.image_file_path = image_file_path self.image_file_hash = computed_image_hash return computed_image_hash def _check_docker_image( self, - expected_image_hash, + expected_image_hash: str, get_hash_timeout: int = None, - ): + ) -> str: docker_image = self.parser.get_setup_args() - computed_image_hash = get_docker_image_hash_from_dockerhub( - docker_image, get_hash_timeout - ) - check_docker_image_hash(computed_image_hash, expected_image_hash) + computed_image_hash = check_docker_image_by_name(docker_image, get_hash_timeout) self.docker_image_hash = computed_image_hash.replace(":", "_") return computed_image_hash diff --git a/cli/medperf/containers/runners/singularity_utils.py b/cli/medperf/containers/runners/singularity_utils.py index 88831dee5..93ca3efa7 100644 --- a/cli/medperf/containers/runners/singularity_utils.py +++ b/cli/medperf/containers/runners/singularity_utils.py @@ -8,6 +8,7 @@ import shlex from medperf.utils import run_command import logging +from .utils import check_docker_image_hash def get_docker_image_hash_from_dockerhub(docker_image, timeout: int = None): @@ -192,3 +193,13 @@ def cleanup_singularity_cache(singularity_executable): run_command(command) except ExecutionError: config.ui.print_warning("WARNING: Failed to clean singularity cache.") + + +def check_docker_image_by_name( + docker_image: str, expected_image_hash: str, get_hash_timeout: int = None +) -> str: + computed_image_hash = get_docker_image_hash_from_dockerhub( + docker_image, get_hash_timeout + ) + check_docker_image_hash(computed_image_hash, expected_image_hash) + return computed_image_hash diff --git a/cli/medperf/containers/runners/utils.py b/cli/medperf/containers/runners/utils.py index 5a20ac70a..0df482021 100644 --- a/cli/medperf/containers/runners/utils.py +++ b/cli/medperf/containers/runners/utils.py @@ -1,4 +1,5 @@ from typing import Optional +from medperf.comms.entity_resources import resources from medperf.exceptions import InvalidContainerSpec, MedperfException from medperf import config import os @@ -107,3 +108,15 @@ def check_docker_image_hash(computed_image_hash, expected_image_hash=None): raise InvalidContainerSpec( f"Hash mismatch. Expected {expected_image_hash}, found {computed_image_hash}." ) + + +def get_expected_hash(hashes_dict, image_name): + """Gets hash from hashes_dict using image_name as a key, or 'default' if not present""" + return hashes_dict.get(image_name, hashes_dict.get("default")) + + +def download_image_file(image_url: str, expected_image_hash: str): + image_path, computed_image_hash = resources.get_cube_image( + image_url, expected_image_hash + ) # Hash checking happens in resources + return image_path, computed_image_hash diff --git a/cli/medperf/entities/aggregator.py b/cli/medperf/entities/aggregator.py index 472730b09..f1efba317 100644 --- a/cli/medperf/entities/aggregator.py +++ b/cli/medperf/entities/aggregator.py @@ -1,5 +1,5 @@ import os -from pydantic import validator +from pydantic import field_validator, Field, ValidationInfo from medperf.entities.interface import Entity from medperf.entities.schemas import MedperfSchema @@ -24,8 +24,8 @@ class Aggregator(Entity, MedperfSchema): - results """ - metadata: dict = {} - config: dict + metadata: dict = Field(default_factory=dict) + config: dict = Field(validate_default=True) aggregation_mlcube: int @staticmethod @@ -48,8 +48,8 @@ def get_metadata_filename(): def get_comms_uploader(): return config.comms.upload_aggregator - @validator("config", pre=True, always=True) - def check_config(cls, v, *, values, **kwargs): + @field_validator("config", mode="before") + def check_config(cls, v: dict, info: ValidationInfo): keys = set(v.keys()) allowed_keys = { "address", diff --git a/cli/medperf/entities/benchmark.py b/cli/medperf/entities/benchmark.py index bd7135b99..ebc071d53 100644 --- a/cli/medperf/entities/benchmark.py +++ b/cli/medperf/entities/benchmark.py @@ -23,15 +23,15 @@ class Benchmark(Entity, ApprovableSchema, DeployableSchema): """ description: Optional[str] = Field(None, max_length=256) - docs_url: Optional[HttpUrl] + docs_url: Optional[HttpUrl] = None demo_dataset_tarball_url: str - demo_dataset_tarball_hash: Optional[str] - demo_dataset_generated_uid: Optional[str] + demo_dataset_tarball_hash: Optional[str] = None + demo_dataset_generated_uid: Optional[str] = None data_preparation_mlcube: int reference_model_mlcube: int data_evaluator_mlcube: int - metadata: dict = {} - user_metadata: dict = {} + metadata: dict = Field(default_factory=dict) + user_metadata: dict = Field(default_factory=dict) is_active: bool = True dataset_auto_approval_allow_list: list[str] = [] dataset_auto_approval_mode: str = "NEVER" diff --git a/cli/medperf/entities/ca.py b/cli/medperf/entities/ca.py index 01a32d78b..d6dc2a67c 100644 --- a/cli/medperf/entities/ca.py +++ b/cli/medperf/entities/ca.py @@ -2,7 +2,7 @@ import os from medperf.entities.interface import Entity from medperf.entities.schemas import MedperfSchema -from pydantic import validator +from pydantic import field_validator, Field, ValidationInfo import medperf.config as config from medperf.account_management import get_medperf_user_data @@ -22,11 +22,11 @@ class CA(Entity, MedperfSchema): - results """ - metadata: dict = {} + metadata: dict = Field(default_factory=dict) client_mlcube: int server_mlcube: int ca_mlcube: int - config: dict + config: dict = Field(validate_default=True) @staticmethod def get_type(): @@ -48,8 +48,8 @@ def get_metadata_filename(): def get_comms_uploader(): return config.comms.upload_ca - @validator("config", pre=True, always=True) - def check_config(cls, v, *, values, **kwargs): + @field_validator("config", mode="before") + def check_config(cls, v: dict, info: ValidationInfo): keys = set(v.keys()) allowed_keys = { "address", diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index 49b6b4842..abe200069 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -17,6 +17,7 @@ remove_path, get_decryption_key_path, ) +from medperf.containers.runners.utils import get_expected_hash from medperf.entities.encrypted_key import EncryptedKey import logging @@ -32,8 +33,8 @@ class Cube(Entity, DeployableSchema): """ container_config: dict - parameters_config: Optional[dict] - image_hash: Optional[str] + parameters_config: Optional[dict] = Field(default_factory=dict) + image_hash: Optional[dict] = Field(default_factory=dict) additional_files_tarball_url: Optional[str] = Field(None, alias="tarball_url") additional_files_tarball_hash: Optional[str] = Field(None, alias="tarball_hash") metadata: dict = Field(default_factory=dict) @@ -78,15 +79,19 @@ def __init__(self, *args, **kwargs): @property def parser(self): if self._parser is None: - self._parser = load_parser(self.container_config) + self._parser = load_parser(self.container_config, self.path) return self._parser @property def runner(self): if self._runner is None: - self._runner = load_runner(self.parser) + self._runner = load_runner(self.parser, self.name) return self._runner + @property + def is_workflow(self) -> bool: + return self.runner.is_workflow + @property def local_id(self): return self.name @@ -159,11 +164,23 @@ def download_run_files(self): raise InvalidEntityError(f"Container {self.name} additional files: {e}") try: - self.image_hash = self.runner.download( - expected_image_hash=self.image_hash, + if self.is_workflow: + expected_hash = self.image_hash + else: + expected_hash = get_expected_hash( + hashes_dict=self.image_hash, image_name="default" + ) + + image_hash = self.runner.download( + expected_image_hash=expected_hash, download_timeout=config.mlcube_configure_timeout, get_hash_timeout=config.mlcube_inspect_timeout, ) + if isinstance(image_hash, str): + image_hash = {"default": image_hash} + + self.image_hash.update(**image_hash) + except InvalidEntityError as e: raise InvalidEntityError(f"Container {self.name} image: {e}") diff --git a/cli/medperf/entities/dataset.py b/cli/medperf/entities/dataset.py index 0009f7ca7..6303c5d7f 100644 --- a/cli/medperf/entities/dataset.py +++ b/cli/medperf/entities/dataset.py @@ -1,7 +1,7 @@ import os from medperf.commands.association.utils import get_user_associations import yaml -from pydantic import Field, validator +from pydantic import Field, field_validator, ValidationInfo from typing import Optional, Union, List from medperf.utils import remove_path @@ -26,11 +26,11 @@ class Dataset(Entity, DeployableSchema): location: Optional[str] = Field(None, max_length=128) input_data_hash: str generated_uid: str - data_preparation_mlcube: Union[int, str] - split_seed: Optional[int] + data_preparation_mlcube: Union[int, str] = Field(validate_default=True) + split_seed: Optional[int] = None generated_metadata: dict = Field(..., alias="metadata") - user_metadata: dict = {} - report: dict = {} + user_metadata: dict = Field(default_factory=dict) + report: dict = Field(default_factory=dict) submitted_as_prepared: bool @staticmethod @@ -53,9 +53,9 @@ def get_metadata_filename(): def get_comms_uploader(): return config.comms.upload_dataset - @validator("data_preparation_mlcube", pre=True, always=True) - def check_data_preparation_mlcube(cls, v, *, values, **kwargs): - if not isinstance(v, int) and not values["for_test"]: + @field_validator("data_preparation_mlcube", mode="before") + def check_data_preparation_mlcube(cls, v: int, info: ValidationInfo): + if not isinstance(v, int) and not info.data.get("for_test"): raise ValueError( "data_preparation_mlcube must be an integer if not running a compatibility test" ) diff --git a/cli/medperf/entities/event.py b/cli/medperf/entities/event.py index 303fa2a8e..291419d4f 100644 --- a/cli/medperf/entities/event.py +++ b/cli/medperf/entities/event.py @@ -26,8 +26,8 @@ class TrainingEvent(Entity, MedperfSchema): training_exp: int participants: dict finished: bool = False - finished_at: Optional[datetime] - report: Optional[dict] + finished_at: Optional[datetime] = None + report: Optional[dict] = None @staticmethod def get_type(): diff --git a/cli/medperf/entities/execution.py b/cli/medperf/entities/execution.py index 0fa0237a6..f92bbb5f9 100644 --- a/cli/medperf/entities/execution.py +++ b/cli/medperf/entities/execution.py @@ -5,7 +5,7 @@ from medperf.account_management import get_medperf_user_data from typing import Optional from datetime import datetime - +from pydantic import Field from medperf.utils import remove_path import yaml @@ -24,13 +24,13 @@ class Execution(Entity, ApprovableSchema): benchmark: int model: int dataset: int - results: dict = {} - metadata: dict = {} - user_metadata: dict = {} - model_report: dict = {} - evaluation_report: dict = {} + results: dict = Field(default_factory=dict) + metadata: dict = Field(default_factory=dict) + user_metadata: dict = Field(default_factory=dict) + model_report: dict = Field(default_factory=dict) + evaluation_report: dict = Field(default_factory=dict) finalized: bool = False - finalized_at: Optional[datetime] + finalized_at: Optional[datetime] = None @staticmethod def get_type(): diff --git a/cli/medperf/entities/report.py b/cli/medperf/entities/report.py index 6f962d5d7..5314941f7 100644 --- a/cli/medperf/entities/report.py +++ b/cli/medperf/entities/report.py @@ -28,15 +28,15 @@ class TestReport(Entity): """ name: Optional[str] = "name" - demo_dataset_url: Optional[str] - demo_dataset_hash: Optional[str] - data_path: Optional[str] - labels_path: Optional[str] - prepared_data_hash: Optional[str] + demo_dataset_url: Optional[str] = None + demo_dataset_hash: Optional[str] = None + data_path: Optional[str] = None + labels_path: Optional[str] = None + prepared_data_hash: Optional[str] = None data_preparation_mlcube: Optional[Union[int, str]] model: Union[int, str] data_evaluator_mlcube: Union[int, str] - results: Optional[dict] + results: Optional[dict] = None @staticmethod def get_type(): @@ -73,7 +73,9 @@ def all(cls, unregistered: bool = False, filters: dict = {}) -> List["TestReport return super().all(unregistered=True, filters={}) @classmethod - def get(cls, uid: str, local_only: bool = False, valid_only: bool = True) -> "TestReport": + def get( + cls, uid: str, local_only: bool = False, valid_only: bool = True + ) -> "TestReport": """Gets an instance of the TestReport. ignores local_only inherited flag as TestReport is always a local entity. Args: uid (str): Report Unique Identifier diff --git a/cli/medperf/entities/schemas.py b/cli/medperf/entities/schemas.py index f5c564901..7a7a904b7 100644 --- a/cli/medperf/entities/schemas.py +++ b/cli/medperf/entities/schemas.py @@ -1,5 +1,14 @@ from datetime import datetime -from pydantic import BaseModel, Field, validator, HttpUrl, ValidationError +from pydantic import ( + BaseModel, + Field, + field_validator, + HttpUrl, + ValidationError, + ConfigDict, + ValidationInfo, +) +from pydantic_core import PydanticUndefined from typing import Optional from collections import defaultdict @@ -10,12 +19,12 @@ class MedperfSchema(BaseModel): for_test: bool = False - id: Optional[int] - name: str = Field(..., max_length=128) - owner: Optional[int] + id: Optional[int] = None + name: str = Field(..., max_length=128, validate_default=True) + owner: Optional[int] = None is_valid: bool = True - created_at: Optional[datetime] - modified_at: Optional[datetime] + created_at: Optional[datetime] = None + modified_at: Optional[datetime] = None def __init__(self, *args, **kwargs): """Override the ValidationError procedure so we can @@ -42,7 +51,7 @@ def dict(self, *args, **kwargs) -> dict: Returns: dict: filtered dictionary """ - fields = self.__fields__ + fields = self.__class__.model_fields valid_fields = [] # Gather all the field names, both original an alias names for field_name, field_item in fields.items(): @@ -50,10 +59,17 @@ def dict(self, *args, **kwargs) -> dict: valid_fields.append(field_item.alias) # Remove duplicates valid_fields = set(valid_fields) - model_dict = super().dict(*args, **kwargs) + model_dict = super().model_dump(*args, **kwargs) out_dict = {k: v for k, v in model_dict.items() if k in valid_fields} return out_dict + def model_dump(self, *args, **kwargs) -> dict: + """ + Added method to have a similar API to Pydantic V2, which recommends using + .model_dump instead of .dict + """ + return self.dict(*args, **kwargs) + def todict(self) -> dict: """Dictionary containing both original and alias fields @@ -70,16 +86,23 @@ def todict(self) -> dict: og_dict[k] = str(v) return og_dict - @validator("*", pre=True) - def empty_str_to_none(cls, v): + @field_validator("*", mode="before") + @classmethod + def empty_str_to_none(cls, v, info: ValidationInfo): if v == "": - return None + current_attribute = cls.model_fields[info.field_name] + default_value = None + if current_attribute.default != PydanticUndefined: + default_value = current_attribute.default + elif current_attribute.default_factory is not None: + default_value = current_attribute.default_factory() + return default_value + return v - class Config: - allow_population_by_field_name = True - extra = "allow" - use_enum_values = True + model_config = ConfigDict( + populate_by_name=True, use_enum_values=True, extra="allow" + ) class DeployableSchema(BaseModel): @@ -87,10 +110,10 @@ class DeployableSchema(BaseModel): class ApprovableSchema(BaseModel): - approved_at: Optional[datetime] - approval_status: Status = None + approved_at: Optional[datetime] = None + approval_status: Status = Field(None, validate_default=True) - @validator("approval_status", pre=True, always=True) + @field_validator("approval_status", mode="before") def default_status(cls, v): status = Status.PENDING if v is not None: diff --git a/cli/medperf/entities/training_exp.py b/cli/medperf/entities/training_exp.py index d6641c953..3ed1e5b63 100644 --- a/cli/medperf/entities/training_exp.py +++ b/cli/medperf/entities/training_exp.py @@ -22,16 +22,16 @@ class TrainingExp(Entity, MedperfSchema, ApprovableSchema, DeployableSchema): """ description: Optional[str] = Field(None, max_length=256) - docs_url: Optional[HttpUrl] + docs_url: Optional[HttpUrl] = None demo_dataset_tarball_url: str - demo_dataset_tarball_hash: Optional[str] - demo_dataset_generated_uid: Optional[str] + demo_dataset_tarball_hash: Optional[str] = None + demo_dataset_generated_uid: Optional[str] = None data_preparation_mlcube: int fl_mlcube: int - fl_admin_mlcube: Optional[int] - plan: dict = {} - metadata: dict = {} - user_metadata: dict = {} + fl_admin_mlcube: Optional[int] = None + plan: dict = Field(default_factory=dict) + metadata: dict = Field(default_factory=dict) + user_metadata: dict = Field(default_factory=dict) @staticmethod def get_type(): diff --git a/cli/medperf/enums.py b/cli/medperf/enums.py index b00b74ee4..dafd184b0 100644 --- a/cli/medperf/enums.py +++ b/cli/medperf/enums.py @@ -27,3 +27,14 @@ class ContainerTypes(Enum): DOCKER_ARCHIVE = "DockerArchive" ENCRYPTED_DOCKER_ARCHIVE = "EncryptedDockerArchive" ENCRYPTED_SINGULARITY_FILE = "EncryptedSingularityFile" + + +class ContainerConfigMountKeys(Enum): + data_path = "data_path" + output_path = "output_path" + labels_path = "labels_path" + output_labels_path = "output_labels_path" + statistics_file = "statistics_file" + additional_files = "additional_files" + parameters_file = "parameters_file" + metadata_path = "metadata_path" diff --git a/cli/medperf/logging/__init__.py b/cli/medperf/logging/__init__.py index 6f13bc315..ff8e6e6e3 100644 --- a/cli/medperf/logging/__init__.py +++ b/cli/medperf/logging/__init__.py @@ -8,12 +8,18 @@ def setup_logging(log_file: str, loglevel: str): + # Airflow overrides root logger, so we import it + # here first to override it immediately with our config + import airflow # noqa: F401 + # Ensure root folder exists log_folder = os.path.dirname(log_file) os.makedirs(log_folder, exist_ok=True) log_fmt = "%(asctime)s | %(module)s.%(funcName)s | %(levelname)s: %(message)s" - handler = handlers.RotatingFileHandler(log_file, backupCount=config.logs_backup_count) + handler = handlers.RotatingFileHandler( + log_file, backupCount=config.logs_backup_count + ) handler.setFormatter(NewLineFormatter(log_fmt)) logging.basicConfig( level=loglevel.upper(), diff --git a/cli/medperf/tests/commands/benchmark/test_submit.py b/cli/medperf/tests/commands/benchmark/test_submit.py index 754f85e65..aeba5f8a1 100644 --- a/cli/medperf/tests/commands/benchmark/test_submit.py +++ b/cli/medperf/tests/commands/benchmark/test_submit.py @@ -40,7 +40,7 @@ def test_submit_prepares_tmp_path_for_cleanup(): def test_submit_uploads_benchmark_data(mocker, result, comms, ui): # Arrange submission = SubmitBenchmark(BENCHMARK_INFO) - submission.bmk.metadata = {"results": result} + submission.bmk.metadata = {"results": {}} expected_data = submission.bmk.todict() spy_upload = mocker.patch.object( comms, "upload_benchmark", return_value=TestBenchmark().todict() diff --git a/cli/medperf/tests/commands/compatibility_test/test_utils.py b/cli/medperf/tests/commands/compatibility_test/test_utils.py index ac1b50592..df349c31a 100644 --- a/cli/medperf/tests/commands/compatibility_test/test_utils.py +++ b/cli/medperf/tests/commands/compatibility_test/test_utils.py @@ -167,7 +167,6 @@ def test_setup_cube_skips_download_with_local_image(self, mocker): def test_setup_cube_stores_decryption_key(self, mocker): # Arrange mock_cube = TestCube(id=5) - mock_cube.identifier = 5 dummy_key_path = "/tmp/somekey.key" stored_key_path = "/tmp/stored_key_path" diff --git a/cli/medperf/tests/mocks/benchmark.py b/cli/medperf/tests/mocks/benchmark.py index 38d0a35a7..19ab70382 100644 --- a/cli/medperf/tests/mocks/benchmark.py +++ b/cli/medperf/tests/mocks/benchmark.py @@ -1,6 +1,7 @@ from typing import Optional from medperf.enums import Status from medperf.entities.benchmark import Benchmark +from pydantic import Field class TestBenchmark(Benchmark): @@ -12,4 +13,4 @@ class TestBenchmark(Benchmark): data_preparation_mlcube: int = 1 reference_model_mlcube: int = 2 data_evaluator_mlcube: int = 3 - approval_status: Status = Status.APPROVED + approval_status: Status = Field(Status.APPROVED, validate_default=True) diff --git a/cli/medperf/tests/mocks/cube.py b/cli/medperf/tests/mocks/cube.py index 1793c9a9d..f2e8fba6d 100644 --- a/cli/medperf/tests/mocks/cube.py +++ b/cli/medperf/tests/mocks/cube.py @@ -16,4 +16,8 @@ class TestCube(Cube): ) additional_files_tarball_hash: Optional[str] = EMPTY_FILE_HASH state: str = "OPERATION" - is_valid = True + is_valid: bool = True + + @property + def is_workflow(self): + return False diff --git a/cli/medperf/tests/mocks/dataset.py b/cli/medperf/tests/mocks/dataset.py index fc9e57add..4362960ec 100644 --- a/cli/medperf/tests/mocks/dataset.py +++ b/cli/medperf/tests/mocks/dataset.py @@ -1,6 +1,7 @@ from typing import Optional, Union from medperf.enums import Status from medperf.entities.dataset import Dataset +from pydantic import Field class TestDataset(Dataset): @@ -11,7 +12,7 @@ class TestDataset(Dataset): data_preparation_mlcube: Union[int, str] = 1 input_data_hash: str = "input_data_hash" generated_uid: str = "generated_uid" - generated_metadata: dict = {} + generated_metadata: dict = Field(default_factory=dict) status: Status = Status.APPROVED.value state: str = "OPERATION" submitted_as_prepared: bool = False diff --git a/cli/medperf/tests/mocks/execution.py b/cli/medperf/tests/mocks/execution.py index c50f93048..8f7e742c7 100644 --- a/cli/medperf/tests/mocks/execution.py +++ b/cli/medperf/tests/mocks/execution.py @@ -1,5 +1,6 @@ from typing import Optional from medperf.entities.execution import Execution +from pydantic import Field class TestExecution(Execution): @@ -9,7 +10,7 @@ class TestExecution(Execution): benchmark: int = 1 model: int = 1 dataset: int = 1 - results: dict = {} + results: dict = Field(default_factory=dict) def upload(self): # self.id = 1 diff --git a/cli/medperf/tests/mocks/report.py b/cli/medperf/tests/mocks/report.py index 4a92660d2..f0c27a265 100644 --- a/cli/medperf/tests/mocks/report.py +++ b/cli/medperf/tests/mocks/report.py @@ -6,9 +6,9 @@ class TestTestReport(TestReport): __test__ = False demo_dataset_url: Optional[str] = "url" demo_dataset_hash: Optional[str] = "hash" - data_path: Optional[str] - labels_path: Optional[str] - prepared_data_hash: Optional[str] + data_path: Optional[str] = None + labels_path: Optional[str] = None + prepared_data_hash: Optional[str] = None data_preparation_mlcube: Optional[Union[int, str]] = 1 model: Union[int, str] = 2 data_evaluator_mlcube: Union[int, str] = 3 diff --git a/cli/medperf/utils.py b/cli/medperf/utils.py index 87c603ff3..154e18ab0 100644 --- a/cli/medperf/utils.py +++ b/cli/medperf/utils.py @@ -16,13 +16,18 @@ import shutil from pexpect import spawn from datetime import datetime -from typing import List +from typing import List, Union from colorama import Fore, Style from pexpect.exceptions import TIMEOUT from git import Repo, GitCommandError import medperf.config as config -from medperf.exceptions import CleanExit, ExecutionError, InvalidArgumentError +from medperf.exceptions import ( + CleanExit, + ExecutionError, + InvalidArgumentError, +) import shlex +from pydantic import TypeAdapter from email_validator import validate_email, EmailNotValidError @@ -668,3 +673,18 @@ def validate_and_normalize_emails(emails: list[str]): logging.debug(f"Invalid email: |{email}|") raise InvalidArgumentError(str(e)) return emails + + +def parse_datetime(datetime_obj: Union[str, int, datetime]): + # Pydantic v2 way of implementing old parse_datetime functionality. + # Adapted from https://github.com/pydantic/pydantic/discussions/6204#discussioncomment-6266717 + + if isinstance(datetime_obj, datetime): + return datetime_obj + elif isinstance(datetime_obj, (str, int)): + return TypeAdapter(datetime).validate_strings(str(datetime_obj)) + else: + raise ValueError( + "Current implementation of parse_datetime only supports strings, ints and datetimes!\n" + f"Object sent was of type {type(datetime_obj)}\n{datetime_obj=}" + ) diff --git a/cli/medperf/web_ui/common.py b/cli/medperf/web_ui/common.py index e6e982c8b..7b95c57e9 100644 --- a/cli/medperf/web_ui/common.py +++ b/cli/medperf/web_ui/common.py @@ -11,7 +11,7 @@ from fastapi.requests import Request from medperf import config from starlette.responses import RedirectResponse -from pydantic.datetime_parse import parse_datetime +from medperf.utils import parse_datetime from medperf.enums import Status from medperf.web_ui.auth import ( diff --git a/cli/medperf/web_ui/containers/routes.py b/cli/medperf/web_ui/containers/routes.py index 5c9431941..85ed53ede 100644 --- a/cli/medperf/web_ui/containers/routes.py +++ b/cli/medperf/web_ui/containers/routes.py @@ -142,6 +142,9 @@ def register_container( "state": "OPERATION", } container_id = None + if not decryption_file: + decryption_file = None # Avoid mapping to current directory from empty string + try: container_id = SubmitCube.run( container_info, @@ -181,6 +184,8 @@ def test_container( ): initialize_state_task(request, task_name="container_compatibility_test") return_response = {"status": "", "error": "", "results": None} + if not decryption_file: + decryption_file = None # Avoid mapping to current directory from empty string try: _, results = CompatibilityTestExecution.run( benchmark=benchmark, diff --git a/cli/pytest.ini b/cli/pytest.ini new file mode 100644 index 000000000..23d657cda --- /dev/null +++ b/cli/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +filterwarnings = ignore:color:UserWarning:yaspin \ No newline at end of file diff --git a/cli/requirements.txt b/cli/requirements.txt index a9e853d85..43acfb281 100644 --- a/cli/requirements.txt +++ b/cli/requirements.txt @@ -1,9 +1,9 @@ -typer~=0.12.0 +typer~=0.15.0 rich~=13.7.0 PyYAML==6.0.1 requests>=2.26.0 -pydantic==1.10.13 -yaspin==2.1.0 +pydantic==2.12.5 +yaspin==3.3.0 tabulate==0.9.0 pexpect==4.8.0 colorama==0.4.4 @@ -12,7 +12,6 @@ pytest-mock==3.14.0 pyfakefs==5.8.0 validators==0.18.2 merge-args==0.1.5 -synapseclient==4.1.1 schema==0.7.5 email-validator==2.0.0 auth0-python==4.3.0 @@ -20,11 +19,22 @@ pandas==2.1.0 numpy==1.26.4 watchdog==3.0.0 GitPython==3.1.41 -psutil==5.9.8 +dill==0.4.0 semver==3.0.4 cookiecutter==2.1.1 -uvicorn==0.30.1 -fastapi==0.111.1 +uvicorn==0.37.0 +fastapi==0.116.1 fastapi-login==1.10.2 cryptography==46.0.3 +methodtools==0.4.7 click==8.1.8 +synapseclient @ git+https://github.com/Sage-Bionetworks/synapsePythonClient@develop +methodtools==0.4.7 +apache-airflow==3.1.5 +apache-airflow-task-sdk==1.1.5 +apache-airflow-providers-singularity==3.9.1 +apache-airflow-providers-docker==4.5.1 +apache-airflow-providers-fab==3.1.0 +psycopg2-binary==2.9.9 +asyncpg==0.30.0 + diff --git a/cli/tests_setup.sh b/cli/tests_setup.sh index 554d13755..5a5321eb5 100644 --- a/cli/tests_setup.sh +++ b/cli/tests_setup.sh @@ -1,5 +1,5 @@ #! /bin/sh -while getopts s:d:c:ft:rl: flag; do +while getopts s:d:c:ft:rl::p flag; do case "${flag}" in s) SERVER_URL=${OPTARG} ;; d) DIRECTORY=${OPTARG} ;; @@ -8,6 +8,7 @@ while getopts s:d:c:ft:rl: flag; do t) TIMEOUT=${OPTARG} ;; r) RESUME_TEST="true" ;; l) TEST_FROM_LINE=${OPTARG} ;; + p) PRIVATE="true" ;; esac done @@ -16,6 +17,8 @@ DIRECTORY="${DIRECTORY:-/tmp/medperf_test_files}" CLEANUP="${CLEANUP:-false}" RESUME_TEST="${RESUME_TEST:-false}" FRESH="${FRESH:-false}" +PRIVATE="${PRIVATE:-false}" # Include private model in cli_chestxray_tutorial_test +OS=$(uname) # if resume test, read the test root from local file if "${RESUME_TEST}"; then diff --git a/examples/HEMnet/data_preparator/.gitignore b/examples/HEMnet/data_preparator/.gitignore new file mode 100644 index 000000000..de412e3b4 --- /dev/null +++ b/examples/HEMnet/data_preparator/.gitignore @@ -0,0 +1,2 @@ +svs/ +additional_files/workflow.yaml* \ No newline at end of file diff --git a/examples/HEMnet/data_preparator/README.md b/examples/HEMnet/data_preparator/README.md new file mode 100644 index 000000000..7965b76be --- /dev/null +++ b/examples/HEMnet/data_preparator/README.md @@ -0,0 +1,67 @@ +# HEMnet pipeline + +This pipeline runs the training data preparation procedure from [HEMnet](https://github.com/BiomedicalMachineLearning/HEMnet/tree/master). A modified version of their original [Docker image](https://hub.docker.com/layers/andrewsu1/hemnet/latest/images/sha256-5b371f828cfd41e223b46678cef157ec599847f17f0cf5711a0288908b287d5b) is used here, which splits the processing into separate steps that are chained together via CWL. This modified Docker image is available in DockerHub at [this link](https://hub.docker.com/r/mlcommons/hemnet-airflow), with the source code available in the `./project`subdirectory. The main modification of this version is splitting the pipeline into separate stages that can be executed by Airflow. + +## 1. Get the HEMnet data + +The data used for the HEMnet study is available on [this location](https://dna-discovery.stanford.edu/publicmaterial/web-resources/HEMnet/images/). The pipeline runs on pairs of TP53 and H&E (HandE suffix) slides. *The code does **NOT** check for valid pairings in inputs!* Make sure your input data is correctly formatted as shown below + + +``` +. +svs +├── NNN_C_XXXX_Y_TP53.svs +├── NNN_C_XXXX_Y_HandE.svs +``` + +Where `NNN` and `XXXX` ID numbers (`XXXX` may contain letters), `C` may be either equal to `T` (tumor) or `N` (non_tumor) and `Y` is a single digit number that may differ between slides of a given pair. For example, the structure below contains only valid input data pairings: + +``` +. +svs +├── 526_T_15907_2_TP53.svs +├── 526_T_15907_3_HandE.svs +├── 2065_N_127524A_2_HandE.svs +├── 2065_N_127524A_4_.TP53svs +├── 2171_T_11524A_2_HandE.svs +├── 2171_T_11524A_4_TP53.svs +``` + +While the example below is invalid. Notice the slightly different IDs in the pseudo-parings marked with (*) and (**). + +``` +. +svs +├── 526_T_15907_2_TP53.svs +├── 526_T_15907_3_HandE.svs +├── 2065_N_12752A_2_HandE.svs (*) +├── 2065_T_12756A_4_.TP53svs. (*) +├── 2171_N_11521A_2_HandE.svs (**) +├── 2171_T_11524A_4_TP53.svs (**) +``` + +### 2.1 Define the template slide +One of the slides used as input must be also used as the template slide. This template must be a copy of the original slide into the `templates` directory, inside the `svs` directory. The example below shows valid input using the `2171_T_11524A_4_TP53.svs` slide as the template slide: + +``` +. +svs +├──template +│ └── 2171_T_11524A_4_TP53.svs +├── 526_T_15907_2_TP53.svs +├── 526_T_15907_3_HandE.svs +├── 2065_N_127524A_2_HandE.svs +├── 2065_N_127524A_4_.TP53svs +├── 2171_T_11524A_2_HandE.svs +├── 2171_T_11524A_4_TP53.svs +``` + + + +## Appendix. Build the customized Docker image +From the directory of this README file, run the following commands +```shell +cd project +docker build . -t local/hemnet:0.0.2 +``` +*NOTE!* If a different image tag is used, the `image` field of the workflow file located at `./workflow.yaml` must be modified to match the name used. \ No newline at end of file diff --git a/examples/HEMnet/data_preparator/project/Dockerfile b/examples/HEMnet/data_preparator/project/Dockerfile new file mode 100644 index 000000000..9af4106e3 --- /dev/null +++ b/examples/HEMnet/data_preparator/project/Dockerfile @@ -0,0 +1,6 @@ +FROM andrewsu1/hemnet:latest +RUN rm -rf /HEMnet && git clone -b master --single-branch --depth 1 https://github.com/BiomedicalMachineLearning/HEMnet.git +COPY *.py /HEMnet/HEMnet/ +WORKDIR "/HEMnet/HEMnet" +ENV PATH="/opt/conda/envs/HEMnet/bin:$PATH" +CMD [ "/bin/bash" ] \ No newline at end of file diff --git a/examples/HEMnet/data_preparator/project/affine_registration.py b/examples/HEMnet/data_preparator/project/affine_registration.py new file mode 100644 index 000000000..3cef6ff8a --- /dev/null +++ b/examples/HEMnet/data_preparator/project/affine_registration.py @@ -0,0 +1,159 @@ +import argparse +from mod_utils import ( + load_pil_image, + save_fig, + dump_sitk_transform, + dump_sitk_image, + get_fixed_and_moving_images, + load_df, + dump_df, +) +import matplotlib.pyplot as plt +import SimpleITK as sitk +from mod_constants import ( + OUTPUT_PATH, + INTERPOLATOR, + AFFINE_TRANSFORM_HDF, + MOVING_RESAMPLED_AFFINE_NPY, + TP53_GRAY, + HE_GRAY, +) +from utils import ( + calculate_mutual_info, + get_pil_from_itk, + start_plot, + update_multires_iterations, + update_plot, + plot_metric, + end_plot, +) +import time +import numpy as np + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-s", + "--subject-subdir", + type=str, + required=True, + help="Prefix that defines the slides used in this step.", + ) + parser.add_argument( + "-a", + "--align_mag", + type=float, + default=2, + help="Magnification for aligning H&E and TP53 slide", + ) + parser.add_argument( + "-v", "--verbosity", action="store_true", help="Increase output verbosity" + ) + + args = parser.parse_args() + # PATHS + PREFIX = args.subject_subdir + + # User selectable parameters + ALIGNMENT_MAG = args.align_mag + VERBOSE = args.verbosity + + print("Runing Affine Registration step on slide: {0}".format(PREFIX)) + + start = time.perf_counter() + tp53_gray = load_pil_image(TP53_GRAY, PREFIX) + he_gray = load_pil_image(HE_GRAY, PREFIX) + fixed_img, moving_img = get_fixed_and_moving_images(tp53_gray, he_gray) + performance_df = load_df(subdir=PREFIX) + end = time.perf_counter() + + print(f"Time spent on reloading normaliser and slides: {end-start}s") + + initial_transform = sitk.CenteredTransformInitializer( + fixed_img, + moving_img, + sitk.Euler2DTransform(), + sitk.CenteredTransformInitializerFilter.GEOMETRY, + ) + affine_method = sitk.ImageRegistrationMethod() + + # Similarity metric settings. + affine_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50) + affine_method.SetMetricSamplingStrategy(affine_method.RANDOM) + affine_method.SetMetricSamplingPercentage(0.15) + + affine_method.SetInterpolator(INTERPOLATOR) + + # Optimizer settings. + affine_method.SetOptimizerAsGradientDescent( + learningRate=1, + numberOfIterations=100, + convergenceMinimumValue=1e-6, + convergenceWindowSize=20, + ) + affine_method.SetOptimizerScalesFromPhysicalShift() + + # Setup for the multi-resolution framework. + affine_method.SetShrinkFactorsPerLevel(shrinkFactors=[8, 4]) + affine_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[4, 2]) + affine_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn() + + # Don't optimize in-place, we would possibly like to run this cell multiple times. + affine_method.SetInitialTransform(initial_transform, inPlace=False) + + # Connect all of the observers so that we can perform plotting during registration. + affine_method.AddCommand(sitk.sitkStartEvent, start_plot) + affine_method.AddCommand( + sitk.sitkMultiResolutionIterationEvent, update_multires_iterations + ) + affine_method.AddCommand( + sitk.sitkIterationEvent, lambda: update_plot(affine_method) + ) + + affine_transform = affine_method.Execute( + sitk.Cast(fixed_img, sitk.sitkFloat32), + sitk.Cast(moving_img, sitk.sitkFloat32), + ) + + if VERBOSE: + affine_fig = plot_metric( + "Plot of mutual information cost in affine registration" + ) + plt.show() + save_fig(affine_fig, OUTPUT_PATH.joinpath(PREFIX + "affine_metric_plot.jpeg")) + end_plot() + + print( + "Affine Optimizer's stopping condition, {0}".format( + affine_method.GetOptimizerStopConditionDescription() + ) + ) + + # Compute the mutual information between the two images after affine registration + moving_resampled_affine = sitk.Resample( + moving_img, + fixed_img, + affine_transform, + INTERPOLATOR, + 0.0, + moving_img.GetPixelID(), + ) + affine_mutual_info = calculate_mutual_info( + np.array(he_gray), np.array(get_pil_from_itk(moving_resampled_affine)) + ) + if VERBOSE: + print("Affine mutual information metric: {0}".format(affine_mutual_info)) + + performance_df["Affine_Mutual_Info"] = affine_mutual_info + + dump_df(performance_df, subdir=PREFIX) + + dump_sitk_image( + sitk_image=moving_resampled_affine, + data_name=MOVING_RESAMPLED_AFFINE_NPY, + subdir=PREFIX, + ) + dump_sitk_transform( + sitk_transform=affine_transform, data_name=AFFINE_TRANSFORM_HDF, subdir=PREFIX + ) diff --git a/examples/HEMnet/data_preparator/project/bspline_registration.py b/examples/HEMnet/data_preparator/project/bspline_registration.py new file mode 100644 index 000000000..113856960 --- /dev/null +++ b/examples/HEMnet/data_preparator/project/bspline_registration.py @@ -0,0 +1,241 @@ +import argparse +from mod_utils import ( + load_sitk_transform, + save_fig, + save_img, + dump_pil_image, + load_sitk_image, + load_pil_image, + get_fixed_and_moving_images, + load_and_magnify_slides_by_prefix, + load_df, + dump_df, +) +import matplotlib.pyplot as plt +import SimpleITK as sitk +from mod_constants import ( + OUTPUT_PATH, + INTERPOLATOR, + MOVING_RESAMPLED_AFFINE_NPY, + AFFINE_TRANSFORM_HDF, + HE_FILTERED_NPY, + TP53_FILTERED_NPY, + HE_GRAY, + HE_NORM, + TP53_GRAY, +) +from utils import ( + calculate_mutual_info, + get_pil_from_itk, + start_plot, + update_multires_iterations, + update_plot, + plot_metric, + end_plot, + sitk_transform_rgb, + PlotImageAlignment, + filter_green, + filter_grays, + show_alignment, +) +import time +import numpy as np + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-s", + "--subject-subdir", + type=str, + required=True, + help="Prefix that defines the slides used in this step.", + ) + parser.add_argument( + "-a", + "--align_mag", + type=float, + default=2, + help="Magnification for aligning H&E and TP53 slide", + ) + parser.add_argument( + "-v", "--verbosity", action="store_true", help="Increase output verbosity" + ) + + args = parser.parse_args() + # PATHS + PREFIX = args.subject_subdir + + # User selectable parameters + ALIGNMENT_MAG = args.align_mag + VERBOSE = args.verbosity + + print("Running B-Spline registration step on Slide: {0}".format(PREFIX)) + + start = time.perf_counter() + he, tp53 = load_and_magnify_slides_by_prefix(PREFIX, ALIGNMENT_MAG) + he_norm = load_pil_image(HE_NORM, PREFIX) + tp53_gray = load_pil_image(TP53_GRAY, PREFIX) + he_gray = load_pil_image(HE_GRAY, PREFIX) + fixed_img, moving_img = get_fixed_and_moving_images(tp53_gray, he_gray) + + moving_resampled_affine = load_sitk_image( + data_name=MOVING_RESAMPLED_AFFINE_NPY, subdir=PREFIX + ) + affine_transform = load_sitk_transform( + data_name=AFFINE_TRANSFORM_HDF, subdir=PREFIX + ) + performance_df = load_df(subdir=PREFIX) + end = time.perf_counter() + print(f"Time spent on reloading images and transforms: {end-start}s") + + bspline_method = sitk.ImageRegistrationMethod() + + # Similarity metric settings. + bspline_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50) + bspline_method.SetMetricSamplingStrategy(bspline_method.RANDOM) + bspline_method.SetMetricSamplingPercentage(0.15) + + bspline_method.SetInterpolator(INTERPOLATOR) + + # Optimizer settings. + bspline_method.SetOptimizerAsGradientDescent( + learningRate=1, + numberOfIterations=200, + convergenceMinimumValue=1e-6, + convergenceWindowSize=10, + ) + bspline_method.SetOptimizerScalesFromPhysicalShift() + + # Setup for the multi-resolution framework. + bspline_method.SetShrinkFactorsPerLevel(shrinkFactors=[2, 1]) + bspline_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[1, 0]) + bspline_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn() + + # Don't optimize in-place, we would possibly like to run this cell multiple times. + transformDomainMeshSize = [8] * moving_resampled_affine.GetDimension() + initial_transform = sitk.BSplineTransformInitializer( + fixed_img, transformDomainMeshSize + ) + bspline_method.SetInitialTransform(initial_transform, inPlace=False) + + # Connect all of the observers so that we can perform plotting during registration. + bspline_method.AddCommand(sitk.sitkStartEvent, start_plot) + bspline_method.AddCommand( + sitk.sitkMultiResolutionIterationEvent, update_multires_iterations + ) + bspline_method.AddCommand( + sitk.sitkIterationEvent, lambda: update_plot(bspline_method) + ) + + bspline_transform = bspline_method.Execute( + sitk.Cast(fixed_img, sitk.sitkFloat32), + sitk.Cast(moving_resampled_affine, sitk.sitkFloat32), + ) + + if VERBOSE: + bspline_fig = plot_metric( + "Plot of mutual information cost in B-spline registration" + ) + plt.show() + + save_fig(bspline_fig, OUTPUT_PATH.joinpath(PREFIX + "bspline_metric_plot.jpeg")) + end_plot() + + print( + "B-spline Optimizer's stopping condition, {0}".format( + bspline_method.GetOptimizerStopConditionDescription() + ) + ) + + # Compute the mutual information between the two images after B-spline registration + moving_resampled_final = sitk.Resample( + moving_resampled_affine, + fixed_img, + bspline_transform, + INTERPOLATOR, + 0.0, + moving_img.GetPixelID(), + ) + bspline_mutual_info = calculate_mutual_info( + np.array(he_gray), np.array(get_pil_from_itk(moving_resampled_final)) + ) + if VERBOSE: + print("B-spline mutual information metric: {0}".format(bspline_mutual_info)) + + performance_df["Final_Mutual_Info"] = bspline_mutual_info + + # Transform the original TP53 into the aligned TP53 image + moving_rgb_affine = sitk_transform_rgb( + tp53, he_norm, affine_transform, INTERPOLATOR + ) + tp53_aligned = sitk_transform_rgb( + moving_rgb_affine, he_norm, bspline_transform, INTERPOLATOR + ) + + # Visualise and save alignment + if VERBOSE: + align_plotter = PlotImageAlignment("vertical", 300) + comparison_post_v_stripes = align_plotter.plot_images(he, tp53_aligned) + save_img( + comparison_post_v_stripes.convert("RGB"), + OUTPUT_PATH.joinpath(PREFIX + "comparison_post_align_v_stripes.jpeg"), + "JPEG", + ) + + align_plotter = PlotImageAlignment("horizontal", 300) + comparison_post_h_stripes = align_plotter.plot_images(he, tp53_aligned) + save_img( + comparison_post_h_stripes.convert("RGB"), + OUTPUT_PATH.joinpath(PREFIX + "comparison_post_align_h_stripes.jpeg"), + "JPEG", + ) + + align_plotter = PlotImageAlignment("mosaic", 300) + comparison_post_mosaic = align_plotter.plot_images(he, tp53_aligned) + save_img( + comparison_post_mosaic.convert("RGB"), + OUTPUT_PATH.joinpath(PREFIX + "comparison_post_align_mosaic.jpeg"), + "JPEG", + ) + + # Remove backgrounds from TP53 and H&E images + tp53_filtered = filter_green(tp53_aligned) + he_filtered = filter_green(he_norm) + tp53_filtered = filter_grays(tp53_filtered, tolerance=2) + he_filtered = filter_grays(he_filtered, tolerance=15) + + dump_df(performance_df, subdir=PREFIX) + + dump_pil_image( + pil_image=tp53_filtered, + data_name=TP53_FILTERED_NPY, + subdir=PREFIX, + ) + + dump_pil_image( + pil_image=he_filtered, + data_name=HE_FILTERED_NPY, + subdir=PREFIX, + ) + # Visually compare alignment between the registered TP53 and original H&E image + if VERBOSE: + comparison_post_colour_overlay = show_alignment(he_filtered, tp53_filtered) + save_img( + comparison_post_colour_overlay.convert("RGB"), + OUTPUT_PATH.joinpath(PREFIX + "comparison_post_align_colour_overlay.jpeg"), + "JPEG", + ) + + save_img( + tp53_aligned.convert("RGB"), + OUTPUT_PATH.joinpath(PREFIX + str(ALIGNMENT_MAG) + "x_TP53_aligned.jpeg"), + "JPEG", + ) + save_img( + tp53_filtered.convert("RGB"), + OUTPUT_PATH.joinpath( + PREFIX + str(ALIGNMENT_MAG) + "x_TP53_aligned_white.jpeg" + ), + "JPEG", + ) diff --git a/examples/HEMnet/data_preparator/project/cleanup.py b/examples/HEMnet/data_preparator/project/cleanup.py new file mode 100644 index 000000000..962eb9e44 --- /dev/null +++ b/examples/HEMnet/data_preparator/project/cleanup.py @@ -0,0 +1,7 @@ +import os +import shutil +from mod_constants import TEMP_DATA_PATH + +print("Running cleanup step") +if os.path.exists(TEMP_DATA_PATH): + shutil.rmtree(TEMP_DATA_PATH) diff --git a/examples/HEMnet/data_preparator/project/consolidate_metrics.py b/examples/HEMnet/data_preparator/project/consolidate_metrics.py new file mode 100644 index 000000000..ab46c933b --- /dev/null +++ b/examples/HEMnet/data_preparator/project/consolidate_metrics.py @@ -0,0 +1,29 @@ +import pandas as pd +import os +from mod_constants import OUTPUT_PATH, TEMP_DATA_PATH, PERFORMANCE_DF + +if __name__ == "__main__": + + final_df = pd.DataFrame() + + slide_subidrs = [subdir for subdir in os.listdir(TEMP_DATA_PATH)] + + for slide_subdir in slide_subidrs: + full_subdir = TEMP_DATA_PATH.joinpath(slide_subdir) + if not os.path.isdir(full_subdir): + continue + elif PERFORMANCE_DF not in os.listdir(full_subdir): + print( + f"Performance data not found for slide prefix {slide_subdir}. Will not be included in final metrics." + ) + continue + + df_path = os.path.join(full_subdir, PERFORMANCE_DF) + tmp_df = pd.read_csv(df_path, encoding="utf-8", index_col=0) + final_df = pd.concat([final_df, tmp_df], axis=0) + print(f"tmp_df=\n{tmp_df}") + print(f"final_df=\n{final_df}") + print("------------------------------------") + + final_path = OUTPUT_PATH.joinpath(PERFORMANCE_DF) + final_df.to_csv(final_path) diff --git a/examples/HEMnet/data_preparator/project/generate_masks.py b/examples/HEMnet/data_preparator/project/generate_masks.py new file mode 100644 index 000000000..2e2da54f0 --- /dev/null +++ b/examples/HEMnet/data_preparator/project/generate_masks.py @@ -0,0 +1,189 @@ +import argparse +from mod_utils import save_img, dump_numpy_array, load_pil_image +from mod_constants import ( + OUTPUT_PATH, + HE_FILTERED_NPY, + TP53_FILTERED_NPY, + U_MASK_FILTERED, + C_MASK_FILTERED, + NON_C_MASK_FILTERED, + T_MASK_FILTERED, +) +from utils import ( + cancer_mask, + tissue_mask_grabcut, + plot_masks, +) + +from HEMnet_train_dataset import uncertain_mask, restricted_float +import time +import numpy as np + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-s", + "--subject-subdir", + type=str, + required=True, + help="Prefix that defines the slides used in this step.", + ) + parser.add_argument( + "-a", + "--align_mag", + type=float, + default=2, + help="Magnification for aligning H&E and TP53 slide", + ) + parser.add_argument( + "-m", + "--tile_mag", + type=float, + default=10, + help="Magnification for generating tiles", + ) + parser.add_argument( + "-ts", "--tile_size", type=int, default=224, help="Output tile size in pixels" + ) + parser.add_argument( + "-c", + "--cancer_thresh", + type=restricted_float, + default=0.39, + help="TP53 threshold for cancer classification", + ) + parser.add_argument( + "-nc", + "--non_cancer_thresh", + type=restricted_float, + default=0.40, + help="TP53 threshold for non-cancer classification", + ) + parser.add_argument( + "-v", "--verbosity", action="store_true", help="Increase output verbosity" + ) + + args = parser.parse_args() + # PATHS + PREFIX = args.subject_subdir + + # User selectable parameters + ALIGNMENT_MAG = args.align_mag + VERBOSE = args.verbosity + CANCER_THRESH = args.cancer_thresh + NON_CANCER_THRESH = args.non_cancer_thresh + TILE_MAG = args.tile_mag + OUTPUT_TILE_SIZE = args.tile_size + + print("Running Mask Generation step on Slide: {0}".format(PREFIX)) + + start = time.perf_counter() + he_filtered = load_pil_image(HE_FILTERED_NPY, PREFIX) + tp53_filtered = load_pil_image(TP53_FILTERED_NPY, PREFIX) + end = time.perf_counter() + print(f"Time spent on reloading filtered images: {end-start}s") + #################################### + # Generate cancer and tissue masks # + #################################### + + # Scale tile size for alignment mag + tile_size = OUTPUT_TILE_SIZE * ALIGNMENT_MAG / TILE_MAG + + # Generate cancer mask and tissue mask from filtered tp53 image + c_mask = cancer_mask(tp53_filtered, tile_size, 250).astype(np.bool) + t_mask_tp53 = tissue_mask_grabcut(tp53_filtered, tile_size) + t_mask_he = tissue_mask_grabcut(he_filtered, tile_size) + + # Generate tissue mask with tissue common to both the TP53 and H&E image + t_mask = np.logical_not(np.logical_not(t_mask_tp53) & np.logical_not(t_mask_he)) + + # Generate uncertain mask + u_mask = uncertain_mask(tp53_filtered, tile_size, CANCER_THRESH, NON_CANCER_THRESH) + u_mask_filtered = np.logical_not(np.logical_not(u_mask) & np.logical_not(t_mask)) + + # Filter tissue mask such that any uncertain tiles are removed + t_mask_filtered = np.zeros(t_mask.shape) + for x in range(t_mask.shape[0]): + for y in range(t_mask.shape[1]): + if t_mask[x, y] == 0 and u_mask[x, y] == 1: + t_mask_filtered[x, y] = False + else: + t_mask_filtered[x, y] = True + + non_c_mask = np.invert(c_mask) + non_c_mask = np.logical_not( + np.logical_and(np.logical_not(non_c_mask), np.logical_not(t_mask_filtered)) + ) + + # If Slide is normal + if "_N_" in PREFIX: + # if True: + # Merge cancer mask with uncertain mask + # This marks all tiles that are uncertain or cancer as uncertain + u_mask_filtered = np.logical_not( + np.logical_or(np.logical_not(u_mask_filtered), np.logical_not(c_mask)) + ) + # Blank out cancer mask so no cancer tiles exist + c_mask_filtered = np.ones(c_mask.shape, dtype=bool) + # Non cancer tiles are tiles that are in the tissue and not cancer + non_c_mask_filtered = np.logical_not( + np.logical_and(np.logical_not(non_c_mask), np.logical_not(t_mask_filtered)) + ) + if VERBOSE: + print("Normal Slide Identified") + + # If Slide is cancerous + if "T" in PREFIX: + # if False: + # Merge non-cancer mask with uncertain mask + u_mask_filtered = np.logical_not( + np.logical_or(np.logical_not(non_c_mask), np.logical_not(u_mask_filtered)) + ) + # Blank out non cancer mask + non_c_mask_filtered = np.ones(non_c_mask.shape, dtype=bool) + # Cancer tile are tiles that are in the tissue and not cancer + # Make sure all cancer tiles exist in the tissue mask + c_mask_filtered = np.logical_not( + np.logical_not(c_mask) & np.logical_not(t_mask_filtered) + ) + + # Overlay masks onto TP53 and H&E Image + if VERBOSE: + print("Cancer Slide Identified") + overlay_tp53 = plot_masks( + tp53_filtered, + c_mask_filtered, + t_mask_filtered, + tile_size, + u_mask_filtered, + ) + save_img( + overlay_tp53.convert("RGB"), + OUTPUT_PATH.joinpath(PREFIX + "TP53_overlay.jpeg"), + "JPEG", + ) + + overlay_he = plot_masks( + he_filtered, + c_mask_filtered, + t_mask_filtered, + tile_size, + u_mask_filtered, + ) + save_img( + overlay_he.convert("RGB"), + OUTPUT_PATH.joinpath(PREFIX + "HE_overlay.jpeg"), + "JPEG", + ) + + arrays_to_save_tuples = [ + (u_mask_filtered, U_MASK_FILTERED), + (c_mask_filtered, C_MASK_FILTERED), + (non_c_mask_filtered, NON_C_MASK_FILTERED), + (t_mask_filtered, T_MASK_FILTERED), + ] + + for array_to_save_tuple in arrays_to_save_tuples: + array, filename = array_to_save_tuple + dump_numpy_array(array, filename, PREFIX) diff --git a/examples/HEMnet/data_preparator/project/image_registration.py b/examples/HEMnet/data_preparator/project/image_registration.py new file mode 100644 index 000000000..1f8b750d6 --- /dev/null +++ b/examples/HEMnet/data_preparator/project/image_registration.py @@ -0,0 +1,165 @@ +import argparse +from mod_utils import ( + load_and_magnify_slides_by_prefix, + save_img, + load_data, + dump_pil_image, + get_fixed_and_moving_images, + dump_data, + get_slide_names_by_prefix, + get_template_slide_from_dir, + dump_df, +) +import SimpleITK as sitk +from mod_constants import ( + OUTPUT_PATH, + INTERPOLATOR, + NORMALISER_PKL, + HE_NORM, + HE_GRAY, + TP53_GRAY, +) +from utils import ( + sitk_transform_rgb, + PlotImageAlignment, + show_alignment, + calculate_mutual_info, + get_pil_from_itk, +) +import numpy as np +import os +import pandas as pd + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-s", + "--subject-subdir", + type=str, + required=True, + help="Prefix that defines the slides used in this step.", + ) + parser.add_argument( + "-a", + "--align_mag", + type=float, + default=2, + help="Magnification for aligning H&E and TP53 slide", + ) + parser.add_argument( + "-v", "--verbosity", action="store_true", help="Increase output verbosity" + ) + + args = parser.parse_args() + # PATHS + PREFIX = args.subject_subdir + + # User selectable parameters + ALIGNMENT_MAG = args.align_mag + VERBOSE = args.verbosity + + print("Running Image Registration step on Slide: {0}".format(PREFIX)) + + template_slide_name = get_template_slide_from_dir() + template_slide_name = os.path.basename(template_slide_name) + he_name, tp53_name = get_slide_names_by_prefix(PREFIX) + he, tp53 = load_and_magnify_slides_by_prefix(PREFIX, ALIGNMENT_MAG) + normaliser = load_data(data_name=NORMALISER_PKL) + + # Normalise H&E Slide + normaliser.fit_source(he) + he_norm = normaliser.transform_tile(he) + dump_data(data_obj=normaliser, data_name=NORMALISER_PKL, subdir=PREFIX) + + if VERBOSE: + save_img( + he_norm.convert("RGB"), + OUTPUT_PATH.joinpath(PREFIX + str(ALIGNMENT_MAG) + "x_normalised.jpeg"), + "JPEG", + ) + + ###################### + # Image Registration # + ###################### + + # Convert to grayscale + tp53_gray = tp53.convert("L") + he_gray = he_norm.convert("L") + + # Dump images necessary for future steps + dump_pil_image(he_norm, HE_NORM, PREFIX) + dump_pil_image(he_gray, HE_GRAY, PREFIX) + dump_pil_image(tp53_gray, TP53_GRAY, PREFIX) + + # Set fixed and moving images + fixed_img, moving_img = get_fixed_and_moving_images(tp53_gray, he_gray) + + # Check initial registration + # Centre the two images, then compare their alignment + initial_transform = sitk.CenteredTransformInitializer( + fixed_img, + moving_img, + sitk.Euler2DTransform(), + sitk.CenteredTransformInitializerFilter.GEOMETRY, + ) + moving_rgb = sitk_transform_rgb(tp53, he_norm, initial_transform) + + # Visualise and save alignment + if VERBOSE: + align_plotter = PlotImageAlignment("vertical", 300) + comparison_pre_v_stripes = align_plotter.plot_images(he, moving_rgb) + save_img( + comparison_pre_v_stripes.convert("RGB"), + OUTPUT_PATH.joinpath(PREFIX + "comparison_pre_align_v_stripes.jpeg"), + "JPEG", + ) + + align_plotter = PlotImageAlignment("horizontal", 300) + comparison_pre_h_stripes = align_plotter.plot_images(he, moving_rgb) + save_img( + comparison_pre_h_stripes.convert("RGB"), + OUTPUT_PATH.joinpath(PREFIX + "comparison_pre_align_h_stripes.jpeg"), + "JPEG", + ) + + align_plotter = PlotImageAlignment("mosaic", 300) + comparison_pre_mosaic = align_plotter.plot_images(he, moving_rgb) + save_img( + comparison_pre_mosaic.convert("RGB"), + OUTPUT_PATH.joinpath(PREFIX + "comparison_pre_align_mosaic.jpeg"), + "JPEG", + ) + + comparison_pre_colour_overlay = show_alignment( + he_norm, moving_rgb, prefilter=True + ) + save_img( + comparison_pre_colour_overlay.convert("RGB"), + OUTPUT_PATH.joinpath(PREFIX + "comparison_pre_align_colour_overlay.jpeg"), + "JPEG", + ) + + # Compute the mutual information between the two images before registration + moving_resampled_initial = sitk.Resample( + moving_img, + fixed_img, + initial_transform, + INTERPOLATOR, + 0.0, + moving_img.GetPixelID(), + ) + initial_mutual_info = calculate_mutual_info( + np.array(he_gray), np.array(get_pil_from_itk(moving_resampled_initial)) + ) + if VERBOSE: + print("Initial mutual information metric: {0}".format(initial_mutual_info)) + + info_dict = { + "TP53_Slide_Name": tp53_name, + "H&E_Slide_Name": he_name, + "Template_Slide_Name": template_slide_name, + "Initial_Mutual_Info": initial_mutual_info, + } + performance_df = pd.DataFrame([info_dict]) + dump_df(performance_df, subdir=PREFIX) diff --git a/examples/HEMnet/data_preparator/project/mod_constants.py b/examples/HEMnet/data_preparator/project/mod_constants.py new file mode 100644 index 000000000..4661142ea --- /dev/null +++ b/examples/HEMnet/data_preparator/project/mod_constants.py @@ -0,0 +1,21 @@ +from pathlib import Path +import SimpleITK as sitk + +BASE_DIR = Path("/workspace") +INPUT_PATH = BASE_DIR.joinpath("input_data") +OUTPUT_PATH = BASE_DIR.joinpath("data") +TEMP_DATA_PATH = OUTPUT_PATH.joinpath(".tmp") +NORMALISER_PKL = "normaliser.pkl" +AFFINE_TRANSFORM_HDF = "affine_transform.hdf" +MOVING_RESAMPLED_AFFINE_NPY = "moving_resampled_affine.npy" +TP53_FILTERED_NPY = "tp53_filtered.npy" +HE_FILTERED_NPY = "he_filtered.npy" +INTERPOLATOR = sitk.sitkLanczosWindowedSinc +U_MASK_FILTERED = "u_mask_filtered.npy" +C_MASK_FILTERED = "c_mask_filtered.npy" +NON_C_MASK_FILTERED = "non_c_mask_filtered.npy" +T_MASK_FILTERED = "t_mask_filtered.npy" +HE_NORM = "he_norm.npy" +HE_GRAY = "he_gray.npy" +TP53_GRAY = "tp53_gray.npy" +PERFORMANCE_DF = "performance_metrics.csv" diff --git a/examples/HEMnet/data_preparator/project/mod_utils.py b/examples/HEMnet/data_preparator/project/mod_utils.py new file mode 100644 index 000000000..2008f1dd0 --- /dev/null +++ b/examples/HEMnet/data_preparator/project/mod_utils.py @@ -0,0 +1,253 @@ +import os +from openslide import open_slide +from mod_constants import INPUT_PATH, TEMP_DATA_PATH, NORMALISER_PKL, PERFORMANCE_DF +from slide import read_slide_at_mag +from normaliser import IterativeNormaliser +import pickle +import SimpleITK as sitk +import numpy as np +from utils import get_pil_from_itk +from PIL import Image +from utils import get_itk_from_pil +import pandas as pd + + +def save_img(img, path, img_type): + img.save(path, img_type) + + +def save_fig(fig, path, dpi=300): + fig.savefig(path, dpi=dpi) + + +def get_template_slide_from_dir(template_slide_path: str = None): + input_dir = str(INPUT_PATH) + template_dir = os.path.join(input_dir, "template") + try: + slides = [file for file in os.listdir(template_dir) if file.endswith(".svs")] + + template_slide_path = os.path.join(template_dir, slides[0]) + + if len(slides) > 1: + print( + f"More than 1 slide found at {template_dir}. Using {template_slide_path} as the template." + ) + except OSError: + raise ValueError( + f"Please provide an explicit template slide either with the -t option or by setting a single .svs file at the {os.path.join(template_dir)} directory!" + ) + return template_slide_path + + +def create_target_fitted_normaliser( + alignment_mag, normaliser_method, standardise_luminosity +) -> IterativeNormaliser: + template_slide_path = get_template_slide_from_dir() + + print( + f"Using slide located at {template_slide_path} as the template to instantiate normaliser." + ) + template_slide = open_slide(str(template_slide_path)) + template_img = read_slide_at_mag(template_slide, alignment_mag).convert("RGB") + + normaliser = IterativeNormaliser(normaliser_method, standardise_luminosity) + normaliser.fit_target(template_img) + + return normaliser + + +def _get_saved_file_full_path(filename, subdir: str = None): + os.makedirs(TEMP_DATA_PATH, exist_ok=True) + data_dir = TEMP_DATA_PATH + if subdir is not None: + data_dir = TEMP_DATA_PATH.joinpath(subdir) + os.makedirs(data_dir, exist_ok=True) + full_pickle_path = data_dir.joinpath(filename) + return full_pickle_path + + +def dump_numpy_array(np_array, data_name: str, subdir: str = None): + np_path = _get_saved_file_full_path(data_name, subdir) + with open(np_path, "wb") as f: + np.save(f, np_array) + + +def load_numpy_array(data_name: str, subdir: str = None): + np_path = _get_saved_file_full_path(data_name, subdir) + with open(np_path, "rb") as f: + np_array = np.load(f) + return np_array + + +def dump_pil_image(pil_image, data_name: str, subdir: str = None): + as_np = np.array(pil_image) + dump_numpy_array(as_np, data_name, subdir) + + +def load_pil_image(data_name: str, subdir: str = None): + as_np = load_numpy_array(data_name, subdir) + pil_image = Image.fromarray(as_np) + return pil_image + + +def dump_sitk_image(sitk_image, data_name: str, subdir: str = None): + as_pil = get_pil_from_itk(sitk_image) + dump_pil_image(as_pil, data_name, subdir) + + +def load_sitk_image(data_name, subdir: str = None): + as_np = load_numpy_array(data_name, subdir) + as_itk = sitk.GetImageFromArray(as_np) + return as_itk + + +def dump_sitk_transform( + sitk_transform: sitk.Transform, data_name: str, subdir: str = None +): + print(f"Dumping SITK transform {data_name}...") + transform_path = str(_get_saved_file_full_path(data_name, subdir)) + sitk_transform.FlattenTransform() + sitk_transform.WriteTransform(transform_path) + + +def load_sitk_transform(data_name, subdir: str = None): + transform_path = str(_get_saved_file_full_path(data_name, subdir)) + sitk_transform = sitk.ReadTransform(transform_path) + return sitk_transform + + +def dump_data(data_obj, data_name: str, subdir: str = None): + full_path = _get_saved_file_full_path(data_name, subdir) + with open(full_path, "wb") as f: + pickle.dump(data_obj, f) + print(f"Successfully dumped object {data_name} at {full_path}.") + + +def load_data(data_name: str, subdir: str = None): + full_path = _get_saved_file_full_path(data_name, subdir) + with open(full_path, "rb") as f: + normalizer_obj = pickle.load(f) + print(f"Successfully loaded object {data_name} from {full_path}.") + return normalizer_obj + + +def dump_df(df: pd.DataFrame, df_name: str = PERFORMANCE_DF, subdir: str = None): + full_path = _get_saved_file_full_path(df_name, subdir) + df.to_csv(full_path, encoding="utf-8") + + +def load_df(df_name: str = PERFORMANCE_DF, subdir: str = None): + full_path = _get_saved_file_full_path(df_name, subdir) + df = pd.read_csv(full_path, encoding="utf-8", index_col=0) + return df + + +def get_slide_names_by_prefix(prefix: str): + relevant_filenames = sorted( + [file for file in os.listdir(INPUT_PATH) if prefix in file], + key=lambda file: ( + file.split("_")[0], + file.split("_")[-1], + ), # Order by prefix (integer) then by suffix (HandE first, then TP53) + ) + he_name, tp53_name = relevant_filenames + return he_name, tp53_name + + +def load_slides_by_prefix(prefix: str): + print(f"Loading slides with prefix {prefix}") + relevant_filenames = get_slide_names_by_prefix(prefix) + relevant_filepaths = sorted( + [os.path.join(INPUT_PATH, file) for file in relevant_filenames] + ) + + he_path, tp53_path = relevant_filepaths + + tp53_slide = open_slide(tp53_path) + he_slide = open_slide(he_path) + + print(f"Successfully loaded slides with prefix {prefix}.") + return he_slide, tp53_slide + + +def load_and_magnify_slides_by_prefix(prefix: str, aligment_mag: float): + he_slide, tp53_slide = load_slides_by_prefix(prefix) + + # Load Slides + he = read_slide_at_mag(he_slide, aligment_mag) + tp53 = read_slide_at_mag(tp53_slide, aligment_mag) + print(f"Successfully loaded and magnified slides with prefix {prefix}.") + return he, tp53 + + +def get_fixed_and_moving_images(tp53_gray, he_gray): + # Convert to ITK format + tp53_itk = get_itk_from_pil(tp53_gray) + he_itk = get_itk_from_pil(he_gray) + + fixed_img = he_itk + moving_img = tp53_itk + + return fixed_img, moving_img + + +def save_train_tiles( + path, + tile_gen, + cancer_mask, + tissue_mask, + uncertain_mask, + prefix="", + verbose: bool = False, +): + """Save tiles for train dataset + + Parameters + ---------- + path : Pathlib Path + tile_gen : tile_gen + cancer_mask : ndarray + tissue_mask : ndarray + uncertain_mask : ndarray + prefix : str (optional) + + Returns + ------- + None + """ + normaliser = load_data(data_name=NORMALISER_PKL, subdir=prefix) + os.makedirs(path.joinpath("cancer"), exist_ok=True) + os.makedirs(path.joinpath("non-cancer"), exist_ok=True) + os.makedirs(path.joinpath("uncertain"), exist_ok=True) + x_tiles, y_tiles = next(tile_gen) + + if verbose: + print("Whole Image Size is {0} x {1}".format(x_tiles, y_tiles)) + i = 0 + cancer = 0 + uncertain = 0 + non_cancer = 0 + for tile in tile_gen: + img = tile.convert("RGB") + ### + img_norm = normaliser.transform_tile(img) + ### + # Name tile as horizontal position _ vertical position starting at (0,0) + tile_name = prefix + str(np.floor_divide(i, x_tiles)) + "_" + str(i % x_tiles) + if uncertain_mask.ravel()[i] == 0: + img_norm.save(path.joinpath("uncertain", tile_name + ".jpeg"), "JPEG") + uncertain += 1 + elif cancer_mask.ravel()[i] == 0: + img_norm.save(path.joinpath("cancer", tile_name + ".jpeg"), "JPEG") + cancer += 1 + elif tissue_mask.ravel()[i] == 0: + img_norm.save(path.joinpath("non-cancer", tile_name + ".jpeg"), "JPEG") + non_cancer += 1 + i += 1 + if verbose: + print( + "Cancer tiles: {0}, Non Cancer tiles: {1}, Uncertain tiles: {2}".format( + cancer, non_cancer, uncertain + ) + ) + print("Exported tiles for {0}".format(prefix)) diff --git a/examples/HEMnet/data_preparator/project/normaliser_step.py b/examples/HEMnet/data_preparator/project/normaliser_step.py new file mode 100644 index 000000000..32f098663 --- /dev/null +++ b/examples/HEMnet/data_preparator/project/normaliser_step.py @@ -0,0 +1,42 @@ +import argparse +from pathlib import Path +from mod_utils import ( + create_target_fitted_normaliser, + dump_data, +) +from mod_constants import NORMALISER_PKL + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-n", + "--normaliser", + type=str, + default="vahadane", + choices=["none", "reinhard", "macenko", "vahadane"], + help="H&E normalisation method", + ) + parser.add_argument( + "-std", + "--standardise_luminosity", + action="store_false", + help="Disable luminosity standardisation", + ) + parser.add_argument( + "-a", + "--align_mag", + type=float, + default=2, + help="Magnification for aligning H&E and TP53 slide", + ) + + print("Running Normaliser step") + args = parser.parse_args() + + ALIGNMENT_MAG = args.align_mag + NORMALISER_METHOD = args.normaliser + STANDARDISE_LUMINOSITY = args.standardise_luminosity + + normaliser = create_target_fitted_normaliser(ALIGNMENT_MAG, NORMALISER_METHOD, STANDARDISE_LUMINOSITY) + dump_data(data_obj=normaliser, data_name=NORMALISER_PKL) diff --git a/examples/HEMnet/data_preparator/project/save_tiles.py b/examples/HEMnet/data_preparator/project/save_tiles.py new file mode 100644 index 000000000..798a61936 --- /dev/null +++ b/examples/HEMnet/data_preparator/project/save_tiles.py @@ -0,0 +1,115 @@ +import argparse +from mod_utils import load_numpy_array, load_slides_by_prefix, save_train_tiles, load_df, dump_df +from mod_constants import ( + OUTPUT_PATH, + U_MASK_FILTERED, + NON_C_MASK_FILTERED, + C_MASK_FILTERED, + T_MASK_FILTERED, +) +from slide import tile_gen_at_mag + +from HEMnet_train_dataset import restricted_float +import time +import numpy as np +import os + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-s", + "--subject-subdir", + type=str, + required=True, + help="Prefix that defines the slides used in this step.", + ) + parser.add_argument( + "-a", + "--align_mag", + type=float, + default=2, + help="Magnification for aligning H&E and TP53 slide", + ) + parser.add_argument( + "-m", + "--tile_mag", + type=float, + default=10, + help="Magnification for generating tiles", + ) + parser.add_argument( + "-ts", "--tile_size", type=int, default=224, help="Output tile size in pixels" + ) + parser.add_argument( + "-c", + "--cancer_thresh", + type=restricted_float, + default=0.39, + help="TP53 threshold for cancer classification", + ) + parser.add_argument( + "-nc", + "--non_cancer_thresh", + type=restricted_float, + default=0.40, + help="TP53 threshold for non-cancer classification", + ) + parser.add_argument( + "-v", "--verbosity", action="store_true", help="Increase output verbosity" + ) + + args = parser.parse_args() + # PATHS + PREFIX = args.subject_subdir + + # User selectable parameters + ALIGNMENT_MAG = args.align_mag + VERBOSE = args.verbosity + CANCER_THRESH = args.cancer_thresh + NON_CANCER_THRESH = args.non_cancer_thresh + TILE_MAG = args.tile_mag + OUTPUT_TILE_SIZE = args.tile_size + + print("Saving tiles from Slide: {0}".format(PREFIX)) + + start = time.perf_counter() + he_slide, _ = load_slides_by_prefix(PREFIX) + u_mask_filtered = load_numpy_array(U_MASK_FILTERED, PREFIX) + c_mask_filtered = load_numpy_array(C_MASK_FILTERED, PREFIX) + non_c_mask_filtered = (load_numpy_array(NON_C_MASK_FILTERED, PREFIX),) + t_mask_filtered = load_numpy_array(T_MASK_FILTERED, PREFIX) + performance_df = load_df(subdir=PREFIX) + end = time.perf_counter() + print(f"Time spent on reloading normaliser and slides: {end-start}s") + + ############## + # Save Tiles # + ############## + + # Make Directory to save tiles + TILES_PATH = OUTPUT_PATH.joinpath("tiles_" + str(TILE_MAG) + "x") + os.makedirs(TILES_PATH, exist_ok=True) + + # Save tiles + tgen = tile_gen_at_mag(he_slide, TILE_MAG, OUTPUT_TILE_SIZE) + save_train_tiles( + TILES_PATH, + tgen, + c_mask_filtered, + t_mask_filtered, + u_mask_filtered, + prefix=PREFIX, + ) + + non_cancer_tiles = np.invert(non_c_mask_filtered).sum() + + uncertain_tiles = np.invert(u_mask_filtered).sum() + + cancer_tiles = np.invert(c_mask_filtered).sum() + + performance_df["Cancer_Tiles"] = cancer_tiles + performance_df["Uncertain_Tiles"] = uncertain_tiles + performance_df["Non_Cancer_Tiles"] = non_cancer_tiles + + dump_df(performance_df, subdir=PREFIX) diff --git a/examples/HEMnet/data_preparator/workflow.yaml b/examples/HEMnet/data_preparator/workflow.yaml new file mode 100644 index 000000000..e1c593705 --- /dev/null +++ b/examples/HEMnet/data_preparator/workflow.yaml @@ -0,0 +1,160 @@ +base_step: &BASE_STEP + - type: container + image: mlcommons/hemnet-airflow:0.0.1 + +steps: + - id: create_normalisation + <<: *BASE_STEP + command: python normaliser_step.py + mounts: + input_volumes: + data_path: + mount_path: /workspace/input_data + type: directory + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + next: image_registration + per_subject: false + + - id: image_registration + <<: *BASE_STEP + mounts: + input_volumes: + data_path: + mount_path: /workspace/input_data + type: directory + normalisation_input: + mount_path: /workspace/data + type: directory + from: + step: create_normalisation + mount: output_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + command: python image_registration.py -v + per_subject: true + next: affine_registration + + - id: affine_registration + <<: *BASE_STEP + mounts: + input_volumes: + image_registration_input: + mount_path: /workspace/data + type: directory + from: + step: image_registration + mount: output_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + command: python affine_registration.py -v + next: bspline_registration + per_subject: true + limit: 1 + + - id: bspline_registration + <<: *BASE_STEP + mounts: + input_volumes: + data_path: + mount_path: /workspace/input_data + type: directory + affine_registration_input: + mount_path: /workspace/data + type: directory + from: + step: affine_registration + mount: output_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + command: python bspline_registration.py -v + next: generate_masks + per_subject: true + cpu_share: 4 + mem_limit: 3g + limit: 1 + + - id: generate_masks + <<: *BASE_STEP + mounts: + input_volumes: + bspline_registration_input: + mount_path: /workspace/data + type: directory + from: + step: bspline_registration + mount: output_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + command: python generate_masks.py -v + next: save_tiles + per_subject: true + limit: 2 + + - id: save_tiles + <<: *BASE_STEP + mounts: + input_volumes: + data_path: + mount_path: /workspace/input_data + type: directory + generate_masks_input: + mount_path: /workspace/data + type: directory + from: + step: generate_masks + mount: output_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + command: python save_tiles.py -v + next: consolidate_metrics + per_subject: true + cpu_share: 3 + limit: 2 + + - id: consolidate_metrics + <<: *BASE_STEP + mounts: + input_volumes: + save_tiles_input: + mount_path: /workspace/data + type: directory + from: + step: save_tiles + mount: output_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + command: python consolidate_metrics.py + next: cleanup + + - id: cleanup + <<: *BASE_STEP + per_subject: false + mounts: + input_volumes: + consolidate_metrics_input: + mount_path: /workspace/data + type: directory + from: + step: consolidate_metrics + mount: output_path + command: python cleanup.py + next: null + +per_subject_def: + type: function + function_name: slides_definition.slides_definition \ No newline at end of file diff --git a/examples/HEMnet/data_preparator/workspace/additional_files/slides_definition.py b/examples/HEMnet/data_preparator/workspace/additional_files/slides_definition.py new file mode 100644 index 000000000..d75f279e8 --- /dev/null +++ b/examples/HEMnet/data_preparator/workspace/additional_files/slides_definition.py @@ -0,0 +1,20 @@ +def slides_definition(pipeline_state): + """ + This preserves the behavior of the original code of assuming the user only + sends paired slides into the input directory. + If unpaired slides are sent, this code does NOT check for that! + """ + from pathlib import Path + + input_data_dir = Path(pipeline_state.host_input_data_path) + slides = [] + + for slide in input_data_dir.glob("*.svs"): + name = slide.name + slides.append(name) + slides.sort() + TP53_slides = [slide for slide in slides if "TP53" in slide] + HE_slides = [slide for slide in slides if "HandE" in slide] + Paired_slides = list(zip(TP53_slides, HE_slides)) + prefixes = [paired_slide[0][:-10] for paired_slide in Paired_slides] + return prefixes diff --git a/examples/RANO/data_preparator_workflow/.gitignore b/examples/RANO/data_preparator_workflow/.gitignore new file mode 100644 index 000000000..5365add4c --- /dev/null +++ b/examples/RANO/data_preparator_workflow/.gitignore @@ -0,0 +1,7 @@ +*data/ +constants.py +tmpmodel/ +models/ +atlasImage_*.nii.gz +container_config.yaml +dev_models_and_more.tar.gz \ No newline at end of file diff --git a/examples/RANO/data_preparator_workflow/README.md b/examples/RANO/data_preparator_workflow/README.md new file mode 100644 index 000000000..fcdf904e4 --- /dev/null +++ b/examples/RANO/data_preparator_workflow/README.md @@ -0,0 +1,277 @@ +# RANO Data Preparation + +Here, a modified version of the data preparation pipeline used in the RANO study is presented as a use case for the YAML Pipelines in Airflow. The YAML file defining this pipeline is located at `./workflow.yaml`, with auxiliary Python code used to evaluate conditional steps in the pipeline located at `additional_files/conditions.py` and `additional_files/subject_definition.py`. A slightly modified version of the pipeline used for development and testing purposes is located at `./workflow_dev.yaml`. + +[Section 1](#1-preparing-to-run-the-pipeline-with-the-developement-dataset) of this README describes the initial setup to run the development pipeline, while [Section 2](#2-preparing-to-run-the-pipeline-with-real-data) describes how to do the initial setup to run the pipeline on real data. The following sections then describe how to properly execute the pipeline via MedPerf and monitor the pipeline on Airflow, regardless of whether the development or real pipeline is used. + +## 1. Preparing to run the Pipeline with the Developement dataset + +The following steps allow running the development pipeline in a local environment: + +- Ensure you have a valid local MedPerf installation. Detailed instructions are available in the [Official MedPerf Documentation for Installation](https://docs.medperf.org/getting_started/installation/). + +- Start a local MedPerf development server. Detailed instructions are available in the [Official MedPerf Documentation for local server setup](https://docs.medperf.org/getting_started/setup/#install-the-medperf-client). + +- (Optional step) If logged into a real MedPerf user, first logout. +```bash +medperf auth logout +``` + +- Log into the test Benchmark Owner user. +```bash +medperf auth login -e testbo@example.com +``` + +- (Optional step) Reset the local development server databse by execution the following command inside the `server` directory at the roost of this repository: +```bash +sh reset_db.sh +``` + +- Seed the server with the Development RANO Workflow. This can be done by executiong the following command, inside the `server` directory at the root of this repository: +```bash +python seed.py --cert cert.crt --demo rano +``` + +- Confirm the seeding process was succesful by executing the following command in any directory: +```bash +medperf container ls +``` + +The output should be similar to what's shown below: +```text + UID Name State Registered +----- ------------------ --------- ------------ + 1 MedPerf CA OPERATION True + 2 rano_workflow_prep OPERATION True +``` + +Note the `UID` field of the `rano_workflow_prep` workflow. It should be `2` if the local database was reset prior to seeding, but may be a different value if no reset was done. This UID will be used when running the data preparation step. + +- Get the RANO development dataset by downloading and extracting (sha256: 701fbba8b253fc5b2f54660837c493a38dec986df9bdbf3d97f07c8bc276a965): + + +- Extract the `input_data` directory into any directory you like, but keep note of its location. It will be used later. + - The other files and directories inside the tarball may be safely deleted or ignored. + + +## 2. Preparing to run the Pipeline with Real data + +The following steps allow running the development pipeline in a local environment: + +- Ensure you have a valid local MedPerf installation. Detailed instructions are available in the [Official MedPerf Documentation for Installation](https://docs.medperf.org/getting_started/installation/). + +- Start a local MedPerf development server. Detailed instructions are available in the [Official MedPerf Documentation for local server setup](https://docs.medperf.org/getting_started/setup/#install-the-medperf-client). + +- (Optional step) If logged into a real MedPerf user, first logout. +```bash +medperf auth logout +``` + +- Log into the test Benchmark Owner user. +```bash +medperf auth login -e testbo@example.com +``` + +- (Optional step) Reset the local development server databse by execution the following command inside the `server` directory at the roost of this repository: +```bash +sh reset_db.sh +``` + +- Register the real dataset workflow in MedPerf. **Inside the same directory as this README file**, run the following command: +```bash +medperf container submit --name rano_workflow --container-config-file /workflow.yaml --parameters-file ./workspace/parameters.yaml --additional-file https://storage.googleapis.com/medperf-storage/rano_test_assets/dev_models_and_more.tar.gz --operational +``` + +- Confirm the seeding process was succesful by executing the following command in any directory: +```bash +medperf container ls +``` + +The output should be similar to what's shown below: +```text + UID Name State Registered +----- ------------------ --------- ------------ + 1 MedPerf CA OPERATION True + 2 rano_workflow OPERATION True +``` + +Note the `UID` field of the `rano_workflow` workflow. It should be `2` if the local database was reset prior to seeding, but may be a different value if no reset was done. This UID will be used when running the data preparation step. + +- Prepare your input data according to [Section 2.1](#21-structuring-your-data) below. + +### 2.1. Structuring your data + +You may create your `input_data` directory anywhere, but please ensure that it is in a location with relatively fast read/write access and with at least 2x more free disk space than your dataset currently occupies. Inside the `input_data` directory, your data needs to follow a folder hierarchy where images are separated by \/\/\. + +**Please note**: For the RANO study, Series-level folders must use the following abbreviations: t2f (T2-weighted FLAIR), t1n (T1-weighted non-contrast), t1c (T1-weighted with contrast), and t2w (T2-weighted). For more information about the required series, please refer to the FeTS 2.0 manual. PatientID and Timepoint must be unique between and within patients, respectively, and Timepoint should be sortable into chronologic order. + +``` +. +├── input_data +│ ├── AAAC_0 +│ ├── 2008.03.30 +│ │ ├── t2f +│ │ ├── t2_Flair_axial-2_echo1_S0002_I000001.dcm +│ │ │ └── ... +│ │ ├── t1n +│ │ │ ├── t1_axial-3_echo1_S0003_I000001.dcm +│ │ │ └── ... +│ │ ├── t1c +│ │ │ ├── t1_axial_stealth-post-14_echo1_S0014_I000001.dcm +| │ │ │ └── ... +│ │ │ └── t2w +│ │ │ ├── T2_SAG_SPACE-4_echo1_S0004_I000001.dcm +│ │ │ └── ... +``` + +## 3. Registering your dataset +Your dataset must be registered into MedPerf for pipeline execution. The following command my be used to register the dataset on MedPerf +```bash +medperf dataset --name RANODataset --data_path --labels_path --data_prep -y +``` +Make sure to substitute in the proper values for the `--data_path`, `--labels_path` and `--data-prep` arguments according to the setup done in either [Section 1](#1-preparing-to-run-the-pipeline-with-the-developement-dataset) or [Section 2](#2-preparing-to-run-the-pipeline-with-real-data). + +Confirm your dataset's UID by running the command +```bash +medperf dataset ls --mine +``` + +The output should be similar to what is shown below. If a local database reset was run at the start of the tutorial, it should be `UID` `1`. Keep note of the dataset UID, as it will be used in [Section 4](#4-running-the-rano-pipeline-via-medperf) +```text + UID Name Data Preparation Container UID State Status Owner +----- ------ -------------------------------- ----------- -------- ------- + 1 RANO 2 DEVELOPMENT +``` + + +## 4. Running the RANO Pipeline via MedPerf +Once the dataset is registered, the pipeline may be executed by running +```bash +medperf dataset prepare --data_uid -y +``` +Make sure to substitute in the dataset UID seen in [Section 3](#3-registering-your-dataset). MedPerf will then being pulling the necessary Docker iamge and then immediately run Airflow. [Section 5](#5-pipeline-overiew) will give a general view of the pipeline, while [Section 6](#6-monitoring-in-airflow) will go on details of how to monitor the pipeline in Airflow. + +## 5. Pipeline Overiew +A general view of the pipeline is shown in the Figure below. A initial setup creating required directories is performed at first. Then, the pipeline will run NIfTI Conversion for multiple subjects in parallel. For each subject, once NIfTi conversion is completed, the pipeline will automatically run the Brain Extraction and Tumor Extraction stages and then await for manual confirmation (see [Section 5.1](#51-manual-approval-steps) for instructions regarding manual confirmation). The `per subject: true` configuration present in multiple steps of the pipeline signifies that this splitting per subject must be done at these steps. + +![Representation of the whole pipeline](./readme_images/pipeline_diagram.png) + +When the parser converts the YAML file into Airflow, each box in the above Figure is converted into a Directed Acyclic Graph (DAG) in Airflow. This results in the Airflow form of the pipeline being constructed as multiple DAGs, which can be though of as a grouping of one or more data processing steps. + +## 6. Monitoring in Airflow +Airflow’s Web UI can be used to monitor the Pipeline while it is running. The WebUI will be ready for accessing when a message similar to the image below appears in the terminal used to run DataPreparation: +```text +Starting Airflow components +Airflow components successfully started +MedPerf has started executing the Data Pipeline rano_workflow_prep via Airflow. +Execution will continue until the pipeline successfully completes. +Please use the following link to access the Airflow WebUI: + +http://localhost:8080/medperf/auto_login?username=SOME_USERNAME&password=RANDOMLY_GENERATED_PASSWORD + +If this process must be stopped prematurely, please use the Ctrl+C command! +``` + +The auto-generated link can be clicked to automatically login in the Airflow WebUI. By clicking the link, the Airflow home screen will be displayed in your web browser, as shown below. +![Airflow home screen](./readme_images/airflow_home.png) + +You can click on the DAGs button, in red in the figure, to switch to the DAGs view. A list of all currently loaded Airflow DAGs will be displayed, as shown below. The pipeline itself consists of multiple DAGs and each DAG maps to one of the `steps` defind in the YAML version of the Pipeline. Each DAG is the corresponding step name, both in its raw format from the YAML file (`some_step`) and in a more readable format (`Some Step`) and, in case of steps with `per_subject: true`, also by the Subject ID and Timepoint. + +![DAG view in Airflow](./readme_images/dag_list.png) + +A view of Airflow Task Instances, which are the unit of execution used by Airflow, may be displayed by clicking the `Task Instances` button at the top of the screen. In this screen, Task Instances may be filtered by their state. We recommend filtering by `Running`, `Failed`, `Success` and `Up for Reschedule` states. The `Up for Reschedule` state is relevant for the Manual Approval Steps discussed in [Section 5.1](#51-manual-approval-steps). The Figure below shows a view of Task Instances with these filters applied, with the `Task Instances` button showcased in red and the state filters in blue. + +![Task Instances view in Airflow](./readme_images/task_instances_view.png) + +### 6.1 Manual Approval Steps +The automatic Tumor Segmentations must be manually validated before the Pipeline concludes. To help with finding the tasks that are awaiting for Manual Approval, we recomend going into the Task Instance view described previously and filter by `Up for Reschedule` tasks. The pipeline automatically creates the `Conditions Prepare for Manual Review` task to evaluate the `if` fields from the `prepare_for_manual_review` step defined in the YAML file. While awating for approval, these tasks remain in the `Up for Reschedule` state. The Figure below shows a Task Instance list view in this situation, with the Task IDs and State in red: + +![DAGs ready for Manual Review](./readme_images/tasks_manual_review.png) + +In the Figure above, Subjects AAAC_1/2008.03.031 and AAAC_1/2012.01.02 are ready for manual review, signalled by the `State` (in blue) column having the status `Up for Reschedule`. This status means that none of the conditions defined in step `prepare_for_manual_review` of the YAML file (`dags_from_yaml/rano.yaml`) have been met yet, and therefore the pipeline is waiting for their manual completion by a user. The procedure for Manual Review is described in Sections [5.1](#51-manual-approval-steps---tumor-segmentation) and [5.2](#52-brain-mask-correction). Subject AAAC_2/2001.01.01 on the other hand, has a currently running task, signalled by the `Running` state, and therefore is not ready for manual review yet. + +#### 6.1.1 Tumor Segmentation +Once the segmentation for a given subject is ready for review, it will be available at the following path: + +``` +{MEDPERF_DATA_DIR}/{DATASET_UID}/data/manual_review/tumor_extraction/{SUBJECT_ID}/{TIMEPOINT}/under_review/{SUBJECT_ID}_{TIMEPOINT}_tumorMask_model_0.nii.gz +``` + +In this path `{MEDPERF_DATA_DIR}` MedPerf data directory, located at `~/.medperf/data/localhost_8000` if running the local development server. `{DATASET_UID}` is the UID of the registered dataset and shoul be `1` if a server reset was run prior to this tutorial,`{SUBJECT_ID}` and `{TIMEPOINT}` must be substituted for the corresponding SubjectID and Timepoint of each data point. Note that this is in the `under_review` directory, signalling the tumor segmentation has not been reviewed yet. For example, for subject AAAC_2 and timepoint 2001.01.01 the complete path would be: + +``` +{MEDPERF_DATA_DIR}/{DATASET_UID}/data/manual_review/tumor_extraction/AAAC_2/2001.01.01/under_review/AAAC_2_2001.01.01_tumorMask_model_0.nii.gz +``` + +The tumor segmentation can be reviewed with the software of your choice and, if necessary, corrections can be made. Once the review is finished, the file must be moved to the adjacent `finalized` directory. The complete path to the `finalized` file is, then + +``` +{MEDPERF_DATA_DIR}/{DATASET_UID}/data/manual_review/tumor_extraction/{SUBJECT_ID}/{TIMEPOINT}/finalized/{SUBJECT_ID}_{TIMEPOINT}_tumorMask_model_0.nii.gz +``` + +Note that this is in the `finalized` directory, signalling the review has been done. Once the Tumor Segmentation is in the `finalized` directory, the pipeline will automatically detect it and proceed for this subject. ***IMPORTANT!! Do NOT change the filename when moving the file into the finalized directory!*** The pipeline will only detect the reviewed Tumor Segmentation if it keeps the exact same filename. + +Please do this review process for all subjects in the study. If the brain mask itself must be corrected for any subjects, please refer to [Section 5.1.2](#512-brain-mask-correction). Note that modifying the Brain Mask of a Subject will cause the pipeline to rollback to the Brain Extraction step corresponding to that subject to run again, after which the given Tumor Segmentation must be manually approved once ready. + +#### 6.1.2 Brain Mask Correction + +If the automatic brain mask is correct, no action from this section is required. However, it is also possible to make corrections to the automatic brain mask, if necessary. **Note that if the Brain Mask is modified, the pipeline will go back to the Brain Extraction stage for this subject, then run Tumor Extraction and await for manual approval once again oncfe the Tumor Extraction is completed.** Once the pipeline reaches the manual approval step for a given subject/timepoint, the brain mask file will be located at the path below: + +``` +{MEDPERF_DATA_DIR}/{DATASET_UID}/data/manual_review/brain_mask/{SUBJECT_ID}/{TIMEPOINT/under_review/brainMask_fused.nii.gz +``` + +In this path `{MEDPERF_DATA_DIR}` MedPerf data directory, located at `~/.medperf/data/localhost_8000` if running the local development server. `{DATASET_UID}` is the UID of the registered dataset and shoul be `1` if a server reset was run prior to this tutorial, `{SUBJECT_ID}` and `{TIMEPOINT}` must be substituted for the corresponding SubjectID and Timepoint of each data point. Note that this is in the `under_review` directory, signalling the tumor segmentation has not been reviewed yet. + +The brain mask can be reviewed and corrected with the software of your choice and, if necessary, corrections can be made. Once the corrections are finished, the file must be moved to the adjacent `finalized` directory. The complete path to the finalized file is, then: + +``` +{MEDPERF_DATA_DIR}/{DATASET_UID}/datamanual_review/brain_mask/{SUBJECT_ID}/{TIMEPOINT/finalized/brainMask_fused.nii.gz +``` + +***IMPORTANT!! Do NOT change the filename when moving the file into the finalized directory!*** The pipeline will only detect the corrected Brain Mask if it keeps the exact same filename. + +#### 6.1.3 Auto-approval script + +For testing and debugging purposes, a script is available in this directory to automatically approve the generated tumor segmentations of the development dataset. If desired, this script may be run by executing the following command inside the same directory as this README file. + +```shell +sh auto_approve.sh +``` + +#### 6.2 Final Confirmation +There is also a manual confirmation step towards the end of the pipeline (step ID `final_confirmation`, of type `manual_approval`). When converted into an Airflow task, this step results into an Approval Task requires manual approval by the user. This task may be easily found by enabling the `Required Actions` in the Airflow UI, as shown in the figure below. If final approval is not yet required, the UI will instead display no DAGs once the filter is selected. + +![Filtering DAGs by Required Actions](./readme_images/filter_by_required_actions.png) + +***IMPORTANT!!* This task will *NOT* show up if all Manual Reviews are not done yet!** If you are unable to find the `Final Confirmation` task instance, make sure you have completed all the Manual Review steps outlind in [Section 5.1](#51-manual-approval-steps). + +Once the `Final Confirmation` task is available, click on the Task Name as shown in the Airflow WebUI. Once the task name is clicked, the window should display a view similar to the figure below. + +![Final Confirmation DAG View](./readme_images/final_confirmation_dag.png) + +In this view, click the `Required Actions (1)` button. The view will then change into a view similar to the Figure below. + +![Required Actions view](./readme_images/required_actions_dag.png) + +Here, click the `Manual Approval Task` text to display the final confirmation view shown below. + +![Final confirmation button](./readme_images/final_confirmation_approve_button.png) + +**Before proceeding with this step, *make sure to review and Tumor Segmentations as per [Section 6.1.1](#611-manual-approval-steps---tumor-segmentation) and ensure you approve all of the results, along with necessary corrections to Brain Masks ([Section 6.1.2](#612-brain-mask-correction)) if any are necessary.*** If all generated images are to your liking, click the `Approve` button to confirm the generated images. Airflow will then automatically proceed with executing the remainder of the Pipeline. + +## 7. Output Data + +The outputs of the pipeline, upon its conclusion, are as follows: + +- The `report.yaml` file, located at `{MEDPERF_DATA_DIR}/{DATASET_UID}/data/report.yaml` which is updated every minute with the completion percentages of each step defined on the Pipeline YAML file (`./workflow.yaml` or `/.workflow_dev.yaml`). Its contents are uploaded to the local MedPerf server. In a real use case, the contents are uploaded to the production MedPerf server and are used for monitoring the progression of data preparation at the different sites running data preparation. Note that the `report.yaml` file contains only a summary of pipeline status and progression completion, not any actual data. + +- The `{MEDPERF_DATA_DIR}/{DATASET_UID}/metadata` directory contains metadata YAML files for each subject, extracted from the initial DICOM data. + +- The `{MEDPERF_DATA_DIR}/{DATASET_UID}/labels` directory contains the final tumor segmentations for each subject. + +- The `{MEDPERF_DATA_DIR}/{DATASET_UID}/data` directory contains two different outputs. + - The NIfTi files obtained for each subject after Brain Extraction, located at `{MEDPERF_DATA_DIR}/{DATASET_UID}/data/{SUBJECT_ID}/{TIMEPOINT}` for each subject/timepoint combination. + - A `splits.csv` file detailing whether each subject was separated into the training or validation data sets. + - A `train.csv` file containing only subjects in the training dataset. + - A `val.csv` file containing only subjects in the validation dataset. diff --git a/examples/RANO/data_preparator_workflow/auto_approve.sh b/examples/RANO/data_preparator_workflow/auto_approve.sh new file mode 100644 index 000000000..85a2bcc7a --- /dev/null +++ b/examples/RANO/data_preparator_workflow/auto_approve.sh @@ -0,0 +1,4 @@ +BASE_DIR=~/.medperf/data/localhost_8000/1/data/manual_review/tumor_extraction +cp $BASE_DIR/AAAC_1/2012.01.02/under_review/AAAC_1_2012.01.02_tumorMask_model_0.nii.gz $BASE_DIR/AAAC_1/2012.01.02/finalized/AAAC_1_2012.01.02_tumorMask_model_0.nii.gz +cp $BASE_DIR/AAAC_1/2008.03.31/under_review/AAAC_1_2008.03.31_tumorMask_model_0.nii.gz $BASE_DIR/AAAC_1/2008.03.31/finalized/AAAC_1_2008.03.31_tumorMask_model_0.nii.gz +cp $BASE_DIR/AAAC_2/2001.01.01/under_review/AAAC_2_2001.01.01_tumorMask_model_0.nii.gz $BASE_DIR/AAAC_2/2001.01.01/finalized/AAAC_2_2001.01.01_tumorMask_model_0.nii.gz \ No newline at end of file diff --git a/examples/RANO/data_preparator_workflow/pipeline/Dockerfile b/examples/RANO/data_preparator_workflow/pipeline/Dockerfile new file mode 100644 index 000000000..c8b936927 --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/Dockerfile @@ -0,0 +1,5 @@ +FROM mlcommons/rano-data-prep-mlcube:1.0.10 + +COPY ./project /project +ENV WORKSPACE_DIRECTORY='/workspace' +ENTRYPOINT ["python", "/project/direct_stages.py"] diff --git a/examples/RANO/data_preparator_workflow/pipeline/Dockerfile.dev b/examples/RANO/data_preparator_workflow/pipeline/Dockerfile.dev new file mode 100644 index 000000000..8c4cf6c26 --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/Dockerfile.dev @@ -0,0 +1,20 @@ +FROM mlcommons/rano-data-prep-mlcube:1.0.10 + +COPY ./project /project +ENV WORKSPACE_DIRECTORY='/workspace' + +COPY ./atlasImage_0.125.nii.gz /project +COPY ./tmpmodel/ /project/tmpmodel/ + +# use a downsampled reference image for DICOM to NIFTI conversion +RUN mv /project/atlasImage_0.125.nii.gz /Front-End/bin/install/appdir/usr/data/sri24/atlasImage.nii.gz + +# remove heavy brain extraction models +RUN rm -rf /project/stages/data_prep_models/brain_extraction/model_0/ +RUN rm -rf /project/stages/data_prep_models/brain_extraction/model_1/ + +# use dummy brain extraction models +RUN cp -r /project/tmpmodel /project/stages/data_prep_models/brain_extraction/model_0 +RUN mv /project/tmpmodel /project/stages/data_prep_models/brain_extraction/model_1 + +ENTRYPOINT ["python", "/project/direct_stages.py"] diff --git a/examples/RANO/data_preparator_workflow/pipeline/project/direct_stages.py b/examples/RANO/data_preparator_workflow/pipeline/project/direct_stages.py new file mode 100644 index 000000000..f424c4c99 --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/project/direct_stages.py @@ -0,0 +1,315 @@ +"""MLCube handler file""" + +import typer +import os +from stages.env_vars import WORKSPACE_DIR, DATA_DIR, INPUT_DIR +from stages.utils import get_aux_files_dir, get_data_csv_filepath, convert_path_to_index +from stages.mlcube_constants import ( + RAW_PATH, + AUX_FILES_PATH, + VALID_PATH, + PREP_PATH, + BRAIN_PATH, + TUMOR_PATH, + DONE_STAGE_STATUS, + BRAIN_STAGE_STATUS, + TUMOR_STAGE_STATUS, + TUMOR_BACKUP_PATH, + MANUAL_STAGE_STATUS, + MANUAL_REVIEW_PATH, + LABELS_PATH, + METADATA_PATH, +) +from stages.constants import INTERIM_FOLDER +from sanity_check import sanity_check +from metrics import calculate_statistics + +app = typer.Typer() + + +@app.command("initial_setup") +def initial_setup(): + from stages.generate_report import InitialSetup + + raw_dir = os.path.join(DATA_DIR, RAW_PATH) + labels_out_dir = os.path.join(WORKSPACE_DIR, LABELS_PATH) + brain_out = os.path.join(DATA_DIR, BRAIN_PATH) + tumor_out = os.path.join(DATA_DIR, TUMOR_PATH) + report_generator = InitialSetup( + data_csv=None, + input_path=INPUT_DIR, + output_path=raw_dir, + input_labels_path=INPUT_DIR, + output_labels_path=labels_out_dir, + done_data_out_path=DATA_DIR, + done_status=DONE_STAGE_STATUS, + brain_data_out_path=brain_out, + brain_status=BRAIN_STAGE_STATUS, + tumor_data_out_path=tumor_out, + reviewed_status=MANUAL_STAGE_STATUS, + ) + report_generator.execute(None) + + +@app.command("make_csv") +def prepare( + subject_subdir: str = typer.Option(..., "--partition"), +): + from stages.get_csv import ( + AddToCSV, + ) + + output_csv_dir = get_aux_files_dir(subject_subdir) + os.makedirs(output_csv_dir, exist_ok=True) + output_csv = get_data_csv_filepath(subject_subdir) + out_dir = os.path.join(DATA_DIR, VALID_PATH) + csv_creator = AddToCSV( + input_dir=INPUT_DIR, + output_csv=output_csv, + out_dir=out_dir, + prev_stage_path=INPUT_DIR, + ) + subject_index = convert_path_to_index(subject_subdir) + csv_creator.execute(subject_index) + print(output_csv) + + +@app.command("convert_nifti") +def convert_nifti( + subject_subdir: str = typer.Option(..., "--partition"), +): + from stages.nifti_transform import NIfTITransform + + csv_path = get_data_csv_filepath(subject_subdir) + output_path = os.path.join(DATA_DIR, PREP_PATH, subject_subdir) + metadata_path = os.path.join(WORKSPACE_DIR, METADATA_PATH) + os.makedirs(output_path, exist_ok=True) + os.makedirs(metadata_path, exist_ok=True) + + nifti_transform = NIfTITransform( + data_csv=csv_path, + out_path=output_path, + prev_stage_path=INPUT_DIR, + metadata_path=metadata_path, + data_out=DATA_DIR, + ) + subject_index = convert_path_to_index(subject_subdir) + nifti_transform.execute(subject_index) + print(output_path) + + +@app.command("extract_brain") +def extract_brain( + subject_subdir: str = typer.Option(..., "--partition"), +): + from stages.extract import Extract + + csv_path = get_data_csv_filepath(subject_subdir) + output_path = os.path.join(DATA_DIR, BRAIN_PATH, subject_subdir) + prev_path = os.path.join(DATA_DIR, PREP_PATH, subject_subdir) + os.makedirs(output_path, exist_ok=True) + + brain_extract = Extract( + data_csv=csv_path, + out_path=output_path, + subpath=INTERIM_FOLDER, + prev_stage_path=prev_path, + prev_subpath=INTERIM_FOLDER, + func_name="extract_brain", + status_code=BRAIN_STAGE_STATUS, + ) + subject_index = convert_path_to_index(subject_subdir) + brain_extract.execute(subject_index) + print(output_path) + + +@app.command("extract_tumor") +def extract_tumor( + subject_subdir: str = typer.Option(..., "--partition"), +): + from stages.extract_nnunet import ExtractNnUNet + + csv_path = get_data_csv_filepath(subject_subdir) + output_path = os.path.join(DATA_DIR, TUMOR_PATH, subject_subdir) + prev_path = os.path.join(DATA_DIR, BRAIN_PATH, subject_subdir) + os.makedirs(output_path, exist_ok=True) + + models_path = os.path.join(WORKSPACE_DIR, "additional_files", "models") + tmpfolder = os.path.join(WORKSPACE_DIR, DATA_DIR, ".tmp", subject_subdir) + cbica_tmpfolder = os.path.join(tmpfolder, ".cbicaTemp") + os.environ["TMPDIR"] = tmpfolder + os.environ["CBICA_TEMP_DIR"] = cbica_tmpfolder + os.makedirs(tmpfolder, exist_ok=True) + os.makedirs(cbica_tmpfolder, exist_ok=True) + os.environ["RESULTS_FOLDER"] = os.path.join(models_path, "nnUNet_trained_models") + os.environ["nnUNet_raw_data_base"] = os.path.join(tmpfolder, "nnUNet_raw_data_base") + os.environ["nnUNet_preprocessed"] = os.path.join(tmpfolder, "nnUNet_preprocessed") + tumor_extract = ExtractNnUNet( + data_csv=csv_path, + out_path=output_path, + subpath=INTERIM_FOLDER, + prev_stage_path=prev_path, + prev_subpath=INTERIM_FOLDER, + status_code=TUMOR_STAGE_STATUS, + ) + subject_index = convert_path_to_index(subject_subdir) + tumor_extract.execute(subject_index) + print(output_path) + + +@app.command("prepare_for_manual_review") +def prepare_for_manual_review( + subject_subdir: str = typer.Option(..., "--partition"), +): + + from stages.manual import ManualStage + + csv_path = get_data_csv_filepath(subject_subdir) + prev_stage_path = os.path.join(DATA_DIR, TUMOR_PATH, subject_subdir) + backup_out = os.path.join(WORKSPACE_DIR, LABELS_PATH, TUMOR_BACKUP_PATH) + + manual_validation = ManualStage( + data_csv=csv_path, + out_path=prev_stage_path, + prev_stage_path=prev_stage_path, + backup_path=backup_out, + ) + subject_index = convert_path_to_index(subject_subdir) + manual_validation.prepare_directories(subject_index) + + +@app.command("rollback_to_brain_extract") +def rollback( + subject_subdir: str = typer.Option(..., "--partition"), +): + + from stages.manual import ManualStage + + csv_path = get_data_csv_filepath(subject_subdir) + prev_stage_path = os.path.join(DATA_DIR, TUMOR_PATH, subject_subdir) + backup_out = os.path.join(WORKSPACE_DIR, LABELS_PATH, TUMOR_BACKUP_PATH) + + manual_validation = ManualStage( + data_csv=csv_path, + out_path=prev_stage_path, + prev_stage_path=prev_stage_path, + backup_path=backup_out, + ) + subject_index = convert_path_to_index(subject_subdir) + manual_validation.rollback(subject_index) + + +@app.command("segmentation_comparison") +def segmentation_comparison( + subject_subdir: str = typer.Option(..., "--partition"), +): + from stages.comparison import SegmentationComparisonStage + + csv_path = get_data_csv_filepath(subject_subdir) + prev_stage_path = os.path.join(DATA_DIR, TUMOR_PATH, subject_subdir) + labels_out = os.path.join(WORKSPACE_DIR, LABELS_PATH) + backup_out = os.path.join(labels_out, TUMOR_BACKUP_PATH) + + segment_compare = SegmentationComparisonStage( + data_csv=csv_path, + out_path=labels_out, + prev_stage_path=prev_stage_path, + backup_path=backup_out, + ) + subject_index = convert_path_to_index(subject_subdir) + segment_compare.execute(subject_index) + + +@app.command("calculate_changed_voxels") +def calculate_changed_voxels(): + from stages.confirm import ConfirmStage + + prev_stage_path = os.path.join(DATA_DIR, TUMOR_PATH) + labels_out = os.path.join(WORKSPACE_DIR, LABELS_PATH) + backup_out = os.path.join(labels_out, TUMOR_BACKUP_PATH) + + confirm_stage = ConfirmStage( + out_data_path=DATA_DIR, + out_labels_path=labels_out, + prev_stage_path=prev_stage_path, + backup_path=backup_out, + ) + confirm_stage.execute() + + +@app.command("move_labeled_files") +def move_labeled_files(): + from stages.confirm import ConfirmStage + + prev_stage_path = os.path.join(DATA_DIR, TUMOR_PATH) + labels_out = os.path.join(WORKSPACE_DIR, LABELS_PATH) + backup_out = os.path.join(labels_out, TUMOR_BACKUP_PATH) + + confirm_stage = ConfirmStage( + out_data_path=DATA_DIR, + out_labels_path=labels_out, + prev_stage_path=prev_stage_path, + backup_path=backup_out, + ) + confirm_stage.move_labels() + + +@app.command("consolidation_stage") +def consolidation_stage(keep_files: bool = typer.Option(False, "--keep-files")): + from stages.split import SplitStage + + labels_out = os.path.join(WORKSPACE_DIR, LABELS_PATH) + params_path = os.path.join(WORKSPACE_DIR, "parameters.yaml") + base_finalized_dir = os.path.join(DATA_DIR, TUMOR_PATH, INTERIM_FOLDER) + + if keep_files: + dirs_to_remove = [] + else: + subdirs_to_remove = [ + BRAIN_PATH, + AUX_FILES_PATH, + PREP_PATH, + TUMOR_PATH, + RAW_PATH, + TUMOR_PATH, + VALID_PATH, + MANUAL_REVIEW_PATH, + ] + dirs_to_remove = [ + os.path.join(DATA_DIR, subdir) for subdir in subdirs_to_remove + ] + dirs_to_remove.extend( + [ + os.path.join(WORKSPACE_DIR, DATA_DIR, ".tmp"), + os.path.join(labels_out, ".tmp"), + os.path.join(labels_out, ".tumor_segmentation_backup"), + ] + ) + + split = SplitStage( + params=params_path, + data_path=DATA_DIR, + labels_path=labels_out, + staging_folders=dirs_to_remove, + base_finalized_dir=base_finalized_dir, + ) + split.execute() + + +@app.command("sanity_check") +def sanity_check_cmdline(): + data_path = DATA_DIR + labels_path = os.path.join(WORKSPACE_DIR, LABELS_PATH) + sanity_check(data_path=data_path, labels_path=labels_path) + + +@app.command("metrics") +def metrics_cmdline(): + splits_path = os.path.join(DATA_DIR, "splits.csv") + invalid_path = os.path.join(WORKSPACE_DIR, METADATA_PATH, ".invalid.txt") + out_file = os.path.join(WORKSPACE_DIR, METADATA_PATH, "statistics.yml") + calculate_statistics(splits_path, invalid_path, out_file) + + +if __name__ == "__main__": + app() diff --git a/examples/RANO/data_preparator_workflow/pipeline/project/metrics.py b/examples/RANO/data_preparator_workflow/pipeline/project/metrics.py new file mode 100644 index 000000000..9c9b70016 --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/project/metrics.py @@ -0,0 +1,57 @@ +import os +import yaml +import argparse +import pandas as pd + + +def calculate_statistics(splits_path, invalid_path, out_file): + invalid_subjects = [] + if os.path.exists(invalid_path): + with open(invalid_path, "r") as f: + invalid_subjects = f.readlines() + + splits_df = pd.read_csv(splits_path) + + num_train_subjects = len(splits_df[splits_df["Split"] == "Train"]) + num_val_subjects = len(splits_df[splits_df["Split"] == "Val"]) + + num_invalid_subjects = len(invalid_subjects) + + stats = { + "num_train_subjects": num_train_subjects, + "num_val_subjects": num_val_subjects, + "num_invalid_subjects": num_invalid_subjects, + } + + with open(out_file, "w") as f: + yaml.dump(stats, f) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("MedPerf Statistics Example") + parser.add_argument( + "--data_path", + dest="data", + type=str, + help="directory containing the prepared data", + ) + parser.add_argument( + "--labels_path", + dest="labels", + ) + parser.add_argument( + "--out_file", dest="out_file", type=str, help="file to store statistics" + ) + parser.add_argument( + "--metadata_path", + dest="metadata_path", + type=str, + help="path to the local metadata folder", + ) + + args = parser.parse_args() + + splits_path = os.path.join(args.data, "splits.csv") + invalid_path = os.path.join(args.metadata_path, ".invalid.txt") + + calculate_statistics(splits_path, invalid_path) diff --git a/examples/RANO/data_preparator_workflow/pipeline/project/requirements.txt b/examples/RANO/data_preparator_workflow/pipeline/project/requirements.txt new file mode 100644 index 000000000..a46e7b08e --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/project/requirements.txt @@ -0,0 +1,12 @@ +typer==0.9.0 +pandas==1.5.3 +PyYAML==6.0.1 +filelock==3.16.1 +# Include all your requirements here +SimpleITK==2.3.1 +tqdm==4.66.2 +scikit-image==0.21.0 +FigureGenerator==0.0.4 +gandlf==0.0.16 +labelfusion==1.0.14 +nibabel==5.1.0 diff --git a/examples/RANO/data_preparator_workflow/pipeline/project/sanity_check.py b/examples/RANO/data_preparator_workflow/pipeline/project/sanity_check.py new file mode 100644 index 000000000..d2137e37b --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/project/sanity_check.py @@ -0,0 +1,45 @@ +import yaml +import argparse +import pandas as pd + +from stages.utils import has_prepared_folder_structure + + +def sanity_check(data_path: str, labels_path: str): + """Runs a few checks to ensure data quality and integrity + + Args: + data_path (str): Path to data. + labels_path (str): Path to labels. + """ + # Here you must add all the checks you consider important regarding the + # state of the data + if not has_prepared_folder_structure(data_path, labels_path): + print("The contents of the labels and data don't resemble a prepared dataset", flush=True) + exit(1) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Medperf Model Sanity Check Example") + parser.add_argument( + "--data_path", + dest="data", + type=str, + help="directory containing the prepared data", + ) + parser.add_argument( + "--labels_path", + dest="labels", + type=str, + help="directory containing the prepared labels", + ) + parser.add_argument( + "--metadata_path", + dest="metadata_path", + type=str, + help="path to the local metadata folder", + ) + + args = parser.parse_args() + + sanity_check(args.data, args.labels) diff --git a/examples/RANO/data_preparator_workflow/pipeline/project/stages/comparison.py b/examples/RANO/data_preparator_workflow/pipeline/project/stages/comparison.py new file mode 100644 index 000000000..9d1b04c85 --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/project/stages/comparison.py @@ -0,0 +1,178 @@ +from typing import Union, Tuple +import os + +import pandas as pd +from pandas import DataFrame +import numpy as np +import nibabel as nib + +from .row_stage import RowStage +from .utils import ( + get_id_tp, + get_changed_voxels_file, + md5_file, + get_manual_approval_finalized_path, +) +from .constants import INTERIM_FOLDER +from .mlcube_constants import ( + COMPARISON_STAGE_STATUS, + GROUND_TRUTH_PATH, + TUMOR_EXTRACTION_REVIEW_PATH, + TUMOR_PATH, +) +from .env_vars import DATA_DIR + + +class SegmentationComparisonStage(RowStage): + def __init__( + self, + data_csv: str, + out_path: str, + prev_stage_path, + backup_path: str, + ): + self.data_csv = data_csv + self.out_path = out_path + self.prev_stage_path = prev_stage_path + self.backup_path = backup_path + + @property + def name(self): + return "Label Segmentation Comparison" + + @property + def status_code(self): + return COMPARISON_STAGE_STATUS + + def __get_input_path(self, index: Union[str, int]) -> str: + id, tp = get_id_tp(index) + path = get_manual_approval_finalized_path(id, tp, TUMOR_EXTRACTION_REVIEW_PATH) + return path + + def __get_backup_path(self, index: Union[str, int]) -> str: + id, tp = get_id_tp(index) + path = os.path.join( + DATA_DIR, TUMOR_PATH, id, tp, INTERIM_FOLDER, id, tp, GROUND_TRUTH_PATH + ) + return path + + def __get_output_path(self, index: Union[str, int]) -> str: + id, tp = get_id_tp(index) + path = os.path.join(self.out_path, id, tp) + return path + + def __get_case_path(self, index: Union[str, int]) -> str: + path = self.__get_input_path(index) + case = os.listdir(path)[0] + + return os.path.join(path, case) + + def __report_gt_not_found( + self, index: Union[str, int], report: pd.DataFrame, reviewed_hash: str + ) -> pd.DataFrame: + case_path = self.__get_case_path(index) + data_path = report.loc[index, "data_path"] + report_data = { + "status": -self.status_code - 0.2, # -6.2 + "data_path": data_path, + "labels_path": case_path, + "segmentation_hash": reviewed_hash, + } + + return report + + def __report_exact_match( + self, index: Union[str, int], report: pd.DataFrame, reviewed_hash: str + ) -> pd.DataFrame: + case_path = self.__get_case_path(index) + data_path = report.loc[index, "data_path"] + report_data = { + "status": -self.status_code - 0.1, # -6.1 + "data_path": data_path, + "labels_path": case_path, + "num_changed_voxels": 0, + "segmentation_hash": reviewed_hash, + } + + return report + + def __report_success( + self, + index: Union[str, int], + report: pd.DataFrame, + num_changed_voxels: int, + reviewed_hash: str, + ) -> pd.DataFrame: + case_path = self.__get_case_path(index) + data_path = report.loc[index, "data_path"] + report_data = { + "status": -self.status_code, # -6 + "data_path": data_path, + "labels_path": case_path, + "num_changed_voxels": num_changed_voxels, + "segmentation_hash": reviewed_hash, + } + update_row_with_dict(report, report_data, index) + return report + + def could_run(self, index: Union[str, int], report: DataFrame) -> bool: + print(f"Checking if {self.name} can run") + # Ensure a single reviewed segmentation file exists + path = self.__get_input_path(index) + gt_path = self.__get_backup_path(index) + + is_valid = True + path_exists = os.path.exists(path) + gt_path_exists = os.path.exists(gt_path) + contains_case = False + reviewed_hash = None + if path_exists: + cases = os.listdir(path) + num_cases = len(cases) + if num_cases: + reviewed_file = os.path.join(path, cases[0]) + reviewed_hash = md5_file(reviewed_file) + contains_case = num_cases == 1 + + prev_hash = report.loc[index]["segmentation_hash"] + hash_changed = prev_hash != reviewed_hash + print( + f"{path_exists=} and {contains_case=} and {gt_path_exists=} and {hash_changed=}" + ) + is_valid = path_exists and contains_case and gt_path_exists and hash_changed + + return is_valid + + def execute(self, index: Union[str, int]) -> Tuple[DataFrame, bool]: + + id, tp = get_id_tp(index) + path = self.__get_input_path(index) + cases = os.listdir(path) + print(f"{path=}") + print(f"{cases=}") + match_output_path = self.__get_output_path(index) + os.makedirs(match_output_path, exist_ok=True) + # Get the necessary files for match check + # We assume reviewed and gt files have the same name + reviewed_file = os.path.join(path, cases[0]) + reviewed_hash = md5_file(reviewed_file) + gt_file = os.path.join(self.__get_backup_path(index), cases[0]) + print(f"{gt_file=}") + if not os.path.exists(gt_file): + print("Ground truth file not found, reviewed file most probably renamed") + report = self.__report_gt_not_found(index, report, reviewed_hash) + raise ValueError("Ground truth file not found") + + reviewed_img = nib.load(reviewed_file) + gt_img = nib.load(gt_file) + + reviewed_voxels = np.array(reviewed_img.dataobj) + gt_voxels = np.array(gt_img.dataobj) + + num_changed_voxels = np.sum(reviewed_voxels != gt_voxels) + print(f"{num_changed_voxels=}") + changed_voxels_file = get_changed_voxels_file(id, tp) + with open(changed_voxels_file, "w") as f: + f.write(str(num_changed_voxels)) + + return True diff --git a/examples/RANO/data_preparator_workflow/pipeline/project/stages/confirm.py b/examples/RANO/data_preparator_workflow/pipeline/project/stages/confirm.py new file mode 100644 index 000000000..1a02ddb56 --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/project/stages/confirm.py @@ -0,0 +1,178 @@ +from typing import Union, Tuple +import os +import shutil +from time import sleep + +import pandas as pd +from pandas import DataFrame + +from .dset_stage import DatasetStage +from .utils import ( + get_id_tp, + get_subdirectories, + get_manual_approval_finalized_path, + get_changed_voxels_file, + find_finalized_subjects, +) +from .constants import FINAL_FOLDER +from .mlcube_constants import ( + CONFIRM_STAGE_STATUS, + TUMOR_EXTRACTION_REVIEW_PATH, + AUX_FILES_PATH, +) +from .env_vars import DATA_DIR + + +class ConfirmStage(DatasetStage): + def __init__( + self, + out_data_path: str, + out_labels_path: str, + prev_stage_path: str, + backup_path: str, + ): + self.out_data_path = out_data_path + self.out_labels_path = out_labels_path + self.prev_stage_path = prev_stage_path + self.backup_path = backup_path + self.prompt_file = ".prompt.txt" + self.response_file = ".response.txt" + + @property + def name(self): + return "Annotations Confirmation" + + @property + def status_code(self): + return CONFIRM_STAGE_STATUS + + def __get_input_data_path(self, id, tp): + path = os.path.join(self.prev_stage_path, id, tp, FINAL_FOLDER, id, tp) + return path + + def __get_input_label_path(self, id, tp): + path = get_manual_approval_finalized_path(id, tp, TUMOR_EXTRACTION_REVIEW_PATH) + + case = os.listdir(path)[0] + + return os.path.join(path, case) + + def __get_output_data_path(self, id, tp): + path = os.path.join(self.out_data_path, id, tp) + return path + + def __get_output_label_path(self, id, tp): + path = os.path.join(self.out_labels_path, id, tp) + filename = f"{id}_{tp}_final_seg.nii.gz" + return path, filename + + def __confirm(self, exact_match_percent: float) -> bool: + exact_match_percent = round(exact_match_percent * 100, 2) + msg = ( + f"We've identified {exact_match_percent}% of cases have not been modified " + + "with respect to the baseline segmentation. Do you confirm this is intended? " + + "[Y]/n" + ) + + prompt_path = os.path.join(self.out_data_path, self.prompt_file) + response_path = os.path.join(self.out_data_path, self.response_file) + + with open(prompt_path, "w") as f: + f.write(msg) + + while not os.path.exists(response_path): + sleep(1) + + with open(response_path, "r") as f: + user_input = f.readline().strip() + + os.remove(prompt_path) + os.remove(response_path) + + return user_input == "y" or user_input == "" + + def __report_failure(self, report: DataFrame) -> DataFrame: + # For this stage, failure is done when the user doesn't confirm + # This means he probably wants to keep working on the data + # And needs to know which rows are exact matches. + # Because of this, failing this stage keeps the report intact + return report + + def __process_row(self, subject_dict) -> pd.Series: + """process a row by moving the required files + to their respective locations, and removing any extra files + + Args: + report (DataFrame): data preparation report + + Returns: + DataFrame: modified data preparation report + """ + subject_id, timepoint = subject_dict["SubjectID"], subject_dict["Timepoint"] + input_data_path = self.__get_input_data_path(subject_id, timepoint) + input_label_filepath = self.__get_input_label_path(subject_id, timepoint) + output_data_path = self.__get_output_data_path(subject_id, timepoint) + output_label_path, filename = self.__get_output_label_path( + subject_id, timepoint + ) + output_label_filepath = os.path.join(output_label_path, filename) + + shutil.rmtree(output_data_path, ignore_errors=True) + shutil.copytree(input_data_path, output_data_path) + os.makedirs(output_label_path, exist_ok=True) + shutil.copy(input_label_filepath, output_label_filepath) + + def could_run(self, report: DataFrame) -> bool: + print(f"Checking if {self.name} can run") + # could run once all cases have been compared to the ground truth + missing_voxels = report["num_changed_voxels"].isnull().values.any() + prev_path_exists = os.path.exists(self.prev_stage_path) + empty_prev_path = True + if prev_path_exists: + empty_prev_path = len(os.listdir(self.prev_stage_path)) == 0 + + print( + f"{prev_path_exists=} and not {empty_prev_path=} and not {missing_voxels=}" + ) + return prev_path_exists and not empty_prev_path and not missing_voxels + + @staticmethod + def calculate_exact_match_percent(): + """ + This value is equal to the sum of all subjects where no voxels were + changed in the Tumor Segmentation divided by the total number of subjects. + """ + base_aux_dir = os.path.join(DATA_DIR, AUX_FILES_PATH) + num_subjects = 0 + num_unchanged_subjects = 0 + + for subject_id in get_subdirectories(base_aux_dir): + complete_subject_path = os.path.join(base_aux_dir, subject_id) + for timepoint in get_subdirectories(complete_subject_path): + changed_voxels_file = get_changed_voxels_file(subject_id, timepoint) + if not os.path.isfile(changed_voxels_file): + continue + num_subjects += 1 + with open(changed_voxels_file, "r") as f: + changed_voxels = int(f.read()) + + if not changed_voxels: + num_unchanged_subjects += 1 + + return num_unchanged_subjects / num_subjects + + def execute(self) -> Tuple[DataFrame, bool]: + exact_match_percent = self.calculate_exact_match_percent() + + rounded_percent = round(exact_match_percent * 100, 2) + msg_file = os.path.join(DATA_DIR, AUX_FILES_PATH, ".msg") + print(f"{str(rounded_percent)=}") + with open(msg_file, "w") as f: + f.write(str(rounded_percent)) + return + + def move_labels(self): + finalized_subjects = find_finalized_subjects() + for finalized_subject in finalized_subjects: + self.__process_row(finalized_subject) + return True diff --git a/examples/RANO/data_preparator_workflow/pipeline/project/stages/dset_stage.py b/examples/RANO/data_preparator_workflow/pipeline/project/stages/dset_stage.py new file mode 100644 index 000000000..9792a9e26 --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/project/stages/dset_stage.py @@ -0,0 +1,32 @@ +from abc import ABC, abstractmethod +import pandas as pd +from typing import Tuple + +from .stage import Stage + + +class DatasetStage(Stage, ABC): + @abstractmethod + def could_run(self, report: pd.DataFrame) -> bool: + """Establishes if this step could be executed + + Args: + index (Union[str, int]): case index in the report + report (pd.DataFrame): Dataframe containing the current state of the preparation flow + + Returns: + bool: wether this stage could be executed + """ + + @abstractmethod + def execute(self, report: pd.DataFrame) -> Tuple[pd.DataFrame, bool]: + """Executes the stage + + Args: + index (Union[str, int]): case index in the report + report (pd.DataFrame): DataFrame containing the current state of the preparation flow + + Returns: + pd.DataFrame: Updated report dataframe + bool: Success status + """ diff --git a/examples/RANO/data_preparator_workflow/pipeline/project/stages/env_vars.py b/examples/RANO/data_preparator_workflow/pipeline/project/stages/env_vars.py new file mode 100644 index 000000000..c9f72e5aa --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/project/stages/env_vars.py @@ -0,0 +1,7 @@ +import os + + +WORKSPACE_DIR = os.getenv("WORKSPACE_DIRECTORY") +DATA_DIR = os.path.join(WORKSPACE_DIR, "data") +DATA_SUBDIR = os.path.join(*DATA_DIR.split(os.sep)[-2:]) +INPUT_DIR = os.path.join(WORKSPACE_DIR, "input_data") diff --git a/examples/RANO/data_preparator_workflow/pipeline/project/stages/extract.py b/examples/RANO/data_preparator_workflow/pipeline/project/stages/extract.py new file mode 100644 index 000000000..5c4ed7456 --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/project/stages/extract.py @@ -0,0 +1,152 @@ +from typing import Union, Tuple +from tqdm import tqdm +import pandas as pd +import os +import shutil + +from .row_stage import RowStage +from .PrepareDataset import Preparator +from .utils import get_id_tp +from .constants import FINAL_FOLDER, EXEC_NAME + + +class Extract(RowStage): + def __init__( + self, + data_csv: str, + out_path: str, + subpath: str, + prev_stage_path: str, + prev_subpath: str, + # pbar: tqdm, + func_name: str, + status_code: int, + extra_labels_path=None, + ): + self.data_csv = data_csv + self.out_path = out_path + self.subpath = subpath + self.data_subpath = FINAL_FOLDER + self.prev_path = prev_stage_path + self.prev_subpath = prev_subpath + os.makedirs(self.out_path, exist_ok=True) + self.prep = Preparator(data_csv, out_path, EXEC_NAME) + self.func_name = func_name + self.func = getattr(self.prep, func_name) + self.pbar = tqdm() + self.failed = False + self.exception = None + self.__status_code = status_code + self.extra_labels_path = extra_labels_path or [] + + @property + def name(self) -> str: + return self.func_name.replace("_", " ").capitalize() + + @property + def status_code(self) -> str: + return self.__status_code + + def could_run(self, index: Union[str, int], report: pd.DataFrame) -> bool: + """Determine if case at given index needs to be converted to NIfTI + + Args: + index (Union[str, int]): Case index, as used by the report dataframe + report (pd.DataFrame): Report Dataframe for providing additional context + + Returns: + bool: Wether this stage could be executed for the given case + """ + print(f"Checking if {self.name} can run") + prev_paths = self.__get_paths(index, self.prev_path, self.prev_subpath) + is_valid = all([os.path.exists(path) for path in prev_paths]) + print(f"{is_valid=}") + return is_valid + + def execute( + self, + index: Union[str, int], + ) -> Tuple[pd.DataFrame, bool]: + """Executes the NIfTI transformation stage on the given case + + Args: + index (Union[str, int]): case index, as used by the report + report (pd.DataFrame): DataFrame containing the current state of the preparation flow + + Returns: + pd.DataFrame: Updated report dataframe + """ + self.__prepare_exec() + self.__copy_case(index) + try: + self._process_case(index) + except Exception as e: + del_paths = self.__get_paths(index, self.out_path, self.subpath) + for del_path in del_paths: + shutil.rmtree(del_path, ignore_errors=True) + raise + + success = self.__update_state(index) + self.prep.write() + + return success + + def __prepare_exec(self): + # Reset the file contents for errors + open(self.prep.stderr_log, "w").close() + + # Update the out dataframes to current state + self.prep.read() + + def __get_paths(self, index: Union[str, int], path: str, subpath: str): + id, tp = get_id_tp(index) + data_path = os.path.join(path, self.data_subpath, id, tp) + out_path = os.path.join(path, subpath, id, tp) + return data_path, out_path + + def __copy_case(self, index: Union[str, int]): + prev_paths = self.__get_paths(index, self.prev_path, self.prev_subpath) + copy_paths = self.__get_paths(index, self.out_path, self.prev_subpath) + for prev, copy in zip(prev_paths, copy_paths): + shutil.rmtree(copy, ignore_errors=True) + shutil.copytree(prev, copy, dirs_exist_ok=True) + + def _process_case(self, index: Union[str, int]): + id, tp = get_id_tp(index) + df = self.prep.subjects_df + row_search = df[(df["SubjectID"] == id) & (df["Timepoint"] == tp)] + if len(row_search) > 0: + row = row_search.iloc[0] + else: + # Most probably this case was semi-prepared. Mock a row + row = pd.Series( + { + "SubjectID": id, + "Timepoint": tp, + "T1": "", + "T1GD": "", + "T2": "", + "FLAIR": "", + } + ) + self.func(row, self.pbar) + + def __hide_paths(self, hide_paths): + for path in hide_paths: + dirname = os.path.dirname(path) + hidden_name = f".{os.path.basename(path)}" + hidden_path = os.path.join(dirname, hidden_name) + if os.path.exists(hidden_path): + shutil.rmtree(hidden_path) + shutil.move(path, hidden_path) + + def __update_state(self, index: Union[str, int]) -> bool: + # Backup the paths in case we need to revert to this stage + hide_paths = self.__get_paths(index, self.prev_path, self.prev_subpath) + # Wait a little so that file gets created + # Handle the case where a brain mask doesn't exist + # Due to the subject being semi-prepared + success = True + self.__hide_paths(hide_paths) + + return success diff --git a/examples/RANO/data_preparator_workflow/pipeline/project/stages/extract_nnunet.py b/examples/RANO/data_preparator_workflow/pipeline/project/stages/extract_nnunet.py new file mode 100644 index 000000000..045ab6551 --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/project/stages/extract_nnunet.py @@ -0,0 +1,232 @@ +from typing import Union +from tqdm import tqdm +import os +import shutil +import time +import SimpleITK as sitk +import subprocess + +from .extract import Extract +from .PrepareDataset import ( + Preparator, + FINAL_FOLDER, + generate_tumor_segmentation_fused_images, + save_screenshot, +) +from .utils import ( + get_id_tp, + get_manual_approval_under_review_path, + get_manual_approval_finalized_path, +) +from .mlcube_constants import ( + GROUND_TRUTH_PATH, + TUMOR_EXTRACTION_REVIEW_PATH, + BRAIN_MASK_REVIEW_PATH, + BRAIN_MASK_FILE, +) +from .constants import INTERIM_FOLDER, FINAL_FOLDER + +MODALITY_MAPPING = { + "t1c": "t1c", + "t1ce": "t1c", + "t1": "t1n", + "t1n": "t1n", + "t2": "t2w", + "t2w": "t2w", + "t2f": "t2f", + "flair": "t2f", +} + +MODALITY_VARIANTS = { + "t1c": "T1GD", + "t1ce": "T1GD", + "t1": "T1", + "t1n": "T1", + "t2": "T2", + "t2w": "T2", + "t2f": "FLAIR", + "flair": "FLAIR", +} + + +class ExtractNnUNet(Extract): + def __init__( + self, + data_csv: str, + out_path: str, + subpath: str, + prev_stage_path: str, + prev_subpath: str, + status_code: int, + extra_labels_path=None, + nnunet_executable: str = "/nnunet_env/bin/nnUNet_predict", + ): + self.data_csv = data_csv + self.out_path = out_path + self.subpath = subpath + self.data_subpath = FINAL_FOLDER + self.prev_path = prev_stage_path + self.prev_subpath = prev_subpath + os.makedirs(self.out_path, exist_ok=True) + self.prep = Preparator(data_csv, out_path, "BraTSPipeline") + self.pbar = tqdm() + self.failed = False + self.exception = None + self.__status_code = status_code + self.extra_labels_path = extra_labels_path or [] + self.nnunet_executable = nnunet_executable + + @property + def name(self) -> str: + return "nnUNet Tumor Extraction" + + @property + def status_code(self) -> str: + return self.__status_code + + def __get_models(self): + models_path = os.path.join(os.environ["RESULTS_FOLDER"], "nnUNet", "3d_fullres") + return os.listdir(models_path) + + def __get_mod_order(self, model): + order_path = os.path.join( + os.environ["RESULTS_FOLDER"], + os.pardir, + "nnUNet_modality_order", + model, + "order", + ) + with open(order_path, "r") as f: + order_str = f.readline() + # remove 'order = ' from the splitted list + modalities = order_str.split()[2:] + modalities = [MODALITY_MAPPING[mod] for mod in modalities] + return modalities + + def __prepare_case(self, path, id, tp, order): + tmp_subject = f"{id}-{tp}" + tmp_path = os.path.join(path, "tmp-data") + tmp_subject_path = os.path.join(tmp_path, tmp_subject) + tmp_out_path = os.path.join(path, "tmp-out", tmp_subject) + shutil.rmtree(tmp_subject_path, ignore_errors=True) + shutil.rmtree(tmp_out_path, ignore_errors=True) + os.makedirs(tmp_subject_path) + os.makedirs(tmp_out_path) + in_modalities_path = os.path.join(path, FINAL_FOLDER, id, tp) + input_modalities = {} + for modality_file in os.listdir(in_modalities_path): + if not modality_file.endswith(".nii.gz"): + continue + modality = modality_file[:-7].split("_")[-1] + norm_mod = MODALITY_MAPPING[modality] + mod_idx = order.index(norm_mod) + mod_idx = str(mod_idx).zfill(4) + + out_modality_file = f"{tmp_subject}_{mod_idx}.nii.gz" + in_file = os.path.join(in_modalities_path, modality_file) + out_file = os.path.join(tmp_subject_path, out_modality_file) + input_modalities[MODALITY_VARIANTS[modality]] = in_file + shutil.copyfile(in_file, out_file) + + return tmp_subject_path, tmp_out_path, input_modalities + + def __run_model(self, model, data_path, out_path): + # models are named Task_..., where is always 3 numbers + task_id = model[4:7] + cmd = f"{self.nnunet_executable} -i {data_path} -o {out_path} -t {task_id}" + print(cmd) + print(os.listdir(data_path)) + start = time.time() + subprocess.call(cmd, shell=True) + end = time.time() + total_time = end - start + print(f"Total time elapsed is {total_time} seconds") + + def __finalize_pred(self, tmp_out_path, out_pred_filepath, *copy_paths): + # We assume there's only one file in out_path + pred = None + for file in os.listdir(tmp_out_path): + if file.endswith(".nii.gz"): + pred = file + + if pred is None: + raise RuntimeError("No tumor segmentation was found") + + pred_filepath = os.path.join(tmp_out_path, pred) + pred_dir = os.path.dirname(pred_filepath) + os.makedirs(pred_dir, exist_ok=True) + shutil.move(pred_filepath, out_pred_filepath) + for copy_path in copy_paths: + copy_dir = os.path.dirname(copy_path) + os.makedirs(copy_dir, exist_ok=True) + shutil.copy(out_pred_filepath, copy_path) + return out_pred_filepath + + def _process_case(self, index: Union[str, int]): + id, tp = get_id_tp(index) + subject_id = f"{id}_{tp}" + + models = self.__get_models() + outputs = [] + images_for_fusion = [] + out_path = os.path.join(self.out_path, INTERIM_FOLDER, id, tp) + out_pred_path = os.path.join(out_path, GROUND_TRUTH_PATH) + finalized_tumor_path = get_manual_approval_finalized_path( + id, tp, TUMOR_EXTRACTION_REVIEW_PATH + ) + under_review_tumor_path = get_manual_approval_under_review_path( + id, tp, TUMOR_EXTRACTION_REVIEW_PATH + ) + finalized_brain_mask_path = get_manual_approval_finalized_path( + id, tp, BRAIN_MASK_REVIEW_PATH + ) + under_review_brain_mask_path = get_manual_approval_under_review_path( + id, tp, BRAIN_MASK_REVIEW_PATH + ) + brain_mask_filepath = os.path.join(out_path, BRAIN_MASK_FILE) + brain_mask_review_filepath = os.path.join( + under_review_brain_mask_path, BRAIN_MASK_FILE + ) + + os.makedirs(out_pred_path, exist_ok=True) + + for i, model in enumerate(models): + order = self.__get_mod_order(model) + tmp_data_path, tmp_out_path, input_modalities = self.__prepare_case( + self.out_path, id, tp, order + ) + filename = f"{id}_{tp}_tumorMask_model_{i}.nii.gz" + out_pred_filepath = os.path.join(out_pred_path, filename) + under_review_filepath = os.path.join(under_review_tumor_path, filename) + self.__run_model(model, tmp_data_path, tmp_out_path) + output = self.__finalize_pred( + tmp_out_path, out_pred_filepath, under_review_filepath + ) + outputs.append(output) + images_for_fusion.append(sitk.ReadImage(output, sitk.sitkUInt8)) + + # cleanup + shutil.rmtree(tmp_data_path, ignore_errors=True) + shutil.rmtree(tmp_out_path, ignore_errors=True) + + fused_outputs = generate_tumor_segmentation_fused_images( + images_for_fusion, out_pred_path, subject_id + ) + outputs += fused_outputs + + for output in outputs: + # save the screenshot + tumor_mask_id = os.path.basename(output).replace(".nii.gz", "") + save_screenshot( + input_modalities, + os.path.join( + out_path, + f"{tumor_mask_id}_summary.png", + ), + output, + ) + + os.makedirs(under_review_brain_mask_path, exist_ok=True) + shutil.copy(brain_mask_filepath, brain_mask_review_filepath) + os.makedirs(finalized_brain_mask_path, exist_ok=True) + os.makedirs(finalized_tumor_path, exist_ok=True) diff --git a/examples/RANO/data_preparator_workflow/pipeline/project/stages/generate_report.py b/examples/RANO/data_preparator_workflow/pipeline/project/stages/generate_report.py new file mode 100644 index 000000000..a69a368e3 --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/project/stages/generate_report.py @@ -0,0 +1,468 @@ +from .dset_stage import DatasetStage +import pandas as pd +import numpy as np +import os +import re +import shutil +from typing import Tuple +from .utils import has_prepared_folder_structure, md5_dir, get_data_csv_filepath +from .constants import INTERIM_FOLDER, FINAL_FOLDER, TUMOR_MASK_FOLDER +from .mlcube_constants import SETUP_STAGE_STATUS, FINALIZED_PATH + +DICOM_MODALITIES_PREFIX = { + "fl": "t2_Flair", + "t1": "t1_axial-3", + "t1c": "t1_axial_stealth", + "t2": "T2_SAG", +} +NIFTI_MODALITIES = ["t1c", "t1n", "t2f", "t2w"] +BRAIN_SCAN_NAME = "brain_(.*)" +TUMOR_SEG_NAME = "final_seg" +CSV_HEADERS = ["SubjectID", "Timepoint", "T1", "T1GD", "T2", "FLAIR"] + + +def get_index(subject, timepoint): + return f"{subject}|{timepoint}" + + +def has_alternative_folder_structure(subject_tp_path, og_path): + contents = os.listdir(subject_tp_path) + prefixes_presence = {prefix: False for prefix in DICOM_MODALITIES_PREFIX.values()} + for content in contents: + content_path = os.path.join(subject_tp_path, content) + # Search recursively across folders + if os.path.isdir(content_path): + return has_alternative_folder_structure(content_path, og_path) + + # Check if the file is a dicom file with an expected prefix + if not content.endswith(".dcm"): + continue + + for prefix in DICOM_MODALITIES_PREFIX.values(): + if content.startswith(prefix): + prefixes_presence[prefix] = True + + # If all prefixes are found within the current path, then it has the folder structure + if all(prefixes_presence.values()): + return True, subject_tp_path + + # Structure not identified at this tree + return False, og_path + + +def to_expected_folder_structure(subject_tp_path, contents_path): + # Create the modality folders + for modality in DICOM_MODALITIES_PREFIX.keys(): + modality_path = os.path.join(subject_tp_path, modality) + os.mkdir(modality_path) + + # Move the dicoms to the needed location + dicoms = os.listdir(contents_path) + prefix2mod = {prefix: mod for mod, prefix in DICOM_MODALITIES_PREFIX.items()} + for dicom in dicoms: + for prefix in prefix2mod.keys(): + if not dicom.startswith(prefix): + continue + mod = prefix2mod[prefix] + old_path = os.path.join(contents_path, dicom) + new_path = os.path.join(subject_tp_path, mod, dicom) + shutil.move(old_path, new_path) + + # Remove extra folders + desired_folders = set(DICOM_MODALITIES_PREFIX.keys()) + found_folders = set(os.listdir(subject_tp_path)) + extra_folders = found_folders - desired_folders + for folder in extra_folders: + folder_path = os.path.join(subject_tp_path, folder) + shutil.rmtree(folder_path) + + +def has_semiprepared_folder_structure(subject_tp_path, og_path, recursive=True): + contents = os.listdir(subject_tp_path) + suffixes_presence = {suffix: False for suffix in NIFTI_MODALITIES} + for content in contents: + content_path = os.path.join(subject_tp_path, content) + if os.path.isdir(content_path): + if recursive: + return has_semiprepared_folder_structure(content_path, og_path) + else: + continue + + if not content.endswith(".nii.gz"): + continue + + for suffix in NIFTI_MODALITIES: + full_suffix = f"_brain_{suffix}.nii.gz" + if content.endswith(full_suffix): + suffixes_presence[suffix] = True + + if all(suffixes_presence.values()): + return True, subject_tp_path + + return False, og_path + + +def get_timepoints(subject, subject_tp_path): + contents = os.listdir(subject_tp_path) + timepoints = set() + for content in contents: + content_path = os.path.join(subject_tp_path, subject) + if os.path.isdir(content_path): + # Assume any directory at this point represents a timepoint + timepoints.add(content) + continue + + pattern = re.compile( + f"{subject}_(.*)_(?:{BRAIN_SCAN_NAME}|{TUMOR_SEG_NAME})\.nii\.gz" + ) + result = pattern.search(content) + timepoint = result.group(1) + timepoints.add(timepoint) + + return list(timepoints) + + +def get_tumor_segmentation(subject, timepoint, subject_tp_path): + contents = os.listdir(subject_tp_path) + seg_file = f"{subject}_{timepoint}_{TUMOR_SEG_NAME}.nii.gz" + if seg_file in contents: + return seg_file + return None + + +def move_brain_scans(subject, timepoint, in_subject_path, out_data_path): + final_path = os.path.join(out_data_path, FINAL_FOLDER, subject, timepoint) + os.makedirs(final_path, exist_ok=True) + + contents = os.listdir(in_subject_path) + + pattern = re.compile(f"{subject}_{timepoint}_{BRAIN_SCAN_NAME}\.nii\.gz") + brain_scans = [content for content in contents if pattern.match(content)] + + for scan in brain_scans: + in_scan = os.path.join(in_subject_path, scan) + out_scan = os.path.join(final_path, scan) + shutil.copyfile(in_scan, out_scan) + + +def move_tumor_segmentation( + subject, timepoint, seg_file, in_subject_path, out_data_path, out_labels_path +): + interim_path = os.path.join(out_data_path, INTERIM_FOLDER, subject, timepoint) + os.makedirs(interim_path, exist_ok=True) + + in_seg_path = os.path.join(in_subject_path, seg_file) + tumor_mask_path = os.path.join(interim_path, TUMOR_MASK_FOLDER) + under_review_path = os.path.join(tumor_mask_path, "under_review") + finalized_path = os.path.join(tumor_mask_path, FINALIZED_PATH) + os.makedirs(under_review_path, exist_ok=True) + os.makedirs(finalized_path, exist_ok=True) + + seg_root_path = os.path.join(tumor_mask_path, seg_file) + seg_under_review_path = os.path.join(under_review_path, seg_file) + seg_finalized_path = os.path.join(finalized_path, seg_file) + shutil.copyfile(in_seg_path, seg_root_path) + shutil.copyfile(in_seg_path, seg_under_review_path) + shutil.copyfile(in_seg_path, seg_finalized_path) + + # Place the segmentation in the backup folder + backup_path = os.path.join(out_labels_path, ".tumor_segmentation_backup") + subject_tp_backup_path = os.path.join( + backup_path, subject, timepoint, TUMOR_MASK_FOLDER + ) + os.makedirs(subject_tp_backup_path, exist_ok=True) + seg_backup_path = os.path.join(subject_tp_backup_path, seg_file) + shutil.copyfile(in_seg_path, seg_backup_path) + + return in_seg_path, seg_finalized_path + + +def write_partial_csv(csv_path, subject, timepoint): + # Used when cases are semi-prepared, in which case they + # skip the formal csv creation + if csv_path is None: + csv_path = get_data_csv_filepath(os.path.join(subject, timepoint)) + + if os.path.exists(csv_path): + df = pd.read_csv(csv_path) + else: + df = pd.DataFrame(columns=CSV_HEADERS) + + row = pd.Series(index=CSV_HEADERS) + row["SubjectID"] = subject + row["Timepoint"] = timepoint + row.name = get_index(subject, timepoint) + row = row.fillna("") + + # Check for existence of this row + row_search = df[(df["SubjectID"] == subject) & (df["Timepoint"] == timepoint)] + if len(row_search) == 0: + df = df.append(row) + + df.to_csv(csv_path, index=False) + + +class InitialSetup(DatasetStage): + def __init__( + self, + data_csv: str, + input_path: str, + output_path: str, + input_labels_path: str, + output_labels_path, + done_data_out_path: str, + done_status: int, + brain_data_out_path: str, + brain_status: int, + tumor_data_out_path: str, + reviewed_status: int, + ): + self.data_csv = data_csv + self.input_path = input_path + self.output_path = output_path + self.input_labels_path = input_labels_path + self.output_labels_path = output_labels_path + self.done_data_out_path = done_data_out_path + self.done_status_code = done_status + self.brain_data_out_path = brain_data_out_path + self.brain_status = brain_status + self.tumor_data_out_path = tumor_data_out_path + self.reviewed_status = reviewed_status + + @property + def name(self) -> str: + return "Initial Setup" + + @property + def status_code(self) -> int: + return SETUP_STAGE_STATUS + + def _proceed_to_comparison(self, subject, timepoint, in_subject_path, report): + index = get_index(subject, timepoint) + final_path = os.path.join( + self.tumor_data_out_path, FINAL_FOLDER, subject, timepoint + ) + input_hash = md5_dir(in_subject_path) + # Stop if the subject was already present and no input change has happened + if index in report.index: + if input_hash == report.loc[index]["input_hash"]: + return report + + # Move brain scans to its expected location + move_brain_scans(subject, timepoint, in_subject_path, self.tumor_data_out_path) + + # Move tumor segmentation to its expected location + seg_file = f"{subject}_{timepoint}_{TUMOR_SEG_NAME}.nii.gz" + _, seg_finalized_path = move_tumor_segmentation( + subject, + timepoint, + seg_file, + in_subject_path, + self.tumor_data_out_path, + self.output_labels_path, + ) + + # Update the report + data = { + "status": self.reviewed_status, + "data_path": final_path, + "labels_path": seg_finalized_path, + "num_changed_voxels": np.nan, + "brain_mask_hash": "", + "segmentation_hash": "", + "input_hash": input_hash, + } + + subject_series = pd.Series(data) + subject_series.name = index + report = report.append(subject_series) + + write_partial_csv(self.data_csv, subject, timepoint) + + return report + + def _proceed_to_tumor_extraction(self, subject, timepoint, in_subject_path, report): + index = get_index(subject, timepoint) + input_hash = md5_dir(in_subject_path) + # Stop if the subject was already present and no input change has happened + if index in report.index: + if input_hash == report.loc[index]["input_hash"]: + return report + final_path = os.path.join( + self.brain_data_out_path, FINAL_FOLDER, subject, timepoint + ) + labels_path = os.path.join( + self.brain_data_out_path, INTERIM_FOLDER, subject, timepoint + ) + os.makedirs(final_path, exist_ok=True) + os.makedirs(labels_path, exist_ok=True) + + # Move brain scans to its expected location + move_brain_scans(subject, timepoint, in_subject_path, self.brain_data_out_path) + + # Update the report + data = { + "status": self.brain_status, + "data_path": final_path, + "labels_path": labels_path, + "num_changed_voxels": np.nan, + "brain_mask_hash": "", + "segmentation_hash": "", + "input_hash": input_hash, + } + + subject_series = pd.Series(data) + subject_series.name = index + report = report.append(subject_series) + + write_partial_csv(self.data_csv, subject, timepoint) + + return report + + def could_run(self, report: pd.DataFrame): + return True + + def execute(self, report: pd.DataFrame) -> Tuple[pd.DataFrame, bool]: + # Rewrite the report + cols = [ + "status", + "status_name", + "comment", + "data_path", + "labels_path", + "input_hash", + ] + print("Initializing report") + if report is None: + print("No previous report was identified. Creating a new one") + report = pd.DataFrame(columns=cols) + + input_is_prepared = has_prepared_folder_structure( + self.input_path, self.input_labels_path + ) + if input_is_prepared: + # If prepared, store data directly in the data folder + print("Input data looks prepared already. Skipping preprocessing") + self.output_path = self.done_data_out_path + + observed_cases = set() + + for subject in os.listdir(self.input_path): + in_subject_path = os.path.join(self.input_path, subject) + out_subject_path = os.path.join(self.output_path, subject) + in_labels_subject_path = os.path.join(self.input_labels_path, subject) + out_labels_subject_path = os.path.join(self.output_labels_path, subject) + + if not os.path.isdir(in_subject_path): + continue + + has_semiprepared, _ = has_semiprepared_folder_structure( + in_subject_path, in_subject_path, recursive=False + ) + if has_semiprepared: + timepoints = get_timepoints(subject, in_subject_path) + for timepoint in timepoints: + index = get_index(subject, timepoint) + tumor_seg = get_tumor_segmentation( + subject, timepoint, in_subject_path + ) + if tumor_seg is not None: + report = self._proceed_to_comparison( + subject, timepoint, in_subject_path, report + ) + else: + report = self._proceed_to_tumor_extraction( + subject, timepoint, in_subject_path, report + ) + observed_cases.add(index) + continue + + for timepoint in os.listdir(in_subject_path): + in_tp_path = os.path.join(in_subject_path, timepoint) + out_tp_path = os.path.join(out_subject_path, timepoint) + in_labels_tp_path = os.path.join(in_labels_subject_path, timepoint) + out_labels_tp_path = os.path.join(out_labels_subject_path, timepoint) + + if not os.path.isdir(in_tp_path): + continue + + input_hash = md5_dir(in_tp_path) + + index = get_index(subject, timepoint) + + # Keep track of the cases that were found on the input folder + observed_cases.add(index) + + has_semiprepared, in_tp_path = has_semiprepared_folder_structure( + in_tp_path, in_tp_path, recursive=True + ) + if has_semiprepared: + tumor_seg = get_tumor_segmentation(subject, timepoint, in_tp_path) + if tumor_seg is not None: + report = self._proceed_to_comparison( + subject, timepoint, in_tp_path, report + ) + else: + report = self._proceed_to_tumor_extraction( + subject, timepoint, in_tp_path, report + ) + continue + + if index in report.index: + # Case has already been identified, see if input hash is different + # or if status is corrupted + # if so, override the contents and restart the state for that case + case = report.loc[index] + has_not_changed = case["input_hash"] == input_hash + has_a_valid_status = not np.isnan(case["status"]) + if has_not_changed and has_a_valid_status: + continue + + print( + f"Case {index} has either changed ({not has_not_changed}) or has a corrupted status ({not has_a_valid_status}). Starting from scratch" + ) + + shutil.rmtree(out_tp_path, ignore_errors=True) + shutil.copytree(in_tp_path, out_tp_path) + report = report.drop(index) + else: + # New case not identified by the report. Add it + print(f"New case identified: {index}. Adding to report") + shutil.rmtree(out_tp_path, ignore_errors=True) + shutil.copytree(in_tp_path, out_tp_path) + + data = { + "status": self.status_code, + "data_path": out_tp_path, + "labels_path": "", + "num_changed_voxels": np.nan, + "brain_mask_hash": "", + "segmentation_hash": "", + "input_hash": input_hash, + } + + has_alternative, contents_path = has_alternative_folder_structure( + out_tp_path, out_tp_path + ) + if has_alternative: + # Move files around so it has the expected structure + to_expected_folder_structure(out_tp_path, contents_path) + + if input_is_prepared: + data["status_code"] = self.done_status_code + shutil.rmtree(out_labels_tp_path, ignore_errors=True) + shutil.copytree(in_labels_tp_path, out_labels_tp_path) + + subject_series = pd.Series(data) + subject_series.name = index + report = report.append(subject_series) + + reported_cases = set(report.index) + removed_cases = reported_cases - observed_cases + + # Stop reporting removed cases + for case_index in removed_cases: + report = report.drop(case_index) + + report = report.sort_index() + return report diff --git a/examples/RANO/data_preparator_workflow/pipeline/project/stages/get_csv.py b/examples/RANO/data_preparator_workflow/pipeline/project/stages/get_csv.py new file mode 100644 index 000000000..9bf7deca8 --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/project/stages/get_csv.py @@ -0,0 +1,123 @@ +from .row_stage import RowStage +from .CreateCSVForDICOMs import CSVCreator +from .utils import get_id_tp +import pandas as pd +from typing import Union, Tuple +import os +import shutil +from .mlcube_constants import CSV_STAGE_STATUS + + +class AddToCSV(RowStage): + def __init__( + self, input_dir: str, output_csv: str, out_dir: str, prev_stage_path: str + ): + self.input_dir = input_dir + self.output_csv = output_csv + self.out_dir = out_dir + self.prev_stage_path = prev_stage_path + os.makedirs(self.out_dir, exist_ok=True) + self.csv_processor = CSVCreator(self.input_dir, self.output_csv) + if os.path.exists(self.output_csv): + # Use the updated version of the CSV + self.contents = pd.read_csv(self.output_csv) + self.csv_processor.output_df_for_csv = self.contents + else: + # Use the default, empty version + self.contents = self.csv_processor.output_df_for_csv + + @property + def name(self) -> str: + return "Initial Validation" + + @property + def status_code(self) -> int: + return CSV_STAGE_STATUS + + def could_run(self, index: Union[str, int], report: pd.DataFrame) -> bool: + """Determines if getting a new CSV is necessary. + This is done by checking the existence of the expected file + + Args: + index (Union[str, int]): case index in the report + report (pd.DataFrame): Dataframe containing the current state of the preparation flow + + Returns: + bool: wether this stage could be executed + """ + print(f"Checking if {self.name} can run") + id, tp = get_id_tp(index) + prev_case_path = os.path.join(self.prev_stage_path, id, tp) + is_valid = os.path.exists(prev_case_path) + print(f"{is_valid=}") + return is_valid + + def execute(self, index: Union[str, int]) -> Tuple[pd.DataFrame, bool]: + """Adds valid cases to the data csv that is used for later processing + Invalid cases are flagged in the report + + Args: + index (Union[str, int]): case index in the report + report (pd.DataFrame): DataFrame containing the current state of the preparation flow + + Returns: + pd.DataFrame: Updated report dataframe + """ + id, tp = get_id_tp(index) + subject_path = os.path.join(self.input_dir, id) + tp_path = os.path.join(subject_path, tp) + subject_out_path = os.path.join(self.out_dir, id) + tp_out_path = os.path.join(subject_out_path, tp) + # We will first copy the timepoint to the out folder + # This is so, if successful, the csv will point to the data + # in the next stage, instead of the previous + shutil.rmtree(tp_out_path, ignore_errors=True) + shutil.copytree(tp_path, tp_out_path) + + try: + self.csv_processor.process_timepoint(tp, id, subject_out_path) + report_data = { + "status": self.status_code, + "data_path": tp_out_path, + "labels_path": "", + } + except Exception as e: + report_data = { + "status": -self.status_code - 0.3, + "comment": str(e), + "data_path": tp_path, + "labels_path": "", + } + raise + + missing = self.csv_processor.subject_timepoint_missing_modalities + extra = self.csv_processor.subject_timepoint_extra_modalities + + success = True + report_data["comment"] = "" + for missing_subject, msg in missing: + if f"{id}_{tp}" in missing_subject: + # Differentiate errors by floating point value + status_code = -self.status_code - 0.1 # -1.1 + report_data["status"] = status_code + report_data["data_path"] = tp_path + report_data["comment"] += "\n\n" + msg + success = False + + for extra_subject, msg in extra: + if f"{id}_{tp}" in extra_subject: + # Differentiate errors by floating point value + status_code = -self.status_code - 0.2 # -1.2 + report_data["status"] = status_code + report_data["data_path"] = tp_path + report_data["comment"] += "\n\n" + msg + success = False + + report_data["comment"] = report_data["comment"].strip() + if not success: + shutil.rmtree(tp_out_path, ignore_errors=True) + raise TypeError(report_data["comment"]) + + self.csv_processor.write() + + return tp_out_path diff --git a/examples/RANO/data_preparator_workflow/pipeline/project/stages/manual.py b/examples/RANO/data_preparator_workflow/pipeline/project/stages/manual.py new file mode 100644 index 000000000..dc2bb562b --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/project/stages/manual.py @@ -0,0 +1,182 @@ +from typing import Union, Tuple +import pandas as pd +import os +import shutil +import json +from .row_stage import RowStage +from .constants import INTERIM_FOLDER, FINAL_FOLDER +from .env_vars import DATA_DIR +from .mlcube_constants import ( + MANUAL_STAGE_STATUS, + BRAIN_MASK_CHANGED_FILE, + PREP_PATH, + BRAIN_MASK_REVIEW_PATH, + TUMOR_EXTRACTION_REVIEW_PATH, + BRAIN_MASK_FILE, + GROUND_TRUTH_PATH, +) +from .utils import ( + get_id_tp, + set_files_read_only, + copy_files, + get_manual_approval_finalized_path, + get_manual_approval_base_path, + delete_empty_directory, +) + + +class ManualStage(RowStage): + def __init__( + self, + data_csv: str, + out_path: str, + prev_stage_path: str, + backup_path: str, + ): + self.data_csv = data_csv + self.out_path = out_path + self.prev_stage_path = prev_stage_path + self.backup_path = backup_path + + @property + def name(self): + return "Manual review" + + @property + def status_code(self) -> int: + return MANUAL_STAGE_STATUS + + def __get_input_paths(self, index: Union[str, int]): + id, tp = get_id_tp(index) + print(f"{self.prev_stage_path=}") + tumor_mask_path = os.path.join( + self.prev_stage_path, INTERIM_FOLDER, id, tp, GROUND_TRUTH_PATH + ) + brain_mask_dir = get_manual_approval_finalized_path( + id, tp, BRAIN_MASK_REVIEW_PATH + ) + brain_mask_path = os.path.join(brain_mask_dir, BRAIN_MASK_FILE) + return tumor_mask_path, brain_mask_path + + def __get_output_path(self, index: Union[str, int]): + id, tp = get_id_tp(index) + path = get_manual_approval_finalized_path(id, tp, TUMOR_EXTRACTION_REVIEW_PATH) + return path + + def __get_backup_path(self, index: Union[str, int]): + id, tp = get_id_tp(index) + path = os.path.join(self.backup_path, id, tp, GROUND_TRUTH_PATH) + return path + + def __get_rollback_paths(self, index: Union[str, int]): + id, tp = get_id_tp(index) + base_rollback_path = os.path.join(DATA_DIR, PREP_PATH, id, tp) + data_path = os.path.join(base_rollback_path, FINAL_FOLDER, id, tp) + labels_path = os.path.join(base_rollback_path, INTERIM_FOLDER, id, tp) + return data_path, labels_path + + def rollback(self, index): + # Unhide the rollback paths + rollback_paths = self.__get_rollback_paths(index) + for rollback_path in rollback_paths: + rollback_dirname = os.path.dirname(rollback_path) + rollback_basename = os.path.basename(rollback_path) + hidden_rollback_path = os.path.join( + rollback_dirname, f".{rollback_basename}" + ) + + if os.path.exists(hidden_rollback_path): + shutil.move(hidden_rollback_path, rollback_path) + + # Move the modified brain mask to the rollback path + _, rollback_labels_path = rollback_paths + tumor_masks_path, brain_mask_path = self.__get_input_paths(index) + rollback_brain_mask_path = os.path.join(rollback_labels_path, BRAIN_MASK_FILE) + if os.path.exists(rollback_brain_mask_path): + os.remove(rollback_brain_mask_path) + shutil.move(brain_mask_path, rollback_brain_mask_path) + + # Remove the complete subject path + subject_path = os.path.abspath( + os.path.join(tumor_masks_path, "..", "..", "..", "..") + ) + + shutil.rmtree(subject_path) + + # Also remove the manual review directory for this subject/timepoint + id, tp = get_id_tp(index) + for approval_type in [TUMOR_EXTRACTION_REVIEW_PATH, BRAIN_MASK_REVIEW_PATH]: + manual_review_path = get_manual_approval_base_path(id, tp, approval_type) + shutil.rmtree(manual_review_path) + subject_review_path = os.path.abspath( + os.path.join(manual_review_path, "..") + ) + delete_empty_directory(subject_review_path) + + def prepare_directories(self, index: Union[str, int]) -> Tuple[str, str]: + # Generate a hidden copy of the baseline segmentations + in_path, brain_path = self.__get_input_paths(index) + out_path = self.__get_output_path(index) + bak_path = self.__get_backup_path(index) + print(f"{in_path=}") + print(f"{out_path=}") + print(f"{bak_path=}") + if not os.path.exists(bak_path) or not os.listdir(bak_path): + print("Entered if") + copy_files(in_path, bak_path) + set_files_read_only(bak_path) + + return out_path, brain_path + + def could_run(self, index: Union[str, int], report: pd.DataFrame) -> bool: + print(f"Checking if {self.name} can run") + out_path = self.__get_output_path(index) + cases = [] + if os.path.exists(out_path): + cases = os.listdir(out_path) + + in_path, brain_path = self.__get_input_paths(index) + brain_mask_hash = "" + if os.path.exists(brain_path): + brain_mask_hash = md5_file(brain_path) + expected_brain_mask_hash = report.loc[index, "brain_mask_hash"] + + segmentation_exists = os.path.exists(in_path) + annotation_exists = len(cases) == 1 + brain_mask_changed = brain_mask_hash != expected_brain_mask_hash + print( + f"{segmentation_exists=} and (not {annotation_exists=} or {brain_mask_changed=})" + ) + return segmentation_exists and (not annotation_exists or brain_mask_changed) + + def execute( + self, index: Union[str, int], report: pd.DataFrame = None + ) -> Tuple[pd.DataFrame, bool]: + """Manual steps are by definition not doable by an algorithm. Therefore, + execution of this step leads to a failed stage message, indicating that + the manual step has not been done. + + Args: + index (Union[str, int]): current case index + report (pd.DataFrame): data preparation report + + Returns: + pd.DataFrame: _description_ + """ + + # Generate a hidden copy of the baseline segmentations + out_path, brain_path = self.prepare_directories(index) + + if report is None: + report = load_report() + brain_mask_changed, brain_mask_hash = self.check_brain_mask_changed( + index, brain_path, report + ) + + if brain_mask_changed: + # Found brain mask changed + self.rollback(index) + # Label this as able to continue + return self.__report_rollback(index, report, brain_mask_hash), True + + return self.check_finalized_cases(index, report, out_path) diff --git a/examples/RANO/data_preparator_workflow/pipeline/project/stages/mlcube_constants.py b/examples/RANO/data_preparator_workflow/pipeline/project/stages/mlcube_constants.py new file mode 100644 index 000000000..90c78c807 --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/project/stages/mlcube_constants.py @@ -0,0 +1,38 @@ +RAW_PATH = "raw" +AUX_FILES_PATH = "auxiliary_files" +VALID_PATH = "validated" +PREP_PATH = "prepared" +BRAIN_PATH = "brain_extracted" +TUMOR_PATH = "tumor_extracted" +LABELS_PATH = "labels" +TUMOR_BACKUP_PATH = ".tumor_segmentation_backup" +OUT_CSV = "data.csv" +TRASH_PATH = ".trash" +INVALID_FILE = ".invalid.txt" +REPORT_FILE = "report.yaml" +BRAIN_MASK_FILE = "brainMask_fused.nii.gz" +METADATA_PATH = "metadata" +CHANGED_VOXELS_FILE = ".changed_voxels.txt" + +# Directories used for the Manual Approval steps +MANUAL_REVIEW_PATH = "manual_review" +BRAIN_MASK_REVIEW_PATH = "brain_mask" +TUMOR_EXTRACTION_REVIEW_PATH = "tumor_extraction" +UNDER_REVIEW_PATH = "under_review" +FINALIZED_PATH = "finalized" + +# Backup segmentation in case the user changes the one being used +GROUND_TRUTH_PATH = ".ground_truth" + +# JSON file (just true/false) for evaluating brain mask changes +BRAIN_MASK_CHANGED_FILE = "brain_mask_changed.json" + +SETUP_STAGE_STATUS = 0 +CSV_STAGE_STATUS = 1 +NIFTI_STAGE_STATUS = 2 +BRAIN_STAGE_STATUS = 3 +TUMOR_STAGE_STATUS = 4 +MANUAL_STAGE_STATUS = 5 +COMPARISON_STAGE_STATUS = 6 +CONFIRM_STAGE_STATUS = 7 +DONE_STAGE_STATUS = 8 diff --git a/examples/RANO/data_preparator_workflow/pipeline/project/stages/nifti_transform.py b/examples/RANO/data_preparator_workflow/pipeline/project/stages/nifti_transform.py new file mode 100644 index 000000000..f5e50f8e9 --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/project/stages/nifti_transform.py @@ -0,0 +1,146 @@ +from typing import Union +from tqdm import tqdm +import pandas as pd +import os +import shutil + +from .row_stage import RowStage +from .PrepareDataset import Preparator +from .utils import get_id_tp +from .mlcube_constants import NIFTI_STAGE_STATUS +from .constants import FINAL_FOLDER, EXEC_NAME + + +class NIfTITransform(RowStage): + def __init__( + self, + data_csv: str, + out_path: str, + prev_stage_path: str, + metadata_path: str, + data_out: str, + ): + self.data_csv = data_csv + self.out_path = out_path + self.data_out = data_out + self.prev_stage_path = prev_stage_path + self.metadata_path = metadata_path + os.makedirs(self.out_path, exist_ok=True) + self.prep = Preparator(data_csv, out_path, EXEC_NAME) + # self.pbar = pbar + self.pbar = tqdm() + + @property + def name(self) -> str: + return "NiFTI Conversion" + + @property + def status_code(self) -> int: + return NIFTI_STAGE_STATUS + + def could_run(self, index: Union[str, int], report: pd.DataFrame) -> bool: + """Determine if case at given index needs to be converted to NIfTI + + Args: + index (Union[str, int]): Case index, as used by the report dataframe + report (pd.DataFrame): Report Dataframe for providing additional context + + Returns: + bool: Wether this stage could be executed for the given case + """ + print(f"Checking if {self.name} can run") + id, tp = get_id_tp(index) + prev_case_path = os.path.join(self.prev_stage_path, id, tp) + if os.path.exists(prev_case_path): + is_valid = len(os.listdir(prev_case_path)) > 0 + print(f"{is_valid}") + return is_valid + return False + + def execute( + self, + index: Union[str, int], + ) -> pd.DataFrame: + """Executes the NIfTI transformation stage on the given case + + Args: + index (Union[str, int]): case index, as used by the report + report (pd.DataFrame): DataFrame containing the current state of the preparation flow + + Returns: + pd.DataFrame: Updated report dataframe + """ + self.__prepare_exec() + self.__process_case(index) + self.__cleanup_artifacts(index) + success = self.__validate_result(index) + self.prep.write() + self.__update_metadata(index) + + return success + + def __cleanup_artifacts(self, index): + unused_artifacts_substrs = ["raw", "to_SRI", ".mat"] + _, out_path = self.__get_output_paths(index) + root_artifacts = os.listdir(out_path) + for artifact in root_artifacts: + if not any([substr in artifact for substr in unused_artifacts_substrs]): + continue + artifact_path = os.path.join(out_path, artifact) + os.remove(artifact_path) + + def __get_output_paths(self, index: Union[str, int]): + id, tp = get_id_tp(index) + fets_path = os.path.join(self.prep.final_output_dir, id, tp) + qc_path = os.path.join(self.prep.interim_output_dir, id, tp) + + return fets_path, qc_path + + def __prepare_exec(self): + # Reset the file contents for errors + open(self.prep.stderr_log, "w").close() + + self.prep.read() + + def __process_case(self, index: Union[str, int]): + id, tp = get_id_tp(index) + df = self.prep.subjects_df + row = df[(df["SubjectID"] == id) & (df["Timepoint"] == tp)].iloc[0] + self.prep.convert_to_dicom(hash(index), row, self.pbar) + + def __undo_current_stage_changes(self, index: Union[str, int]): + fets_path, qc_path = self.__get_output_paths(index) + shutil.rmtree(fets_path, ignore_errors=True) + shutil.rmtree(qc_path, ignore_errors=True) + + def __validate_result(self, index: Union[str, int]) -> pd.DataFrame: + id, tp = get_id_tp(index) + failing = self.prep.failing_subjects + failing_subject = failing[ + (failing["SubjectID"] == id) & (failing["Timepoint"] == tp) + ] + if len(failing_subject): + self.__undo_current_stage_changes(index) + self.__report_failure() + else: + success = True + + return success + + def __update_metadata(self, index): + id, tp = get_id_tp(index) + fets_path = os.path.join(self.out_path, FINAL_FOLDER) + outfile_dir = os.path.join(self.metadata_path, id, tp) + os.makedirs(outfile_dir, exist_ok=True) + for file in os.listdir(fets_path): + filepath = os.path.join(fets_path, file) + out_filepath = os.path.join(self.metadata_path, id, tp, file) + if os.path.isfile(filepath) and filepath.endswith(".yaml"): + shutil.copyfile(filepath, out_filepath) + + def __report_failure(self): + + with open(self.prep.stderr_log, "r") as f: + msg = f.read() + + raise TypeError(msg) diff --git a/examples/RANO/data_preparator_workflow/pipeline/project/stages/pipeline.py b/examples/RANO/data_preparator_workflow/pipeline/project/stages/pipeline.py new file mode 100644 index 000000000..c709fff25 --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/project/stages/pipeline.py @@ -0,0 +1,285 @@ +from pandas import DataFrame +from typing import Union, List, Tuple +from tqdm import tqdm +import traceback +import os + +from .utils import write_report +from .dset_stage import DatasetStage +from .row_stage import RowStage +from .stage import Stage +from .utils import cleanup_storage +from .mlcube_constants import DONE_STAGE_STATUS + + +class Pipeline: + def __init__( + self, + init_stage: DatasetStage, + stages: List[Union[DatasetStage, RowStage]], + staging_folders: List[str], + trash_folders: List[str], + invalid_subjects_file: str, + ): + self.init_stage = init_stage + self.stages = stages + self.staging_folders = staging_folders + self.trash_folders = trash_folders + self.invalid_subjects_file = invalid_subjects_file + + def __invalid_subjects(self) -> List[str]: + """Retrieve invalid subjects + + Returns: + List[str]: list of invalid subjects + """ + if not os.path.exists(self.invalid_subjects_file): + open(self.invalid_subjects_file, "a").close() + + with open(self.invalid_subjects_file, "r") as f: + invalid_subjects = set([line.strip() for line in f]) + + return invalid_subjects + + def __is_subject_done(self, subject: Union[str, int], report: DataFrame) -> bool: + """Determines if a subject is considered done + + Args: + subject (Union[str, int]): subject index + report (DataFrame): DataFrame containing the state of the processing + + Returns: + bool: wether the subject is done or not + """ + subject_status = report.loc[subject, "status"] + + return subject_status == DONE_STAGE_STATUS + + def __is_done(self, report: DataFrame) -> bool: + """Determines if the preparation is complete + + Args: + report (DataFrame): DataFrame containing the state of the processing + + Returns: + bool: Wether the preparation is complete + """ + return all(report["status"] == DONE_STAGE_STATUS) + + def __get_report_stage_to_run( + self, subject: Union[str, int], report: DataFrame + ) -> Union[DatasetStage, RowStage]: + """Retrieves the stage a subject is in indicated by the report + + Args: + subject (Union[str, int]): Subject index + report (DataFrame): Dataframe containing the state of the processing + + Returns: + Union[DatasetStage, RowStage]: Stage the current subject is in + """ + report_status_code = int(report.loc[subject, "status"]) + if report_status_code < 0: + # Error code, rerun the stage specified in the report + report_status_code = abs(report_status_code) + else: + # Success code, reported stage works so move to the next one + report_status_code += 1 + for stage in self.stages: + if stage.status_code == report_status_code: + return stage + + return None + + def determine_next_stage( + self, subject: Union[str, int], report + ) -> Tuple[List[Union[DatasetStage, RowStage]], bool]: + """Determines what stage to run + First priority goes to a stage if it is the only one that could run. (only one stage can run) + Second priority goes to what the report says should run next. (The report knows what stage can run) + Third priority goes to the first of all possible stages that could run. (Earliest of all possible stages) + + Args: + subject (Union[str, int]): Subject name (SubjectID, Timepoint) + report (pd.DataFrame): report dataframe + + Returns: + Tuple[List[Union[DatasetStage, RowStage]], bool]: Stage to run, and wether it is done or not + """ + could_run_stages = [] + for i, stage in enumerate(self.stages): + could_run = False + if isinstance(stage, RowStage): + could_run = stage.could_run(subject, report) + else: + could_run = stage.could_run(report) + + if could_run: + runnable_stage = self.stages[i] + could_run_stages.append(runnable_stage) + + print(f"Possible next stages: {[stage.name for stage in could_run_stages]}") + + # TODO: split into a function + if len(could_run_stages) == 1: + stage = could_run_stages[0] + is_last_subject = subject == report.index[-1] + if isinstance(stage, DatasetStage) and not is_last_subject: + # Only run dataset stages on the last subject, so all subjects can update + # their state if needed before proceeding + return None, False + return stage, False + + # Handle errors + # Either no stage can be executed (len(could_run_stages == 0)) + # or multiple stages can be executed (len(could_run_stages > 1)) + report_stage = self.__get_report_stage_to_run(subject, report) + if report_stage is not None: + print(f"Reported next stage: {report_stage.name}") + + # TODO: split into a function + if len(could_run_stages) == 0: + # Either the case processing was on-going but it's state is broken + # or the next stage is a dataset stage, which means we're done with this one + # or the case is done and no stage can nor should run + # We can tell this by looking at the report + is_done = self.__is_subject_done(subject, report) + is_dset_stage = isinstance(report_stage, DatasetStage) + if is_done or is_dset_stage: + return None, True + else: + return None, False + # TODO: split into a function + else: + # Multiple stages could run. Remove ambiguity by + # syncing with the report + if report_stage in could_run_stages: + return report_stage, False + + return could_run_stages[0], False + + def run(self, report: DataFrame, report_path: str): + # cleanup the trash at the very beginning + cleanup_storage(self.trash_folders) + + # The init stage always has to be executed + report, _ = self.init_stage.execute(report) + write_report(report, report_path) + + invalid_subjects = self.__invalid_subjects() + + should_loop = True + should_stop = False + while should_loop: + + # Since we could have row and dataset stages interwoven, we want + # to make sure we continue processing subjects until nothing new has happened. + # This means we can resume a given subject and its row stages even after a dataset stage + prev_status = report["status"].copy() + subjects = list(report.index) + subjects_loop = tqdm(subjects) + + for subject in subjects_loop: + report, should_stop = self.process_subject( + subject, report, report_path, subjects_loop + ) + + if should_stop: + break + + # If a new invalid subject is identified, start over + new_invalid_subjects = self.__invalid_subjects() + if invalid_subjects != new_invalid_subjects: + invalid_subjects = new_invalid_subjects + # We're going to restart the subjects loop + break + + # Check for report differences. If there are, rerun the loop + should_loop = any(report["status"] != prev_status) and not should_stop + + if self.__is_done(report): + cleanup_folders = self.staging_folders + self.trash_folders + cleanup_storage(cleanup_folders) + + def process_subject( + self, subject: Union[int, str], report: DataFrame, report_path: str, pbar: tqdm + ): + should_stop = False + while True: + # Check if subject has been invalidated + invalid_subjects = self.__invalid_subjects() + if subject in invalid_subjects: + break + + # Filter out invalid subjects + working_report = report[~report.index.isin(invalid_subjects)].copy() + + print(f"Determining next stage for {subject}", flush=True) + stage, done = self.determine_next_stage(subject, working_report) + if stage is not None: + print(f"Next stage for {subject}: {stage.name}", flush=True) + + if done: + print(f"Subject {subject} is Done", flush=True) + break + + try: + working_report, successful = self.run_stage( + stage, subject, working_report, pbar + ) + except Exception: + # TODO: The superclass could be in charge of catching the error, reporting it and cleaning up + # and raise the exception again to be caught here + working_report = self.__report_unhandled_exception( + stage, subject, working_report + ) + print(traceback.format_exc()) + successful = False + + report.update(working_report) + write_report(report, report_path) + + if not successful: + # Send back a signal that a dset stage failed + if isinstance(stage, DatasetStage): + should_stop = True + break + + return report, should_stop + + def run_stage(self, stage, subject, report, pbar): + successful = False + if isinstance(stage, RowStage): + pbar.set_description(f"{subject} | {stage.name}") + report, successful = stage.execute(subject, report) + elif isinstance(stage, DatasetStage): + pbar.set_description(f"{stage.name}") + report, successful = stage.execute(report) + + return report, successful + + def __report_unhandled_exception( + self, + stage: Stage, + subject: Union[int, str], + report: DataFrame, + ): + # Assign a special status code for unhandled errors, associated + # to the stage status code + status_code = -stage.status_code - 0.101 + name = f"{stage.name.upper().replace(' ', '_')}_UNHANDLED_ERROR" + comment = traceback.format_exc() + data_path = report.loc[subject]["data_path"] + labels_path = report.loc[subject]["labels_path"] + + body = { + "status": status_code, + "status_name": name, + "comment": comment, + "data_path": data_path, + "labels_path": labels_path, + } + + report.loc[subject] = body + + return report diff --git a/examples/RANO/data_preparator_workflow/pipeline/project/stages/row_stage.py b/examples/RANO/data_preparator_workflow/pipeline/project/stages/row_stage.py new file mode 100644 index 000000000..70701beb0 --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/project/stages/row_stage.py @@ -0,0 +1,34 @@ +from abc import ABC, abstractmethod +from typing import Union, Tuple +import pandas as pd + +from .stage import Stage + + +class RowStage(Stage, ABC): + @abstractmethod + def could_run(self, index: Union[str, int], report: pd.DataFrame) -> bool: + """Establishes if this step could be executed for the given case + + Args: + index (Union[str, int]): case index in the report + report (pd.DataFrame): Dataframe containing the current state of the preparation flow + + Returns: + bool: wether this stage could be executed + """ + + @abstractmethod + def execute( + self, index: Union[str, int], report: pd.DataFrame + ) -> Tuple[pd.DataFrame, bool]: + """Executes the stage on the given case + + Args: + index (Union[str, int]): case index in the report + report (pd.DataFrame): DataFrame containing the current state of the preparation flow + + Returns: + pd.DataFrame: Updated report dataframe + bool: Success status + """ diff --git a/examples/RANO/data_preparator_workflow/pipeline/project/stages/split.py b/examples/RANO/data_preparator_workflow/pipeline/project/stages/split.py new file mode 100644 index 000000000..a22789a4f --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/project/stages/split.py @@ -0,0 +1,154 @@ +import os +import yaml +import pandas as pd +from typing import List +import math + +from .dset_stage import DatasetStage +from .utils import ( + get_id_tp, + cleanup_storage, + safe_remove, + find_finalized_subjects, + delete_empty_directory, +) +from .mlcube_constants import DONE_STAGE_STATUS, METADATA_PATH +from .env_vars import WORKSPACE_DIR +from .constants import DICOM_ANON_FILENAME, DICOM_COLLAB_FILENAME + + +def row_to_path(row: pd.Series) -> str: + id = row["SubjectID"] + tp = row["Timepoint"] + return os.path.join(id, tp) + + +class SplitStage(DatasetStage): + def __init__( + self, + params: str, + data_path: str, + labels_path: str, + staging_folders: List[str], + base_finalized_dir: str, + ): + self.params = params + self.data_path = data_path + self.labels_path = labels_path + self.split_csv_path = os.path.join(data_path, "splits.csv") + self.train_csv_path = os.path.join(data_path, "train.csv") + self.val_csv_path = os.path.join(data_path, "val.csv") + self.staging_folders = staging_folders + self.base_finalized_dir = base_finalized_dir + + @property + def name(self) -> str: + return "Generate splits" + + @property + def status_code(self) -> int: + return DONE_STAGE_STATUS + + def could_run(self, report: pd.DataFrame) -> bool: + split_exists = os.path.exists(self.split_csv_path) + if split_exists: + # This stage already ran + return False + + for index in report.index: + id, tp = get_id_tp(index) + case_data_path = os.path.join(self.data_path, id, tp) + case_labels_path = os.path.join(self.labels_path, id, tp) + data_exists = os.path.exists(case_data_path) + labels_exist = os.path.exists(case_labels_path) + + if not data_exists or not labels_exist: + # Some subjects are not ready + return False + + return True + + def consolidate_metadata(self): + base_metadata_dir = os.path.join(WORKSPACE_DIR, METADATA_PATH) + anon_dict = {} + collab_dict = {} + files_to_delete = set() + for subject_id_dir in os.listdir(base_metadata_dir): + try: + subject_complete_dir = os.path.join(base_metadata_dir, subject_id_dir) + + for timepoint_dir in os.listdir(subject_complete_dir): + subject_timepoint_complete_dir = os.path.join( + subject_complete_dir, timepoint_dir + ) + subject_metadata_path = os.path.join(subject_timepoint_complete_dir) + if not os.path.isdir(subject_metadata_path): + continue + + anon_yaml = os.path.join(subject_metadata_path, DICOM_ANON_FILENAME) + collab_yaml = os.path.join( + subject_metadata_path, DICOM_COLLAB_FILENAME + ) + + update_tuples = [(anon_yaml, anon_dict), (collab_yaml, collab_dict)] + for yaml_file, data_dict in update_tuples: + if not os.path.isfile(yaml_file): + continue + + with open(yaml_file, "r") as f: + update_dict = yaml.safe_load(f) + + data_dict.update(**update_dict) + files_to_delete.add(yaml_file) + except OSError: + pass + + final_anon_file = os.path.join(base_metadata_dir, DICOM_ANON_FILENAME) + final_collab_file = os.path.join(base_metadata_dir, DICOM_COLLAB_FILENAME) + + write_tuples = [(final_anon_file, anon_dict), (final_collab_file, collab_dict)] + for file_path, data in write_tuples: + with open(file_path, "w") as f: + yaml.dump(data, f) + + for file in files_to_delete: + safe_remove(file) + + for subdir in os.listdir(base_metadata_dir): + complete_subdir_path = os.path.join(base_metadata_dir, subdir) + delete_empty_directory(complete_subdir_path) + + def execute(self) -> pd.DataFrame: + with open(self.params, "r") as f: + params = yaml.safe_load(f) + + seed = params["seed"] + train_pct = params["train_percent"] + + finalized_subjects = find_finalized_subjects() + split_df = pd.DataFrame(finalized_subjects) + subjects = split_df["SubjectID"].drop_duplicates() + subjects = subjects.sample(frac=1, random_state=seed) + train_size = math.floor(len(subjects) * train_pct) + + train_subjects = subjects.iloc[:train_size] + val_subjects = subjects.iloc[train_size:] + + train_mask = split_df["SubjectID"].isin(train_subjects) + val_mask = split_df["SubjectID"].isin(val_subjects) + + split_df.loc[train_mask, "Split"] = "Train" + split_df.loc[val_mask, "Split"] = "Val" + + split_df.to_csv(self.split_csv_path, index=False) + + # Generate separate splits files with relative path + split_df["path"] = split_df.apply(row_to_path, axis=1) + + split_df.loc[train_mask].to_csv(self.train_csv_path, index=False) + split_df.loc[val_mask].to_csv(self.val_csv_path, index=False) + + self.consolidate_metadata() + cleanup_storage(self.staging_folders) + + return True diff --git a/examples/RANO/data_preparator_workflow/pipeline/project/stages/stage.py b/examples/RANO/data_preparator_workflow/pipeline/project/stages/stage.py new file mode 100644 index 000000000..ac453bd6d --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/project/stages/stage.py @@ -0,0 +1,5 @@ +from abc import ABC + +class Stage(ABC): + name: str + status_code: int \ No newline at end of file diff --git a/examples/RANO/data_preparator_workflow/pipeline/project/stages/utils.py b/examples/RANO/data_preparator_workflow/pipeline/project/stages/utils.py new file mode 100644 index 000000000..f6a2130f4 --- /dev/null +++ b/examples/RANO/data_preparator_workflow/pipeline/project/stages/utils.py @@ -0,0 +1,320 @@ +import os +import shutil +from pandas import DataFrame +from tqdm import tqdm +from functools import reduce +from pathlib import Path +import hashlib +import yaml +import pandas as pd + +from .env_vars import DATA_DIR, DATA_SUBDIR +from .mlcube_constants import ( + OUT_CSV, + AUX_FILES_PATH, + FINALIZED_PATH, + MANUAL_REVIEW_PATH, + UNDER_REVIEW_PATH, + TUMOR_EXTRACTION_REVIEW_PATH, + REPORT_FILE, + CHANGED_VOXELS_FILE, +) + + +def convert_path_to_index(path: str): + as_list = path.split(os.sep) + as_index = "|".join(as_list) + return as_index + + +# Taken from https://code.activestate.com/recipes/577879-create-a-nested-dictionary-from-oswalk/ +def get_directory_structure(rootdir): + """ + Creates a nested dictionary that represents the folder structure of rootdir + """ + dir = {} + rootdir = rootdir.rstrip(os.sep) + start = rootdir.rfind(os.sep) + 1 + for path, dirs, files in os.walk(rootdir): + folders = path[start:].split(os.sep) + subdir = dict.fromkeys(files) + parent = reduce(dict.get, folders[:-1], dir) + parent[folders[-1]] = subdir + return dir + + +def get_subdirectories(base_directory: str): + return [ + subdir + for subdir in os.listdir(base_directory) + if os.path.isdir(os.path.join(base_directory, subdir)) + ] + + +def has_prepared_folder_structure(data_path, labels_path) -> bool: + data_struct = list(get_directory_structure(data_path).values())[0] + labels_struct = list(get_directory_structure(labels_path).values())[0] + + expected_data_files = [ + "brain_t1c.nii.gz", + "brain_t1n.nii.gz", + "brain_t2f.nii.gz", + "brain_t2w.nii.gz", + ] + expected_labels_files = ["final_seg.nii.gz"] + + if "splits.csv" not in data_struct: + return False + + for id in data_struct.keys(): + if data_struct[id] is None: + # This is a file, ignore + continue + for tp in data_struct[id].keys(): + expected_subject_data_files = set( + ["_".join([id, tp, file]) for file in expected_data_files] + ) + expected_subject_labels_files = set( + ["_".join([id, tp, file]) for file in expected_labels_files] + ) + + found_data_files = set(data_struct[id][tp].keys()) + found_labels_files = set(labels_struct[id][tp].keys()) + + data_files_diff = len(expected_subject_data_files - found_data_files) + labels_files_diff = len(expected_subject_labels_files - found_labels_files) + if data_files_diff or labels_files_diff: + return False + + # Passed all checks + return True + + +def normalize_path(path: str) -> str: + """Remove mlcube-specific components from the given path + + Args: + path (str): mlcube path + + Returns: + str: normalized path + """ + # for this specific problem, we know that all paths start with `/mlcube_io*` + # and that this pattern won't change, shrink or grow. We can therefore write a + # simple, specific solution + if path.startswith("/mlcube_io"): + return path[12:] + + # In case the path has already been normalized + return path + + +def unnormalize_path(path: str, parent: str) -> str: + """Add back mlcube-specific components to the given path + + Args: + path (str): normalized path + + Returns: + str: mlcube-specific path + """ + if path.startswith(os.path.sep): + path = path[1:] + return os.path.join(parent, path) + + +def load_report(subject_id: str, timepoint: str) -> pd.DataFrame: + report_path = get_report_yaml_filepath(subject_id, timepoint) + + try: + with open(report_path, "r") as f: + report_data = yaml.safe_load(f) + except FileNotFoundError: + report_data = None + + report_df = pd.DataFrame(report_data) + return report_df + + +def normalize_report_paths(report: DataFrame) -> DataFrame: + """Ensures paths are normalized and converts them to relative paths for the local machine + + Args: + report (DataFrame): report to normalize + + Returns: + DataFrame: report with transformed paths + """ + pattern = DATA_SUBDIR + report["data_path"] = report["data_path"].str.split(pattern).str[-1] + report["labels_path"] = report["labels_path"].str.split(pattern).str[-1] + return report + + +def write_report(report: DataFrame, subject_id: str, timepoint: str): + filepath = get_report_yaml_filepath(subject_id, timepoint) + report_dict = report.to_dict() + + # Use a temporary file to avoid quick writes collisions and corruption + temp_path = Path(filepath).parent / ".report.yaml" + with open(temp_path, "w") as f: + yaml.dump(report_dict, f) + os.rename(temp_path, filepath) + + +def get_id_tp(index: str): + return index.split("|") + + +def set_files_read_only(path): + for root, dirs, files in os.walk(path): + for file_name in files: + file_path = os.path.join(root, file_name) + os.chmod(file_path, 0o444) # Set read-only permission for files + + for dir_name in dirs: + dir_path = os.path.join(root, dir_name) + set_files_read_only( + dir_path + ) # Recursively call the function for subdirectories + + +def cleanup_storage(remove_folders): + for folder in remove_folders: + print(f"Deleting directory {folder}...") + shutil.rmtree(folder, ignore_errors=True) + + +def copy_files(src_dir, dest_dir): + # Ensure the destination directory exists + os.makedirs(dest_dir, exist_ok=True) + + # Iterate through the files in the source directory + for filename in os.listdir(src_dir): + src_file = os.path.join(src_dir, filename) + dest_file = os.path.join(dest_dir, filename) + + # Check if the item is a file (not a directory) + if os.path.isfile(src_file): + shutil.copy2(src_file, dest_file) # Copy the file + + +# Taken from https://stackoverflow.com/questions/24937495/how-can-i-calculate-a-hash-for-a-filesystem-directory-using-python +def md5_update_from_dir(directory, hash): + assert Path(directory).is_dir() + for path in sorted(Path(directory).iterdir(), key=lambda p: str(p).lower()): + hash.update(path.name.encode()) + if path.is_file(): + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash.update(chunk) + elif path.is_dir(): + hash = md5_update_from_dir(path, hash) + return hash + + +def md5_dir(directory): + return md5_update_from_dir(directory, hashlib.md5()).hexdigest() + + +def md5_file(filepath): + return hashlib.md5(open(filepath, "rb").read()).hexdigest() + + +class MockTqdm(tqdm): + def __getattr__(self, attr): + return lambda *args, **kwargs: None + + +def get_aux_files_dir(subject_subdir): + return os.path.join(DATA_DIR, AUX_FILES_PATH, subject_subdir) + + +def get_report_yaml_filepath(subject_id, timepoint): + yaml_dir = os.path.join( + DATA_DIR, AUX_FILES_PATH, os.path.join(subject_id, timepoint) + ) + return os.path.join(yaml_dir, REPORT_FILE) + + +def get_data_csv_filepath(subject_subdir): + csv_dir = get_aux_files_dir(subject_subdir) + return os.path.join(csv_dir, OUT_CSV) + + +def find_finalized_subjects(): + base_finalized_dir = os.path.join( + DATA_DIR, MANUAL_REVIEW_PATH, TUMOR_EXTRACTION_REVIEW_PATH + ) + subject_and_timepoint_list = [] + + candidate_subjects = get_subdirectories(base_finalized_dir) + for candidate_subject in candidate_subjects: + subject_path = os.path.join(base_finalized_dir, candidate_subject) + timepoint_dirs = get_subdirectories(subject_path) + + for timepoint in timepoint_dirs: + timepoint_complete_path = os.path.join(subject_path, timepoint) + finalized_path = os.path.join(timepoint_complete_path, FINALIZED_PATH) + try: + path_exists = os.path.exists(finalized_path) + path_is_dir = os.path.isdir(finalized_path) + only_one_case = len(os.listdir(finalized_path)) == 1 + if path_exists and path_is_dir and only_one_case: + subject_timepoint_dict = { + "SubjectID": candidate_subject, + "Timepoint": timepoint, + } + subject_and_timepoint_list.append(subject_timepoint_dict) + except OSError: + pass + return subject_and_timepoint_list + + +def get_manual_approval_base_path(subject_id, timepoint, approval_type): + manual_approval_base_path = os.path.join( + DATA_DIR, + MANUAL_REVIEW_PATH, + approval_type, + subject_id, + timepoint, + ) + return manual_approval_base_path + + +def get_manual_approval_finalized_path(subject_id, timepoint, approval_type): + base_path = get_manual_approval_base_path(subject_id, timepoint, approval_type) + full_path = os.path.join(base_path, FINALIZED_PATH) + return full_path + + +def get_manual_approval_under_review_path(subject_id, timepoint, approval_type): + base_path = get_manual_approval_base_path(subject_id, timepoint, approval_type) + full_path = os.path.join(base_path, UNDER_REVIEW_PATH) + return full_path + + +def safe_remove(path_to_remove: str): + try: + os.remove(path_to_remove) + except FileNotFoundError: + pass + + +def delete_empty_directory(path_to_directory: str): + if os.path.isdir(path_to_directory): + inside_this_dir = os.listdir(path_to_directory) + for subdir in inside_this_dir: + complete_path = os.path.join(path_to_directory, subdir) + delete_empty_directory(complete_path) + + # List again, could be empty now + inside_this_dir = os.listdir(path_to_directory) + if not inside_this_dir: + shutil.rmtree(path_to_directory) + + +def get_changed_voxels_file(subject_id, timepoint): + return os.path.join( + DATA_DIR, AUX_FILES_PATH, subject_id, timepoint, CHANGED_VOXELS_FILE + ) diff --git a/examples/RANO/data_preparator_workflow/readme_images/airflow_home.png b/examples/RANO/data_preparator_workflow/readme_images/airflow_home.png new file mode 100644 index 000000000..ed873fa9d Binary files /dev/null and b/examples/RANO/data_preparator_workflow/readme_images/airflow_home.png differ diff --git a/examples/RANO/data_preparator_workflow/readme_images/airflow_login.png b/examples/RANO/data_preparator_workflow/readme_images/airflow_login.png new file mode 100644 index 000000000..c89ff9b55 Binary files /dev/null and b/examples/RANO/data_preparator_workflow/readme_images/airflow_login.png differ diff --git a/examples/RANO/data_preparator_workflow/readme_images/dag_list.png b/examples/RANO/data_preparator_workflow/readme_images/dag_list.png new file mode 100644 index 000000000..f47ea0c06 Binary files /dev/null and b/examples/RANO/data_preparator_workflow/readme_images/dag_list.png differ diff --git a/examples/RANO/data_preparator_workflow/readme_images/filter_by_required_actions.png b/examples/RANO/data_preparator_workflow/readme_images/filter_by_required_actions.png new file mode 100644 index 000000000..5fc0a7141 Binary files /dev/null and b/examples/RANO/data_preparator_workflow/readme_images/filter_by_required_actions.png differ diff --git a/examples/RANO/data_preparator_workflow/readme_images/final_confirmation_approve_button.png b/examples/RANO/data_preparator_workflow/readme_images/final_confirmation_approve_button.png new file mode 100644 index 000000000..f02e7788d Binary files /dev/null and b/examples/RANO/data_preparator_workflow/readme_images/final_confirmation_approve_button.png differ diff --git a/examples/RANO/data_preparator_workflow/readme_images/final_confirmation_dag.png b/examples/RANO/data_preparator_workflow/readme_images/final_confirmation_dag.png new file mode 100644 index 000000000..2b7b36e99 Binary files /dev/null and b/examples/RANO/data_preparator_workflow/readme_images/final_confirmation_dag.png differ diff --git a/examples/RANO/data_preparator_workflow/readme_images/pipeline_diagram.png b/examples/RANO/data_preparator_workflow/readme_images/pipeline_diagram.png new file mode 100644 index 000000000..66e57993a Binary files /dev/null and b/examples/RANO/data_preparator_workflow/readme_images/pipeline_diagram.png differ diff --git a/examples/RANO/data_preparator_workflow/readme_images/required_actions_dag.png b/examples/RANO/data_preparator_workflow/readme_images/required_actions_dag.png new file mode 100644 index 000000000..9cb6e734b Binary files /dev/null and b/examples/RANO/data_preparator_workflow/readme_images/required_actions_dag.png differ diff --git a/examples/RANO/data_preparator_workflow/readme_images/task_instances_view.png b/examples/RANO/data_preparator_workflow/readme_images/task_instances_view.png new file mode 100644 index 000000000..0ca6fba6c Binary files /dev/null and b/examples/RANO/data_preparator_workflow/readme_images/task_instances_view.png differ diff --git a/examples/RANO/data_preparator_workflow/readme_images/task_list_filtered_final_confirmation.png b/examples/RANO/data_preparator_workflow/readme_images/task_list_filtered_final_confirmation.png new file mode 100644 index 000000000..635d0dffe Binary files /dev/null and b/examples/RANO/data_preparator_workflow/readme_images/task_list_filtered_final_confirmation.png differ diff --git a/examples/RANO/data_preparator_workflow/readme_images/tasks_manual_review.png b/examples/RANO/data_preparator_workflow/readme_images/tasks_manual_review.png new file mode 100644 index 000000000..d9c6abab4 Binary files /dev/null and b/examples/RANO/data_preparator_workflow/readme_images/tasks_manual_review.png differ diff --git a/examples/RANO/data_preparator_workflow/workflow.yaml b/examples/RANO/data_preparator_workflow/workflow.yaml new file mode 100644 index 000000000..623c5c856 --- /dev/null +++ b/examples/RANO/data_preparator_workflow/workflow.yaml @@ -0,0 +1,336 @@ +base_step: &BASE_STEP + - type: container + image: mlcommons/rano-data-prep-workflow:0.0.2 + + +steps: + - id: setup + <<: *BASE_STEP + command: initial_setup + mounts: + input_volumes: + data_path: + mount_path: /workspace/input_data + type: directory + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + next: make_csv + on_error: do_something + partition: false + + - id: make_csv + <<: *BASE_STEP + command: make_csv + mounts: + input_volumes: + data_path: + mount_path: /workspace/input_data + type: directory + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + next: nifti_conversion + on_error: do_something + partition: true + + - id: nifti_conversion + <<: *BASE_STEP + mounts: + input_volumes: + csv_input: + mount_path: /workspace/data + type: directory + from: + step: make_csv + mount: output_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + metadata_path: + mount_path: /workspace/metadata + type: directory + command: convert_nifti + next: brain_extraction + on_error: do_something + partition: true + + - id: brain_extraction + <<: *BASE_STEP + mounts: + input_volumes: + nifti_input: + mount_path: /workspace/data + type: directory + from: + step: nifti_conversion + mount: output_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + command: extract_brain + next: tumor_extraction + on_error: do_something + partition: true + limit: 2 + + - id: tumor_extraction + <<: *BASE_STEP + mounts: + input_volumes: + brain_extraction_input: + mount_path: /workspace/data + type: directory + from: + step: brain_extraction + mount: output_path + additional_files: + mount_path: /workspace/additional_files + type: directory + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + command: extract_tumor + next: manual_review + on_error: do_something + partition: true + limit: 2 + + - id: manual_review + <<: *BASE_STEP + mounts: + input_volumes: + tumor_extraction_input: + mount_path: /workspace/data + type: directory + from: + step: tumor_extraction + mount: output_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + output_labels_path: + mount_path: /workspace/labels + type: directory + command: prepare_for_manual_review + next: + if: + - condition: annotation_done + target: segmentation_comparison + - condition: brain_mask_changed + target: rollback_to_brain_extract + else: manual_review + wait: 60 + on_error: do_something + partition: true + + - id: rollback_to_brain_extract + <<: *BASE_STEP + mounts: + input_volumes: + review_input_data: + mount_path: /workspace/data + type: directory + from: + step: manual_review + mount: output_path + review_input_labels: + mount_path: /workspace/labels + type: directory + from: + step: manual_review + mount: output_labels_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + output_labels_path: + mount_path: /workspace/labels + type: directory + command: rollback_to_brain_extract + next: brain_extraction + partition: true + + - id: segmentation_comparison + <<: *BASE_STEP + mounts: + input_volumes: + review_input_data: + mount_path: /workspace/data + type: directory + from: + step: manual_review + mount: output_path + review_input_labels: + mount_path: /workspace/labels + type: directory + from: + step: manual_review + mount: output_labels_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + output_labels_path: + mount_path: /workspace/labels + type: directory + command: segmentation_comparison + next: calculate_changed_voxels + partition: true + + - id: calculate_changed_voxels + <<: *BASE_STEP + mounts: + input_volumes: + segmentation_comparison_input_data: + mount_path: /workspace/data + type: directory + from: + step: segmentation_comparison + mount: output_path + segmentation_comparison_input_labels: + mount_path: /workspace/labels + type: directory + from: + step: segmentation_comparison + mount: output_labels_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + output_labels_path: + mount_path: /workspace/labels + type: directory + command: calculate_changed_voxels + next: final_confirmation + + - id: final_confirmation + type: manual_approval + next: move_labeled_files + + + - id: move_labeled_files + <<: *BASE_STEP + mounts: + input_volumes: + changed_voxels_input_data: + mount_path: /workspace/data + type: directory + from: + step: calculate_changed_voxels + mount: output_path + changed_voxels_input_input_labels: + mount_path: /workspace/labels + type: directory + from: + step: calculate_changed_voxels + mount: output_labels_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + output_labels_path: + mount_path: /workspace/labels + type: directory + command: move_labeled_files + next: consolidation_stage + + + - id: consolidation_stage + <<: *BASE_STEP + mounts: + input_volumes: + parameters_file: + mount_path: /workspace/parameters.yaml + type: file + move_labeled_files_input_data: + mount_path: /workspace/data + type: directory + from: + step: move_labeled_files + mount: output_path + changed_voxels_input_input_labels: + mount_path: /workspace/labels + type: directory + from: + step: move_labeled_files + mount: output_labels_path + nifti_metadata: + mount_path: /workspace/metadata + type: directory + from: + step: nifti_conversion + mount: metadata_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + output_labels_path: + mount_path: /workspace/labels + type: directory + metadata_path: + mount_path: /workspace/metadata + type: directory + command: consolidation_stage + next: sanity_check + + - id: sanity_check + <<: *BASE_STEP + mounts: + input_volumes: + consolidation_input_data: + mount_path: /workspace/data + type: directory + from: + step: consolidation_stage + mount: output_path + consolidation_input_labels: + mount_path: /workspace/labels + type: directory + from: + step: consolidation_stage + mount: output_labels_path + command: sanity_check + next: metrics + + - id: metrics + <<: *BASE_STEP + mounts: + input_volumes: + consolidation_input: + mount_path: /workspace/data + type: directory + from: + step: consolidation_stage + mount: output_path + output_volumes: + statistics_file: + mount_path: /workspace/data/statistics.yaml + type: file + metadata_path: + mount_path: /workspace/metadata + type: directory + output_labels_path: + mount_path: /workspace/labels + type: directory + command: metrics + next: null + +conditions: + - id: annotation_done + type: function + function_name: conditions.annotation_done + + - id: brain_mask_changed + type: function + function_name: conditions.brain_mask_changed + +partition_def: + type: function + function_name: subject_definition.subject_definition \ No newline at end of file diff --git a/examples/RANO/data_preparator_workflow/workflow_dev.yaml b/examples/RANO/data_preparator_workflow/workflow_dev.yaml new file mode 100644 index 000000000..0206a7da1 --- /dev/null +++ b/examples/RANO/data_preparator_workflow/workflow_dev.yaml @@ -0,0 +1,336 @@ +base_step: &BASE_STEP + - type: container + image: mlcommons/rano-data-prep-workflow-dev:0.0.2 + + +steps: + - id: setup + <<: *BASE_STEP + command: initial_setup + mounts: + input_volumes: + data_path: + mount_path: /workspace/input_data + type: directory + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + next: make_csv + on_error: do_something + partition: false + + - id: make_csv + <<: *BASE_STEP + command: make_csv + mounts: + input_volumes: + data_path: + mount_path: /workspace/input_data + type: directory + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + next: nifti_conversion + on_error: do_something + partition: true + + - id: nifti_conversion + <<: *BASE_STEP + mounts: + input_volumes: + csv_input: + mount_path: /workspace/data + type: directory + from: + step: make_csv + mount: output_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + metadata_path: + mount_path: /workspace/metadata + type: directory + command: convert_nifti + next: brain_extraction + on_error: do_something + partition: true + + - id: brain_extraction + <<: *BASE_STEP + mounts: + input_volumes: + nifti_input: + mount_path: /workspace/data + type: directory + from: + step: nifti_conversion + mount: output_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + command: extract_brain + next: tumor_extraction + on_error: do_something + partition: true + limit: 2 + + - id: tumor_extraction + <<: *BASE_STEP + mounts: + input_volumes: + brain_extraction_input: + mount_path: /workspace/data + type: directory + from: + step: brain_extraction + mount: output_path + additional_files: + mount_path: /workspace/additional_files + type: directory + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + command: extract_tumor + next: manual_review + on_error: do_something + partition: true + limit: 2 + + - id: manual_review + <<: *BASE_STEP + mounts: + input_volumes: + tumor_extraction_input: + mount_path: /workspace/data + type: directory + from: + step: tumor_extraction + mount: output_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + output_labels_path: + mount_path: /workspace/labels + type: directory + command: prepare_for_manual_review + next: + if: + - condition: annotation_done + target: segmentation_comparison + - condition: brain_mask_changed + target: rollback_to_brain_extract + else: manual_review + wait: 60 + on_error: do_something + partition: true + + - id: rollback_to_brain_extract + <<: *BASE_STEP + mounts: + input_volumes: + review_input_data: + mount_path: /workspace/data + type: directory + from: + step: manual_review + mount: output_path + review_input_labels: + mount_path: /workspace/labels + type: directory + from: + step: manual_review + mount: output_labels_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + output_labels_path: + mount_path: /workspace/labels + type: directory + command: rollback_to_brain_extract + next: brain_extraction + partition: true + + - id: segmentation_comparison + <<: *BASE_STEP + mounts: + input_volumes: + review_input_data: + mount_path: /workspace/data + type: directory + from: + step: manual_review + mount: output_path + review_input_labels: + mount_path: /workspace/labels + type: directory + from: + step: manual_review + mount: output_labels_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + output_labels_path: + mount_path: /workspace/labels + type: directory + command: segmentation_comparison + next: calculate_changed_voxels + partition: true + + - id: calculate_changed_voxels + <<: *BASE_STEP + mounts: + input_volumes: + segmentation_comparison_input_data: + mount_path: /workspace/data + type: directory + from: + step: segmentation_comparison + mount: output_path + segmentation_comparison_input_labels: + mount_path: /workspace/labels + type: directory + from: + step: segmentation_comparison + mount: output_labels_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + output_labels_path: + mount_path: /workspace/labels + type: directory + command: calculate_changed_voxels + next: final_confirmation + + - id: final_confirmation + type: manual_approval + next: move_labeled_files + + + - id: move_labeled_files + <<: *BASE_STEP + mounts: + input_volumes: + changed_voxels_input_data: + mount_path: /workspace/data + type: directory + from: + step: calculate_changed_voxels + mount: output_path + changed_voxels_input_input_labels: + mount_path: /workspace/labels + type: directory + from: + step: calculate_changed_voxels + mount: output_labels_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + output_labels_path: + mount_path: /workspace/labels + type: directory + command: move_labeled_files + next: consolidation_stage + + + - id: consolidation_stage + <<: *BASE_STEP + mounts: + input_volumes: + parameters_file: + mount_path: /workspace/parameters.yaml + type: file + move_labeled_files_input_data: + mount_path: /workspace/data + type: directory + from: + step: move_labeled_files + mount: output_path + changed_voxels_input_input_labels: + mount_path: /workspace/labels + type: directory + from: + step: move_labeled_files + mount: output_labels_path + nifti_metadata: + mount_path: /workspace/metadata + type: directory + from: + step: nifti_conversion + mount: metadata_path + output_volumes: + output_path: + mount_path: /workspace/data + type: directory + output_labels_path: + mount_path: /workspace/labels + type: directory + metadata_path: + mount_path: /workspace/metadata + type: directory + command: consolidation_stage + next: sanity_check + + - id: sanity_check + <<: *BASE_STEP + mounts: + input_volumes: + consolidation_input_data: + mount_path: /workspace/data + type: directory + from: + step: consolidation_stage + mount: output_path + consolidation_input_labels: + mount_path: /workspace/labels + type: directory + from: + step: consolidation_stage + mount: output_labels_path + command: sanity_check + next: metrics + + - id: metrics + <<: *BASE_STEP + mounts: + input_volumes: + consolidation_input: + mount_path: /workspace/data + type: directory + from: + step: consolidation_stage + mount: output_path + output_volumes: + statistics_file: + mount_path: /workspace/data/statistics.yaml + type: file + metadata_path: + mount_path: /workspace/metadata + type: directory + output_labels_path: + mount_path: /workspace/labels + type: directory + command: metrics + next: null + +conditions: + - id: annotation_done + type: function + function_name: conditions.annotation_done + + - id: brain_mask_changed + type: function + function_name: conditions.brain_mask_changed + +partition_def: + type: function + function_name: subject_definition.subject_definition \ No newline at end of file diff --git a/examples/RANO/data_preparator_workflow/workspace/additional_files/conditions.py b/examples/RANO/data_preparator_workflow/workspace/additional_files/conditions.py new file mode 100644 index 000000000..9d40f7660 --- /dev/null +++ b/examples/RANO/data_preparator_workflow/workspace/additional_files/conditions.py @@ -0,0 +1,62 @@ +import os + + +def annotation_done(pipeline_state): + + base_review_dir = os.path.join( + pipeline_state.host_output_data_path, + "manual_review", + "tumor_extraction", + pipeline_state.running_subject, + ) + finalized_dir = os.path.join(base_review_dir, "finalized") + dir_files = os.listdir(finalized_dir) + + if len(dir_files) == 0: + print("Reviewed annotation not Found!") + return False + + elif len(dir_files) > 1: + print( + "More than one annotation found! Please only keep one file in the finalized directory" + ) + return False + + formatted_subject = pipeline_state.running_subject.replace("/", "_") + proper_name = f"{formatted_subject}_tumorMask_model_0.nii.gz" + if dir_files[0] != proper_name: + print( + f"Reviewed file has been renamed! Please make sure the file is named\n{proper_name}\nto ensure the pipeline runs correctly!" + ) + return False + return True + + +def brain_mask_changed(pipeline_state): + + base_review_dir = os.path.join( + pipeline_state.host_output_data_path, + "manual_review", + "brain_mask", + pipeline_state.running_subject, + ) + finalized_dir = os.path.join(base_review_dir, "finalized") + dir_files = os.listdir(finalized_dir) + + if len(dir_files) == 0: + print("No brain mask change detected.") + return False + + elif len(dir_files) > 1: + print( + "More than one brain mask correction found! Please only keep one file in the finalized directory." + ) + return False + + proper_name = f"brainMask_fused.nii.gz" + if dir_files[0] != proper_name: + print( + f"Brain Mask file has been renamed! Please make sure the file is named\n{proper_name}\nto ensure the pipeline runs correctly!" + ) + return False + return True diff --git a/examples/RANO/data_preparator_workflow/workspace/additional_files/subject_definition.py b/examples/RANO/data_preparator_workflow/workspace/additional_files/subject_definition.py new file mode 100644 index 000000000..b5bb38a4b --- /dev/null +++ b/examples/RANO/data_preparator_workflow/workspace/additional_files/subject_definition.py @@ -0,0 +1,20 @@ +import os + + +def subject_definition(pipeline_state): + + input_data_dir = pipeline_state.host_input_data_path + subject_slash_timepoint_list = [] + + for subject_id_dir in os.listdir(input_data_dir): + subject_complete_dir = os.path.join(input_data_dir, subject_id_dir) + + if not os.path.isdir(subject_complete_dir): + continue + + for timepoint_dir in os.listdir(subject_complete_dir): + subject_slash_timepoint_list.append( + os.path.join(subject_id_dir, timepoint_dir) + ) + + return subject_slash_timepoint_list diff --git a/examples/RANO/data_preparator_workflow/workspace/parameters.yaml b/examples/RANO/data_preparator_workflow/workspace/parameters.yaml new file mode 100644 index 000000000..7755bbbd7 --- /dev/null +++ b/examples/RANO/data_preparator_workflow/workspace/parameters.yaml @@ -0,0 +1,21 @@ +seed: 2784 +train_percent: 0.8 +medperf_report_stages: +- "IDENTIFIED" +- "VALIDATED" +- "MISSING_MODALITIES" +- "EXTRA_MODALITIES" +- "VALIDATION_FAILED" +- "CONVERTED_TO_NIfTI" +- "NIfTI_CONVERSION_FAILED" +- "BRAIN_EXTRACT_FINISHED" +- "BRAIN_EXTRACT_FINISHED" +- "TUMOR_EXTRACT_FAILED" +- "MANUAL_REVIEW_COMPLETE" +- "MANUAL_REVIEW_REQUIRED" +- "MULTIPLE_ANNOTATIONS_ERROR" +- "COMPARISON_COMPLETE" +- "EXACT_MATCH_IDENTIFIED" +- "ANNOTATION_COMPARISON_FAILED" +- "ANNOTATION_CONFIRMED" +- "DONE" \ No newline at end of file diff --git a/examples/chestxray_tutorial/data_preparator_workflow/workflow.yaml b/examples/chestxray_tutorial/data_preparator_workflow/workflow.yaml new file mode 100644 index 000000000..cfa99e653 --- /dev/null +++ b/examples/chestxray_tutorial/data_preparator_workflow/workflow.yaml @@ -0,0 +1,77 @@ +steps: + - id: prepare + type: container + image: mlcommons/chestxray-tutorial-prep:0.0.1 + command: python3 /project/prepare.py + mounts: + input_volumes: + data_path: + mount_path: /mlcommons/volumes/raw_data + type: directory + labels_path: + mount_path: /mlcommons/volumes/raw_labels + type: directory + parameters_file: + mount_path: /mlcommons/volumes/parameters/parameters_file.yaml + type: file + output_volumes: + output_path: + mount_path: /mlcommons/volumes/data + type: directory + output_labels_path: + mount_path: /mlcommons/volumes/labels + type: directory + next: sanity_check + on_error: do_something + + - id: sanity_check + type: container + image: mlcommons/chestxray-tutorial-prep:0.0.1 + command: python3 /project/sanity_check.py + mounts: + input_volumes: + data_path: + mount_path: /mlcommons/volumes/data + type: directory + from: + step: prepare + mount: output_path + labels_path: + mount_path: /mlcommons/volumes/labels + type: directory + from: + step: prepare + mount: output_labels_path + parameters_file: + mount_path: /mlcommons/volumes/parameters/parameters_file.yaml + type: file + next: statistics + on_error: do_something + + - id: statistics + type: container + image: mlcommons/chestxray-tutorial-prep:0.0.1 + command: python3 /project/statistics.py + mounts: + input_volumes: + data_path: + mount_path: /mlcommons/volumes/data + type: directory + from: + step: prepare + mount: output_path + labels_path: + mount_path: /mlcommons/volumes/labels + type: directory + from: + step: prepare + mount: output_path + parameters_file: + mount_path: /mlcommons/volumes/parameters/parameters_file.yaml + type: file + output_volumes: + statistics_file: + mount_path: /mlcommons/volumes/statistics/statistics.yaml + type: file + next: null + on_error: do_something \ No newline at end of file diff --git a/examples/chestxray_tutorial/data_preparator_workflow/workspace/parameters.yaml b/examples/chestxray_tutorial/data_preparator_workflow/workspace/parameters.yaml new file mode 100644 index 000000000..cf1348d59 --- /dev/null +++ b/examples/chestxray_tutorial/data_preparator_workflow/workspace/parameters.yaml @@ -0,0 +1,20 @@ +labels_list: + [ + "atelectasis", + "cardiomegaly", + "effusion", + "infiltration", + "mass", + "nodule", + "pneumonia", + "pneumothorax", + "consolidation", + "edema", + "emphysema", + "fibrosis", + "pleural", + "hernia", + ] +image_column_id: Image ID +label_column_id: Label +image_output_size: [28, 28, 1] diff --git a/server/medperf/testing_utils.py b/server/medperf/testing_utils.py index bde917800..1e0251627 100644 --- a/server/medperf/testing_utils.py +++ b/server/medperf/testing_utils.py @@ -59,7 +59,7 @@ def mock_mlcube(**kwargs): "name": "testmlcube", "container_config": {"key": "value"}, "parameters_config": {}, - "image_hash": "string", + "image_hash": {"default": "string"}, "additional_files_tarball_url": "string", "additional_files_tarball_hash": "string", "state": "DEVELOPMENT", diff --git a/server/mlcube/migrations/0006_convert_image_hash_to_json.py b/server/mlcube/migrations/0006_convert_image_hash_to_json.py new file mode 100644 index 000000000..e03fa2e46 --- /dev/null +++ b/server/mlcube/migrations/0006_convert_image_hash_to_json.py @@ -0,0 +1,75 @@ +# Generated by Django 4.2.23 on 2025-11-06 12:53 + +from django.db import migrations, models + + +def convert_image_hash_from_string_to_json(apps, schema_editor): + MlCube = apps.get_model("mlcube", "MlCube") + for mlcube in MlCube.objects.all(): + mlcube.image_hash_tmp = {"default": mlcube.image_hash} + mlcube.save() + + +def convert_image_hash_from_json_to_str(apps, schema_editor): + MlCube = apps.get_model("mlcube", "MlCube") + for mlcube in MlCube.objects.all(): + mlcube.image_hash = list(mlcube.image_hash_tmp.values())[0] + mlcube.save() + + +class Migration(migrations.Migration): + + dependencies = [ + ( + "mlcube", + "0005_alter_mlcube_unique_together_alter_mlcube_image_hash_and_more", + ), + ] + + operations = [ + migrations.AddField( + model_name="mlcube", + name="image_hash_tmp", + field=models.JSONField(default=dict), + preserve_default=False, + ), + migrations.RunPython( + code=convert_image_hash_from_string_to_json, + reverse_code=convert_image_hash_from_json_to_str, + ), + migrations.AlterUniqueTogether( + name="mlcube", + unique_together=set(), + ), + migrations.AlterUniqueTogether( + name="mlcube", + unique_together={ + ( + "image_hash_tmp", + "additional_files_tarball_hash", + "container_config", + "parameters_config", + ) + }, + ), + migrations.RemoveField( + model_name="mlcube", + name="image_hash", + ), + migrations.RenameField( + model_name="mlcube", + old_name="image_hash_tmp", + new_name="image_hash", + ), + migrations.AlterUniqueTogether( + name="mlcube", + unique_together={ + ( + "image_hash", + "additional_files_tarball_hash", + "container_config", + "parameters_config", + ) + }, + ), + ] diff --git a/server/mlcube/models.py b/server/mlcube/models.py index d7e4c802b..b9583dcfe 100644 --- a/server/mlcube/models.py +++ b/server/mlcube/models.py @@ -13,7 +13,7 @@ class MlCube(models.Model): name = models.CharField(max_length=128, unique=True) container_config = models.JSONField() parameters_config = models.JSONField(blank=True, null=True) - image_hash = models.CharField(max_length=100) + image_hash = models.JSONField() additional_files_tarball_url = models.CharField(max_length=256, blank=True) additional_files_tarball_hash = models.CharField(max_length=100, blank=True) owner = models.ForeignKey(User, on_delete=models.PROTECT) diff --git a/server/mlcube/serializers.py b/server/mlcube/serializers.py index a34ec64f5..199741215 100644 --- a/server/mlcube/serializers.py +++ b/server/mlcube/serializers.py @@ -16,6 +16,14 @@ def validate_optional_mlcube_components(data): ) +def validate_image_hash(data): + hashes_dict = data.get("image_hash") + if not hashes_dict: + raise serializers.ValidationError( + "Cannot submit Container with empty image_hash!" + ) + + class MlCubeSerializer(serializers.ModelSerializer): class Meta: model = MlCube @@ -24,6 +32,7 @@ class Meta: def validate(self, data): validate_optional_mlcube_components(data) + validate_image_hash(data) return data @@ -58,5 +67,5 @@ def validate(self, data): updated_dict[key] = data.get(key, getattr(self.instance, key)) validate_optional_mlcube_components(updated_dict) - + validate_image_hash(updated_dict) return data diff --git a/server/mlcube/tests/test_.py b/server/mlcube/tests/test_.py index 25dc02613..f24d54483 100644 --- a/server/mlcube/tests/test_.py +++ b/server/mlcube/tests/test_.py @@ -167,8 +167,8 @@ def test_additional_files_should_have_a_hash(self, url_provided): @parameterized.expand( [ - ("hash", status.HTTP_201_CREATED), - ("", status.HTTP_400_BAD_REQUEST), + ({"image": "hash"}, status.HTTP_201_CREATED), + ({}, status.HTTP_400_BAD_REQUEST), ] ) def test_required_image_hash(self, image_hash, exp_status): diff --git a/server/mlcube/tests/test_pk.py b/server/mlcube/tests/test_pk.py index 7c2c905f0..9e2e1b93f 100644 --- a/server/mlcube/tests/test_pk.py +++ b/server/mlcube/tests/test_pk.py @@ -87,7 +87,7 @@ def test_put_modifies_editable_fields_in_development(self): "name": "newtestmlcube", "container_config": {"newstring": "newstring"}, "parameters_config": {"newstring": "newstring"}, - "image_hash": "newstring", + "image_hash": {"hash": "newstring"}, "additional_files_tarball_url": "newstring", "additional_files_tarball_hash": "newstring", "state": "OPERATION", @@ -134,7 +134,7 @@ def test_put_does_not_modify_non_editable_fields_in_operation(self): newtestmlcube = { "name": "newtestmlcube", - "image_hash": "newhash", + "image_hash": {"default": "newhash"}, "additional_files_tarball_hash": "newstring", "state": "DEVELOPMENT", "metadata": {"newkey": "newvalue"}, @@ -163,7 +163,6 @@ def test_put_does_not_modify_readonly_fields_in_both_states(self, state): # Act response = self.client.put(url, newtestmlcube, format="json") - # Assert self.assertEqual(response.status_code, status.HTTP_200_OK) for k, v in newtestmlcube.items(): @@ -255,7 +254,7 @@ def test_put_clearing_image_hash(self): testmlcube = self.mock_mlcube(state="DEVELOPMENT") testmlcube = self.create_mlcube(testmlcube).data - put_body = {"image_hash": ""} + put_body = {"image_hash": {}} url = self.url.format(testmlcube["id"]) # Act @@ -341,7 +340,7 @@ def test_put_permissions(self, user, expected_status): "name": "newtestmlcube", "container_config": {"newstring": "newstring"}, "parameters_config": {"newstring": "newstring"}, - "image_hash": "", + "image_hash": {"default": ""}, "additional_files_tarball_url": "newstring", "additional_files_tarball_hash": "newstring", "state": "OPERATION", diff --git a/server/seed.py b/server/seed.py index b8a0ec7cb..5a5baaca0 100644 --- a/server/seed.py +++ b/server/seed.py @@ -8,12 +8,25 @@ executes Django code to set admin permissions for a test user.""" import argparse -from seed_utils import Server, set_user_as_admin, create_benchmark, create_model +from seed_utils import ( + Server, + set_user_as_admin, + create_benchmark, + create_model, + create_workflow_benchmark, + create_rano_workflow_mlcube, +) from auth_provider_token import auth_provider_token import json from pathlib import Path REPO_BASE_DIR = Path(__file__).resolve().parent.parent +default_cert_file = str(REPO_BASE_DIR / "server" / "cert.crt") +default_tokens_file = str(REPO_BASE_DIR / "mock_tokens" / "tokens.json") +default_xray_containers_assets_path = str( + REPO_BASE_DIR / "examples" / "chestxray_tutorial" +) +rano_assets_path = str(REPO_BASE_DIR / "examples" / "RANO") def populate_mock_benchmarks(api_server, admin_token): @@ -68,11 +81,30 @@ def seed(args): return # create benchmark benchmark_owner_token = get_token("testbo@example.com") - benchmark = create_benchmark( - api_server, - benchmark_owner_token, - args.containers_assets_path, + + if args.demo == "rano": + create_rano_workflow_mlcube( + api_server=api_server, + benchmark_owner_token=benchmark_owner_token, + assets_path=rano_assets_path, + ) + return + + xray_assets_path = ( + args.containers_assets_path or default_xray_containers_assets_path ) + if args.workflow: + benchmark = create_workflow_benchmark( + api_server, + benchmark_owner_token, + xray_assets_path, + ) + else: + benchmark = create_benchmark( + api_server, + benchmark_owner_token, + xray_assets_path, + ) if args.demo == "model": return # create model @@ -82,17 +114,11 @@ def seed(args): model_owner_token, benchmark_owner_token, benchmark, - args.containers_assets_path, + xray_assets_path, ) if __name__ == "__main__": - default_cert_file = str(REPO_BASE_DIR / "server" / "cert.crt") - default_tokens_file = str(REPO_BASE_DIR / "mock_tokens" / "tokens.json") - default_containers_assets_path = str( - REPO_BASE_DIR / "examples" / "chestxray_tutorial" - ) - parser = argparse.ArgumentParser(description="Seed the db with demo entries") parser.add_argument( "--server", @@ -114,9 +140,9 @@ def seed(args): parser.add_argument( "--demo", type=str, - help="Seed for a tutorial: 'benchmark', 'model', or 'data'.", + help="Seed for a tutorial: 'benchmark', 'model', 'data', 'tutorial' or 'rano.", default="data", - choices=["benchmark", "model", "data", "tutorial"], + choices=["benchmark", "model", "data", "tutorial", "rano"], ) parser.add_argument( "--tokens", @@ -125,10 +151,17 @@ def seed(args): default=default_tokens_file, ) parser.add_argument( + "-w", + "--workflow", + action="store_true", + help="Use an Airflow workflow instead of a container for Data Preparation", + ) + parser.add_argument( + "-c", "--containers-assets-path", type=str, help="Path to folder containing container asset files for seeding dev database", - default=default_containers_assets_path, + default=None, ) args = parser.parse_args() if args.cert.lower() == "none": diff --git a/server/seed_utils.py b/server/seed_utils.py index 9e4dd0687..8a828151d 100644 --- a/server/seed_utils.py +++ b/server/seed_utils.py @@ -18,6 +18,14 @@ def _load_asset_content(assets_path: str, file_relative_path: str): return content +def load_workflow_config(assets_path: str, dirname: str, dev: bool = False): + if dev: + workflow_file = "workflow_dev.yaml" + else: + workflow_file = "workflow.yaml" + return _load_asset_content(assets_path, f"{dirname}/{workflow_file}") + + def load_container_config(assets_path: str, dirname: str): return _load_asset_content(assets_path, f"{dirname}/container_config.yaml") @@ -119,7 +127,171 @@ def create_benchmark(api_server, benchmark_owner_token, assets_path): "name": "chestxray_prep", "container_config": data_prep_config, "parameters_config": data_prep_params, - "image_hash": "sha256:f8697dc1c646395ad1ac54b8c0373195dbcfde0c4ef5913d4330a5fe481ae9a4", + "image_hash": { + "default": "sha256:f8697dc1c646395ad1ac54b8c0373195dbcfde0c4ef5913d4330a5fe481ae9a4" + }, + "additional_files_tarball_url": "", + "additional_files_tarball_hash": "", + "metadata": {}, + }, + "id", + ) + print( + "Data Preprocessor MLCube Created(by Benchmark Owner). ID:", + data_preprocessor_mlcube, + ) + + # Update state of the Data preprocessor MLCube to OPERATION + data_preprocessor_mlcube_state = api_server.request( + "/mlcubes/" + str(data_preprocessor_mlcube) + "/", + "PUT", + benchmark_owner_token, + {"state": "OPERATION"}, + "state", + ) + print( + "Data Preprocessor MlCube state updated to", + data_preprocessor_mlcube_state, + "by Benchmark Owner", + ) + + model_cnn_container_config = load_container_config(assets_path, "model_custom_cnn") + model_cnn_parameters_config = load_parameters_config( + assets_path, "model_custom_cnn" + ) + # Create a reference model executor mlcube by Benchmark Owner + reference_model_executor_mlcube = api_server.request( + "/mlcubes/", + "POST", + benchmark_owner_token, + { + "name": "chestxray_cnn", + "container_config": model_cnn_container_config, + "parameters_config": model_cnn_parameters_config, + "additional_files_tarball_url": ( + "https://storage.googleapis.com/medperf-storage/" + "chestxray_tutorial/cnn_weights.tar.gz" + ), + "additional_files_tarball_hash": "bff003e244759c3d7c8b9784af0819c7f252da8626745671ccf7f46b8f19a0ca", + "image_hash": { + "default": "sha256:a1bdddce05b9d156df359dd570de8031fdd1ea5a858f755139bed4a95fad19d1" + }, + "metadata": {}, + }, + "id", + ) + print( + "Reference Model Executor MlCube Created(by Benchmark Owner). ID:", + reference_model_executor_mlcube, + ) + + # Update state of the Reference Model Executor MLCube to OPERATION + reference_model_executor_mlcube_state = api_server.request( + "/mlcubes/" + str(reference_model_executor_mlcube) + "/", + "PUT", + benchmark_owner_token, + {"state": "OPERATION"}, + "state", + ) + print( + "Reference Model Executor MlCube state updated to", + reference_model_executor_mlcube_state, + "by Benchmark Owner", + ) + + evaluator_container_config = load_container_config(assets_path, "metrics") + evaluator_parameters_config = load_parameters_config(assets_path, "metrics") + + # Create a Data evalutor MLCube by Benchmark Owner + data_evaluator_mlcube = api_server.request( + "/mlcubes/", + "POST", + benchmark_owner_token, + { + "name": "chestxray_metrics", + "container_config": evaluator_container_config, + "parameters_config": evaluator_parameters_config, + "image_hash": { + "default": "sha256:d33904c1104d0a3df314f29c603901a8584fec01e58b90d7ae54c8d74d32986c" + }, + "additional_files_tarball_url": "", + "additional_files_tarball_hash": "", + "metadata": {}, + }, + "id", + ) + print( + "Data Evaluator MlCube Created(by Benchmark Owner). ID:", + data_evaluator_mlcube, + ) + + # Update state of the Data Evaluator MLCube to OPERATION + data_evaluator_mlcube_state = api_server.request( + "/mlcubes/" + str(data_evaluator_mlcube) + "/", + "PUT", + benchmark_owner_token, + {"state": "OPERATION"}, + "state", + ) + print( + "Data Evaluator MlCube state updated to", + data_evaluator_mlcube_state, + "by Benchmark Owner", + ) + + # Create a new benchmark by Benchmark owner + benchmark = api_server.request( + "/benchmarks/", + "POST", + benchmark_owner_token, + { + "name": "chestxray", + "description": "benchmark-sample", + "docs_url": "", + "demo_dataset_tarball_url": "https://storage.googleapis.com/medperf-storage/chestxray_tutorial/demo_data.tar.gz", + "demo_dataset_tarball_hash": "71faabd59139bee698010a0ae3a69e16d97bc4f2dde799d9e187b94ff9157c00", + "demo_dataset_generated_uid": "730d2474d8f22340d9da89fa2eb925fcb95683e0", + "data_preparation_mlcube": data_preprocessor_mlcube, + "reference_model_mlcube": reference_model_executor_mlcube, + "data_evaluator_mlcube": data_evaluator_mlcube, + }, + "id", + ) + print("Benchmark Created(by Benchmark Owner). ID:", benchmark) + + # Update the benchmark state to OPERATION + benchmark_state = api_server.request( + "/benchmarks/" + str(benchmark) + "/", + "PUT", + benchmark_owner_token, + {"state": "OPERATION"}, + "state", + ) + print("Benchmark state updated to", benchmark_state, "by Benchmark owner") + + return benchmark + + +def create_workflow_benchmark(api_server, benchmark_owner_token, assets_path): + print( + "##########################BENCHMARK OWNER (WORKFLOW)##########################" + ) + + data_prep_config = load_workflow_config(assets_path, "data_preparator_workflow") + data_prep_params = load_parameters_config(assets_path, "data_preparator_workflow") + + # Create a Data preprocessor MLCube by Benchmark Owner + data_preprocessor_mlcube = api_server.request( + "/mlcubes/", + "POST", + benchmark_owner_token, + { + "name": "chestxray_prep", + "container_config": data_prep_config, + "parameters_config": data_prep_params, + "image_hash": { + "default": "sha256:f8697dc1c646395ad1ac54b8c0373195dbcfde0c4ef5913d4330a5fe481ae9a4" + }, "additional_files_tarball_url": "", "additional_files_tarball_hash": "", "metadata": {}, @@ -149,6 +321,9 @@ def create_benchmark(api_server, benchmark_owner_token, assets_path): model_cnn_parameters_config = load_parameters_config( assets_path, "model_custom_cnn" ) + model_cnn_hash = ( + "sha256:a1bdddce05b9d156df359dd570de8031fdd1ea5a858f755139bed4a95fad19d1" + ) # Create a reference model executor mlcube by Benchmark Owner reference_model_executor_mlcube = api_server.request( "/mlcubes/", @@ -163,7 +338,7 @@ def create_benchmark(api_server, benchmark_owner_token, assets_path): "chestxray_tutorial/cnn_weights.tar.gz" ), "additional_files_tarball_hash": "bff003e244759c3d7c8b9784af0819c7f252da8626745671ccf7f46b8f19a0ca", - "image_hash": "sha256:a1bdddce05b9d156df359dd570de8031fdd1ea5a858f755139bed4a95fad19d1", + "image_hash": {"default": model_cnn_hash}, "metadata": {}, }, "id", @@ -189,6 +364,9 @@ def create_benchmark(api_server, benchmark_owner_token, assets_path): evaluator_container_config = load_container_config(assets_path, "metrics") evaluator_parameters_config = load_parameters_config(assets_path, "metrics") + evaluator_hash = ( + "sha256:d33904c1104d0a3df314f29c603901a8584fec01e58b90d7ae54c8d74d32986c" + ) # Create a Data evalutor MLCube by Benchmark Owner data_evaluator_mlcube = api_server.request( "/mlcubes/", @@ -198,7 +376,7 @@ def create_benchmark(api_server, benchmark_owner_token, assets_path): "name": "chestxray_metrics", "container_config": evaluator_container_config, "parameters_config": evaluator_parameters_config, - "image_hash": "sha256:d33904c1104d0a3df314f29c603901a8584fec01e58b90d7ae54c8d74d32986c", + "image_hash": {"default": evaluator_hash}, "additional_files_tarball_url": "", "additional_files_tarball_hash": "", "metadata": {}, @@ -267,7 +445,6 @@ def create_model( mobilenet_parameters_config = load_parameters_config( assets_path, "model_mobilenetv2" ) - # Create a model mlcube by Model Owner model_executor1_mlcube = api_server.request( "/mlcubes/", @@ -282,7 +459,9 @@ def create_model( "chestxray_tutorial/mobilenetv2_weights.tar.gz" ), "additional_files_tarball_hash": "771f67bba92a11c83d16a522f0ba1018020ff758e2277d33f49056680c788892", - "image_hash": "sha256:f27deb052eafd48ad1e350ceef7b0b9600aef0ea3f8cba47baee2b1d17411a83", + "image_hash": { + "default": "sha256:f27deb052eafd48ad1e350ceef7b0b9600aef0ea3f8cba47baee2b1d17411a83" + }, "metadata": {}, }, "id", @@ -347,3 +526,51 @@ def create_model( model_executor1_in_benchmark_status, "(by Benchmark Owner)", ) + + +def create_rano_workflow_mlcube(api_server, benchmark_owner_token, assets_path): + print( + "##########################BENCHMARK OWNER (RANO WORKFLOW)##########################" + ) + + data_prep_config = load_workflow_config( + assets_path, "data_preparator_workflow", dev=True + ) + data_prep_params = load_parameters_config(assets_path, "data_preparator_workflow") + additional_files_url = "https://storage.googleapis.com/medperf-storage/rano_test_assets/dev_models_and_more.tar.gz" + # Create a Data preprocessor MLCube by Benchmark Owner + data_preprocessor_mlcube = api_server.request( + "/mlcubes/", + "POST", + benchmark_owner_token, + { + "name": "rano_workflow_prep", + "container_config": data_prep_config, + "parameters_config": data_prep_params, + "image_hash": { + "default": "sha256:bc9c50d360d5ac2369eb2eaae8146c33c7eef3d0b0506bbdf26692262c786f50" + }, + "additional_files_tarball_url": additional_files_url, + "additional_files_tarball_hash": "808632d9b9fa1da00faa923a752ab47eb0bc19daff037e9c2447b994dd415084", + "metadata": {}, + }, + "id", + ) + print( + "Data Preprocessor MLCube Created(by Benchmark Owner). ID:", + data_preprocessor_mlcube, + ) + + # Update state of the Data preprocessor MLCube to OPERATION + data_preprocessor_mlcube_state = api_server.request( + "/mlcubes/" + str(data_preprocessor_mlcube) + "/", + "PUT", + benchmark_owner_token, + {"state": "OPERATION"}, + "state", + ) + print( + "Data Preprocessor MlCube state updated to", + data_preprocessor_mlcube_state, + "by Benchmark Owner", + )