diff --git a/pyproject.toml b/pyproject.toml index cd124a4..948ffb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,11 +32,7 @@ classifiers = [ # NB: Keep this in sync with environment.yml AND dev-environment.yml! requires-python = ">=3.9" -dependencies = [ - "jpype1 >= 1.3.0", - "jgo", - "cjdk", -] +dependencies = ["jpype1 >= 1.3.0", "jgo", "cjdk", "stubgenj"] [project.optional-dependencies] # NB: Keep this in sync with dev-environment.yml! @@ -50,9 +46,17 @@ dev = [ "pandas", "ruff", "toml", - "validate-pyproject[all]" + "validate-pyproject[all]", ] +[project.scripts] +scyjava-stubgen = "scyjava._stubs._cli:main" + +[project.entry-points.hatch] +scyjava = "scyjava._stubs._hatchling_plugin" +[project.entry-points."distutils.commands"] +build_py = "scyjava_stubgen.build:build_py" + [project.urls] homepage = "https://github.com/scijava/scyjava" documentation = "https://github.com/scijava/scyjava/blob/main/README.md" diff --git a/src/scyjava/_jvm.py b/src/scyjava/_jvm.py index 224ac61..2f0e4a8 100644 --- a/src/scyjava/_jvm.py +++ b/src/scyjava/_jvm.py @@ -363,7 +363,7 @@ def is_awt_initialized() -> bool: return False Thread = scyjava.jimport("java.lang.Thread") threads = Thread.getAllStackTraces().keySet() - return any(t.getName().startsWith("AWT-") for t in threads) + return any(str(t.getName()).startswith("AWT-") for t in threads) def when_jvm_starts(f) -> None: diff --git a/src/scyjava/_stubs/__init__.py b/src/scyjava/_stubs/__init__.py new file mode 100644 index 0000000..d6a5e7c --- /dev/null +++ b/src/scyjava/_stubs/__init__.py @@ -0,0 +1,4 @@ +from ._dynamic_import import setup_java_imports +from ._genstubs import generate_stubs + +__all__ = ["setup_java_imports", "generate_stubs"] diff --git a/src/scyjava/_stubs/_cli.py b/src/scyjava/_stubs/_cli.py new file mode 100644 index 0000000..3a4d7df --- /dev/null +++ b/src/scyjava/_stubs/_cli.py @@ -0,0 +1,154 @@ +"""The scyjava-stubs executable.""" + +from __future__ import annotations + +import argparse +import importlib +import importlib.util +import logging +import sys +from pathlib import Path + +from ._genstubs import generate_stubs + + +def main() -> None: + """The main entry point for the scyjava-stubs executable.""" + logging.basicConfig(level="INFO") + parser = argparse.ArgumentParser( + description="Generate Python Type Stubs for Java classes." + ) + parser.add_argument( + "endpoints", + type=str, + nargs="+", + help="Maven endpoints to install and use (e.g. org.myproject:myproject:1.0.0)", + ) + parser.add_argument( + "--prefix", + type=str, + help="package prefixes to generate stubs for (e.g. org.myproject), " + "may be used multiple times. If not specified, prefixes are gleaned from the " + "downloaded artifacts.", + action="append", + default=[], + metavar="PREFIX", + dest="prefix", + ) + path_group = parser.add_mutually_exclusive_group() + path_group.add_argument( + "--output-dir", + type=str, + default=None, + help="Filesystem path to write stubs to.", + ) + path_group.add_argument( + "--output-python-path", + type=str, + default=None, + help="Python path to write stubs to (e.g. 'scyjava.types').", + ) + parser.add_argument( + "--convert-strings", + dest="convert_strings", + action="store_true", + default=False, + help="convert java.lang.String to python str in return types. " + "consult the JPype documentation on the convertStrings flag for details", + ) + parser.add_argument( + "--no-javadoc", + dest="with_javadoc", + action="store_false", + default=True, + help="do not generate docstrings from JavaDoc where available", + ) + + rt_group = parser.add_mutually_exclusive_group() + rt_group.add_argument( + "--runtime-imports", + dest="runtime_imports", + action="store_true", + default=True, + help="Add runtime imports to the generated stubs. ", + ) + rt_group.add_argument( + "--no-runtime-imports", dest="runtime_imports", action="store_false" + ) + + parser.add_argument( + "--remove-namespace-only-stubs", + dest="remove_namespace_only_stubs", + action="store_true", + default=False, + help="Remove stubs that export no names beyond a single __module_protocol__. " + "This leaves some folders as PEP420 implicit namespace folders.", + ) + + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) + + args = parser.parse_args() + output_dir = _get_ouput_dir(args.output_dir, args.output_python_path) + if not output_dir.exists(): + output_dir.mkdir(parents=True, exist_ok=True) + + generate_stubs( + endpoints=args.endpoints, + prefixes=args.prefix, + output_dir=output_dir, + convert_strings=args.convert_strings, + include_javadoc=args.with_javadoc, + add_runtime_imports=args.runtime_imports, + remove_namespace_only_stubs=args.remove_namespace_only_stubs, + ) + + +def _get_ouput_dir(output_dir: str | None, python_path: str | None) -> Path: + if out_dir := output_dir: + return Path(out_dir) + if pp := python_path: + return _glean_path(pp) + try: + import scyjava + + return Path(scyjava.__file__).parent / "types" + except ImportError: + return Path("stubs") + + +def _glean_path(pp: str) -> Path: + try: + importlib.import_module(pp.split(".")[0]) + except ModuleNotFoundError: + # the top level module doesn't exist: + raise ValueError(f"Module {pp} does not exist. Cannot install stubs there.") + + try: + spec = importlib.util.find_spec(pp) + except ModuleNotFoundError as e: + # at least one of the middle levels doesn't exist: + raise NotImplementedError(f"Cannot install stubs to {pp}: {e}") + + new_ns = None + if not spec: + # if we get here, it means everything but the last level exists: + parent, new_ns = pp.rsplit(".", 1) + spec = importlib.util.find_spec(parent) + + if not spec: + # if we get here, it means the last level doesn't exist: + raise ValueError(f"Module {pp} does not exist. Cannot install stubs there.") + + search_locations = spec.submodule_search_locations + if not spec.loader and search_locations: + # namespace package with submodules + return Path(search_locations[0]) + if spec.origin: + return Path(spec.origin).parent + if new_ns and search_locations: + # namespace package with submodules + return Path(search_locations[0]) / new_ns + + raise ValueError(f"Error finding module {pp}. Cannot install stubs there.") diff --git a/src/scyjava/_stubs/_dynamic_import.py b/src/scyjava/_stubs/_dynamic_import.py new file mode 100644 index 0000000..cc62641 --- /dev/null +++ b/src/scyjava/_stubs/_dynamic_import.py @@ -0,0 +1,102 @@ +import ast +from logging import warning +from pathlib import Path +from typing import Any, Callable, Sequence + + +def setup_java_imports( + module_name: str, + module_file: str, + endpoints: Sequence[str] = (), + base_prefix: str = "", +) -> tuple[list[str], Callable[[str], Any]]: + """Setup a module to dynamically import Java class names. + + This function creates a `__getattr__` function that, when called, will dynamically + import the requested class from the Java namespace corresponding to the calling + module. + + :param module_name: The dotted name/identifier of the module that is calling this + function (usually `__name__` in the calling module). + :param module_file: The path to the module file (usually `__file__` in the calling + module). + :param endpoints: A list of Java endpoints to add to the scyjava configuration. + :param base_prefix: The base prefix for the Java package name. This is used when + determining the Java class path for the requested class. The java class path + will be truncated to only the part including the base_prefix and after. This + makes it possible to embed a module in a subpackage (like `scyjava.types`) and + still have the correct Java class path. + :return: A 2-tuple containing: + - A list of all classes in the module (as defined in the stub file), to be + assigned to `__all__`. + - A callable that takes a class name and returns a proxy for the Java class. + This callable should be assigned to `__getattr__` in the calling module. + The proxy object, when called, will start the JVM, import the Java class, + and return an instance of the class. The JVM will *only* be started when + the object is called. + + Example: + If the module calling this function is named `scyjava.types.org.scijava.parsington`, + then it should invoke this function as: + + .. code-block:: python + + from scyjava._stubs import setup_java_imports + + __all__, __getattr__ = setup_java_imports( + __name__, + __file__, + endpoints=["org.scijava:parsington:3.1.0"], + base_prefix="org" + ) + """ + import scyjava + import scyjava.config + + for ep in endpoints: + if ep not in scyjava.config.endpoints: + scyjava.config.endpoints.append(ep) + + module_all = [] + try: + my_stub = Path(module_file).with_suffix(".pyi") + stub_ast = ast.parse(my_stub.read_text()) + module_all = sorted( + { + node.name + for node in stub_ast.body + if isinstance(node, ast.ClassDef) and not node.name.startswith("__") + } + ) + except (OSError, SyntaxError): + warning( + f"Failed to read stub file {my_stub!r}. Falling back to empty __all__.", + stacklevel=3, + ) + + def module_getattr(name: str, mod_name: str = module_name) -> Any: + if module_all and name not in module_all: + raise AttributeError(f"module {module_name!r} has no attribute {name!r}") + + # cut the mod_name to only the part including the base_prefix and after + if base_prefix in mod_name: + mod_name = mod_name[mod_name.index(base_prefix) :] + + class_path = f"{mod_name}.{name}" + + class ProxyMeta(type): + def __repr__(self) -> str: + return f"" + + class Proxy(metaclass=ProxyMeta): + def __new__(_cls_, *args: Any, **kwargs: Any) -> Any: + cls = scyjava.jimport(class_path) + return cls(*args, **kwargs) + + Proxy.__name__ = name + Proxy.__qualname__ = name + Proxy.__module__ = module_name + Proxy.__doc__ = f"Proxy for {class_path}" + return Proxy + + return module_all, module_getattr diff --git a/src/scyjava/_stubs/_genstubs.py b/src/scyjava/_stubs/_genstubs.py new file mode 100644 index 0000000..cb6bef2 --- /dev/null +++ b/src/scyjava/_stubs/_genstubs.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +import ast +import logging +import os +import shutil +import subprocess +from importlib import import_module +from itertools import chain +from pathlib import Path, PurePath +import sys +from typing import TYPE_CHECKING, Any +from unittest.mock import patch +from zipfile import ZipFile + +import scyjava +import scyjava.config + +if TYPE_CHECKING: + from collections.abc import Sequence + +logger = logging.getLogger(__name__) + + +def generate_stubs( + endpoints: Sequence[str], + prefixes: Sequence[str] = (), + output_dir: str | Path = "stubs", + convert_strings: bool = True, + include_javadoc: bool = True, + add_runtime_imports: bool = True, + remove_namespace_only_stubs: bool = False, +) -> None: + """Generate stubs for the given maven endpoints. + + Parameters + ---------- + endpoints : Sequence[str] + The maven endpoints to generate stubs for. This should be a list of GAV + coordinates, e.g. ["org.apache.commons:commons-lang3:3.12.0"]. + prefixes : Sequence[str], optional + The prefixes to generate stubs for. This should be a list of Java class + prefixes that you expect to find in the endpoints. For example, + ["org.apache.commons"]. If not provided, the prefixes will be + automatically determined from the jar files provided by endpoints. + output_dir : str | Path, optional + The directory to write the generated stubs to. Defaults to "stubs". + convert_strings : bool, optional + Whether to cast Java strings to Python strings in the stubs. Defaults to True. + NOTE: This leads to type stubs that may not be strictly accurate at runtime. + The actual runtime type of strings is determined by whether jpype.startJVM is + called with the `convertStrings` argument set to True or False. By setting + this `convert_strings` argument to true, the type stubs will be generated as if + `convertStrings` is set to True: that is, all string types will be listed as + `str` rather than `java.lang.String | str`. This is a safer default (as `str`) + is a subtype of `java.lang.String`), but may lead to type errors in some cases. + include_javadoc : bool, optional + Whether to include Javadoc in the generated stubs. Defaults to True. + add_runtime_imports : bool, optional + Whether to add runtime imports to the generated stubs. Defaults to True. + This is useful if you want to use the stubs as a runtime package with type + safety. + remove_namespace_only_stubs : bool, optional + Whether to remove stubs that export no names beyond a single + `__module_protocol__`. This leaves some folders as PEP420 implicit namespace + folders. Defaults to False. Setting this to `True` is useful if you want to + merge the generated stubs with other stubs in the same namespace. Without this, + the `__init__.pyi` for any given module will be whatever whatever the *last* + stub generator wrote to it (and therefore inaccurate). + """ + try: + import stubgenj + except ImportError as e: + raise ImportError( + "stubgenj is not installed, but is required to generate java stubs. " + "Please install it with `pip/conda install stubgenj`." + ) from e + + import jpype + + startJVM = jpype.startJVM + + scyjava.config.endpoints.extend(endpoints) + + def _patched_start(*args: Any, **kwargs: Any) -> None: + kwargs.setdefault("convertStrings", convert_strings) + startJVM(*args, **kwargs) + + with patch.object(jpype, "startJVM", new=_patched_start): + scyjava.start_jvm() + + _prefixes = set(prefixes) + if not _prefixes: + cp = jpype.getClassPath(env=False) + ep_artifacts = tuple(ep.split(":")[1] for ep in endpoints) + for j in cp.split(os.pathsep): + if Path(j).name.startswith(ep_artifacts): + _prefixes.update(list_top_level_packages(j)) + + prefixes = sorted(_prefixes) + logger.info(f"Using endpoints: {scyjava.config.endpoints!r}") + logger.info(f"Generating stubs for: {prefixes}") + logger.info(f"Writing stubs to: {output_dir}") + + metapath = sys.meta_path + try: + import jpype.imports + + jmodules = [import_module(prefix) for prefix in prefixes] + finally: + # remove the jpype.imports magic from the import system + # if it wasn't there to begin with + sys.meta_path = metapath + + stubgenj.generateJavaStubs( + jmodules, + useStubsSuffix=False, + outputDir=str(output_dir), + jpypeJPackageStubs=False, + includeJavadoc=include_javadoc, + ) + + output_dir = Path(output_dir) + if add_runtime_imports: + logger.info("Adding runtime imports to generated stubs") + + for stub in output_dir.rglob("*.pyi"): + stub_ast = ast.parse(stub.read_text()) + members = {node.name for node in stub_ast.body if hasattr(node, "name")} + if members == {"__module_protocol__"}: + # this is simply a module stub... no exports + if remove_namespace_only_stubs: + logger.info("Removing namespace only stub %s", stub) + stub.unlink() + continue + if add_runtime_imports: + real_import = stub.with_suffix(".py") + base_prefix = stub.relative_to(output_dir).parts[0] + real_import.write_text( + INIT_TEMPLATE.format( + endpoints=repr(endpoints), + base_prefix=repr(base_prefix), + ) + ) + + ruff_check(output_dir.absolute()) + + +# the "real" init file that goes into the stub package +INIT_TEMPLATE = """\ +# this file was autogenerated by scyjava-stubgen +# it creates a __getattr__ function that will dynamically import +# the requested class from the Java namespace corresponding to this module. +# see scyjava._stubs for implementation details. +from scyjava._stubs import setup_java_imports + +__all__, __getattr__ = setup_java_imports( + __name__, + __file__, + endpoints={endpoints}, + base_prefix={base_prefix}, +) +""" + + +def ruff_check(output: Path, select: str = "E,W,F,I,UP,C4,B,RUF,TC,TID") -> None: + """Run ruff check and format on the generated stubs.""" + if not shutil.which("ruff"): + return + + py_files = [str(x) for x in chain(output.rglob("*.py"), output.rglob("*.pyi"))] + logger.info( + "Running ruff check on %d generated stubs in % s", + len(py_files), + str(output), + ) + subprocess.run( + [ + "ruff", + "check", + *py_files, + "--quiet", + "--fix-only", + "--unsafe-fixes", + f"--select={select}", + ] + ) + logger.info("Running ruff format") + subprocess.run(["ruff", "format", *py_files, "--quiet"]) + + +def list_top_level_packages(jar_path: str) -> set[str]: + """Inspect a JAR file and return the set of top-level Java package names.""" + packages: set[str] = set() + with ZipFile(jar_path, "r") as jar: + # find all classes + class_dirs = { + entry.parent + for x in jar.namelist() + if (entry := PurePath(x)).suffix == ".class" + } + + roots: set[PurePath] = set() + for p in sorted(class_dirs, key=lambda p: len(p.parts)): + # If none of the already accepted roots is a parent of p, keep p + if not any(root in p.parents for root in roots): + roots.add(p) + packages.update({str(p).replace(os.sep, ".") for p in roots}) + + return packages diff --git a/src/scyjava/_stubs/_hatchling_plugin.py b/src/scyjava/_stubs/_hatchling_plugin.py new file mode 100644 index 0000000..866f829 --- /dev/null +++ b/src/scyjava/_stubs/_hatchling_plugin.py @@ -0,0 +1,65 @@ +"""Hatchling build hook for generating Java stubs. + +To use this hook, add the following to your `pyproject.toml`: + +```toml +[build-system] +requires = ["hatchling", "scyjava"] +build-backend = "hatchling.build" + +[tool.hatch.build.hooks.scyjava] +maven_coordinates = ["org.scijava:parsington:3.1.0"] +prefixes = ["org.scijava"] # optional ... can be auto-determined from the jar files +``` + +This will generate stubs for the given maven coordinates and prefixes. The generated +stubs will be placed in `src/scyjava/types` and will be included in the wheel package. +This hook is only run when building a wheel package. +""" + +import logging +from pathlib import Path + +from hatchling.builders.hooks.plugin.interface import BuildHookInterface +from hatchling.plugin import hookimpl + + +from scyjava._stubs._genstubs import generate_stubs + +logger = logging.getLogger("scyjava") + + +class ScyjavaBuildHook(BuildHookInterface): + """Custom build hook for generating Java stubs.""" + + PLUGIN_NAME = "scyjava" + + def initialize(self, version: str, build_data: dict) -> None: + """Initialize the build hook with the version and build data.""" + if self.target_name != "wheel": + return + + endpoints = self.config.get("maven_coordinates", []) + if not endpoints: + logger.warning("No maven coordinates provided. Skipping stub generation.") + return + + prefixes = self.config.get("prefixes", []) + dest = Path(self.root, "src", "scyjava", "types") + (dest.parent / "py.typed").touch() + + # actually build the stubs + generate_stubs( + endpoints=endpoints, + prefixes=prefixes, + output_dir=dest, + remove_namespace_only_stubs=True, + ) + print(f"Generated stubs for {endpoints} in {dest}") + # add all new packages to the build config + build_data["force_include"].update({str(dest.parent): "scyjava"}) + + +@hookimpl +def hatch_register_build_hook(): + return ScyjavaBuildHook diff --git a/src/scyjava/types/.gitignore b/src/scyjava/types/.gitignore new file mode 100644 index 0000000..5e7d273 --- /dev/null +++ b/src/scyjava/types/.gitignore @@ -0,0 +1,4 @@ +# Ignore everything in this directory +* +# Except this file +!.gitignore diff --git a/tests/test_stubgen.py b/tests/test_stubgen.py new file mode 100644 index 0000000..ad4a74a --- /dev/null +++ b/tests/test_stubgen.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING +from unittest.mock import patch + +import jpype +import pytest + +import scyjava +from scyjava._stubs import _cli + +if TYPE_CHECKING: + from pathlib import Path + + +@pytest.mark.skipif( + scyjava.config.mode != scyjava.config.Mode.JPYPE, + reason="Stubgen not supported in JEP", +) +def test_stubgen(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + # run the stubgen command as if it was run from the command line + monkeypatch.setattr( + sys, + "argv", + [ + "scyjava-stubgen", + "org.scijava:parsington:3.1.0", + "--output-dir", + str(tmp_path), + ], + ) + _cli.main() + + # remove the `jpype.imports` magic from the import system if present + mp = [x for x in sys.meta_path if not isinstance(x, jpype.imports._JImportLoader)] + monkeypatch.setattr(sys, "meta_path", mp) + + # add tmp_path to the import path + monkeypatch.setattr(sys, "path", [str(tmp_path)]) + + # first cleanup to make sure we are not importing from the cache + sys.modules.pop("org", None) + sys.modules.pop("org.scijava", None) + sys.modules.pop("org.scijava.parsington", None) + # make sure the stubgen command works and that we can now impmort stuff + + with patch.object(scyjava._jvm, "start_jvm") as mock_start_jvm: + from org.scijava.parsington import Function + + assert Function is not None + # ensure that no calls to start_jvm were made + mock_start_jvm.assert_not_called() + + # only after instantiating the class should we have a call to start_jvm + func = Function(1) + mock_start_jvm.assert_called_once() + assert isinstance(func, jpype.JObject)