diff --git a/frontends/krita/krita_diff/__init__.py b/frontends/krita/krita_diff/__init__.py index 42af00ed..a087dbf3 100644 --- a/frontends/krita/krita_diff/__init__.py +++ b/frontends/krita/krita_diff/__init__.py @@ -8,6 +8,7 @@ TAB_SDCOMMON, TAB_TXT2IMG, TAB_UPSCALE, + TAB_CONTROLNET ) from .docker import create_docker from .extension import SDPluginExtension @@ -18,6 +19,7 @@ SDCommonPage, Txt2ImgPage, UpscalePage, + ControlNetPage ) from .pages.preview import PreviewPage from .script import script @@ -60,6 +62,13 @@ create_docker(UpscalePage), ) ) +instance.addDockWidgetFactory( + DockWidgetFactory( + TAB_CONTROLNET, + DockWidgetFactoryBase.DockLeft, + create_docker(ControlNetPage), + ) +) instance.addDockWidgetFactory( DockWidgetFactory( TAB_CONFIG, diff --git a/frontends/krita/krita_diff/client.py b/frontends/krita/krita_diff/client.py index 2f89a393..1c124834 100644 --- a/frontends/krita/krita_diff/client.py +++ b/frontends/krita/krita_diff/client.py @@ -14,13 +14,21 @@ LONG_TIMEOUT, OFFICIAL_ROUTE_PREFIX, ROUTE_PREFIX, + CONTROLNET_ROUTE_PREFIX, SHORT_TIMEOUT, STATE_DONE, STATE_READY, STATE_URLERROR, THREADED, ) -from .utils import bytewise_xor, fix_prompt, get_ext_args, get_ext_key, img_to_b64 +from .utils import ( + bytewise_xor, + fix_prompt, + get_ext_args, + get_ext_key, + img_to_b64, + calculate_resized_image_dimensions +) # NOTE: backend queues up responses, so no explicit need to block multiple requests # except to prevent user from spamming themselves @@ -238,6 +246,78 @@ def common_params(self, has_selection): save_samples=self.cfg("save_temp_images", bool), ) return params + + def options_params(self): + """Parameters that are specific for the official API options endpoint + or overriding settings.""" + params = dict( + sd_model_checkpoint=self.cfg("sd_model", str), + sd_vae=self.cfg("sd_vae", str), + CLIP_stop_at_last_layers=self.cfg("clip_skip", int), + upscaler_for_img2img=self.cfg("upscaler_name", str), + face_restoration_model=self.cfg("face_restorer_model", str), + code_former_weight=self.cfg("codeformer_weight", float), + #Couldn't find filter_nsfw option for official API. + img2img_fix_steps=self.cfg("do_exact_steps", bool), #Not sure if this is matched correctly. + img2img_color_correction=self.cfg("img2img_color_correct", bool), + return_grid=self.cfg("include_grid", bool) + ) + return params + + def official_api_common_params(self, has_selection, width, height, + controlnet_src_imgs): + """Parameters used by most official API endpoints.""" + tiling = self.cfg("sd_tiling", bool) and not ( + self.cfg("only_full_img_tiling", bool) and has_selection + ) + + params = dict( + batch_size=self.cfg("sd_batch_size", int), + width=width, + height=height, + tiling=tiling, + restore_faces=self.cfg("face_restorer_model", str) != "None", + override_settings=self.options_params(), + override_settings_restore_afterwards=True, + alwayson_scripts={} + ) + + if controlnet_src_imgs: + controlnet_units_param = list() + + for i in range(len(self.cfg("controlnet_unit_list", "QStringList"))): + if self.cfg(f"controlnet{i}_enable", bool): + controlnet_units_param.append( + self.controlnet_unit_params(img_to_b64(controlnet_src_imgs[str(i)]), i, width, height) + ) + else: + controlnet_units_param.append({"enabled": False}) + + params["alwayson_scripts"].update({ + "controlnet": { + "args": controlnet_units_param + } + }) + + return params + + def controlnet_unit_params(self, image: str, unit: int, width: int, height: int): + preprocessor_resolution = min(width, height) if self.cfg(f"controlnet{unit}_pixel_perfect", bool) \ + else self.cfg(f"controlnet{unit}_preprocessor_resolution", int) + params = dict( + input_image=image, + module=self.cfg(f"controlnet{unit}_preprocessor", str), + model=self.cfg(f"controlnet{unit}_model", str), + weight=self.cfg(f"controlnet{unit}_weight", float), + lowvram=self.cfg(f"controlnet{unit}_low_vram", bool), + processor_res=preprocessor_resolution, + threshold_a=self.cfg(f"controlnet{unit}_threshold_a", float), + threshold_b=self.cfg(f"controlnet{unit}_threshold_b", float), + guidance_start=self.cfg(f"controlnet{unit}_guidance_start", float), + guidance_end=self.cfg(f"controlnet{unit}_guidance_end", float), + control_mode=self.cfg(f"controlnet{unit}_control_mode", str) + ) + return params def get_config(self): def cb(obj): @@ -294,6 +374,42 @@ def cb(obj): self.get("config", cb, ignore_no_connection=True) + def get_controlnet_config(self): + '''Get models and modules for ControlNet''' + def check_response(obj, key: str): + try: + assert key in obj + except: + self.status.emit( + f"{STATE_URLERROR}: incompatible response, are you running the right API?" + ) + print("Invalid Response:\n", obj) + return + + def set_model_list(obj): + key = "model_list" + check_response(obj, key) + self.cfg.set("controlnet_model_list", ["None"] + obj[key]) + + def set_preprocessor_list(obj): + key = "module_list" + check_response(obj, key) + self.cfg.set("controlnet_preprocessor_list", obj[key]) + + #Get controlnet API url + url = get_url(self.cfg, prefix=CONTROLNET_ROUTE_PREFIX) + self.get("model_list", set_model_list, base_url=url) + self.get("module_list", set_preprocessor_list, base_url=url) + + # def post_options(self): + # """Sets the options for the backend, using the official API""" + # def cb(response): + # assert response is not None, "Backend Error, check terminal" + + # params = self.options_params() + # url = get_url(self.cfg, prefix=OFFICIAL_ROUTE_PREFIX) + # self.post("options", params, cb, base_url=url) + def post_txt2img(self, cb, width, height, has_selection): params = dict(orig_width=width, orig_height=height) if not self.cfg("just_use_yaml", bool): @@ -320,8 +436,48 @@ def post_txt2img(self, cb, width, height, has_selection): self.post("txt2img", params, cb) + def post_official_api_txt2img(self, cb, width, height, has_selection, + controlnet_src_imgs: dict = {}): + """Uses official API. Leave controlnet_src_imgs empty to not use controlnet.""" + if not self.cfg("just_use_yaml", bool): + seed = ( + int(self.cfg("txt2img_seed", str)) # Qt casts int as 32-bit int + if not self.cfg("txt2img_seed", str).strip() == "" + else -1 + ) + ext_name = self.cfg("txt2img_script", str) + ext_args = get_ext_args(self.ext_cfg, "scripts_txt2img", ext_name) + resized_width, resized_height = calculate_resized_image_dimensions( + self.cfg("sd_base_size", int), self.cfg("sd_max_size", int), width, height + ) + disable_base_and_max_size = self.cfg("disable_sddebz_highres", bool) + params = self.official_api_common_params( + has_selection, + resized_width if not disable_base_and_max_size else width, + resized_height if not disable_base_and_max_size else height, + controlnet_src_imgs + ) + params.update( + prompt=fix_prompt(self.cfg("txt2img_prompt", str)), + negative_prompt=fix_prompt(self.cfg("txt2img_negative_prompt", str)), + sampler_name=self.cfg("txt2img_sampler", str), + steps=self.cfg("txt2img_steps", int), + cfg_scale=self.cfg("txt2img_cfg_scale", float), + seed=seed, + enable_hr=self.cfg("txt2img_highres", bool), + hr_upscaler=self.cfg("upscaler_name", str), + hr_resize_x=width, + hr_resize_y=height, + denoising_strength=self.cfg("txt2img_denoising_strength", float), + script_name=ext_name if ext_name != "None" else None, #Prevent unrecognized "None" script from backend + script_args=ext_args if ext_name != "None" else [] + ) + + url = get_url(self.cfg, prefix=OFFICIAL_ROUTE_PREFIX) + self.post("txt2img", params, cb, base_url=url) + def post_img2img(self, cb, src_img, mask_img, has_selection): - params = dict(is_inpaint=False, src_img=img_to_b64(src_img)) + params = dict(is_inpaint=False, src_img=img_to_b64(src_img)) if not self.cfg("just_use_yaml", bool): seed = ( int(self.cfg("img2img_seed", str)) # Qt casts int as 32-bit int @@ -346,11 +502,49 @@ def post_img2img(self, cb, src_img, mask_img, has_selection): self.post("img2img", params, cb) + def post_official_api_img2img(self, cb, src_img, width, height, has_selection, + controlnet_src_imgs: dict = {}): + """Uses official API. Leave controlnet_src_imgs empty to not use controlnet.""" + params = dict(init_images=[img_to_b64(src_img)]) + if not self.cfg("just_use_yaml", bool): + seed = ( + int(self.cfg("img2img_seed", str)) # Qt casts int as 32-bit int + if not self.cfg("img2img_seed", str).strip() == "" + else -1 + ) + ext_name = self.cfg("img2img_script", str) + ext_args = get_ext_args(self.ext_cfg, "scripts_img2img", ext_name) + resized_width, resized_height = calculate_resized_image_dimensions( + self.cfg("sd_base_size", int), self.cfg("sd_max_size", int), width, height + ) + disable_base_and_max_size = self.cfg("disable_sddebz_highres", bool) + params.update(self.official_api_common_params( + has_selection, + resized_width if not disable_base_and_max_size else width, + resized_height if not disable_base_and_max_size else height, + controlnet_src_imgs + )) + params.update( + prompt=fix_prompt(self.cfg("img2img_prompt", str)), + negative_prompt=fix_prompt(self.cfg("img2img_negative_prompt", str)), + sampler_name=self.cfg("img2img_sampler", str), + steps=self.cfg("img2img_steps", int), + cfg_scale=self.cfg("img2img_cfg_scale", float), + seed=seed, + denoising_strength=self.cfg("img2img_denoising_strength", float), + script_name=ext_name if ext_name != "None" else None, + script_args=ext_args if ext_name != "None" else [] + ) + + url = get_url(self.cfg, prefix=OFFICIAL_ROUTE_PREFIX) + self.post("img2img", params, cb, base_url=url) + def post_inpaint(self, cb, src_img, mask_img, has_selection): assert mask_img, "Inpaint layer is needed for inpainting!" params = dict( is_inpaint=True, src_img=img_to_b64(src_img), mask_img=img_to_b64(mask_img) ) + if not self.cfg("just_use_yaml", bool): seed = ( int(self.cfg("inpaint_seed", str)) # Qt casts int as 32-bit int @@ -385,6 +579,57 @@ def post_inpaint(self, cb, src_img, mask_img, has_selection): self.post("img2img", params, cb) + def post_official_api_inpaint(self, cb, src_img, mask_img, width, height, has_selection, + controlnet_src_imgs: dict = {}): + """Uses official API. Leave controlnet_src_imgs empty to not use controlnet.""" + assert mask_img, "Inpaint layer is needed for inpainting!" + params = dict( + init_images=[img_to_b64(src_img)], mask=img_to_b64(mask_img) + ) + if not self.cfg("just_use_yaml", bool): + seed = ( + int(self.cfg("inpaint_seed", str)) # Qt casts int as 32-bit int + if not self.cfg("inpaint_seed", str).strip() == "" + else -1 + ) + fill = self.cfg("inpaint_fill_list", "QStringList").index( + self.cfg("inpaint_fill", str) + ) + ext_name = self.cfg("inpaint_script", str) + ext_args = get_ext_args(self.ext_cfg, "scripts_inpaint", ext_name) + resized_width, resized_height = calculate_resized_image_dimensions( + self.cfg("sd_base_size", int), self.cfg("sd_max_size", int), width, height + ) + invert_mask = self.cfg("inpaint_invert_mask", bool) + disable_base_and_max_size = self.cfg("disable_sddebz_highres", bool) + params.update(self.official_api_common_params( + has_selection, + resized_width if not disable_base_and_max_size else width, + resized_height if not disable_base_and_max_size else height, + controlnet_src_imgs + )) + params.update( + prompt=fix_prompt(self.cfg("inpaint_prompt", str)), + negative_prompt=fix_prompt(self.cfg("inpaint_negative_prompt", str)), + sampler_name=self.cfg("inpaint_sampler", str), + steps=self.cfg("inpaint_steps", int), + cfg_scale=self.cfg("inpaint_cfg_scale", float), + seed=seed, + denoising_strength=self.cfg("inpaint_denoising_strength", float), + script_name=ext_name if ext_name != "None" else None, + script_args=ext_args if ext_name != "None" else [], + inpainting_mask_invert=0 if not invert_mask else 1, + inpainting_fill=fill, + mask_blur=0, + inpaint_full_res=False + #not sure what's the equivalent of mask weight for official API + ) + + params["override_settings"]["return_grid"] = False + + url = get_url(self.cfg, prefix=OFFICIAL_ROUTE_PREFIX) + self.post("img2img", params, cb, base_url=url) + def post_upscale(self, cb, src_img): params = ( { @@ -397,6 +642,61 @@ def post_upscale(self, cb, src_img): ) self.post("upscale", params, cb) + def post_official_api_upscale_postprocess(self, cb, src_imgs: list, width, height): + """Uses official API. Intended for finalizing img2img pipeline.""" + + params = dict( + resize_mode=1, + show_extras_results=False, + gfpgan_visibility=0, + codeformer_visibility=0, + codeformer_weight=0, + upscaling_resize=1, + upscaling_resize_w=width, + upscaling_resize_h=height, + upscaling_crop=True, + upscaler_1=self.cfg("upscaler_name", str), + upscaler_2="None", # Todo: would be nice to support blended upscalers + extras_upscaler_2_visibility=0, + upscale_first=False, + imageList=[] + ) + + for img in src_imgs: + params["imageList"].append({ + "data": img, + "name": "example_image" + }) + + url = get_url(self.cfg, prefix=OFFICIAL_ROUTE_PREFIX) + self.post("extra-batch-images", params, cb, base_url=url) + + def post_controlnet_preview(self, cb, src_img, width, height): + def get_pixel_perfect_preprocessor_resolution(): + if self.cfg("disable_sddebz_highres", bool): + return min(width, height) + + resized_width, resized_height = calculate_resized_image_dimensions( + self.cfg("sd_base_size", int), self.cfg("sd_max_size", int), width, height + ) + return min(resized_width, resized_height) + + unit = self.cfg("controlnet_unit", str) + preprocessor_resolution = get_pixel_perfect_preprocessor_resolution() if self.cfg(f"controlnet{unit}_pixel_perfect", bool) \ + else self.cfg(f"controlnet{unit}_preprocessor_resolution", int) + + params = ( + { + "controlnet_module": self.cfg(f"controlnet{unit}_preprocessor", str), + "controlnet_input_images": [img_to_b64(src_img)], + "controlnet_processor_res": preprocessor_resolution, + "controlnet_threshold_a": self.cfg(f"controlnet{unit}_threshold_a", float), + "controlnet_threshold_b": self.cfg(f"controlnet{unit}_threshold_b", float) + } #Not sure if it's necessary to make the just_use_yaml validation here + ) + url = get_url(self.cfg, prefix=CONTROLNET_ROUTE_PREFIX) + self.post("detect", params, cb, url) + def post_interrupt(self, cb): # get official API url url = get_url(self.cfg, prefix=OFFICIAL_ROUTE_PREFIX) @@ -405,4 +705,4 @@ def post_interrupt(self, cb): def get_progress(self, cb): # get official API url url = get_url(self.cfg, prefix=OFFICIAL_ROUTE_PREFIX) - self.get("progress", cb, base_url=url) + self.get("progress", cb, base_url=url) \ No newline at end of file diff --git a/frontends/krita/krita_diff/defaults.py b/frontends/krita/krita_diff/defaults.py index 9c3f100d..5d5597d0 100644 --- a/frontends/krita/krita_diff/defaults.py +++ b/frontends/krita/krita_diff/defaults.py @@ -25,6 +25,8 @@ THREADED = True ROUTE_PREFIX = "/sdapi/interpause/" OFFICIAL_ROUTE_PREFIX = "/sdapi/v1/" +CONTROLNET_ROUTE_PREFIX = "/controlnet/" +CONTROLNET_ROUTE_PREFIX = "/controlnet/" # error messages ERR_MISSING_CONFIG = "Report this bug, developer missed out a config key somewhere." @@ -39,8 +41,135 @@ TAB_IMG2IMG = "krita_diff_img2img" TAB_INPAINT = "krita_diff_inpaint" TAB_UPSCALE = "krita_diff_upscale" +TAB_CONTROLNET = "krita_diff_controlnet" +TAB_CONTROLNET = "krita_diff_controlnet" TAB_PREVIEW = "krita_diff_preview" +# controlnet +CONTROLNET_PREPROCESSOR_SETTINGS = { + "canny": { + "resolution_label": "Annotator resolution", + "threshold_a_label": "Canny low threshold", + "threshold_b_label": "Canny high threshold", + "threshold_a_value": 100, + "threshold_b_value": 200, + "threshold_a_min_value": 1, + "threshold_a_max_value": 255, + "threshold_b_min_value": 1, + "threshold_b_max_value": 255 + }, + "depth_leres": { + "resolution_label": "LeReS resolution", + "threshold_a_label": "Remove near %", + "threshold_b_label": "Remove background %", + "threshold_a_min_value": 0, + "threshold_a_max_value": 100, + "threshold_b_min_value": 0, + "threshold_b_max_value": 100 + }, + "depth_leres++": { + "resolution_label": "LeReS resolution", + "threshold_a_label": "Remove near %", + "threshold_b_label": "Remove background %", + "threshold_a_min_value": 0, + "threshold_a_max_value": 100, + "threshold_b_min_value": 0, + "threshold_b_max_value": 100 + }, + "mediapipe_face": { + "threshold_a_label": "Max Faces", + "threshold_b_label": "Min Face Confidence", + "threshold_a_min_value": 1, + "threshold_a_max_value": 10, + "threshold_b_min_value": 0.01, + "threshold_b_max_value": 1, + "threshold_b_step": 0.01 + }, + "normal_midas": { + "threshold_a_label": "Normal background threshold", + "threshold_a_value": 0.4, + "threshold_a_min_value": 0, + "threshold_a_max_value": 1, + "threshold_step": 0.01 + }, + "reference_adain": { + "threshold_a_label": "Style Fidelity (only for \"Balanced\" mode)", + "threshold_a_value": 0.5, + "threshold_a_min_value": 0, + "threshold_a_max_value": 1, + "threshold_step": 0.01 + }, + "reference_adain+attn": { + "threshold_a_label": "Style Fidelity (only for \"Balanced\" mode)", + "threshold_a_value": 0.5, + "threshold_a_min_value": 0, + "threshold_a_max_value": 1, + "threshold_step": 0.01 + }, + "reference_only": { + "threshold_a_label": "Style Fidelity (only for \"Balanced\" mode)", + "threshold_a_value": 0.5, + "threshold_a_min_value": 0, + "threshold_a_max_value": 1, + "threshold_step": 0.01 + }, + "scribble_xdog": { + "threshold_a_label": "XDoG Threshold", + "threshold_a_value": 32, + "threshold_a_min_value": 1, + "threshold_a_max_value": 64 + }, + "threshold": { + "threshold_a_label": "Binarization Threshold", + "threshold_a_value": 127, + "threshold_a_min_value": 0, + "threshold_a_max_value": 255 + }, + "tile_resample": { + "threshold_a_label": "Down Sampling Rate", + "threshold_a_value": 1, + "threshold_a_min_value": 1, + "threshold_a_max_value": 8, + "threshold_step": 0.01 + }, + "hed": { + "resolution_label": "HED resolution", + }, + "mlsd": { + "resolution_label": "Hough resolution", + "threshold_a_label": "Hough value threshold (MLSD)", + "threshold_b_label": "Hough distance threshold (MLSD)", + "threshold_a_value": 0.1, + "threshold_b_value": 0.1, + "threshold_a_min_value": 0.01, + "threshold_b_max_value": 2, + "threshold_a_min_value": 0.01, + "threshold_b_max_value": 20, + "threshold_step": 0.01 + }, + "normal_map": { + "threshold_a_label": "Normal background threshold", + "threshold_a_value": 0.4, + "threshold_a_min_value": 0, + "threshold_a_max_value": 1, + "threshold_step": 0.01 + }, + "openpose": {}, + "openpose_hand": {}, + "clip_vision": {}, + "color": {}, + "pidinet": {}, + "scribble": {}, + "fake_scribble": { + "resolution_label": "HED resolution", + }, + "segmentation": {}, + "binary": { + "threshold_a_label": "Binary threshold", + "threshold_a_min_value": 0, + "threshold_a_max_value": 255, + } +} @dataclass(frozen=True) class Defaults: @@ -128,5 +257,150 @@ class Defaults: upscale_upscaler_name: str = "None" upscale_downscale_first: bool = False + controlnet_unit: str = "0" + controlnet_unit_list: List[str] = field(default_factory=lambda: list(str(i) for i in range(10))) + controlnet_preprocessor_list: List[str] = field(default_factory=lambda: [ERROR_MSG]) + controlnet_model_list: List[str] = field(default_factory=lambda: [ERROR_MSG]) + controlnet_control_mode_list: List[str] = field(default_factory=lambda: ["Balanced", "My prompt is more important", "ControlNet is more important"]) + + controlnet0_enable: bool = False + controlnet0_low_vram: bool = False + controlnet0_pixel_perfect: bool = False + controlnet0_preprocessor: str = "None" + controlnet0_model: str = "None" + controlnet0_weight: float = 1.0 + controlnet0_guidance_start: float = 0 + controlnet0_guidance_end: float = 1 + controlnet0_preprocessor_resolution: int = 512 + controlnet0_threshold_a: float = 0 + controlnet0_threshold_b: float = 0 + controlnet0_input_image: str = "" + controlnet0_control_mode: str = "Balanced" + + controlnet1_enable: bool = False + controlnet1_low_vram: bool = False + controlnet1_pixel_perfect: bool = False + controlnet1_preprocessor: str = "None" + controlnet1_model: str = "None" + controlnet1_weight: float = 1.0 + controlnet1_guidance_start: float = 0 + controlnet1_guidance_end: float = 1 + controlnet1_preprocessor_resolution: int = 512 + controlnet1_threshold_a: float = 0 + controlnet1_threshold_b: float = 0 + controlnet1_input_image: str = "" + controlnet1_control_mode: str = "Balanced" + + controlnet2_enable: bool = False + controlnet2_low_vram: bool = False + controlnet2_pixel_perfect: bool = False + controlnet2_preprocessor: str = "None" + controlnet2_model: str = "None" + controlnet2_weight: float = 1.0 + controlnet2_guidance_start: float = 0 + controlnet2_guidance_end: float = 1 + controlnet2_preprocessor_resolution: int = 512 + controlnet2_threshold_a: float = 0 + controlnet2_threshold_b: float = 0 + controlnet2_input_image: str = "" + controlnet2_control_mode: str = "Balanced" + + controlnet3_enable: bool = False + controlnet3_low_vram: bool = False + controlnet3_pixel_perfect: bool = False + controlnet3_preprocessor: str = "None" + controlnet3_model: str = "None" + controlnet3_weight: float = 1.0 + controlnet3_guidance_start: float = 0 + controlnet3_guidance_end: float = 1 + controlnet3_preprocessor_resolution: int = 512 + controlnet3_threshold_a: float = 0 + controlnet3_threshold_b: float = 0 + controlnet3_input_image: str = "" + controlnet3_control_mode: str = "Balanced" + + controlnet4_enable: bool = False + controlnet4_low_vram: bool = False + controlnet4_pixel_perfect: bool = False + controlnet4_preprocessor: str = "None" + controlnet4_model: str = "None" + controlnet4_weight: float = 1.0 + controlnet4_guidance_start: float = 0 + controlnet4_guidance_end: float = 1 + controlnet4_preprocessor_resolution: int = 512 + controlnet4_threshold_a: float = 0 + controlnet4_threshold_b: float = 0 + controlnet4_input_image: str = "" + controlnet4_control_mode: str = "Balanced" + + controlnet5_enable: bool = False + controlnet5_low_vram: bool = False + controlnet5_pixel_perfect: bool = False + controlnet5_preprocessor: str = "None" + controlnet5_model: str = "None" + controlnet5_weight: float = 1.0 + controlnet5_guidance_start: float = 0 + controlnet5_guidance_end: float = 1 + controlnet5_preprocessor_resolution: int = 512 + controlnet5_threshold_a: float = 0 + controlnet5_threshold_b: float = 0 + controlnet5_input_image: str = "" + controlnet5_control_mode: str = "Balanced" + + controlnet6_enable: bool = False + controlnet6_low_vram: bool = False + controlnet6_pixel_perfect: bool = False + controlnet6_preprocessor: str = "None" + controlnet6_model: str = "None" + controlnet6_weight: float = 1.0 + controlnet6_guidance_start: float = 0 + controlnet6_guidance_end: float = 1 + controlnet6_preprocessor_resolution: int = 512 + controlnet6_threshold_a: float = 0 + controlnet6_threshold_b: float = 0 + controlnet6_input_image: str = "" + controlnet6_control_mode: str = "Balanced" + + controlnet7_enable: bool = False + controlnet7_low_vram: bool = False + controlnet7_pixel_perfect: bool = False + controlnet7_preprocessor: str = "None" + controlnet7_model: str = "None" + controlnet7_weight: float = 1.0 + controlnet7_guidance_start: float = 0 + controlnet7_guidance_end: float = 1 + controlnet7_preprocessor_resolution: int = 512 + controlnet7_threshold_a: float = 0 + controlnet7_threshold_b: float = 0 + controlnet7_input_image: str = "" + controlnet7_control_mode: str = "Balanced" + + controlnet8_enable: bool = False + controlnet8_low_vram: bool = False + controlnet8_pixel_perfect: bool = False + controlnet8_preprocessor: str = "None" + controlnet8_model: str = "None" + controlnet8_weight: float = 1.0 + controlnet8_guidance_start: float = 0 + controlnet8_guidance_end: float = 1 + controlnet8_preprocessor_resolution: int = 512 + controlnet8_threshold_a: float = 0 + controlnet8_threshold_b: float = 0 + controlnet8_input_image: str = "" + controlnet8_control_mode: str = "Balanced" + + controlnet9_enable: bool = False + controlnet9_low_vram: bool = False + controlnet9_pixel_perfect: bool = False + controlnet9_preprocessor: str = "None" + controlnet9_model: str = "None" + controlnet9_weight: float = 1.0 + controlnet9_guidance_start: float = 0 + controlnet9_guidance_end: float = 1 + controlnet9_preprocessor_resolution: int = 512 + controlnet9_threshold_a: float = 0 + controlnet9_threshold_b: float = 0 + controlnet9_input_image: str = "" + controlnet9_control_mode: str = "Balanced" DEFAULTS = Defaults() diff --git a/frontends/krita/krita_diff/pages/__init__.py b/frontends/krita/krita_diff/pages/__init__.py index aba4901a..7715bc20 100644 --- a/frontends/krita/krita_diff/pages/__init__.py +++ b/frontends/krita/krita_diff/pages/__init__.py @@ -5,3 +5,4 @@ from .preview import PreviewPage from .txt2img import Txt2ImgPage from .upscale import UpscalePage +from .controlnet import ControlNetPage \ No newline at end of file diff --git a/frontends/krita/krita_diff/pages/config.py b/frontends/krita/krita_diff/pages/config.py index 9060c58d..a9a813bf 100644 --- a/frontends/krita/krita_diff/pages/config.py +++ b/frontends/krita/krita_diff/pages/config.py @@ -153,6 +153,7 @@ def cfg_connect(self): self.base_url.textChanged.connect(partial(script.cfg.set, "base_url")) # NOTE: this triggers on every keystroke; theres no focus lost signal... self.base_url.textChanged.connect(lambda: script.action_update_config()) + self.base_url.textChanged.connect(lambda: script.action_update_controlnet_config()) self.base_url_reset.released.connect( lambda: self.base_url.setText(DEFAULTS.base_url) ) @@ -178,6 +179,7 @@ def restore_defaults(): script.cfg.set("first_setup", False) # retrieve list of available stuff again script.action_update_config() + script.action_update_controlnet_config() self.refresh_btn.released.connect(lambda: script.action_update_config()) self.restore_defaults.released.connect(restore_defaults) diff --git a/frontends/krita/krita_diff/pages/controlnet.py b/frontends/krita/krita_diff/pages/controlnet.py new file mode 100644 index 00000000..b5c570b2 --- /dev/null +++ b/frontends/krita/krita_diff/pages/controlnet.py @@ -0,0 +1,351 @@ +from krita import ( + QApplication, + QPixmap, + QImage, + QPushButton, + QWidget, + QVBoxLayout, + QHBoxLayout, + QStackedLayout, + Qt +) + +from functools import partial +from ..defaults import CONTROLNET_PREPROCESSOR_SETTINGS +from ..script import script +from ..widgets import ( + QLabel, + StatusBar, + ImageLoaderLayout, + QCheckBox, + TipsLayout, + QComboBoxLayout, + QSpinBoxLayout +) +from ..utils import img_to_b64, b64_to_img + +class ControlNetPage(QWidget): + name = "ControlNet" + + def __init__(self, *args, **kwargs): + super(ControlNetPage, self).__init__(*args, **kwargs) + self.status_bar = StatusBar() + self.controlnet_unit = QComboBoxLayout( + script.cfg, "controlnet_unit_list", "controlnet_unit", label="Unit:" + ) + self.controlnet_unit_layout_list = list(ControlNetUnitSettings(i) + for i in range(len(script.cfg("controlnet_unit_list")))) + + self.units_stacked_layout = QStackedLayout() + + for unit_layout in self.controlnet_unit_layout_list: + self.units_stacked_layout.addWidget(unit_layout) + + layout = QVBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self.status_bar) + layout.addLayout(self.controlnet_unit) + layout.addLayout(self.units_stacked_layout) + self.setLayout(layout) + + def controlnet_unit_changed(self, selected: str): + self.units_stacked_layout.setCurrentIndex(int(selected)) + + def cfg_init(self): + self.controlnet_unit.cfg_init() + + for controlnet_unit_layout in self.controlnet_unit_layout_list: + controlnet_unit_layout.cfg_init() + + self.controlnet_unit_changed(self.controlnet_unit.qcombo.currentText()) + + def cfg_connect(self): + self.controlnet_unit.cfg_connect() + + for controlnet_unit_layout in self.controlnet_unit_layout_list: + controlnet_unit_layout.cfg_connect() + + self.controlnet_unit.qcombo.currentTextChanged.connect(self.controlnet_unit_changed) + script.status_changed.connect(lambda s: self.status_bar.set_status(s)) + +class ControlNetUnitSettings(QWidget): + def __init__(self, cfg_unit_number: int = 0, *args, **kwargs): + super(ControlNetUnitSettings, self).__init__(*args, **kwargs) + self.unit = cfg_unit_number + self.preview_result = QPixmap() #This will help us to copy to clipboard the image with original dimensions. + + #Top checkbox + self.enable = QCheckBox( + script.cfg, f"controlnet{self.unit}_enable", f"Enable ControlNet {self.unit}" + ) + + self.image_loader = ImageLoaderLayout() + input_image = script.cfg(f"controlnet{self.unit}_input_image", str) + self.image_loader.preview.setPixmap( + QPixmap.fromImage(b64_to_img(input_image) if input_image else QImage()) + ) + + #Main settings + self.low_vram = QCheckBox( + script.cfg, f"controlnet{self.unit}_low_vram", "Low VRAM" + ) + + self.pixel_perfect = QCheckBox( + script.cfg, f"controlnet{self.unit}_pixel_perfect", "Pixel Perfect" + ) + + #Tips + self.tips = TipsLayout( + ["Invert colors if your image has white background.", + "Selection will be used as input if no image has been uploaded or pasted.", + "Remember to set multi-controlnet in the backend as well if you want to use more than one unit.", + "Enable pixel perfect if you want the preprocessor to automatically adjust to the selection size (respects base/max size)"] + ) + + #Preprocessor list + self.preprocessor_layout = QComboBoxLayout( + script.cfg, "controlnet_preprocessor_list", f"controlnet{self.unit}_preprocessor", label="Preprocessor:" + ) + + #Model list + self.model_layout = QComboBoxLayout( + script.cfg, "controlnet_model_list", f"controlnet{self.unit}_model", label="Model:" + ) + + #Refresh button + self.refresh_button = QPushButton("Refresh") + + self.weight_layout = QSpinBoxLayout( + script.cfg, f"controlnet{self.unit}_weight", label="Weight:", min=0, max=2, step=0.05 + ) + self.guidance_start_layout = QSpinBoxLayout( + script.cfg, f"controlnet{self.unit}_guidance_start", label="Guidance start:", min=0, max=1, step=0.01 + ) + self.guidance_end_layout = QSpinBoxLayout( + script.cfg, f"controlnet{self.unit}_guidance_end", label="Guidance end:", min=0, max=1, step=0.01 + ) + + self.control_mode = QComboBoxLayout( + script.cfg, "controlnet_control_mode_list", f"controlnet{self.unit}_control_mode", label="Control mode:" + ) + + #Preprocessor settings + self.annotator_resolution = QSpinBoxLayout( + script.cfg, + f"controlnet{self.unit}_preprocessor_resolution", + label="Preprocessor resolution:", + min=64, + max=2048, + step=1 + ) + self.threshold_a = QSpinBoxLayout( + script.cfg, + f"controlnet{self.unit}_threshold_a", + label="Threshold A:", + min=1, + max=255, + step=1, + always_float=True + ) + self.threshold_b = QSpinBoxLayout( + script.cfg, + f"controlnet{self.unit}_threshold_b", + label="Threshold B:", + min=1, + max=255, + step=1, + always_float=True + ) + + #Preview annotator + self.annotator_preview = QLabel() + self.annotator_preview.setAlignment(Qt.AlignCenter) + self.annotator_preview_button = QPushButton("Preview annotator") + self.annotator_clear_button = QPushButton("Clear preview") + self.copy_result_button = QPushButton("Copy result to clipboard") + + main_settings_layout_2 = QHBoxLayout() + main_settings_layout_2.addWidget(self.low_vram) + main_settings_layout_2.addWidget(self.pixel_perfect) + + guidance_layout = QHBoxLayout() + guidance_layout.addLayout(self.guidance_start_layout) + guidance_layout.addLayout(self.guidance_end_layout) + + threshold_layout = QHBoxLayout() + threshold_layout.addLayout(self.threshold_a) + threshold_layout.addLayout(self.threshold_b) + + layout = QVBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self.enable) + layout.addLayout(self.image_loader) + layout.addLayout(self.tips) + layout.addLayout(main_settings_layout_2) + layout.addLayout(self.preprocessor_layout) + layout.addLayout(self.model_layout) + layout.addWidget(self.refresh_button) + layout.addLayout(self.weight_layout) + layout.addLayout(guidance_layout) + layout.addLayout(self.control_mode) + layout.addLayout(self.annotator_resolution) + layout.addLayout(threshold_layout) + layout.addWidget(self.annotator_preview) + layout.addWidget(self.annotator_preview_button) + layout.addWidget(self.copy_result_button) + layout.addWidget(self.annotator_clear_button) + layout.addStretch() + + self.setLayout(layout) + + self.cfg_init() + self.set_preprocessor_options(self.preprocessor_layout.qcombo.currentText()) + + def set_preprocessor_options(self, selected: str): + if selected in CONTROLNET_PREPROCESSOR_SETTINGS: + self.show_preprocessor_options() + self.annotator_resolution.qlabel.setText(CONTROLNET_PREPROCESSOR_SETTINGS[selected]["resolution_label"] \ + if "resolution_label" in CONTROLNET_PREPROCESSOR_SETTINGS[selected] else "Preprocessor resolution:") + + if "threshold_a_label" in CONTROLNET_PREPROCESSOR_SETTINGS[selected]: + self.threshold_a.qlabel.show() + self.threshold_a.qspin.show() + self.threshold_a.qlabel.setText(CONTROLNET_PREPROCESSOR_SETTINGS[selected]["threshold_a_label"]) + self.threshold_a.qspin.setMinimum(CONTROLNET_PREPROCESSOR_SETTINGS[selected]["threshold_a_min_value"] \ + if "threshold_a_min_value" in CONTROLNET_PREPROCESSOR_SETTINGS[selected] else 0) + self.threshold_a.qspin.setMaximum(CONTROLNET_PREPROCESSOR_SETTINGS[selected]["threshold_a_max_value"] \ + if "threshold_a_max_value" in CONTROLNET_PREPROCESSOR_SETTINGS[selected] else 0) + self.threshold_a.qspin.setValue(CONTROLNET_PREPROCESSOR_SETTINGS[selected]["threshold_a_value"] \ + if "threshold_a_value" in CONTROLNET_PREPROCESSOR_SETTINGS[selected] else self.threshold_a.qspin.minimum()) + self.threshold_a.qspin.setSingleStep(CONTROLNET_PREPROCESSOR_SETTINGS[selected]["threshold_step"] \ + if "threshold_step" in CONTROLNET_PREPROCESSOR_SETTINGS[selected] else 1) + else: + self.threshold_a.qlabel.hide() + self.threshold_a.qspin.hide() + + if "threshold_b_label" in CONTROLNET_PREPROCESSOR_SETTINGS[selected]: + self.threshold_b.qlabel.show() + self.threshold_b.qspin.show() + self.threshold_b.qlabel.setText(CONTROLNET_PREPROCESSOR_SETTINGS[selected]["threshold_b_label"]) + self.threshold_b.qspin.setMinimum(CONTROLNET_PREPROCESSOR_SETTINGS[selected]["threshold_b_min_value"] \ + if "threshold_b_min_value" in CONTROLNET_PREPROCESSOR_SETTINGS[selected] else 0) + self.threshold_b.qspin.setMaximum(CONTROLNET_PREPROCESSOR_SETTINGS[selected]["threshold_b_max_value"] \ + if "threshold_b_max_value" in CONTROLNET_PREPROCESSOR_SETTINGS[selected] else 0) + self.threshold_b.qspin.setValue(CONTROLNET_PREPROCESSOR_SETTINGS[selected]["threshold_b_value"] \ + if "threshold_b_value" in CONTROLNET_PREPROCESSOR_SETTINGS[selected] else self.threshold_b.qspin.minimum()) + self.threshold_b.qspin.setSingleStep(CONTROLNET_PREPROCESSOR_SETTINGS[selected]["threshold_b_step"] \ + if "threshold_b_step" in CONTROLNET_PREPROCESSOR_SETTINGS[selected] else 1) + else: + self.threshold_b.qlabel.hide() + self.threshold_b.qspin.hide() + else: + self.hide_preprocessor_options(selected) + + def hide_preprocessor_options(self, selected: str): + #Hide all annotator settings if no annotator chosen. + #if there is an annotator that hasn't been listed in defaults, + #just show resolution option. Users may be able to play + #with new unsupported annotators, but they may or not work. + if selected == "none": + self.annotator_resolution.qlabel.hide() + self.annotator_resolution.qspin.hide() + + self.threshold_a.qlabel.hide() + self.threshold_a.qspin.hide() + self.threshold_b.qlabel.hide() + self.threshold_b.qspin.hide() + + def show_preprocessor_options(self): + self.annotator_resolution.qlabel.show() + self.annotator_resolution.qspin.show() + self.threshold_a.qlabel.show() + self.threshold_a.qspin.show() + self.threshold_b.qlabel.show() + self.threshold_b.qspin.show() + + def enable_changed(self, state): + if state == 1 or state == 2: + script.action_update_controlnet_config() + + def image_loaded(self): + image = self.image_loader.preview.pixmap().toImage().convertToFormat(QImage.Format_RGBA8888) + script.cfg.set(f"controlnet{self.unit}_input_image", img_to_b64(image)) + + def annotator_preview_received(self, pixmap): + self.preview_result = pixmap + if pixmap.width() > self.annotator_preview.width(): + pixmap = pixmap.scaledToWidth(self.annotator_preview.width(), Qt.SmoothTransformation) + self.annotator_preview.setPixmap(pixmap) + + def annotator_clear_button_released(self): + self.annotator_preview.setPixmap(QPixmap()) + self.preview_result = QPixmap() + + def copy_result_released(self): + if self.preview_result: + clipboard = QApplication.clipboard() + clipboard.setImage(self.preview_result.toImage()) + + def hide_or_show_preprocessor_resolution(self, pixel_perfect): + if pixel_perfect: + self.annotator_resolution.qlabel.hide() + self.annotator_resolution.qspin.hide() + else: + self.annotator_resolution.qlabel.show() + self.annotator_resolution.qspin.show() + + def cfg_init(self): + self.enable.cfg_init() + self.low_vram.cfg_init() + self.pixel_perfect.cfg_init() + self.preprocessor_layout.cfg_init() + self.model_layout.cfg_init() + self.weight_layout.cfg_init() + self.guidance_start_layout.cfg_init() + self.guidance_end_layout.cfg_init() + self.control_mode.cfg_init() + self.annotator_resolution.cfg_init() + self.threshold_a.cfg_init() + self.threshold_b.cfg_init() + + self.hide_or_show_preprocessor_resolution(self.pixel_perfect.isChecked()) + + if self.preprocessor_layout.qcombo.currentText() == "none": + self.annotator_preview_button.setEnabled(False) + else: + self.annotator_preview_button.setEnabled(True) + + def cfg_connect(self): + self.enable.cfg_connect() + self.low_vram.cfg_connect() + self.pixel_perfect.cfg_connect() + self.preprocessor_layout.cfg_connect() + self.model_layout.cfg_connect() + self.weight_layout.cfg_connect() + self.guidance_start_layout.cfg_connect() + self.guidance_end_layout.cfg_connect() + self.control_mode.cfg_connect() + self.annotator_resolution.cfg_connect() + self.threshold_a.cfg_connect() + self.threshold_b.cfg_connect() + self.enable.stateChanged.connect(self.enable_changed) + self.image_loader.import_button.released.connect(self.image_loaded) + self.image_loader.paste_button.released.connect(self.image_loaded) + self.image_loader.clear_button.released.connect( + partial(script.cfg.set, f"controlnet{self.unit}_input_image", "") + ) + self.pixel_perfect.stateChanged.connect( + lambda: self.hide_or_show_preprocessor_resolution(self.pixel_perfect.isChecked()) + ) + self.preprocessor_layout.qcombo.currentTextChanged.connect(self.set_preprocessor_options) + self.preprocessor_layout.qcombo.currentTextChanged.connect( + lambda: self.annotator_preview_button.setEnabled(False) if + self.preprocessor_layout.qcombo.currentText() == "none" else self.annotator_preview_button.setEnabled(True) + ) + self.refresh_button.released.connect(lambda: script.action_update_controlnet_config()) + self.annotator_preview_button.released.connect( + lambda: script.action_preview_controlnet_annotator() + ) + self.copy_result_button.released.connect(self.copy_result_released) + self.annotator_clear_button.released.connect(lambda: self.annotator_preview.setPixmap(QPixmap())) + script.controlnet_preview_annotator_received.connect(self.annotator_preview_received) \ No newline at end of file diff --git a/frontends/krita/krita_diff/script.py b/frontends/krita/krita_diff/script.py index 069eb27f..19031627 100644 --- a/frontends/krita/krita_diff/script.py +++ b/frontends/krita/krita_diff/script.py @@ -8,11 +8,15 @@ Krita, Node, QImage, + QColor, + QPainter, QObject, + QPixmap, + QRect, Qt, QTimer, Selection, - pyqtSignal, + pyqtSignal ) from .client import Client @@ -63,6 +67,7 @@ class Script(QObject): status_changed = pyqtSignal(str) config_updated = pyqtSignal() progress_update = pyqtSignal(object) + controlnet_preview_annotator_received = pyqtSignal(QPixmap) def __init__(self): super(Script, self).__init__() @@ -165,25 +170,34 @@ def get_selection_image(self) -> QImage: QImage.Format_RGBA8888, ).rgbSwapped() - def get_mask_image(self) -> Union[QImage, None]: + def get_mask_image(self, using_official_api) -> Union[QImage, None]: """QImage of mask layer for inpainting""" if self.node.type() not in {"paintlayer", "filelayer"}: assert False, "Please select a valid layer to use as inpaint mask!" elif self.node in self._inserted_layers: assert False, "Selected layer was generated. Copy the layer if sure you want to use it as inpaint mask." - return QImage( + mask = QImage( self.node.pixelData(self.x, self.y, self.width, self.height), self.width, self.height, QImage.Format_RGBA8888, - ).rgbSwapped() + ) - def img_inserter(self, x, y, width, height, group=False): + if using_official_api: + # Official API requires a black and white mask. + # Fastest way to do this: Convert to 1 channel alpha, tell it that + # it's grayscale, convert that to RGBA. + mask = mask.convertToFormat(QImage.Format_Alpha8) + mask.reinterpretAsFormat(QImage.Format_Grayscale8) + mask = mask.convertToFormat(QImage.Format_RGBA8888) + + return mask.rgbSwapped() + + def img_inserter(self, x, y, width, height, inpaint=False, glayer=None): """Return frozen image inserter to insert images as new layer.""" # Selection may change before callback, so freeze selection region has_selection = self.selection is not None - glayer = self.doc.createGroupLayer("Unnamed Group") if group else None def create_layer(name: str): """Create new layer in document or group""" @@ -191,11 +205,10 @@ def create_layer(name: str): parent = self.doc.rootNode() if glayer: glayer.addChildNode(layer, None) - parent.addChildNode(glayer, None) else: parent.addChildNode(layer, None) return layer - + def insert(layer_name, enc): nonlocal x, y, width, height, has_selection print(f"inserting layer {layer_name}") @@ -210,11 +223,11 @@ def insert(layer_name, enc): f"image created: {image}, {image.width()}x{image.height()}, depth: {image.depth()}, format: {image.format()}" ) - # NOTE: Scaling is usually done by backend (although I am reconsidering this) - # The scaling here is for SD Upscale or Upscale on a selection region rather than whole image + # NOTE: Scaling must be done by the frontend when using the official API. + # The scaling here is for SD Upscale, Upscale on a selection region, or inpainting. # Image won't be scaled down ONLY if there is no selection; i.e. selecting whole image will scale down, # not selecting anything won't scale down, leading to the canvas being resized afterwards - if has_selection and (image.width() != width or image.height() != height): + if (has_selection or inpaint) and (image.width() != width or image.height() != height): print(f"Rescaling image to selection: {width}x{height}") image = image.scaled( width, height, transformMode=Qt.SmoothTransformation @@ -246,14 +259,35 @@ def insert(layer_name, enc): print(f"inserting at x: {x}, y: {y}, w: {width}, h: {height}") layer.setPixelData(ba, x, y, width, height) self._inserted_layers.append(layer) + return layer - return insert, glayer + return insert + + def check_controlnet_enabled(self): + for i in range(len(self.cfg("controlnet_unit_list", "QStringList"))): + if self.cfg(f"controlnet{i}_enable", bool): + return True + + def get_controlnet_input_images(self, selected): + input_images = dict() + + for i in range(len(self.cfg("controlnet_unit_list", "QStringList"))): + if self.cfg(f"controlnet{i}_enable", bool): + input_image = b64_to_img(self.cfg(f"controlnet{i}_input_image", str)) if \ + self.cfg(f"controlnet{i}_input_image", str) else selected + + input_images.update({f"{i}": input_image}) + + return input_images def apply_txt2img(self): # freeze selection region - insert, glayer = self.img_inserter( - self.x, self.y, self.width, self.height, not self.cfg("no_groups", bool) + controlnet_enabled = self.check_controlnet_enabled() + glayer = self.doc.createGroupLayer("Unnamed Group") + self.doc.rootNode().addChildNode(glayer, None) + insert = self.img_inserter( + self.x, self.y, self.width, self.height, False, glayer ) mask_trigger = self.transparency_mask_inserter() @@ -261,7 +295,8 @@ def cb(response): if len(self.client.long_reqs) == 1: # last request self.eta_timer.stop() assert response is not None, "Backend Error, check terminal" - outputs = response["outputs"] + #response key varies for official api used for controlnet + outputs = response["outputs"] if not controlnet_enabled else response["images"] glayer_name, layer_names = get_desc_from_resp(response, "txt2img") layers = [ insert(name if name else f"txt2img {i + 1}", output) @@ -276,16 +311,28 @@ def cb(response): mask_trigger(layers) self.eta_timer.start(ETA_REFRESH_INTERVAL) - self.client.post_txt2img( - cb, self.width, self.height, self.selection is not None - ) + + if controlnet_enabled: + sel_image = self.get_selection_image() + self.client.post_official_api_txt2img( + cb, self.width, self.height, self.selection is not None, + self.get_controlnet_input_images(sel_image) + ) + else: + self.client.post_txt2img( + cb, self.width, self.height, self.selection is not None + ) def apply_img2img(self, is_inpaint): - insert, glayer = self.img_inserter( - self.x, self.y, self.width, self.height, not self.cfg("no_groups", bool) - ) + controlnet_enabled = self.check_controlnet_enabled() + mask_trigger = self.transparency_mask_inserter() - mask_image = self.get_mask_image() + mask_image = self.get_mask_image(controlnet_enabled) if is_inpaint else None + glayer = self.doc.createGroupLayer("Unnamed Group") + self.doc.rootNode().addChildNode(glayer, None) + insert = self.img_inserter( + self.x, self.y, self.width, self.height, is_inpaint, glayer + ) path = os.path.join(self.cfg("sample_path", str), f"{int(time.time())}.png") mask_path = os.path.join( @@ -296,6 +343,7 @@ def apply_img2img(self, is_inpaint): save_img(mask_image, mask_path) # auto-hide mask layer before getting selection image self.node.setVisible(False) + self.controlnet_transparency_mask_inserter(glayer, mask_image) self.doc.refreshProjection() sel_image = self.get_selection_image() @@ -303,11 +351,42 @@ def apply_img2img(self, is_inpaint): save_img(sel_image, path) def cb(response): + def cb_upscale(upscale_response): + if len(self.client.long_reqs) == 1: # last request + self.eta_timer.stop() + assert response is not None, "Backend Error, check terminal" + + outputs = upscale_response["images"] + layer_name_prefix = "inpaint" if is_inpaint else "img2img" + glayer_name, layer_names = get_desc_from_resp(response, layer_name_prefix) + layers = [ + insert(name if name else f"{layer_name_prefix} {i + 1}", output) + for output, name, i in zip(outputs, layer_names, itertools.count()) + ] + if self.cfg("hide_layers", bool): + for layer in layers[:-1]: + layer.setVisible(False) + if glayer: + glayer.setName(glayer_name) + self.doc.refreshProjection() + if len(self.client.long_reqs) == 1: # last request self.eta_timer.stop() assert response is not None, "Backend Error, check terminal" - outputs = response["outputs"] + outputs = response["outputs"] if not controlnet_enabled else response["images"] + + if controlnet_enabled: + if min(self.width,self.height) > self.cfg("sd_base_size", int) \ + or max(self.width,self.height) > self.cfg("sd_max_size", int): + # this only handles with base/max size enabled + self.client.post_official_api_upscale_postprocess( + cb_upscale, outputs, self.width, self.height) + else: + # passing response directly to the callback works fine + cb_upscale(response) + return + layer_name_prefix = "inpaint" if is_inpaint else "img2img" glayer_name, layer_names = get_desc_from_resp(response, layer_name_prefix) layers = [ @@ -324,17 +403,89 @@ def cb(response): if not is_inpaint: mask_trigger(layers) - method = self.client.post_inpaint if is_inpaint else self.client.post_img2img self.eta_timer.start() - method( - cb, - sel_image, - mask_image, # is unused by backend in img2img mode - self.selection is not None, - ) + if controlnet_enabled: + if is_inpaint: + self.client.post_official_api_inpaint( + cb, sel_image, mask_image, self.width, self.height, self.selection is not None, + self.get_controlnet_input_images(sel_image)) + else: + self.client.post_official_api_img2img( + cb, sel_image, self.width, self.height, self.selection is not None, + self.get_controlnet_input_images(sel_image)) + else: + method = self.client.post_inpaint if is_inpaint else self.client.post_img2img + method( + cb, + sel_image, + mask_image, # is unused by backend in img2img mode + self.selection is not None, + ) + + def controlnet_transparency_mask_inserter(self, glayer, mask_image): + orig_selection = self.selection.duplicate() if self.selection else None + create_mask = self.cfg("create_mask_layer", bool) + add_mask_action = self.app.action("add_new_transparency_mask") + merge_mask_action = self.app.action("flatten_layer") + + if orig_selection: + sx = orig_selection.x() + sy = orig_selection.y() + sw = orig_selection.width() + sh = orig_selection.height() + else: + sx = 0 + sy = 0 + sw = self.doc.width() + sh = self.doc.height() + + # must convert mask to single channel format + gray_mask = mask_image.convertToFormat(QImage.Format_Grayscale8) + + mw = gray_mask.width() + mh = gray_mask.height() + # crop mask to the actual selection size + crop_rect = QRect((mw - sw)/2,(mh - sh)/2, sw, sh) + crop_mask = gray_mask.copy(crop_rect) + + mask_ba = img_to_ba(crop_mask) + + # Why is sizeInBytes() different from width * height? Just... why? + w = crop_mask.bytesPerLine() + h = crop_mask.sizeInBytes()/w + + mask_selection = Selection() + mask_selection.setPixelData(mask_ba, sx, sy, w, h) + + def apply_mask_when_ready(): + # glayer will be selected when it is done being created + if self.doc.activeNode() == glayer: + self.doc.setSelection(mask_selection) + add_mask_action.trigger() + self.doc.setSelection(orig_selection) + timer.stop() + + timer = QTimer() + timer.timeout.connect(apply_mask_when_ready) + timer.start(0.05) + + def apply_controlnet_preview_annotator(self): + unit = self.cfg("controlnet_unit", str) + if self.cfg(f"controlnet{unit}_input_image"): + image = b64_to_img(self.cfg(f"controlnet{unit}_input_image")) + else: + image = self.get_selection_image() + + def cb(response): + assert response is not None, "Backend Error, check terminal" + output = response["images"][0] + pixmap = QPixmap.fromImage(b64_to_img(output)) + self.controlnet_preview_annotator_received.emit(pixmap) + + self.client.post_controlnet_preview(cb, image, self.width, self.height) def apply_simple_upscale(self): - insert, _ = self.img_inserter(self.x, self.y, self.width, self.height) + insert = self.img_inserter(self.x, self.y, self.width, self.height) sel_image = self.get_selection_image() path = os.path.join(self.cfg("sample_path", str), f"{int(time.time())}.png") @@ -379,7 +530,7 @@ def restore(): self.doc.setActiveNode(layer) self.doc.setSelection(orig_selection) add_mask_action.trigger() - + if create_mask: # collapse transparency mask by default layer.setCollapsed(True) @@ -444,6 +595,31 @@ def action_simple_upscale(self): def action_update_config(self): """Update certain config/state from the backend.""" self.client.get_config() + + def action_update_controlnet_config(self): + """Update controlnet config from the backend.""" + self.client.get_controlnet_config() + + def action_preview_controlnet_annotator(self): + self.status_changed.emit(STATE_WAIT) + self.update_selection() + if not self.doc: + return + self.adjust_selection() + self.apply_controlnet_preview_annotator() + + + def action_update_controlnet_config(self): + """Update controlnet config from the backend.""" + self.client.get_controlnet_config() + + def action_preview_controlnet_annotator(self): + self.status_changed.emit(STATE_WAIT) + self.update_selection() + if not self.doc: + return + self.adjust_selection() + self.apply_controlnet_preview_annotator() def action_interrupt(self): def cb(resp=None): @@ -455,4 +631,4 @@ def action_update_eta(self): self.client.get_progress(self.progress_update.emit) -script = Script() +script = Script() \ No newline at end of file diff --git a/frontends/krita/krita_diff/utils.py b/frontends/krita/krita_diff/utils.py index ff948be1..d12b2a62 100644 --- a/frontends/krita/krita_diff/utils.py +++ b/frontends/krita/krita_diff/utils.py @@ -14,6 +14,7 @@ TAB_SDCOMMON, TAB_TXT2IMG, TAB_UPSCALE, + TAB_CONTROLNET ) @@ -50,14 +51,11 @@ def get_ext_args(ext_cfg: Config, ext_type: str, ext_name: str): args.append(val) return args - -def find_fixed_aspect_ratio( - base_size: int, max_size: int, orig_width: int, orig_height: int +def calculate_resized_image_dimensions( + base_size: int, max_size: int, orig_width: int, orig_height: int ): - """Copy of `krita_server.utils.sddebz_highres_fix()`. - - This is used by `find_optimal_selection_region()` below to adjust the selected region. - """ + """Finds the dimensions of the resized images based on base_size and max_size. + See https://github.com/Interpause/auto-sd-paint-ext#faq for more details.""" def rnd(r, x, z=64): """Scale dimension x with stride z while attempting to preserve aspect ratio r.""" @@ -75,9 +73,19 @@ def rnd(r, x, z=64): width, height = base_size, rnd(1 / ratio, base_size) if height > max_size: width, height = rnd(ratio, max_size), max_size + + return width, height + +def find_fixed_aspect_ratio( + base_size: int, max_size: int, orig_width: int, orig_height: int +): + """Copy of `krita_server.utils.sddebz_highres_fix()`. - return width / height - + This is used by `find_optimal_selection_region()` below to adjust the selected region. + """ + width, height = calculate_resized_image_dimensions(base_size, max_size, orig_width, orig_height) + + return width/height def find_optimal_selection_region( base_size: int, @@ -179,7 +187,7 @@ def img_to_b64(img: QImage): def b64_to_img(enc: str): """Converts base64-encoded string to QImage""" ba = QByteArray.fromBase64(enc.encode("utf-8")) - return QImage.fromData(ba, "PNG") + return QImage.fromData(ba) #Removed explicit format to support other image formats. def bytewise_xor(msg: bytes, key: bytes): @@ -212,6 +220,7 @@ def reset_docker_layout(): docker_ids = { TAB_SDCOMMON, TAB_CONFIG, + TAB_CONTROLNET, TAB_IMG2IMG, TAB_TXT2IMG, TAB_UPSCALE, @@ -232,9 +241,11 @@ def reset_docker_layout(): qmainwindow.addDockWidget(Qt.LeftDockWidgetArea, d) qmainwindow.tabifyDockWidget(dockers[TAB_SDCOMMON], dockers[TAB_CONFIG]) + qmainwindow.tabifyDockWidget(dockers[TAB_SDCOMMON], dockers[TAB_CONTROLNET]) qmainwindow.tabifyDockWidget(dockers[TAB_SDCOMMON], dockers[TAB_PREVIEW]) qmainwindow.tabifyDockWidget(dockers[TAB_TXT2IMG], dockers[TAB_IMG2IMG]) qmainwindow.tabifyDockWidget(dockers[TAB_TXT2IMG], dockers[TAB_INPAINT]) qmainwindow.tabifyDockWidget(dockers[TAB_TXT2IMG], dockers[TAB_UPSCALE]) dockers[TAB_SDCOMMON].raise_() dockers[TAB_INPAINT].raise_() + diff --git a/frontends/krita/krita_diff/widgets/__init__.py b/frontends/krita/krita_diff/widgets/__init__.py index bb46b96d..5b950669 100644 --- a/frontends/krita/krita_diff/widgets/__init__.py +++ b/frontends/krita/krita_diff/widgets/__init__.py @@ -6,3 +6,4 @@ from .spin_box import QSpinBoxLayout from .status_bar import StatusBar from .tips import TipsLayout +from .image_loader import ImageLoaderLayout diff --git a/frontends/krita/krita_diff/widgets/image_loader.py b/frontends/krita/krita_diff/widgets/image_loader.py new file mode 100644 index 00000000..321341ce --- /dev/null +++ b/frontends/krita/krita_diff/widgets/image_loader.py @@ -0,0 +1,48 @@ +from krita import QApplication, QFileDialog, QPixmap, QPushButton, QVBoxLayout, QHBoxLayout, Qt +from ..widgets import QLabel + +class ImageLoaderLayout(QVBoxLayout): + def __init__(self, *args, **kwargs): + super(ImageLoaderLayout, self).__init__(*args, **kwargs) + + self.preview = QLabel() + self.preview.setAlignment(Qt.AlignCenter) + self.import_button = QPushButton('Import image') + self.paste_button = QPushButton('Paste image') + self.clear_button = QPushButton('Clear') + + button_layout = QHBoxLayout() + button_layout.addWidget(self.import_button) + button_layout.addWidget(self.paste_button) + + self.addLayout(button_layout) + self.addWidget(self.clear_button) + self.addWidget(self.preview) + + self.import_button.released.connect(self.load_image) + self.paste_button.released.connect(self.paste_image) + self.clear_button.released.connect(self.clear_image) + + def load_image(self): + file_name, _ = QFileDialog.getOpenFileName(self.import_button, 'Open File', '', 'Image Files (*.png *.jpg *.bmp)') + if file_name: + pixmap = QPixmap(file_name) + + if pixmap.width() > self.preview.width(): + pixmap = pixmap.scaledToWidth(self.preview.width(), Qt.SmoothTransformation) + + self.preview.setPixmap(pixmap) + + def paste_image(self): + pixmap = QPixmap(QApplication.clipboard().pixmap()) + + if pixmap.width() > self.preview.width(): + pixmap = pixmap.scaledToWidth(self.preview.width(), Qt.SmoothTransformation) + + self.preview.setPixmap(pixmap) + + def clear_image(self): + self.preview.setPixmap(QPixmap()) + + + \ No newline at end of file diff --git a/frontends/krita/krita_diff/widgets/spin_box.py b/frontends/krita/krita_diff/widgets/spin_box.py index 89eadf15..88e6f818 100644 --- a/frontends/krita/krita_diff/widgets/spin_box.py +++ b/frontends/krita/krita_diff/widgets/spin_box.py @@ -17,6 +17,7 @@ def __init__( min: Union[int, float] = 0.0, max: Union[int, float] = 1.0, step: Union[int, float] = 0.1, + always_float: bool = False, #Workaround for controlnet threshold spin boxes *args, **kwargs ): @@ -40,6 +41,7 @@ def __init__( self.qlabel = QLabel(field_cfg if label is None else label) is_integer = ( + not always_float and float(step).is_integer() and float(min).is_integer() and float(max).is_integer()