Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

2 changes: 1 addition & 1 deletion backends/model_converter/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def convert_model(checkpoint_filename=None, out_filename=None, torch_weights=No
raise ValueError("Invalid sd_version "+ sd_version)
model_metadata = {"float_type" : cur_dtype , "sd_type" :sd_version, "type" : sd_type }
print("__converted_model_data__" , json.dumps(model_metadata))
return {"output_path": out_filename, "model_metadata": model_metadata}


def usage():
Expand Down Expand Up @@ -155,4 +156,3 @@ def usage():

convert_model(checkpoint_filename , out_filename )


1 change: 1 addition & 0 deletions backends/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

125 changes: 18 additions & 107 deletions backends/stable_diffusion/diffusionbee_backend.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,23 @@
print("starting backend")
import numpy as np
import argparse
from PIL import Image
import json
import random
import multiprocessing
import sys
import copy
import math
import time
import traceback
import os
from pathlib import Path


# b2py t2im {"prompt": "sun glasses" , "img_width":640 , "img_height" : 640 , "num_imgs" : 10 , "input_image":"/Users/divamgupta/Downloads/inn.png" , "mask_image" : "/Users/divamgupta/Downloads/maa.png" , "is_inpaint":true }

if not ( getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS')):
print("Adding sys paths")
dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(dir_path , "../model_converter"))

model_interface_path = os.environ.get('MODEL_INTERFACE_PATH') or "../stable_diffusion_tf_models"
sys.path.append( os.path.join(dir_path , model_interface_path) )
else:
print("not adding sys paths")


from convert_model import convert_model
from stable_diffusion.stable_diffusion import StableDiffusion , ModelContainer
from stable_diffusion.utils.utils import get_sd_run_from_dict

from applets.applets import register_applet , run_applet
from applets.frame_interpolator import FrameInterpolator
# get the model interface form the environ
USE_DUMMY_INTERFACE = False
if USE_DUMMY_INTERFACE :
from fake_interface import ModelInterface
else:
from interface import ModelInterface

model_container = ModelContainer()



home_path = Path.home()

projects_root_path = os.path.join(home_path, ".diffusionbee")

if not os.path.isdir(projects_root_path):
os.mkdir(projects_root_path)




if 'DEBUG' in os.environ and str(os.environ['DEBUG']) == '1':
debug_output_path = os.path.join(projects_root_path, "debug_outs")
if not os.path.isdir(debug_output_path):
os.mkdir(debug_output_path)
print("Debug outputs stored at : " , debug_output_path )
else:
debug_output_path = None


try:
from .applets.applets import register_applet, run_applet
from .applets.frame_interpolator import FrameInterpolator
from .service import DiffusionBeeService, generation_result_payload
except ImportError:
from applets.applets import register_applet, run_applet
from applets.frame_interpolator import FrameInterpolator
from service import DiffusionBeeService, generation_result_payload


defualt_data_root = os.path.join(projects_root_path, "images")

# b2py t2im {"prompt": "sun glasses" , "img_width":640 , "img_height" : 640 , "num_imgs" : 10 , "input_image":"/Users/divamgupta/Downloads/inn.png" , "mask_image" : "/Users/divamgupta/Downloads/maa.png" , "is_inpaint":true }

if not os.path.isdir(defualt_data_root):
os.mkdir(defualt_data_root)
service = DiffusionBeeService()



Expand All @@ -95,47 +43,10 @@ def __getattr__(self, attr):



def process_opt(d, generator):

batch_size = 1# int(d['batch_size'])
n_imgs = math.ceil(d['num_imgs'] / batch_size)
sd_run = get_sd_run_from_dict(d)

for i in range(n_imgs):

sd_run.img_id = i

print("got" , d )

outs = generator.generate(sd_run)

if outs is None:
return

img = outs['img']

if img is None:
return

for i in range(len(img)):
s = ''.join(filter(str.isalnum, str(d['prompt'])[:30] ))
fpath = os.path.join(defualt_data_root , "%s_%d.png"%(s , random.randint(0 ,100000000)) )

Image.fromarray(img[i]).save(fpath)
ret_dict = {"generated_img_path" : fpath}

if 'aux_img' in outs:
ret_dict['aux_output_image_path'] = outs['aux_img']

print("sdbk nwim %s"%(json.dumps(ret_dict)) )




def diffusion_bee_main():

time.sleep(2)
register_applet(model_container , FrameInterpolator)
register_applet(service.get_model_container(), FrameInterpolator)

print("sdbk mltl Loading Model")

Expand All @@ -148,8 +59,8 @@ def callback(state="" , progress=-1):
if "__stop__" in get_input():
return "stop"

generator = StableDiffusion( model_container , ModelInterface , None , model_name=None, callback=callback, debug_output_path=debug_output_path )

service._progress_callback = callback
service._get_generator()

print("sdbk mdld")

Expand All @@ -171,8 +82,9 @@ def callback(state="" , progress=-1):
try:
d = json.loads(inp_str)
print("sdbk inwk") # working on the input

process_opt(d, generator)
result = service.generate_images(d, progress_callback=callback)
for image in generation_result_payload(result)["images"]:
print("sdbk nwim %s"%(json.dumps(image)) )

except Exception as e:
traceback.print_exc()
Expand All @@ -193,7 +105,7 @@ def callback(state="" , progress=-1):
print("sdbk errr %s"%(str(e)))


from stable_diffusion.utils.stdin_input import is_avail, get_input
from stable_diffusion.utils.stdin_input import is_avail, get_input


if __name__ == "__main__":
Expand All @@ -202,8 +114,7 @@ def callback(state="" , progress=-1):
if len(sys.argv) > 1 and sys.argv[1] == 'convert_model':
checkpoint_filename = sys.argv[2]
out_filename = sys.argv[3]
convert_model(checkpoint_filename, out_filename )
service.convert_model(checkpoint_filename, out_filename)
print("model converted ")
else:
diffusion_bee_main()

Loading