56
56
from inpaint .utils import path_planning
57
57
from inpaint .bilateral_filtering import sparse_bilateral_filtering
58
58
59
+ # background removal
60
+ from rembg import new_session , remove
61
+
59
62
whole_size_threshold = 1600 # R_max from the paper
60
63
pix2pixsize = 1024
61
64
scriptname = "DepthMap v0.3.6"
@@ -108,10 +111,17 @@ def ui(self, is_img2img):
108
111
with gr .Row ():
109
112
inpaint = gr .Checkbox (label = "Generate 3D inpainted mesh. (Slooooooooow)" ,value = False , visible = False )
110
113
114
+ with gr .Group ():
115
+ with gr .Row ():
116
+ background_removal_model = gr .Dropdown (label = "Model" , choices = ['u2net' ,'u2netp' ,'u2net_human_seg' , 'silueta' ], value = 'u2net' , type = "value" , elem_id = "model_type" )
117
+ with gr .Row ():
118
+ background_removal = gr .Checkbox (label = "remove background" ,value = False )
119
+ save_background_removal_masks = gr .Checkbox (label = "save the foreground masks" ,value = False )
120
+ pre_depth_background_removal = gr .Checkbox (label = "pre-depth background removal" ,value = False )
121
+
111
122
with gr .Box ():
112
123
gr .HTML ("Information, comment and share @ <a href='https://github.com/thygate/stable-diffusion-webui-depthmap-script'>https://github.com/thygate/stable-diffusion-webui-depthmap-script</a>" )
113
124
114
-
115
125
clipthreshold_far .change (
116
126
fn = lambda a , b : a if b < a else b ,
117
127
inputs = [clipthreshold_far , clipthreshold_near ],
@@ -124,10 +134,10 @@ def ui(self, is_img2img):
124
134
outputs = [clipthreshold_far ]
125
135
)
126
136
127
- return [compute_device , model_type , net_width , net_height , match_size , invert_depth , boost , save_depth , show_depth , show_heat , combine_output , combine_output_axis , gen_stereo , gen_anaglyph , stereo_divergence , stereo_fill , stereo_balance , clipdepth , clipthreshold_far , clipthreshold_near , inpaint ]
137
+ return [compute_device , model_type , net_width , net_height , match_size , invert_depth , boost , save_depth , show_depth , show_heat , combine_output , combine_output_axis , gen_stereo , gen_anaglyph , stereo_divergence , stereo_fill , stereo_balance , clipdepth , clipthreshold_far , clipthreshold_near , inpaint , background_removal_model , background_removal , pre_depth_background_removal , save_background_removal_masks ]
128
138
129
139
# run from script in txt2img or img2img
130
- def run (self , p , compute_device , model_type , net_width , net_height , match_size , invert_depth , boost , save_depth , show_depth , show_heat , combine_output , combine_output_axis , gen_stereo , gen_anaglyph , stereo_divergence , stereo_fill , stereo_balance , clipdepth , clipthreshold_far , clipthreshold_near , inpaint ):
140
+ def run (self , p , compute_device , model_type , net_width , net_height , match_size , invert_depth , boost , save_depth , show_depth , show_heat , combine_output , combine_output_axis , gen_stereo , gen_anaglyph , stereo_divergence , stereo_fill , stereo_balance , clipdepth , clipthreshold_far , clipthreshold_near , inpaint , background_removal_model , background_removal , pre_depth_background_removal , save_background_removal_masks ):
131
141
132
142
# sd process
133
143
processed = processing .process_images (p )
@@ -140,14 +150,24 @@ def run(self, p, compute_device, model_type, net_width, net_height, match_size,
140
150
if count == 0 and len (processed .images ) > 1 :
141
151
continue
142
152
inputimages .append (processed .images [count ])
153
+
154
+ #remove on base image before depth calculation
155
+ background_removed_images = []
156
+ if background_removal :
157
+ if pre_depth_background_removal :
158
+ inputimages = batched_background_removal (inputimages , background_removal_model )
159
+ background_removed_images = inputimages
160
+ else :
161
+ background_removed_images = batched_background_removal (inputimages , background_removal_model )
143
162
144
- newmaps , mesh_fi = run_depthmap (processed , p .outpath_samples , inputimages , None , compute_device , model_type , net_width , net_height , match_size , invert_depth , boost , save_depth , show_depth , show_heat , combine_output , combine_output_axis , gen_stereo , gen_anaglyph , stereo_divergence , stereo_fill , stereo_balance , clipdepth , clipthreshold_far , clipthreshold_near , inpaint , "mp4" , 0 )
163
+ newmaps , mesh_fi = run_depthmap (processed , p .outpath_samples , inputimages , None , compute_device , model_type , net_width , net_height , match_size , invert_depth , boost , save_depth , show_depth , show_heat , combine_output , combine_output_axis , gen_stereo , gen_anaglyph , stereo_divergence , stereo_fill , stereo_balance , clipdepth , clipthreshold_far , clipthreshold_near , inpaint , "mp4" , 0 , background_removal , background_removed_images , save_background_removal_masks )
164
+
145
165
for img in newmaps :
146
166
processed .images .append (img )
147
167
148
168
return processed
149
169
150
- def run_depthmap (processed , outpath , inputimages , inputnames , compute_device , model_type , net_width , net_height , match_size , invert_depth , boost , save_depth , show_depth , show_heat , combine_output , combine_output_axis , gen_stereo , gen_anaglyph , stereo_divergence , stereo_fill , stereo_balance , clipdepth , clipthreshold_far , clipthreshold_near , inpaint , fnExt , vid_ssaa ):
170
+ def run_depthmap (processed , outpath , inputimages , inputnames , compute_device , model_type , net_width , net_height , match_size , invert_depth , boost , save_depth , show_depth , show_heat , combine_output , combine_output_axis , gen_stereo , gen_anaglyph , stereo_divergence , stereo_fill , stereo_balance , clipdepth , clipthreshold_far , clipthreshold_near , inpaint , fnExt , vid_ssaa , background_removal , background_removed_images , save_background_removal_masks ):
151
171
152
172
if len (inputimages ) == 0 or inputimages [0 ] == None :
153
173
return []
@@ -379,6 +399,29 @@ def run_depthmap(processed, outpath, inputimages, inputnames, compute_device, mo
379
399
p = Path (inputnames [count ])
380
400
basename = p .stem
381
401
402
+ rgb_image = inputimages [count ]
403
+
404
+ #applying background masks after depth
405
+ if background_removal :
406
+ print ('applying background masks' )
407
+ background_removed_image = background_removed_images [count - 1 ]
408
+ #maybe a threshold cut would be better on the line below.
409
+ background_removed_array = np .array (background_removed_image )
410
+ bg_mask = (background_removed_array [:,:,0 ]== 0 )| (background_removed_array [:,:,1 ]== 0 )| (background_removed_array [:,:,2 ]== 0 )
411
+ far_value = 255 if invert_depth else 0
412
+
413
+ img_output [bg_mask ] = far_value * far_value #255*255 or 0*0
414
+
415
+ #should this be optional
416
+ images .save_image (background_removed_image , path = outpath , basename = 'depthmap' , seed = None , prompt = None , extension = opts .samples_format , info = info , short_filename = True ,no_prompt = True , grid = False , pnginfo_section_name = "extras" , existing_info = None , forced_filename = None , suffix = "_background_removed" )
417
+ outimages .append (background_removed_image )
418
+ if save_background_removal_masks :
419
+ bg_array = (1 - bg_mask .astype ('int8' ))* 255
420
+ mask_array = np .stack ( (bg_array , bg_array , bg_array , bg_array ), axis = 2 )
421
+ mask_image = Image .fromarray ( mask_array .astype (np .uint8 ))
422
+ images .save_image (mask_image , path = outpath , basename = 'depthmap' , seed = None , prompt = None , extension = opts .samples_format , info = info , short_filename = True ,no_prompt = True , grid = False , pnginfo_section_name = "extras" , existing_info = None , forced_filename = None , suffix = "_foreground_mask" )
423
+ outimages .append (mask_image )
424
+
382
425
if not combine_output :
383
426
if show_depth :
384
427
outimages .append (Image .fromarray (img_output ))
@@ -396,7 +439,7 @@ def run_depthmap(processed, outpath, inputimages, inputnames, compute_device, mo
396
439
else :
397
440
images .save_image (Image .fromarray (img_output2 ), path = outpath , basename = basename , seed = None , prompt = None , extension = opts .samples_format , info = info , short_filename = True ,no_prompt = True , grid = False , pnginfo_section_name = "extras" , existing_info = None , forced_filename = None )
398
441
else :
399
- img_concat = np .concatenate ((inputimages [ count ] , img_output2 ), axis = combine_output_axis )
442
+ img_concat = np .concatenate ((rgb_image , img_output2 ), axis = combine_output_axis )
400
443
if show_depth :
401
444
outimages .append (Image .fromarray (img_concat ))
402
445
if save_depth and processed is not None :
@@ -1005,6 +1048,10 @@ def run_generate(depthmap_mode,
1005
1048
clipthreshold_far ,
1006
1049
clipthreshold_near ,
1007
1050
inpaint ,
1051
+ background_removal_model ,
1052
+ background_removal ,
1053
+ pre_depth_background_removal ,
1054
+ save_background_removal_masks ,
1008
1055
vid_format ,
1009
1056
vid_ssaa
1010
1057
):
@@ -1048,8 +1095,15 @@ def run_generate(depthmap_mode,
1048
1095
else :
1049
1096
outpath = opts .outdir_samples or opts .outdir_extras_samples
1050
1097
1098
+ background_removed_images = []
1099
+ if background_removal :
1100
+ if pre_depth_background_removal :
1101
+ imageArr = batched_background_removal (imageArr , background_removal_model )
1102
+ background_removed_images = imageArr
1103
+ else :
1104
+ background_removed_images = batched_background_removal (imageArr , background_removal_model )
1051
1105
1052
- outputs , mesh_fi = run_depthmap (None , outpath , imageArr , imageNameArr , compute_device , model_type , net_width , net_height , match_size , invert_depth , boost , save_depth , show_depth , show_heat , combine_output , combine_output_axis , gen_stereo , gen_anaglyph , stereo_divergence , stereo_fill , stereo_balance , clipdepth , clipthreshold_far , clipthreshold_near , inpaint , fnExt , vid_ssaa )
1106
+ outputs , mesh_fi = run_depthmap (None , outpath , imageArr , imageNameArr , compute_device , model_type , net_width , net_height , match_size , invert_depth , boost , save_depth , show_depth , show_heat , combine_output , combine_output_axis , gen_stereo , gen_anaglyph , stereo_divergence , stereo_fill , stereo_balance , clipdepth , clipthreshold_far , clipthreshold_near , inpaint , fnExt , vid_ssaa , background_removal , background_removed_images , save_background_removal_masks )
1053
1107
1054
1108
return outputs , mesh_fi , plaintext_to_html ('info' ), ''
1055
1109
@@ -1114,6 +1168,14 @@ def on_ui_tabs():
1114
1168
with gr .Row ():
1115
1169
inpaint = gr .Checkbox (label = "Generate 3D inpainted mesh and demo videos. (Sloooow)" ,value = False )
1116
1170
1171
+ with gr .Group ():
1172
+ with gr .Row ():
1173
+ background_removal_model = gr .Dropdown (label = "Model" , choices = ['u2net' ,'u2netp' ,'u2net_human_seg' , 'silueta' ], value = 'u2net' , type = "value" , elem_id = "model_type" )
1174
+ with gr .Row ():
1175
+ background_removal = gr .Checkbox (label = "remove background" ,value = False )
1176
+ save_background_removal_masks = gr .Checkbox (label = "save the foreground masks" ,value = False )
1177
+ pre_depth_background_removal = gr .Checkbox (label = "pre-depth background removal" ,value = False )
1178
+
1117
1179
with gr .Box ():
1118
1180
gr .HTML ("Information, comment and share @ <a href='https://github.com/thygate/stable-diffusion-webui-depthmap-script'>https://github.com/thygate/stable-diffusion-webui-depthmap-script</a>" )
1119
1181
@@ -1188,6 +1250,10 @@ def on_ui_tabs():
1188
1250
clipthreshold_far ,
1189
1251
clipthreshold_near ,
1190
1252
inpaint ,
1253
+ background_removal_model ,
1254
+ background_removal ,
1255
+ pre_depth_background_removal ,
1256
+ save_background_removal_masks ,
1191
1257
vid_format ,
1192
1258
vid_ssaa
1193
1259
],
@@ -1224,6 +1290,26 @@ def on_ui_tabs():
1224
1290
script_callbacks .on_ui_settings (on_ui_settings )
1225
1291
script_callbacks .on_ui_tabs (on_ui_tabs )
1226
1292
1293
+ def batched_background_removal (inimages , model_name ):
1294
+ print ('creating background masks' )
1295
+ outimages = []
1296
+
1297
+ # model path and name
1298
+ bg_model_dir = Path .joinpath (Path ().resolve (), "models/rem_bg" )
1299
+ os .makedirs (bg_model_dir , exist_ok = True )
1300
+ os .environ ["U2NET_HOME" ] = str (bg_model_dir )
1301
+
1302
+ #starting a session
1303
+ background_removal_session = new_session (model_name )
1304
+ for count in range (0 , len (inimages )):
1305
+ # skip first grid image
1306
+ if count == 0 and len (inimages ) > 1 :
1307
+ continue
1308
+ bg_remove_img = np .array (remove (inimages [count ], session = background_removal_session ))
1309
+ outimages .append (Image .fromarray (bg_remove_img ))
1310
+ #The line below might be redundant
1311
+ del background_removal_session
1312
+ return outimages
1227
1313
1228
1314
def download_file (filename , url ):
1229
1315
print ("Downloading" , url , "to" , filename )
0 commit comments