diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a0ee535..8927e46 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ default_language_version: python: python3 repos: - repo: https://github.com/ambv/black - rev: stable + rev: 23.3.0 hooks: - id: black - repo: https://github.com/pre-commit/pre-commit-hooks @@ -10,6 +10,6 @@ repos: hooks: - id: flake8 - repo: https://github.com/kynan/nbstripout - rev: 0.3.9 + rev: 0.6.1 hooks: - id: nbstripout diff --git a/data_prototype/containers.py b/data_prototype/containers.py index 7575c97..4d87446 100644 --- a/data_prototype/containers.py +++ b/data_prototype/containers.py @@ -117,7 +117,9 @@ def query( coord_transform: _MatplotlibTransform, size: Tuple[int, int], ) -> Tuple[Dict[str, Any], Union[str, int]]: - return {k: np.random.randn(*d.shape) for k, d in self._desc.items()}, str(uuid.uuid4()) + return {k: np.random.randn(*d.shape) for k, d in self._desc.items()}, str( + uuid.uuid4() + ) def describe(self) -> Dict[str, Desc]: return dict(self._desc) @@ -127,9 +129,15 @@ class FuncContainer: def __init__( self, # TODO: is this really the best spelling?! - xfuncs: Optional[Dict[str, Tuple[Tuple[Union[str, int], ...], Callable[[Any], Any]]]] = None, - yfuncs: Optional[Dict[str, Tuple[Tuple[Union[str, int], ...], Callable[[Any], Any]]]] = None, - xyfuncs: Optional[Dict[str, Tuple[Tuple[Union[str, int], ...], Callable[[Any, Any], Any]]]] = None, + xfuncs: Optional[ + Dict[str, Tuple[Tuple[Union[str, int], ...], Callable[[Any], Any]]] + ] = None, + yfuncs: Optional[ + Dict[str, Tuple[Tuple[Union[str, int], ...], Callable[[Any], Any]]] + ] = None, + xyfuncs: Optional[ + Dict[str, Tuple[Tuple[Union[str, int], ...], Callable[[Any, Any], Any]]] + ] = None, ): """ A container that wraps several functions. They are split into 3 categories: @@ -274,7 +282,10 @@ def query( coord_transform: _MatplotlibTransform, size: Tuple[int, int], ) -> Tuple[Dict[str, Any], Union[str, int]]: - return {self._index_name: self._data.index.values, self._col_name: self._data.values}, self._hash_key + return { + self._index_name: self._data.index.values, + self._col_name: self._data.values, + }, self._hash_key def describe(self) -> Dict[str, Desc]: return dict(self._desc) diff --git a/data_prototype/conversion_node.py b/data_prototype/conversion_node.py index fefc8bd..834896c 100644 --- a/data_prototype/conversion_node.py +++ b/data_prototype/conversion_node.py @@ -50,7 +50,9 @@ def from_nodes(cls, *nodes: ConversionNode, trim_keys=False): return cls(required, tuple(output), trim_keys, nodes) def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: - return super().evaluate({k: v for n in self.nodes for k, v in n.evaluate(input).items()}) + return super().evaluate( + {k: v for n in self.nodes for k, v in n.evaluate(input).items()} + ) @dataclass @@ -66,7 +68,9 @@ def from_mapping(cls, mapping: dict[str, str], trim_keys=False): return cls(required, tuple(output), trim_keys, mapping) def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: - return super().evaluate({**input, **{out: input[inp] for (inp, out) in self.mapping.items()}}) + return super().evaluate( + {**input, **{out: input[inp] for (inp, out) in self.mapping.items()}} + ) @dataclass @@ -91,7 +95,10 @@ def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: return super().evaluate( { **input, - **{k: func(**{p: input[p] for p in sig.parameters}) for (k, (func, sig)) in self._sigs.items()}, + **{ + k: func(**{p: input[p] for p in sig.parameters}) + for (k, (func, sig)) in self._sigs.items() + }, } ) diff --git a/data_prototype/patches.py b/data_prototype/patches.py index d993f2f..b062024 100644 --- a/data_prototype/patches.py +++ b/data_prototype/patches.py @@ -50,7 +50,11 @@ def __init__(self, data: DataContainer, converters=None, /, **kwargs): @_stale_wrapper def draw(self, renderer): - self._update_wrapped(self._query_and_transform(renderer, xunits=self._xunits, yunits=self._yunits)) + self._update_wrapped( + self._query_and_transform( + renderer, xunits=self._xunits, yunits=self._yunits + ) + ) return self._wrapped_instance.draw(renderer) def _update_wrapped(self, data): @@ -75,7 +79,14 @@ class RectangleWrapper(PatchWrapper): ) _xunits = ("x", "width") _yunits = ("y", "height") - required_keys = PatchWrapper.required_keys | {"x", "y", "width", "height", "angle", "rotation_point"} + required_keys = PatchWrapper.required_keys | { + "x", + "y", + "width", + "height", + "angle", + "rotation_point", + } def _update_wrapped(self, data): for k, v in data.items(): diff --git a/data_prototype/tests/test_containers.py b/data_prototype/tests/test_containers.py index 4fadefb..fb2fc7d 100644 --- a/data_prototype/tests/test_containers.py +++ b/data_prototype/tests/test_containers.py @@ -10,7 +10,9 @@ @pytest.fixture def ac(): - return containers.ArrayContainer(a=np.arange(5), b=np.arange(42, dtype=float).reshape(6, 7)) + return containers.ArrayContainer( + a=np.arange(5), b=np.arange(42, dtype=float).reshape(6, 7) + ) def _verify_describe(container): diff --git a/data_prototype/wrappers.py b/data_prototype/wrappers.py index e447e63..14ac51c 100644 --- a/data_prototype/wrappers.py +++ b/data_prototype/wrappers.py @@ -12,7 +12,10 @@ from matplotlib.patches import StepPatch as _StepPatch from matplotlib.text import Text as _Text import matplotlib.transforms as mtransforms -from matplotlib.collections import LineCollection as _LineCollection, PathCollection as _PathCollection +from matplotlib.collections import ( + LineCollection as _LineCollection, + PathCollection as _PathCollection, +) from matplotlib.artist import Artist as _Artist from data_prototype.containers import DataContainer, _MatplotlibTransform @@ -60,7 +63,9 @@ def identity(**kwargs): (_,) = kwargs.values() return _ - identity.__signature__ = inspect.Signature([inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD)]) + identity.__signature__ = inspect.Signature( + [inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD)] + ) return identity @@ -116,7 +121,9 @@ def draw(self, renderer): def _update_wrapped(self, data): raise NotImplementedError - def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str]) -> Dict[str, Any]: + def _query_and_transform( + self, renderer, *, xunits: List[str], yunits: List[str] + ) -> Dict[str, Any]: """ Helper to centralize the data querying and python-side transforms @@ -152,7 +159,9 @@ def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str] self._cache[cache_key] = transformed_data return transformed_data - def __init__(self, data, converters: ConversionNode | list[ConversionNode] | None, **kwargs): + def __init__( + self, data, converters: ConversionNode | list[ConversionNode] | None, **kwargs + ): super().__init__(**kwargs) self.data = data self._cache = LFUCache(64) @@ -180,9 +189,18 @@ def __getattr__(self, key): return getattr(self._wrapped_instance, key) def __setattr__(self, key, value): - if key in ("_wrapped_instance", "data", "_cache", "_converters", "stale", "_sigs"): + if key in ( + "_wrapped_instance", + "data", + "_cache", + "_converters", + "stale", + "_sigs", + ): super().__setattr__(key, value) - elif hasattr(self, "_wrapped_instance") and hasattr(self._wrapped_instance, key): + elif hasattr(self, "_wrapped_instance") and hasattr( + self._wrapped_instance, key + ): setattr(self._wrapped_instance, key, value) else: super().__setattr__(key, value) @@ -190,13 +208,24 @@ def __setattr__(self, key, value): class LineWrapper(ProxyWrapper): _wrapped_class = _Line2D - _privtized_methods = ("set_xdata", "set_ydata", "set_data", "get_xdata", "get_ydata", "get_data") + _privtized_methods = ( + "set_xdata", + "set_ydata", + "set_data", + "get_xdata", + "get_ydata", + "get_data", + ) required_keys = {"x", "y"} def __init__(self, data: DataContainer, converters=None, /, **kwargs): super().__init__(data, converters) - self._wrapped_instance = self._wrapped_class(np.array([]), np.array([]), **kwargs) - self._converters.insert(-1, RenameConversionNode.from_mapping({"x": "xdata", "y": "ydata"})) + self._wrapped_instance = self._wrapped_class( + np.array([]), np.array([]), **kwargs + ) + self._converters.insert( + -1, RenameConversionNode.from_mapping({"x": "xdata", "y": "ydata"}) + ) setters = [f[4:] for f in dir(self._wrapped_class) if f.startswith("set_")] self._converters[-1] = LimitKeysConversionNode.from_keys(setters) @@ -252,7 +281,9 @@ class ImageWrapper(ProxyWrapper): _wrapped_class = _AxesImage required_keys = {"xextent", "yextent", "image"} - def __init__(self, data: DataContainer, converters=None, /, cmap=None, norm=None, **kwargs): + def __init__( + self, data: DataContainer, converters=None, /, cmap=None, norm=None, **kwargs + ): converters = converters or [] if cmap is not None or norm is not None: if converters is not None and "image" in converters: @@ -261,7 +292,11 @@ def __init__(self, data: DataContainer, converters=None, /, cmap=None, norm=None cmap = mpl.colormaps["viridis"] if norm is None: raise ValueError("not sure how to do autoscaling yet") - converters.append(FunctionConversionNode.from_funcs({"image": lambda image: cmap(norm(image))})) + converters.append( + FunctionConversionNode.from_funcs( + {"image": lambda image: cmap(norm(image))} + ) + ) super().__init__(data, converters) kwargs.setdefault("origin", "lower") self._wrapped_instance = self._wrapped_class(None, **kwargs) @@ -341,7 +376,9 @@ def __setattr__(self, key, value): super().__setattr__(key, value) if hasattr(self, "_wrapped_instances"): # We can end up with out wrapped instance as part of init - children_have_attrs = [hasattr(c, key) for c in self._wrapped_instances.values()] + children_have_attrs = [ + hasattr(c, key) for c in self._wrapped_instances.values() + ] if any(children_have_attrs): if not all(children_have_attrs): raise Exception("mixed attributes 😱") @@ -356,7 +393,9 @@ def get_children(self): class ErrorbarWrapper(MultiProxyWrapper): required_keys = {"x", "y"} - expected_keys = {f"{axis}{dirc}" for axis in ["x", "y"] for dirc in ["upper", "lower"]} + expected_keys = { + f"{axis}{dirc}" for axis in ["x", "y"] for dirc in ["upper", "lower"] + } def __init__(self, data: DataContainer, converters=None, /, **kwargs): super().__init__(data, converters) @@ -387,7 +426,9 @@ def __init__(self, data: DataContainer, converters=None, /, **kwargs): def draw(self, renderer): self._update_wrapped( self._query_and_transform( - renderer, xunits=["x", "xupper", "xlower"], yunits=["y", "yupper", "ylower"] + renderer, + xunits=["x", "xupper", "xlower"], + yunits=["y", "yupper", "ylower"], ), ) for k, v in self._wrapped_instances.items(): diff --git a/docs/source/conf.py b/docs/source/conf.py index 03e060d..c6fc468 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -166,7 +166,11 @@ def matplotlib_reduced_latex_scraper(block, block_vars, gallery_conf, **kwargs): # html_logo = "_static/logo2.svg" html_theme_options = { - "logo": {"link": "index", "image_light": "images/logo2.svg", "image_dark": "images/logo_dark.svg"}, + "logo": { + "link": "index", + "image_light": "images/logo2.svg", + "image_dark": "images/logo_dark.svg", + }, } # Add any paths that contain custom static files (such as style sheets) here, @@ -214,7 +218,13 @@ def matplotlib_reduced_latex_scraper(block, block_vars, gallery_conf, **kwargs): # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, "data_prototype.tex", "data_prototype Documentation", "Contributors", "manual"), + ( + master_doc, + "data_prototype.tex", + "data_prototype Documentation", + "Contributors", + "manual", + ), ] @@ -222,7 +232,9 @@ def matplotlib_reduced_latex_scraper(block, block_vars, gallery_conf, **kwargs): # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [(master_doc, "data_prototype", "data_prototype Documentation", [author], 1)] +man_pages = [ + (master_doc, "data_prototype", "data_prototype Documentation", [author], 1) +] # -- Options for Texinfo output ------------------------------------------- diff --git a/examples/2Dfunc.py b/examples/2Dfunc.py index 623a7d1..deaad36 100644 --- a/examples/2Dfunc.py +++ b/examples/2Dfunc.py @@ -21,7 +21,10 @@ xyfuncs={ "xextent": ((2,), lambda x, y: [x[0], x[-1]]), "yextent": ((2,), lambda x, y: [y[0], y[-1]]), - "image": (("N", "M"), lambda x, y: np.sin(x).reshape(1, -1) * np.cos(y).reshape(-1, 1)), + "image": ( + ("N", "M"), + lambda x, y: np.sin(x).reshape(1, -1) * np.cos(y).reshape(-1, 1), + ), }, ) norm = Normalize(vmin=-1, vmax=1) diff --git a/examples/data_frame.py b/examples/data_frame.py index 00eaff7..19f8392 100644 --- a/examples/data_frame.py +++ b/examples/data_frame.py @@ -16,7 +16,9 @@ th = np.linspace(0, 4 * np.pi, 256) -dc1 = DataFrameContainer(pd.DataFrame({"x": th, "y": np.cos(th)}), index_name=None, col_names=lambda n: n) +dc1 = DataFrameContainer( + pd.DataFrame({"x": th, "y": np.cos(th)}), index_name=None, col_names=lambda n: n +) df = pd.DataFrame( { diff --git a/examples/errorbar.py b/examples/errorbar.py index 6e0451e..fcf03ff 100644 --- a/examples/errorbar.py +++ b/examples/errorbar.py @@ -21,7 +21,9 @@ xupper = x + 0.5 xlower = x - 0.5 -ac = ArrayContainer(x=x, y=y, yupper=yupper, ylower=ylower, xlower=xlower, xupper=xupper) +ac = ArrayContainer( + x=x, y=y, yupper=yupper, ylower=ylower, xlower=xlower, xupper=xupper +) fig, ax = plt.subplots() diff --git a/examples/hist.py b/examples/hist.py index 1853862..4fd45d3 100644 --- a/examples/hist.py +++ b/examples/hist.py @@ -13,7 +13,9 @@ from data_prototype.wrappers import StepWrapper from data_prototype.containers import HistContainer -hc = HistContainer(np.concatenate([np.random.randn(5000), 0.1 * np.random.randn(500) + 5]), 25) +hc = HistContainer( + np.concatenate([np.random.randn(5000), 0.1 * np.random.randn(500) + 5]), 25 +) fig, (ax1, ax2) = plt.subplots(1, 2, layout="constrained") diff --git a/examples/lissajous.py b/examples/lissajous.py index 0a9c710..6563e31 100644 --- a/examples/lissajous.py +++ b/examples/lissajous.py @@ -46,7 +46,9 @@ def query( ) -> Tuple[Dict[str, Any], Union[str, int]]: def next_time(): cur_time = time.time() - cur_time = np.array([cur_time, cur_time - 0.1, cur_time - 0.2, cur_time - 0.3]) + cur_time = np.array( + [cur_time, cur_time - 0.1, cur_time - 0.2, cur_time - 0.3] + ) phase = 15 * np.pi * (self.scale * cur_time % 60) / 150 marker_obj = mmarkers.MarkerStyle("o") @@ -54,7 +56,9 @@ def next_time(): "x": np.cos(5 * phase), "y": np.sin(3 * phase), "sizes": np.array([256]), - "paths": [marker_obj.get_path().transformed(marker_obj.get_transform())], + "paths": [ + marker_obj.get_path().transformed(marker_obj.get_transform()) + ], "edgecolors": "k", "facecolors": ["#4682b4ff", "#82b446aa", "#46b48288", "#8246b433"], "time": cur_time[0], diff --git a/examples/widgets.py b/examples/widgets.py index 32a824f..cfa25e3 100644 --- a/examples/widgets.py +++ b/examples/widgets.py @@ -35,7 +35,9 @@ def get_needed_keys(f, offset=1): s, # this line binds the correct sliders to the functions # and makes lambdas that match the API FuncContainer needs - lambda x, keys=get_needed_keys(f), f=f: f(x, *(sliders[k].val for k in keys)), + lambda x, keys=get_needed_keys(f), f=f: f( + x, *(sliders[k].val for k in keys) + ), ) for k, (s, f) in xfuncs.items() }, @@ -104,7 +106,8 @@ def _query_hash(self, coord_transform, size): "y": ( ("N",), # the y data needs all three sliders - lambda t, amplitude, frequency, phase: amplitude * np.sin(2 * np.pi * frequency * t + phase), + lambda t, amplitude, frequency, phase: amplitude + * np.sin(2 * np.pi * frequency * t + phase), ), # the color data has to take the x (because reasons), but just # needs the phase @@ -118,7 +121,9 @@ def _query_hash(self, coord_transform, size): lw = LineWrapper( fc, # color map phase (scaled to 2pi and wrapped to [0, 1]) - FunctionConversionNode.from_funcs({"color": lambda color: cmap((color / (2 * np.pi)) % 1)}), + FunctionConversionNode.from_funcs( + {"color": lambda color: cmap((color / (2 * np.pi)) % 1)} + ), lw=5, ) ax.add_artist(lw) @@ -127,6 +132,8 @@ def _query_hash(self, coord_transform, size): # Create a `matplotlib.widgets.Button` to reset the sliders to initial values. resetax = fig.add_axes([0.8, 0.025, 0.1, 0.04]) button = Button(resetax, "Reset", hovercolor="0.975") -button.on_clicked(lambda event: [sld.reset() for sld in (freq_slider, amp_slider, phase_slider)]) +button.on_clicked( + lambda event: [sld.reset() for sld in (freq_slider, amp_slider, phase_slider)] +) plt.show() diff --git a/pyproject.toml b/pyproject.toml index 3239179..8ce676f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [tool.black] -line-length = 115 +line-length = 88 include = '\.pyi?$' exclude = ''' /( diff --git a/setup.py b/setup.py index f164436..3e0bc45 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,11 @@ with open(path.join(here, "requirements.txt")) as requirements_file: # Parse requirements.txt, ignoring any commented-out lines. - requirements = [line for line in requirements_file.read().splitlines() if not line.startswith("#")] + requirements = [ + line + for line in requirements_file.read().splitlines() + if not line.startswith("#") + ] setup(