From d18610cb75c5d2361a4abca6b1156a3cc15a2ce0 Mon Sep 17 00:00:00 2001 From: jeb Date: Tue, 13 Jun 2023 00:28:50 -0700 Subject: [PATCH] Added the ability to batch process a directory of .safetensor and .ckpt models. Added the ability to append the shape of the converted trt model to the model name. i.e. my_model_1x512x512.trt 1 batch size, max width 512, max height 512. To help differentiate between different trt model conversion shapes for the same base model. Added dynamic alert at the top for when you have gone over the limits of the shape a trt model can be generated at as you slide the max width, height, and batch size sliders. --- ui_trt.py | 250 +++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 219 insertions(+), 31 deletions(-) diff --git a/ui_trt.py b/ui_trt.py index 0dd72f5..b7df969 100644 --- a/ui_trt.py +++ b/ui_trt.py @@ -1,3 +1,4 @@ +import asyncio import html import os @@ -13,26 +14,162 @@ from modules.ui_components import FormRow -def export_unet_to_onnx(filename, opset): - if not filename: - modelname = shared.sd_model.sd_checkpoint_info.model_name + ".onnx" - filename = os.path.join(paths_internal.models_path, "Unet-onnx", modelname) - - export_onnx.export_current_unet_to_onnx(filename, opset) - - return f'Saved as {filename}', '' - - -def get_trt_filename(filename, onnx_filename): +def export_unet_to_onnx(filename, opset, batch_run, batch_directory): + print(f'Starting Conversion to .onnx') + + # Check if 'Unet-onnx' directory exists and create it if not + unet_onnx_path = os.path.join(paths_internal.models_path, "Unet-onnx") + os.makedirs(unet_onnx_path, exist_ok=True) + + # Use default folder if batch_directory is empty + if not batch_directory: + batch_directory = os.path.join(paths_internal.models_path, "Stable-diffusion") + # Batch mode + if batch_run: + print(f"--Batch Models mode--") + + onnx_files = os.listdir(os.path.join(paths_internal.models_path, "Unet-onnx")) + onnx_files_to_process = list() + # Take all files if destination folder is empty + if not onnx_files: + print(f"Unet-onnx is empty, adding all .safetensors and .ckpt files from {batch_directory}\n") # Debug line + onnx_files_to_process = [file for file in os.listdir(batch_directory) if file.endswith(".safetensors") or file.endswith(".ckpt")] + else: + for batch_file in os.listdir(batch_directory): + add_flag = True + for onnx_file in onnx_files: + if batch_file.split('.')[0] == onnx_file.split('.')[0]: + add_flag = False + if add_flag and (batch_file.endswith(".safetensors") or batch_file.endswith(".ckpt")): + onnx_files_to_process.append(batch_file) + print(f"Files to process:\n{onnx_files_to_process}\n") + # Exit if no files to process + if not onnx_files_to_process: + print(f'No files to convert...\nPlease uncheck the "Run Batch" checkbox or use a folder containing models.') + return f'No files to convert...\nPlease uncheck the "Run Batch" checkbox or use a folder containing models.', '' + # Process files + for i, file in enumerate(onnx_files_to_process): + print(f"Converting model file: {file}") # Debug line + modelname = os.path.splitext(file)[0] + ".onnx" + onnx_filename = os.path.join(unet_onnx_path, modelname) + print(f"Target ONNX filename: {onnx_filename}\n") # Debug line + + export_onnx.export_current_unet_to_onnx(onnx_filename, opset) + + # Ending message + print(f'Batch conversion completed for files in {batch_directory}') + return f'Batch conversion completed for files in {batch_directory}', '' + # Single mode + else: + print(f"--Single Model mode--") + if not filename: + modelname = shared.sd_model.sd_checkpoint_info.model_name + ".onnx" + filename = os.path.join(unet_onnx_path, modelname) + else: + modelname = os.path.splitext(filename)[0] + ".onnx" + filename = os.path.join(unet_onnx_path, modelname) + print(f"Target ONNX filename: {filename}\n") # Debug line + + export_onnx.export_current_unet_to_onnx(filename, opset) + + # Ending message + print(f'Done! Model saved as {filename}') + return f'Done! Model saved as {filename}', '' + + +def convert_onnx_to_trt(filename, onnx_filename, add_shape_to_filename, batch_run, batch_directory, *args): + assert not cmd_opts.disable_extension_access, "Won't run the command to create TensorRT file because extension access is disabled (use --enable-insecure-extension-access)" + print(f'Starting Conversion to .trt') + + # Check if 'Unet-trt' directory exists and create it if not + unet_onnx_path = os.path.join(paths_internal.models_path, "Unet-trt") + os.makedirs(unet_onnx_path, exist_ok=True) + + # Use default folder if batch_directory is empty + if not batch_directory: + batch_directory = os.path.join(paths_internal.models_path, "Unet-onnx") + # Batch mode + if batch_run: + print(f"--Batch Models mode--") + + trt_files = os.listdir(os.path.join(paths_internal.models_path, "Unet-trt")) + trt_files_to_process = list() + # Take all files if destination folder is empty + if not trt_files: + print(f"Unet-trt is empty, adding all .onnx files from {batch_directory}\n") # Debug line + trt_files_to_process = [file for file in os.listdir(batch_directory) if file.endswith(".onnx")] + else: + for batch_file in os.listdir(batch_directory): + add_flag = True + for trt_file in trt_files: + size = len(trt_file.split('_')) + trt_file_shape = "" + for tokens in trt_file.split('_'): + if size > 1: + trt_file_shape += tokens + if size > 2: + trt_file_shape += '_' + size -= 1 + if batch_file.split('.')[0] == trt_file.split('.')[0] or batch_file.split('.')[0] == trt_file_shape: + add_flag = False + if add_flag and batch_file.endswith(".onnx"): + trt_files_to_process.append(batch_file) + print(f"Files to process:\n{trt_files_to_process}\n") + # Exit if no files to process + if not trt_files_to_process: + print(f'No files to convert...\nPlease uncheck the "Run Batch" checkbox or use a folder containing models.') + return f'No files to convert...\nPlease uncheck the "Run Batch" checkbox or use a folder containing models.', '' + # Process files + for file in trt_files_to_process: + onnx_file = os.path.join(batch_directory, file) + print(f"Converting ONNX file: {onnx_file}") # Debug line + modelname = os.path.splitext(file)[0] + ".trt" + filename = os.path.join(paths_internal.models_path, "Unet-trt", modelname) + + trt_filename = get_trt_filename(filename, onnx_file, batch_run, add_shape_to_filename, *args) + print(f"Target TRT filename: {trt_filename}\n") # Debug line + command = export_trt.get_trt_command(trt_filename, onnx_file, *args) + + launch.run(command, live=True) + + # Ending message + print(f'Batch conversion completed for files in {batch_directory}') + return f'Batch conversion completed for files in {batch_directory}', '' + # Single mode + else: + print(f"--Single Model mode--") + trt_filename = get_trt_filename(filename, onnx_filename, add_shape_to_filename, *args) + print(f"Target TRT filename: {trt_filename}\n") # Debug line + command = export_trt.get_trt_command(trt_filename, onnx_filename, *args) + + launch.run(command, live=True) + + # Ending message + print(f'Done! Model saved as {trt_filename}') + return f'Done! Model saved as {trt_filename}', '' + + +def get_trt_filename(filename, onnx_filename, batch_run=False, add_shape_to_filename=False, *args): + modelname = os.path.splitext(os.path.basename(onnx_filename))[0]; + #print("Shape args: ", args) # args: (1, 1, 75, 750, 512, 768, 512, 960, True, '') + #({0}min_bs, {1}max_bs, {2}min_token_count, {3}max_token_count, {4}min_width, {5}max_width, {6}min_height, {7}max_height, {8}use_fp16, {9}trt_extra_args) + if(add_shape_to_filename): + if len(args) == 9: + modelname += f'_{args[0]}x{args[4]}x{args[6]}' + ".trt" + if len(args) == 10: + modelname += f'_{args[1]}x{args[5]}x{args[7]}' + ".trt" + else: + modelname += ".trt" + if batch_run: + return os.path.join(paths_internal.models_path, "Unet-trt", modelname) if filename: return filename - - modelname = os.path.splitext(os.path.basename(onnx_filename))[0] + ".trt" return os.path.join(paths_internal.models_path, "Unet-trt", modelname) -def get_trt_command(filename, onnx_filename, *args): - filename = get_trt_filename(filename, onnx_filename) +def get_trt_command(filename, onnx_filename, add_shape_to_filename, *args): + filename = get_trt_filename(filename, onnx_filename, False, add_shape_to_filename, *args) + command = export_trt.get_trt_command(filename, onnx_filename, *args) env_command = f""" @@ -59,16 +196,41 @@ def get_trt_command(filename, onnx_filename, *args):

