Skip to content

Commit 7f397e5

Browse files
update maisi ckpt link, update load functions (#1792)
Fixes #1772. ### Description update maisi ckpt link, update load functions ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Avoid including large-size files in the PR. - [ ] Clean up long text outputs from code cells in the notebook. - [ ] For security purposes, please check the contents and remove any sensitive info such as user names and private key. - [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use relative paths for tutorial repo files (3) put figure and graphs in the `./figure` folder - [ ] Notebook runs automatically `./runner.sh -t <path to .ipynb file>` --------- Signed-off-by: Can-Zhao <volcanofly@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 88d22e8 commit 7f397e5

File tree

6 files changed

+89
-172
lines changed

6 files changed

+89
-172
lines changed

generation/maisi/configs/config_maisi.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"use_checkpointing": false,
3333
"use_convtranspose": false,
3434
"norm_float16": true,
35-
"num_splits": 16,
35+
"num_splits": 8,
3636
"dim_split": 1
3737
},
3838
"diffusion_unet_def": {

generation/maisi/maisi_inference_tutorial.ipynb

Lines changed: 74 additions & 65 deletions
Large diffs are not rendered by default.

generation/maisi/scripts/diff_model_infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def load_models(args: argparse.Namespace, device: torch.device, logger: logging.
5959
"""
6060
autoencoder = define_instance(args, "autoencoder_def").to(device)
6161
try:
62-
checkpoint_autoencoder = load_autoencoder_ckpt(args.trained_autoencoder_path)
62+
checkpoint_autoencoder = torch.load(args.trained_autoencoder_path)
6363
autoencoder.load_state_dict(checkpoint_autoencoder)
6464
except Exception:
6565
logger.error("The trained_autoencoder_path does not exist!")

generation/maisi/scripts/infer_controlnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def main():
9898
if args.trained_autoencoder_path is not None:
9999
if not os.path.exists(args.trained_autoencoder_path):
100100
raise ValueError("Please download the autoencoder checkpoint.")
101-
autoencoder_ckpt = load_autoencoder_ckpt(args.trained_autoencoder_path)
101+
autoencoder_ckpt = torch.load(args.trained_autoencoder_path)
102102
autoencoder.load_state_dict(autoencoder_ckpt)
103103
logger.info(f"Load trained diffusion model from {args.trained_autoencoder_path}.")
104104
else:

generation/maisi/scripts/inference.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from monai.transforms import LoadImage, Orientation
2525
from monai.utils import set_determinism
2626
from scripts.sample import LDMSampler, check_input
27-
from scripts.utils import define_instance, load_autoencoder_ckpt, load_diffusion_ckpt
27+
from scripts.utils import define_instance
2828
from scripts.utils_plot import find_label_center_loc, get_xyz_plot, show_image
2929

3030

@@ -76,23 +76,23 @@ def main():
7676
files = [
7777
{
7878
"path": "models/autoencoder_epoch273.pt",
79-
"url": "https://drive.google.com/file/d/1jQefG0yJPzSvTG5rIJVHNqDReBTvVmZ0/view?usp=drive_link",
79+
"url": "https://drive.google.com/file/d/1Ojw25lFO8QbHkxazdK4CgZTyp3GFNZGz/view?usp=sharing",
8080
},
8181
{
8282
"path": "models/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt",
83-
"url": "https://drive.google.com/file/d/1FtOHBGUF5dLZNHtiuhf5EH448EQGGs-_/view?usp=sharing",
83+
"url": "https://drive.google.com/file/d/1lklNv4MTdI_9bwFRMd98QQ7JLerR5gC_/view?usp=drive_link",
8484
},
8585
{
8686
"path": "models/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt",
87-
"url": "https://drive.google.com/file/d/1izr52Whkk56OevNTk2QzI86eJV9TTaLk/view?usp=sharing",
87+
"url": "https://drive.google.com/file/d/1mLYeqeZ819_WpZPlAInhcWuCIHgn3QNT/view?usp=drive_link",
8888
},
8989
{
9090
"path": "models/mask_generation_autoencoder.pt",
91-
"url": "https://drive.google.com/file/d/1FzWrpv6ornYUaPiAWGOOxhRx2P9Wnynm/view?usp=drive_link",
91+
"url": "https://drive.google.com/file/d/19JnX-C6QAg4RfghTwpPnj4KEWhtawpCy/view?usp=drive_link",
9292
},
9393
{
9494
"path": "models/mask_generation_diffusion_unet.pt",
95-
"url": "https://drive.google.com/file/d/11SA9RUZ6XmCOJr5v6w6UW1kDzr6hlymw/view?usp=drive_link",
95+
"url": "https://drive.google.com/file/d/1yOQvlhXFGY1ZYavADM3N34vgg5AEitda/view?usp=drive_link",
9696
},
9797
{
9898
"path": "configs/candidate_masks_flexible_size_and_spacing_3000.json",
@@ -155,29 +155,27 @@ def main():
155155
device = torch.device("cuda")
156156

157157
autoencoder = define_instance(args, "autoencoder_def").to(device)
158-
checkpoint_autoencoder = load_autoencoder_ckpt(args.trained_autoencoder_path)
158+
checkpoint_autoencoder = torch.load(args.trained_autoencoder_path)
159159
autoencoder.load_state_dict(checkpoint_autoencoder)
160160

161161
diffusion_unet = define_instance(args, "diffusion_unet_def").to(device)
162162
checkpoint_diffusion_unet = torch.load(args.trained_diffusion_path)
163-
new_dict = load_diffusion_ckpt(diffusion_unet.state_dict(), checkpoint_diffusion_unet["unet_state_dict"])
164-
diffusion_unet.load_state_dict(new_dict, strict=True)
163+
diffusion_unet.load_state_dict(checkpoint_diffusion_unet["unet_state_dict"], strict=True)
165164
scale_factor = checkpoint_diffusion_unet["scale_factor"].to(device)
166165

167166
controlnet = define_instance(args, "controlnet_def").to(device)
168167
checkpoint_controlnet = torch.load(args.trained_controlnet_path)
169-
new_dict = load_diffusion_ckpt(controlnet.state_dict(), checkpoint_controlnet["controlnet_state_dict"])
170168
monai.networks.utils.copy_model_state(controlnet, diffusion_unet.state_dict())
171-
controlnet.load_state_dict(new_dict, strict=True)
169+
controlnet.load_state_dict(checkpoint_controlnet["controlnet_state_dict"], strict=True)
172170

173171
mask_generation_autoencoder = define_instance(args, "mask_generation_autoencoder_def").to(device)
174-
checkpoint_mask_generation_autoencoder = load_autoencoder_ckpt(args.trained_mask_generation_autoencoder_path)
172+
checkpoint_mask_generation_autoencoder = torch.load(args.trained_mask_generation_autoencoder_path)
175173
mask_generation_autoencoder.load_state_dict(checkpoint_mask_generation_autoencoder)
176174

177175
mask_generation_diffusion_unet = define_instance(args, "mask_generation_diffusion_def").to(device)
178176
checkpoint_mask_generation_diffusion_unet = torch.load(args.trained_mask_generation_diffusion_path)
179-
mask_generation_diffusion_unet.load_old_state_dict(checkpoint_mask_generation_diffusion_unet)
180-
mask_generation_scale_factor = args.mask_generation_scale_factor
177+
mask_generation_diffusion_unet.load_state_dict(checkpoint_mask_generation_diffusion_unet["unet_state_dict"])
178+
mask_generation_scale_factor = checkpoint_mask_generation_diffusion_unet["scale_factor"]
181179

182180
print("All the trained model weights have been loaded.")
183181

generation/maisi/scripts/utils.py

Lines changed: 0 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -669,96 +669,6 @@ def __call__(self, img: NdarrayOrTensor):
669669
return out
670670

671671

672-
def load_autoencoder_ckpt(load_autoencoder_path):
673-
"""
674-
Load a state dict from an autoencoder checkpoint trained with
675-
[MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).
676-
677-
The loaded state dict is for
678-
monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi.
679-
680-
Args:
681-
load_autoencoder_path (str): Path to the autoencoder checkpoint file.
682-
683-
Returns:
684-
dict: Processed state dictionary for the autoencoder.
685-
"""
686-
checkpoint_autoencoder = torch.load(load_autoencoder_path)
687-
new_state_dict = {}
688-
for k, v in checkpoint_autoencoder.items():
689-
if "decoder" in k and "conv" in k:
690-
new_key = (
691-
k.replace("conv.weight", "conv.conv.weight")
692-
if "conv.weight" in k
693-
else k.replace("conv.bias", "conv.conv.bias")
694-
)
695-
new_state_dict[new_key] = v
696-
elif "encoder" in k and "conv" in k:
697-
new_key = (
698-
k.replace("conv.weight", "conv.conv.weight")
699-
if "conv.weight" in k
700-
else k.replace("conv.bias", "conv.conv.bias")
701-
)
702-
new_state_dict[new_key] = v
703-
else:
704-
new_state_dict[k] = v
705-
checkpoint_autoencoder = new_state_dict
706-
return checkpoint_autoencoder
707-
708-
709-
def load_diffusion_ckpt(new_state_dict: dict, old_state_dict: dict, verbose=False) -> dict:
710-
"""
711-
Load a state dict from a DiffusionModelUNet trained with
712-
[MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).
713-
714-
The loaded state dict is for
715-
monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi.
716-
717-
Args:
718-
new_state_dict: state dict from the new model.
719-
old_state_dict: state dict from the old model.
720-
"""
721-
if verbose:
722-
# print all new_state_dict keys that are not in old_state_dict
723-
for k in new_state_dict:
724-
if k not in old_state_dict:
725-
logging.info(f"New key {k} not found in old state dict")
726-
# and vice versa
727-
for k in old_state_dict:
728-
if k not in new_state_dict:
729-
logging.info(f"Old key {k} not found in new state dict")
730-
731-
# copy over all matching keys
732-
for k in new_state_dict:
733-
if k in old_state_dict:
734-
new_state_dict[k] = old_state_dict.pop(k)
735-
736-
# fix the attention blocks
737-
# attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k]
738-
attention_blocks = [k.replace(".attn.to_k.weight", "") for k in new_state_dict if "attn.to_k.weight" in k]
739-
for block in attention_blocks:
740-
new_state_dict[f"{block}.attn.to_q.weight"] = old_state_dict.pop(f"{block}.to_q.weight")
741-
new_state_dict[f"{block}.attn.to_k.weight"] = old_state_dict.pop(f"{block}.to_k.weight")
742-
new_state_dict[f"{block}.attn.to_v.weight"] = old_state_dict.pop(f"{block}.to_v.weight")
743-
new_state_dict[f"{block}.attn.to_q.bias"] = old_state_dict.pop(f"{block}.to_q.bias")
744-
new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias")
745-
new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias")
746-
747-
# projection
748-
new_state_dict[f"{block}.attn.out_proj.weight"] = old_state_dict.pop(f"{block}.proj_attn.weight")
749-
new_state_dict[f"{block}.attn.out_proj.bias"] = old_state_dict.pop(f"{block}.proj_attn.bias")
750-
751-
# fix the upsample conv blocks which were renamed postconv
752-
for k in new_state_dict:
753-
if "postconv" in k:
754-
old_name = k.replace("postconv", "conv")
755-
# new_state_dict[k] = old_state_dict[old_name]
756-
new_state_dict[k] = old_state_dict.pop(old_name)
757-
if len(old_state_dict.keys()) > 0:
758-
logging.info(f"{old_state_dict.keys()} remaining***********")
759-
return new_state_dict
760-
761-
762672
def KL_loss(z_mu, z_sigma):
763673
"""
764674
Compute the Kullback-Leibler (KL) divergence loss for a variational autoencoder (VAE).

0 commit comments

Comments
 (0)