Skip to content

MAISI Quality check #1789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
66ae578
add quality check
Can-Zhao Aug 17, 2024
28f5c96
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2024
7c0bd11
add quality check
Can-Zhao Aug 17, 2024
91ad068
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2024
494ed36
refactor
Can-Zhao Aug 17, 2024
ad5afba
add docstring
Can-Zhao Aug 18, 2024
2d87cd5
Merge branch 'main' into maisi_quality
Can-Zhao Aug 20, 2024
74d4e87
rm unused import, correct typo, add statistics file
Can-Zhao Aug 20, 2024
38daf8a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 20, 2024
38aacc3
np.nanmedian
Can-Zhao Aug 20, 2024
b10550e
Merge branch 'maisi_quality' of https://github.com/Can-Zhao/tutorials…
Can-Zhao Aug 20, 2024
d1e096c
add logging
Can-Zhao Aug 20, 2024
0db8d90
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 20, 2024
4e0c77e
add description on num_label_acceleration_thresh
Can-Zhao Aug 20, 2024
dd57eb3
Merge branch 'maisi_quality' of https://github.com/Can-Zhao/tutorials…
Can-Zhao Aug 20, 2024
36b3f0e
add description on quality check
Can-Zhao Aug 20, 2024
83d8811
add description on input FOV
Can-Zhao Aug 20, 2024
4e80b50
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 20, 2024
2fdedb0
add description on input FOV
Can-Zhao Aug 20, 2024
e1d0821
add description on input FOV
Can-Zhao Aug 20, 2024
88d84cd
Merge branch 'maisi_quality' of https://github.com/Can-Zhao/tutorials…
Can-Zhao Aug 20, 2024
c38039d
typo
Can-Zhao Aug 20, 2024
828c4e3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 20, 2024
b4b2159
typo
Can-Zhao Aug 20, 2024
6caf7ad
add description on input FOV
Can-Zhao Aug 20, 2024
7622ad6
add description on input FOV
Can-Zhao Aug 20, 2024
8266e9d
update checking on input FOV
Can-Zhao Aug 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions generation/maisi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,27 +54,30 @@ MAISI is based on the following papers:
Network definition is stored in [./configs/config_maisi.json](./configs/config_maisi.json). Training and inference should use the same [./configs/config_maisi.json](./configs/config_maisi.json).

### 2. Model Inference
#### Inference parameters:
The information for the inference input, like body region and anatomy to generate, is stored in [./configs/config_infer.json](./configs/config_infer.json). Please feel free to play with it. Here are the details of the parameters.

- `"num_output_samples"`: int, the number of output image/mask pairs it will generate.
- `"spacing"`: voxel size of generated images. E.g., if set to `[1.5, 1.5, 2.0]`, it will generate images with a resolution of 1.5x1.5x2.0 mm.
- `"output_size"`: volume size of generated images. E.g., if set to `[512, 512, 256]`, it will generate images with size of 512x512x256. They need to be divisible by 16. If you have a small GPU memory size, you should adjust it to small numbers.
- `"output_size"`: volume size of generated images. E.g., if set to `[512, 512, 256]`, it will generate images with size of 512x512x256. They need to be divisible by 16. If you have a small GPU memory size, you should adjust it to small numbers. Note that `"spacing"` and `"output_size"` together decide the output field of view (FOV). For eample, if set them to `[1.5, 1.5, 2.0]`mm and `[512, 512, 256]`, the FOV is 768x768x512 mm. We recommend the FOV in x and y axis to be at least 256mm for head, and at least 384mm for other body regions like abdomen. There is no such restriction for z-axis.
- `"controllable_anatomy_size"`: a list of controllable anatomy and its size scale (0--1). E.g., if set to `[["liver", 0.5],["hepatic tumor", 0.3]]`, the generated image will contain liver that have a median size, with size around 50% percentile, and hepatic tumor that is relatively small, with around 30% percentile. The output will contain paired image and segmentation mask for the controllable anatomy.
- `"body_region"`: If "controllable_anatomy_size" is not specified, "body_region" will be used to constrain the region of generated images. It needs to be chosen from "head", "chest", "thorax", "abdomen", "pelvis", "lower".
- `"anatomy_list"`: If "controllable_anatomy_size" is not specified, the output will contain paired image and segmentation mask for the anatomy in "./configs/label_dict.json".
- `"autoencoder_sliding_window_infer_size"`: in order to save GPU memory, we use sliding window inference when decoding latents to image when `"output_size"` is large. This is the patch size of the sliding window. Small value will reduce GPU memory but increase time cost. They need to be divisible by 16.
- `"autoencoder_sliding_window_infer_overlap"`: float between 0 and 1. Large value will reduce the stitching artifacts when stitching patches during sliding window inference, but increase time cost. If you do not observe seam lines in the generated image result, you can use a smaller value to save inference time.


