1
+ import numpy as np
2
+ from PIL import PngImagePlugin , Image
3
+ import base64
4
+ from io import BytesIO
5
+ from fastapi .exceptions import HTTPException
6
+
7
+ import gradio as gr
8
+
9
+
10
+ from src .core import core_generation_funnel , CoreGenerationFunnelInp
11
+ from src import backbone
12
+ from src .api .api_constants import api_defaults , api_forced , models_to_index
13
+
14
+ # moedified from modules/api/api.py auto1111
15
+ def decode_base64_to_image (encoding ):
16
+ if encoding .startswith ("data:image/" ):
17
+ encoding = encoding .split (";" )[1 ].split ("," )[1 ]
18
+ try :
19
+ image = Image .open (BytesIO (base64 .b64decode (encoding )))
20
+ return image
21
+ except Exception as e :
22
+ raise HTTPException (status_code = 500 , detail = "Invalid encoded image" ) from e
23
+
24
+ # modified from modules/api/api.py auto1111. TODO check that internally we always use png. Removed webp and jpeg
25
+ def encode_pil_to_base64 (image , image_type = 'png' ):
26
+ with BytesIO () as output_bytes :
27
+
28
+ if image_type == 'png' :
29
+ use_metadata = False
30
+ metadata = PngImagePlugin .PngInfo ()
31
+ for key , value in image .info .items ():
32
+ if isinstance (key , str ) and isinstance (value , str ):
33
+ metadata .add_text (key , value )
34
+ use_metadata = True
35
+ image .save (output_bytes , format = "PNG" , pnginfo = (metadata if use_metadata else None ))
36
+
37
+ else :
38
+ raise HTTPException (status_code = 500 , detail = "Invalid image format" )
39
+
40
+ bytes_data = output_bytes .getvalue ()
41
+
42
+ return base64 .b64encode (bytes_data )
43
+
44
+ def encode_to_base64 (image ):
45
+ if type (image ) is str :
46
+ return image
47
+ elif type (image ) is Image .Image :
48
+ return encode_pil_to_base64 (image )
49
+ elif type (image ) is np .ndarray :
50
+ return encode_np_to_base64 (image )
51
+ else :
52
+ return ""
53
+
54
+ def encode_np_to_base64 (image ):
55
+ pil = Image .fromarray (image )
56
+ return encode_pil_to_base64 (pil )
57
+
58
+ def to_base64_PIL (encoding : str ):
59
+ return Image .fromarray (np .array (decode_base64_to_image (encoding )).astype ('uint8' ))
60
+
61
+
62
+ def api_gen (input_images , client_options ):
63
+
64
+ default_options = CoreGenerationFunnelInp (api_defaults ).values
65
+
66
+ #TODO try-catch type errors here
67
+ for key , value in client_options .items ():
68
+ if key == "model_type" :
69
+ default_options [key ] = models_to_index (value )
70
+ continue
71
+ default_options [key ] = value
72
+
73
+ for key , value in api_forced .items ():
74
+ default_options [key .lower ()] = value
75
+
76
+ print (f"Processing { str (len (input_images ))} images through the API" )
77
+
78
+ print (default_options )
79
+
80
+ pil_images = []
81
+ for input_image in input_images :
82
+ pil_images .append (to_base64_PIL (input_image ))
83
+ outpath = backbone .get_outpath ()
84
+ gen_obj = core_generation_funnel (outpath , pil_images , None , None , default_options )
85
+ return gen_obj
0 commit comments