Skip to content

Commit c7031c3

Browse files
committed
RF: Add simpler Analyze-compatible helper
1 parent d541fdd commit c7031c3

File tree

1 file changed

+78
-3
lines changed

1 file changed

+78
-3
lines changed

nibabel/nifti1.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2153,10 +2153,10 @@ def get_data_dtype(self, finalize=False):
21532153

21542154
datatype = None
21552155
if self._dtype_alias == 'compat':
2156-
datatype = _get_smallest_dtype(self._dataobj)
2156+
datatype = _get_analyze_compat_dtype(self._dataobj)
21572157
descrip = "an Analyze-compatible dtype"
21582158
elif self._dtype_alias == 'smallest':
2159-
datatype = _get_smallest_dtype(self._dataobj, ftypes=())
2159+
datatype = _get_smallest_dtype(self._dataobj)
21602160
descrip = "an integer type with fewer than 64 bits"
21612161
else:
21622162
raise ValueError(f"Unknown dtype alias {self._dtype_alias}.")
@@ -2266,7 +2266,7 @@ def save(img, filename):
22662266
def _get_smallest_dtype(
22672267
arr,
22682268
itypes=(np.uint8, np.int16, np.int32),
2269-
ftypes=(np.float32,),
2269+
ftypes=(),
22702270
):
22712271
""" Return the smallest "sensible" dtype that will hold the array data
22722272
@@ -2310,3 +2310,78 @@ def _get_smallest_dtype(
23102310
dtinfo = info(dt)
23112311
if dtinfo.min <= mn and mx <= dtinfo.max:
23122312
return np.dtype(dt)
2313+
2314+
2315+
def _get_analyze_compat_dtype(arr):
2316+
""" Return an Analyze-compatible dtype that ``arr`` can be safely cast to
2317+
2318+
Analyze-compatible types are returned without inspection:
2319+
2320+
>>> _get_analyze_compat_dtype(np.uint8([0, 1]))
2321+
dtype('uint8')
2322+
>>> _get_analyze_compat_dtype(np.int16([0, 1]))
2323+
dtype('int16')
2324+
>>> _get_analyze_compat_dtype(np.int32([0, 1]))
2325+
dtype('int32')
2326+
>>> _get_analyze_compat_dtype(np.float32([0, 1]))
2327+
dtype('float32')
2328+
2329+
Signed ``int8`` are cast to ``uint8`` or ``int16`` based on value ranges:
2330+
2331+
>>> _get_analyze_compat_dtype(np.int8([0, 1]))
2332+
dtype('uint8')
2333+
>>> _get_analyze_compat_dtype(np.int8([-1, 1]))
2334+
dtype('int16')
2335+
2336+
Unsigned ``uint16`` are cast to ``int16`` or ``int32`` based on value ranges:
2337+
2338+
>>> _get_analyze_compat_dtype(np.uint16([32767]))
2339+
dtype('int16')
2340+
>>> _get_analyze_compat_dtype(np.uint16([65535]))
2341+
dtype('int32')
2342+
2343+
``int32`` is returned for integer types and ``float32`` for floating point types:
2344+
2345+
>>> _get_analyze_compat_dtype(np.array([-1, 1]))
2346+
dtype('int32')
2347+
>>> _get_analyze_compat_dtype(np.array([-1., 1.]))
2348+
dtype('float32')
2349+
2350+
If the value ranges exceed 4 bytes or cannot be cast, then a ``ValueError`` is raised:
2351+
2352+
>>> _get_analyze_compat_dtype(np.array([0, 4294967295]))
2353+
Traceback (most recent call last):
2354+
...
2355+
ValueError: Cannot find analyze-compatible dtype for array with dtype=int64 (min=0, max=4294967295)
2356+
2357+
>>> _get_analyze_compat_dtype([0., 2.e40])
2358+
Traceback (most recent call last):
2359+
...
2360+
ValueError: Cannot find analyze-compatible dtype for array with dtype=float64 (min=0.0, max=2e+40)
2361+
2362+
Note that real-valued complex arrays cannot be safely cast.
2363+
2364+
>>> _get_analyze_compat_dtype(np.array([1+0j]))
2365+
Traceback (most recent call last):
2366+
...
2367+
ValueError: Cannot find analyze-compatible dtype for array with dtype=complex128 (min=(1+0j), max=(1+0j))
2368+
"""
2369+
arr = np.asanyarray(arr)
2370+
dtype = arr.dtype
2371+
if dtype in (np.uint8, np.int16, np.int32, np.float32):
2372+
return dtype
2373+
2374+
if dtype == np.int8:
2375+
return np.dtype('uint8' if arr.min() >= 0 else 'int16')
2376+
elif dtype == np.uint16:
2377+
return np.dtype('int16' if arr.max() <= np.iinfo(np.int16).max else 'int32')
2378+
2379+
mn, mx = arr.min(), arr.max()
2380+
if np.can_cast(mn, np.int32) and np.can_cast(mx, np.int32):
2381+
return np.dtype('int32')
2382+
if np.can_cast(mn, np.float32) and np.can_cast(mx, np.float32):
2383+
return np.dtype('float32')
2384+
2385+
raise ValueError(
2386+
f"Cannot find analyze-compatible dtype for array with {dtype=!s} (min={mn}, max={mx})"
2387+
)

0 commit comments

Comments
 (0)