Open
Description
Here is some code I did for the IO team to draw heatmaps from larrays. Unsure whether or not it should be included in larray.
def heatmap(arr : Array, y_axes=None, x_axes=None, numhaxes=1, axes_names=True, ax=None, **kwargs):
"""plot an ND array as a heatmap.
By default it uses the last array axis as the X axis and other array axes as Y axis (like the viewer table).
Only the first axis in each "direction" will have its name and labels shown.
Parameters
----------
arr : Array
data to display.
y_axes : int, str, Axis, tuple or AxisCollection, optional
axis or axes to use on the Y axis. Defaults to all array axes except the last `numhaxes` ones.
x_axes : int, str, Axis, tuple or AxisCollection, optional
axis or axes to use on the X axis. Defaults to all array axes except `y_axes`.
numhaxes : int, optional
if x_axes and y_axes are not specified, use the last numhaxes as X axes. Defaults to 1.
axes_names : bool, optional
whether or not to show axes names. Defaults to True
ax : matplotlib axes object, optional
**kwargs
any extra keyword argument is passed to pcolormesh. Likely of interest are cmap, vmin, vmax, antialiased or shading.
Returns
-------
matplotlib.AxesSubplot
"""
if arr.ndim < 2:
arr = arr.expand(Axis([''], ''))
if y_axes is None:
if x_axes is not None:
y_axes = arr.axes - x_axes
else:
y_axes = arr.axes[:-numhaxes]
else:
if isinstance(y_axes, str):
y_axes = [y_axes]
y_axes = arr.axes[y_axes]
if x_axes is None:
x_axes = arr.axes - y_axes
else:
if isinstance(x_axes, str):
x_axes = [x_axes]
x_axes = arr.axes[x_axes]
arr = arr.transpose(y_axes + x_axes).combine_axes([y_axes, x_axes])
# block size is the size of the other (non first) combined axes
x_block_size = x_axes[1:].size
y_block_size = y_axes[1:].size
if ax is None:
fig, ax = plt.subplots()
ax.pcolormesh(arr, **kwargs)
# place major ticks in the middle of blocks so that labels are centered
xticks = ax.get_xticks()
xlabels = x_axes[0].labels
if len(xlabels) >= len(xticks):
ax.set_xticks([t + x_block_size / 2 for t in xticks.astype(int) if t < len(xlabels)])
ax.set_xticklabels([xlabels[t] for t in ax.get_xticks().astype(int)], rotation=0)
else:
ax.set_xticks(np.arange(0, x_axes.size, x_block_size) + x_block_size / 2)
ax.set_xticklabels(xlabels, rotation=0)
yticks = ax.get_yticks()
ylabels = y_axes[0].labels
if len(ylabels) >= len(yticks):
ax.set_yticks([t + y_block_size / 2 for t in yticks.astype(int) if t < len(ylabels)])
ax.set_yticklabels([ylabels[t] for t in ax.get_yticks().astype(int)], rotation=90, va='center')
else:
ax.set_yticks(np.arange(0, y_axes.size, y_block_size) + y_block_size / 2)
ax.set_yticklabels(ylabels, rotation=90, va='center')
ax.invert_yaxis()
ax.xaxis.tick_top()
ax.xaxis.set_label_position('top')
# enable grid lines for minor ticks on axes when we have several "levels" for that axis
if len(x_axes) > 1:
# place minor ticks for grid lines between each block on the main axis
ax.set_xticks(np.arange(x_block_size, x_axes.size, x_block_size), minor=True)
ax.grid(True, axis='x', which='minor')
# hide ticks on x axis
ax.tick_params(axis='x', which='both', bottom=False, top=False)
if len(y_axes) > 1:
ax.set_yticks(np.arange(y_block_size, y_axes.size, y_block_size), minor=True)
ax.grid(True, axis='y', which='minor')
# hide ticks on y axis
ax.tick_params(axis='y', which='both', left=False, right=False)
# set axes names
if axes_names:
ax.set_xlabel(x_axes[0].name)
ax.set_ylabel(y_axes[0].name)
# hide ticks on both axes
# ax.tick_params(which='both', bottom=False, top=False, left=False, right=False)
return ax