diff --git a/README.md b/README.md index 2f07c511a..550d281ca 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ Inside this repo you can find all important pieces for running MedPerf. In its c If you use MedPerf, please cite our main paper: Karargyris, A., Umeton, R., Sheller, M.J. et al. Federated benchmarking of medical artificial intelligence with MedPerf. *Nature Machine Intelligence* **5**, 799–810 (2023). [https://www.nature.com/articles/s42256-023-00652-2](https://www.nature.com/articles/s42256-023-00652-2) -Additonally, here you can see how others used MedPerf already: [https://scholar.google.com/scholar?q="medperf"](https://scholar.google.com/scholar?q="medperf"). +Additionally, here you can see how others used MedPerf already: [https://scholar.google.com/scholar?q="medperf"](https://scholar.google.com/scholar?q="medperf"). ## Experiments diff --git a/cli/cli_tests.sh b/cli/cli_tests.sh index 697c40105..4640fcca0 100755 --- a/cli/cli_tests.sh +++ b/cli/cli_tests.sh @@ -5,7 +5,6 @@ ################### Start Testing ######################## ########################################################## - ########################################################## echo "==========================================" echo "Printing MedPerf version" @@ -195,7 +194,7 @@ echo "Running data submission step" echo "=====================================" print_eval "medperf dataset submit -p $PREP_UID -d $DIRECTORY/dataset_a -l $DIRECTORY/dataset_a --name='dataset_a' --description='mock dataset a' --location='mock location a' -y" checkFailed "Data submission step failed" -DSET_A_UID=$(medperf dataset ls | grep dataset_a | tr -s ' ' | cut -d ' ' -f 1) +DSET_A_UID=$(medperf dataset ls | grep dataset_a | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) echo "DSET_A_UID=$DSET_A_UID" ########################################################## diff --git a/cli/medperf/cli.py b/cli/medperf/cli.py index 4fc7102c4..0910c3ed8 100644 --- a/cli/medperf/cli.py +++ b/cli/medperf/cli.py @@ -71,7 +71,7 @@ def execute( please run the command again with the --no-cache option.\n""" ) else: - ResultSubmission.run(result.generated_uid, approved=approval) + ResultSubmission.run(result.local_id, approved=approval) config.ui.print("✅ Done!") diff --git a/cli/medperf/commands/benchmark/benchmark.py b/cli/medperf/commands/benchmark/benchmark.py index f02d67cb4..35d719b0d 100644 --- a/cli/medperf/commands/benchmark/benchmark.py +++ b/cli/medperf/commands/benchmark/benchmark.py @@ -16,14 +16,16 @@ @app.command("ls") @clean_except def list( - local: bool = typer.Option(False, "--local", help="Get local benchmarks"), + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered benchmarks" + ), mine: bool = typer.Option(False, "--mine", help="Get current-user benchmarks"), ): - """List benchmarks stored locally and remotely from the user""" + """List benchmarks""" EntityList.run( Benchmark, fields=["UID", "Name", "Description", "State", "Approval Status", "Registered"], - local_only=local, + unregistered=unregistered, mine_only=mine, ) @@ -162,10 +164,10 @@ def view( "--format", help="Format to display contents. Available formats: [yaml, json]", ), - local: bool = typer.Option( + unregistered: bool = typer.Option( False, - "--local", - help="Display local benchmarks if benchmark ID is not provided", + "--unregistered", + help="Display unregistered benchmarks if benchmark ID is not provided", ), mine: bool = typer.Option( False, @@ -180,4 +182,4 @@ def view( ), ): """Displays the information of one or more benchmarks""" - EntityView.run(entity_id, Benchmark, format, local, mine, output) + EntityView.run(entity_id, Benchmark, format, unregistered, mine, output) diff --git a/cli/medperf/commands/benchmark/submit.py b/cli/medperf/commands/benchmark/submit.py index ebace1880..05d1a0d10 100644 --- a/cli/medperf/commands/benchmark/submit.py +++ b/cli/medperf/commands/benchmark/submit.py @@ -79,7 +79,7 @@ def run_compatibility_test(self): self.ui.print("Running compatibility test") self.bmk.write() data_uid, results = CompatibilityTestExecution.run( - benchmark=self.bmk.generated_uid, + benchmark=self.bmk.local_id, no_cache=self.no_cache, skip_data_preparation_step=self.skip_data_preparation_step, ) diff --git a/cli/medperf/commands/compatibility_test/compatibility_test.py b/cli/medperf/commands/compatibility_test/compatibility_test.py index a3b25ac78..0bd4a4695 100644 --- a/cli/medperf/commands/compatibility_test/compatibility_test.py +++ b/cli/medperf/commands/compatibility_test/compatibility_test.py @@ -95,7 +95,11 @@ def run( @clean_except def list(): """List previously executed tests reports.""" - EntityList.run(TestReport, fields=["UID", "Data Source", "Model", "Evaluator"]) + EntityList.run( + TestReport, + fields=["UID", "Data Source", "Model", "Evaluator"], + unregistered=True, + ) @app.command("view") @@ -116,4 +120,4 @@ def view( ), ): """Displays the information of one or more test reports""" - EntityView.run(entity_id, TestReport, format, output=output) + EntityView.run(entity_id, TestReport, format, unregistered=True, output=output) diff --git a/cli/medperf/commands/compatibility_test/run.py b/cli/medperf/commands/compatibility_test/run.py index 2e3082849..f06603d57 100644 --- a/cli/medperf/commands/compatibility_test/run.py +++ b/cli/medperf/commands/compatibility_test/run.py @@ -239,7 +239,7 @@ def cached_results(self): """ if self.no_cache: return - uid = self.report.generated_uid + uid = self.report.local_id try: report = TestReport.get(uid) except InvalidArgumentError: diff --git a/cli/medperf/commands/compatibility_test/utils.py b/cli/medperf/commands/compatibility_test/utils.py index a12ac5ea2..c56a57d41 100644 --- a/cli/medperf/commands/compatibility_test/utils.py +++ b/cli/medperf/commands/compatibility_test/utils.py @@ -138,23 +138,23 @@ def create_test_dataset( # TODO: existing dataset could make problems # make some changes since this is a test dataset config.tmp_paths.remove(data_creation.dataset.path) - data_creation.dataset.write() if skip_data_preparation_step: data_creation.make_dataset_prepared() dataset = data_creation.dataset + old_generated_uid = dataset.generated_uid + old_path = dataset.path # prepare/check dataset DataPreparation.run(dataset.generated_uid) # update dataset generated_uid - old_path = dataset.path - generated_uid = get_folders_hash([dataset.data_path, dataset.labels_path]) - dataset.generated_uid = generated_uid - dataset.write() - if dataset.input_data_hash != dataset.generated_uid: + new_generated_uid = get_folders_hash([dataset.data_path, dataset.labels_path]) + if new_generated_uid != old_generated_uid: # move to a correct location if it underwent preparation - new_path = old_path.replace(dataset.input_data_hash, generated_uid) + new_path = old_path.replace(old_generated_uid, new_generated_uid) remove_path(new_path) os.rename(old_path, new_path) + dataset.generated_uid = new_generated_uid + dataset.write() - return generated_uid + return new_generated_uid diff --git a/cli/medperf/commands/dataset/dataset.py b/cli/medperf/commands/dataset/dataset.py index a27e36814..fc18022ac 100644 --- a/cli/medperf/commands/dataset/dataset.py +++ b/cli/medperf/commands/dataset/dataset.py @@ -17,17 +17,19 @@ @app.command("ls") @clean_except def list( - local: bool = typer.Option(False, "--local", help="Get local datasets"), + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered datasets" + ), mine: bool = typer.Option(False, "--mine", help="Get current-user datasets"), mlcube: int = typer.Option( None, "--mlcube", "-m", help="Get datasets for a given data prep mlcube" ), ): - """List datasets stored locally and remotely from the user""" + """List datasets""" EntityList.run( Dataset, fields=["UID", "Name", "Data Preparation Cube UID", "State", "Status", "Owner"], - local_only=local, + unregistered=unregistered, mine_only=mine, mlcube=mlcube, ) @@ -149,8 +151,10 @@ def view( "--format", help="Format to display contents. Available formats: [yaml, json]", ), - local: bool = typer.Option( - False, "--local", help="Display local datasets if dataset ID is not provided" + unregistered: bool = typer.Option( + False, + "--unregistered", + help="Display unregistered datasets if dataset ID is not provided", ), mine: bool = typer.Option( False, @@ -165,4 +169,4 @@ def view( ), ): """Displays the information of one or more datasets""" - EntityView.run(entity_id, Dataset, format, local, mine, output) + EntityView.run(entity_id, Dataset, format, unregistered, mine, output) diff --git a/cli/medperf/commands/execution.py b/cli/medperf/commands/execution.py index d8afb2244..85416fe96 100644 --- a/cli/medperf/commands/execution.py +++ b/cli/medperf/commands/execution.py @@ -47,12 +47,12 @@ def prepare(self): logging.debug(f"tmp results output: {self.results_path}") def __setup_logs_path(self): - model_uid = self.model.generated_uid - eval_uid = self.evaluator.generated_uid - data_hash = self.dataset.generated_uid + model_uid = self.model.local_id + eval_uid = self.evaluator.local_id + data_uid = self.dataset.local_id logs_path = os.path.join( - config.experiments_logs_folder, str(model_uid), str(data_hash) + config.experiments_logs_folder, str(model_uid), str(data_uid) ) os.makedirs(logs_path, exist_ok=True) model_logs_path = os.path.join(logs_path, "model.log") @@ -60,10 +60,10 @@ def __setup_logs_path(self): return model_logs_path, metrics_logs_path def __setup_predictions_path(self): - model_uid = self.model.generated_uid - data_hash = self.dataset.generated_uid + model_uid = self.model.local_id + data_uid = self.dataset.local_id preds_path = os.path.join( - config.predictions_folder, str(model_uid), str(data_hash) + config.predictions_folder, str(model_uid), str(data_uid) ) if os.path.exists(preds_path): msg = f"Found existing predictions for model {self.model.id} on dataset " diff --git a/cli/medperf/commands/list.py b/cli/medperf/commands/list.py index 5fd462bf7..99236ac3f 100644 --- a/cli/medperf/commands/list.py +++ b/cli/medperf/commands/list.py @@ -1,3 +1,5 @@ +from typing import List, Type +from medperf.entities.interface import Entity from medperf.exceptions import InvalidArgumentError from tabulate import tabulate @@ -8,29 +10,38 @@ class EntityList: @staticmethod def run( - entity_class, - fields, - local_only: bool = False, + entity_class: Type[Entity], + fields: List[str], + unregistered: bool = False, mine_only: bool = False, **kwargs, ): """Lists all local datasets Args: - local_only (bool, optional): Display all local results. Defaults to False. - mine_only (bool, optional): Display all current-user results. Defaults to False. + unregistered (bool, optional): Display only local unregistered results. Defaults to False. + mine_only (bool, optional): Display all registered current-user results. Defaults to False. kwargs (dict): Additional parameters for filtering entity lists. """ - entity_list = EntityList(entity_class, fields, local_only, mine_only, **kwargs) + entity_list = EntityList( + entity_class, fields, unregistered, mine_only, **kwargs + ) entity_list.prepare() entity_list.validate() entity_list.filter() entity_list.display() - def __init__(self, entity_class, fields, local_only, mine_only, **kwargs): + def __init__( + self, + entity_class: Type[Entity], + fields: List[str], + unregistered: bool, + mine_only: bool, + **kwargs, + ): self.entity_class = entity_class self.fields = fields - self.local_only = local_only + self.unregistered = unregistered self.mine_only = mine_only self.filters = kwargs self.data = [] @@ -40,7 +51,7 @@ def prepare(self): self.filters["owner"] = get_medperf_user_data()["id"] entities = self.entity_class.all( - local_only=self.local_only, filters=self.filters + unregistered=self.unregistered, filters=self.filters ) self.data = [entity.display_dict() for entity in entities] diff --git a/cli/medperf/commands/mlcube/edit.py b/cli/medperf/commands/mlcube/edit.py new file mode 100644 index 000000000..31b44df00 --- /dev/null +++ b/cli/medperf/commands/mlcube/edit.py @@ -0,0 +1,112 @@ +import logging +from typing import Union + +import medperf.config as config +from medperf.entities.cube import Cube +from medperf.entities.edit_cube import EditCubeData + + +class EditCube: + @classmethod + def run(cls, cube_uid: Union[str, int], mlcube_partial_info: EditCubeData): + """Update mlcube in the development mode on the medperf server + + Args: + cube_uid: uid of cube to modify + mlcube_partial_info (dict): Dictionary containing the modified fields. + """ + ui = config.ui + + logging.debug("Downloading initial MLCube..") + edition = cls(cube_uid, mlcube_partial_info) + logging.debug("Validating MLCube DEVELOPMENT state..") + edition.validate_dev_state() + + with ui.interactive(): + ui.text = "Validating updated MLCube can be downloaded" + logging.debug("Applying MLCube edit..") + edition.apply_and_get_hashes() + ui.text = "Submitting MLCube edit to MedPerf" + logging.debug("Uploading MLCube..") + edition.upload() + edition.write() + + def __init__(self, cube_uid: Union[str, int], edit_info: EditCubeData): + self.ui = config.ui + self.cube = Cube.get(cube_uid) + self.edit_info = edit_info + + def validate_dev_state(self): + if self.cube.state != "DEVELOPMENT": + raise ValueError("Only cubes in development state can be edited") + + def apply_and_get_hashes(self): + cube = self.cube + new = self.edit_info + + if new.name: + cube.name = new.name + + if new.git_mlcube_url: + cube.git_mlcube_url = new.git_mlcube_url + # Differs from further ifs: if mlcube.yaml url is provided, reset image also + cube.image_hash = "" + + if new.mlcube_hash: + cube.mlcube_hash = new.mlcube_hash + elif new.git_mlcube_url is not None: + cube.mlcube_hash = "" + + if new.git_parameters_url: + cube.git_parameters_url = new.git_parameters_url + + if new.parameters_hash: + cube.parameters_hash = new.parameters_hash + elif new.git_parameters_url is not None: + cube.parameters_hash = "" + + if new.image_tarball_url: + cube.image_tarball_url = new.image_tarball_url + # same as with git_mlcube_url + cube.image_hash = "" + + if new.image_tarball_hash: + cube.image_tarball_hash = new.image_tarball_hash + elif new.image_tarball_url is not None: + cube.image_tarball_hash = "" + + if new.additional_files_tarball_url: + cube.additional_files_tarball_url = new.additional_files_tarball_url + + if new.additional_files_tarball_hash: + cube.additional_files_tarball_hash = new.additional_files_tarball_hash + elif new.additional_files_tarball_url is not None: + cube.additional_files_tarball_hash = "" + + self.download() + + if new.git_mlcube_url and not new.mlcube_hash: + new.mlcube_hash = cube.mlcube_hash + if new.git_parameters_url and not new.parameters_hash: + new.parameters_hash = cube.parameters_hash + if new.image_tarball_url and not new.image_tarball_hash: + new.image_tarball_hash = cube.image_tarball_hash + if new.additional_files_tarball_url and not new.additional_files_tarball_hash: + new.additional_files_tarball_hash = cube.additional_files_tarball_hash + if new.git_mlcube_url or new.image_tarball_url: + new.image_hash = cube.image_hash + + def download(self): + logging.debug("removing from filesystem...") + self.cube.remove_from_filesystem() + logging.debug("download config files..") + self.cube.download_config_files() + logging.debug("download run files..") + self.cube.download_run_files() + + def upload(self): + updated_body = Cube.edit(self.cube.id, self.edit_info) + self.cube = Cube(**updated_body) + + def write(self): + self.cube.write() diff --git a/cli/medperf/commands/mlcube/mlcube.py b/cli/medperf/commands/mlcube/mlcube.py index 4c365e574..5d36160bc 100644 --- a/cli/medperf/commands/mlcube/mlcube.py +++ b/cli/medperf/commands/mlcube/mlcube.py @@ -4,11 +4,13 @@ import medperf.config as config from medperf.decorators import clean_except from medperf.entities.cube import Cube +from medperf.entities.edit_cube import EditCubeData from medperf.commands.list import EntityList from medperf.commands.view import EntityView from medperf.commands.mlcube.create import CreateCube from medperf.commands.mlcube.submit import SubmitCube from medperf.commands.mlcube.associate import AssociateCube +from medperf.commands.mlcube.edit import EditCube app = typer.Typer() @@ -16,14 +18,16 @@ @app.command("ls") @clean_except def list( - local: bool = typer.Option(False, "--local", help="Get local mlcubes"), + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered mlcubes" + ), mine: bool = typer.Option(False, "--mine", help="Get current-user mlcubes"), ): - """List mlcubes stored locally and remotely from the user""" + """List mlcubes""" EntityList.run( Cube, fields=["UID", "Name", "State", "Registered"], - local_only=local, + unregistered=unregistered, mine_only=mine, ) @@ -121,6 +125,71 @@ def submit( config.ui.print("✅ Done!") +@app.command("edit") +@clean_except +def edit( + uid: str = typer.Option(..., "--uid", "-u", help="UID of the MLCube to edit"), + name: str = typer.Option(None, "--name", "-n", help="Name of the mlcube"), + mlcube_file: str = typer.Option( + None, + "--mlcube-file", + "-m", + help="Identifier to download the mlcube file. See the description above", + ), + mlcube_hash: str = typer.Option(None, "--mlcube-hash", help="hash of mlcube file"), + parameters_file: str = typer.Option( + None, + "--parameters-file", + "-p", + help="Identifier to download the parameters file. See the description above", + ), + parameters_hash: str = typer.Option(None, "--parameters-hash", help="hash of parameters file"), + additional_file: str = typer.Option( + None, + "--additional-file", + "-a", + help="Identifier to download the additional files tarball. See the description above", + ), + additional_hash: str = typer.Option(None, "--additional-hash", help="hash of additional file"), + image_file: str = typer.Option( + None, + "--image-file", + "-i", + help="Identifier to download the image file. See the description above", + ), + image_hash: str = typer.Option(None, "--image-hash", help="hash of image file"), +): + """Updates the existing mlcube. Only mlcubes in DEVELOPMENT state may be updated.\n + The following assets:\n + - mlcube_file\n + - parameters_file\n + - additional_file\n + - image_file\n + are expected to be given in the following format: + where `source_prefix` instructs the client how to download the resource, and `resource_identifier` + is the identifier used to download the asset. The following are supported:\n + 1. A direct link: "direct:"\n + 2. An asset hosted on the Synapse platform: "synapse:"\n\n + + If a URL is given without a source prefix, it will be treated as a direct download link. + """ + + mlcube_partial_info = EditCubeData( + uid=uid, + name=name, + git_mlcube_url=mlcube_file, + git_mlcube_hash=mlcube_hash, + git_parameters_url=parameters_file, + parameters_hash=parameters_hash, + image_tarball_url=image_file, + image_tarball_hash=image_hash, + additional_files_tarball_url=additional_file, + additional_files_tarball_hash=additional_hash, + ) + EditCube.run(uid, mlcube_partial_info) + config.ui.print("✅ Done!") + + @app.command("associate") @clean_except def associate( @@ -148,8 +217,10 @@ def view( "--format", help="Format to display contents. Available formats: [yaml, json]", ), - local: bool = typer.Option( - False, "--local", help="Display local mlcubes if mlcube ID is not provided" + unregistered: bool = typer.Option( + False, + "--unregistered", + help="Display unregistered mlcubes if mlcube ID is not provided", ), mine: bool = typer.Option( False, @@ -164,4 +235,4 @@ def view( ), ): """Displays the information of one or more mlcubes""" - EntityView.run(entity_id, Cube, format, local, mine, output) + EntityView.run(entity_id, Cube, format, unregistered, mine, output) diff --git a/cli/medperf/commands/result/create.py b/cli/medperf/commands/result/create.py index 42f97d990..26d52fa2e 100644 --- a/cli/medperf/commands/result/create.py +++ b/cli/medperf/commands/result/create.py @@ -1,5 +1,6 @@ import os from typing import List, Optional +from medperf.account_management.account_management import get_medperf_user_data from medperf.commands.execution import Execution from medperf.entities.result import Result from tabulate import tabulate @@ -143,7 +144,9 @@ def __validate_models(self, benchmark_models): raise InvalidArgumentError(msg) def load_cached_results(self): - results = Result.all() + user_id = get_medperf_user_data()["id"] + results = Result.all(filters={"owner": user_id}) + results += Result.all(unregistered=True) benchmark_dset_results = [ result for result in results @@ -254,7 +257,7 @@ def print_summary(self): data_lists_for_display.append( [ experiment["model_uid"], - experiment["result"].generated_uid, + experiment["result"].local_id, experiment["result"].metadata["partial"], experiment["cached"], experiment["error"], diff --git a/cli/medperf/commands/result/result.py b/cli/medperf/commands/result/result.py index 6fbb3b08a..40b65c52e 100644 --- a/cli/medperf/commands/result/result.py +++ b/cli/medperf/commands/result/result.py @@ -62,17 +62,19 @@ def submit( @app.command("ls") @clean_except def list( - local: bool = typer.Option(False, "--local", help="Get local results"), + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered results" + ), mine: bool = typer.Option(False, "--mine", help="Get current-user results"), benchmark: int = typer.Option( None, "--benchmark", "-b", help="Get results for a given benchmark" ), ): - """List results stored locally and remotely from the user""" + """List results""" EntityList.run( Result, fields=["UID", "Benchmark", "Model", "Dataset", "Registered"], - local_only=local, + unregistered=unregistered, mine_only=mine, benchmark=benchmark, ) @@ -88,8 +90,10 @@ def view( "--format", help="Format to display contents. Available formats: [yaml, json]", ), - local: bool = typer.Option( - False, "--local", help="Display local results if result ID is not provided" + unregistered: bool = typer.Option( + False, + "--unregistered", + help="Display unregistered results if result ID is not provided", ), mine: bool = typer.Option( False, @@ -107,4 +111,6 @@ def view( ), ): """Displays the information of one or more results""" - EntityView.run(entity_id, Result, format, local, mine, output, benchmark=benchmark) + EntityView.run( + entity_id, Result, format, unregistered, mine, output, benchmark=benchmark + ) diff --git a/cli/medperf/commands/result/submit.py b/cli/medperf/commands/result/submit.py index 15649ee04..b69a596ce 100644 --- a/cli/medperf/commands/result/submit.py +++ b/cli/medperf/commands/result/submit.py @@ -3,7 +3,6 @@ from medperf.exceptions import CleanExit from medperf.utils import remove_path, dict_pretty_print, approval_prompt from medperf.entities.result import Result -from medperf.enums import Status from medperf import config @@ -11,6 +10,7 @@ class ResultSubmission: @classmethod def run(cls, result_uid, approved=False): sub = cls(result_uid, approved=approved) + sub.get_result() updated_result_dict = sub.upload_results() sub.to_permanent_path(updated_result_dict) sub.write(updated_result_dict) @@ -21,27 +21,26 @@ def __init__(self, result_uid, approved=False): self.ui = config.ui self.approved = approved - def request_approval(self, result): - if result.approval_status == Status.APPROVED: - return True + def get_result(self): + self.result = Result.get(self.result_uid) - dict_pretty_print(result.results) + def request_approval(self): + dict_pretty_print(self.result.results) self.ui.print("Above are the results generated by the model") approved = approval_prompt( - "Do you approve uploading the presented results to the MLCommons comms? [Y/n]" + "Do you approve uploading the presented results to the MedPerf? [Y/n]" ) return approved def upload_results(self): - result = Result.get(self.result_uid) - approved = self.approved or self.request_approval(result) + approved = self.approved or self.request_approval() if not approved: raise CleanExit("Results upload operation cancelled") - updated_result_dict = result.upload() + updated_result_dict = self.result.upload() return updated_result_dict def to_permanent_path(self, result_dict: dict): @@ -50,12 +49,12 @@ def to_permanent_path(self, result_dict: dict): Args: result_dict (dict): updated results dictionary """ - result = Result(**result_dict) - result_storage = config.results_folder - old_res_loc = os.path.join(result_storage, result.generated_uid) - new_res_loc = result.path - remove_path(new_res_loc) - os.rename(old_res_loc, new_res_loc) + + old_result_loc = self.result.path + updated_result = Result(**result_dict) + new_result_loc = updated_result.path + remove_path(new_result_loc) + os.rename(old_result_loc, new_result_loc) def write(self, updated_result_dict): result = Result(**updated_result_dict) diff --git a/cli/medperf/commands/view.py b/cli/medperf/commands/view.py index b4c242f0a..d19aedec0 100644 --- a/cli/medperf/commands/view.py +++ b/cli/medperf/commands/view.py @@ -1,6 +1,6 @@ import yaml import json -from typing import Union +from typing import Union, Type from medperf import config from medperf.account_management import get_medperf_user_data @@ -12,9 +12,9 @@ class EntityView: @staticmethod def run( entity_id: Union[int, str], - entity_class: Entity, + entity_class: Type[Entity], format: str = "yaml", - local_only: bool = False, + unregistered: bool = False, mine_only: bool = False, output: str = None, **kwargs, @@ -24,14 +24,14 @@ def run( Args: entity_id (Union[int, str]): Entity identifies entity_class (Entity): Entity type - local_only (bool, optional): Display all local entities. Defaults to False. + unregistered (bool, optional): Display only local unregistered entities. Defaults to False. mine_only (bool, optional): Display all current-user entities. Defaults to False. format (str, optional): What format to use to display the contents. Valid formats: [yaml, json]. Defaults to yaml. output (str, optional): Path to a file for storing the entity contents. If not provided, the contents are printed. kwargs (dict): Additional parameters for filtering entity lists. """ entity_view = EntityView( - entity_id, entity_class, format, local_only, mine_only, output, **kwargs + entity_id, entity_class, format, unregistered, mine_only, output, **kwargs ) entity_view.validate() entity_view.prepare() @@ -41,12 +41,19 @@ def run( entity_view.store() def __init__( - self, entity_id, entity_class, format, local_only, mine_only, output, **kwargs + self, + entity_id: Union[int, str], + entity_class: Type[Entity], + format: str, + unregistered: bool, + mine_only: bool, + output: str, + **kwargs, ): self.entity_id = entity_id self.entity_class = entity_class self.format = format - self.local_only = local_only + self.unregistered = unregistered self.mine_only = mine_only self.output = output self.filters = kwargs @@ -65,7 +72,7 @@ def prepare(self): self.filters["owner"] = get_medperf_user_data()["id"] entities = self.entity_class.all( - local_only=self.local_only, filters=self.filters + unregistered=self.unregistered, filters=self.filters ) self.data = [entity.todict() for entity in entities] diff --git a/cli/medperf/comms/interface.py b/cli/medperf/comms/interface.py index 01436e435..c8635b3eb 100644 --- a/cli/medperf/comms/interface.py +++ b/cli/medperf/comms/interface.py @@ -87,6 +87,18 @@ def get_cube_metadata(self, cube_uid: int) -> dict: dict: Dictionary containing url and hashes for the cube files """ + @abstractmethod + def edit_cube(self, cube_uid: int, edited_fields: dict) -> dict: + """Updates mlcube with dict of changed fields + + Args: + cube_uid (int): UID of the desired cube. + edited_fields: Dictionary containing the fields to be updated + + Returns: + dict: Dictionary containing the full mlcube + """ + @abstractmethod def get_user_cubes(self) -> List[dict]: """Retrieves metadata from all cubes registered by the user diff --git a/cli/medperf/comms/rest.py b/cli/medperf/comms/rest.py index 5ac236f93..9ccb362dc 100644 --- a/cli/medperf/comms/rest.py +++ b/cli/medperf/comms/rest.py @@ -227,6 +227,25 @@ def get_cube_metadata(self, cube_uid: int) -> dict: ) return res.json() + def edit_cube(self, cube_uid: int, edited_fields: dict) -> dict: + """Updates mlcube with dict of changed fields + + Args: + cube_uid (int): UID of the desired cube. + edited_fields: Dictionary containing the fields to be updated + + Returns: + dict: Dictionary containing the full mlcube + """ + res = self.__auth_put(f"{self.server_url}/mlcubes/{cube_uid}/", json=edited_fields) + if res.status_code != 200: + log_response_error(res) + details = format_errors_dict(res.json()) + raise CommunicationRetrievalError( + f"the specified cube doesn't exist {details}" + ) + return res.json() + def get_user_cubes(self) -> List[dict]: """Retrieves metadata from all cubes registered by the user diff --git a/cli/medperf/entities/benchmark.py b/cli/medperf/entities/benchmark.py index 849ea3fcd..e03fcdb4f 100644 --- a/cli/medperf/entities/benchmark.py +++ b/cli/medperf/entities/benchmark.py @@ -1,18 +1,13 @@ -import os -from medperf.exceptions import MedperfException -import yaml -import logging -from typing import List, Optional, Union +from typing import List, Optional from pydantic import HttpUrl, Field import medperf.config as config -from medperf.entities.interface import Entity, Uploadable -from medperf.exceptions import CommunicationRetrievalError, InvalidArgumentError -from medperf.entities.schemas import MedperfSchema, ApprovableSchema, DeployableSchema +from medperf.entities.interface import Entity +from medperf.entities.schemas import ApprovableSchema, DeployableSchema from medperf.account_management import get_medperf_user_data -class Benchmark(Entity, Uploadable, MedperfSchema, ApprovableSchema, DeployableSchema): +class Benchmark(Entity, ApprovableSchema, DeployableSchema): """ Class representing a Benchmark @@ -35,6 +30,26 @@ class Benchmark(Entity, Uploadable, MedperfSchema, ApprovableSchema, DeployableS user_metadata: dict = {} is_active: bool = True + @staticmethod + def get_type(): + return "benchmark" + + @staticmethod + def get_storage_path(): + return config.benchmarks_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_benchmark + + @staticmethod + def get_metadata_filename(): + return config.benchmarks_filename + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_benchmark + def __init__(self, *args, **kwargs): """Creates a new benchmark instance @@ -43,54 +58,12 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) - self.generated_uid = f"p{self.data_preparation_mlcube}m{self.reference_model_mlcube}e{self.data_evaluator_mlcube}" - path = config.benchmarks_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - self.path = path - - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Benchmark"]: - """Gets and creates instances of all retrievable benchmarks - - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. - - Returns: - List[Benchmark]: a list of Benchmark instances. - """ - logging.info("Retrieving all benchmarks") - benchmarks = [] - - if not local_only: - benchmarks = cls.__remote_all(filters=filters) - - remote_uids = set([bmk.id for bmk in benchmarks]) + @property + def local_id(self): + return self.name - local_benchmarks = cls.__local_all() - - benchmarks += [bmk for bmk in local_benchmarks if bmk.id not in remote_uids] - - return benchmarks - - @classmethod - def __remote_all(cls, filters: dict) -> List["Benchmark"]: - benchmarks = [] - try: - comms_fn = cls.__remote_prefilter(filters) - bmks_meta = comms_fn() - benchmarks = [cls(**meta) for meta in bmks_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all benchmarks from the server" - logging.warning(msg) - - return benchmarks - - @classmethod - def __remote_prefilter(cls, filters: dict) -> callable: + @staticmethod + def remote_prefilter(filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities Args: @@ -104,104 +77,6 @@ def __remote_prefilter(cls, filters: dict) -> callable: comms_fn = config.comms.get_user_benchmarks return comms_fn - @classmethod - def __local_all(cls) -> List["Benchmark"]: - benchmarks = [] - bmks_storage = config.benchmarks_folder - try: - uids = next(os.walk(bmks_storage))[1] - except StopIteration: - msg = "Couldn't iterate over benchmarks directory" - logging.warning(msg) - raise MedperfException(msg) - - for uid in uids: - meta = cls.__get_local_dict(uid) - benchmark = cls(**meta) - benchmarks.append(benchmark) - - return benchmarks - - @classmethod - def get( - cls, benchmark_uid: Union[str, int], local_only: bool = False - ) -> "Benchmark": - """Retrieves and creates a Benchmark instance from the server. - If benchmark already exists in the platform then retrieve that - version. - - Args: - benchmark_uid (str): UID of the benchmark. - comms (Comms): Instance of a communication interface. - - Returns: - Benchmark: a Benchmark instance with the retrieved data. - """ - - if not str(benchmark_uid).isdigit() or local_only: - return cls.__local_get(benchmark_uid) - - try: - return cls.__remote_get(benchmark_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting Benchmark {benchmark_uid} from comms failed") - logging.info(f"Looking for benchmark {benchmark_uid} locally") - return cls.__local_get(benchmark_uid) - - @classmethod - def __remote_get(cls, benchmark_uid: int) -> "Benchmark": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving benchmark {benchmark_uid} remotely") - benchmark_dict = config.comms.get_benchmark(benchmark_uid) - benchmark = cls(**benchmark_dict) - benchmark.write() - return benchmark - - @classmethod - def __local_get(cls, benchmark_uid: Union[str, int]) -> "Benchmark": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving benchmark {benchmark_uid} locally") - benchmark_dict = cls.__get_local_dict(benchmark_uid) - benchmark = cls(**benchmark_dict) - return benchmark - - @classmethod - def __get_local_dict(cls, benchmark_uid) -> dict: - """Retrieves a local benchmark information - - Args: - benchmark_uid (str): uid of the local benchmark - - Returns: - dict: information of the benchmark - """ - logging.info(f"Retrieving benchmark {benchmark_uid} from local storage") - storage = config.benchmarks_folder - bmk_storage = os.path.join(storage, str(benchmark_uid)) - bmk_file = os.path.join(bmk_storage, config.benchmarks_filename) - if not os.path.exists(bmk_file): - raise InvalidArgumentError("No benchmark with the given uid could be found") - with open(bmk_file, "r") as f: - data = yaml.safe_load(f) - - return data - @classmethod def get_models_uids(cls, benchmark_uid: int) -> List[int]: """Retrieves the list of models associated to the benchmark @@ -221,43 +96,6 @@ def get_models_uids(cls, benchmark_uid: int) -> List[int]: ] return models_uids - def todict(self) -> dict: - """Dictionary representation of the benchmark instance - - Returns: - dict: Dictionary containing benchmark information - """ - return self.extended_dict() - - def write(self) -> str: - """Writes the benchmark into disk - - Args: - filename (str, optional): name of the file. Defaults to config.benchmarks_filename. - - Returns: - str: path to the created benchmark file - """ - data = self.todict() - bmk_file = os.path.join(self.path, config.benchmarks_filename) - if not os.path.exists(bmk_file): - os.makedirs(self.path, exist_ok=True) - with open(bmk_file, "w") as f: - yaml.dump(data, f) - return bmk_file - - def upload(self): - """Uploads a benchmark to the server - - Args: - comms (Comms): communications entity to submit through - """ - if self.for_test: - raise InvalidArgumentError("Cannot upload test benchmarks.") - body = self.todict() - updated_body = config.comms.upload_benchmark(body) - return updated_body - def display_dict(self): return { "UID": self.identifier, diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index 98d2b95a8..fa250083e 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -1,7 +1,7 @@ import os import yaml import logging -from typing import List, Dict, Optional, Union +from typing import Dict, Optional, Union from pydantic import Field from pathlib import Path @@ -12,21 +12,16 @@ generate_tmp_path, spawn_and_kill, ) -from medperf.entities.interface import Entity, Uploadable -from medperf.entities.schemas import MedperfSchema, DeployableSchema -from medperf.exceptions import ( - InvalidArgumentError, - ExecutionError, - InvalidEntityError, - MedperfException, - CommunicationRetrievalError, -) +from medperf.entities.edit_cube import EditCubeData +from medperf.entities.interface import Entity +from medperf.entities.schemas import DeployableSchema +from medperf.exceptions import InvalidArgumentError, ExecutionError, InvalidEntityError import medperf.config as config from medperf.comms.entity_resources import resources from medperf.account_management import get_medperf_user_data -class Cube(Entity, Uploadable, MedperfSchema, DeployableSchema): +class Cube(Entity, DeployableSchema): """ Class representing an MLCube Container @@ -48,6 +43,32 @@ class Cube(Entity, Uploadable, MedperfSchema, DeployableSchema): metadata: dict = {} user_metadata: dict = {} + @staticmethod + def get_type(): + return "cube" + + @staticmethod + def get_storage_path(): + return config.cubes_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_cube_metadata + + @staticmethod + def get_metadata_filename(): + return config.cube_metadata_filename + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_mlcube + + # as currently edit is implemented only for mlcubes, this function is not defined + # in interface and thus is not overridden. + @staticmethod + def get_comms_edit(): + return config.comms.edit_cube + def __init__(self, *args, **kwargs): """Creates a Cube instance @@ -56,60 +77,17 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) - self.generated_uid = self.name - path = config.cubes_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - # NOTE: maybe have these as @property, to have the same entity reusable - # before and after submission - self.path = path - self.cube_path = os.path.join(path, config.cube_filename) + self.cube_path = os.path.join(self.path, config.cube_filename) self.params_path = None if self.git_parameters_url: - self.params_path = os.path.join(path, config.params_filename) - - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Cube"]: - """Class method for retrieving all retrievable MLCubes + self.params_path = os.path.join(self.path, config.params_filename) - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. - - Returns: - List[Cube]: List containing all cubes - """ - logging.info("Retrieving all cubes") - cubes = [] - if not local_only: - cubes = cls.__remote_all(filters=filters) - - remote_uids = set([cube.id for cube in cubes]) - - local_cubes = cls.__local_all() - - cubes += [cube for cube in local_cubes if cube.id not in remote_uids] - - return cubes - - @classmethod - def __remote_all(cls, filters: dict) -> List["Cube"]: - cubes = [] - - try: - comms_fn = cls.__remote_prefilter(filters) - cubes_meta = comms_fn() - cubes = [cls(**meta) for meta in cubes_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all cubes from the server" - logging.warning(msg) - - return cubes + @property + def local_id(self): + return self.name - @classmethod - def __remote_prefilter(cls, filters: dict): + @staticmethod + def remote_prefilter(filters: dict): """Applies filtering logic that must be done before retrieving remote entities Args: @@ -124,25 +102,6 @@ def __remote_prefilter(cls, filters: dict): return comms_fn - @classmethod - def __local_all(cls) -> List["Cube"]: - cubes = [] - cubes_folder = config.cubes_folder - try: - uids = next(os.walk(cubes_folder))[1] - logging.debug(f"Local cubes found: {uids}") - except StopIteration: - msg = "Couldn't iterate over cubes directory" - logging.warning(msg) - raise MedperfException(msg) - - for uid in uids: - meta = cls.__get_local_dict(uid) - cube = cls(**meta) - cubes.append(cube) - - return cubes - @classmethod def get(cls, cube_uid: Union[str, int], local_only: bool = False) -> "Cube": """Retrieves and creates a Cube instance from the comms. If cube already exists @@ -155,36 +114,12 @@ def get(cls, cube_uid: Union[str, int], local_only: bool = False) -> "Cube": Cube : a Cube instance with the retrieved data. """ - if not str(cube_uid).isdigit() or local_only: - cube = cls.__local_get(cube_uid) - else: - try: - cube = cls.__remote_get(cube_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting MLCube {cube_uid} from comms failed") - logging.info(f"Retrieving MLCube {cube_uid} from local storage") - cube = cls.__local_get(cube_uid) - + cube = super().get(cube_uid, local_only) if not cube.is_valid: raise InvalidEntityError("The requested MLCube is marked as INVALID.") cube.download_config_files() return cube - @classmethod - def __remote_get(cls, cube_uid: int) -> "Cube": - logging.debug(f"Retrieving mlcube {cube_uid} remotely") - meta = config.comms.get_cube_metadata(cube_uid) - cube = cls(**meta) - cube.write() - return cube - - @classmethod - def __local_get(cls, cube_uid: Union[str, int]) -> "Cube": - logging.debug(f"Retrieving cube {cube_uid} locally") - local_meta = cls.__get_local_dict(cube_uid) - cube = cls(**local_meta) - return cube - def download_mlcube(self): url = self.git_mlcube_url path, file_hash = resources.get_cube(url, self.path, self.mlcube_hash) @@ -316,11 +251,11 @@ def run( Defaults to {}. timeout (int, optional): timeout for the task in seconds. Defaults to None. read_protected_input (bool, optional): Wether to disable write permissions on input volumes. Defaults to True. - kwargs (dict): additional arguments that are passed directly to the mlcube command + kwargs: additional arguments that are passed directly to the mlcube command """ kwargs.update(string_params) cmd = f"mlcube --log-level {config.loglevel} run" - cmd += f" --mlcube=\"{self.cube_path}\" --task={task} --platform={config.platform} --network=none" + cmd += f' --mlcube="{self.cube_path}" --task={task} --platform={config.platform} --network=none' if config.gpus is not None: cmd += f" --gpus={config.gpus}" if read_protected_input: @@ -430,36 +365,6 @@ def get_config(self, identifier): return cube - def todict(self) -> Dict: - return self.extended_dict() - - def write(self): - cube_loc = str(Path(self.cube_path).parent) - meta_file = os.path.join(cube_loc, config.cube_metadata_filename) - os.makedirs(cube_loc, exist_ok=True) - with open(meta_file, "w") as f: - yaml.dump(self.todict(), f) - return meta_file - - def upload(self): - if self.for_test: - raise InvalidArgumentError("Cannot upload test mlcubes.") - cube_dict = self.todict() - updated_cube_dict = config.comms.upload_mlcube(cube_dict) - return updated_cube_dict - - @classmethod - def __get_local_dict(cls, uid): - cubes_folder = config.cubes_folder - meta_file = os.path.join(cubes_folder, str(uid), config.cube_metadata_filename) - if not os.path.exists(meta_file): - raise InvalidArgumentError( - "The requested mlcube information could not be found locally" - ) - with open(meta_file, "r") as f: - meta = yaml.safe_load(f) - return meta - def display_dict(self): return { "UID": self.identifier, @@ -469,3 +374,16 @@ def display_dict(self): "Created At": self.created_at, "Registered": self.is_registered, } + + @staticmethod + def edit(cube_uid: Union[str, int], edited_fields: EditCubeData) -> Dict: + """Uploads the mlcube diff and updates the entity + + Returns: + Dict: Dictionary with the updated cube + """ + + comms_func = Cube.get_comms_edit() + logging.debug(f"Editing cube {cube_uid} with fields: {edited_fields}") + updated_body = comms_func(cube_uid, edited_fields.not_null_dict()) + return updated_body diff --git a/cli/medperf/entities/dataset.py b/cli/medperf/entities/dataset.py index 4c210431f..7f13c2185 100644 --- a/cli/medperf/entities/dataset.py +++ b/cli/medperf/entities/dataset.py @@ -1,22 +1,17 @@ import os import yaml -import logging from pydantic import Field, validator -from typing import List, Optional, Union +from typing import Optional, Union from medperf.utils import remove_path -from medperf.entities.interface import Entity, Uploadable -from medperf.entities.schemas import MedperfSchema, DeployableSchema -from medperf.exceptions import ( - InvalidArgumentError, - MedperfException, - CommunicationRetrievalError, -) +from medperf.entities.interface import Entity +from medperf.entities.schemas import DeployableSchema + import medperf.config as config from medperf.account_management import get_medperf_user_data -class Dataset(Entity, Uploadable, MedperfSchema, DeployableSchema): +class Dataset(Entity, DeployableSchema): """ Class representing a Dataset @@ -37,6 +32,26 @@ class Dataset(Entity, Uploadable, MedperfSchema, DeployableSchema): report: dict = {} submitted_as_prepared: bool + @staticmethod + def get_type(): + return "dataset" + + @staticmethod + def get_storage_path(): + return config.datasets_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_dataset + + @staticmethod + def get_metadata_filename(): + return config.reg_file + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_dataset + @validator("data_preparation_mlcube", pre=True, always=True) def check_data_preparation_mlcube(cls, v, *, values, **kwargs): if not isinstance(v, int) and not values["for_test"]: @@ -47,20 +62,16 @@ def check_data_preparation_mlcube(cls, v, *, values, **kwargs): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - - path = config.datasets_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - - self.path = path self.data_path = os.path.join(self.path, "data") self.labels_path = os.path.join(self.path, "labels") self.report_path = os.path.join(self.path, config.report_file) self.metadata_path = os.path.join(self.path, config.metadata_folder) self.statistics_path = os.path.join(self.path, config.statistics_filename) + @property + def local_id(self): + return self.generated_uid + def set_raw_paths(self, raw_data_path: str, raw_labels_path: str): raw_paths_file = os.path.join(self.path, config.dataset_raw_paths_file) data = {"data_path": raw_data_path, "labels_path": raw_labels_path} @@ -86,48 +97,8 @@ def is_ready(self): flag_file = os.path.join(self.path, config.ready_flag_file) return os.path.exists(flag_file) - def todict(self): - return self.extended_dict() - - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Dataset"]: - """Gets and creates instances of all the locally prepared datasets - - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. - - Returns: - List[Dataset]: a list of Dataset instances. - """ - logging.info("Retrieving all datasets") - dsets = [] - if not local_only: - dsets = cls.__remote_all(filters=filters) - - remote_uids = set([dset.id for dset in dsets]) - - local_dsets = cls.__local_all() - - dsets += [dset for dset in local_dsets if dset.id not in remote_uids] - - return dsets - - @classmethod - def __remote_all(cls, filters: dict) -> List["Dataset"]: - dsets = [] - try: - comms_fn = cls.__remote_prefilter(filters) - dsets_meta = comms_fn() - dsets = [cls(**meta) for meta in dsets_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all datasets from the server" - logging.warning(msg) - - return dsets - - @classmethod - def __remote_prefilter(cls, filters: dict) -> callable: + @staticmethod + def remote_prefilter(filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities Args: @@ -149,111 +120,6 @@ def func(): return comms_fn - @classmethod - def __local_all(cls) -> List["Dataset"]: - dsets = [] - datasets_folder = config.datasets_folder - try: - uids = next(os.walk(datasets_folder))[1] - except StopIteration: - msg = "Couldn't iterate over the dataset directory" - logging.warning(msg) - raise MedperfException(msg) - - for uid in uids: - local_meta = cls.__get_local_dict(uid) - dset = cls(**local_meta) - dsets.append(dset) - - return dsets - - @classmethod - def get(cls, dset_uid: Union[str, int], local_only: bool = False) -> "Dataset": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - if not str(dset_uid).isdigit() or local_only: - return cls.__local_get(dset_uid) - - try: - return cls.__remote_get(dset_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting Dataset {dset_uid} from comms failed") - logging.info(f"Looking for dataset {dset_uid} locally") - return cls.__local_get(dset_uid) - - @classmethod - def __remote_get(cls, dset_uid: int) -> "Dataset": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving dataset {dset_uid} remotely") - meta = config.comms.get_dataset(dset_uid) - dataset = cls(**meta) - dataset.write() - return dataset - - @classmethod - def __local_get(cls, dset_uid: Union[str, int]) -> "Dataset": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving dataset {dset_uid} locally") - local_meta = cls.__get_local_dict(dset_uid) - dataset = cls(**local_meta) - return dataset - - def write(self): - logging.info(f"Updating registration information for dataset: {self.id}") - logging.debug(f"registration information: {self.todict()}") - regfile = os.path.join(self.path, config.reg_file) - os.makedirs(self.path, exist_ok=True) - with open(regfile, "w") as f: - yaml.dump(self.todict(), f) - return regfile - - def upload(self): - """Uploads the registration information to the comms. - - Args: - comms (Comms): Instance of the comms interface. - """ - if self.for_test: - raise InvalidArgumentError("Cannot upload test datasets.") - dataset_dict = self.todict() - updated_dataset_dict = config.comms.upload_dataset(dataset_dict) - return updated_dataset_dict - - @classmethod - def __get_local_dict(cls, data_uid): - dataset_path = os.path.join(config.datasets_folder, str(data_uid)) - regfile = os.path.join(dataset_path, config.reg_file) - if not os.path.exists(regfile): - raise InvalidArgumentError( - "The requested dataset information could not be found locally" - ) - with open(regfile, "r") as f: - reg = yaml.safe_load(f) - return reg - def display_dict(self): return { "UID": self.identifier, diff --git a/cli/medperf/entities/edit_cube.py b/cli/medperf/entities/edit_cube.py new file mode 100644 index 000000000..f4f68bda7 --- /dev/null +++ b/cli/medperf/entities/edit_cube.py @@ -0,0 +1,21 @@ +from typing import Union, Optional +from pydantic import BaseModel + + +class EditCubeData(BaseModel): + """represents a partial mlcube with fields to be updated""" + uid: Union[str, int] + name: Optional[str] + git_mlcube_url: Optional[str] + mlcube_hash: Optional[str] + git_parameters_url: Optional[str] + parameters_hash: Optional[str] + image_tarball_url: Optional[str] + image_tarball_hash: Optional[str] + additional_files_tarball_url: Optional[str] + additional_files_tarball_hash: Optional[str] + image_hash: Optional[str] = None + + def not_null_dict(self): + """returns a dictionary of the non-null fields""" + return {k: v for k, v in self.dict().items() if v is not None} diff --git a/cli/medperf/entities/interface.py b/cli/medperf/entities/interface.py index af2afabd7..cff39768d 100644 --- a/cli/medperf/entities/interface.py +++ b/cli/medperf/entities/interface.py @@ -1,77 +1,231 @@ -from typing import List, Dict, Union -from abc import ABC, abstractmethod +import shutil +from typing import List, Dict, Union, Callable +from abc import ABC +import logging +import os +import yaml +from medperf.exceptions import MedperfException, InvalidArgumentError +from medperf.entities.schemas import MedperfSchema +from typing import Type, TypeVar +EntityType = TypeVar("EntityType", bound="Entity") -class Entity(ABC): - @abstractmethod + +class Entity(MedperfSchema, ABC): + @staticmethod + def get_type() -> str: + raise NotImplementedError() + + @staticmethod + def get_storage_path() -> str: + raise NotImplementedError() + + @staticmethod + def get_comms_retriever() -> Callable[[int], dict]: + raise NotImplementedError() + + @staticmethod + def get_metadata_filename() -> str: + raise NotImplementedError() + + @staticmethod + def get_comms_uploader() -> Callable[[dict], dict]: + raise NotImplementedError() + + @property + def local_id(self) -> str: + raise NotImplementedError() + + @property + def identifier(self) -> Union[int, str]: + return self.id or self.local_id + + @property + def is_registered(self) -> bool: + return self.id is not None + + @property + def path(self) -> str: + storage_path = self.get_storage_path() + return os.path.join(storage_path, str(self.identifier)) + + @classmethod def all( - cls, local_only: bool = False, comms_func: callable = None - ) -> List["Entity"]: + cls: Type[EntityType], unregistered: bool = False, filters: dict = {} + ) -> List[EntityType]: """Gets a list of all instances of the respective entity. - Wether the list is local or remote depends on the implementation. + Whether the list is local or remote depends on the implementation. Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - comms_func (callable, optional): Function to use to retrieve remote entities. - If not provided, will use the default entrypoint. + unregistered (bool, optional): Wether to retrieve only unregistered local entities. Defaults to False. + filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. + Returns: List[Entity]: a list of entities. """ + logging.info(f"Retrieving all {cls.get_type()} entities") + if unregistered: + if filters: + raise InvalidArgumentError( + "Filtering is not supported for unregistered entities" + ) + return cls.__unregistered_all() + return cls.__remote_all(filters=filters) + + @classmethod + def __remote_all(cls: Type[EntityType], filters: dict) -> List[EntityType]: + comms_fn = cls.remote_prefilter(filters) + entity_meta = comms_fn() + entities = [cls(**meta) for meta in entity_meta] + return entities + + @classmethod + def __unregistered_all(cls: Type[EntityType]) -> List[EntityType]: + entities = [] + storage_path = cls.get_storage_path() + try: + uids = next(os.walk(storage_path))[1] + except StopIteration: + msg = f"Couldn't iterate over the {cls.get_type()} storage" + logging.warning(msg) + raise MedperfException(msg) + + for uid in uids: + if uid.isdigit(): + continue + entity = cls.__local_get(uid) + entities.append(entity) + + return entities + + @staticmethod + def remote_prefilter(filters: dict) -> callable: + """Applies filtering logic that must be done before retrieving remote entities - @abstractmethod - def get(cls, uid: Union[str, int]) -> "Entity": + Args: + filters (dict): filters to apply + + Returns: + callable: A function for retrieving remote entities with the applied prefilters + """ + raise NotImplementedError + + @classmethod + def get( + cls: Type[EntityType], uid: Union[str, int], local_only: bool = False + ) -> EntityType: """Gets an instance of the respective entity. Wether this requires only local read or remote calls depends on the implementation. Args: uid (str): Unique Identifier to retrieve the entity + local_only (bool): If True, the entity will be retrieved locally Returns: Entity: Entity Instance associated to the UID """ - @abstractmethod - def todict(self) -> Dict: - """Dictionary representation of the entity + if not str(uid).isdigit() or local_only: + return cls.__local_get(uid) + return cls.__remote_get(uid) + + @classmethod + def __remote_get(cls: Type[EntityType], uid: int) -> EntityType: + """Retrieves and creates an entity instance from the comms instance. + + Args: + uid (int): server UID of the entity Returns: - Dict: Dictionary containing information about the entity + Entity: Specified Entity Instance """ + logging.debug(f"Retrieving {cls.get_type()} {uid} remotely") + comms_func = cls.get_comms_retriever() + entity_dict = comms_func(uid) + entity = cls(**entity_dict) + entity.write() + return entity - @abstractmethod - def write(self) -> str: - """Writes the entity to the local storage + @classmethod + def __local_get(cls: Type[EntityType], uid: Union[str, int]) -> EntityType: + """Retrieves and creates an entity instance from the local storage. + + Args: + uid (str|int): UID of the entity Returns: - str: Path to the stored entity + Entity: Specified Entity Instance """ + logging.debug(f"Retrieving {cls.get_type()} {uid} locally") + entity_dict = cls.__get_local_dict(uid) + entity = cls(**entity_dict) + return entity - @abstractmethod - def display_dict(self) -> dict: - """Returns a dictionary of entity properties that can be displayed - to a user interface using a verbose name of the property rather than - the internal names + @classmethod + def __get_local_dict(cls: Type[EntityType], uid: Union[str, int]) -> dict: + """Retrieves a local entity information + + Args: + uid (str): uid of the local entity Returns: - dict: the display dictionary + dict: information of the entity """ + logging.info(f"Retrieving {cls.get_type()} {uid} from local storage") + storage_path = cls.get_storage_path() + metadata_filename = cls.get_metadata_filename() + entity_file = os.path.join(storage_path, str(uid), metadata_filename) + if not os.path.exists(entity_file): + raise InvalidArgumentError( + f"No {cls.get_type()} with the given uid could be found" + ) + with open(entity_file, "r") as f: + data = yaml.safe_load(f) + + return data + + def write(self) -> str: + """Writes the entity to the local storage + Returns: + str: Path to the stored entity + """ + data = self.todict() + metadata_filename = self.get_metadata_filename() + entity_file = os.path.join(self.path, metadata_filename) + os.makedirs(self.path, exist_ok=True) + with open(entity_file, "w") as f: + yaml.dump(data, f) + return entity_file + + def remove_from_filesystem(self): + """Removes the entity folder recursively from the local storage""" + # TODO: might be dangerous + shutil.rmtree(self.path, ignore_errors=True) -class Uploadable: - @abstractmethod def upload(self) -> Dict: """Upload the entity-related information to the communication's interface Returns: Dict: Dictionary with the updated entity information """ + if self.for_test: + raise InvalidArgumentError( + f"This test {self.get_type()} cannot be uploaded." + ) + body = self.todict() + comms_func = self.get_comms_uploader() + updated_body = comms_func(body) + return updated_body - @property - def identifier(self): - return self.id or self.generated_uid + def display_dict(self) -> dict: + """Returns a dictionary of entity properties that can be displayed + to a user interface using a verbose name of the property rather than + the internal names - @property - def is_registered(self): - return self.id is not None + Returns: + dict: the display dictionary + """ + raise NotImplementedError diff --git a/cli/medperf/entities/report.py b/cli/medperf/entities/report.py index c76f09894..cefd168b3 100644 --- a/cli/medperf/entities/report.py +++ b/cli/medperf/entities/report.py @@ -1,16 +1,11 @@ import hashlib -import os -import yaml -import logging from typing import List, Union, Optional -from medperf.entities.schemas import MedperfBaseSchema import medperf.config as config -from medperf.exceptions import InvalidArgumentError from medperf.entities.interface import Entity -class TestReport(Entity, MedperfBaseSchema): +class TestReport(Entity): """ Class representing a compatibility test report entry @@ -23,8 +18,16 @@ class TestReport(Entity, MedperfBaseSchema): - model cube - evaluator cube - results + + Note: This entity is only a local one, there is no TestReports on the server + However, we still use the same Entity interface used by other entities + in order to reduce repeated code. Consequently, we mocked a few methods + and attributes inherited from the Entity interface that are not relevant to + this entity, such as the `name` and `id` attributes, and such as + the `get` and `all` methods. """ + name: Optional[str] = "name" demo_dataset_url: Optional[str] demo_dataset_hash: Optional[str] data_path: Optional[str] @@ -35,13 +38,25 @@ class TestReport(Entity, MedperfBaseSchema): data_evaluator_mlcube: Union[int, str] results: Optional[dict] + @staticmethod + def get_type(): + return "report" + + @staticmethod + def get_storage_path(): + return config.tests_folder + + @staticmethod + def get_metadata_filename(): + return config.test_report_file + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.generated_uid = self.__generate_uid() - path = config.tests_folder - self.path = os.path.join(path, self.generated_uid) + self.id = None + self.for_test = True - def __generate_uid(self): + @property + def local_id(self): """A helper that generates a unique hash for a test report.""" params = self.todict() del params["results"] @@ -52,71 +67,21 @@ def set_results(self, results): self.results = results @classmethod - def all( - cls, local_only: bool = False, mine_only: bool = False - ) -> List["TestReport"]: - """Gets and creates instances of test reports. - Arguments are only specified for compatibility with - `Entity.List` and `Entity.View`, but they don't contribute to - the logic. - - Returns: - List[TestReport]: List containing all test reports - """ - logging.info("Retrieving all reports") - reports = [] - tests_folder = config.tests_folder - try: - uids = next(os.walk(tests_folder))[1] - except StopIteration: - msg = "Couldn't iterate over the tests directory" - logging.warning(msg) - raise RuntimeError(msg) - - for uid in uids: - local_meta = cls.__get_local_dict(uid) - report = cls(**local_meta) - reports.append(report) - - return reports + def all(cls, unregistered: bool = False, filters: dict = {}) -> List["TestReport"]: + assert unregistered, "Reports are only unregistered" + assert filters == {}, "Reports cannot be filtered" + return super().all(unregistered=True, filters={}) @classmethod - def get(cls, report_uid: str) -> "TestReport": - """Retrieves and creates a TestReport instance obtained the user's machine - + def get(cls, uid: str, local_only: bool = False) -> "TestReport": + """Gets an instance of the TestReport. ignores local_only inherited flag as TestReport is always a local entity. Args: - report_uid (str): UID of the TestReport instance - + uid (str): Report Unique Identifier + local_only (bool): ignored. Left for aligning with parent Entity class Returns: - TestReport: Specified TestReport instance + TestReport: Report Instance associated to the UID """ - logging.debug(f"Retrieving report {report_uid}") - report_dict = cls.__get_local_dict(report_uid) - report = cls(**report_dict) - report.write() - return report - - def todict(self): - return self.extended_dict() - - def write(self): - report_file = os.path.join(self.path, config.test_report_file) - os.makedirs(self.path, exist_ok=True) - with open(report_file, "w") as f: - yaml.dump(self.todict(), f) - return report_file - - @classmethod - def __get_local_dict(cls, local_uid): - report_path = os.path.join(config.tests_folder, str(local_uid)) - report_file = os.path.join(report_path, config.test_report_file) - if not os.path.exists(report_file): - raise InvalidArgumentError( - f"The requested report {local_uid} could not be retrieved" - ) - with open(report_file, "r") as f: - report_info = yaml.safe_load(f) - return report_info + return super().get(uid, local_only=True) def display_dict(self): if self.data_path: @@ -127,7 +92,7 @@ def display_dict(self): data_source = f"{self.prepared_data_hash}" return { - "UID": self.generated_uid, + "UID": self.local_id, "Data Source": data_source, "Model": ( self.model if isinstance(self.model, int) else self.model[:27] + "..." diff --git a/cli/medperf/entities/result.py b/cli/medperf/entities/result.py index c82add87b..0e96d1feb 100644 --- a/cli/medperf/entities/result.py +++ b/cli/medperf/entities/result.py @@ -1,16 +1,10 @@ -import os -import yaml -import logging -from typing import List, Union - -from medperf.entities.interface import Entity, Uploadable -from medperf.entities.schemas import MedperfSchema, ApprovableSchema +from medperf.entities.interface import Entity +from medperf.entities.schemas import ApprovableSchema import medperf.config as config -from medperf.exceptions import CommunicationRetrievalError, InvalidArgumentError from medperf.account_management import get_medperf_user_data -class Result(Entity, Uploadable, MedperfSchema, ApprovableSchema): +class Result(Entity, ApprovableSchema): """ Class representing a Result entry @@ -28,59 +22,36 @@ class Result(Entity, Uploadable, MedperfSchema, ApprovableSchema): metadata: dict = {} user_metadata: dict = {} - def __init__(self, *args, **kwargs): - """Creates a new result instance""" - super().__init__(*args, **kwargs) - - self.generated_uid = f"b{self.benchmark}m{self.model}d{self.dataset}" - path = config.results_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - - self.path = path - - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Result"]: - """Gets and creates instances of all the user's results - - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. - - Returns: - List[Result]: List containing all results - """ - logging.info("Retrieving all results") - results = [] - if not local_only: - results = cls.__remote_all(filters=filters) - - remote_uids = set([result.id for result in results]) + @staticmethod + def get_type(): + return "result" - local_results = cls.__local_all() + @staticmethod + def get_storage_path(): + return config.results_folder - results += [res for res in local_results if res.id not in remote_uids] + @staticmethod + def get_comms_retriever(): + return config.comms.get_result - return results + @staticmethod + def get_metadata_filename(): + return config.results_info_file - @classmethod - def __remote_all(cls, filters: dict) -> List["Result"]: - results = [] + @staticmethod + def get_comms_uploader(): + return config.comms.upload_result - try: - comms_fn = cls.__remote_prefilter(filters) - results_meta = comms_fn() - results = [cls(**meta) for meta in results_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all results from the server" - logging.warning(msg) + def __init__(self, *args, **kwargs): + """Creates a new result instance""" + super().__init__(*args, **kwargs) - return results + @property + def local_id(self): + return self.name - @classmethod - def __remote_prefilter(cls, filters: dict) -> callable: + @staticmethod + def remote_prefilter(filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities Args: @@ -104,113 +75,6 @@ def get_benchmark_results(): return comms_fn - @classmethod - def __local_all(cls) -> List["Result"]: - results = [] - results_folder = config.results_folder - try: - uids = next(os.walk(results_folder))[1] - except StopIteration: - msg = "Couldn't iterate over the dataset directory" - logging.warning(msg) - raise RuntimeError(msg) - - for uid in uids: - local_meta = cls.__get_local_dict(uid) - result = cls(**local_meta) - results.append(result) - - return results - - @classmethod - def get(cls, result_uid: Union[str, int], local_only: bool = False) -> "Result": - """Retrieves and creates a Result instance obtained from the platform. - If the result instance already exists in the user's machine, it loads - the local instance - - Args: - result_uid (str): UID of the Result instance - - Returns: - Result: Specified Result instance - """ - if not str(result_uid).isdigit() or local_only: - return cls.__local_get(result_uid) - - try: - return cls.__remote_get(result_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting Result {result_uid} from comms failed") - logging.info(f"Looking for result {result_uid} locally") - return cls.__local_get(result_uid) - - @classmethod - def __remote_get(cls, result_uid: int) -> "Result": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - result_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving result {result_uid} remotely") - meta = config.comms.get_result(result_uid) - result = cls(**meta) - result.write() - return result - - @classmethod - def __local_get(cls, result_uid: Union[str, int]) -> "Result": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - result_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving result {result_uid} locally") - local_meta = cls.__get_local_dict(result_uid) - result = cls(**local_meta) - return result - - def todict(self): - return self.extended_dict() - - def upload(self): - """Uploads the results to the comms - - Args: - comms (Comms): Instance of the communications interface. - """ - if self.for_test: - raise InvalidArgumentError("Cannot upload test results.") - results_info = self.todict() - updated_results_info = config.comms.upload_result(results_info) - return updated_results_info - - def write(self): - result_file = os.path.join(self.path, config.results_info_file) - os.makedirs(self.path, exist_ok=True) - with open(result_file, "w") as f: - yaml.dump(self.todict(), f) - return result_file - - @classmethod - def __get_local_dict(cls, local_uid): - result_path = os.path.join(config.results_folder, str(local_uid)) - result_file = os.path.join(result_path, config.results_info_file) - if not os.path.exists(result_file): - raise InvalidArgumentError( - f"The requested result {local_uid} could not be retrieved" - ) - with open(result_file, "r") as f: - results_info = yaml.safe_load(f) - return results_info - def display_dict(self): return { "UID": self.identifier, diff --git a/cli/medperf/entities/schemas.py b/cli/medperf/entities/schemas.py index 0e7a54291..79926abd9 100644 --- a/cli/medperf/entities/schemas.py +++ b/cli/medperf/entities/schemas.py @@ -8,7 +8,15 @@ from medperf.utils import format_errors_dict -class MedperfBaseSchema(BaseModel): +class MedperfSchema(BaseModel): + for_test: bool = False + id: Optional[int] + name: str = Field(..., max_length=64) + owner: Optional[int] + is_valid: bool = True + created_at: Optional[datetime] + modified_at: Optional[datetime] + def __init__(self, *args, **kwargs): """Override the ValidationError procedure so we can format the error message in our desired way @@ -46,7 +54,7 @@ def dict(self, *args, **kwargs) -> dict: out_dict = {k: v for k, v in model_dict.items() if k in valid_fields} return out_dict - def extended_dict(self) -> dict: + def todict(self) -> dict: """Dictionary containing both original and alias fields Returns: @@ -68,27 +76,17 @@ def empty_str_to_none(cls, v): return None return v - class Config: - allow_population_by_field_name = True - extra = "allow" - use_enum_values = True - - -class MedperfSchema(MedperfBaseSchema): - for_test: bool = False - id: Optional[int] - name: str = Field(..., max_length=64) - owner: Optional[int] - is_valid: bool = True - created_at: Optional[datetime] - modified_at: Optional[datetime] - @validator("name", pre=True, always=True) def name_max_length(cls, v, *, values, **kwargs): if not values["for_test"] and len(v) > 20: raise ValueError("The name must have no more than 20 characters") return v + class Config: + allow_population_by_field_name = True + extra = "allow" + use_enum_values = True + class DeployableSchema(BaseModel): state: str = "DEVELOPMENT" diff --git a/cli/medperf/tests/commands/benchmark/test_submit.py b/cli/medperf/tests/commands/benchmark/test_submit.py index b00e1c5a8..7e2d5b23b 100644 --- a/cli/medperf/tests/commands/benchmark/test_submit.py +++ b/cli/medperf/tests/commands/benchmark/test_submit.py @@ -94,7 +94,7 @@ def test_run_compatibility_test_uses_expected_default_parameters(mocker, comms, # Assert comp_spy.assert_called_once_with( - benchmark=bmk.generated_uid, no_cache=True, skip_data_preparation_step=False + benchmark=bmk.local_id, no_cache=True, skip_data_preparation_step=False ) @@ -117,7 +117,7 @@ def test_run_compatibility_test_with_passed_parameters(mocker, force, skip, comm # Assert comp_spy.assert_called_once_with( - benchmark=bmk.generated_uid, no_cache=force, skip_data_preparation_step=skip + benchmark=bmk.local_id, no_cache=force, skip_data_preparation_step=skip ) diff --git a/cli/medperf/tests/commands/mlcube/test_submit.py b/cli/medperf/tests/commands/mlcube/test_submit.py index 630390205..a946c1fef 100644 --- a/cli/medperf/tests/commands/mlcube/test_submit.py +++ b/cli/medperf/tests/commands/mlcube/test_submit.py @@ -57,7 +57,7 @@ def test_to_permanent_path_renames_correctly(mocker, comms, ui, cube, uid): submission.cube = cube spy = mocker.patch("os.rename") mocker.patch("os.path.exists", return_value=False) - old_path = os.path.join(config.cubes_folder, cube.generated_uid) + old_path = os.path.join(config.cubes_folder, cube.local_id) new_path = os.path.join(config.cubes_folder, str(uid)) # Act submission.to_permanent_path({**cube.todict(), "id": uid}) diff --git a/cli/medperf/tests/commands/result/test_create.py b/cli/medperf/tests/commands/result/test_create.py index 74299c77e..c69544781 100644 --- a/cli/medperf/tests/commands/result/test_create.py +++ b/cli/medperf/tests/commands/result/test_create.py @@ -57,6 +57,9 @@ def mock_result_all(mocker, state_variables): TestResult(benchmark=triplet[0], model=triplet[1], dataset=triplet[2]) for triplet in cached_results_triplets ] + mocker.patch( + PATCH_EXECUTION.format("get_medperf_user_data", return_value={"id": 1}) + ) mocker.patch(PATCH_EXECUTION.format("Result.all"), return_value=results) diff --git a/cli/medperf/tests/commands/result/test_submit.py b/cli/medperf/tests/commands/result/test_submit.py index 10680fbe1..26b03fbcc 100644 --- a/cli/medperf/tests/commands/result/test_submit.py +++ b/cli/medperf/tests/commands/result/test_submit.py @@ -25,6 +25,7 @@ def submission(mocker, comms, ui, result, dataset): sub = ResultSubmission(1) mocker.patch(PATCH_SUBMISSION.format("Result"), return_value=result) mocker.patch(PATCH_SUBMISSION.format("Result.get"), return_value=result) + sub.get_result() return sub diff --git a/cli/medperf/tests/commands/test_execution.py b/cli/medperf/tests/commands/test_execution.py index 669d7dfd9..d50ca5d31 100644 --- a/cli/medperf/tests/commands/test_execution.py +++ b/cli/medperf/tests/commands/test_execution.py @@ -102,8 +102,8 @@ def test_failure_with_existing_predictions(mocker, setup, ignore_model_errors, f # Arrange preds_path = os.path.join( config.predictions_folder, - INPUT_MODEL.generated_uid, - INPUT_DATASET.generated_uid, + INPUT_MODEL.local_id, + INPUT_DATASET.local_id, ) fs.create_dir(preds_path) @@ -149,22 +149,22 @@ def test_cube_run_are_called_properly(mocker, setup): # Arrange exp_preds_path = os.path.join( config.predictions_folder, - INPUT_MODEL.generated_uid, - INPUT_DATASET.generated_uid, + INPUT_MODEL.local_id, + INPUT_DATASET.local_id, ) exp_model_logs_path = os.path.join( config.experiments_logs_folder, - INPUT_MODEL.generated_uid, - INPUT_DATASET.generated_uid, + INPUT_MODEL.local_id, + INPUT_DATASET.local_id, "model.log", ) exp_metrics_logs_path = os.path.join( config.experiments_logs_folder, - INPUT_MODEL.generated_uid, - INPUT_DATASET.generated_uid, - f"metrics_{INPUT_EVALUATOR.generated_uid}.log", + INPUT_MODEL.local_id, + INPUT_DATASET.local_id, + f"metrics_{INPUT_EVALUATOR.local_id}.log", ) exp_model_call = call( diff --git a/cli/medperf/tests/commands/test_list.py b/cli/medperf/tests/commands/test_list.py index 1c2dc3267..ce7035960 100644 --- a/cli/medperf/tests/commands/test_list.py +++ b/cli/medperf/tests/commands/test_list.py @@ -47,18 +47,18 @@ def set_common_attributes(self, setup): self.state_variables = state_variables self.spies = spies - @pytest.mark.parametrize("local_only", [False, True]) + @pytest.mark.parametrize("unregistered", [False, True]) @pytest.mark.parametrize("mine_only", [False, True]) - def test_entity_all_is_called_properly(self, mocker, local_only, mine_only): + def test_entity_all_is_called_properly(self, mocker, unregistered, mine_only): # Arrange filters = {"owner": 1} if mine_only else {} # Act - EntityList.run(Entity, [], local_only, mine_only) + EntityList.run(Entity, [], unregistered, mine_only) # Assert self.spies["all"].assert_called_once_with( - local_only=local_only, filters=filters + unregistered=unregistered, filters=filters ) @pytest.mark.parametrize("fields", [["UID", "MLCube"]]) diff --git a/cli/medperf/tests/commands/test_view.py b/cli/medperf/tests/commands/test_view.py index a2dddfeda..0ffe0fb13 100644 --- a/cli/medperf/tests/commands/test_view.py +++ b/cli/medperf/tests/commands/test_view.py @@ -1,143 +1,86 @@ import pytest -import yaml -import json from medperf.entities.interface import Entity -from medperf.exceptions import InvalidArgumentError from medperf.commands.view import EntityView - -def expected_output(entities, format): - if isinstance(entities, list): - data = [entity.todict() for entity in entities] - else: - data = entities.todict() - - if format == "yaml": - return yaml.dump(data) - if format == "json": - return json.dumps(data) - - -def generate_entity(id, mocker): - entity = mocker.create_autospec(spec=Entity) - mocker.patch.object(entity, "todict", return_value={"id": id}) - return entity +PATCH_VIEW = "medperf.commands.view.{}" @pytest.fixture -def ui_spy(mocker, ui): - return mocker.patch.object(ui, "print") +def entity(mocker): + return mocker.create_autospec(Entity) -@pytest.fixture( - params=[{"local": ["1", "2", "3"], "remote": ["4", "5", "6"], "user": ["4"]}] -) -def setup(request, mocker): - local_ids = request.param.get("local", []) - remote_ids = request.param.get("remote", []) - user_ids = request.param.get("user", []) - all_ids = list(set(local_ids + remote_ids + user_ids)) - - local_entities = [generate_entity(id, mocker) for id in local_ids] - remote_entities = [generate_entity(id, mocker) for id in remote_ids] - user_entities = [generate_entity(id, mocker) for id in user_ids] - all_entities = list(set(local_entities + remote_entities + user_entities)) - - def mock_all(filters={}, local_only=False): - if "owner" in filters: - return user_entities - if local_only: - return local_entities - return all_entities - - def mock_get(entity_id): - if entity_id in all_ids: - return generate_entity(entity_id, mocker) - else: - raise InvalidArgumentError - - mocker.patch("medperf.commands.view.get_medperf_user_data", return_value={"id": 1}) - mocker.patch.object(Entity, "all", side_effect=mock_all) - mocker.patch.object(Entity, "get", side_effect=mock_get) - - return local_entities, remote_entities, user_entities, all_entities - - -class TestViewEntityID: - def test_view_displays_entity_if_given(self, mocker, setup, ui_spy): - # Arrange - entity_id = "1" - entity = generate_entity(entity_id, mocker) - output = expected_output(entity, "yaml") - - # Act - EntityView.run(entity_id, Entity) - - # Assert - ui_spy.assert_called_once_with(output) - - def test_view_displays_all_if_no_id(self, setup, ui_spy): - # Arrange - *_, entities = setup - output = expected_output(entities, "yaml") - - # Act - EntityView.run(None, Entity) - - # Assert - ui_spy.assert_called_once_with(output) - - -class TestViewFilteredEntities: - def test_view_displays_local_entities(self, setup, ui_spy): - # Arrange - entities, *_ = setup - output = expected_output(entities, "yaml") - - # Act - EntityView.run(None, Entity, local_only=True) - - # Assert - ui_spy.assert_called_once_with(output) - - def test_view_displays_user_entities(self, setup, ui_spy): - # Arrange - *_, entities, _ = setup - output = expected_output(entities, "yaml") - - # Act - EntityView.run(None, Entity, mine_only=True) - - # Assert - ui_spy.assert_called_once_with(output) - - -@pytest.mark.parametrize("entity_id", ["4", None]) -@pytest.mark.parametrize("format", ["yaml", "json"]) -class TestViewOutput: - @pytest.fixture - def output(self, setup, mocker, entity_id, format): - if entity_id is None: - *_, entities = setup - return expected_output(entities, format) - else: - entity = generate_entity(entity_id, mocker) - return expected_output(entity, format) - - def test_view_displays_specified_format(self, entity_id, output, ui_spy, format): - # Act - EntityView.run(entity_id, Entity, format=format) - - # Assert - ui_spy.assert_called_once_with(output) - - def test_view_stores_specified_format(self, entity_id, output, format, fs): - # Arrange - filename = "file.txt" - - # Act - EntityView.run(entity_id, Entity, format=format, output=filename) - - # Assert - contents = open(filename, "r").read() - assert contents == output +@pytest.fixture +def entity_view(mocker): + view_class = EntityView(None, Entity, "", "", "", "") + return view_class + + +def test_prepare_with_id_given(mocker, entity_view, entity): + # Arrange + entity_view.entity_id = 1 + get_spy = mocker.patch(PATCH_VIEW.format("Entity.get"), return_value=entity) + all_spy = mocker.patch(PATCH_VIEW.format("Entity.all"), return_value=[entity]) + + # Act + entity_view.prepare() + + # Assert + get_spy.assert_called_once_with(1) + all_spy.assert_not_called() + assert not isinstance(entity_view.data, list) + + +def test_prepare_with_no_id_given(mocker, entity_view, entity): + # Arrange + entity_view.entity_id = None + entity_view.mine_only = False + get_spy = mocker.patch(PATCH_VIEW.format("Entity.get"), return_value=entity) + all_spy = mocker.patch(PATCH_VIEW.format("Entity.all"), return_value=[entity]) + + # Act + entity_view.prepare() + + # Assert + all_spy.assert_called_once() + get_spy.assert_not_called() + assert isinstance(entity_view.data, list) + + +@pytest.mark.parametrize("unregistered", [False, True]) +def test_prepare_with_no_id_calls_all_with_unregistered_properly( + mocker, entity_view, entity, unregistered +): + # Arrange + entity_view.entity_id = None + entity_view.mine_only = False + entity_view.unregistered = unregistered + all_spy = mocker.patch(PATCH_VIEW.format("Entity.all"), return_value=[entity]) + + # Act + entity_view.prepare() + + # Assert + all_spy.assert_called_once_with(unregistered=unregistered, filters={}) + + +@pytest.mark.parametrize("filters", [{}, {"f1": "v1"}]) +@pytest.mark.parametrize("mine_only", [False, True]) +def test_prepare_with_no_id_calls_all_with_proper_filters( + mocker, entity_view, entity, filters, mine_only +): + # Arrange + entity_view.entity_id = None + entity_view.mine_only = False + entity_view.unregistered = False + entity_view.filters = filters + all_spy = mocker.patch(PATCH_VIEW.format("Entity.all"), return_value=[entity]) + mocker.patch(PATCH_VIEW.format("get_medperf_user_data"), return_value={"id": 1}) + if mine_only: + filters["owner"] = 1 + + # Act + entity_view.prepare() + + # Assert + all_spy.assert_called_once_with(unregistered=False, filters=filters) diff --git a/cli/medperf/tests/entities/test_benchmark.py b/cli/medperf/tests/entities/test_benchmark.py index 3f1fde2e2..6fa6aae47 100644 --- a/cli/medperf/tests/entities/test_benchmark.py +++ b/cli/medperf/tests/entities/test_benchmark.py @@ -3,30 +3,20 @@ from medperf.entities.benchmark import Benchmark from medperf.tests.entities.utils import setup_benchmark_fs, setup_benchmark_comms - PATCH_BENCHMARK = "medperf.entities.benchmark.{}" -@pytest.fixture( - params={ - "local": [1, 2, 3], - "remote": [4, 5, 6], - "user": [4], - "models": [10, 11], - } -) +@pytest.fixture(autouse=True) def setup(request, mocker, comms, fs): local_ids = request.param.get("local", []) remote_ids = request.param.get("remote", []) user_ids = request.param.get("user", []) models = request.param.get("models", []) # Have a list that will contain all uploaded entities of the given type - uploaded = [] setup_benchmark_fs(local_ids, fs) - setup_benchmark_comms(mocker, comms, remote_ids, user_ids, uploaded) + setup_benchmark_comms(mocker, comms, remote_ids, user_ids) mocker.patch.object(comms, "get_benchmark_model_associations", return_value=models) - request.param["uploaded"] = uploaded return request.param @@ -51,10 +41,10 @@ def setup(request, mocker, comms, fs): class TestModels: def test_benchmark_get_models_works_as_expected(self, setup, expected_models): # Arrange - id = setup["remote"][0] + id_ = setup["remote"][0] # Act - assciated_models = Benchmark.get_models_uids(id) + associated_models = Benchmark.get_models_uids(id_) # Assert - assert set(assciated_models) == set(expected_models) + assert set(associated_models) == set(expected_models) diff --git a/cli/medperf/tests/entities/test_cube.py b/cli/medperf/tests/entities/test_cube.py index 96f81dba0..cd182489f 100644 --- a/cli/medperf/tests/entities/test_cube.py +++ b/cli/medperf/tests/entities/test_cube.py @@ -24,18 +24,16 @@ } -@pytest.fixture(params={"local": [1, 2, 3], "remote": [4, 5, 6], "user": [4]}) +@pytest.fixture(autouse=True) def setup(request, mocker, comms, fs): local_ents = request.param.get("local", []) remote_ents = request.param.get("remote", []) user_ents = request.param.get("user", []) # Have a list that will contain all uploaded entities of the given type - uploaded = [] setup_cube_fs(local_ents, fs) - setup_cube_comms(mocker, comms, remote_ents, user_ents, uploaded) + request.param["storage"] = setup_cube_comms(mocker, comms, remote_ents, user_ents) setup_cube_comms_downloads(mocker, fs) - request.param["uploaded"] = uploaded # Mock additional third party elements mpexpect = MockPexpect(0) @@ -282,7 +280,9 @@ def test_run_stops_execution_if_child_fails(self, mocker, setup, task): cube.run(task) -@pytest.mark.parametrize("setup", [{"local": [DEFAULT_CUBE]}], indirect=True) +@pytest.mark.parametrize( + "setup", [{"local": [DEFAULT_CUBE], "remote": [DEFAULT_CUBE]}], indirect=True +) @pytest.mark.parametrize("task", ["task"]) @pytest.mark.parametrize( "out_key,out_value", diff --git a/cli/medperf/tests/entities/test_entity.py b/cli/medperf/tests/entities/test_entity.py index c636b2c26..960a8edb2 100644 --- a/cli/medperf/tests/entities/test_entity.py +++ b/cli/medperf/tests/entities/test_entity.py @@ -15,7 +15,8 @@ setup_result_fs, setup_result_comms, ) -from medperf.exceptions import InvalidArgumentError +from medperf.exceptions import CommunicationRetrievalError, InvalidArgumentError +from medperf.tests.mocks.comms import TestEntityStorage @pytest.fixture(params=[Benchmark, Cube, Dataset, Result]) @@ -23,13 +24,12 @@ def Implementation(request): return request.param -@pytest.fixture(params={"local": [1, 2, 3], "remote": [4, 5, 6], "user": [4]}) +@pytest.fixture(autouse=True) def setup(request, mocker, comms, Implementation, fs): local_ids = request.param.get("local", []) remote_ids = request.param.get("remote", []) user_ids = request.param.get("user", []) # Have a list that will contain all uploaded entities of the given type - uploaded = [] if Implementation == Benchmark: setup_fs = setup_benchmark_fs @@ -44,49 +44,65 @@ def setup(request, mocker, comms, Implementation, fs): elif Implementation == Result: setup_fs = setup_result_fs setup_comms = setup_result_comms + else: + raise NotImplementedError("Wrong implementation") - setup_comms(mocker, comms, remote_ids, user_ids, uploaded) + storage = setup_comms(mocker, comms, remote_ids, user_ids) setup_fs(local_ids, fs) - request.param["uploaded"] = uploaded + + request.param["storage"] = storage return request.param @pytest.mark.parametrize( "setup", - [{"local": [283, 17, 493], "remote": [283, 1, 2], "user": [2]}], + [ + { + "unregistered": ["e1", "e2"], + "local": ["e1", "e2", 283], + "remote": [283, 1, 2], + "user": [2], + } + ], indirect=True, ) class TestAll: @pytest.fixture(autouse=True) def set_common_attributes(self, setup): self.ids = setup + self.unregistered_ids = set(self.ids["unregistered"]) self.local_ids = set(self.ids["local"]) self.remote_ids = set(self.ids["remote"]) self.user_ids = set(self.ids["user"]) - def test_all_returns_all_remote_and_local(self, Implementation): - # Arrange - all_ids = self.local_ids.union(self.remote_ids) - + def test_all_returns_all_remote_by_default(self, Implementation): # Act entities = Implementation.all() # Assert retrieved_ids = set([e.todict()["id"] for e in entities]) - assert all_ids == retrieved_ids + assert self.remote_ids == retrieved_ids - def test_all_local_only_returns_all_local(self, Implementation): + def test_all_unregistered_returns_all_unregistered(self, Implementation): # Act - entities = Implementation.all(local_only=True) + entities = Implementation.all(unregistered=True) # Assert - retrieved_ids = set([e.todict()["id"] for e in entities]) - assert self.local_ids == retrieved_ids + retrieved_ids = set([e.local_id for e in entities]) + assert self.unregistered_ids == retrieved_ids @pytest.mark.parametrize( - "setup", [{"local": [78], "remote": [479, 42, 7, 1]}], indirect=True + "setup", + [ + { + "unregistered": ["e1", "e2"], + "local": ["e1", "e2", 479], + "remote": [479, 42, 7, 1], + } + ], + indirect=True, ) class TestGet: def test_get_retrieves_entity_from_server(self, Implementation, setup): @@ -99,30 +115,20 @@ def test_get_retrieves_entity_from_server(self, Implementation, setup): # Assert assert entity.todict()["id"] == id - def test_get_retrieves_entity_local_if_not_on_server(self, Implementation, setup): - # Arrange - id = setup["local"][0] - - # Act - entity = Implementation.get(id) - - # Assert - assert entity.todict()["id"] == id - def test_get_raises_error_if_nonexistent(self, Implementation, setup): # Arrange id = str(19283) # Act & Assert - with pytest.raises(InvalidArgumentError): + with pytest.raises(CommunicationRetrievalError): Implementation.get(id) -@pytest.mark.parametrize("setup", [{"local": [742]}], indirect=True) +@pytest.mark.parametrize("setup", [{"remote": [742]}], indirect=True) class TestToDict: @pytest.fixture(autouse=True) def set_common_attributes(self, setup): - self.id = setup["local"][0] + self.id = setup["remote"][0] def test_todict_returns_dict_representation(self, Implementation): # Arrange @@ -147,7 +153,16 @@ def test_todict_can_recreate_object(self, Implementation): assert ent_dict == ent_copy_dict -@pytest.mark.parametrize("setup", [{"local": [36]}], indirect=True) +@pytest.mark.parametrize( + "setup", + [ + { + "unregistered": ["e1", "e2"], + "local": ["e1", "e2"], + } + ], + indirect=True, +) class TestUpload: @pytest.fixture(autouse=True) def set_common_attributes(self, setup): @@ -155,14 +170,16 @@ def set_common_attributes(self, setup): def test_upload_adds_to_remote(self, Implementation, setup): # Arrange - uploaded_entities = setup["uploaded"] + storage: TestEntityStorage = setup["storage"] + assert self.id not in storage.storage + ent = Implementation.get(self.id) # Act ent.upload() # Assert - assert ent.todict() in uploaded_entities + assert ent.todict() in storage.uploaded def test_upload_returns_dict(self, Implementation): # Arrange @@ -172,20 +189,30 @@ def test_upload_returns_dict(self, Implementation): ent_dict = ent.upload() # Assert - assert ent_dict == ent.todict() + real_dict = ent.todict() + diff = {} + for k in set(real_dict) | set(ent_dict): + if real_dict.get(k) != ent_dict.get(k): + diff[k] = (real_dict.get(k), ent_dict.get(k)) + assert ent_dict == ent.todict(), f"Expected: {ent_dict}\nGot: {real_dict}\nDiff: {diff}" def test_upload_fails_for_test_entity(self, Implementation, setup): # Arrange - uploaded_entities = setup["uploaded"] + storage: TestEntityStorage = setup["storage"] ent = Implementation.get(self.id) ent.for_test = True + # pre-check + len_before_test = len(storage.uploaded) + assert self.id not in storage.storage # Act with pytest.raises(InvalidArgumentError): ent.upload() # Assert - assert ent.todict() not in uploaded_entities + assert self.id not in storage.storage + assert ent.todict() not in storage.uploaded + assert len(storage.uploaded) == len_before_test @pytest.mark.parametrize( diff --git a/cli/medperf/tests/entities/utils.py b/cli/medperf/tests/entities/utils.py index 522251ca7..98c3a5070 100644 --- a/cli/medperf/tests/entities/utils.py +++ b/cli/medperf/tests/entities/utils.py @@ -1,4 +1,5 @@ import os + from medperf import config import yaml @@ -8,21 +9,23 @@ from medperf.tests.mocks.dataset import TestDataset from medperf.tests.mocks.result import TestResult from medperf.tests.mocks.cube import TestCube -from medperf.tests.mocks.comms import mock_comms_entity_gets +from medperf.tests.mocks.comms import mock_comms_entity_gets, TestEntityStorage PATCH_RESOURCES = "medperf.comms.entity_resources.resources.{}" # Setup Benchmark def setup_benchmark_fs(ents, fs): - bmks_path = config.benchmarks_folder for ent in ents: - if not isinstance(ent, dict): - # Assume we're passing ids - ent = {"id": str(ent)} - id = ent["id"] - bmk_filepath = os.path.join(bmks_path, str(id), config.benchmarks_filename) - bmk_contents = TestBenchmark(**ent) + # Assume we're passing ids, local_ids, or dicts + if isinstance(ent, dict): + bmk_contents = TestBenchmark(**ent) + elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): + bmk_contents = TestBenchmark(id=str(ent)) + else: + bmk_contents = TestBenchmark(id=None, name=ent) + + bmk_filepath = os.path.join(bmk_contents.path, config.benchmarks_filename) cubes_ids = [] cubes_ids.append(bmk_contents.data_preparation_mlcube) cubes_ids.append(bmk_contents.reference_model_mlcube) @@ -30,12 +33,12 @@ def setup_benchmark_fs(ents, fs): cubes_ids = list(set(cubes_ids)) setup_cube_fs(cubes_ids, fs) try: - fs.create_file(bmk_filepath, contents=yaml.dump(bmk_contents.dict())) + fs.create_file(bmk_filepath, contents=yaml.dump(bmk_contents.todict())) except FileExistsError: pass -def setup_benchmark_comms(mocker, comms, all_ents, user_ents, uploaded): +def setup_benchmark_comms(mocker, comms, all_ents, user_ents) -> TestEntityStorage: generate_fn = TestBenchmark comms_calls = { "get_all": "get_benchmarks", @@ -44,31 +47,31 @@ def setup_benchmark_comms(mocker, comms, all_ents, user_ents, uploaded): "upload_instance": "upload_benchmark", } mocker.patch.object(comms, "get_benchmark_model_associations", return_value=[]) - mock_comms_entity_gets( - mocker, comms, generate_fn, comms_calls, all_ents, user_ents, uploaded + return mock_comms_entity_gets( + mocker, comms, generate_fn, comms_calls, all_ents, user_ents ) # Setup Cube def setup_cube_fs(ents, fs): - cubes_path = config.cubes_folder for ent in ents: - if not isinstance(ent, dict): - # Assume we're passing ids - ent = {"id": str(ent)} - id = ent["id"] - meta_cube_file = os.path.join( - cubes_path, str(id), config.cube_metadata_filename - ) - cube = TestCube(**ent) - meta = cube.dict() + # Assume we're passing ids, names, or dicts + if isinstance(ent, dict): + cube = TestCube(**ent) + elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): + cube = TestCube(id=str(ent)) + else: + cube = TestCube(id=None, name=ent) + + meta_cube_file = os.path.join(cube.path, config.cube_metadata_filename) + meta = cube.todict() try: fs.create_file(meta_cube_file, contents=yaml.dump(meta)) except FileExistsError: pass -def setup_cube_comms(mocker, comms, all_ents, user_ents, uploaded): +def setup_cube_comms(mocker, comms, all_ents, user_ents) -> TestEntityStorage: generate_fn = TestCube comms_calls = { "get_all": "get_cubes", @@ -76,15 +79,15 @@ def setup_cube_comms(mocker, comms, all_ents, user_ents, uploaded): "get_instance": "get_cube_metadata", "upload_instance": "upload_mlcube", } - mock_comms_entity_gets( - mocker, comms, generate_fn, comms_calls, all_ents, user_ents, uploaded + return mock_comms_entity_gets( + mocker, comms, generate_fn, comms_calls, all_ents, user_ents ) def generate_cubefile_fn(fs, path, filename): # all_ids = [ent["id"] if type(ent) == dict else ent for ent in all_ents] - def cubefile_fn(url, cube_path, *args): + def cubefile_fn(url: str, cube_path: str, *args): if url == "broken_url": raise CommunicationRetrievalError filepath = os.path.join(cube_path, path, filename) @@ -124,23 +127,25 @@ def setup_cube_comms_downloads(mocker, fs): # Setup Dataset def setup_dset_fs(ents, fs): - dsets_path = config.datasets_folder for ent in ents: - if not isinstance(ent, dict): - # Assume passing ids - ent = {"id": str(ent)} - id = ent["id"] - reg_dset_file = os.path.join(dsets_path, str(id), config.reg_file) - dset_contents = TestDataset(**ent) + # Assume we're passing ids, generated_uids, or dicts + if isinstance(ent, dict): + dset_contents = TestDataset(**ent) + elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): + dset_contents = TestDataset(id=str(ent)) + else: + dset_contents = TestDataset(id=None, generated_uid=ent) + + reg_dset_file = os.path.join(dset_contents.path, config.reg_file) cube_id = dset_contents.data_preparation_mlcube setup_cube_fs([cube_id], fs) try: - fs.create_file(reg_dset_file, contents=yaml.dump(dset_contents.dict())) + fs.create_file(reg_dset_file, contents=yaml.dump(dset_contents.todict())) except FileExistsError: pass -def setup_dset_comms(mocker, comms, all_ents, user_ents, uploaded): +def setup_dset_comms(mocker, comms, all_ents, user_ents) -> TestEntityStorage: generate_fn = TestDataset comms_calls = { "get_all": "get_datasets", @@ -148,34 +153,37 @@ def setup_dset_comms(mocker, comms, all_ents, user_ents, uploaded): "get_instance": "get_dataset", "upload_instance": "upload_dataset", } - mock_comms_entity_gets( - mocker, comms, generate_fn, comms_calls, all_ents, user_ents, uploaded + return mock_comms_entity_gets( + mocker, comms, generate_fn, comms_calls, all_ents, user_ents ) # Setup Result def setup_result_fs(ents, fs): - results_path = config.results_folder for ent in ents: - if not isinstance(ent, dict): - # Assume passing ids - ent = {"id": str(ent)} - id = ent["id"] - result_file = os.path.join(results_path, str(id), config.results_info_file) - bmk_id = ent.get("benchmark", 1) - cube_id = ent.get("model", 1) - dataset_id = ent.get("dataset", 1) + # Assume we're passing ids, names, or dicts + if isinstance(ent, dict): + result_contents = TestResult(**ent) + elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): + result_contents = TestResult(id=str(ent)) + else: + result_contents = TestResult(id=None, name=ent) + + result_file = os.path.join(result_contents.path, config.results_info_file) + bmk_id = result_contents.benchmark + cube_id = result_contents.model + dataset_id = result_contents.dataset setup_benchmark_fs([bmk_id], fs) setup_cube_fs([cube_id], fs) setup_dset_fs([dataset_id], fs) - result_contents = TestResult(**ent) + try: - fs.create_file(result_file, contents=yaml.dump(result_contents.dict())) + fs.create_file(result_file, contents=yaml.dump(result_contents.todict())) except FileExistsError: pass -def setup_result_comms(mocker, comms, all_ents, user_ents, uploaded): +def setup_result_comms(mocker, comms, all_ents, user_ents) -> TestEntityStorage: generate_fn = TestResult comms_calls = { "get_all": "get_results", @@ -185,7 +193,7 @@ def setup_result_comms(mocker, comms, all_ents, user_ents, uploaded): } # Enable dset retrieval since its required for result creation - setup_dset_comms(mocker, comms, [1], [1], uploaded) - mock_comms_entity_gets( - mocker, comms, generate_fn, comms_calls, all_ents, user_ents, uploaded + setup_dset_comms(mocker, comms, [1], [1]) + return mock_comms_entity_gets( + mocker, comms, generate_fn, comms_calls, all_ents, user_ents ) diff --git a/cli/medperf/tests/mocks/comms.py b/cli/medperf/tests/mocks/comms.py index 57084c849..73c1619bd 100644 --- a/cli/medperf/tests/mocks/comms.py +++ b/cli/medperf/tests/mocks/comms.py @@ -1,11 +1,44 @@ # Utility functions for mocking comms and its methods -from typing import Dict, List, Callable, Union +from typing import Dict, List, Callable, Union, TypeVar, Tuple from unittest.mock import MagicMock from pytest_mock.plugin import MockFixture - from medperf.exceptions import CommunicationRetrievalError +class TestEntityStorage: + AnyEntity = TypeVar("AnyEntity") + + def __init__(self, + generate_fun: Callable[[Dict], AnyEntity], + ents: Dict[str, Dict]): + + self.storage = ents + self.uploaded = [] + self.generate_fun = generate_fun # 🥳 <- generated fun + + def get(self, id_) -> Dict: + if id_ not in self.storage: + raise CommunicationRetrievalError(f"Get entity {id_}: not found in test storage") + return self.storage[id_] + + def upload(self, ent: Dict) -> Dict: + id_ = ent["id"] + if id_ is None or id_ == "": # not include 0 as 0 is a valid id + id_ = str(-len(self.storage)) # some non-existent id + assert id_ not in self.storage, f"Upload failed: generated id {id_} already exists in storage" + self.storage[id_] = self.generate_fun(**ent).todict() + self.uploaded.append(ent) + return self.storage[id_] + + def edit(self, ent: Dict): + id_ = ent["id"] + if id_ not in self.storage: + raise CommunicationRetrievalError(f"Edit entity {id_}: not found in test storage") + orig_value = self.storage[id_] + new_value = {**orig_value, **ent} # rewrites all fields from ent + self.storage[id_] = new_value + + def mock_comms_entity_gets( mocker: MockFixture, comms: MagicMock, @@ -13,8 +46,7 @@ def mock_comms_entity_gets( comms_calls: Dict[str, str], all_ents: List[Union[str, Dict]], user_ents: List[Union[str, Dict]], - uploaded: List, -): +) -> TestEntityStorage: """Mocks API endpoints used by an entity instance. Allows to define what is returned by each endpoint, and keeps track of submitted instances. @@ -28,69 +60,34 @@ def mock_comms_entity_gets( - get_user - get_instance - upload_instance - all_ids (List[Union[str, Dict]]): List of ids or curations that should be returned by the all endpoint - user_ids (List[Union[str, Dict]]): List of ids or configurations that should be returned by the user endpoint - uploaded (List): List that will be updated with uploaded instances + - edit_instance [optional] + all_ents (List[Union[str, Dict]]): List of ids or configurations to init storage. Should be returned by the + `all` endpoint. + user_ents (List[Union[str, Dict]]): List of ids or configurations that should be returned by the user endpoint. + Non-updatable. + Returns: + TestStorage: A link to the storage. Whenever new entity is uploaded / edited, it is updated """ get_all = comms_calls["get_all"] get_user = comms_calls["get_user"] get_instance = comms_calls["get_instance"] upload_instance = comms_calls["upload_instance"] - all_ents = [ent if isinstance(ent, dict) else {"id": ent} for ent in all_ents] - user_ents = [ent if isinstance(ent, dict) else {"id": ent} for ent in user_ents] - - instances = [generate_fn(**ent).dict() for ent in all_ents] - user_instances = [generate_fn(**ent).dict() for ent in user_ents] - mocker.patch.object(comms, get_all, return_value=instances) - mocker.patch.object(comms, get_user, return_value=user_instances) - get_behavior = get_comms_instance_behavior(generate_fn, all_ents) - mocker.patch.object( - comms, - get_instance, - side_effect=get_behavior, - ) - upload_behavior = upload_comms_instance_behavior(uploaded) - mocker.patch.object(comms, upload_instance, side_effect=upload_behavior) - - -def get_comms_instance_behavior( - generate_fn: Callable, ents: List[Union[str, Dict]] -) -> Callable: - """Function that defines a GET behavior - - Args: - generate_fn (Callable): Function to generate entity dictionaries - ents (List[Union[str, Dict]]): List of Entities configurations that are allowed to return - - Return: - function: Function that returns an entity dictionary if found, - or raises an error if not - """ - ids = [ent["id"] if isinstance(ent, dict) else ent for ent in ents] - - def get_behavior(id: int): - if id in ids: - idx = ids.index(id) - return generate_fn(**ents[idx]).dict() + def _to_dict_entity(ent: Union[str, Dict]) -> Tuple[str, Dict]: + """returns pair (id, entity-as-a-full-dict)""" + if isinstance(ent, dict): + id_, ent_params = ent["id"], ent else: - raise CommunicationRetrievalError - - return get_behavior + id_, ent_params = ent, {"id": ent} + return id_, generate_fn(**ent_params).dict() + all_ents = dict(_to_dict_entity(ent) for ent in all_ents) + user_ents = dict(_to_dict_entity(ent) for ent in user_ents) -def upload_comms_instance_behavior(uploaded: List) -> Callable: - """Function that defines the comms mocked behavior when uploading entities - - Args: - uploaded (List): List that will be updated with uploaded entities - - Returns: - Callable: Function containing the desired behavior - """ - - def upload_behavior(entity_dict): - uploaded.append(entity_dict) - return entity_dict + storage = TestEntityStorage(generate_fn, all_ents) + mocker.patch.object(comms, get_all, return_value=list(all_ents.values())) + mocker.patch.object(comms, get_user, return_value=list(user_ents.values())) - return upload_behavior + mocker.patch.object(comms, get_instance, side_effect=storage.get) + mocker.patch.object(comms, upload_instance, side_effect=storage.upload) + return storage diff --git a/cli/medperf/tests/mocks/cube.py b/cli/medperf/tests/mocks/cube.py index 9c1acbb8a..b8f8d828e 100644 --- a/cli/medperf/tests/mocks/cube.py +++ b/cli/medperf/tests/mocks/cube.py @@ -1,6 +1,6 @@ from typing import Optional from medperf.entities.cube import Cube - +from pydantic import Field EMPTY_FILE_HASH = "da39a3ee5e6b4b0d3255bfef95601890afd80709" @@ -14,9 +14,10 @@ class TestCube(Cube): parameters_hash: Optional[str] = EMPTY_FILE_HASH image_tarball_url: Optional[str] = "https://test.com/image.tar.gz" image_tarball_hash: Optional[str] = EMPTY_FILE_HASH - additional_files_tarball_url: Optional[str] = ( - "https://test.com/additional_files.tar.gz" + additional_files_tarball_url: Optional[str] = Field( + "https://test.com/additional_files.tar.gz", + alias="tarball_url" ) - additional_files_tarball_hash: Optional[str] = EMPTY_FILE_HASH + additional_files_tarball_hash: Optional[str] = Field(EMPTY_FILE_HASH, alias="tarball_hash") state: str = "OPERATION" is_valid = True diff --git a/cli/medperf/tests/mocks/dataset.py b/cli/medperf/tests/mocks/dataset.py index b3d6d4217..faed25876 100644 --- a/cli/medperf/tests/mocks/dataset.py +++ b/cli/medperf/tests/mocks/dataset.py @@ -1,6 +1,6 @@ from typing import Optional, Union -from medperf.enums import Status from medperf.entities.dataset import Dataset +from pydantic import Field class TestDataset(Dataset): @@ -10,7 +10,6 @@ class TestDataset(Dataset): data_preparation_mlcube: Union[int, str] = 1 input_data_hash: str = "input_data_hash" generated_uid: str = "generated_uid" - generated_metadata: dict = {} - status: Status = Status.APPROVED.value + generated_metadata: dict = Field({}, alias="metadata") state: str = "OPERATION" submitted_as_prepared: bool = False