""" - -def convert_onnx_to_trt(filename, onnx_filename, *args): - assert not cmd_opts.disable_extension_access, "won't run the command to create TensorRT file because extension access is dsabled (use --enable-insecure-extension-access)" - - filename = get_trt_filename(filename, onnx_filename) - command = export_trt.get_trt_command(filename, onnx_filename, *args) - - launch.run(command, live=True) - - return f'Saved as {filename}', '' +batch_sizes = { + 1: 92160, + 2: 129024, + 3: 159744, + 4: 184320, + 5: 184320, + 6: 221184, + 7: 229376, + 8: 229376, + 9: 276480, + 10: 286720, + 11: 281600 +} + +async def calculate_and_check_constraints(max_width, max_height, max_batch_size): + B = max_batch_size * 2 + unknown = 4 + H = max_height / 8 + W = max_width / 8 + + calculated_value = int(B * unknown * H * W) + + color = "#00FF00"; #green + batch_size = batch_sizes[max_batch_size] if max_batch_size in batch_sizes else batch_sizes[11] + if(calculated_value > batch_size): + color="#FF0000" #red + value = f"{calculated_value} / {batch_size}" + + return gr.Label.update(value=value, color=color), "" + + +def on_slider_change(max_width, max_height, max_bs): + loop = asyncio.get_event_loop() + calculated_value, _ = loop.run_until_complete(calculate_and_check_constraints(max_width, max_height, max_bs)) + return calculated_value, "" def on_ui_tabs(): @@ -82,11 +244,18 @@ def on_ui_tabs(): onnx_filename = gr.Textbox(label='Filename', value="", elem_id="onnx_filename", info="Leave empty to use the same name as model and put results into models/Unet-onnx directory") onnx_opset = gr.Number(label='ONNX opset version', precision=0, value=17, info="Leave this alone unless you know what you are doing") + batch_run_onnx = gr.Checkbox(label='Run Batch', value=False) + batch_directory_onnx = gr.Textbox(label='Directory', value="", info="Input directory containing models. Leave empty to use the default 'models/Stable-diffusion' folder as source.") + button_export_unet = gr.Button(value="Convert Unet to ONNX", variant='primary', elem_id="onnx_export_unet") with gr.Tab(label="Convert ONNX to TensorRT"): trt_source_filename = gr.Textbox(label='Onnx model filename', value="", elem_id="trt_source_filename") trt_filename = gr.Textbox(label='Output filename', value="", elem_id="trt_filename", info="Leave empty to use the same name as onnx and put results into models/Unet-trt directory") + add_shape_to_filename = gr.Checkbox(label='Add Shape to end of filename', value=False) + + with gr.Column(elem_id="trt_calculated_value"): + calculated_value_label = gr.Label(elem_id="calculated_value", value="32768 / 92160", label="Current Value / Limit NOTE: move sliders slowly to make sure they update properly", show_label=True, color="#00FF00") with gr.Column(elem_id="trt_width"): min_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Minimum width", value=512, elem_id="trt_min_width") @@ -103,36 +272,55 @@ def on_ui_tabs(): with gr.Column(elem_id="trt_token_count"): min_token_count = gr.Slider(minimum=75, maximum=750, step=75, label="Minimum prompt token count", value=75, elem_id="trt_min_token_count") max_token_count = gr.Slider(minimum=75, maximum=750, step=75, label="Maximum prompt token count", value=75, elem_id="trt_max_token_count") - + trt_extra_args = gr.Textbox(label='Extra arguments', value="", elem_id="trt_extra_args", info="Extra arguments for trtexec command in plain text form") with FormRow(elem_classes="checkboxes-row", variant="compact"): use_fp16 = gr.Checkbox(label='Use half floats', value=True, elem_id="trt_fp16") + batch_run_trt = gr.Checkbox(label='Run Batch', value=False) + batch_directory_trt = gr.Textbox(label='Directory', value="", info="Input directory containing models. Leave empty to use the default 'models/Unet-onnx' folder as source.") + button_export_trt = gr.Button(value="Convert ONNX to TensorRT", variant='primary', elem_id="trt_convert_from_onnx") button_show_trt_command = gr.Button(value="Show command for conversion", variant='secondary', elem_id="trt_convert_from_onnx") with gr.Column(variant='panel'): trt_result = gr.Label(elem_id="trt_result", value="", show_label=False) trt_info = gr.HTML(elem_id="trt_info", value="") - + calculated_value_label + max_width.change( + on_slider_change, + inputs=[max_width, max_height, max_bs], + outputs=[calculated_value_label, trt_info], + ) + max_height.change( + on_slider_change, + inputs=[max_width, max_height, max_bs], + outputs=[calculated_value_label, trt_info], + ) + max_bs.change( + on_slider_change, + inputs=[max_width, max_height, max_bs], + outputs=[calculated_value_label, trt_info], + ) + button_export_unet.click( wrap_gradio_gpu_call(export_unet_to_onnx, extra_outputs=["Conversion failed"]), - inputs=[onnx_filename, onnx_opset], + inputs=[onnx_filename, onnx_opset, batch_run_onnx, batch_directory_onnx], outputs=[trt_result, trt_info], ) button_export_trt.click( wrap_gradio_gpu_call(convert_onnx_to_trt, extra_outputs=[""]), - inputs=[trt_filename, trt_source_filename, min_bs, max_bs, min_token_count, max_token_count, min_width, max_width, min_height, max_height, use_fp16, trt_extra_args], + inputs=[trt_filename, trt_source_filename, add_shape_to_filename, batch_run_trt, batch_directory_trt, min_bs, max_bs, min_token_count, max_token_count, min_width, max_width, min_height, max_height, use_fp16, trt_extra_args], outputs=[trt_result, trt_info], ) button_show_trt_command.click( wrap_gradio_gpu_call(get_trt_command, extra_outputs=[""]), - inputs=[trt_filename, trt_source_filename, min_bs, max_bs, min_token_count, max_token_count, min_width, max_width, min_height, max_height, use_fp16, trt_extra_args], + inputs=[trt_filename, trt_source_filename, add_shape_to_filename, min_bs, max_bs, min_token_count, max_token_count, min_width, max_width, min_height, max_height, use_fp16, trt_extra_args], outputs=[trt_result, trt_info], ) return [(trt_interface, "TensorRT", "tensorrt")] - + \ No newline at end of file