From daca094d843adb03b99c85d28415663dd7dcee7e Mon Sep 17 00:00:00 2001 From: Rauha Rahkola Date: Tue, 25 Apr 2023 14:13:29 -0700 Subject: [PATCH 1/2] for #30, adding type annotations --- .gitignore | 6 ++ adafruit_bno08x/__init__.py | 165 +++++++++++++++++++----------------- 2 files changed, 93 insertions(+), 78 deletions(-) diff --git a/.gitignore b/.gitignore index db3d538..87b9508 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,12 @@ __pycache__ # Sphinx build-specific files _build +# MyPy-specific type-checking files +.mypy_cache + +# pip install files +/build/ + # This file results from running `pip -e install .` in a local repository *.egg-info diff --git a/adafruit_bno08x/__init__.py b/adafruit_bno08x/__init__.py index f2d999e..64a5b6d 100644 --- a/adafruit_bno08x/__init__.py +++ b/adafruit_bno08x/__init__.py @@ -25,6 +25,8 @@ * `Adafruit's Bus Device library `_ """ +from __future__ import annotations + __version__ = "0.0.0+auto.0" __repo__ = "https:# github.com/adafruit/Adafruit_CircuitPython_BNO08x.git" @@ -36,6 +38,12 @@ # TODO: Remove on release from .debug import channels, reports +# For IDE type recognition +try: + from typing import Any, Dict, List, Optional, Tuple, Union +except ImportError: + pass + # TODO: shorten names # Channel 0: the SHTP command channel BNO_CHANNEL_SHTP_COMMAND = const(0) @@ -206,12 +214,12 @@ class PacketError(Exception): pass # pylint:disable=unnecessary-pass -def _elapsed(start_time): +def _elapsed(start_time: float) -> float: return time.monotonic() - start_time ############ PACKET PARSING ########################### -def _parse_sensor_report_data(report_bytes): +def _parse_sensor_report_data(report_bytes: bytearray) -> Tuple[Any, int]: """Parses reports with only 16-bit fields""" data_offset = 4 # this may not always be true report_id = report_bytes[0] @@ -235,11 +243,11 @@ def _parse_sensor_report_data(report_bytes): return (results_tuple, accuracy) -def _parse_step_couter_report(report_bytes): +def _parse_step_couter_report(report_bytes: bytearray) -> int: return unpack_from(" str: classification_bitfield = unpack_from(" Tuple[Any, ...]: return unpack_from(" Dict[str, int]: activities = [ "Unknown", "In-Vehicle", # look @@ -292,12 +300,12 @@ def _parse_activity_classifier_report(report_bytes): return classification -def _parse_shake_report(report_bytes): +def _parse_shake_report(report_bytes: bytearray) -> bool: shake_bitfield = unpack_from(" 0 -def parse_sensor_id(buffer): +def parse_sensor_id(buffer: bytearray) -> Tuple[int, ...]: """Parse the fields of a product id report""" if not buffer[0] == _SHTP_REPORT_PRODUCT_ID_RESPONSE: raise AttributeError("Wrong report id for sensor id: %s" % hex(buffer[0])) @@ -311,8 +319,7 @@ def parse_sensor_id(buffer): return (sw_part_number, sw_major, sw_minor, sw_patch, sw_build_number) -def _parse_command_response(report_bytes): - +def _parse_command_response(report_bytes: bytearray) -> Tuple[Any, Any]: # CMD response report: # 0 Report ID = 0xF1 # 1 Sequence number @@ -327,8 +334,11 @@ def _parse_command_response(report_bytes): def _insert_command_request_report( - command, buffer, next_sequence_number, command_params=None -): + command: int, + buffer: bytearray, + next_sequence_number: int, + command_params: Optional[List[int]] = None, +) -> None: if command_params and len(command_params) > 9: raise AttributeError( "Command request reports can only have up to 9 arguments but %d were given" @@ -346,14 +356,14 @@ def _insert_command_request_report( buffer[3 + idx] = param -def _report_length(report_id): +def _report_length(report_id: int) -> int: if report_id < 0xF0: # it's a sensor report return _AVAIL_SENSOR_REPORTS[report_id][2] return _REPORT_LENGTHS[report_id] -def _separate_batch(packet, report_slices): +def _separate_batch(packet: Packet, report_slices: List[Any]) -> None: # get first report id, loop up its report length # read that many bytes, parse them next_byte_index = 0 @@ -377,13 +387,12 @@ def _separate_batch(packet, report_slices): class Packet: """A class representing a Hillcrest LaboratorySensor Hub Transport packet""" - def __init__(self, packet_bytes): + def __init__(self, packet_bytes: bytearray) -> None: self.header = self.header_from_buffer(packet_bytes) data_end_index = self.header.data_length + _BNO_HEADER_LEN self.data = packet_bytes[_BNO_HEADER_LEN:data_end_index] - def __str__(self): - + def __str__(self) -> str: length = self.header.packet_byte_count outstr = "\n\t\t********** Packet *************\n" outstr += "DBG::\t\t HEADER:\n" @@ -441,17 +450,17 @@ def __str__(self): return outstr @property - def report_id(self): + def report_id(self) -> int: """The Packet's Report ID""" return self.data[0] @property - def channel_number(self): + def channel_number(self) -> int: """The packet channel""" return self.header.channel_number @classmethod - def header_from_buffer(cls, packet_bytes): + def header_from_buffer(cls, packet_bytes: bytearray) -> PacketHeader: """Creates a `PacketHeader` object from a given buffer""" packet_byte_count = unpack_from(" bool: """Returns True if the header is an error condition""" if header.channel_number > 5: @@ -482,32 +491,29 @@ class BNO08X: # pylint: disable=too-many-instance-attributes, too-many-public-m """ - def __init__(self, reset=None, debug=False): - self._debug = debug - self._reset = reset + def __init__(self, reset: Optional[dict] = None, debug: bool = False) -> None: + self._debug: bool = debug + self._reset: Optional[dict] = reset self._dbg("********** __init__ *************") - self._data_buffer = bytearray(DATA_BUFFER_SIZE) - self._command_buffer = bytearray(12) - self._packet_slices = [] + self._data_buffer: bytearray = bytearray(DATA_BUFFER_SIZE) + self._command_buffer: bytearray = bytearray(12) + self._packet_slices: List[Any] = [] # TODO: this is wrong there should be one per channel per direction - self._sequence_number = [0, 0, 0, 0, 0, 0] - self._two_ended_sequence_numbers = { - "send": {}, # holds the next seq number to send with the report id as a key - "receive": {}, - } - self._dcd_saved_at = -1 - self._me_calibration_started_at = -1 + self._sequence_number: List[int] = [0, 0, 0, 0, 0, 0] + self._two_ended_sequence_numbers: Dict[int, int] = {} + self._dcd_saved_at: float = -1 + self._me_calibration_started_at: float = -1 self._calibration_complete = False self._magnetometer_accuracy = 0 self._wait_for_initialize = True self._init_complete = False self._id_read = False # for saving the most recent reading when decoding several packets - self._readings = {} + self._readings: Dict[int, Any] = {} self.initialize() - def initialize(self): + def initialize(self) -> None: """Initialize the sensor""" for _ in range(3): self.hard_reset() @@ -521,7 +527,7 @@ def initialize(self): raise RuntimeError("Could not read ID") @property - def magnetic(self): + def magnetic(self) -> Optional[Tuple[float, float, float]]: """A tuple of the current magnetic field measurements on the X, Y, and Z axes""" self._process_available_packets() # decorator? try: @@ -530,7 +536,7 @@ def magnetic(self): raise RuntimeError("No magfield report found, is it enabled?") from None @property - def quaternion(self): + def quaternion(self) -> Optional[Tuple[float, float, float, float]]: """A quaternion representing the current rotation vector""" self._process_available_packets() try: @@ -539,7 +545,7 @@ def quaternion(self): raise RuntimeError("No quaternion report found, is it enabled?") from None @property - def geomagnetic_quaternion(self): + def geomagnetic_quaternion(self) -> Optional[Tuple[float, float, float, float]]: """A quaternion representing the current geomagnetic rotation vector""" self._process_available_packets() try: @@ -550,7 +556,7 @@ def geomagnetic_quaternion(self): ) from None @property - def game_quaternion(self): + def game_quaternion(self) -> Optional[Tuple[float, float, float, float]]: """A quaternion representing the current rotation vector expressed as a quaternion with no specific reference for heading, while roll and pitch are referenced against gravity. To prevent sudden jumps in heading due to corrections, the `game_quaternion` property is not @@ -564,7 +570,7 @@ def game_quaternion(self): ) from None @property - def steps(self): + def steps(self) -> Optional[int]: """The number of steps detected since the sensor was initialized""" self._process_available_packets() try: @@ -573,7 +579,7 @@ def steps(self): raise RuntimeError("No steps report found, is it enabled?") from None @property - def linear_acceleration(self): + def linear_acceleration(self) -> Optional[Tuple[float, float, float]]: """A tuple representing the current linear acceleration values on the X, Y, and Z axes in meters per second squared""" self._process_available_packets() @@ -583,7 +589,7 @@ def linear_acceleration(self): raise RuntimeError("No lin. accel report found, is it enabled?") from None @property - def acceleration(self): + def acceleration(self) -> Optional[Tuple[float, float, float]]: """A tuple representing the acceleration measurements on the X, Y, and Z axes in meters per second squared""" self._process_available_packets() @@ -593,7 +599,7 @@ def acceleration(self): raise RuntimeError("No accel report found, is it enabled?") from None @property - def gravity(self): + def gravity(self) -> Optional[Tuple[float, float, float]]: """A tuple representing the gravity vector in the X, Y, and Z components axes in meters per second squared""" self._process_available_packets() @@ -603,7 +609,7 @@ def gravity(self): raise RuntimeError("No gravity report found, is it enabled?") from None @property - def gyro(self): + def gyro(self) -> Optional[Tuple[float, float, float]]: """A tuple representing Gyro's rotation measurements on the X, Y, and Z axes in radians per second""" self._process_available_packets() @@ -613,7 +619,7 @@ def gyro(self): raise RuntimeError("No gyro report found, is it enabled?") from None @property - def shake(self): + def shake(self) -> Optional[bool]: """True if a shake was detected on any axis since the last time it was checked This property has a "latching" behavior where once a shake is detected, it will stay in a @@ -631,7 +637,7 @@ def shake(self): raise RuntimeError("No shake report found, is it enabled?") from None @property - def stability_classification(self): + def stability_classification(self) -> Optional[str]: """Returns the sensor's assessment of it's current stability, one of: * "Unknown" - The sensor is unable to classify the current stability @@ -653,7 +659,7 @@ def stability_classification(self): ) from None @property - def activity_classification(self): + def activity_classification(self) -> Optional[dict]: """Returns the sensor's assessment of the activity that is creating the motions\ that it is sensing, one of: @@ -678,7 +684,7 @@ def activity_classification(self): ) from None @property - def raw_acceleration(self): + def raw_acceleration(self) -> Optional[Tuple[int, int, int]]: """Returns the sensor's raw, unscaled value from the accelerometer registers""" self._process_available_packets() try: @@ -690,7 +696,7 @@ def raw_acceleration(self): ) from None @property - def raw_gyro(self): + def raw_gyro(self) -> Optional[Tuple[int, int, int]]: """Returns the sensor's raw, unscaled value from the gyro registers""" self._process_available_packets() try: @@ -700,7 +706,7 @@ def raw_gyro(self): raise RuntimeError("No raw gyro report found, is it enabled?") from None @property - def raw_magnetic(self): + def raw_magnetic(self) -> Optional[Tuple[int, int, int]]: """Returns the sensor's raw, unscaled value from the magnetometer registers""" self._process_available_packets() try: @@ -709,7 +715,7 @@ def raw_magnetic(self): except KeyError: raise RuntimeError("No raw magnetic report found, is it enabled?") from None - def begin_calibration(self): + def begin_calibration(self) -> None: """Begin the sensor's self-calibration routine""" # start calibration for accel, gyro, and mag self._send_me_command( @@ -728,7 +734,7 @@ def begin_calibration(self): self._calibration_complete = False @property - def calibration_status(self): + def calibration_status(self) -> int: """Get the status of the self-calibration""" self._send_me_command( [ @@ -745,8 +751,7 @@ def calibration_status(self): ) return self._magnetometer_accuracy - def _send_me_command(self, subcommand_params): - + def _send_me_command(self, subcommand_params: Optional[List[int]]) -> None: start_time = time.monotonic() local_buffer = self._command_buffer _insert_command_request_report( @@ -762,7 +767,7 @@ def _send_me_command(self, subcommand_params): if self._me_calibration_started_at > start_time: break - def save_calibration_data(self): + def save_calibration_data(self) -> None: """Save the self-calibration data""" # send a DCD save command start_time = time.monotonic() @@ -782,7 +787,7 @@ def save_calibration_data(self): ############### private/helper methods ############### # # decorator? - def _process_available_packets(self, max_packets=None): + def _process_available_packets(self, max_packets: Optional[int] = None) -> None: processed_count = 0 while self._data_ready: if max_packets and processed_count > max_packets: @@ -800,7 +805,9 @@ def _process_available_packets(self, max_packets=None): self._dbg("") self._dbg(" ** DONE! **") - def _wait_for_packet_type(self, channel_number, report_id=None, timeout=5.0): + def _wait_for_packet_type( + self, channel_number: int, report_id: Optional[int] = None, timeout: float = 5.0 + ) -> Packet: if report_id: report_id_str = " with report id %s" % hex(report_id) else: @@ -825,7 +832,7 @@ def _wait_for_packet_type(self, channel_number, report_id=None, timeout=5.0): raise RuntimeError("Timed out waiting for a packet on channel", channel_number) - def _wait_for_packet(self, timeout=_PACKET_READ_TIMEOUT): + def _wait_for_packet(self, timeout: float = _PACKET_READ_TIMEOUT) -> Packet: start_time = time.monotonic() while _elapsed(start_time) < timeout: if not self._data_ready: @@ -837,12 +844,12 @@ def _wait_for_packet(self, timeout=_PACKET_READ_TIMEOUT): # update the cached sequence number so we know what to increment from # TODO: this is wrong there should be one per channel per direction # and apparently per report as well - def _update_sequence_number(self, new_packet): + def _update_sequence_number(self, new_packet: Packet) -> None: channel = new_packet.channel_number seq = new_packet.header.sequence_number self._sequence_number[channel] = seq - def _handle_packet(self, packet): + def _handle_packet(self, packet: Packet) -> None: # split out reports first try: _separate_batch(packet, self._packet_slices) @@ -852,7 +859,7 @@ def _handle_packet(self, packet): print(packet) raise error - def _handle_control_report(self, report_id, report_bytes): + def _handle_control_report(self, report_id: int, report_bytes: bytearray) -> None: if report_id == _SHTP_REPORT_PRODUCT_ID_RESPONSE: ( sw_part_number, @@ -876,7 +883,7 @@ def _handle_control_report(self, report_id, report_bytes): if report_id == _COMMAND_RESPONSE: self._handle_command_response(report_bytes) - def _handle_command_response(self, report_bytes): + def _handle_command_response(self, report_bytes: bytearray) -> None: (report_body, response_values) = _parse_command_response(report_bytes) ( @@ -899,7 +906,7 @@ def _handle_command_response(self, report_bytes): else: raise RuntimeError("Unable to save calibration data") - def _process_report(self, report_id, report_bytes): + def _process_report(self, report_id: int, report_bytes: bytearray) -> None: if report_id >= 0xF0: self._handle_control_report(report_id, report_bytes) return @@ -948,8 +955,10 @@ def _process_report(self, report_id, report_bytes): # TODO: Make this a Packet creation @staticmethod def _get_feature_enable_report( - feature_id, report_interval=_DEFAULT_REPORT_INTERVAL, sensor_specific_config=0 - ): + feature_id: int, + report_interval: Any = _DEFAULT_REPORT_INTERVAL, + sensor_specific_config: int = 0, + ) -> bytearray: set_feature_report = bytearray(17) set_feature_report[0] = _SET_FEATURE_COMMAND set_feature_report[1] = feature_id @@ -961,7 +970,7 @@ def _get_feature_enable_report( # TODO: add docs for available features # TODO2: I think this should call an fn that imports all the bits for the given feature # so we're not carrying around stuff for extra features - def enable_feature(self, feature_id): + def enable_feature(self, feature_id: int) -> None: """Used to enable a given feature of the BNO08x""" self._dbg("\n********** Enabling feature id:", feature_id, "**********") @@ -989,7 +998,7 @@ def enable_feature(self, feature_id): return raise RuntimeError("Was not able to enable feature", feature_id) - def _check_id(self): + def _check_id(self) -> bool: self._dbg("\n********** READ ID **********") if self._id_read: return True @@ -1012,7 +1021,7 @@ def _check_id(self): return False - def _parse_sensor_id(self): + def _parse_sensor_id(self) -> Optional[int]: if not self._data_buffer[4] == _SHTP_REPORT_PRODUCT_ID_RESPONSE: return None @@ -1030,21 +1039,21 @@ def _parse_sensor_id(self): # TODO: this is only one of the numbers! return sw_part_number - def _dbg(self, *args, **kwargs): + def _dbg(self, *args: Any, **kwargs: Any) -> None: if self._debug: print("DBG::\t\t", *args, **kwargs) - def _get_data(self, index, fmt_string): + def _get_data(self, index: int, fmt_string: str) -> Any: # index arg is not including header, so add 4 into data buffer data_index = index + 4 return unpack_from(fmt_string, self._data_buffer, offset=data_index)[0] # pylint:disable=no-self-use @property - def _data_ready(self): + def _data_ready(self) -> None: raise RuntimeError("Not implemented") - def hard_reset(self): + def hard_reset(self) -> None: """Hardware reset the sensor to an initial unconfigured state""" if not self._reset: return @@ -1058,7 +1067,7 @@ def hard_reset(self): self._reset.value = True time.sleep(0.01) - def soft_reset(self): + def soft_reset(self) -> None: """Reset the sensor to an initial unconfigured state""" self._dbg("Soft resetting...", end="") data = bytearray(1) @@ -1077,15 +1086,15 @@ def soft_reset(self): self._dbg("OK!") # all is good! - def _send_packet(self, channel, data): + def _send_packet(self, channel: int, data: bytearray) -> Optional[int]: raise RuntimeError("Not implemented") - def _read_packet(self): + def _read_packet(self) -> Optional[Packet]: raise RuntimeError("Not implemented") - def _increment_report_seq(self, report_id): + def _increment_report_seq(self, report_id: int) -> None: current = self._two_ended_sequence_numbers.get(report_id, 0) self._two_ended_sequence_numbers[report_id] = (current + 1) % 256 - def _get_report_seq_id(self, report_id): + def _get_report_seq_id(self, report_id: int) -> int: return self._two_ended_sequence_numbers.get(report_id, 0) From 58bf8499d4493b657e633d10c985b816a565da75 Mon Sep 17 00:00:00 2001 From: foamyguy Date: Tue, 30 May 2023 09:36:05 -0500 Subject: [PATCH 2/2] fix a few types --- .gitignore | 6 ------ adafruit_bno08x/__init__.py | 15 +++++++++------ 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 87b9508..db3d538 100644 --- a/.gitignore +++ b/.gitignore @@ -32,12 +32,6 @@ __pycache__ # Sphinx build-specific files _build -# MyPy-specific type-checking files -.mypy_cache - -# pip install files -/build/ - # This file results from running `pip -e install .` in a local repository *.egg-info diff --git a/adafruit_bno08x/__init__.py b/adafruit_bno08x/__init__.py index 64a5b6d..bf6aa50 100644 --- a/adafruit_bno08x/__init__.py +++ b/adafruit_bno08x/__init__.py @@ -41,6 +41,7 @@ # For IDE type recognition try: from typing import Any, Dict, List, Optional, Tuple, Union + from digitalio import DigitalInOut except ImportError: pass @@ -219,7 +220,7 @@ def _elapsed(start_time: float) -> float: ############ PACKET PARSING ########################### -def _parse_sensor_report_data(report_bytes: bytearray) -> Tuple[Any, int]: +def _parse_sensor_report_data(report_bytes: bytearray) -> Tuple[Tuple, int]: """Parses reports with only 16-bit fields""" data_offset = 4 # this may not always be true report_id = report_bytes[0] @@ -272,7 +273,7 @@ def _parse_get_feature_response_report(report_bytes: bytearray) -> Tuple[Any, .. # 4 Page Number + EOS # 5 Most likely state # 6-15 Classification (10 x Page Number) + confidence -def _parse_activity_classifier_report(report_bytes: bytearray) -> Dict[str, int]: +def _parse_activity_classifier_report(report_bytes: bytearray) -> Dict[str, str]: activities = [ "Unknown", "In-Vehicle", # look @@ -491,9 +492,11 @@ class BNO08X: # pylint: disable=too-many-instance-attributes, too-many-public-m """ - def __init__(self, reset: Optional[dict] = None, debug: bool = False) -> None: + def __init__( + self, reset: Optional[DigitalInOut] = None, debug: bool = False + ) -> None: self._debug: bool = debug - self._reset: Optional[dict] = reset + self._reset: Optional[DigitalInOut] = reset self._dbg("********** __init__ *************") self._data_buffer: bytearray = bytearray(DATA_BUFFER_SIZE) self._command_buffer: bytearray = bytearray(12) @@ -503,7 +506,7 @@ def __init__(self, reset: Optional[dict] = None, debug: bool = False) -> None: self._sequence_number: List[int] = [0, 0, 0, 0, 0, 0] self._two_ended_sequence_numbers: Dict[int, int] = {} self._dcd_saved_at: float = -1 - self._me_calibration_started_at: float = -1 + self._me_calibration_started_at: float = -1.0 self._calibration_complete = False self._magnetometer_accuracy = 0 self._wait_for_initialize = True @@ -956,7 +959,7 @@ def _process_report(self, report_id: int, report_bytes: bytearray) -> None: @staticmethod def _get_feature_enable_report( feature_id: int, - report_interval: Any = _DEFAULT_REPORT_INTERVAL, + report_interval: int = _DEFAULT_REPORT_INTERVAL, sensor_specific_config: int = 0, ) -> bytearray: set_feature_report = bytearray(17)