-
Notifications
You must be signed in to change notification settings - Fork 234
Expand file tree
/
Copy pathgenerator_factory.py
More file actions
30 lines (23 loc) · 1016 Bytes
/
generator_factory.py
File metadata and controls
30 lines (23 loc) · 1016 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import logging
from stable_diffusion_tf.stable_diffusion import StableDiffusion
from tensorflow import keras
import numpy as np
from typing import Optional
def make_stable_diffusion_model(height: int, width: int) -> StableDiffusion:
logging.debug(f"Creating stable diffusion model for images of dimension {width}x{height}")
generator = StableDiffusion(img_height=height, img_width=width, jit_compile=False)
return generator
def run_generator(generator: StableDiffusion, prompt: str, steps: int, scale: float, temperature: int, batch_size: int,
seed: int, negative_prompt: Optional[str], input_image: Optional[np.ndarray]) -> np.ndarray:
logging.debug(f"Start running generation for prompt `{prompt}` with negative prompt `{negative_prompt}`")
image = generator.generate(
prompt,
negative_prompt=negative_prompt,
num_steps=steps,
unconditional_guidance_scale=scale,
temperature=temperature,
input_image=input_image,
batch_size=batch_size,
seed=seed,
)
return image[0]