Please refer to [maisi_inference_tutorial.ipynb](maisi_inference_tutorial.ipynb) for the tutorial for MAISI model inference.

#### Execute Inference:
To run the inference script, please run:
```bash
export MONAI_DATA_DIRECTORY=<dir_you_will_download_data>
python -m scripts.inference -c ./configs/config_maisi.json -i ./configs/config_infer.json -e ./configs/environment.json --random-seed 0
```

Please refer to [maisi_inference_tutorial.ipynb](maisi_inference_tutorial.ipynb) for the tutorial for MAISI model inference.

#### Quality Check:
We have implemented a quality check function for the generated CT images. The main idea behind this function is to ensure that the Hounsfield units (HU) intensity for each organ in the CT images remains within a defined range. For each training image used in the Diffusion network, we computed the median value for a few major organs. Then we summarize the statistics of these median values and save it to [./configs/image_median_statistics.json](./configs/image_median_statistics.json). During inference, for each generated image, we compute the median HU values for the major organs and check whether they fall within the normal range.

### 3. Model Training
Training data preparation can be found in [./data/README.md](./data/README.md)

Expand Down
72 changes: 72 additions & 0 deletions generation/maisi/configs/image_median_statistics.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
{
"liver": {
"min_median": -14.0,
"max_median": 1000.0,
"percentile_0_5": 9.530000000000001,
"percentile_99_5": 162.0,
"sigma_6_low": -21.596463547885904,
"sigma_6_high": 156.27881534763367,
"sigma_12_low": -110.53410299564568,
"sigma_12_high": 245.21645479539342
},
"spleen": {
"min_median": -69.0,
"max_median": 1000.0,
"percentile_0_5": 16.925000000000004,
"percentile_99_5": 184.07500000000073,
"sigma_6_low": -43.133891656525165,
"sigma_6_high": 177.40494997185993,
"sigma_12_low": -153.4033124707177,
"sigma_12_high": 287.6743707860525
},
"pancreas": {
"min_median": -124.0,
"max_median": 1000.0,
"percentile_0_5": -29.0,
"percentile_99_5": 145.92000000000007,
"sigma_6_low": -56.59382515620725,
"sigma_6_high": 149.50627399318438,
"sigma_12_low": -159.64387473090306,
"sigma_12_high": 252.5563235678802
},
"kidney": {
"min_median": -165.5,
"max_median": 819.0,
"percentile_0_5": -40.0,
"percentile_99_5": 254.61999999999898,
"sigma_6_low": -130.56375604853028,
"sigma_6_high": 267.28163511081016,
"sigma_12_low": -329.4864516282005,
"sigma_12_high": 466.20433069048045
},
"lung": {
"min_median": -1000.0,
"max_median": 65.0,
"percentile_0_5": -937.0,
"percentile_99_5": -366.9500000000007,
"sigma_6_low": -1088.5583843889117,
"sigma_6_high": -551.8503346949108,
"sigma_12_low": -1356.912409235912,
"sigma_12_high": -283.4963098479103
},
"bone": {
"min_median": 77.5,
"max_median": 1000.0,
"percentile_0_5": 136.45499999999998,
"percentile_99_5": 551.6350000000002,
"sigma_6_low": 71.39901958080469,
"sigma_6_high": 471.9957615639765,
"sigma_12_low": -128.8993514107812,
"sigma_12_high": 672.2941325555623
},
"brain": {
"min_median": -1000.0,
"max_median": 238.0,
"percentile_0_5": -951.0,
"percentile_99_5": 126.25,
"sigma_6_low": -304.8208236135867,
"sigma_6_high": 369.5118535139189,
"sigma_12_low": -641.9871621773394,
"sigma_12_high": 706.6781920776717
}
}
2 changes: 1 addition & 1 deletion generation/maisi/scripts/diff_model_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from .diff_model_setting import initialize_distributed, load_config, setup_logging
from .sample import ReconModel
from .utils import define_instance, load_autoencoder_ckpt
from .utils import define_instance


def set_random_seed(seed: int) -> int:
Expand Down
2 changes: 1 addition & 1 deletion generation/maisi/scripts/infer_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from monai.utils import RankFilter

from .sample import ldm_conditional_sample_one_image
from .utils import define_instance, load_autoencoder_ckpt, prepare_maisi_controlnet_json_dataloader, setup_ddp
from .utils import define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp


