|
24 | 24 | from monai.transforms import LoadImage, Orientation
|
25 | 25 | from monai.utils import set_determinism
|
26 | 26 | from scripts.sample import LDMSampler, check_input
|
27 |
| -from scripts.utils import define_instance, load_autoencoder_ckpt, load_diffusion_ckpt |
| 27 | +from scripts.utils import define_instance |
28 | 28 | from scripts.utils_plot import find_label_center_loc, get_xyz_plot, show_image
|
29 | 29 |
|
30 | 30 |
|
@@ -76,23 +76,23 @@ def main():
|
76 | 76 | files = [
|
77 | 77 | {
|
78 | 78 | "path": "models/autoencoder_epoch273.pt",
|
79 |
| - "url": "https://drive.google.com/file/d/1jQefG0yJPzSvTG5rIJVHNqDReBTvVmZ0/view?usp=drive_link", |
| 79 | + "url": "https://drive.google.com/file/d/1Ojw25lFO8QbHkxazdK4CgZTyp3GFNZGz/view?usp=sharing", |
80 | 80 | },
|
81 | 81 | {
|
82 | 82 | "path": "models/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt",
|
83 |
| - "url": "https://drive.google.com/file/d/1FtOHBGUF5dLZNHtiuhf5EH448EQGGs-_/view?usp=sharing", |
| 83 | + "url": "https://drive.google.com/file/d/1lklNv4MTdI_9bwFRMd98QQ7JLerR5gC_/view?usp=drive_link", |
84 | 84 | },
|
85 | 85 | {
|
86 | 86 | "path": "models/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt",
|
87 |
| - "url": "https://drive.google.com/file/d/1izr52Whkk56OevNTk2QzI86eJV9TTaLk/view?usp=sharing", |
| 87 | + "url": "https://drive.google.com/file/d/1mLYeqeZ819_WpZPlAInhcWuCIHgn3QNT/view?usp=drive_link", |
88 | 88 | },
|
89 | 89 | {
|
90 | 90 | "path": "models/mask_generation_autoencoder.pt",
|
91 |
| - "url": "https://drive.google.com/file/d/1FzWrpv6ornYUaPiAWGOOxhRx2P9Wnynm/view?usp=drive_link", |
| 91 | + "url": "https://drive.google.com/file/d/19JnX-C6QAg4RfghTwpPnj4KEWhtawpCy/view?usp=drive_link", |
92 | 92 | },
|
93 | 93 | {
|
94 | 94 | "path": "models/mask_generation_diffusion_unet.pt",
|
95 |
| - "url": "https://drive.google.com/file/d/11SA9RUZ6XmCOJr5v6w6UW1kDzr6hlymw/view?usp=drive_link", |
| 95 | + "url": "https://drive.google.com/file/d/1yOQvlhXFGY1ZYavADM3N34vgg5AEitda/view?usp=drive_link", |
96 | 96 | },
|
97 | 97 | {
|
98 | 98 | "path": "configs/candidate_masks_flexible_size_and_spacing_3000.json",
|
@@ -155,29 +155,27 @@ def main():
|
155 | 155 | device = torch.device("cuda")
|
156 | 156 |
|
157 | 157 | autoencoder = define_instance(args, "autoencoder_def").to(device)
|
158 |
| - checkpoint_autoencoder = load_autoencoder_ckpt(args.trained_autoencoder_path) |
| 158 | + checkpoint_autoencoder = torch.load(args.trained_autoencoder_path) |
159 | 159 | autoencoder.load_state_dict(checkpoint_autoencoder)
|
160 | 160 |
|
161 | 161 | diffusion_unet = define_instance(args, "diffusion_unet_def").to(device)
|
162 | 162 | checkpoint_diffusion_unet = torch.load(args.trained_diffusion_path)
|
163 |
| - new_dict = load_diffusion_ckpt(diffusion_unet.state_dict(), checkpoint_diffusion_unet["unet_state_dict"]) |
164 |
| - diffusion_unet.load_state_dict(new_dict, strict=True) |
| 163 | + diffusion_unet.load_state_dict(checkpoint_diffusion_unet["unet_state_dict"], strict=True) |
165 | 164 | scale_factor = checkpoint_diffusion_unet["scale_factor"].to(device)
|
166 | 165 |
|
167 | 166 | controlnet = define_instance(args, "controlnet_def").to(device)
|
168 | 167 | checkpoint_controlnet = torch.load(args.trained_controlnet_path)
|
169 |
| - new_dict = load_diffusion_ckpt(controlnet.state_dict(), checkpoint_controlnet["controlnet_state_dict"]) |
170 | 168 | monai.networks.utils.copy_model_state(controlnet, diffusion_unet.state_dict())
|
171 |
| - controlnet.load_state_dict(new_dict, strict=True) |
| 169 | + controlnet.load_state_dict(checkpoint_controlnet["controlnet_state_dict"], strict=True) |
172 | 170 |
|
173 | 171 | mask_generation_autoencoder = define_instance(args, "mask_generation_autoencoder_def").to(device)
|
174 |
| - checkpoint_mask_generation_autoencoder = load_autoencoder_ckpt(args.trained_mask_generation_autoencoder_path) |
| 172 | + checkpoint_mask_generation_autoencoder = torch.load(args.trained_mask_generation_autoencoder_path) |
175 | 173 | mask_generation_autoencoder.load_state_dict(checkpoint_mask_generation_autoencoder)
|
176 | 174 |
|
177 | 175 | mask_generation_diffusion_unet = define_instance(args, "mask_generation_diffusion_def").to(device)
|
178 | 176 | checkpoint_mask_generation_diffusion_unet = torch.load(args.trained_mask_generation_diffusion_path)
|
179 |
| - mask_generation_diffusion_unet.load_old_state_dict(checkpoint_mask_generation_diffusion_unet) |
180 |
| - mask_generation_scale_factor = args.mask_generation_scale_factor |
| 177 | + mask_generation_diffusion_unet.load_state_dict(checkpoint_mask_generation_diffusion_unet["unet_state_dict"]) |
| 178 | + mask_generation_scale_factor = checkpoint_mask_generation_diffusion_unet["scale_factor"] |
181 | 179 |
|
182 | 180 | print("All the trained model weights have been loaded.")
|
183 | 181 |
|
|
0 commit comments