Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 33 additions & 19 deletions demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,50 @@
from diffusers import StableDiffusionPipeline
from free_lunch_utils import register_free_upblock2d, register_free_crossattn_upblock2d



# Model and pipeline setup
model_id = "stabilityai/stable-diffusion-2-1"
# model_id = "./stable-diffusion-2-1"
pip_2_1 = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pip_2_1 = pip_2_1.to("cuda")

# Global variables
prompt_prev = None
sd_options_prev = None
seed_prev = None
sd_image_prev = None

def infer(prompt, sd_options, seed, b1, b2, s1, s2):
"""
Inference function for generating images.

Args:
prompt (str): The input prompt.
sd_options (str): The Stable Diffusion options.
seed (int): The random seed.
b1 (float): The backbone factor of the first stage block of decoder.
b2 (float): The backbone factor of the second stage block of decoder.
s1 (float): The skip factor of the first stage block of decoder.
s2 (float): The skip factor of the second stage block of decoder.

Returns:
images (list): A list of two images, the first one generated by Stable Diffusion and the second one generated by FreeU.
"""
global prompt_prev
global sd_options_prev
global seed_prev
global sd_image_prev

# if sd_options == 'SD1.5':
# pip = pip_1_5
# elif sd_options == 'SD2.1':
# pip = pip_2_1
# else:
# pip = pip_1_4

# Model selection
pip = pip_2_1

# Check if the input has changed
run_baseline = False
if prompt != prompt_prev or sd_options != sd_options_prev or seed != seed_prev:
run_baseline = True
prompt_prev = prompt
sd_options_prev = sd_options
seed_prev = seed

# Generate the baseline image
if run_baseline:
register_free_upblock2d(pip, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
register_free_crossattn_upblock2d(pip, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
Expand All @@ -50,20 +60,19 @@ def infer(prompt, sd_options, seed, b1, b2, s1, s2):
else:
sd_image = sd_image_prev


# Generate the FreeU image
register_free_upblock2d(pip, b1=b1, b2=b2, s1=s1, s2=s1)
register_free_crossattn_upblock2d(pip, b1=b1, b2=b2, s1=s1, s2=s1)

torch.manual_seed(seed)
print("Generating FreeU:")
freeu_image = pip(prompt, num_inference_steps=25).images[0]

# First SD, then freeu
# Return the images
images = [sd_image, freeu_image]

return images


# Examples
examples = [
[
"A drone view of celebration with Christma tree and fireworks, starry sky - background.",
Expand Down Expand Up @@ -117,8 +126,8 @@ def infer(prompt, sd_options, seed, b1, b2, s1, s2):
"a drone flying over a snowy forest."
],
]

# CSS styles
css = """
h1 {
text-align: center;
Expand All @@ -130,10 +139,13 @@ def infer(prompt, sd_options, seed, b1, b2, s1, s2):
}
"""

# Gradio app
block = gr.Blocks(css='style.css')

# Options
options = ['SD2.1']

# App layout
with block:
gr.Markdown("# SD 2.1 vs. FreeU")
with gr.Group():
Expand All @@ -150,9 +162,7 @@ def infer(prompt, sd_options, seed, b1, b2, s1, s2):
with gr.Row():
sd_options = gr.Dropdown(["SD2.1"], label="SD options", value="SD2.1", visible=False)




# FreeU parameters
with gr.Group():
with gr.Row():
with gr.Accordion('FreeU Parameters (feel free to adjust these parameters based on your prompt): ', open=False):
Expand Down Expand Up @@ -185,6 +195,7 @@ def infer(prompt, sd_options, seed, b1, b2, s1, s2):
step=1,
value=42)

# Image display
with gr.Row():
with gr.Group():
# btn = gr.Button("Generate image", scale=0)
Expand All @@ -201,11 +212,14 @@ def infer(prompt, sd_options, seed, b1, b2, s1, s2):
image_2_label = gr.Markdown("FreeU")


# Examples
ex = gr.Examples(examples=examples, fn=infer, inputs=[text, sd_options, seed, b1, b2, s1, s2], outputs=[image_1, image_2], cache_examples=False)
ex.dataset.headers = [""]

# Button click event
text.submit(infer, inputs=[text, sd_options, seed, b1, b2, s1, s2], outputs=[image_1, image_2])
btn.click(infer, inputs=[text, sd_options, seed, b1, b2, s1, s2], outputs=[image_1, image_2])

# Launch the app
block.launch()
# block.queue(default_enabled=False).launch(share=False)