Skip to content

Commit 367d3db

Browse files
authored
Merge pull request #78 from graemeniedermayer/rembg-integration
Rembg integration
2 parents 3e1bd5e + bb4bf46 commit 367d3db

File tree

3 files changed

+107
-7
lines changed

3 files changed

+107
-7
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,4 +246,16 @@ Boosting Monocular Depth Estimation Models to High-Resolution via Content-Adapti
246246
booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
247247
year = {2020}
248248
}
249+
```
250+
251+
U2-Net:
252+
```
253+
@InProceedings{Qin_2020_PR,
254+
title = {U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection},
255+
author = {Qin, Xuebin and Zhang, Zichen and Huang, Chenyang and Dehghan, Masood and Zaiane, Osmar and Jagersand, Martin},
256+
journal = {Pattern Recognition},
257+
volume = {106},
258+
pages = {107404},
259+
year = {2020}
260+
}
249261
```

install.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
if not launch.is_installed("vispy"):
1212
launch.run_pip("install vispy", "vispy requirement for depthmap script")
1313

14+
if not launch.is_installed("rembg"):
15+
launch.run_pip("install rembg", "rembg requirement for depthmap script")
1416

1517
if not launch.is_installed("moviepy"):
1618
launch.run_pip("install moviepy==1.0.2", "moviepy requirement for depthmap script")

scripts/depthmap.py

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@
5656
from inpaint.utils import path_planning
5757
from inpaint.bilateral_filtering import sparse_bilateral_filtering
5858

59+
# background removal
60+
from rembg import new_session, remove
61+
5962
whole_size_threshold = 1600 # R_max from the paper
6063
pix2pixsize = 1024
6164
scriptname = "DepthMap v0.3.6"
@@ -108,10 +111,17 @@ def ui(self, is_img2img):
108111
with gr.Row():
109112
inpaint = gr.Checkbox(label="Generate 3D inpainted mesh. (Slooooooooow)",value=False, visible=False)
110113

114+
with gr.Group():
115+
with gr.Row():
116+
background_removal_model = gr.Dropdown(label="Model", choices=['u2net','u2netp','u2net_human_seg', 'silueta'], value='u2net', type="value", elem_id="model_type")
117+
with gr.Row():
118+
background_removal = gr.Checkbox(label="remove background",value=False)
119+
save_background_removal_masks = gr.Checkbox(label="save the foreground masks",value=False)
120+
pre_depth_background_removal = gr.Checkbox(label="pre-depth background removal",value=False)
121+
111122
with gr.Box():
112123
gr.HTML("Information, comment and share @ <a href='https://github.com/thygate/stable-diffusion-webui-depthmap-script'>https://github.com/thygate/stable-diffusion-webui-depthmap-script</a>")
113124

114-
115125
clipthreshold_far.change(
116126
fn = lambda a, b: a if b < a else b,
117127
inputs = [clipthreshold_far, clipthreshold_near],
@@ -124,10 +134,10 @@ def ui(self, is_img2img):
124134
outputs=[clipthreshold_far]
125135
)
126136

127-
return [compute_device, model_type, net_width, net_height, match_size, invert_depth, boost, save_depth, show_depth, show_heat, combine_output, combine_output_axis, gen_stereo, gen_anaglyph, stereo_divergence, stereo_fill, stereo_balance, clipdepth, clipthreshold_far, clipthreshold_near, inpaint]
137+
return [compute_device, model_type, net_width, net_height, match_size, invert_depth, boost, save_depth, show_depth, show_heat, combine_output, combine_output_axis, gen_stereo, gen_anaglyph, stereo_divergence, stereo_fill, stereo_balance, clipdepth, clipthreshold_far, clipthreshold_near, inpaint, background_removal_model, background_removal, pre_depth_background_removal, save_background_removal_masks]
128138

129139
# run from script in txt2img or img2img
130-
def run(self, p, compute_device, model_type, net_width, net_height, match_size, invert_depth, boost, save_depth, show_depth, show_heat, combine_output, combine_output_axis, gen_stereo, gen_anaglyph, stereo_divergence, stereo_fill, stereo_balance, clipdepth, clipthreshold_far, clipthreshold_near, inpaint):
140+
def run(self, p, compute_device, model_type, net_width, net_height, match_size, invert_depth, boost, save_depth, show_depth, show_heat, combine_output, combine_output_axis, gen_stereo, gen_anaglyph, stereo_divergence, stereo_fill, stereo_balance, clipdepth, clipthreshold_far, clipthreshold_near, inpaint, background_removal_model, background_removal, pre_depth_background_removal, save_background_removal_masks):
131141

132142
# sd process
133143
processed = processing.process_images(p)
@@ -140,14 +150,24 @@ def run(self, p, compute_device, model_type, net_width, net_height, match_size,
140150
if count == 0 and len(processed.images) > 1:
141151
continue
142152
inputimages.append(processed.images[count])
153+
154+
#remove on base image before depth calculation
155+
background_removed_images = []
156+
if background_removal:
157+
if pre_depth_background_removal:
158+
inputimages = batched_background_removal(inputimages, background_removal_model)
159+
background_removed_images = inputimages
160+
else:
161+
background_removed_images = batched_background_removal(inputimages, background_removal_model)
143162

144-
newmaps, mesh_fi = run_depthmap(processed, p.outpath_samples, inputimages, None, compute_device, model_type, net_width, net_height, match_size, invert_depth, boost, save_depth, show_depth, show_heat, combine_output, combine_output_axis, gen_stereo, gen_anaglyph, stereo_divergence, stereo_fill, stereo_balance, clipdepth, clipthreshold_far, clipthreshold_near, inpaint, "mp4", 0)
163+
newmaps, mesh_fi = run_depthmap(processed, p.outpath_samples, inputimages, None, compute_device, model_type, net_width, net_height, match_size, invert_depth, boost, save_depth, show_depth, show_heat, combine_output, combine_output_axis, gen_stereo, gen_anaglyph, stereo_divergence, stereo_fill, stereo_balance, clipdepth, clipthreshold_far, clipthreshold_near, inpaint, "mp4", 0, background_removal, background_removed_images, save_background_removal_masks)
164+
145165
for img in newmaps:
146166
processed.images.append(img)
147167

148168
return processed
149169

150-
def run_depthmap(processed, outpath, inputimages, inputnames, compute_device, model_type, net_width, net_height, match_size, invert_depth, boost, save_depth, show_depth, show_heat, combine_output, combine_output_axis, gen_stereo, gen_anaglyph, stereo_divergence, stereo_fill, stereo_balance, clipdepth, clipthreshold_far, clipthreshold_near, inpaint, fnExt, vid_ssaa):
170+
def run_depthmap(processed, outpath, inputimages, inputnames, compute_device, model_type, net_width, net_height, match_size, invert_depth, boost, save_depth, show_depth, show_heat, combine_output, combine_output_axis, gen_stereo, gen_anaglyph, stereo_divergence, stereo_fill, stereo_balance, clipdepth, clipthreshold_far, clipthreshold_near, inpaint, fnExt, vid_ssaa, background_removal, background_removed_images, save_background_removal_masks):
151171

152172
if len(inputimages) == 0 or inputimages[0] == None:
153173
return []
@@ -379,6 +399,29 @@ def run_depthmap(processed, outpath, inputimages, inputnames, compute_device, mo
379399
p = Path(inputnames[count])
380400
basename = p.stem
381401

402+
rgb_image = inputimages[count]
403+
404+
#applying background masks after depth
405+
if background_removal:
406+
print('applying background masks')
407+
background_removed_image = background_removed_images[count-1]
408+
#maybe a threshold cut would be better on the line below.
409+
background_removed_array = np.array(background_removed_image)
410+
bg_mask = (background_removed_array[:,:,0]==0)|(background_removed_array[:,:,1]==0)|(background_removed_array[:,:,2]==0)
411+
far_value = 255 if invert_depth else 0
412+
413+
img_output[bg_mask] = far_value * far_value #255*255 or 0*0
414+
415+
#should this be optional
416+
images.save_image(background_removed_image, path=outpath, basename='depthmap', seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=None, forced_filename=None, suffix="_background_removed")
417+
outimages.append(background_removed_image )
418+
if save_background_removal_masks:
419+
bg_array = (1 - bg_mask.astype('int8'))*255
420+
mask_array = np.stack( (bg_array, bg_array, bg_array, bg_array), axis=2)
421+
mask_image = Image.fromarray( mask_array.astype(np.uint8))
422+
images.save_image(mask_image, path=outpath, basename='depthmap', seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=None, forced_filename=None, suffix="_foreground_mask")
423+
outimages.append(mask_image)
424+
382425
if not combine_output:
383426
if show_depth:
384427
outimages.append(Image.fromarray(img_output))
@@ -396,7 +439,7 @@ def run_depthmap(processed, outpath, inputimages, inputnames, compute_device, mo
396439
else:
397440
images.save_image(Image.fromarray(img_output2), path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=None, forced_filename=None)
398441
else:
399-
img_concat = np.concatenate((inputimages[count], img_output2), axis=combine_output_axis)
442+
img_concat = np.concatenate((rgb_image, img_output2), axis=combine_output_axis)
400443
if show_depth:
401444
outimages.append(Image.fromarray(img_concat))
402445
if save_depth and processed is not None:
@@ -1005,6 +1048,10 @@ def run_generate(depthmap_mode,
10051048
clipthreshold_far,
10061049
clipthreshold_near,
10071050
inpaint,
1051+
background_removal_model,
1052+
background_removal,
1053+
pre_depth_background_removal,
1054+
save_background_removal_masks,
10081055
vid_format,
10091056
vid_ssaa
10101057
):
@@ -1048,8 +1095,15 @@ def run_generate(depthmap_mode,
10481095
else:
10491096
outpath = opts.outdir_samples or opts.outdir_extras_samples
10501097

1098+
background_removed_images = []
1099+
if background_removal:
1100+
if pre_depth_background_removal:
1101+
imageArr = batched_background_removal(imageArr, background_removal_model)
1102+
background_removed_images = imageArr
1103+
else:
1104+
background_removed_images = batched_background_removal(imageArr, background_removal_model)
10511105

1052-
outputs, mesh_fi = run_depthmap(None, outpath, imageArr, imageNameArr, compute_device, model_type, net_width, net_height, match_size, invert_depth, boost, save_depth, show_depth, show_heat, combine_output, combine_output_axis, gen_stereo, gen_anaglyph, stereo_divergence, stereo_fill, stereo_balance, clipdepth, clipthreshold_far, clipthreshold_near, inpaint, fnExt, vid_ssaa)
1106+
outputs, mesh_fi = run_depthmap(None, outpath, imageArr, imageNameArr, compute_device, model_type, net_width, net_height, match_size, invert_depth, boost, save_depth, show_depth, show_heat, combine_output, combine_output_axis, gen_stereo, gen_anaglyph, stereo_divergence, stereo_fill, stereo_balance, clipdepth, clipthreshold_far, clipthreshold_near, inpaint, fnExt, vid_ssaa, background_removal, background_removed_images, save_background_removal_masks)
10531107

10541108
return outputs, mesh_fi, plaintext_to_html('info'), ''
10551109

@@ -1114,6 +1168,14 @@ def on_ui_tabs():
11141168
with gr.Row():
11151169
inpaint = gr.Checkbox(label="Generate 3D inpainted mesh and demo videos. (Sloooow)",value=False)
11161170

1171+
with gr.Group():
1172+
with gr.Row():
1173+
background_removal_model = gr.Dropdown(label="Model", choices=['u2net','u2netp','u2net_human_seg', 'silueta'], value='u2net', type="value", elem_id="model_type")
1174+
with gr.Row():
1175+
background_removal = gr.Checkbox(label="remove background",value=False)
1176+
save_background_removal_masks = gr.Checkbox(label="save the foreground masks",value=False)
1177+
pre_depth_background_removal = gr.Checkbox(label="pre-depth background removal",value=False)
1178+
11171179
with gr.Box():
11181180
gr.HTML("Information, comment and share @ <a href='https://github.com/thygate/stable-diffusion-webui-depthmap-script'>https://github.com/thygate/stable-diffusion-webui-depthmap-script</a>")
11191181

@@ -1188,6 +1250,10 @@ def on_ui_tabs():
11881250
clipthreshold_far,
11891251
clipthreshold_near,
11901252
inpaint,
1253+
background_removal_model,
1254+
background_removal,
1255+
pre_depth_background_removal,
1256+
save_background_removal_masks,
11911257
vid_format,
11921258
vid_ssaa
11931259
],
@@ -1224,6 +1290,26 @@ def on_ui_tabs():
12241290
script_callbacks.on_ui_settings(on_ui_settings)
12251291
script_callbacks.on_ui_tabs(on_ui_tabs)
12261292

1293+
def batched_background_removal(inimages, model_name):
1294+
print('creating background masks')
1295+
outimages = []
1296+
1297+
# model path and name
1298+
bg_model_dir = Path.joinpath(Path().resolve(), "models/rem_bg")
1299+
os.makedirs(bg_model_dir, exist_ok=True)
1300+
os.environ["U2NET_HOME"] = str(bg_model_dir)
1301+
1302+
#starting a session
1303+
background_removal_session = new_session(model_name)
1304+
for count in range(0, len(inimages)):
1305+
# skip first grid image
1306+
if count == 0 and len(inimages) > 1:
1307+
continue
1308+
bg_remove_img = np.array(remove(inimages[count], session=background_removal_session))
1309+
outimages.append(Image.fromarray(bg_remove_img))
1310+
#The line below might be redundant
1311+
del background_removal_session
1312+
return outimages
12271313

12281314
def download_file(filename, url):
12291315
print("Downloading", url, "to", filename)

0 commit comments

Comments
 (0)