@torch.inference_mode()
Expand Down
147 changes: 147 additions & 0 deletions generation/maisi/scripts/quality_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import nibabel as nib
import numpy as np


def get_masked_data(label_data, image_data, labels):
"""
Extracts and returns the image data corresponding to specified labels within a 3D volume.

This function efficiently masks the `image_data` array based on the provided `labels` in the `label_data` array.
The function handles cases with both a large and small number of labels, optimizing performance accordingly.

Args:
label_data (np.ndarray): A NumPy array containing label data, representing different anatomical
regions or classes in a 3D medical image.
image_data (np.ndarray): A NumPy array containing the image data from which the relevant regions
will be extracted.
labels (list of int): A list of integers representing the label values to be used for masking.

Returns:
np.ndarray: A NumPy array containing the elements of `image_data` that correspond to the specified
labels in `label_data`. If no labels are provided, an empty array is returned.

Raises:
ValueError: If `image_data` and `label_data` do not have the same shape.

Example:
label_int_dict = {"liver": [1], "kidney": [5, 14]}
masked_data = get_masked_data(label_data, image_data, label_int_dict["kidney"])
"""

# Check if the shapes of image_data and label_data match
if image_data.shape != label_data.shape:
raise ValueError(
f"Shape mismatch: image_data has shape {image_data.shape}, "
f"but label_data has shape {label_data.shape}. They must be the same."
)

if not labels:
return np.array([]) # Return an empty array if no labels are provided

labels = list(set(labels)) # remove duplicate items

# Optimize performance based on the number of labels
num_label_acceleration_thresh = 3
if len(labels) >= num_label_acceleration_thresh:
# if many labels, np.isin is faster
mask = np.isin(label_data, labels)
else:
# Use logical OR to combine masks if the number of labels is small
mask = np.zeros_like(label_data, dtype=bool)
for label in labels:
mask = np.logical_or(mask, label_data == label)

# Retrieve the masked data
masked_data = image_data[mask.astype(bool)]

return masked_data


def is_outlier(statistics, image_data, label_data, label_int_dict):
"""
Perform a quality check on the generated image by comparing its statistics with precomputed thresholds.

Args:
statistics (dict): Dictionary containing precomputed statistics including mean +/- 3sigma ranges.
image_data (np.ndarray): The image data to be checked, typically a 3D NumPy array.
label_data (np.ndarray): The label data corresponding to the image, used for masking regions of interest.
label_int_dict (dict): Dictionary mapping label names to their corresponding integer lists.
e.g., label_int_dict = {"liver": [1], "kidney": [5, 14]}

Returns:
dict: A dictionary with labels as keys, each containing the quality check result,
including whether it's an outlier, the median value, and the thresholds used.
If no data is found for a label, the median value will be `None` and `is_outlier` will be `False`.

Example:
# Example input data
statistics = {
"liver": {
"sigma_6_low": -21.596463547885904,
"sigma_6_high": 156.27881534763367
},
"kidney": {
"sigma_6_low": -15.0,
"sigma_6_high": 120.0
}
}
label_int_dict = {
"liver": [1],
"kidney": [5, 14]
}
image_data = np.random.rand(100, 100, 100) # Replace with actual image data
label_data = np.zeros((100, 100, 100)) # Replace with actual label data
label_data[40:60, 40:60, 40:60] = 1 # Example region for liver
label_data[70:90, 70:90, 70:90] = 5 # Example region for kidney
result = is_outlier(statistics, image_data, label_data, label_int_dict)
"""
outlier_results = {}

for label_name, stats in statistics.items():
# Get the thresholds from the statistics
low_thresh = stats["sigma_6_low"] # or "sigma_12_low" depending on your needs
high_thresh = stats["sigma_6_high"] # or "sigma_12_high" depending on your needs

# Retrieve the corresponding label integers
labels = label_int_dict.get(label_name, [])
masked_data = get_masked_data(label_data, image_data, labels)
masked_data = masked_data[~np.isnan(masked_data)]

if len(masked_data) == 0 or masked_data.size == 0:
outlier_results[label_name] = {
"is_outlier": False,
"median_value": None,
"low_thresh": low_thresh,
"high_thresh": high_thresh,
}
continue

# Compute the median of the masked region
median_value = np.nanmedian(masked_data)

if np.isnan(median_value):
median_value = None
is_outlier = False
else:
# Determine if the median value is an outlier
is_outlier = median_value < low_thresh or median_value > high_thresh

outlier_results[label_name] = {
"is_outlier": is_outlier,
"median_value": median_value,
"low_thresh": low_thresh,
"high_thresh": high_thresh,
}

return outlier_results
Loading
Loading