|
| 1 | +# Copyright (c) MONAI Consortium |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | +import nibabel as nib |
| 13 | +import numpy as np |
| 14 | + |
| 15 | + |
| 16 | +def get_masked_data(label_data, image_data, labels): |
| 17 | + """ |
| 18 | + Extracts and returns the image data corresponding to specified labels within a 3D volume. |
| 19 | +
|
| 20 | + This function efficiently masks the `image_data` array based on the provided `labels` in the `label_data` array. |
| 21 | + The function handles cases with both a large and small number of labels, optimizing performance accordingly. |
| 22 | +
|
| 23 | + Args: |
| 24 | + label_data (np.ndarray): A NumPy array containing label data, representing different anatomical |
| 25 | + regions or classes in a 3D medical image. |
| 26 | + image_data (np.ndarray): A NumPy array containing the image data from which the relevant regions |
| 27 | + will be extracted. |
| 28 | + labels (list of int): A list of integers representing the label values to be used for masking. |
| 29 | +
|
| 30 | + Returns: |
| 31 | + np.ndarray: A NumPy array containing the elements of `image_data` that correspond to the specified |
| 32 | + labels in `label_data`. If no labels are provided, an empty array is returned. |
| 33 | +
|
| 34 | + Raises: |
| 35 | + ValueError: If `image_data` and `label_data` do not have the same shape. |
| 36 | +
|
| 37 | + Example: |
| 38 | + label_int_dict = {"liver": [1], "kidney": [5, 14]} |
| 39 | + masked_data = get_masked_data(label_data, image_data, label_int_dict["kidney"]) |
| 40 | + """ |
| 41 | + |
| 42 | + # Check if the shapes of image_data and label_data match |
| 43 | + if image_data.shape != label_data.shape: |
| 44 | + raise ValueError( |
| 45 | + f"Shape mismatch: image_data has shape {image_data.shape}, " |
| 46 | + f"but label_data has shape {label_data.shape}. They must be the same." |
| 47 | + ) |
| 48 | + |
| 49 | + if not labels: |
| 50 | + return np.array([]) # Return an empty array if no labels are provided |
| 51 | + |
| 52 | + labels = list(set(labels)) # remove duplicate items |
| 53 | + |
| 54 | + # Optimize performance based on the number of labels |
| 55 | + num_label_acceleration_thresh = 3 |
| 56 | + if len(labels) >= num_label_acceleration_thresh: |
| 57 | + # if many labels, np.isin is faster |
| 58 | + mask = np.isin(label_data, labels) |
| 59 | + else: |
| 60 | + # Use logical OR to combine masks if the number of labels is small |
| 61 | + mask = np.zeros_like(label_data, dtype=bool) |
| 62 | + for label in labels: |
| 63 | + mask = np.logical_or(mask, label_data == label) |
| 64 | + |
| 65 | + # Retrieve the masked data |
| 66 | + masked_data = image_data[mask.astype(bool)] |
| 67 | + |
| 68 | + return masked_data |
| 69 | + |
| 70 | + |
| 71 | +def is_outlier(statistics, image_data, label_data, label_int_dict): |
| 72 | + """ |
| 73 | + Perform a quality check on the generated image by comparing its statistics with precomputed thresholds. |
| 74 | +
|
| 75 | + Args: |
| 76 | + statistics (dict): Dictionary containing precomputed statistics including mean +/- 3sigma ranges. |
| 77 | + image_data (np.ndarray): The image data to be checked, typically a 3D NumPy array. |
| 78 | + label_data (np.ndarray): The label data corresponding to the image, used for masking regions of interest. |
| 79 | + label_int_dict (dict): Dictionary mapping label names to their corresponding integer lists. |
| 80 | + e.g., label_int_dict = {"liver": [1], "kidney": [5, 14]} |
| 81 | +
|
| 82 | + Returns: |
| 83 | + dict: A dictionary with labels as keys, each containing the quality check result, |
| 84 | + including whether it's an outlier, the median value, and the thresholds used. |
| 85 | + If no data is found for a label, the median value will be `None` and `is_outlier` will be `False`. |
| 86 | +
|
| 87 | + Example: |
| 88 | + # Example input data |
| 89 | + statistics = { |
| 90 | + "liver": { |
| 91 | + "sigma_6_low": -21.596463547885904, |
| 92 | + "sigma_6_high": 156.27881534763367 |
| 93 | + }, |
| 94 | + "kidney": { |
| 95 | + "sigma_6_low": -15.0, |
| 96 | + "sigma_6_high": 120.0 |
| 97 | + } |
| 98 | + } |
| 99 | + label_int_dict = { |
| 100 | + "liver": [1], |
| 101 | + "kidney": [5, 14] |
| 102 | + } |
| 103 | + image_data = np.random.rand(100, 100, 100) # Replace with actual image data |
| 104 | + label_data = np.zeros((100, 100, 100)) # Replace with actual label data |
| 105 | + label_data[40:60, 40:60, 40:60] = 1 # Example region for liver |
| 106 | + label_data[70:90, 70:90, 70:90] = 5 # Example region for kidney |
| 107 | + result = is_outlier(statistics, image_data, label_data, label_int_dict) |
| 108 | + """ |
| 109 | + outlier_results = {} |
| 110 | + |
| 111 | + for label_name, stats in statistics.items(): |
| 112 | + # Get the thresholds from the statistics |
| 113 | + low_thresh = stats["sigma_6_low"] # or "sigma_12_low" depending on your needs |
| 114 | + high_thresh = stats["sigma_6_high"] # or "sigma_12_high" depending on your needs |
| 115 | + |
| 116 | + # Retrieve the corresponding label integers |
| 117 | + labels = label_int_dict.get(label_name, []) |
| 118 | + masked_data = get_masked_data(label_data, image_data, labels) |
| 119 | + masked_data = masked_data[~np.isnan(masked_data)] |
| 120 | + |
| 121 | + if len(masked_data) == 0 or masked_data.size == 0: |
| 122 | + outlier_results[label_name] = { |
| 123 | + "is_outlier": False, |
| 124 | + "median_value": None, |
| 125 | + "low_thresh": low_thresh, |
| 126 | + "high_thresh": high_thresh, |
| 127 | + } |
| 128 | + continue |
| 129 | + |
| 130 | + # Compute the median of the masked region |
| 131 | + median_value = np.nanmedian(masked_data) |
| 132 | + |
| 133 | + if np.isnan(median_value): |
| 134 | + median_value = None |
| 135 | + is_outlier = False |
| 136 | + else: |
| 137 | + # Determine if the median value is an outlier |
| 138 | + is_outlier = median_value < low_thresh or median_value > high_thresh |
| 139 | + |
| 140 | + outlier_results[label_name] = { |
| 141 | + "is_outlier": is_outlier, |
| 142 | + "median_value": median_value, |
| 143 | + "low_thresh": low_thresh, |
| 144 | + "high_thresh": high_thresh, |
| 145 | + } |
| 146 | + |
| 147 | + return outlier_results |
0 commit comments