Skip to content

Commit 9ec9148

Browse files
authored
Merge pull request #56 from cwstryker/cwstryker/generic_types
One option for adding generic types
2 parents 710f4b7 + 604af69 commit 9ec9148

File tree

4 files changed

+67
-26
lines changed

4 files changed

+67
-26
lines changed

commands2/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from . import button
44
from . import cmd
5+
from . import typing
56

67
from .commandscheduler import CommandScheduler
78
from .conditionalcommand import ConditionalCommand

commands2/profiledpidcommand.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,22 @@
55
# the WPILib BSD license file in the root directory of this project.
66
#
77

8-
from typing import Any, Callable, Union
9-
10-
from .command import Command
11-
from .subsystem import Subsystem
8+
from typing import Any, Generic
129

1310
from wpimath.controller import ProfiledPIDController, ProfiledPIDControllerRadians
1411
from wpimath.trajectory import TrapezoidProfile, TrapezoidProfileRadians
1512

13+
from .command import Command
14+
from .subsystem import Subsystem
15+
from .typing import (
16+
FloatOrFloatSupplier,
17+
FloatSupplier,
18+
TProfiledPIDController,
19+
UseOutputFunction,
20+
)
21+
1622

17-
class ProfiledPIDCommand(Command):
23+
class ProfiledPIDCommand(Command, Generic[TProfiledPIDController]):
1824
"""A command that controls an output with a :class:`.ProfiledPIDController`. Runs forever by default -
1925
to add exit conditions and/or other behavior, subclass this class. The controller calculation and
2026
output are performed synchronously in the command's execute() method.
@@ -24,10 +30,10 @@ class ProfiledPIDCommand(Command):
2430

2531
def __init__(
2632
self,
27-
controller,
28-
measurementSource: Callable[[], float],
29-
goalSource: Union[float, Callable[[], float]],
30-
useOutput: Callable[[float, Any], Any],
33+
controller: TProfiledPIDController,
34+
measurementSource: FloatSupplier,
35+
goalSource: FloatOrFloatSupplier,
36+
useOutput: UseOutputFunction,
3137
*requirements: Subsystem,
3238
):
3339
"""Creates a new ProfiledPIDCommand, which controls the given output with a ProfiledPIDController. Goal
@@ -40,14 +46,15 @@ def __init__(
4046
:param requirements: the subsystems required by this command
4147
"""
4248

49+
super().__init__()
4350
if isinstance(controller, ProfiledPIDController):
4451
self._stateCls = TrapezoidProfile.State
4552
elif isinstance(controller, ProfiledPIDControllerRadians):
4653
self._stateCls = TrapezoidProfileRadians.State
4754
else:
4855
raise ValueError(f"unknown controller type {controller!r}")
4956

50-
self._controller = controller
57+
self._controller: TProfiledPIDController = controller
5158
self._useOutput = useOutput
5259
self._measurement = measurementSource
5360
if isinstance(goalSource, (float, int)):

commands2/profiledpidsubsystem.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22
# Open Source Software; you can modify and/or share it under the terms of
33
# the WPILib BSD license file in the root directory of this project.
44

5-
from typing import Union, cast
5+
from typing import Generic
66

77
from wpimath.trajectory import TrapezoidProfile
88

99
from .subsystem import Subsystem
10+
from .typing import TProfiledPIDController, TTrapezoidProfileState
1011

1112

12-
class ProfiledPIDSubsystem(Subsystem):
13+
class ProfiledPIDSubsystem(
14+
Subsystem, Generic[TProfiledPIDController, TTrapezoidProfileState]
15+
):
1316
"""
1417
A subsystem that uses a :class:`wpimath.controller.ProfiledPIDController`
1518
or :class:`wpimath.controller.ProfiledPIDControllerRadians` to
@@ -19,12 +22,18 @@ class ProfiledPIDSubsystem(Subsystem):
1922

2023
def __init__(
2124
self,
22-
controller,
25+
controller: TProfiledPIDController,
2326
initial_position: float = 0,
2427
):
25-
"""Creates a new PIDSubsystem."""
28+
"""
29+
Creates a new Profiled PID Subsystem using the provided PID Controller
30+
31+
:param controller: the controller that controls the output
32+
:param initial_position: the initial value of the process variable
33+
34+
"""
2635
super().__init__()
27-
self._controller = controller
36+
self._controller: TProfiledPIDController = controller
2837
self._enabled = False
2938
self.setGoal(initial_position)
3039

@@ -38,20 +47,16 @@ def periodic(self):
3847

3948
def getController(
4049
self,
41-
):
50+
) -> TProfiledPIDController:
4251
"""Returns the controller"""
4352
return self._controller
4453

4554
def setGoal(self, goal):
46-
"""
47-
Sets the goal state for the subsystem.
48-
"""
55+
"""Sets the goal state for the subsystem."""
4956
self._controller.setGoal(goal)
5057

51-
def useOutput(self, output: float, setpoint: TrapezoidProfile.State):
52-
"""
53-
Uses the output from the controller object.
54-
"""
58+
def useOutput(self, output: float, setpoint: TTrapezoidProfileState):
59+
"""Uses the output from the controller object."""
5560
raise NotImplementedError(f"{self.__class__} must implement useOutput")
5661

5762
def getMeasurement(self) -> float:
@@ -72,7 +77,5 @@ def disable(self):
7277
self.useOutput(0, TrapezoidProfile.State())
7378

7479
def isEnabled(self) -> bool:
75-
"""
76-
Returns whether the controller is enabled.
77-
"""
80+
"""Returns whether the controller is enabled."""
7881
return self._enabled

commands2/typing.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from typing import Callable, Protocol, TypeVar, Union
2+
3+
from typing_extensions import TypeAlias
4+
from wpimath.controller import ProfiledPIDController, ProfiledPIDControllerRadians
5+
from wpimath.trajectory import TrapezoidProfile, TrapezoidProfileRadians
6+
7+
# Generic Types
8+
TProfiledPIDController = TypeVar(
9+
"TProfiledPIDController", ProfiledPIDControllerRadians, ProfiledPIDController
10+
)
11+
TTrapezoidProfileState = TypeVar(
12+
"TTrapezoidProfileState",
13+
TrapezoidProfileRadians.State,
14+
TrapezoidProfile.State,
15+
)
16+
17+
18+
# Protocols - Structural Typing
19+
class UseOutputFunction(Protocol):
20+
21+
def __init__(self, *args, **kwargs) -> None: ...
22+
23+
def __call__(self, t: float, u: TTrapezoidProfileState) -> None: ...
24+
25+
def accept(self, t: float, u: TTrapezoidProfileState) -> None: ...
26+
27+
28+
# Type Aliases
29+
FloatSupplier: TypeAlias = Callable[[], float]
30+
FloatOrFloatSupplier: TypeAlias = Union[float, Callable[[], float]]

0 commit comments

Comments
 (0)