diff --git a/commands2/__init__.py b/commands2/__init__.py index 1c243a7..2b64950 100644 --- a/commands2/__init__.py +++ b/commands2/__init__.py @@ -2,6 +2,7 @@ from . import button from . import cmd +from . import typing from .commandscheduler import CommandScheduler from .conditionalcommand import ConditionalCommand diff --git a/commands2/profiledpidcommand.py b/commands2/profiledpidcommand.py index c2ece07..d7ce298 100644 --- a/commands2/profiledpidcommand.py +++ b/commands2/profiledpidcommand.py @@ -5,16 +5,22 @@ # the WPILib BSD license file in the root directory of this project. # -from typing import Any, Callable, Union - -from .command import Command -from .subsystem import Subsystem +from typing import Any, Generic from wpimath.controller import ProfiledPIDController, ProfiledPIDControllerRadians from wpimath.trajectory import TrapezoidProfile, TrapezoidProfileRadians +from .command import Command +from .subsystem import Subsystem +from .typing import ( + FloatOrFloatSupplier, + FloatSupplier, + TProfiledPIDController, + UseOutputFunction, +) + -class ProfiledPIDCommand(Command): +class ProfiledPIDCommand(Command, Generic[TProfiledPIDController]): """A command that controls an output with a :class:`.ProfiledPIDController`. Runs forever by default - to add exit conditions and/or other behavior, subclass this class. The controller calculation and output are performed synchronously in the command's execute() method. @@ -24,10 +30,10 @@ class ProfiledPIDCommand(Command): def __init__( self, - controller, - measurementSource: Callable[[], float], - goalSource: Union[float, Callable[[], float]], - useOutput: Callable[[float, Any], Any], + controller: TProfiledPIDController, + measurementSource: FloatSupplier, + goalSource: FloatOrFloatSupplier, + useOutput: UseOutputFunction, *requirements: Subsystem, ): """Creates a new ProfiledPIDCommand, which controls the given output with a ProfiledPIDController. Goal @@ -40,6 +46,7 @@ def __init__( :param requirements: the subsystems required by this command """ + super().__init__() if isinstance(controller, ProfiledPIDController): self._stateCls = TrapezoidProfile.State elif isinstance(controller, ProfiledPIDControllerRadians): @@ -47,7 +54,7 @@ def __init__( else: raise ValueError(f"unknown controller type {controller!r}") - self._controller = controller + self._controller: TProfiledPIDController = controller self._useOutput = useOutput self._measurement = measurementSource if isinstance(goalSource, (float, int)): diff --git a/commands2/profiledpidsubsystem.py b/commands2/profiledpidsubsystem.py index f2c7069..4a05baa 100644 --- a/commands2/profiledpidsubsystem.py +++ b/commands2/profiledpidsubsystem.py @@ -2,14 +2,17 @@ # Open Source Software; you can modify and/or share it under the terms of # the WPILib BSD license file in the root directory of this project. -from typing import Union, cast +from typing import Generic from wpimath.trajectory import TrapezoidProfile from .subsystem import Subsystem +from .typing import TProfiledPIDController, TTrapezoidProfileState -class ProfiledPIDSubsystem(Subsystem): +class ProfiledPIDSubsystem( + Subsystem, Generic[TProfiledPIDController, TTrapezoidProfileState] +): """ A subsystem that uses a :class:`wpimath.controller.ProfiledPIDController` or :class:`wpimath.controller.ProfiledPIDControllerRadians` to @@ -19,12 +22,18 @@ class ProfiledPIDSubsystem(Subsystem): def __init__( self, - controller, + controller: TProfiledPIDController, initial_position: float = 0, ): - """Creates a new PIDSubsystem.""" + """ + Creates a new Profiled PID Subsystem using the provided PID Controller + + :param controller: the controller that controls the output + :param initial_position: the initial value of the process variable + + """ super().__init__() - self._controller = controller + self._controller: TProfiledPIDController = controller self._enabled = False self.setGoal(initial_position) @@ -38,20 +47,16 @@ def periodic(self): def getController( self, - ): + ) -> TProfiledPIDController: """Returns the controller""" return self._controller def setGoal(self, goal): - """ - Sets the goal state for the subsystem. - """ + """Sets the goal state for the subsystem.""" self._controller.setGoal(goal) - def useOutput(self, output: float, setpoint: TrapezoidProfile.State): - """ - Uses the output from the controller object. - """ + def useOutput(self, output: float, setpoint: TTrapezoidProfileState): + """Uses the output from the controller object.""" raise NotImplementedError(f"{self.__class__} must implement useOutput") def getMeasurement(self) -> float: @@ -72,7 +77,5 @@ def disable(self): self.useOutput(0, TrapezoidProfile.State()) def isEnabled(self) -> bool: - """ - Returns whether the controller is enabled. - """ + """Returns whether the controller is enabled.""" return self._enabled diff --git a/commands2/typing.py b/commands2/typing.py new file mode 100644 index 0000000..6168330 --- /dev/null +++ b/commands2/typing.py @@ -0,0 +1,30 @@ +from typing import Callable, Protocol, TypeVar, Union + +from typing_extensions import TypeAlias +from wpimath.controller import ProfiledPIDController, ProfiledPIDControllerRadians +from wpimath.trajectory import TrapezoidProfile, TrapezoidProfileRadians + +# Generic Types +TProfiledPIDController = TypeVar( + "TProfiledPIDController", ProfiledPIDControllerRadians, ProfiledPIDController +) +TTrapezoidProfileState = TypeVar( + "TTrapezoidProfileState", + TrapezoidProfileRadians.State, + TrapezoidProfile.State, +) + + +# Protocols - Structural Typing +class UseOutputFunction(Protocol): + + def __init__(self, *args, **kwargs) -> None: ... + + def __call__(self, t: float, u: TTrapezoidProfileState) -> None: ... + + def accept(self, t: float, u: TTrapezoidProfileState) -> None: ... + + +# Type Aliases +FloatSupplier: TypeAlias = Callable[[], float] +FloatOrFloatSupplier: TypeAlias = Union[float, Callable[[], float]]