From 5b8385ecb77f1dfb6f0a831a086d28fda25d6c72 Mon Sep 17 00:00:00 2001 From: Chao Li Date: Thu, 30 May 2024 11:13:06 +0800 Subject: [PATCH 01/11] Add a script for testing the accuracy of timm models after quantization --- tests/ryzenai/test_timm_acc.py | 130 +++++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 tests/ryzenai/test_timm_acc.py diff --git a/tests/ryzenai/test_timm_acc.py b/tests/ryzenai/test_timm_acc.py new file mode 100644 index 00000000..bd312629 --- /dev/null +++ b/tests/ryzenai/test_timm_acc.py @@ -0,0 +1,130 @@ +import time +from argparse import ArgumentParser +from functools import partial + +import numpy as np +import onnxruntime +import timm +from timm.data import create_dataset, create_loader +from timm.utils import AverageMeter + +from optimum.amd.ryzenai import ( + AutoQuantizationConfig, + RyzenAIModelForImageClassification, + RyzenAIOnnxQuantizer, +) + + +def parse_args(): + parser = ArgumentParser("RyzenAIQuantization") + parser.add_argument("--data-path", metavar="DIR", required=True, help="path to dataset") + parser.add_argument( + "--model_id", type=str, default="timm/resnet50.a1_in1k", help='Model id, default to "timm/resnet50.a1_in1k"' + ) + parser.add_argument( + "--dataset", type=str, default="imagenet-1k", help='Calibration dataset, default to "imagenet-1k"' + ) + parser.add_argument( + "--onnx-output-opt", default="", type=str, metavar="PATH", help="path to output optimized onnx graph" + ) + parser.add_argument("--profile", action="store_true", default=False, help="Enable profiler output.") + parser.add_argument( + "-j", "--workers", default=2, type=int, metavar="N", help="number of data loading workers (default: 2)" + ) + parser.add_argument("-b", "--batch-size", default=1, type=int, metavar="N", help="mini-batch size (default: 1)") + args, _ = parser.parse_known_args() + return args + + +def main(args): + model_id = args.model_id + + onnx_model = RyzenAIModelForImageClassification.from_pretrained( + model_id, export=True, provider="CPUExecutionProvider" + ) + # preprocess config + data_config = timm.data.resolve_data_config(pretrained_cfg=onnx_model.config.to_dict()) + transforms = timm.data.create_transform(**data_config, is_training=False) + + def preprocess_fn(ex, transforms): + image = ex["image"] + if image.mode == "L": + # Three channels. + image = image.convert("RGB") + pixel_values = transforms(image) + + return {"pixel_values": pixel_values} + + # quantize + quantizer = RyzenAIOnnxQuantizer.from_pretrained(onnx_model) + quantization_config = AutoQuantizationConfig.cpu_cnn_config() + + calibration_dataset = quantizer.get_calibration_dataset( + args.dataset, + preprocess_function=partial(preprocess_fn, transforms=transforms), + num_samples=200, + dataset_split="validation", + preprocess_batch=False, + streaming=True, + ) + quantizer.quantize( + quantization_config=quantization_config, dataset=calibration_dataset, save_dir="quantized_model" + ) + + # Set graph optimization level + sess_options = onnxruntime.SessionOptions() + sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + if args.profile: + sess_options.enable_profiling = True + if args.onnx_output_opt: + sess_options.optimized_model_filepath = args.onnx_output_opt + + session = onnxruntime.InferenceSession("quantized_model/model_quantized.onnx", sess_options) + + data_config = timm.data.resolve_data_config(pretrained_cfg=onnx_model.config.to_dict(), use_test_size=True) + + loader = create_loader( + create_dataset("", args.data_path), + input_size=data_config["input_size"], + batch_size=args.batch_size, + use_prefetcher=False, + interpolation=data_config["interpolation"], + mean=data_config["mean"], + std=data_config["std"], + num_workers=args.workers, + crop_pct=data_config["crop_pct"], + ) + + input_name = session.get_inputs()[0].name + + batch_time = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + end = time.time() + for i, (input, target) in enumerate(loader): + # run the net and return prediction + output = session.run([], {input_name: input.data.numpy()}) + output = output[0] + + # measure accuracy and record loss + prec1, prec5 = accuracy_np(output, target.numpy()) + top1.update(prec1.item(), input.size(0)) + top5.update(prec5.item(), input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + print(f" * Prec@1 {top1.avg:.3f} ({100-top1.avg:.3f}) Prec@5 {top5.avg:.3f} ({100.-top5.avg:.3f})") + + +def accuracy_np(output, target): + max_indices = np.argsort(output, axis=1)[:, ::-1] + top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean() + top1 = 100 * np.equal(max_indices[:, 0], target).mean() + return top1, top5 + + +if __name__ == "__main__": + args = parse_args() + main(args) From a6cea67736ea59fab9b0554b9d9ac2b60e43573c Mon Sep 17 00:00:00 2001 From: Chao Li Date: Thu, 30 May 2024 11:15:26 +0800 Subject: [PATCH 02/11] Add license --- tests/ryzenai/test_timm_acc.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/ryzenai/test_timm_acc.py b/tests/ryzenai/test_timm_acc.py index bd312629..119c2e40 100644 --- a/tests/ryzenai/test_timm_acc.py +++ b/tests/ryzenai/test_timm_acc.py @@ -1,3 +1,7 @@ +# +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +# import time from argparse import ArgumentParser from functools import partial From e859c4f8493ced64837a74caadd50a800a288407 Mon Sep 17 00:00:00 2001 From: Chao Li Date: Thu, 30 May 2024 18:37:30 +0800 Subject: [PATCH 03/11] rename --- tests/ryzenai/{test_timm_acc.py => test_timm_quant_and_eval.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/ryzenai/{test_timm_acc.py => test_timm_quant_and_eval.py} (100%) diff --git a/tests/ryzenai/test_timm_acc.py b/tests/ryzenai/test_timm_quant_and_eval.py similarity index 100% rename from tests/ryzenai/test_timm_acc.py rename to tests/ryzenai/test_timm_quant_and_eval.py From b361d11fa78a32f4af89c064f602afc1c17c8e36 Mon Sep 17 00:00:00 2001 From: Chao Li Date: Sat, 1 Jun 2024 22:39:53 +0800 Subject: [PATCH 04/11] Add doc string --- tests/ryzenai/test_timm_quant_and_eval.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/tests/ryzenai/test_timm_quant_and_eval.py b/tests/ryzenai/test_timm_quant_and_eval.py index 119c2e40..1151b38d 100644 --- a/tests/ryzenai/test_timm_quant_and_eval.py +++ b/tests/ryzenai/test_timm_quant_and_eval.py @@ -9,8 +9,10 @@ import numpy as np import onnxruntime import timm +import vai_q_onnx from timm.data import create_dataset, create_loader from timm.utils import AverageMeter +from tqdm import tqdm from optimum.amd.ryzenai import ( AutoQuantizationConfig, @@ -19,11 +21,23 @@ ) +""" +For example: +Float Accuracy of resnet50.tv_in1k: +- Prec@1: 76.128% +- Prec@5: 92.858% + +Quantization Accuracy of resnet50.tv_in1k: +- Prec@1: 74.072% +- Prec@5: 91.816% +""" + + def parse_args(): parser = ArgumentParser("RyzenAIQuantization") parser.add_argument("--data-path", metavar="DIR", required=True, help="path to dataset") parser.add_argument( - "--model_id", type=str, default="timm/resnet50.a1_in1k", help='Model id, default to "timm/resnet50.a1_in1k"' + "--model_id", type=str, default="timm/resnet50.tv_in1k", help='Model id, default to "timm/resnet50.tv_in1k"' ) parser.add_argument( "--dataset", type=str, default="imagenet-1k", help='Calibration dataset, default to "imagenet-1k"' @@ -62,11 +76,12 @@ def preprocess_fn(ex, transforms): # quantize quantizer = RyzenAIOnnxQuantizer.from_pretrained(onnx_model) quantization_config = AutoQuantizationConfig.cpu_cnn_config() + quantization_config.calibration_method = vai_q_onnx.CalibrationMethod.Percentile calibration_dataset = quantizer.get_calibration_dataset( args.dataset, preprocess_function=partial(preprocess_fn, transforms=transforms), - num_samples=200, + num_samples=100, dataset_split="validation", preprocess_batch=False, streaming=True, @@ -105,7 +120,7 @@ def preprocess_fn(ex, transforms): top1 = AverageMeter() top5 = AverageMeter() end = time.time() - for i, (input, target) in enumerate(loader): + for i, (input, target) in enumerate(tqdm(loader, desc="Processing")): # run the net and return prediction output = session.run([], {input_name: input.data.numpy()}) output = output[0] From ff0a20a527757a210a885e4afa75ca6db531743d Mon Sep 17 00:00:00 2001 From: Chao Li Date: Fri, 7 Jun 2024 09:51:52 +0800 Subject: [PATCH 05/11] Add new parameters to vai_q_onnx --- optimum/amd/ryzenai/configuration.py | 6 ++++++ optimum/amd/ryzenai/quantization.py | 15 ++++++++++----- tests/ryzenai/test_timm_quant_and_eval.py | 12 ++++++++++++ 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/optimum/amd/ryzenai/configuration.py b/optimum/amd/ryzenai/configuration.py index c043d473..75efa7eb 100644 --- a/optimum/amd/ryzenai/configuration.py +++ b/optimum/amd/ryzenai/configuration.py @@ -48,6 +48,12 @@ class QuantizationConfig: weights_dtype: QuantType = QuantType.QInt8 weights_symmetric: bool = True enable_dpu: bool = True + use_external_data_format: bool = False + include_cle: bool = False + include_sq: bool = False + include_fast_ft: bool = False + include_auto_mp: bool = False + extra_options: dict = None @staticmethod def quantization_type_str(activations_dtype: QuantType, weights_dtype: QuantType) -> str: diff --git a/optimum/amd/ryzenai/quantization.py b/optimum/amd/ryzenai/quantization.py index 28fc5621..822ddaed 100644 --- a/optimum/amd/ryzenai/quantization.py +++ b/optimum/amd/ryzenai/quantization.py @@ -161,7 +161,10 @@ def quantize( suffix = f"_{file_suffix}" if file_suffix else "" quantized_model_path = save_dir.joinpath(f"{self.onnx_model_path.stem}{suffix}").with_suffix(".onnx") - + if quantization_config.extra_options is None: + quantization_config.extra_options = {} + quantization_config.extra_options["WeightSymmetric"] = quantization_config.weights_symmetric + quantization_config.extra_options["ActivationSymmetric"] = quantization_config.activations_symmetric LOGGER.info("Quantizing model...") quantize_static( model_input=Path(self.onnx_model_path).as_posix(), @@ -172,10 +175,12 @@ def quantize( weight_type=quantization_config.weights_dtype, activation_type=quantization_config.activations_dtype, enable_dpu=quantization_config.enable_dpu, - extra_options={ - "WeightSymmetric": quantization_config.weights_symmetric, - "ActivationSymmetric": quantization_config.activations_symmetric, - }, + use_external_data_format=quantization_config.use_external_data_format, + include_cle=quantization_config.include_cle, + include_sq=quantization_config.include_sq, + include_fast_ft=quantization_config.include_fast_ft, + include_auto_mp=quantization_config.include_auto_mp, + extra_options=quantization_config.extra_options, ) LOGGER.info(f"Saved quantized model at: {save_dir}") diff --git a/tests/ryzenai/test_timm_quant_and_eval.py b/tests/ryzenai/test_timm_quant_and_eval.py index 1151b38d..959bbd08 100644 --- a/tests/ryzenai/test_timm_quant_and_eval.py +++ b/tests/ryzenai/test_timm_quant_and_eval.py @@ -77,6 +77,18 @@ def preprocess_fn(ex, transforms): quantizer = RyzenAIOnnxQuantizer.from_pretrained(onnx_model) quantization_config = AutoQuantizationConfig.cpu_cnn_config() quantization_config.calibration_method = vai_q_onnx.CalibrationMethod.Percentile + quantization_config.include_cle = True + quantization_config.include_fast_ft = True + quantization_config.extra_options = { + "FastFinetune": { + "BatchSize": 1, + "NumIterations": 1000, + "LearningRate": 0.1, + "OptimAlgorithm": "adaround", + "OptimDevice": "cpu", + "EarlyStop": True, + }, + } calibration_dataset = quantizer.get_calibration_dataset( args.dataset, From 2b0b78fdad8173d6b9ef8db547889f9752fd5f1f Mon Sep 17 00:00:00 2001 From: Chao Li Date: Thu, 13 Jun 2024 22:44:59 +0800 Subject: [PATCH 06/11] Add parameters for val data and calib data. --- tests/ryzenai/test_timm_quant_and_eval.py | 146 +++++++++++++++++----- 1 file changed, 115 insertions(+), 31 deletions(-) diff --git a/tests/ryzenai/test_timm_quant_and_eval.py b/tests/ryzenai/test_timm_quant_and_eval.py index 959bbd08..6fc18856 100644 --- a/tests/ryzenai/test_timm_quant_and_eval.py +++ b/tests/ryzenai/test_timm_quant_and_eval.py @@ -2,14 +2,18 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT # +import os +import shutil +import tarfile import time from argparse import ArgumentParser -from functools import partial import numpy as np import onnxruntime import timm +import torch import vai_q_onnx +from datasets import Dataset from timm.data import create_dataset, create_loader from timm.utils import AverageMeter from tqdm import tqdm @@ -22,22 +26,38 @@ """ +If you already have an ImageNet datasets, you can directly use your dataset path with' --calib-data-path' and '--eval-data-path'. + +To prepare the test data, please check the download section of the main website: +https://huggingface.co/datasets/imagenet-1k/tree/main/data. +You need to register and download **val_images.tar.gz**. + For example: +python test_timm_quant_and_eval.py -c $PATH/calib_100 -e $PATH/val_data -m timm/resnetv2_50.a1h_in1k +or +python test_timm_quant_and_eval.py -v $PATH/val_images.tar.gz -m timm/resnetv2_50.a1h_in1k + Float Accuracy of resnet50.tv_in1k: - Prec@1: 76.128% - Prec@5: 92.858% Quantization Accuracy of resnet50.tv_in1k: -- Prec@1: 74.072% -- Prec@5: 91.816% +- Prec@1: 74.384% +- Prec@5: 91.968% """ def parse_args(): parser = ArgumentParser("RyzenAIQuantization") - parser.add_argument("--data-path", metavar="DIR", required=True, help="path to dataset") + parser.add_argument("-v", "--val-path", metavar="DIR", required=False, help="path to dataset") + parser.add_argument("-c", "--calib-data-path", metavar="DIR", required=False, help="path to dataset") + parser.add_argument("-e", "--eval-data-path", metavar="DIR", required=False, help="path to dataset") parser.add_argument( - "--model_id", type=str, default="timm/resnet50.tv_in1k", help='Model id, default to "timm/resnet50.tv_in1k"' + "-m", + "--model_id", + type=str, + default="timm/resnet50.tv_in1k", + help='Model id, default to "timm/resnet50.tv_in1k"', ) parser.add_argument( "--dataset", type=str, default="imagenet-1k", help='Calibration dataset, default to "imagenet-1k"' @@ -51,53 +71,119 @@ def parse_args(): ) parser.add_argument("-b", "--batch-size", default=1, type=int, metavar="N", help="mini-batch size (default: 1)") args, _ = parser.parse_known_args() + if args.val_path is None and (args.calib_data_path is None and args.eval_data_path is None): + parser.error("You must either provide --calib-data-path and --eval-data-path, or --val-path") + return args def main(args): + # prepare val data and calib data + if (args.calib_data_path is None and args.eval_data_path is None) and args.val_path is not None: + os.makedirs("val_data", exist_ok=True) + with tarfile.open(args.val_path, "r:gz") as tar: + tar.extractall(path="val_data") + source_folder = "val_data" + calib_data_path = "calib_data" + if not os.path.exists(source_folder): + raise ValueError("The val_data does not exist.") + files = os.listdir(source_folder) + for filename in files: + if not filename.startswith("ILSVRC2012_val_") or not filename.endswith(".JPEG"): + continue + + n_identifier = filename.split("_")[-1].split(".")[0] + folder_name = n_identifier + folder_path = os.path.join(source_folder, folder_name) + if not os.path.exists(folder_path): + os.makedirs(folder_path) + file_path = os.path.join(source_folder, filename) + destination = os.path.join(folder_path, filename) + shutil.move(file_path, destination) + + print("File organization complete.") + + if not os.path.exists(calib_data_path): + os.makedirs(calib_data_path) + + destination_folder = calib_data_path + + subfolders = os.listdir(source_folder) + + for subfolder in subfolders: + source_subfolder = os.path.join(source_folder, subfolder) + destination_subfolder = os.path.join(destination_folder, subfolder) + os.makedirs(destination_subfolder, exist_ok=True) + + files = os.listdir(source_subfolder) + + if files: + file_to_copy = files[0] + source_file = os.path.join(source_subfolder, file_to_copy) + destination_file = os.path.join(destination_subfolder, file_to_copy) + + shutil.copy(source_file, destination_file) + + print("Creating calibration dataset complete.") + model_id = args.model_id onnx_model = RyzenAIModelForImageClassification.from_pretrained( model_id, export=True, provider="CPUExecutionProvider" ) - # preprocess config - data_config = timm.data.resolve_data_config(pretrained_cfg=onnx_model.config.to_dict()) - transforms = timm.data.create_transform(**data_config, is_training=False) - - def preprocess_fn(ex, transforms): - image = ex["image"] - if image.mode == "L": - # Three channels. - image = image.convert("RGB") - pixel_values = transforms(image) - - return {"pixel_values": pixel_values} + # # preprocess config + data_config = timm.data.resolve_data_config(pretrained_cfg=onnx_model.config.to_dict(), use_test_size=True) - # quantize + # # quantize quantizer = RyzenAIOnnxQuantizer.from_pretrained(onnx_model) quantization_config = AutoQuantizationConfig.cpu_cnn_config() quantization_config.calibration_method = vai_q_onnx.CalibrationMethod.Percentile quantization_config.include_cle = True quantization_config.include_fast_ft = True quantization_config.extra_options = { + "CalibDataSize": 200, "FastFinetune": { - "BatchSize": 1, - "NumIterations": 1000, + "BatchSize": 2, + "NumIterations": 10000, "LearningRate": 0.1, "OptimAlgorithm": "adaround", "OptimDevice": "cpu", "EarlyStop": True, }, + "Percentile": 99.9999, } - calibration_dataset = quantizer.get_calibration_dataset( - args.dataset, - preprocess_function=partial(preprocess_fn, transforms=transforms), - num_samples=100, - dataset_split="validation", - preprocess_batch=False, - streaming=True, + calib_loader = create_loader( + create_dataset("", args.calib_data_path), + input_size=data_config["input_size"], + batch_size=args.batch_size, + use_prefetcher=False, + interpolation=data_config["interpolation"], + mean=data_config["mean"], + std=data_config["std"], + num_workers=args.workers, + crop_pct=data_config["crop_pct"], ) + + data_list = [] + labels_list = [] + + for batch in calib_loader: + data, labels = batch + data_list.append(data) + labels_list.append(labels) + + data_list = torch.cat(data_list, dim=0) + labels_list = torch.cat(labels_list, dim=0) + + data_np = data_list.numpy() + + data_dict = { + "pixel_values": data_np, + } + + calibration_dataset = Dataset.from_dict(data_dict) + quantizer.quantize( quantization_config=quantization_config, dataset=calibration_dataset, save_dir="quantized_model" ) @@ -112,10 +198,8 @@ def preprocess_fn(ex, transforms): session = onnxruntime.InferenceSession("quantized_model/model_quantized.onnx", sess_options) - data_config = timm.data.resolve_data_config(pretrained_cfg=onnx_model.config.to_dict(), use_test_size=True) - loader = create_loader( - create_dataset("", args.data_path), + create_dataset("", args.eval_data_path), input_size=data_config["input_size"], batch_size=args.batch_size, use_prefetcher=False, @@ -132,7 +216,7 @@ def preprocess_fn(ex, transforms): top1 = AverageMeter() top5 = AverageMeter() end = time.time() - for i, (input, target) in enumerate(tqdm(loader, desc="Processing")): + for input, target in tqdm(loader, desc="Processing"): # run the net and return prediction output = session.run([], {input_name: input.data.numpy()}) output = output[0] From 1dd53c610bcf50691a7d990b940494cf8d18d120 Mon Sep 17 00:00:00 2001 From: Chao Li Date: Sat, 15 Jun 2024 22:56:45 +0800 Subject: [PATCH 07/11] Change the number of calibration images to 200 --- tests/ryzenai/test_timm_quant_and_eval.py | 100 ++++++++++++---------- 1 file changed, 55 insertions(+), 45 deletions(-) diff --git a/tests/ryzenai/test_timm_quant_and_eval.py b/tests/ryzenai/test_timm_quant_and_eval.py index 6fc18856..e0ff5ff9 100644 --- a/tests/ryzenai/test_timm_quant_and_eval.py +++ b/tests/ryzenai/test_timm_quant_and_eval.py @@ -80,51 +80,62 @@ def parse_args(): def main(args): # prepare val data and calib data if (args.calib_data_path is None and args.eval_data_path is None) and args.val_path is not None: - os.makedirs("val_data", exist_ok=True) - with tarfile.open(args.val_path, "r:gz") as tar: - tar.extractall(path="val_data") source_folder = "val_data" calib_data_path = "calib_data" - if not os.path.exists(source_folder): - raise ValueError("The val_data does not exist.") - files = os.listdir(source_folder) - for filename in files: - if not filename.startswith("ILSVRC2012_val_") or not filename.endswith(".JPEG"): - continue - - n_identifier = filename.split("_")[-1].split(".")[0] - folder_name = n_identifier - folder_path = os.path.join(source_folder, folder_name) - if not os.path.exists(folder_path): - os.makedirs(folder_path) - file_path = os.path.join(source_folder, filename) - destination = os.path.join(folder_path, filename) - shutil.move(file_path, destination) - - print("File organization complete.") - - if not os.path.exists(calib_data_path): - os.makedirs(calib_data_path) - - destination_folder = calib_data_path - - subfolders = os.listdir(source_folder) - - for subfolder in subfolders: - source_subfolder = os.path.join(source_folder, subfolder) - destination_subfolder = os.path.join(destination_folder, subfolder) - os.makedirs(destination_subfolder, exist_ok=True) - - files = os.listdir(source_subfolder) - - if files: - file_to_copy = files[0] - source_file = os.path.join(source_subfolder, file_to_copy) - destination_file = os.path.join(destination_subfolder, file_to_copy) - - shutil.copy(source_file, destination_file) - - print("Creating calibration dataset complete.") + if os.path.isdir(source_folder) and os.path.isdir(calib_data_path): + print( + f"Detected that {source_folder} and {calib_data_path} already exist, skipping the creation of the calibration dataset." + ) + else: + os.makedirs(source_folder, exist_ok=True) + with tarfile.open(args.val_path, "r:gz") as tar: + tar.extractall(path=source_folder) + + if not os.path.exists(source_folder): + raise ValueError("The val_data does not exist.") + files = os.listdir(source_folder) + for filename in files: + if not filename.startswith("ILSVRC2012_val_") or not filename.endswith(".JPEG"): + continue + + n_identifier = filename.split("_")[-1].split(".")[0] + folder_name = n_identifier + folder_path = os.path.join(source_folder, folder_name) + if not os.path.exists(folder_path): + os.makedirs(folder_path) + file_path = os.path.join(source_folder, filename) + destination = os.path.join(folder_path, filename) + shutil.move(file_path, destination) + + print("File organization complete.") + + if not os.path.exists(calib_data_path): + os.makedirs(calib_data_path) + + destination_folder = calib_data_path + + subfolders = os.listdir(source_folder) + cnt = 0 + for subfolder in subfolders: + source_subfolder = os.path.join(source_folder, subfolder) + destination_subfolder = os.path.join(destination_folder, subfolder) + os.makedirs(destination_subfolder, exist_ok=True) + + files = os.listdir(source_subfolder) + + if files: + file_to_copy = files[0] + source_file = os.path.join(source_subfolder, file_to_copy) + destination_file = os.path.join(destination_subfolder, file_to_copy) + + shutil.copy(source_file, destination_file) + cnt += 1 + if cnt > 200: + break + + print("Creating calibration dataset complete.") + args.calib_data_path = source_folder + args.eval_data_path = calib_data_path model_id = args.model_id @@ -141,13 +152,12 @@ def main(args): quantization_config.include_cle = True quantization_config.include_fast_ft = True quantization_config.extra_options = { - "CalibDataSize": 200, "FastFinetune": { "BatchSize": 2, "NumIterations": 10000, "LearningRate": 0.1, "OptimAlgorithm": "adaround", - "OptimDevice": "cpu", + "OptimDevice": "cuda:0", "EarlyStop": True, }, "Percentile": 99.9999, From ad96e1abf4304c6343ac7c450a705f4761d8537c Mon Sep 17 00:00:00 2001 From: Chao Li Date: Sun, 16 Jun 2024 13:39:40 +0800 Subject: [PATCH 08/11] bugfix --- tests/ryzenai/test_timm_quant_and_eval.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/ryzenai/test_timm_quant_and_eval.py b/tests/ryzenai/test_timm_quant_and_eval.py index e0ff5ff9..bafe9e97 100644 --- a/tests/ryzenai/test_timm_quant_and_eval.py +++ b/tests/ryzenai/test_timm_quant_and_eval.py @@ -78,6 +78,7 @@ def parse_args(): def main(args): + torch.multiprocessing.set_sharing_strategy("file_system") # prepare val data and calib data if (args.calib_data_path is None and args.eval_data_path is None) and args.val_path is not None: source_folder = "val_data" @@ -130,12 +131,12 @@ def main(args): shutil.copy(source_file, destination_file) cnt += 1 - if cnt > 200: + if cnt >= 200: break print("Creating calibration dataset complete.") - args.calib_data_path = source_folder - args.eval_data_path = calib_data_path + args.calib_data_path = calib_data_path + args.eval_data_path = source_folder model_id = args.model_id From a4283b0b0e1ac9fd8ea7e4cc126fceae2a444952 Mon Sep 17 00:00:00 2001 From: Chao Li Date: Sun, 16 Jun 2024 22:47:47 +0800 Subject: [PATCH 09/11] update --- tests/ryzenai/test_timm_quant_and_eval.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/ryzenai/test_timm_quant_and_eval.py b/tests/ryzenai/test_timm_quant_and_eval.py index bafe9e97..6176ba46 100644 --- a/tests/ryzenai/test_timm_quant_and_eval.py +++ b/tests/ryzenai/test_timm_quant_and_eval.py @@ -149,17 +149,26 @@ def main(args): # # quantize quantizer = RyzenAIOnnxQuantizer.from_pretrained(onnx_model) quantization_config = AutoQuantizationConfig.cpu_cnn_config() + quantization_config.activations_dtype=vai_q_onnx.QuantType.QInt8 quantization_config.calibration_method = vai_q_onnx.CalibrationMethod.Percentile quantization_config.include_cle = True quantization_config.include_fast_ft = True quantization_config.extra_options = { "FastFinetune": { "BatchSize": 2, - "NumIterations": 10000, - "LearningRate": 0.1, - "OptimAlgorithm": "adaround", - "OptimDevice": "cuda:0", - "EarlyStop": True, + 'FixedSeed': 1705472343, + 'NumBatches': 1, + 'NumIterations': 10000, + 'LearningRate': 0.1, + 'OptimAlgorithm': 'adaround', + 'OptimDevice': "cuda:0", # or 'cpu' + 'LRAdjust': (), + 'SelectiveUpdate': False, + 'EarlyStop': True, + 'DropRatio': 0.75, + 'RegParam': 0.01, # default + 'BetaRange': (20, 2), # default + 'WarmStart': 0.2, # default }, "Percentile": 99.9999, } From 863bba8e1697ae3107e84ce83eca955857a6e60c Mon Sep 17 00:00:00 2001 From: Chao Li Date: Sun, 16 Jun 2024 22:58:00 +0800 Subject: [PATCH 10/11] Add ipu transformer config --- optimum/amd/ryzenai/configuration.py | 23 +++++++++++++++-- tests/ryzenai/test_timm_quant_and_eval.py | 31 ++++++++++------------- 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/optimum/amd/ryzenai/configuration.py b/optimum/amd/ryzenai/configuration.py index 75efa7eb..2540d973 100644 --- a/optimum/amd/ryzenai/configuration.py +++ b/optimum/amd/ryzenai/configuration.py @@ -48,6 +48,7 @@ class QuantizationConfig: weights_dtype: QuantType = QuantType.QInt8 weights_symmetric: bool = True enable_dpu: bool = True + enable_ipu_transformer: bool = False use_external_data_format: bool = False include_cle: bool = False include_sq: bool = False @@ -88,20 +89,38 @@ def ipu_cnn_config(): enable_dpu=True, ) + @staticmethod + def ipu_transformer_config(): + return QuantizationConfig( + format=QuantFormat.QDQ, + calibration_method=vai_q_onnx.CalibrationMethod.MinMax, + activations_dtype=QuantType.QInt8, + activations_symmetric=False, + weights_dtype=QuantType.QInt8, + weights_symmetric=True, + enable_ipu_transformer=True, + ) + @staticmethod def cpu_cnn_config( use_symmetric_activations: bool = False, use_symmetric_weights: bool = True, enable_dpu: bool = False, + include_cle: bool = True, + include_fast_ft: bool = True, + extra_options: dict = None, ): return QuantizationConfig( format=QuantFormat.QDQ, - calibration_method=vai_q_onnx.CalibrationMethod.MinMax, - activations_dtype=QuantType.QUInt8, + calibration_method=vai_q_onnx.CalibrationMethod.Percentile, + activations_dtype=QuantType.QInt8, activations_symmetric=use_symmetric_activations, weights_dtype=QuantType.QInt8, weights_symmetric=use_symmetric_weights, enable_dpu=enable_dpu, + include_cle=include_cle, + include_fast_ft=include_fast_ft, + extra_options=extra_options, ) diff --git a/tests/ryzenai/test_timm_quant_and_eval.py b/tests/ryzenai/test_timm_quant_and_eval.py index 6176ba46..7fea484b 100644 --- a/tests/ryzenai/test_timm_quant_and_eval.py +++ b/tests/ryzenai/test_timm_quant_and_eval.py @@ -12,7 +12,6 @@ import onnxruntime import timm import torch -import vai_q_onnx from datasets import Dataset from timm.data import create_dataset, create_loader from timm.utils import AverageMeter @@ -149,26 +148,22 @@ def main(args): # # quantize quantizer = RyzenAIOnnxQuantizer.from_pretrained(onnx_model) quantization_config = AutoQuantizationConfig.cpu_cnn_config() - quantization_config.activations_dtype=vai_q_onnx.QuantType.QInt8 - quantization_config.calibration_method = vai_q_onnx.CalibrationMethod.Percentile - quantization_config.include_cle = True - quantization_config.include_fast_ft = True quantization_config.extra_options = { "FastFinetune": { "BatchSize": 2, - 'FixedSeed': 1705472343, - 'NumBatches': 1, - 'NumIterations': 10000, - 'LearningRate': 0.1, - 'OptimAlgorithm': 'adaround', - 'OptimDevice': "cuda:0", # or 'cpu' - 'LRAdjust': (), - 'SelectiveUpdate': False, - 'EarlyStop': True, - 'DropRatio': 0.75, - 'RegParam': 0.01, # default - 'BetaRange': (20, 2), # default - 'WarmStart': 0.2, # default + "FixedSeed": 1705472343, + "NumBatches": 1, + "NumIterations": 10000, + "LearningRate": 0.1, + "OptimAlgorithm": "adaround", + "OptimDevice": "cuda:0", # or 'cpu' + "LRAdjust": (), + "SelectiveUpdate": False, + "EarlyStop": True, + "DropRatio": 0.75, + "RegParam": 0.01, # default + "BetaRange": (20, 2), # default + "WarmStart": 0.2, # default }, "Percentile": 99.9999, } From de99a3f19b45d512811af6bb01613c4d1789187a Mon Sep 17 00:00:00 2001 From: Chao Li Date: Tue, 25 Jun 2024 16:14:49 +0800 Subject: [PATCH 11/11] remove unused labels_list --- tests/ryzenai/test_timm_quant_and_eval.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/ryzenai/test_timm_quant_and_eval.py b/tests/ryzenai/test_timm_quant_and_eval.py index 7fea484b..856be184 100644 --- a/tests/ryzenai/test_timm_quant_and_eval.py +++ b/tests/ryzenai/test_timm_quant_and_eval.py @@ -181,15 +181,12 @@ def main(args): ) data_list = [] - labels_list = [] for batch in calib_loader: - data, labels = batch + data = batch[0] data_list.append(data) - labels_list.append(labels) data_list = torch.cat(data_list, dim=0) - labels_list = torch.cat(labels_list, dim=0) data_np = data_list.numpy()