diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..c7e10e3 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,21 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +indent_style = space +indent_size = 4 +trim_trailing_whitespace = true +insert_final_newline = true + +[*.py] +max_line_length = 120 + +[*.{yml,yaml,toml,json}] +indent_size = 2 + +[*.md] +trim_trailing_whitespace = false + +[Makefile] +indent_style = tab diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..38ce88e --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,72 @@ +name: Tests & Lint + +on: + push: + branches: ["**"] + tags-ignore: ["**"] + pull_request: + +permissions: + contents: read + +jobs: + # --------------------------------------------------------------------------- + # Lint and format check (fast feedback) + # --------------------------------------------------------------------------- + lint: + name: Ruff lint & format + runs-on: ubuntu-latest + permissions: + contents: read + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install ruff + run: pip install ruff + + - name: Ruff lint + run: ruff check src/ + + - name: Ruff format check + run: ruff format src/ --check + + # --------------------------------------------------------------------------- + # Unit tests + # --------------------------------------------------------------------------- + test: + name: pytest / Python ${{ matrix.python-version }} / ${{ matrix.os }} + runs-on: ${{ matrix.os }} + needs: lint + permissions: + contents: read + + strategy: + fail-fast: false + matrix: + python-version: ["3.12"] + os: [ubuntu-latest, windows-latest, macos-latest] + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install package with test dependencies + run: pip install -e ".[test]" + + - name: Run tests with coverage + run: pytest unit-test/ -v --cov=src --cov-report=term-missing --cov-report=xml + + - name: Upload coverage report + uses: actions/upload-artifact@v4 + if: matrix.os == 'ubuntu-latest' + with: + name: coverage-report + path: coverage.xml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..dcd7685 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,32 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.2 + hooks: + # Run ruff linter + - id: ruff + args: [--fix] + # Run ruff formatter + - id: ruff-format + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-toml + - id: check-merge-conflict + - id: check-added-large-files + args: [--maxkb=5120] + - id: debug-statements + - id: no-commit-to-branch + args: [--branch, main] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.15.0 + hooks: + - id: mypy + additional_dependencies: [] + args: [--ignore-missing-imports, --non-interactive] + pass_filenames: false + args: [src/openlifu_sdk, --ignore-missing-imports] diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..9bad61b --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,59 @@ +# Changelog + +All notable changes to `openlifu-sdk` are documented here. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +--- + +## [Unreleased] + +### Added +- `CONTRIBUTING.md` with development setup, coding conventions, and release process. +- `CHANGELOG.md` (this file). +- `.editorconfig` for consistent editor settings across contributors. +- `.pre-commit-config.yaml` with ruff (lint + format) and general file checks. +- `ruff`, `mypy`, and `pytest` configuration in `pyproject.toml`. +- CI workflow (`.github/workflows/test.yml`) that runs lint and tests on every push + and pull request across Ubuntu, Windows, and macOS. + +### Fixed +- **`LIFUUart`** — removed library-level `log.propagate = False`, hard-coded + `log.setLevel(logging.ERROR)`, and manual `StreamHandler` setup. Libraries must not + configure the root logging hierarchy. Logger name changed from `"UART"` to + `__name__`. +- **`LIFUTXDevice`** — same logging configuration fix; logger name changed from + `"TXDevice"` to `__name__`. +- **`LIFUInterface`** — `signal_connect`, `signal_disconnect`, `signal_data_received`, + `hvcontroller`, and `txdevice` were declared as class-level attributes, causing all + instances to share the same signal objects. These are now proper instance attributes + created in `__init__`. Instance attribute types are also correctly annotated with + `Optional[...]`. +- **`pyproject.toml`** — the `[test]` and `[dev]` optional-dependency groups were + identical. `[dev]` now adds `pre-commit`, `ruff`, and `mypy` on top of the test + dependencies. +- **All modules** — replaced f-string logging calls (`logger.error(f"msg {x}")`) with + lazy `%`-style formatting (`logger.error("msg %s", x)`) throughout + `LIFUTXDevice.py`, `LIFUHVController.py`, `LIFUUart.py`, `LIFUInterface.py`, and + `LIFUUserConfig.py`. +- **`LIFUTXDevice.py`** — fixed a bare `logging.warn` statement (useless expression / + root-logger call) in `calc_pulse_pattern`; replaced with `logger.warning(...)`. +- **`LIFUTXDevice.py`** — moved the mid-file `LIFUConfig` import block to the top of + the file (alongside all other imports) to comply with PEP 8 and ruff E402. +- **`LIFUTXDevice.write_config_json`** — `raise ValueError(...)` inside an `except` + block now chains the original exception (`raise ... from e`) per PEP 3134 / ruff B904. + +--- + +## [1.0.1] — 2025-XX-XX + +_(Previous release — see git history for details.)_ + +## [1.0.0] — 2025-XX-XX + +_(Initial public release.)_ + +[Unreleased]: https://github.com/OpenwaterHealth/openlifu-sdk/compare/1.0.1...HEAD +[1.0.1]: https://github.com/OpenwaterHealth/openlifu-sdk/compare/1.0.0...1.0.1 +[1.0.0]: https://github.com/OpenwaterHealth/openlifu-sdk/releases/tag/1.0.0 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..885ae22 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,167 @@ +# Contributing to openlifu-sdk + +Thank you for your interest in contributing! This guide covers environment setup, coding conventions, testing, and the pull request process. + +## Table of Contents + +- [Development setup](#development-setup) +- [Project structure](#project-structure) +- [Coding conventions](#coding-conventions) +- [Running tests](#running-tests) +- [Linting and formatting](#linting-and-formatting) +- [Pre-commit hooks](#pre-commit-hooks) +- [Pull request process](#pull-request-process) +- [Release process](#release-process) + +--- + +## Development setup + +1. **Fork and clone** the repository: + + ```bash + git clone https://github.com//openlifu-sdk.git + cd openlifu-sdk + ``` + +2. **Create a virtual environment** (Python 3.12+): + + ```bash + python -m venv .venv + source .venv/bin/activate # Windows: .venv\Scripts\activate + ``` + +3. **Install in editable mode** with dev dependencies: + + ```bash + pip install -e ".[dev]" + ``` + +4. **Install pre-commit hooks** (optional but recommended): + + ```bash + pre-commit install + ``` + +--- + +## Project structure + +``` +src/openlifu_sdk/ # Library source (src layout) +unit-test/ # Automated unit tests (pytest) +examples/ # Hardware-targeting interactive scripts (not automated) +.github/workflows/ # CI workflows +``` + +The package uses a `src/` layout (PEP 517). All library code lives under `src/openlifu_sdk/`. + +--- + +## Coding conventions + +- **Python version:** 3.12+. Use modern syntax (`X | Y` unions, `match`, etc.). +- **Imports:** Use `from __future__ import annotations` at the top of every module. +- **Line length:** 120 characters (enforced by ruff). +- **Logging:** + - Use `logging.getLogger(__name__)` — never hard-code logger names. + - Never set `logger.propagate = False` or add handlers in library code. + - Use `%s` lazy formatting: `logger.info("msg %s", value)` — never f-strings. +- **Type annotations:** Add return type annotations to all public methods. +- **Docstrings:** Google style with `Args:` / `Returns:` / `Raises:` sections. +- **Exceptions:** Raise with chaining inside `except` blocks (`raise ... from err`). +- **Signals:** Use `LIFUSignal` for event callbacks; never share signal objects at class level — always create them in `__init__`. + +--- + +## Running tests + +```bash +# All tests, verbose +pytest unit-test/ -v + +# With coverage +pytest unit-test/ -v --cov=src --cov-report=term-missing + +# Single test file +pytest unit-test/test_tx_device.py -v +``` + +Tests in `unit-test/` use `unittest.mock.MagicMock(spec=LIFUUart)` to mock hardware — **no physical device is required**. + +The `examples/` scripts require real hardware and are not run in CI. + +--- + +## Linting and formatting + +This project uses [ruff](https://docs.astral.sh/ruff/) for both linting and formatting. + +```bash +# Check for lint errors +ruff check src/ + +# Auto-fix lint errors +ruff check src/ --fix + +# Format code +ruff format src/ + +# Check formatting without modifying files +ruff format src/ --check +``` + +All lint and format checks run automatically in CI on every push and pull request. + +--- + +## Pre-commit hooks + +[pre-commit](https://pre-commit.com/) runs ruff, mypy, and general file checks before every commit. + +```bash +# Install hooks (one-time, after cloning) +pre-commit install + +# Run manually on all files +pre-commit run --all-files +``` + +--- + +## Pull request process + +1. Create a branch from `main`: + + ```bash + git checkout -b feat/my-feature + ``` + +2. Make your changes, write or update tests, and ensure all checks pass: + + ```bash + pytest unit-test/ -v + ruff check src/ && ruff format src/ --check + ``` + +3. Open a pull request against `main`. CI will run automatically. +4. Address reviewer feedback. +5. A maintainer will merge once CI is green and the PR is approved. + +--- + +## Release process + +Releases are driven by git tags: + +| Tag pattern | Outcome | +|---|---| +| `1.2.3` | Full release — wheel built, GitHub Release created, published to PyPI | +| `pre-1.2.3` | Pre-release — wheel built, GitHub pre-release created, **not** published to PyPI | + +Steps for maintainers: + +1. Ensure `CHANGELOG.md` is up to date. +2. Push a tag: `git tag 1.2.3 && git push origin 1.2.3` +3. The `release-build.yml` workflow builds the wheel and creates the GitHub Release automatically. +4. The `publish-pypi.yml` workflow publishes to PyPI when a non-pre-release GitHub Release is published. diff --git a/README.md b/README.md index 5582507..144495a 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,28 @@ # openlifu-sdk +[![Tests](https://github.com/OpenwaterHealth/openlifu-sdk/actions/workflows/test.yml/badge.svg)](https://github.com/OpenwaterHealth/openlifu-sdk/actions/workflows/test.yml) +[![PyPI](https://img.shields.io/pypi/v/openlifu-sdk)](https://pypi.org/project/openlifu-sdk/) +[![Python](https://img.shields.io/pypi/pyversions/openlifu-sdk)](https://pypi.org/project/openlifu-sdk/) +[![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE) + Openwater LIFU SDK — standalone hardware I/O interface library. -This package provides the low-level communication layer for Openwater LIFU devices, -including the TX module and HV controller. +This package provides the low-level communication layer for Openwater LIFU (Low-Intensity Focused Ultrasound) devices, including the TX beamformer module and the HV (high-voltage) controller console. + +## Features + +- Full USB-serial (VCP) communication with the LIFU TX module and HV console +- Solution programming: delays, apodizations, pulse sequences +- Voltage safety enforcement (duty-cycle / sequence-time / voltage table checks) +- Asynchronous USB hot-plug monitoring +- Firmware update (DFU) for TX modules over USB and I²C +- Demo/test mode (no hardware required) for CI and unit testing +- Bundled libusb DLLs for Windows (win32 / win64) + +## Requirements + +- Python ≥ 3.12 +- `numpy`, `pandas`, `xarray`, `pyserial`, `pyusb` ## Installation @@ -11,12 +30,74 @@ including the TX module and HV controller. pip install openlifu-sdk ``` -Or for development: +For development (includes linting and testing tools): ```bash +git clone https://github.com/OpenwaterHealth/openlifu-sdk.git +cd openlifu-sdk pip install -e ".[dev]" ``` +## Quick Start + +```python +from openlifu_sdk import LIFUInterface + +# Connect to hardware (set TX_test_mode=True / HV_test_mode=True for demo mode) +interface = LIFUInterface() +tx_connected, hv_connected = interface.is_device_connected() +print(f"TX: {tx_connected} HV: {hv_connected}") + +# Program a sonication solution +solution = { + "name": "example", + "voltage": 20.0, + "pulse": {"frequency": 500e3, "duration": 2e-5, "amplitude": 1.0}, + "delays": [[0.0] * 64], # 64-channel delay array (seconds) + "apodizations": [[1.0] * 64], # 64-channel apodization array + "sequence": { + "pulse_interval": 0.1, + "pulse_count": 10, + "pulse_train_interval": 1.0, + "pulse_train_count": 1, + }, +} + +with interface: + interface.set_solution(solution) + interface.start_sonication() + # ... wait for sonication to complete ... + interface.stop_sonication() +``` + +## Architecture Overview + +``` +openlifu_sdk/ +├── __init__.py # Public API: LIFUInterface, LIFUInterfaceStatus +├── io/ +│ ├── LIFUInterface.py # High-level orchestration: solution safety, sonication control +│ ├── LIFUTXDevice.py # TX beamformer: register map, pulse/delay profiles, DFU +│ ├── LIFUHVController.py # HV console: voltage, fans, temperature, LEDs +│ ├── LIFUUart.py # USB-serial transport: framing, CRC, async hot-plug +│ ├── LIFUConfig.py # Protocol constants (packet types, commands) +│ ├── LIFUSignal.py # Qt-style observer/signal pattern +│ ├── LIFUDFU.py # Firmware update (USB DFU + I2C DFU via UART passthrough) +│ └── LIFUUserConfig.py # Device user-config wire format (header + JSON) +└── util/ + ├── units.py # SI unit conversion utilities + └── annotations.py # Typed annotation helpers for dataclass fields +``` + +## Examples + +See the [`examples/`](examples/) directory for hardware-targeted scripts covering: +- Basic connectivity and ping (`test_transmitter.py`) +- Register read/write (`test_registers.py`) +- Solution programming (`test_solution.py`) +- Firmware update (`test_fw_update.py`) +- Async mode (`test_async.py`) + ## Building a wheel ```bash @@ -24,15 +105,26 @@ pip install build python -m build ``` -## Usage +## Running tests -```python -from openlifu_sdk import LIFUInterface +```bash +pytest unit-test/ -v +``` -interface = LIFUInterface() -tx_connected, hv_connected = interface.is_device_connected() +With coverage: + +```bash +pytest unit-test/ -v --cov=src --cov-report=term-missing ``` -## Examples +## Contributing + +See [CONTRIBUTING.md](CONTRIBUTING.md) for development setup, coding conventions, and the pull request process. + +## Changelog + +See [CHANGELOG.md](CHANGELOG.md) for a history of notable changes. + +## License -See the `examples/` directory for usage scripts. +MIT — see [LICENSE](LICENSE). diff --git a/pyproject.toml b/pyproject.toml index 4578dd9..2dee544 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,9 @@ dev = [ "pytest>=6", "pytest-cov>=3", "pytest-mock", + "pre-commit", + "ruff", + "mypy", ] @@ -67,4 +70,67 @@ openlifu_sdk = [ # Uses tags like v1.2.3 or 1.2.3 -> 1.2.3 tag_regex = "^(?:pre-)?v?(?P\\d+\\.\\d+\\.\\d+)$" version_scheme = "no-guess-dev" -local_scheme = "no-local-version" \ No newline at end of file +local_scheme = "no-local-version" + + +# --------------------------------------------------------------------------- +# Ruff – linting and formatting +# --------------------------------------------------------------------------- +[tool.ruff] +target-version = "py312" +line-length = 120 + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "UP", # pyupgrade + "B", # flake8-bugbear + "G", # flake8-logging-format (catches f-string logging) + "LOG", # flake8-logging + "C4", # flake8-comprehensions + "PIE", # flake8-pie + "RUF", # ruff-specific +] +ignore = [ + "E501", # line too long – handled by formatter + "B008", # do not perform function calls in default arguments + "UP007", # use X | Y for unions – fine but deferred +] + +[tool.ruff.lint.isort] +known-first-party = ["openlifu_sdk"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" + + +# --------------------------------------------------------------------------- +# Mypy – static type checking +# --------------------------------------------------------------------------- +[tool.mypy] +python_version = "3.12" +warn_return_any = false +warn_unused_configs = true +ignore_missing_imports = true +# Start non-strict; tighten incrementally +strict = false + + +# --------------------------------------------------------------------------- +# Pytest +# --------------------------------------------------------------------------- +[tool.pytest.ini_options] +testpaths = ["unit-test"] +addopts = "-v" + +[tool.coverage.run] +source = ["src"] +omit = ["*/firmware/*", "*/libusb/*"] + +[tool.coverage.report] +show_missing = true +skip_covered = false diff --git a/src/openlifu_sdk/__init__.py b/src/openlifu_sdk/__init__.py index d87852a..99f3f52 100644 --- a/src/openlifu_sdk/__init__.py +++ b/src/openlifu_sdk/__init__.py @@ -1,8 +1,24 @@ from __future__ import annotations +from openlifu_sdk.exceptions import ( + CommunicationError, + ConfigurationError, + DeviceNotConnectedError, + FirmwareUpdateError, + OpenLIFUError, + SolutionValidationError, +) from openlifu_sdk.io.LIFUInterface import LIFUInterface, LIFUInterfaceStatus +from openlifu_sdk.transport import Transport __all__ = [ + "CommunicationError", + "ConfigurationError", + "DeviceNotConnectedError", + "FirmwareUpdateError", "LIFUInterface", "LIFUInterfaceStatus", + "OpenLIFUError", + "SolutionValidationError", + "Transport", ] diff --git a/src/openlifu_sdk/beamforming/__init__.py b/src/openlifu_sdk/beamforming/__init__.py new file mode 100644 index 0000000..85e8ec1 --- /dev/null +++ b/src/openlifu_sdk/beamforming/__init__.py @@ -0,0 +1,33 @@ +"""Beamforming register-map package for the TX7332 ultrasound beamformer IC.""" + +from __future__ import annotations + +from openlifu_sdk.beamforming.tx7332 import ( + Tx7332DelayProfile, + Tx7332PulseProfile, + Tx7332Registers, + TxDeviceRegisters, + calc_pulse_pattern, + get_delay_location, + get_pattern_location, + get_register_value, + pack_registers, + print_regs, + set_register_value, + swap_byte_order, +) + +__all__ = [ + "Tx7332DelayProfile", + "Tx7332PulseProfile", + "Tx7332Registers", + "TxDeviceRegisters", + "calc_pulse_pattern", + "get_delay_location", + "get_pattern_location", + "get_register_value", + "pack_registers", + "print_regs", + "set_register_value", + "swap_byte_order", +] diff --git a/src/openlifu_sdk/beamforming/tx7332.py b/src/openlifu_sdk/beamforming/tx7332.py new file mode 100644 index 0000000..a8a42f6 --- /dev/null +++ b/src/openlifu_sdk/beamforming/tx7332.py @@ -0,0 +1,984 @@ +"""TX7332 beamformer register-map constants, computation helpers, and dataclasses. + +This module contains everything specific to the Texas Instruments TX7332 ultrasound +beamformer IC: + +- Hardware register address constants +- Register-map computation helpers (delay/pulse profile calculations) +- Dataclasses: Tx7332DelayProfile, Tx7332PulseProfile, Tx7332Registers, TxDeviceRegisters + +These are separated from the UART communication layer (LIFUTXDevice) to make the +register-map logic independently testable without hardware. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Annotated, Literal + +import numpy as np + +from openlifu_sdk.util.annotations import OpenLIFUFieldData +from openlifu_sdk.util.units import getunitconversion + +logger = logging.getLogger(__name__) + +DEFAULT_NUM_TRANSMITTERS = 2 +TRANSMITTERS_PER_MODULE = 2 +ADDRESS_GLOBAL_MODE = 0x0 +ADDRESS_STANDBY = 0x1 +ADDRESS_DYNPWR_2 = 0x6 +ADDRESS_LDO_PWR_1 = 0xB +ADDRESS_TRSW_TURNOFF = 0xC +ADDRESS_DYNPWR_1 = 0xF +ADDRESS_LDO_PWR_2 = 0x14 +ADDRESS_TRSW_TURNON = 0x15 +ADDRESS_DELAY_SEL = 0x16 +ADDRESS_PATTERN_MODE = 0x18 +ADDRESS_PATTERN_REPEAT = 0x19 +ADDRESS_PATTERN_SEL_G2 = 0x1E +ADDRESS_PATTERN_SEL_G1 = 0x1F +ADDRESS_TRSW = 0x1A +ADDRESS_APODIZATION = 0x1B +ADDRESSES_GLOBAL = [ + ADDRESS_GLOBAL_MODE, + ADDRESS_STANDBY, + ADDRESS_DYNPWR_2, + ADDRESS_LDO_PWR_1, + ADDRESS_TRSW_TURNOFF, + ADDRESS_DYNPWR_1, + ADDRESS_LDO_PWR_2, + ADDRESS_TRSW_TURNON, + ADDRESS_DELAY_SEL, + ADDRESS_PATTERN_MODE, + ADDRESS_PATTERN_REPEAT, + ADDRESS_PATTERN_SEL_G1, + ADDRESS_PATTERN_SEL_G2, + ADDRESS_TRSW, + ADDRESS_APODIZATION, +] +ADDRESSES_DELAY_DATA = list(range(0x20, 0x11F + 1)) +ADDRESSES_PATTERN_DATA = list(range(0x120, 0x19F + 1)) +ADDRESSES = ADDRESSES_GLOBAL + ADDRESSES_DELAY_DATA + ADDRESSES_PATTERN_DATA +NUM_CHANNELS = 32 +MAX_REGISTER = 0x19F +REGISTER_BYTES = 4 +REGISTER_WIDTH = REGISTER_BYTES * 8 +DELAY_ORDER = [ + [32, 30], + [28, 26], + [24, 22], + [20, 18], + [31, 29], + [27, 25], + [23, 21], + [19, 17], + [16, 14], + [12, 10], + [8, 6], + [4, 2], + [15, 13], + [11, 9], + [7, 5], + [3, 1], +] +DELAY_ORDER_REVERSED = [[33 - c for c in row] for row in DELAY_ORDER] +DELAY_CHANNEL_MAP = {} +for row, channels in enumerate(DELAY_ORDER_REVERSED): + for i, channel in enumerate(channels): + DELAY_CHANNEL_MAP[channel] = {"row": row, "lsb": 16 * (1 - i)} +DELAY_PROFILE_OFFSET = 16 +VALID_DELAY_PROFILES = list(range(1, 17)) +DELAY_WIDTH = 13 +APODIZATION_CHANNEL_ORDER = [ + 17, + 19, + 21, + 23, + 25, + 27, + 29, + 31, + 18, + 20, + 22, + 24, + 26, + 28, + 30, + 32, + 1, + 3, + 5, + 7, + 9, + 11, + 13, + 15, + 2, + 4, + 6, + 8, + 10, + 12, + 14, + 16, +] +APODIZATION_CHANNEL_ORDER_REVERSED = [33 - c for c in APODIZATION_CHANNEL_ORDER] +DEFAULT_PATTERN_DUTY_CYCLE = 0.66 +PATTERN_PROFILE_OFFSET = 4 +NUM_PATTERN_PROFILES = 32 +VALID_PATTERN_PROFILES = list(range(1, NUM_PATTERN_PROFILES + 1)) +MAX_PATTERN_PERIODS = 16 +PATTERN_PERIOD_ORDER = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]] +PATTERN_LENGTH_WIDTH = 5 +MAX_PATTERN_PERIOD_LENGTH = 30 +PATTERN_LEVEL_WIDTH = 3 +PATTERN_MAP = {} +for row, periods in enumerate(PATTERN_PERIOD_ORDER): + for i, period in enumerate(periods): + PATTERN_MAP[period] = { + "row": row, + "lsb_lvl": i * (PATTERN_LEVEL_WIDTH + PATTERN_LENGTH_WIDTH), + "lsb_period": i * (PATTERN_LENGTH_WIDTH + PATTERN_LEVEL_WIDTH) + PATTERN_LEVEL_WIDTH, + } +MAX_REPEAT = 2**5 - 1 +MAX_ELASTIC_REPEAT = 2**16 - 1 +DEFAULT_TAIL_COUNT = 29 +DEFAULT_CLK_FREQ = 10e6 +ELASTIC_MODE_PULSE_LENGTH_ADJUST = 125e-6 +ProfileOpts = Literal["active", "configured", "all"] +TriggerModeOpts = Literal["sequence", "continuous", "single"] +DEFAULT_PULSE_WIDTH_US = 20 +HW_ID_DATA_LENGTH = 12 +TEMPERATURE_DATA_LENGTH = 4 + +if TYPE_CHECKING: + pass + + +def get_delay_location(channel: int, profile: int = 1): + """ + Gets the address and least significant bit of a delay + + :param channel: Channel number + :param profile: Delay profile number + :returns: Register address and least significant bit of the delay location + """ + if channel not in DELAY_CHANNEL_MAP: + raise ValueError(f"Invalid channel {channel}.") + channel_map = DELAY_CHANNEL_MAP[channel] + if profile not in VALID_DELAY_PROFILES: + raise ValueError(f"Invalid Profile {profile}") + address = ADDRESSES_DELAY_DATA[0] + (profile - 1) * DELAY_PROFILE_OFFSET + channel_map["row"] + lsb = channel_map["lsb"] + return address, lsb + + +def set_register_value(reg_value: int, value: int, lsb: int = 0, width: int | None = None): + """ + Sets the value of a parameter in a register integer + + :param reg_value: Register value + :param value: New value of the parameter + :param lsb: Least significant bit of the parameter + :param width: Width of the parameter (bits) + :returns: New register value + """ + if width is None: + width = REGISTER_WIDTH - lsb + mask = (1 << width) - 1 + if value < 0 or value > mask: + raise ValueError(f"Value {value} does not fit in {width} bits") + return (reg_value & ~(mask << lsb)) | ((int(value) & mask) << lsb) + + +def get_register_value(reg_value: int, lsb: int = 0, width: int | None = None): + """ + Extracts the value of a parameter from a register integer + + :param reg_value: Register value + :param lsb: Least significant bit of the parameter + :param width: Width of the parameter (bits) + :returns: Value of the parameter + """ + if width is None: + width = REGISTER_WIDTH - lsb + mask = (1 << width) - 1 + return (reg_value >> lsb) & mask + + +def calc_pulse_pattern( + frequency: float, duty_cycle: float = DEFAULT_PATTERN_DUTY_CYCLE, bf_clk: float = DEFAULT_CLK_FREQ +): + """ + Calculates the pattern for a given frequency and duty cycle + + The pattern is calculated to represent a single cycle of a pulse with the specified frequency and duty cycle. + If the pattern requires more than 16 periods, the clock divider is increased to reduce the period length. + + :param frequency: Frequency of the pattern in Hz + :param duty_cycle: Duty cycle of the pattern + :param bf_clk: Clock frequency of the BF system in Hz + :returns: Tuple of lists of levels and lengths, and the clock divider setting + """ + clk_div_n = 0 + while clk_div_n < 6: + clk_n = bf_clk / (2**clk_div_n) + period_samples = int(clk_n / frequency) + first_half_period_samples = int(period_samples / 2) + second_half_period_samples = period_samples - first_half_period_samples + first_on_samples = int(first_half_period_samples * duty_cycle) + if first_on_samples < 2: + logger.warning("Duty cycle too short. Setting to minimum of 2 samples") + first_on_samples = 2 + first_off_samples = first_half_period_samples - first_on_samples + second_on_samples = max(2, int(second_half_period_samples * duty_cycle)) + if second_on_samples < 2: + logger.warning("Duty cycle too short. Setting to minimum of 2 samples") + second_on_samples = 2 + second_off_samples = second_half_period_samples - second_on_samples + if first_off_samples > 0 and first_off_samples < 2: + first_off_samples = 0 + first_on_samples = first_half_period_samples + if second_off_samples > 0 and first_off_samples < 2: + second_off_samples = 0 + second_on_samples = second_half_period_samples + levels = [1, 0, -1, 0] + per_lengths = [] + per_levels = [] + for i, samples in enumerate([first_on_samples, first_off_samples, second_on_samples, second_off_samples]): + while samples > 0: + if samples > MAX_PATTERN_PERIOD_LENGTH + 2: + if samples == MAX_PATTERN_PERIOD_LENGTH + 3: + per_lengths.append(MAX_PATTERN_PERIOD_LENGTH - 1) + samples -= MAX_PATTERN_PERIOD_LENGTH + 1 + else: + per_lengths.append(MAX_PATTERN_PERIOD_LENGTH) + samples -= MAX_PATTERN_PERIOD_LENGTH + 2 + per_levels.append(levels[i]) + else: + per_lengths.append(samples - 2) + per_levels.append(levels[i]) + samples = 0 + if len(per_levels) <= MAX_PATTERN_PERIODS: + t = (np.arange(np.sum(np.array(per_lengths) + 2)) * (1 / clk_n)).tolist() + y = np.concatenate([[yi] * (ni + 2) for yi, ni in zip(per_levels, per_lengths, strict=False)]).tolist() + pattern = {"levels": per_levels, "lengths": per_lengths, "clk_div_n": clk_div_n, "t": t, "y": y} + return pattern + else: + clk_div_n += 1 + raise ValueError(f"Pattern requires too many periods ({len(per_levels)} > {MAX_PATTERN_PERIODS})") + + +def get_pattern_location(period: int, profile: int = 1): + """ + Gets the address and least significant bit of a pattern period + + :param period: Pattern period number + :param profile: Pattern profile number + :returns: Register address and least significant bit of the pattern period location + """ + if period not in PATTERN_MAP: + raise ValueError(f"Invalid period {period}.") + if profile not in VALID_PATTERN_PROFILES: + raise ValueError(f"Invalid profile {profile}.") + address = ADDRESSES_PATTERN_DATA[0] + (profile - 1) * PATTERN_PROFILE_OFFSET + PATTERN_MAP[period]["row"] + lsb_lvl = PATTERN_MAP[period]["lsb_lvl"] + lsb_period = PATTERN_MAP[period]["lsb_period"] + return address, lsb_lvl, lsb_period + + +def print_regs(d): + for addr, val in sorted(d.items()): + if isinstance(val, list): + for i, v in enumerate(val): + print(f"0x{addr:X}[+{i:d}]:x{v:08X}") + else: + print(f"0x{addr:X}:x{val:08X}") + + +def pack_registers(regs, pack_single: bool = False): + """ + Packs registers into contiguous blocks + + :param regs: Dictionary of registers + :param pack_single: Pack single registers into arrays. Default True. + :returns: Dictionary of packed registers. + """ + addresses = sorted(regs.keys()) + if len(addresses) == 0: + return {} + last_addr = -255 + burst_addr = -255 + packed = {} + for addr in addresses: + if addr == last_addr + 1 and burst_addr in packed: + packed[burst_addr].append(regs[addr]) + else: + packed[addr] = [regs[addr]] + burst_addr = addr + last_addr = addr + if not pack_single: + for addr, val in packed.items(): + if len(val) == 1: + packed[addr] = val[0] + return packed + + +def swap_byte_order(regs): + """ + Swaps the byte order of the registers + + :param regs: Dictionary of registers + :returns: Dictionary of registers with swapped byte order + """ + swapped = {} + for addr, val in regs.items(): + if isinstance(val, list): + swapped[addr] = [int.from_bytes(v.to_bytes(REGISTER_BYTES, "big"), "little") for v in val] + else: + swapped[addr] = int.from_bytes(val.to_bytes(REGISTER_BYTES, "big"), "little") + return swapped + + +@dataclass +class Tx7332DelayProfile: + profile: Annotated[int, OpenLIFUFieldData("Profile Index (1-16)", "Index of the delay profile (1-16)")] + """Index of the delay profile (1-16). The Tx7332 support 16 unique delay profiles.""" + + delays: Annotated[list[float], OpenLIFUFieldData("Delay values", "Delay values for transducer elements")] + """Delay values for transducer elements""" + + apodizations: Annotated[ + list[int] | None, OpenLIFUFieldData("Apodizations", "Apodization values for transducer elements") + ] = None + """Apodization values for transducer elements""" + + units: Annotated[str, OpenLIFUFieldData("Units", "Time units used for delay values")] = "s" + """Time units used for delay values""" + + def __post_init__(self): + self.num_elements = len(self.delays) + if self.apodizations is None: + self.apodizations = [1] * self.num_elements + if len(self.apodizations) != self.num_elements: + raise ValueError(f"Apodizations list must have {self.num_elements} elements") + if self.profile not in VALID_DELAY_PROFILES: + raise ValueError(f"Invalid Profile {self.profile}") + + +@dataclass +class Tx7332PulseProfile: + profile: Annotated[int, OpenLIFUFieldData("Profile index (1-32)", "Index of the pulse profile (1-32)")] + """Index of the pulse profile (1-32). The Tx7332 supports 32 unique pulse profiles.""" + + frequency: Annotated[float, OpenLIFUFieldData("Frequency (Hz)", "Center frequency of the pulse (Hz)")] + """Center frequency of the pulse (Hz)""" + + cycles: Annotated[int, OpenLIFUFieldData("Number of cycles", "Number of cycles in the pulse")] + """Number of cycles in the pulse""" + + duty_cycle: Annotated[ + float, OpenLIFUFieldData("Duty cycle (0-1)", "Pulse duty cycle for the generated square wave (0-1)") + ] = DEFAULT_PATTERN_DUTY_CYCLE + """Pulse duty cycle for the generated square wave (0-1). By default 0.66 is used to approximate a sinusoidal wave.""" + + tail_count: Annotated[ + int, + OpenLIFUFieldData( + "Tail count (cycles)", "Clock cycles to actively drive the pulser to ground after the pulse ends" + ), + ] = DEFAULT_TAIL_COUNT + """Clock cycles to actively drive the pulser to ground after the pulse ends. Default 29""" + + invert: Annotated[ + bool, OpenLIFUFieldData("Invert polarity?", "Flag indicating whether to invert the pulse amplitude") + ] = False + """Invert the pulse amplitude. Default False""" + + def __post_init__(self): + if self.profile not in VALID_PATTERN_PROFILES: + raise ValueError(f"Invalid profile {self.profile}.") + + +@dataclass +class Tx7332Registers: + bf_clk: Annotated[float, OpenLIFUFieldData("Clock Frequency (Hz)", "The beamformer clock frequency in Hz.")] = ( + DEFAULT_CLK_FREQ + ) + """The beamformer clock frequency in Hz. This much match the hardware clock frequency in order for calculated register values to produce the correct pulse and delay timting. Default is 64 MHz.""" + + _delay_profiles_list: Annotated[ + list[Tx7332DelayProfile], OpenLIFUFieldData("Delay profiles list", "Internal list of available delay profiles") + ] = field(default_factory=list) + """Internal list of available delay profiles""" + + _pulse_profiles_list: Annotated[ + list[Tx7332PulseProfile], OpenLIFUFieldData("Pulse profiles list", "Internal list of available pulse profiles") + ] = field(default_factory=list) + """Internal list of available pulse profiles""" + + active_delay_profile: Annotated[ + int | None, OpenLIFUFieldData("Active delay profile", "Index of the currently active delay profile") + ] = None + """Index of the currently active delay profile""" + + active_pulse_profile: Annotated[ + int | None, OpenLIFUFieldData("Active pulse profile", "Index of the currently active pulse profile") + ] = None + """Index of the currently active pulse profile""" + + def __post_init__(self): + delay_profile_indices = self.configured_delay_profiles() + if len(delay_profile_indices) != len(set(delay_profile_indices)): + raise ValueError("Duplicate delay profiles found") + if self.active_delay_profile is not None and self.active_delay_profile not in delay_profile_indices: + raise ValueError(f"Delay profile {self.active_delay_profile} not found") + pulse_profile_indices = self.configured_pulse_profiles() + if len(pulse_profile_indices) != len(set(pulse_profile_indices)): + raise ValueError("Duplicate pulse profiles found") + if self.active_pulse_profile is not None and self.active_pulse_profile not in pulse_profile_indices: + raise ValueError(f"Pulse profile {self.active_pulse_profile} not found") + + def add_delay_profile(self, p: Tx7332DelayProfile, activate: bool | None = None): + if p.num_elements != NUM_CHANNELS: + raise ValueError(f"Delay profile must have {NUM_CHANNELS} elements") + profile_indices = self.configured_delay_profiles() + if p.profile in profile_indices: + i = profile_indices.index(p.profile) + self._delay_profiles_list[i] = p + else: + self._delay_profiles_list.append(p) + if activate is None: + activate = self.active_delay_profile is None + if activate: + self.active_delay_profile = p.profile + + def add_pulse_profile(self, p: Tx7332PulseProfile, activate: bool | None = None): + profile_indices = self.configured_pulse_profiles() + if p.profile in profile_indices: + i = profile_indices.index(p.profile) + self._pulse_profiles_list[i] = p + else: + self._pulse_profiles_list.append(p) + if activate is None: + activate = self.active_pulse_profile is None + if activate: + self.active_pulse_profile = p.profile + + def remove_delay_profile(self, profile: int): + profile_indices = self.configured_delay_profiles() + if profile not in profile_indices: + raise ValueError(f"Delay profile {profile} not found") + index = profile_indices.index(profile) + del self._delay_profiles_list[index] + if self.active_delay_profile == index: + self.active_delay_profile = None + + def remove_pulse_profile(self, profile: int): + profiles = self.configured_pulse_profiles() + if profile not in profiles: + raise ValueError(f"Pulse profile {profile} not found") + index = profiles.index(profile) + del self._pulse_profiles_list[index] + if self.active_pulse_profile == index: + self.active_pulse_profile = None + + def get_delay_profile(self, profile: int | None = None) -> Tx7332DelayProfile: + if profile is None: + profile = self.active_delay_profile + profiles = self.configured_delay_profiles() + if profile not in profiles: + raise ValueError(f"Delay profile {profile} not found") + index = profiles.index(profile) + return self._delay_profiles_list[index] + + def configured_delay_profiles(self) -> list[int]: + return [p.profile for p in self._delay_profiles_list] + + def get_pulse_profile(self, profile: int | None = None) -> Tx7332PulseProfile: + if profile is None: + profile = self.active_pulse_profile + profiles = self.configured_pulse_profiles() + if profile not in profiles: + raise ValueError(f"Pulse profile {profile} not found") + index = profiles.index(profile) + return self._pulse_profiles_list[index] + + def configured_pulse_profiles(self) -> list[int]: + return [p.profile for p in self._pulse_profiles_list] + + def activate_delay_profile(self, profile: int): + if profile not in self.configured_delay_profiles(): + raise ValueError(f"Delay profile {profile} not configured") + self.active_delay_profile = profile + + def activate_pulse_profile(self, profile: int): + if profile not in self.configured_pulse_profiles(): + raise ValueError(f"Pulse profile {profile} not configured") + self.active_pulse_profile = profile + + def get_delay_control_registers(self, profile: int | None = None) -> dict[int, int]: + if profile is None: + profile = self.active_delay_profile + delay_profile = self.get_delay_profile(profile) + apod_register = 0 + for i, apod in enumerate(delay_profile.apodizations): + apod_register = set_register_value( + apod_register, 1 - apod, lsb=APODIZATION_CHANNEL_ORDER_REVERSED.index(i + 1), width=1 + ) + delay_sel_register = 0 + delay_sel_register = set_register_value(delay_sel_register, delay_profile.profile - 1, lsb=12, width=4) + delay_sel_register = set_register_value(delay_sel_register, delay_profile.profile - 1, lsb=28, width=4) + return {ADDRESS_DELAY_SEL: delay_sel_register, ADDRESS_APODIZATION: apod_register} + + def get_pulse_control_registers(self, profile: int | None = None, pulse_invert: bool = False) -> dict[int, int]: + if profile is None: + profile = self.active_pulse_profile + pulse_profile = self.get_pulse_profile(profile) + if pulse_profile.profile not in VALID_PATTERN_PROFILES: + raise ValueError(f"Invalid profile {pulse_profile.profile}.") + pattern = calc_pulse_pattern(pulse_profile.frequency, pulse_profile.duty_cycle, bf_clk=self.bf_clk) + clk_div_n = pattern["clk_div_n"] + clk_div = 2**clk_div_n + clk_n = self.bf_clk / clk_div + cycles = int(pulse_profile.cycles) + if cycles > (MAX_REPEAT + 1): + # Use elastic repeat + pulse_duration_samples = self.bf_clk * ( + (cycles / pulse_profile.frequency) + ELASTIC_MODE_PULSE_LENGTH_ADJUST + ) + repeat = 0 + elastic_repeat = int(pulse_duration_samples / 16) + period_samples = int(clk_n / pulse_profile.frequency) + cycles = 16 * elastic_repeat / period_samples + y = pattern["y"] * int(cycles + 1) + y = y[: (16 * elastic_repeat)] + y = y + ([0] * pulse_profile.tail_count) + np.arange(len(y)) * (1 / clk_n) + elastic_mode = 1 + if elastic_repeat > MAX_ELASTIC_REPEAT: + raise ValueError("Pattern duration too long for elastic repeat") + else: + repeat = cycles - 1 + elastic_repeat = 0 + elastic_mode = 0 + y = pattern["y"] * (repeat + 1) + y = np.array(y + [0] * pulse_profile.tail_count) + reg_mode = 0x02000003 + reg_mode = set_register_value(reg_mode, clk_div_n, lsb=3, width=3) + reg_mode = set_register_value(reg_mode, int(pulse_profile.invert ^ pulse_invert), lsb=6, width=1) + reg_repeat = 0 + reg_repeat = set_register_value(reg_repeat, repeat, lsb=1, width=5) + reg_repeat = set_register_value(reg_repeat, pulse_profile.tail_count, lsb=6, width=5) + reg_repeat = set_register_value(reg_repeat, elastic_mode, lsb=11, width=1) + reg_repeat = set_register_value(reg_repeat, elastic_repeat, lsb=12, width=16) + reg_pat_sel = 0 + reg_pat_sel = set_register_value(reg_pat_sel, pulse_profile.profile - 1, lsb=0, width=6) + registers = { + ADDRESS_PATTERN_MODE: reg_mode, + ADDRESS_PATTERN_REPEAT: reg_repeat, + ADDRESS_PATTERN_SEL_G1: reg_pat_sel, + ADDRESS_PATTERN_SEL_G2: reg_pat_sel, + } + return registers + + def get_delay_data_registers( + self, profile: int | None = None, pack: bool = False, pack_single: bool = False + ) -> dict[int, int]: + if profile is None: + profile = self.active_delay_profile + delay_profile = self.get_delay_profile(profile) + data_registers = {} + for channel in range(1, NUM_CHANNELS + 1): + address, lsb = get_delay_location(channel, delay_profile.profile) + if address not in data_registers: + data_registers[address] = 0 + delay_value = int( + delay_profile.delays[channel - 1] * getunitconversion(delay_profile.units, "s") * self.bf_clk + ) + data_registers[address] = set_register_value( + data_registers[address], delay_value, lsb=lsb, width=DELAY_WIDTH + ) + if pack: + data_registers = pack_registers(data_registers, pack_single=pack_single) + return data_registers + + def get_pulse_data_registers( + self, profile: int | None = None, pack: bool = False, pack_single: bool = False + ) -> dict[int, int]: + if profile is None: + profile = self.active_pulse_profile + profile_index = self.get_pulse_profile(profile) + data_registers = {} + pattern = calc_pulse_pattern(profile_index.frequency, profile_index.duty_cycle, bf_clk=self.bf_clk) + levels = pattern["levels"] + lengths = pattern["lengths"] + nperiods = len(levels) + level_lut = { + -1: 0b01, + 0: 0b11, + 1: 0b10, + } # Map levels to register values 0b11 drive to ground 0b00 high impedance + for i, (level, length) in enumerate(zip(levels, lengths, strict=False)): + address, lsb_lvl, lsb_length = get_pattern_location(i + 1, profile_index.profile) + if address not in data_registers: + data_registers[address] = 0 + data_registers[address] = set_register_value( + data_registers[address], level_lut[level], lsb=lsb_lvl, width=PATTERN_LEVEL_WIDTH + ) + data_registers[address] = set_register_value( + data_registers[address], length, lsb=lsb_length, width=PATTERN_LENGTH_WIDTH + ) + if nperiods < MAX_PATTERN_PERIODS: + address, lsb_lvl, lsb_length = get_pattern_location(nperiods + 1, profile_index.profile) + if address not in data_registers: + data_registers[address] = 0 + data_registers[address] = set_register_value( + data_registers[address], 0b111, lsb=lsb_lvl, width=PATTERN_LEVEL_WIDTH + ) + data_registers[address] = set_register_value( + data_registers[address], 0, lsb=lsb_length, width=PATTERN_LENGTH_WIDTH + ) + if pack: + data_registers = pack_registers(data_registers, pack_single=pack_single) + return data_registers + + def get_registers( + self, + profiles: ProfileOpts = "configured", + pack: bool = False, + pack_single: bool = False, + pulse_invert: bool = False, + ) -> dict[int, int]: + if len(self._delay_profiles_list) == 0: + raise ValueError("No delay profiles have been configured") + if len(self._pulse_profiles_list) == 0: + raise ValueError("No pulse profiles have been configured") + if self.active_delay_profile is None: + raise ValueError("No delay profile activated") + if self.active_pulse_profile is None: + raise ValueError("No pulse profile activated") + registers = dict.fromkeys(ADDRESSES_GLOBAL, 0) + registers.update(self.get_delay_control_registers()) + registers.update(self.get_pulse_control_registers(pulse_invert=pulse_invert)) + if profiles == "active": + delay_data = self.get_delay_data_registers() + pulse_data = self.get_pulse_data_registers() + else: + if profiles == "all": + delay_data = dict.fromkeys(ADDRESSES_DELAY_DATA, 0) + pulse_data = dict.fromkeys(ADDRESSES_PATTERN_DATA, 0) + else: + delay_data = {} + pulse_data = {} + for delay_profile in self._delay_profiles_list: + delay_data.update(self.get_delay_data_registers(profile=delay_profile.profile)) + for profile_index in self._pulse_profiles_list: + pulse_data.update(self.get_pulse_data_registers(profile=profile_index.profile)) + registers.update(delay_data) + registers.update(pulse_data) + if pack: + registers = pack_registers(registers, pack_single=pack_single) + return registers + + +@dataclass +class TxDeviceRegisters: + bf_clk: Annotated[int, OpenLIFUFieldData("Clock Frequency (Hz)", "The beamformer clock frequency in Hz.")] = ( + DEFAULT_CLK_FREQ + ) + """The beamformer clock frequency in Hz. This much match the hardware clock frequency in order for calculated register values to produce the correct pulse and delay timting. Default is 64 MHz.""" + + _delay_profiles_list: Annotated[ + list[Tx7332DelayProfile], OpenLIFUFieldData("Delay profiles list", "Internal list of available delay profiles") + ] = field(default_factory=list) + """Internal list of available delay profiles""" + + _profiles_list: Annotated[ + list[Tx7332PulseProfile], OpenLIFUFieldData("Pulse profiles list", "Internal list of available pulse profiles") + ] = field(default_factory=list) + """Internal list of available pulse profiles""" + + active_delay_profile: Annotated[ + int | None, OpenLIFUFieldData("Active delay profile", "Index of the currently active delay profile") + ] = None + """Index of the currently active delay profile""" + + active_profile: Annotated[ + int | None, OpenLIFUFieldData("Active pulse profile", "Index of the currently active pulse profile") + ] = None + """Index of the currently active pulse profile""" + + num_transmitters: Annotated[ + int, OpenLIFUFieldData("Number of transmitters", "The number of transmitters available on the device") + ] = DEFAULT_NUM_TRANSMITTERS + """The number of transmitters available on the device""" + + module_invert: Annotated[ + list[bool] | bool, + OpenLIFUFieldData( + "Module Invert", + "List of flags indicating whether to invert the pulse amplitude for each module or a single flag for all modules", + ), + ] = False + """List of flags indicating whether to invert the pulse amplitude for each module or a single flag for all modules""" + + def __post_init__(self): + self.transmitters = tuple([Tx7332Registers(bf_clk=self.bf_clk) for _ in range(self.num_transmitters)]) + + def add_pulse_profile(self, pulse_profile: Tx7332PulseProfile, activate: bool | None = None): + """ + Add a pulse profile + + :param p: Pulse profile + :param activate: Activate the pulse profile + """ + profiles = self.configured_pulse_profiles() + if pulse_profile.profile in profiles: + i = profiles.index(pulse_profile.profile) + self._profiles_list[i] = pulse_profile + else: + self._profiles_list.append(pulse_profile) + if activate is None: + activate = self.active_profile is None + if activate: + self.active_profile = pulse_profile.profile + for tx in self.transmitters: + tx.add_pulse_profile(pulse_profile, activate=activate) + + def add_delay_profile(self, delay_profile: Tx7332DelayProfile, activate: bool | None = None): + """ + Add a delay profile + + :param p: Delay profile + :param activate: Activate the delay profile + """ + if delay_profile.num_elements != NUM_CHANNELS * self.num_transmitters: + raise ValueError(f"Delay profile must have {NUM_CHANNELS * self.num_transmitters} elements") + profiles = self.configured_delay_profiles() + if delay_profile.profile in profiles: + i = profiles.index(delay_profile.profile) + self._delay_profiles_list[i] = delay_profile + else: + self._delay_profiles_list.append(delay_profile) + if activate is None: + activate = self.active_delay_profile is None + if activate: + self.active_delay_profile = delay_profile.profile + for i, tx in enumerate(self.transmitters): + module = i // 2 + chip = (i + 1) % 2 + start_channel = module * NUM_CHANNELS * 2 + chip * NUM_CHANNELS + profiles = np.arange(start_channel, start_channel + NUM_CHANNELS, dtype=int) + tx_delays = np.array(delay_profile.delays)[profiles].tolist() + tx_apodizations = np.array(delay_profile.apodizations)[profiles].tolist() + txp = Tx7332DelayProfile(delay_profile.profile, tx_delays, tx_apodizations, delay_profile.units) + tx.add_delay_profile(txp, activate=activate) + + def remove_delay_profile(self, profile: int): + """ + Remove a delay profile + + :param profile: Delay profile number + """ + profiles = self.configured_delay_profiles() + if profile not in profiles: + raise ValueError(f"Delay profile {profile} not found") + i = profiles.index(profile) + del self._delay_profiles_list[i] + if self.active_delay_profile == profile: + self.active_delay_profile = None + for tx in self.transmitters: + tx.remove_delay_profile(profile) + + def remove_pulse_profile(self, profile: int): + """ + Remove a pulse profile + + :param profile: Pulse profile number + """ + profiles = self.configured_pulse_profiles() + if profile not in profiles: + raise ValueError(f"Pulse profile {profile} not found") + i = profiles.index(profile) + del self._profiles_list[i] + if self.active_profile == profile: + self.active_profile = None + for tx in self.transmitters: + tx.remove_pulse_profile(profile) + + def get_delay_profile(self, profile: int | None = None) -> Tx7332DelayProfile: + """ + Retrieve a delay profile + + :param profile: Delay profile number + :return: Delay profile + """ + if profile is None: + profile = self.active_delay_profile + profiles = self.configured_delay_profiles() + if profile not in profiles: + raise ValueError(f"Delay profile {profile} not found") + i = profiles.index(profile) + return self._delay_profiles_list[i] + + def configured_delay_profiles(self) -> list[int]: + """ + Get the configured delay profiles + + :return: List of delay profiles + """ + return [p.profile for p in self._delay_profiles_list] + + def get_pulse_profile(self, profile: int | None = None) -> Tx7332PulseProfile: + """ + Retrieve a pulse profile + + :param profile: Pulse profile number + :return: Pulse profile + """ + if profile is None: + profile = self.active_profile + profiles = self.configured_pulse_profiles() + if profile not in profiles: + raise ValueError(f"Pulse profile {profile} not found") + i = profiles.index(profile) + return self._profiles_list[i] + + def configured_pulse_profiles(self) -> list[int]: + """ + Get the configured pulse profiles + + :return: List of pulse profiles + """ + return [p.profile for p in self._profiles_list] + + def activate_delay_profile(self, profile: int = 1): + """ + Activates a delay profile + + :param profile: Delay profile number + """ + for tx in self.transmitters: + tx.activate_delay_profile(profile) + self.active_delay_profile = profile + + def activate_pulse_profile(self, profile: int = 1): + """ + Activates a pulse profile + + :param profile: Pulse profile number + """ + for tx in self.transmitters: + tx.activate_pulse_profile(profile) + self.active_profile = profile + + def recompute_delay_profiles(self): + """ + Recompute the delay profiles + """ + for tx in self.transmitters: + profiles = tx.configured_delay_profiles() + for profile in profiles: + tx.remove_delay_profile(profile) + for dp in self._delay_profiles_list: + self.add_delay_profile(dp, activate=dp.profile == self.active_delay_profile) + + def recompute_pulse_profiles(self): + """ + Recompute the pulse profiles + """ + for tx in self.transmitters: + profiles = tx.configured_pulse_profiles() + for profile in profiles: + tx.remove_pulse_profile(profile) + for pp in self._profiles_list: + tx.add_pulse_profile(pp, activate=pp.profile == self.active_profile) + + def get_registers( + self, + profiles: ProfileOpts = "configured", + recompute: bool = False, + pack: bool = False, + pack_single: bool = False, + ) -> list[dict[int, int]]: + """ + Get the registers for all transmitters + + :param profiles: Profile options + :param recompute: Recompute the registers + :return: List of registers for each transmitter + """ + if recompute: + self.recompute_delay_profiles() + self.recompute_pulse_profiles() + if isinstance(self.module_invert, bool): + tx_invert = [self.module_invert] * self.num_transmitters + else: + tx_invert = list( + np.array(self.module_invert * TRANSMITTERS_PER_MODULE, dtype=bool) + .reshape(TRANSMITTERS_PER_MODULE, -1) + .T.reshape(-1) + ) + return [ + tx.get_registers(profiles, pack=pack, pack_single=pack_single, pulse_invert=inv) + for tx, inv in zip(self.transmitters, tx_invert, strict=False) + ] + + def get_delay_control_registers(self, profile: int | None = None) -> list[dict[int, int]]: + """ + Get the delay control registers for all transmitters + + :param profile: Delay profile number + :return: List of delay control registers for each transmitter + """ + if profile is None: + profile = self.active_delay_profile + return [tx.get_delay_control_registers(profile) for tx in self.transmitters] + + def get_pulse_control_registers(self, profile: int | None = None) -> list[dict[int, int]]: + """ + Get the pulse control registers for all transmitters + + :param profile: Pulse profile number + :return: List of pulse control registers for each transmitter + """ + if profile is None: + profile = self.active_profile + if isinstance(self.module_invert, bool): + tx_invert = [self.module_invert] * self.num_transmitters + else: + tx_invert = list(np.array(self.module_invert * TRANSMITTERS_PER_MODULE, dtype=bool).T.reshape(-1)) + return [ + tx.get_pulse_control_registers(profile, pulse_invert=inv) + for tx, inv in zip(self.transmitters, tx_invert, strict=False) + ] + + def get_delay_data_registers( + self, profile: int | None = None, pack: bool = False, pack_single: bool = False + ) -> list[dict[int, int]]: + """ + Get the delay data registers for all transmitters + + :param profile: Delay profile number + :return: List of delay data registers for each transmitter + """ + if profile is None: + profile = self.active_delay_profile + return [tx.get_delay_data_registers(profile, pack=pack, pack_single=pack_single) for tx in self.transmitters] + + def get_pulse_data_registers( + self, profile: int | None = None, pack: bool = False, pack_single: bool = False + ) -> list[dict[int, int]]: + """ + Get the pulse data registers for all transmitters + + :param profile: Pulse profile number + :return: List of pulse data registers for each transmitter + """ + if profile is None: + profile = self.active_profile + return [tx.get_pulse_data_registers(profile, pack=pack, pack_single=pack_single) for tx in self.transmitters] diff --git a/src/openlifu_sdk/exceptions.py b/src/openlifu_sdk/exceptions.py new file mode 100644 index 0000000..83fe29b --- /dev/null +++ b/src/openlifu_sdk/exceptions.py @@ -0,0 +1,31 @@ +"""SDK-level exception hierarchy for openlifu-sdk. + +Catch :class:`OpenLIFUError` to handle any SDK-specific error. +Sub-classes provide finer-grained handling. +""" + +from __future__ import annotations + + +class OpenLIFUError(Exception): + """Base class for all openlifu-sdk exceptions.""" + + +class DeviceNotConnectedError(OpenLIFUError): + """Raised when a hardware operation is attempted but the device is not connected.""" + + +class CommunicationError(OpenLIFUError): + """Raised when a transport-level communication failure occurs (CRC mismatch, timeout, etc.).""" + + +class SolutionValidationError(OpenLIFUError, ValueError): + """Raised when a sonication solution fails safety validation (voltage, duty cycle, etc.).""" + + +class FirmwareUpdateError(OpenLIFUError): + """Raised when a firmware update (DFU) operation fails.""" + + +class ConfigurationError(OpenLIFUError): + """Raised when device configuration is invalid or cannot be parsed.""" diff --git a/src/openlifu_sdk/io/LIFUConfig.py b/src/openlifu_sdk/io/LIFUConfig.py index d145b1d..55c81d0 100644 --- a/src/openlifu_sdk/io/LIFUConfig.py +++ b/src/openlifu_sdk/io/LIFUConfig.py @@ -1,4 +1,3 @@ - # Packet Types from __future__ import annotations diff --git a/src/openlifu_sdk/io/LIFUDFU.py b/src/openlifu_sdk/io/LIFUDFU.py index 69d77fb..b739d21 100644 --- a/src/openlifu_sdk/io/LIFUDFU.py +++ b/src/openlifu_sdk/io/LIFUDFU.py @@ -14,9 +14,10 @@ import struct import sys import time -from pathlib import Path -from typing import TYPE_CHECKING, Callable +from collections.abc import Callable from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING from openlifu_sdk.io.LIFUConfig import OW_ERROR, OW_I2C_PASSTHRU @@ -29,9 +30,10 @@ # Optional USB DFU dependencies (module 0 only) # --------------------------------------------------------------------------- try: + import usb.backend.libusb1 as _usb_libusb1 import usb.core as _usb_core import usb.util as _usb_util - import usb.backend.libusb1 as _usb_libusb1 + _USB_DFU_AVAILABLE = True except ImportError: _usb_core = None @@ -67,43 +69,42 @@ def _find_bundled_libusb_dll() -> str | None: # 2. Development / editable install: search up the directory tree ms_dir = "MS64" if arch_dir == "win64" else "MS32" for parent in Path(__file__).parents: - dev_candidate = ( - parent / "libusb-1.0.29" / "VS2022" / ms_dir / "dll" / "libusb-1.0.dll" - ) + dev_candidate = parent / "libusb-1.0.29" / "VS2022" / ms_dir / "dll" / "libusb-1.0.dll" if dev_candidate.is_file(): return str(dev_candidate) return None + # --------------------------------------------------------------------------- # DFU protocol constants (shared by USB and I2C paths) # --------------------------------------------------------------------------- # USB DFU virtual addresses (must match usbd_dfu_if.c) USB_DFU_VERSION_VIRT_ADDR = 0xFFFFFF00 -USB_DFU_VERSION_READ_LEN = 64 +USB_DFU_VERSION_READ_LEN = 64 # I2C DFU command bytes (must match i2c_dfu_if.h) -I2C_DFU_SLAVE_ADDR = 0x72 -I2C_DFU_CMD_DNLOAD = 0x01 -I2C_DFU_CMD_ERASE = 0x02 -I2C_DFU_CMD_GETSTATUS = 0x03 -I2C_DFU_CMD_MANIFEST = 0x04 -I2C_DFU_CMD_RESET = 0x05 -I2C_DFU_CMD_GETVERSION = 0x06 -I2C_DFU_STATUS_OK = 0x00 -I2C_DFU_STATUS_BUSY = 0x01 -I2C_DFU_STATUS_ERROR = 0x02 +I2C_DFU_SLAVE_ADDR = 0x72 +I2C_DFU_CMD_DNLOAD = 0x01 +I2C_DFU_CMD_ERASE = 0x02 +I2C_DFU_CMD_GETSTATUS = 0x03 +I2C_DFU_CMD_MANIFEST = 0x04 +I2C_DFU_CMD_RESET = 0x05 +I2C_DFU_CMD_GETVERSION = 0x06 +I2C_DFU_STATUS_OK = 0x00 +I2C_DFU_STATUS_BUSY = 0x01 +I2C_DFU_STATUS_ERROR = 0x02 I2C_DFU_STATUS_BAD_ADDR = 0x03 -I2C_DFU_STATUS_FLASH_ERR= 0x04 -I2C_DFU_STATE_DNBUSY = 0x01 -I2C_DFU_STATE_ERROR = 0x04 +I2C_DFU_STATUS_FLASH_ERR = 0x04 +I2C_DFU_STATE_DNBUSY = 0x01 +I2C_DFU_STATE_ERROR = 0x04 # Maximum data bytes per write_block call. The enclosing OW_I2C_PASSTHRU UART # packet carries (1 cmd + 4 addr + 2 len) = 7 bytes of I2C-DFU header, so the # total packet payload is I2C_DFU_MAX_XFER_SIZE + 7. The master firmware hard- # rejects any UART packet with data_len > DATA_MAX_SIZE (2048), so this value # must be ≤ 2041. Use 512 for a safe, standard I2C block size. -I2C_DFU_MAX_XFER_SIZE = 512 +I2C_DFU_MAX_XFER_SIZE = 512 I2C_DFU_VERSION_STR_MAX = 32 @@ -116,6 +117,7 @@ class DeviceProfile: app_default_address: int | None = None reset_virt_addr: int | None = 0xFFFFFF08 + # Built-in profiles TRANSMITTER_PROFILE = DeviceProfile( name="transmitter", @@ -136,20 +138,21 @@ class DeviceProfile: ) # OW_I2C_PASSTHRU sub-commands (must match firmware if_commands.c handler) -_PASSTHRU_WRITE = 0x00 # write only -_PASSTHRU_WRITE_READ = 0x01 # write then delay 5 ms then read +_PASSTHRU_WRITE = 0x00 # write only +_PASSTHRU_WRITE_READ = 0x01 # write then delay 5 ms then read # Signed package format (must match dfu-test.py) -_PKG_MAGIC = 0x314B4750 # 'PGK1' -_PKG_VERSION = 1 -_PKG_HDR_NOCRC = " int: """Compute CRC32 compatible with the STM32 CRC peripheral (poly=0x04C11DB7).""" poly = 0x04C11DB7 @@ -176,10 +179,9 @@ def parse_signed_package(pkg: bytes) -> dict: if len(pkg) < hdr_size: raise ValueError("signed package too small") - (magic, version, declared_hdr_size, - fw_address, fw_len, - meta_address, meta_len, - payload_crc, header_crc) = struct.unpack(_PKG_HDR_FULL, pkg[:hdr_size]) + (magic, version, declared_hdr_size, fw_address, fw_len, meta_address, meta_len, payload_crc, header_crc) = ( + struct.unpack(_PKG_HDR_FULL, pkg[:hdr_size]) + ) if magic != _PKG_MAGIC: raise ValueError(f"signed package magic mismatch: 0x{magic:08X}") @@ -188,30 +190,24 @@ def parse_signed_package(pkg: bytes) -> dict: if declared_hdr_size != hdr_size: raise ValueError("signed package header size mismatch") - calc_hdr_crc = stm32_crc32(pkg[:hdr_size - 4]) + calc_hdr_crc = stm32_crc32(pkg[: hdr_size - 4]) if header_crc != calc_hdr_crc: - raise ValueError( - f"header CRC mismatch: pkg=0x{header_crc:08X}, calc=0x{calc_hdr_crc:08X}" - ) + raise ValueError(f"header CRC mismatch: pkg=0x{header_crc:08X}, calc=0x{calc_hdr_crc:08X}") payload_len = fw_len + meta_len payload = pkg[hdr_size:] if len(payload) != payload_len: - raise ValueError( - f"payload size mismatch: expected {payload_len}, got {len(payload)}" - ) + raise ValueError(f"payload size mismatch: expected {payload_len}, got {len(payload)}") calc_payload_crc = stm32_crc32(payload) if payload_crc != calc_payload_crc: - raise ValueError( - f"payload CRC mismatch: pkg=0x{payload_crc:08X}, calc=0x{calc_payload_crc:08X}" - ) + raise ValueError(f"payload CRC mismatch: pkg=0x{payload_crc:08X}, calc=0x{calc_payload_crc:08X}") return { - "fw_address": fw_address, + "fw_address": fw_address, "meta_address": meta_address, - "fw": payload[:fw_len], - "meta": payload[fw_len:], + "fw": payload[:fw_len], + "meta": payload[fw_len:], } @@ -219,6 +215,7 @@ def parse_signed_package(pkg: bytes) -> dict: # USB DFU client (module 0) # --------------------------------------------------------------------------- + class STM32USBDFU: """Minimal STM32 DfuSe USB client using PyUSB. @@ -229,33 +226,36 @@ class STM32USBDFU: """ # DFU class requests - DFU_DNLOAD = 1 - DFU_UPLOAD = 2 + DFU_DNLOAD = 1 + DFU_UPLOAD = 2 DFU_GETSTATUS = 3 DFU_CLRSTATUS = 4 - DFU_ABORT = 6 + DFU_ABORT = 6 # DfuSe DNLOAD block 0 sub-commands CMD_SET_ADDRESS_POINTER = 0x21 - CMD_ERASE = 0x41 + CMD_ERASE = 0x41 # DFU state values - STATE_DFU_DNLOAD_SYNC = 3 - STATE_DFU_DNLOAD_BUSY = 4 - STATE_DFU_DNLOAD_IDLE = 5 - STATE_DFU_MANIFEST_SYNC = 6 - STATE_DFU_MANIFEST = 7 + STATE_DFU_DNLOAD_SYNC = 3 + STATE_DFU_DNLOAD_BUSY = 4 + STATE_DFU_DNLOAD_IDLE = 5 + STATE_DFU_MANIFEST_SYNC = 6 + STATE_DFU_MANIFEST = 7 STATE_DFU_MANIFEST_WAIT_RESET = 8 - STATE_DFU_ERROR = 10 - - def __init__(self, vid: int = 0x0483, pid: int = 0xDF11, - transfer_size: int = 1024, timeout_ms: int = 4000, - libusb_dll: str | None = None, - device_profile: "DeviceProfile" | None = None): + STATE_DFU_ERROR = 10 + + def __init__( + self, + vid: int = 0x0483, + pid: int = 0xDF11, + transfer_size: int = 1024, + timeout_ms: int = 4000, + libusb_dll: str | None = None, + device_profile: DeviceProfile | None = None, + ): if not _USB_DFU_AVAILABLE: - raise RuntimeError( - "PyUSB not available. Install with: pip install pyusb" - ) + raise RuntimeError("PyUSB not available. Install with: pip install pyusb") self.vid = vid self.pid = pid self.transfer_size = transfer_size @@ -278,38 +278,26 @@ def _get_backend(self): if self._backend is not None: return self._backend if self.libusb_dll: - self._backend = _usb_libusb1.get_backend( - find_library=lambda _: self.libusb_dll - ) + self._backend = _usb_libusb1.get_backend(find_library=lambda _: self.libusb_dll) elif _libusb_package is not None: - self._backend = _usb_libusb1.get_backend( - find_library=_libusb_package.find_library - ) + self._backend = _usb_libusb1.get_backend(find_library=_libusb_package.find_library) else: bundled_dll = _find_bundled_libusb_dll() if bundled_dll: logger.debug("Using bundled libusb DLL: %s", bundled_dll) - self._backend = _usb_libusb1.get_backend( - find_library=lambda _: bundled_dll - ) + self._backend = _usb_libusb1.get_backend(find_library=lambda _: bundled_dll) else: self._backend = _usb_libusb1.get_backend() return self._backend - def open(self) -> "STM32USBDFU": - self.dev = _usb_core.find( - idVendor=self.vid, idProduct=self.pid, backend=self._get_backend() - ) + def open(self) -> STM32USBDFU: + self.dev = _usb_core.find(idVendor=self.vid, idProduct=self.pid, backend=self._get_backend()) if self.dev is None: - raise RuntimeError( - f"USB DFU device not found: VID=0x{self.vid:04X}, PID=0x{self.pid:04X}" - ) + raise RuntimeError(f"USB DFU device not found: VID=0x{self.vid:04X}, PID=0x{self.pid:04X}") self.dev.set_configuration() cfg = self.dev.get_active_configuration() for intf in cfg: - if (intf.bInterfaceClass == 0xFE - and intf.bInterfaceSubClass == 0x01 - and intf.bInterfaceProtocol == 0x02): + if intf.bInterfaceClass == 0xFE and intf.bInterfaceSubClass == 0x01 and intf.bInterfaceProtocol == 0x02: self.intf = intf break if self.intf is None: @@ -333,7 +321,7 @@ def close(self) -> None: self.dev = None self.intf = None - def __enter__(self) -> "STM32USBDFU": + def __enter__(self) -> STM32USBDFU: return self.open() def __exit__(self, *args) -> None: @@ -342,16 +330,12 @@ def __exit__(self, *args) -> None: # --- low-level USB control transfers --- def _ctrl_out(self, req: int, value: int, data: bytes = b"") -> int: - return self.dev.ctrl_transfer( - 0x21, req, value, self.intf.bInterfaceNumber, - data, timeout=self.timeout_ms - ) + return self.dev.ctrl_transfer(0x21, req, value, self.intf.bInterfaceNumber, data, timeout=self.timeout_ms) def _ctrl_in(self, req: int, value: int, length: int) -> bytes: - return bytes(self.dev.ctrl_transfer( - 0xA1, req, value, self.intf.bInterfaceNumber, - length, timeout=self.timeout_ms - )) + return bytes( + self.dev.ctrl_transfer(0xA1, req, value, self.intf.bInterfaceNumber, length, timeout=self.timeout_ms) + ) def get_status(self) -> dict: raw = self._ctrl_in(self.DFU_GETSTATUS, 0, 6) @@ -400,9 +384,7 @@ def _wait_while_busy(self) -> dict: def _dnload(self, block_num: int, payload: bytes) -> dict: self._recover_idle() try: - self._ctrl_out( - self.DFU_DNLOAD, block_num, bytes(payload) if payload else b"" - ) + self._ctrl_out(self.DFU_DNLOAD, block_num, bytes(payload) if payload else b"") except Exception as e: if "timeout" not in str(e).lower(): raise @@ -428,9 +410,9 @@ def get_version(self) -> str: self.abort() return raw.rstrip(b"\x00").decode("ascii", errors="replace") - def write_memory(self, address: int, data: bytes, - page_erase: bool = True, - progress_callback: Callable | None = None) -> None: + def write_memory( + self, address: int, data: bytes, page_erase: bool = True, progress_callback: Callable | None = None + ) -> None: """Write data to target flash, optionally erasing each 2 KB page first. IMPORTANT: All page erases are performed before any data is written. @@ -446,8 +428,7 @@ def write_memory(self, address: int, data: bytes, # enforce alignment expectations from device: address and chunk lengths if (address % getattr(self, "program_alignment", 1)) != 0: raise RuntimeError( - f"write_memory: start address 0x{address:08X} not aligned to " - f"{self.program_alignment} bytes" + f"write_memory: start address 0x{address:08X} not aligned to {self.program_alignment} bytes" ) # Phase 1: erase all required pages up-front (before setting the @@ -468,12 +449,12 @@ def write_memory(self, address: int, data: bytes, block = 2 written = 0 for offset in range(0, total, self.transfer_size): - chunk = data[offset:offset + self.transfer_size] + chunk = data[offset : offset + self.transfer_size] # pad final chunk to program_alignment if required by bootloader align = getattr(self, "program_alignment", 1) if align > 1 and (len(chunk) % align) != 0: pad_len = align - (len(chunk) % align) - chunk = chunk + (b"\xFF" * pad_len) + chunk = chunk + (b"\xff" * pad_len) self._dnload(block, chunk) block += 1 written += len(chunk) @@ -495,6 +476,7 @@ def manifest(self) -> None: # I2C DFU client via OW master passthrough (modules 1+) # --------------------------------------------------------------------------- + class STM32I2CDFUviaMaster: """I2C DFU client that routes all I2C transactions through the USB-master module via the ``OW_I2C_PASSTHRU`` UART packet type. @@ -513,9 +495,7 @@ class STM32I2CDFUviaMaster: data = raw bytes to write """ - def __init__(self, uart: "LIFUUart", - i2c_addr: int = I2C_DFU_SLAVE_ADDR, - write_read_delay_s: float = 0.005): + def __init__(self, uart: LIFUUart, i2c_addr: int = I2C_DFU_SLAVE_ADDR, write_read_delay_s: float = 0.005): self._uart = uart self._addr = i2c_addr self._wr_delay = write_read_delay_s @@ -535,12 +515,10 @@ def _write(self, payload: bytes) -> None: self._uart.clear_buffer() if r is None or r.packet_type == OW_ERROR: raise RuntimeError( - f"I2C passthrough write failed (addr=0x{self._addr:02X}, " - f"payload={payload[:8].hex()}...)" + f"I2C passthrough write failed (addr=0x{self._addr:02X}, payload={payload[:8].hex()}...)" ) - def _exchange(self, payload: bytes, read_len: int, - pre_read_delay_s: float | None = None) -> bytes: + def _exchange(self, payload: bytes, read_len: int, pre_read_delay_s: float | None = None) -> bytes: """Write *payload* to the I2C slave, wait, then read *read_len* bytes back. The firmware inserts a fixed 5 ms gap between write and read. @@ -560,12 +538,8 @@ def _exchange(self, payload: bytes, read_len: int, ) self._uart.clear_buffer() if r is None or r.packet_type == OW_ERROR: - raise RuntimeError( - f"I2C passthrough exchange failed (addr=0x{self._addr:02X}, " - f"want_rx={read_len})" - ) - return bytes(r.data[:read_len]) if (r.data and len(r.data) >= read_len) \ - else bytes(read_len) + raise RuntimeError(f"I2C passthrough exchange failed (addr=0x{self._addr:02X}, want_rx={read_len})") + return bytes(r.data[:read_len]) if (r.data and len(r.data) >= read_len) else bytes(read_len) # --- DFU protocol commands --- @@ -580,12 +554,8 @@ def _wait_while_busy(self, timeout_s: float = 10.0) -> dict: while time.monotonic() < deadline: st = self.get_status() if st["state"] == I2C_DFU_STATE_ERROR or st["status"] in _ERROR_STATUSES: - raise RuntimeError( - f"I2C DFU error: status=0x{st['status']:02X}, " - f"state=0x{st['state']:02X}" - ) - if (st["status"] != I2C_DFU_STATUS_BUSY - and st["state"] != I2C_DFU_STATE_DNBUSY): + raise RuntimeError(f"I2C DFU error: status=0x{st['status']:02X}, state=0x{st['state']:02X}") + if st["status"] != I2C_DFU_STATUS_BUSY and st["state"] != I2C_DFU_STATE_DNBUSY: return st time.sleep(0.020) raise TimeoutError(f"I2C DFU timed out after {timeout_s:.0f} s") @@ -608,13 +578,12 @@ def write_block(self, address: int, data: bytes) -> None: self._write(payload) self._wait_while_busy(timeout_s=10.0) - def write_memory(self, address: int, data: bytes, - progress_callback: Callable | None = None) -> None: + def write_memory(self, address: int, data: bytes, progress_callback: Callable | None = None) -> None: """Write arbitrary-length data in ``I2C_DFU_MAX_XFER_SIZE``-byte chunks.""" total = len(data) written = 0 for offset in range(0, total, I2C_DFU_MAX_XFER_SIZE): - chunk = data[offset:offset + I2C_DFU_MAX_XFER_SIZE] + chunk = data[offset : offset + I2C_DFU_MAX_XFER_SIZE] self.write_block(address + offset, chunk) written += len(chunk) if progress_callback: @@ -634,9 +603,7 @@ def get_version(self) -> str: read_len = 2 + I2C_DFU_VERSION_STR_MAX raw = self._exchange(bytes([I2C_DFU_CMD_GETVERSION]), read_len) if raw[0] not in (I2C_DFU_STATUS_OK, I2C_DFU_STATUS_BUSY): - raise RuntimeError( - f"I2C DFU GETVERSION failed: status=0x{raw[0]:02X}" - ) + raise RuntimeError(f"I2C DFU GETVERSION failed: status=0x{raw[0]:02X}") return raw[2:].split(b"\x00")[0].decode("ascii", errors="replace") @@ -644,6 +611,7 @@ def get_version(self) -> str: # High-level firmware update manager # --------------------------------------------------------------------------- + class LIFUDFUManager: """Orchestrates firmware updates for a single LIFU transmitter module. @@ -658,13 +626,12 @@ class LIFUDFUManager: ) """ - def __init__(self, uart: "LIFUUart"): + def __init__(self, uart: LIFUUart): self._uart = uart # --- per-transport helpers --- - def get_bootloader_version_usb(self, vid: int = 0x0483, pid: int = 0xDF11, - libusb_dll: str | None = None) -> str: + def get_bootloader_version_usb(self, vid: int = 0x0483, pid: int = 0xDF11, libusb_dll: str | None = None) -> str: """Read bootloader version string from module 0 via USB DFU.""" with STM32USBDFU(vid=vid, pid=pid, libusb_dll=libusb_dll) as dfu: return dfu.get_version() @@ -674,11 +641,15 @@ def get_bootloader_version_i2c(self, i2c_addr: int = I2C_DFU_SLAVE_ADDR) -> str: dfu = STM32I2CDFUviaMaster(uart=self._uart, i2c_addr=i2c_addr) return dfu.get_version() - def program_usb(self, package_file: str, - vid: int = 0x0483, pid: int = 0xDF11, - libusb_dll: str | None = None, - device_type: str = "transmitter", - progress_callback: Callable | None = None) -> None: + def program_usb( + self, + package_file: str, + vid: int = 0x0483, + pid: int = 0xDF11, + libusb_dll: str | None = None, + device_type: str = "transmitter", + progress_callback: Callable | None = None, + ) -> None: """Program a signed package to module 0 via USB DFU. The module must already be in DFU bootloader mode. @@ -689,31 +660,24 @@ def program_usb(self, package_file: str, logger.info( "USB DFU: fw %d B @ 0x%08X, meta %d B @ 0x%08X", - len(pkg["fw"]), pkg["fw_address"], - len(pkg["meta"]), pkg["meta_address"], + len(pkg["fw"]), + pkg["fw_address"], + len(pkg["meta"]), + pkg["meta_address"], ) if device_type not in ("transmitter", "console"): - raise ValueError( - f"Unknown device_type {device_type!r}; expected 'transmitter' or 'console'." - ) + raise ValueError(f"Unknown device_type {device_type!r}; expected 'transmitter' or 'console'.") profile = TRANSMITTER_PROFILE if device_type == "transmitter" else CONSOLE_PROFILE - with STM32USBDFU(vid=vid, pid=pid, libusb_dll=libusb_dll, - device_profile=profile) as dfu: - dfu.write_memory( - pkg["fw_address"], pkg["fw"], - page_erase=True, progress_callback=progress_callback - ) - dfu.write_memory( - pkg["meta_address"], pkg["meta"], - page_erase=True, progress_callback=progress_callback - ) + with STM32USBDFU(vid=vid, pid=pid, libusb_dll=libusb_dll, device_profile=profile) as dfu: + dfu.write_memory(pkg["fw_address"], pkg["fw"], page_erase=True, progress_callback=progress_callback) + dfu.write_memory(pkg["meta_address"], pkg["meta"], page_erase=True, progress_callback=progress_callback) logger.info("USB DFU: sending manifest...") dfu.manifest() logger.info("USB DFU: programming complete.") - def program_i2c(self, package_file: str, - i2c_addr: int = I2C_DFU_SLAVE_ADDR, - progress_callback: Callable | None = None) -> None: + def program_i2c( + self, package_file: str, i2c_addr: int = I2C_DFU_SLAVE_ADDR, progress_callback: Callable | None = None + ) -> None: """Program a signed package to a slave module via I2C passthrough. The slave must already be in DFU bootloader mode at *i2c_addr*. @@ -732,29 +696,32 @@ def program_i2c(self, package_file: str, logger.info( "I2C DFU: fw %d B @ 0x%08X, meta %d B @ 0x%08X", - len(pkg["fw"]), pkg["fw_address"], - len(pkg["meta"]), pkg["meta_address"], + len(pkg["fw"]), + pkg["fw_address"], + len(pkg["meta"]), + pkg["meta_address"], ) dfu = STM32I2CDFUviaMaster(uart=self._uart, i2c_addr=i2c_addr) logger.info("I2C DFU: mass erasing application region...") dfu.mass_erase() logger.info("I2C DFU: erasing metadata page @ 0x%08X...", pkg["meta_address"]) dfu.erase_page(pkg["meta_address"]) - dfu.write_memory( - pkg["fw_address"], pkg["fw"], - progress_callback=progress_callback - ) + dfu.write_memory(pkg["fw_address"], pkg["fw"], progress_callback=progress_callback) logger.info("I2C DFU: writing metadata...") - dfu.write_memory( - pkg["meta_address"], pkg["meta"] - ) + dfu.write_memory(pkg["meta_address"], pkg["meta"]) logger.info("I2C DFU: sending manifest...") dfu.manifest() logger.info("I2C DFU: programming complete.") - def _wait_for_usb_dfu(self, vid: int, pid: int, libusb_dll: str | None, - timeout_s: float = 30.0, poll_interval_s: float = 1.0, - device_profile: "DeviceProfile" | None = None) -> str: + def _wait_for_usb_dfu( + self, + vid: int, + pid: int, + libusb_dll: str | None, + timeout_s: float = 30.0, + poll_interval_s: float = 1.0, + device_profile: DeviceProfile | None = None, + ) -> str: """Poll for the USB DFU device until it enumerates or *timeout_s* elapses. Returns the bootloader version string once the device is found. @@ -763,13 +730,11 @@ def _wait_for_usb_dfu(self, vid: int, pid: int, libusb_dll: str | None, # Pre-flight: verify the libusb backend can be loaded before entering # the poll loop. If the DLL is missing or the path is wrong this fails # immediately with a clear message instead of silently timing out. - _probe = STM32USBDFU(vid=vid, pid=pid, libusb_dll=libusb_dll, - device_profile=device_profile) + _probe = STM32USBDFU(vid=vid, pid=pid, libusb_dll=libusb_dll, device_profile=device_profile) backend = _probe._get_backend() if backend is None: raise RuntimeError( - "libusb backend not available — install libusb or pass --libusb-dll " - "pointing to a valid libusb-1.0.dll." + "libusb backend not available — install libusb or pass --libusb-dll pointing to a valid libusb-1.0.dll." ) deadline = time.monotonic() + timeout_s @@ -786,51 +751,45 @@ def _wait_for_usb_dfu(self, vid: int, pid: int, libusb_dll: str | None, if dev is None: remaining = deadline - time.monotonic() - logger.debug( - "USB DFU not found yet (attempt %d, %.0f s remaining)...", - attempt, max(remaining, 0) - ) + logger.debug("USB DFU not found yet (attempt %d, %.0f s remaining)...", attempt, max(remaining, 0)) time.sleep(poll_interval_s) continue # Phase 2: device is present — open it and read the version string. elapsed = timeout_s - (deadline - time.monotonic()) - logger.info( - "USB DFU device found after %.1f s (attempt %d)", elapsed, attempt - ) + logger.info("USB DFU device found after %.1f s (attempt %d)", elapsed, attempt) try: - with STM32USBDFU(vid=vid, pid=pid, libusb_dll=libusb_dll, - device_profile=device_profile) as dfu: + with STM32USBDFU(vid=vid, pid=pid, libusb_dll=libusb_dll, device_profile=device_profile) as dfu: version = dfu.get_version() return version except Exception as e: - # Device enumerated but version read failed (e.g. DFU state + # Device enumerated but version read failed (e.g. DFU state # machine not ready yet or bootloader doesn't support virtual # version address). Log visibly and return a placeholder so # the update can still proceed. logger.warning( - "USB DFU device found but version read failed: %s — " - "proceeding with version='unknown'", e + "USB DFU device found but version read failed: %s — proceeding with version='unknown'", e ) return "unknown" raise RuntimeError( - f"USB DFU device (VID=0x{vid:04X}, PID=0x{pid:04X}) did not " - f"enumerate within {timeout_s:.0f} s" + f"USB DFU device (VID=0x{vid:04X}, PID=0x{pid:04X}) did not enumerate within {timeout_s:.0f} s" ) - def update_module(self, - module: int, - package_file: str, - enter_dfu_fn: Callable, - vid: int = 0x0483, - pid: int = 0xDF11, - libusb_dll: str | None = None, - i2c_addr: int = I2C_DFU_SLAVE_ADDR, - dfu_wait_s: float = 3.0, - dfu_enum_timeout_s: float = 30.0, - device_type: str = "transmitter", - progress_callback: Callable | None = None) -> None: + def update_module( + self, + module: int, + package_file: str, + enter_dfu_fn: Callable, + vid: int = 0x0483, + pid: int = 0xDF11, + libusb_dll: str | None = None, + i2c_addr: int = I2C_DFU_SLAVE_ADDR, + dfu_wait_s: float = 3.0, + dfu_enum_timeout_s: float = 30.0, + device_type: str = "transmitter", + progress_callback: Callable | None = None, + ) -> None: """High-level firmware update for a single module. Steps: @@ -869,9 +828,7 @@ def update_module(self, elif device_type == "console": # Console/host DFU is only valid for the USB master (module 0) if module != 0: - raise ValueError( - f"Console DFU is only supported for module 0; got module {module}" - ) + raise ValueError(f"Console DFU is only supported for module 0; got module {module}") enter_dfu_fn() else: raise ValueError(f"Unsupported device_type {device_type!r} for DFU entry") @@ -881,22 +838,23 @@ def update_module(self, time.sleep(dfu_wait_s) if module == 0: - logger.info( - "Waiting for USB DFU device (timeout %ds)...", dfu_enum_timeout_s - ) + logger.info("Waiting for USB DFU device (timeout %ds)...", dfu_enum_timeout_s) try: profile = TRANSMITTER_PROFILE if device_type == "transmitter" else CONSOLE_PROFILE bl_version = self._wait_for_usb_dfu( - vid=vid, pid=pid, libusb_dll=libusb_dll, - timeout_s=dfu_enum_timeout_s, device_profile=profile, + vid=vid, + pid=pid, + libusb_dll=libusb_dll, + timeout_s=dfu_enum_timeout_s, + device_profile=profile, ) except RuntimeError as e: - raise RuntimeError( - f"Module 0 did not enter USB DFU mode: {e}" - ) from e + raise RuntimeError(f"Module 0 did not enter USB DFU mode: {e}") from e logger.info("USB DFU bootloader version: %s", bl_version) self.program_usb( - package_file, vid=vid, pid=pid, + package_file, + vid=vid, + pid=pid, libusb_dll=libusb_dll, device_type=device_type, progress_callback=progress_callback, @@ -904,22 +862,19 @@ def update_module(self, else: logger.info( "Verifying I2C DFU entry (module %d, addr=0x%02X via master)...", - module, i2c_addr, + module, + i2c_addr, ) try: bl_version = self.get_bootloader_version_i2c(i2c_addr=i2c_addr) except Exception as e: - raise RuntimeError( - f"Module {module} did not enter I2C DFU mode at " - f"0x{i2c_addr:02X}: {e}" - ) from e + raise RuntimeError(f"Module {module} did not enter I2C DFU mode at 0x{i2c_addr:02X}: {e}") from e if not bl_version: - raise RuntimeError( - f"Module {module} I2C DFU bootloader returned an empty version string" - ) + raise RuntimeError(f"Module {module} I2C DFU bootloader returned an empty version string") logger.info("I2C DFU bootloader version: %s", bl_version) self.program_i2c( - package_file, i2c_addr=i2c_addr, + package_file, + i2c_addr=i2c_addr, progress_callback=progress_callback, ) diff --git a/src/openlifu_sdk/io/LIFUHVController.py b/src/openlifu_sdk/io/LIFUHVController.py index 6b3751c..0346a0f 100644 --- a/src/openlifu_sdk/io/LIFUHVController.py +++ b/src/openlifu_sdk/io/LIFUHVController.py @@ -126,9 +126,7 @@ def get_version(self) -> str: if not self.uart.is_connected(): raise ValueError("Console Device not connected") - r = self.uart.send_packet( - id=None, packetType=OW_CMD, command=OW_CMD_VERSION - ) + r = self.uart.send_packet(id=None, packetType=OW_CMD, command=OW_CMD_VERSION) self.uart.clear_buffer() # r.print_packet() if r.data_len == 3: @@ -136,10 +134,10 @@ def get_version(self) -> str: elif r.data_len and r.data: try: # Decode only the valid length, strip trailing NULs and whitespace - ver_str = r.data[:r.data_len].decode('utf-8', errors='ignore').rstrip('\x00').strip() - ver = ver_str if ver_str else 'v0.0.0' + ver_str = r.data[: r.data_len].decode("utf-8", errors="ignore").rstrip("\x00").strip() + ver = ver_str if ver_str else "v0.0.0" except Exception: - ver = 'v0.0.0' + ver = "v0.0.0" logger.info(ver) return ver @@ -178,9 +176,7 @@ def echo(self, echo_data=None) -> tuple[bytes, int]: if echo_data is not None and not isinstance(echo_data, bytes | bytearray): raise TypeError("echo_data must be a byte array") - r = self.uart.send_packet( - id=None, packetType=OW_CMD, command=OW_CMD_ECHO, data=echo_data - ) + r = self.uart.send_packet(id=None, packetType=OW_CMD, command=OW_CMD_ECHO, data=echo_data) self.uart.clear_buffer() # r.print_packet() if r.data_len > 0: @@ -211,9 +207,7 @@ def toggle_led(self) -> bool: if not self.uart.is_connected(): raise ValueError("Console Device not connected") - r = self.uart.send_packet( - id=None, packetType=OW_CMD, command=OW_CMD_TOGGLE_LED - ) + r = self.uart.send_packet(id=None, packetType=OW_CMD, command=OW_CMD_TOGGLE_LED) self.uart.clear_buffer() # r.print_packet() if r.packet_type == OW_ERROR: @@ -282,9 +276,7 @@ def get_temperature1(self) -> float: return 0 # Send the GET_TEMP command - r = self.uart.send_packet( - id=None, packetType=OW_POWER, command=OW_POWER_GET_TEMP1 - ) + r = self.uart.send_packet(id=None, packetType=OW_POWER, command=OW_POWER_GET_TEMP1) self.uart.clear_buffer() # r.print_packet() @@ -321,9 +313,7 @@ def get_temperature2(self) -> float: return 0 # Send the GET_TEMP command - r = self.uart.send_packet( - id=None, packetType=OW_POWER, command=OW_POWER_GET_TEMP2 - ) + r = self.uart.send_packet(id=None, packetType=OW_POWER, command=OW_POWER_GET_TEMP2) self.uart.clear_buffer() # r.print_packet() @@ -350,9 +340,7 @@ def turn_12v_off(self): logger.info("Turning off 12V.") - r = self.uart.send_packet( - id=None, packetType=OW_POWER, command=OW_POWER_12V_OFF - ) + r = self.uart.send_packet(id=None, packetType=OW_POWER, command=OW_POWER_12V_OFF) self.uart.clear_buffer() # r.print_packet() @@ -382,9 +370,7 @@ def turn_12v_on(self): logger.info("Turning on 12V.") - r = self.uart.send_packet( - id=None, packetType=OW_POWER, command=OW_POWER_12V_ON - ) + r = self.uart.send_packet(id=None, packetType=OW_POWER, command=OW_POWER_12V_ON) self.uart.clear_buffer() # r.print_packet() @@ -414,9 +400,7 @@ def get_12v_status(self): logger.info("Get 12V voltage status.") - r = self.uart.send_packet( - id=None, packetType=OW_POWER, command=OW_POWER_GET_12VON - ) + r = self.uart.send_packet(id=None, packetType=OW_POWER, command=OW_POWER_GET_12VON) self.uart.clear_buffer() # r.print_packet() @@ -452,9 +436,7 @@ def turn_hv_on(self): logger.info("Turning on high voltage.") - r = self.uart.send_packet( - id=None, packetType=OW_POWER, command=OW_POWER_HV_ON, timeout=30 - ) + r = self.uart.send_packet(id=None, packetType=OW_POWER, command=OW_POWER_HV_ON, timeout=30) self.uart.clear_buffer() # r.print_packet() @@ -487,9 +469,7 @@ def turn_hv_off(self): logger.info("Turning off high voltage.") - r = self.uart.send_packet( - id=None, packetType=OW_POWER, command=OW_POWER_HV_OFF - ) + r = self.uart.send_packet(id=None, packetType=OW_POWER, command=OW_POWER_HV_OFF) self.uart.clear_buffer() # r.print_packet() @@ -519,9 +499,7 @@ def get_hv_status(self): logger.info("Get high voltage status.") - r = self.uart.send_packet( - id=None, packetType=OW_POWER, command=OW_POWER_GET_HVON - ) + r = self.uart.send_packet(id=None, packetType=OW_POWER, command=OW_POWER_GET_HVON) self.uart.clear_buffer() # r.print_packet() @@ -564,26 +542,22 @@ def set_voltage(self, voltage: float) -> bool: if voltage is None: voltage = 0 elif not (5.0 <= voltage <= 100.0): - raise ValueError( - "Voltage input must be within the valid range 5 to 100 Volts)." - ) + raise ValueError("Voltage input must be within the valid range 5 to 100 Volts).") try: - #dac_input = int(((voltage) / 162) * 4095) + # dac_input = int(((voltage) / 162) * 4095) # logger.info("Setting DAC Value %d.", dac_input) # Pack the 12-bit DAC input into two bytes - #data = bytes( + # data = bytes( # [ # (dac_input >> 8) & 0xFF, # High byte (most significant bits) # dac_input & 0xFF, # Low byte (least significant bits) # ] - #) + # ) - data = struct.pack('>f', voltage) + data = struct.pack(">f", voltage) - r = self.uart.send_packet( - id=None, packetType=OW_POWER, command=OW_POWER_SET_HV, data=data - ) + r = self.uart.send_packet(id=None, packetType=OW_POWER, command=OW_POWER_SET_HV, data=data) self.uart.clear_buffer() # r.print_packet() @@ -656,9 +630,7 @@ def set_dacs(self, hvp: int, hvm: int, hrp: int, hrm: int) -> bool: ] ) - r = self.uart.send_packet( - id=None, packetType=OW_POWER, command=OW_POWER_SET_DACS, data=data - ) + r = self.uart.send_packet(id=None, packetType=OW_POWER, command=OW_POWER_SET_DACS, data=data) self.uart.clear_buffer() # r.print_packet() @@ -695,9 +667,7 @@ def get_voltage(self) -> float: logger.info("Getting current output voltage.") - r = self.uart.send_packet( - id=None, packetType=OW_POWER, command=OW_POWER_GET_HV - ) + r = self.uart.send_packet(id=None, packetType=OW_POWER, command=OW_POWER_GET_HV) self.uart.clear_buffer() # r.print_packet() @@ -772,7 +742,7 @@ def set_fan_speed(self, fan_id: int = 0, fan_speed: int = 50) -> int: logger.error("Error setting Fan Speed") return -1 - logger.info(f"Set fan speed to {fan_speed}") + logger.info("Set fan speed to %s", fan_speed) return fan_speed except ValueError as v: @@ -808,9 +778,7 @@ def get_fan_speed(self, fan_id: int = 0) -> int: logger.info("Getting current output voltage.") - r = self.uart.send_packet( - id=None, addr=fan_id, packetType=OW_POWER, command=OW_POWER_GET_FAN - ) + r = self.uart.send_packet(id=None, addr=fan_id, packetType=OW_POWER, command=OW_POWER_GET_FAN) self.uart.clear_buffer() # r.print_packet() @@ -821,7 +789,7 @@ def get_fan_speed(self, fan_id: int = 0) -> int: elif r.data_len == 1: fan_value = r.data[0] - logger.info(f"Output fan speed is {fan_value}") + logger.info("Output fan speed is %s", fan_value) return fan_value else: logger.error("Error getting output voltage from device") @@ -852,9 +820,7 @@ def set_rgb_led(self, rgb_state: int) -> int: raise ValueError("High voltage controller not connected") if rgb_state not in [0, 1, 2, 3]: - raise ValueError( - "Invalid RGB state. Must be 0 (OFF), 1 (RED), 2 (BLUE), or 3 (GREEN)" - ) + raise ValueError("Invalid RGB state. Must be 0 (OFF), 1 (RED), 2 (BLUE), or 3 (GREEN)") try: if self.uart.demo_mode: @@ -876,7 +842,7 @@ def set_rgb_led(self, rgb_state: int) -> int: logger.error("Error setting RGB LED state") return -1 - logger.info(f"Set RGB LED state to {rgb_state}") + logger.info("Set RGB LED state to %s", rgb_state) return rgb_state except ValueError as v: @@ -906,9 +872,7 @@ def get_rgb_led(self) -> int: logger.info("Getting current RGB LED state.") - r = self.uart.send_packet( - id=None, packetType=OW_POWER, command=OW_POWER_GET_RGB - ) + r = self.uart.send_packet(id=None, packetType=OW_POWER, command=OW_POWER_GET_RGB) self.uart.clear_buffer() @@ -917,7 +881,7 @@ def get_rgb_led(self) -> int: return -1 rgb_state = r.reserved - logger.info(f"Current RGB LED state is {rgb_state}") + logger.info("Current RGB LED state is %s", rgb_state) return rgb_state except ValueError as v: @@ -948,7 +912,13 @@ def get_vmon_values(self) -> list[dict]: if self.uart.demo_mode: # Return demo data for 8 channels return [ - {"channel": i, "raw_adc": 2048, "reserved": 0, "voltage": 12.5 + i, "converted_voltage": 25.0 + i * 2} + { + "channel": i, + "raw_adc": 2048, + "reserved": 0, + "voltage": 12.5 + i, + "converted_voltage": 25.0 + i * 2, + } for i in range(8) ] @@ -957,9 +927,7 @@ def get_vmon_values(self) -> list[dict]: return [] # Send the voltage monitor command - r = self.uart.send_packet( - id=None, packetType=OW_POWER, command=OW_POWER_VMON - ) + r = self.uart.send_packet(id=None, packetType=OW_POWER, command=OW_POWER_VMON) self.uart.clear_buffer() # r.print_packet() @@ -975,18 +943,22 @@ def get_vmon_values(self) -> list[dict]: # Unpack data for 8 channels for channel_num in range(8): - channels.append({ - "channel": channel_num, - "raw_adc": raw_values[channel_num], - "voltage": round(voltages[channel_num], 3), - "converted_voltage": round(converted_voltages[channel_num], 3) - }) + channels.append( + { + "channel": channel_num, + "raw_adc": raw_values[channel_num], + "voltage": round(voltages[channel_num], 3), + "converted_voltage": round(converted_voltages[channel_num], 3), + } + ) offset += 10 # Move to next channel (2 + 4 + 4 = 10 bytes) return channels else: - raise ValueError(f"Invalid data length received for voltage monitor: expected 96 bytes, got {r.data_len}") + raise ValueError( + f"Invalid data length received for voltage monitor: expected 96 bytes, got {r.data_len}" + ) except ValueError as v: logger.error("ValueError: %s", v) @@ -1008,7 +980,7 @@ def set_raw_dac(self, dac_id: int = 0, dac_value: int = 0) -> int: if not self.uart.is_connected(): raise ValueError("High voltage controller not connected") - if dac_id not in [0, 1, 2 ,3]: + if dac_id not in [0, 1, 2, 3]: raise ValueError("Invalid DAC ID. Must be 0, 1, 2, or 3") if dac_value not in range(4096): @@ -1039,7 +1011,7 @@ def set_raw_dac(self, dac_id: int = 0, dac_value: int = 0) -> int: logger.error("Error setting DAC value") return -1 - logger.info(f"Set DAC value to {dac_value}") + logger.info("Set DAC value to %s", dac_value) return dac_value except ValueError as v: @@ -1069,7 +1041,7 @@ def hv_enable(self, enable: bool = False) -> bool: if self.uart.demo_mode: return True - logger.info(f"{'Enabling' if enable else 'Disabling'} high voltage output.") + logger.info("%s high voltage output.", "Enabling" if enable else "Disabling") r = self.uart.send_packet( id=None, @@ -1086,7 +1058,7 @@ def hv_enable(self, enable: bool = False) -> bool: logger.error("Error setting HV enable state") return False - logger.info(f"High voltage output {'enabled' if enable else 'disabled'} successfully.") + logger.info("High voltage output %s successfully.", "enabled" if enable else "disabled") return True except ValueError as v: diff --git a/src/openlifu_sdk/io/LIFUInterface.py b/src/openlifu_sdk/io/LIFUInterface.py index fd500be..71808df 100644 --- a/src/openlifu_sdk/io/LIFUInterface.py +++ b/src/openlifu_sdk/io/LIFUInterface.py @@ -4,7 +4,6 @@ import importlib.metadata import logging from enum import Enum -from typing import Dict, List, Optional import numpy as np import pandas as pd @@ -15,31 +14,32 @@ from openlifu_sdk.io.LIFUUart import LIFUUart REF_MAX_SEQUENCE_TIMES = { - "default": [2*60, 5*60, 10*60], # users to use default values - "stress_test": [5*60, 9*60, 15*60] # QA to use stress test values + "default": [2 * 60, 5 * 60, 10 * 60], # users to use default values + "stress_test": [5 * 60, 9 * 60, 15 * 60], # QA to use stress test values } REF_MAX_DUTY_CYCLES = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5] MAX_VOLTAGE_BY_DUTY_CYCLE_AND_SEQUENCE_TIME = { "evt2": [ - [45, 45, 45], # 0.05 - [40, 40, 40], # 0.1 - [40, 40, 35], # 0.2 - [40, 35, 30], # 0.3 - [35, 30, 25], # 0.4 - [30, 25, 20] # 0.5 + [45, 45, 45], # 0.05 + [40, 40, 40], # 0.1 + [40, 40, 35], # 0.2 + [40, 35, 30], # 0.3 + [35, 30, 25], # 0.4 + [30, 25, 20], # 0.5 ], "evt0": [ - [65, 65, 65], # 0.05 - [65, 65, 50], # 0.1 - [50, 40, 35], # 0.2 - [45, 35, 30], # 0.3 - [35, 30, 25], # 0.4 - [30, 25, 20] # 0.5 + [65, 65, 65], # 0.05 + [65, 65, 50], # 0.1 + [50, 40, 35], # 0.2 + [45, 35, 30], # 0.3 + [35, 30, 25], # 0.4 + [30, 25, 20], # 0.5 ], } + class LIFUInterfaceStatus(Enum): STATUS_COMMS_ERROR = -1 STATUS_SYS_OFF = 0 @@ -52,28 +52,26 @@ class LIFUInterfaceStatus(Enum): STATUS_FINISHED = 7 STATUS_ERROR = 8 + logger = logging.getLogger(__name__) + class LIFUInterface: - signal_connect: LIFUSignal = LIFUSignal() - signal_disconnect: LIFUSignal = LIFUSignal() - signal_data_received: LIFUSignal = LIFUSignal() - hvcontroller: HVController = None - txdevice: TxDevice = None - - def __init__(self, - vid: int = 0x0483, - tx_pid: int = 0x57AF, - con_pid: int = 0x57A0, - baudrate: int = 921600, - timeout: int = 10, - TX_test_mode: bool = False, - HV_test_mode: bool = False, - run_async: bool = False, - ext_power_supply: bool = False, - module_invert: bool | List[bool] = False, - voltage_table_selection: Optional[str] = None, - sequence_time_selection: Optional[str] = None) -> None: + def __init__( + self, + vid: int = 0x0483, + tx_pid: int = 0x57AF, + con_pid: int = 0x57A0, + baudrate: int = 921600, + timeout: int = 10, + TX_test_mode: bool = False, + HV_test_mode: bool = False, + run_async: bool = False, + ext_power_supply: bool = False, + module_invert: bool | list[bool] = False, + voltage_table_selection: str | None = None, + sequence_time_selection: str | None = None, + ) -> None: """ Initialize the LIFUInterface with given parameters and store them in the class. @@ -83,9 +81,20 @@ def __init__(self, con_pid (int): Product ID for console device. baudrate (int): Communication baud rate. timeout (int): Read timeout in seconds. - test_mode (bool): Enable test mode. + TX_test_mode (bool): Enable TX test/demo mode. + HV_test_mode (bool): Enable HV test/demo mode. run_async (bool): Enable asynchronous operation. + ext_power_supply (bool): Use external power supply (skip HVController). + module_invert (bool | List[bool]): Invert signal on modules. + voltage_table_selection (str | None): EVT version override for voltage table. + sequence_time_selection (str | None): Sequence time set override. """ + # Per-instance signals — must be created here, not at class level, to + # prevent all instances from sharing the same signal objects. + self.signal_connect: LIFUSignal = LIFUSignal() + self.signal_disconnect: LIFUSignal = LIFUSignal() + self.signal_data_received: LIFUSignal = LIFUSignal() + # Store parameters in instance variables self.vid = vid self.tx_pid = tx_pid @@ -94,8 +103,10 @@ def __init__(self, self.timeout = timeout self._test_mode = TX_test_mode self._async_mode = run_async - self._tx_uart = None - self._hv_uart = None + self._tx_uart: LIFUUart | None = None + self._hv_uart: LIFUUart | None = None + self.txdevice: TxDevice | None = None + self.hvcontroller: HVController | None = None self.status = LIFUInterfaceStatus.STATUS_SYS_OFF self.voltage_table = None @@ -104,8 +115,22 @@ def __init__(self, self.sequence_time_selection = sequence_time_selection # Create a TXDevice instance as part of the interface - logger.debug("Initializing TX Module of LIFUInterface with VID: %s, PID: %s, baudrate: %s, timeout: %s", vid, tx_pid, baudrate, timeout) - self._tx_uart = LIFUUart(vid=vid, pid=tx_pid, baudrate=baudrate, timeout=timeout, desc="TX", demo_mode=TX_test_mode, async_mode=run_async) + logger.debug( + "Initializing TX Module of LIFUInterface with VID: %s, PID: %s, baudrate: %s, timeout: %s", + vid, + tx_pid, + baudrate, + timeout, + ) + self._tx_uart = LIFUUart( + vid=vid, + pid=tx_pid, + baudrate=baudrate, + timeout=timeout, + desc="TX", + demo_mode=TX_test_mode, + async_mode=run_async, + ) self.txdevice = TxDevice(uart=self._tx_uart, module_invert=module_invert) if ext_power_supply: @@ -113,8 +138,22 @@ def __init__(self, self.hvcontroller = None else: # Create a LIFUHVController instance as part of the interface - logger.debug("Initializing Console of LIFUInterface with VID: %s, PID: %s, baudrate: %s, timeout: %s", vid, con_pid, baudrate, timeout) - self._hv_uart = LIFUUart(vid=vid, pid=con_pid, baudrate=baudrate, timeout=timeout, desc="HV", demo_mode=HV_test_mode, async_mode=run_async) + logger.debug( + "Initializing Console of LIFUInterface with VID: %s, PID: %s, baudrate: %s, timeout: %s", + vid, + con_pid, + baudrate, + timeout, + ) + self._hv_uart = LIFUUart( + vid=vid, + pid=con_pid, + baudrate=baudrate, + timeout=timeout, + desc="HV", + demo_mode=HV_test_mode, + async_mode=run_async, + ) self.hvcontroller = HVController(uart=self._hv_uart) # Connect signals to internal handlers @@ -137,7 +176,9 @@ def _resolve_voltage_chart_evt_version(self, voltage_table: str) -> list[list[in else: evt_version = voltage_table.lower() if evt_version not in MAX_VOLTAGE_BY_DUTY_CYCLE_AND_SEQUENCE_TIME: - raise ValueError(f"Invalid voltage_table option '{voltage_table}'. Valid options are: {tuple(MAX_VOLTAGE_BY_DUTY_CYCLE_AND_SEQUENCE_TIME.keys())}") + raise ValueError( + f"Invalid voltage_table option '{voltage_table}'. Valid options are: {tuple(MAX_VOLTAGE_BY_DUTY_CYCLE_AND_SEQUENCE_TIME.keys())}" + ) return MAX_VOLTAGE_BY_DUTY_CYCLE_AND_SEQUENCE_TIME[evt_version] @@ -148,7 +189,9 @@ def _resolve_max_sequence_time_set(self, sequence_time: str) -> list[int]: else: sequence_time = sequence_time.lower() if sequence_time not in REF_MAX_SEQUENCE_TIMES: - raise ValueError(f"Invalid sequence_time option '{sequence_time}'. Valid options are: {tuple(REF_MAX_SEQUENCE_TIMES.keys())}") + raise ValueError( + f"Invalid sequence_time option '{sequence_time}'. Valid options are: {tuple(REF_MAX_SEQUENCE_TIMES.keys())}" + ) return REF_MAX_SEQUENCE_TIMES[sequence_time] async def start_monitoring(self, interval: int = 1) -> None: @@ -156,8 +199,7 @@ async def start_monitoring(self, interval: int = 1) -> None: try: if self._hv_uart is not None: await asyncio.gather( - self._tx_uart.monitor_usb_status(interval), - self._hv_uart.monitor_usb_status(interval) + self._tx_uart.monitor_usb_status(interval), self._hv_uart.monitor_usb_status(interval) ) else: await self._tx_uart.monitor_usb_status(interval) @@ -191,7 +233,7 @@ def is_device_connected(self) -> tuple: hv_connected = self.hvcontroller.is_connected() return tx_connected, hv_connected - def get_max_voltage(self, solution: Dict) -> float: + def get_max_voltage(self, solution: dict) -> float: """ Get the maximum voltage for a given solution. @@ -201,7 +243,7 @@ def get_max_voltage(self, solution: Dict) -> float: Returns: float: The maximum voltage for the solution. """ - + sequence_duty_cycle = self.get_sequence_duty_cycle(solution) sequence_duration = self.get_sequence_duration(solution) @@ -225,18 +267,16 @@ def get_max_voltage_table(self) -> pd.DataFrame: """ data = { "Duty Cycle (%)": [f"<={100 * dc:0.1f}%" for dc in REF_MAX_DUTY_CYCLES], - } + } for i, duration in enumerate(self.sequence_time): col_name = f"<={duration // 60} min" - data[col_name] = [ - self.voltage_table[j][i] for j in range(len(REF_MAX_DUTY_CYCLES)) - ] - max_voltage = pd.DataFrame(data).set_index("Duty Cycle (%)") + data[col_name] = [self.voltage_table[j][i] for j in range(len(REF_MAX_DUTY_CYCLES))] + max_voltage = pd.DataFrame(data).set_index("Duty Cycle (%)") max_voltage.Name = "Maximum Voltage (V)" max_voltage.Description = "This table shows the maximum voltage for different duty cycles and sequence times." return max_voltage - def check_solution(self, solution: Dict) -> None: + def check_solution(self, solution: dict) -> None: """ Check if the solution is valid. Args: @@ -244,26 +284,32 @@ def check_solution(self, solution: Dict) -> None: Raises: ValueError: If the solution is invalid. """ - + self.voltage_table = self._resolve_voltage_chart_evt_version(self.voltage_table_selection) self.sequence_time = self._resolve_max_sequence_time_set(self.sequence_time_selection) sequence_duty_cycle = self.get_sequence_duty_cycle(solution) duty_cycles_limits = np.array(REF_MAX_DUTY_CYCLES) if sequence_duty_cycle > duty_cycles_limits.max(): - raise ValueError(f"Sequence duty cycle ({100*sequence_duty_cycle:0.1f} %) exceeds maximum allowed duty cycle ({100*duty_cycles_limits.max():0.1f} %).") + raise ValueError( + f"Sequence duty cycle ({100 * sequence_duty_cycle:0.1f} %) exceeds maximum allowed duty cycle ({100 * duty_cycles_limits.max():0.1f} %)." + ) duty_cycle_index = np.where(duty_cycles_limits >= sequence_duty_cycle)[0][0] sequence_duration = self.get_sequence_duration(solution) duration_limits = np.array(self.sequence_time) if sequence_duration > duration_limits.max(): - raise ValueError(f"Sequence duration ({sequence_duration:0.0f} s) exceeds maximum allowed duration ({duration_limits.max()} s).") + raise ValueError( + f"Sequence duration ({sequence_duration:0.0f} s) exceeds maximum allowed duration ({duration_limits.max()} s)." + ) duration_index = np.where(duration_limits >= sequence_duration)[0][0] max_voltage = self.voltage_table[duty_cycle_index][duration_index] - if solution['voltage'] > max_voltage: - raise ValueError(f"Voltage ({solution['voltage']:0.1f}V) exceeds maximum allowed voltage ({max_voltage:0.1f}V) for duty cycle ({100*sequence_duty_cycle:0.1f} <= {100*duty_cycles_limits[duty_cycle_index]}%) and sequence time ({sequence_duration:0.0f} <= {duration_limits[duration_index]}s).") + if solution["voltage"] > max_voltage: + raise ValueError( + f"Voltage ({solution['voltage']:0.1f}V) exceeds maximum allowed voltage ({max_voltage:0.1f}V) for duty cycle ({100 * sequence_duty_cycle:0.1f} <= {100 * duty_cycles_limits[duty_cycle_index]}%) and sequence time ({sequence_duration:0.0f} <= {duration_limits[duration_index]}s)." + ) - def get_sequence_duty_cycle(self, solution: Dict) -> float: + def get_sequence_duty_cycle(self, solution: dict) -> float: """ Get the duty cycle of the sequence in the solution. @@ -273,14 +319,15 @@ def get_sequence_duty_cycle(self, solution: Dict) -> float: Returns: float: The duty cycle of the sequence. """ - - if solution['sequence']['pulse_train_interval'] == 0: - return solution['pulse']['duration'] / solution['sequence']['pulse_interval'] + if solution["sequence"]["pulse_train_interval"] == 0: + return solution["pulse"]["duration"] / solution["sequence"]["pulse_interval"] else: - return (solution['pulse']['duration'] * solution['sequence']['pulse_count']) / solution['sequence']['pulse_train_interval'] + return (solution["pulse"]["duration"] * solution["sequence"]["pulse_count"]) / solution["sequence"][ + "pulse_train_interval" + ] - def get_sequence_duration(self, solution: Dict) -> float: + def get_sequence_duration(self, solution: dict) -> float: """ Get the duration of the sequence in the solution. @@ -290,23 +337,27 @@ def get_sequence_duration(self, solution: Dict) -> float: Returns: float: The duration of the sequence. """ - - if solution['sequence']['pulse_train_interval'] == 0: - return solution['sequence']['pulse_interval'] * solution['sequence']['pulse_count'] * solution['sequence']['pulse_train_count'] + if solution["sequence"]["pulse_train_interval"] == 0: + return ( + solution["sequence"]["pulse_interval"] + * solution["sequence"]["pulse_count"] + * solution["sequence"]["pulse_train_count"] + ) else: - return solution['sequence']['pulse_train_interval'] * solution['sequence']['pulse_train_count'] + return solution["sequence"]["pulse_train_interval"] * solution["sequence"]["pulse_train_count"] - def set_module_invert(self, module_invert: bool | List[bool]) -> None: + def set_module_invert(self, module_invert: bool | list[bool]) -> None: if self.txdevice is not None: self.txdevice.set_module_invert(module_invert) - def set_solution(self, - solution: Dict, - profile_index:int=1, - profile_increment:bool=True, - trigger_mode: TriggerModeOpts = "sequence", - ) -> None: + def set_solution( + self, + solution: dict, + profile_index: int = 1, + profile_increment: bool = True, + trigger_mode: TriggerModeOpts = "sequence", + ) -> None: """ Load a solution to the device. @@ -317,11 +368,14 @@ def set_solution(self, trigger_mode (TriggerModeOpts): The trigger mode to use (defaults to "sequence") module_invert (List[bool]|bool): Invert the signal on all modules (singleton) or specific modules (list) (defaults to False) """ - self.check_solution(solution) - if "transducer" in solution and solution["transducer"] is not None and "module_invert" in solution["transducer"]: + if ( + "transducer" in solution + and solution["transducer"] is not None + and "module_invert" in solution["transducer"] + ): self.txdevice.set_module_invert(solution["transducer"]["module_invert"]) else: self.txdevice.set_module_invert(False) @@ -334,22 +388,22 @@ def set_solution(self, else: solution_name = "Solution" - voltage = solution['voltage'] + voltage = solution["voltage"] logger.info("Loading %s...", solution_name) # Convert solution data and send to the device self.txdevice.set_solution( - pulse = solution['pulse'], - delays = solution['delays'], - apodizations= solution['apodizations'], - sequence= solution['sequence'], - profile_index=profile_index, - profile_increment=profile_increment, - trigger_mode=trigger_mode + pulse=solution["pulse"], + delays=solution["delays"], + apodizations=solution["apodizations"], + sequence=solution["sequence"], + profile_index=profile_index, + profile_increment=profile_increment, + trigger_mode=trigger_mode, ) self.set_status(LIFUInterfaceStatus.STATUS_READY) if self.hvcontroller is not None: - logger.info(f"Setting HV to {voltage} V...") + logger.info("Setting HV to %s V...", voltage) self.hvcontroller.set_voltage(voltage) logger.info("%s loaded successfully.", solution_name) diff --git a/src/openlifu_sdk/io/LIFUTXDevice.py b/src/openlifu_sdk/io/LIFUTXDevice.py index 5772b9a..bd5cbb0 100644 --- a/src/openlifu_sdk/io/LIFUTXDevice.py +++ b/src/openlifu_sdk/io/LIFUTXDevice.py @@ -4,116 +4,77 @@ import logging import re import struct -import time -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Annotated, Dict, List, Literal, Optional +from typing import TYPE_CHECKING, Literal import numpy as np -from openlifu_sdk.io.LIFUUart import LIFUUart -from openlifu_sdk.io.LIFUUserConfig import LifuUserConfig -from openlifu_sdk.util.annotations import OpenLIFUFieldData -from openlifu_sdk.util.units import getunitconversion - -DEFAULT_NUM_TRANSMITTERS = 2 -TRANSMITTERS_PER_MODULE = 2 -ADDRESS_GLOBAL_MODE = 0x0 -ADDRESS_STANDBY = 0x1 -ADDRESS_DYNPWR_2 = 0x6 -ADDRESS_LDO_PWR_1 = 0xB -ADDRESS_TRSW_TURNOFF = 0xC -ADDRESS_DYNPWR_1 = 0xF -ADDRESS_LDO_PWR_2 = 0x14 -ADDRESS_TRSW_TURNON = 0x15 -ADDRESS_DELAY_SEL = 0x16 -ADDRESS_PATTERN_MODE = 0x18 -ADDRESS_PATTERN_REPEAT = 0x19 -ADDRESS_PATTERN_SEL_G2 = 0x1E -ADDRESS_PATTERN_SEL_G1 = 0x1F -ADDRESS_TRSW = 0x1A -ADDRESS_APODIZATION = 0x1B -ADDRESSES_GLOBAL = [ADDRESS_GLOBAL_MODE, - ADDRESS_STANDBY, - ADDRESS_DYNPWR_2, - ADDRESS_LDO_PWR_1, - ADDRESS_TRSW_TURNOFF, - ADDRESS_DYNPWR_1, - ADDRESS_LDO_PWR_2, - ADDRESS_TRSW_TURNON, - ADDRESS_DELAY_SEL, - ADDRESS_PATTERN_MODE, - ADDRESS_PATTERN_REPEAT, - ADDRESS_PATTERN_SEL_G1, - ADDRESS_PATTERN_SEL_G2, - ADDRESS_TRSW, - ADDRESS_APODIZATION] -ADDRESSES_DELAY_DATA = list(range(0x20, 0x11F+1)) -ADDRESSES_PATTERN_DATA = list(range(0x120, 0x19F+1)) -ADDRESSES = ADDRESSES_GLOBAL + ADDRESSES_DELAY_DATA + ADDRESSES_PATTERN_DATA -NUM_CHANNELS = 32 -MAX_REGISTER = 0x19F -REGISTER_BYTES = 4 -REGISTER_WIDTH = REGISTER_BYTES*8 -DELAY_ORDER = [[32, 30], - [28, 26], - [24, 22], - [20, 18], - [31, 29], - [27, 25], - [23, 21], - [19, 17], - [16, 14], - [12, 10], - [8, 6], - [4, 2], - [15, 13], - [11, 9], - [7, 5], - [3, 1]] -DELAY_ORDER_REVERSED = [[33 - c for c in row] for row in DELAY_ORDER] -DELAY_CHANNEL_MAP = {} -for row, channels in enumerate(DELAY_ORDER_REVERSED): - for i, channel in enumerate(channels): - DELAY_CHANNEL_MAP[channel] = {'row': row, 'lsb': 16*(1-i)} -DELAY_PROFILE_OFFSET = 16 -VALID_DELAY_PROFILES = list(range(1, 17)) -DELAY_WIDTH = 13 -APODIZATION_CHANNEL_ORDER = [17, 19, 21, 23, 25, 27, 29, 31, 18, 20, 22, 24, 26, 28, 30, 32, 1, 3, 5, 7, 9, 11, 13, 15, 2, 4, 6, 8, 10, 12, 14, 16] -APODIZATION_CHANNEL_ORDER_REVERSED = [33 - c for c in APODIZATION_CHANNEL_ORDER] -DEFAULT_PATTERN_DUTY_CYCLE = 0.66 -PATTERN_PROFILE_OFFSET = 4 -NUM_PATTERN_PROFILES = 32 -VALID_PATTERN_PROFILES = list(range(1, NUM_PATTERN_PROFILES+1)) -MAX_PATTERN_PERIODS = 16 -PATTERN_PERIOD_ORDER = [[1, 2, 3, 4], - [5, 6, 7, 8], - [9, 10, 11, 12], - [13, 14, 15, 16]] -PATTERN_LENGTH_WIDTH = 5 -MAX_PATTERN_PERIOD_LENGTH = 30 -PATTERN_LEVEL_WIDTH = 3 -PATTERN_MAP = {} -for row, periods in enumerate(PATTERN_PERIOD_ORDER): - for i, period in enumerate(periods): - PATTERN_MAP[period] = {'row': row, 'lsb_lvl': i*(PATTERN_LEVEL_WIDTH+PATTERN_LENGTH_WIDTH), 'lsb_period': i*(PATTERN_LENGTH_WIDTH+PATTERN_LEVEL_WIDTH)+PATTERN_LEVEL_WIDTH} -MAX_REPEAT = 2**5-1 -MAX_ELASTIC_REPEAT = 2**16-1 -DEFAULT_TAIL_COUNT = 29 -DEFAULT_CLK_FREQ = 10e6 -ELASTIC_MODE_PULSE_LENGTH_ADJUST = 125e-6 -ProfileOpts = Literal['active', 'configured', 'all'] -TriggerModeOpts = Literal['sequence', 'continuous','single'] -DEFAULT_PULSE_WIDTH_US = 20 -HW_ID_DATA_LENGTH = 12 -TEMPERATURE_DATA_LENGTH = 4 - +from openlifu_sdk.beamforming.tx7332 import ( # noqa: F401 -- re-exported for backward compatibility + ADDRESS_APODIZATION, + ADDRESS_DELAY_SEL, + ADDRESS_DYNPWR_1, + ADDRESS_DYNPWR_2, + ADDRESS_GLOBAL_MODE, + ADDRESS_LDO_PWR_1, + ADDRESS_LDO_PWR_2, + ADDRESS_PATTERN_MODE, + ADDRESS_PATTERN_REPEAT, + ADDRESS_PATTERN_SEL_G1, + ADDRESS_PATTERN_SEL_G2, + ADDRESS_STANDBY, + ADDRESS_TRSW, + ADDRESS_TRSW_TURNOFF, + ADDRESS_TRSW_TURNON, + ADDRESSES, + ADDRESSES_DELAY_DATA, + ADDRESSES_GLOBAL, + ADDRESSES_PATTERN_DATA, + APODIZATION_CHANNEL_ORDER, + APODIZATION_CHANNEL_ORDER_REVERSED, + DEFAULT_CLK_FREQ, + DEFAULT_NUM_TRANSMITTERS, + DEFAULT_PATTERN_DUTY_CYCLE, + DEFAULT_TAIL_COUNT, + DELAY_CHANNEL_MAP, + DELAY_ORDER, + DELAY_ORDER_REVERSED, + DELAY_PROFILE_OFFSET, + DELAY_WIDTH, + ELASTIC_MODE_PULSE_LENGTH_ADJUST, + MAX_ELASTIC_REPEAT, + MAX_PATTERN_PERIOD_LENGTH, + MAX_PATTERN_PERIODS, + MAX_REGISTER, + MAX_REPEAT, + NUM_CHANNELS, + PATTERN_LENGTH_WIDTH, + PATTERN_LEVEL_WIDTH, + PATTERN_MAP, + PATTERN_PERIOD_ORDER, + PATTERN_PROFILE_OFFSET, + REGISTER_BYTES, + REGISTER_WIDTH, + TRANSMITTERS_PER_MODULE, + VALID_DELAY_PROFILES, + VALID_PATTERN_PROFILES, + Tx7332DelayProfile, + Tx7332PulseProfile, + Tx7332Registers, + TxDeviceRegisters, + calc_pulse_pattern, + get_delay_location, + get_pattern_location, + get_register_value, + pack_registers, + print_regs, + set_register_value, + swap_byte_order, +) from openlifu_sdk.io.LIFUConfig import ( OW_CMD, OW_CMD_ASYNC, OW_CMD_DFU, OW_CMD_ECHO, OW_CMD_GET_AMBIENT, - OW_CTRL_GET_MODULE_COUNT, OW_CMD_GET_TEMP, OW_CMD_HWID, OW_CMD_PING, @@ -122,6 +83,7 @@ OW_CMD_USR_CFG, OW_CMD_VERSION, OW_CONTROLLER, + OW_CTRL_GET_MODULE_COUNT, OW_CTRL_GET_SWTRIG, OW_CTRL_SET_SWTRIG, OW_CTRL_START_SWTRIG, @@ -131,8 +93,8 @@ OW_TX7332_DEMO, OW_TX7332_DEVICE_COUNT, OW_TX7332_ENUM, - OW_TX7332_RREG, OW_TX7332_RBLOCK, + OW_TX7332_RREG, OW_TX7332_VWBLOCK, OW_TX7332_VWREG, OW_TX7332_WBLOCK, @@ -141,20 +103,21 @@ TRIGGER_MODE_SEQUENCE, TRIGGER_MODE_SINGLE, ) +from openlifu_sdk.io.LIFUUart import LIFUUart +from openlifu_sdk.io.LIFUUserConfig import LifuUserConfig + +DEFAULT_PULSE_WIDTH_US = 20 +HW_ID_DATA_LENGTH = 12 +TEMPERATURE_DATA_LENGTH = 4 +ProfileOpts = Literal["active", "configured", "all"] +TriggerModeOpts = Literal["sequence", "continuous", "single"] if TYPE_CHECKING: pass -logger = logging.getLogger("TXDevice") -logger.setLevel(logging.INFO) -logger.propagate = False +logger = logging.getLogger(__name__) -if not logger.handlers: - handler = logging.StreamHandler() - formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") - handler.setFormatter(formatter) - logger.addHandler(handler) class TxDevice: def __init__(self, uart: LIFUUart, module_invert: bool | list[bool] = False): @@ -209,7 +172,7 @@ def close(self): if self.uart and self.uart.is_connected(): self.uart.disconnect() - def ping(self, module:int=0) -> bool: + def ping(self, module: int = 0) -> bool: """ Send a ping command to the TX device to verify connectivity. @@ -243,7 +206,7 @@ def ping(self, module:int=0) -> bool: logger.error("Unexpected error during process: %s", e) raise # Re-raise the exception for the caller to handle - def get_version(self, module:int=0) -> str: + def get_version(self, module: int = 0) -> str: """ Retrieve the firmware version of the TX device. @@ -256,26 +219,26 @@ def get_version(self, module:int=0) -> str: """ try: if self.uart.demo_mode: - return 'v0.1.1' + return "v0.1.1" if not self.uart.is_connected(): logger.error("TX Device not connected") - return 'v0.0.0' + return "v0.0.0" r = self.uart.send_packet(id=None, packetType=OW_CMD, command=OW_CMD_VERSION, addr=module) self.uart.clear_buffer() r.print_packet() if r.data_len == 3: - ver = f'v{r.data[0]}.{r.data[1]}.{r.data[2]}' + ver = f"v{r.data[0]}.{r.data[1]}.{r.data[2]}" elif r.data_len and r.data: try: # Decode only the valid length, strip trailing NULs and whitespace - ver_str = r.data[:r.data_len].decode('utf-8', errors='ignore').rstrip('\x00').strip() - ver = ver_str if ver_str else 'v0.0.0' + ver_str = r.data[: r.data_len].decode("utf-8", errors="ignore").rstrip("\x00").strip() + ver = ver_str if ver_str else "v0.0.0" except Exception: - ver = 'v0.0.0' + ver = "v0.0.0" else: - ver = 'v0.0.0' + ver = "v0.0.0" logger.debug(ver) return ver except ValueError as v: @@ -286,7 +249,7 @@ def get_version(self, module:int=0) -> str: logger.error("Unexpected error during process: %s", e) raise # Re-raise the exception for the caller to handle - def echo(self, module:int=0, echo_data = None) -> tuple[bytes, int]: + def echo(self, module: int = 0, echo_data=None) -> tuple[bytes, int]: """ Send an echo command to the device with data and receive the same data in response. @@ -334,7 +297,7 @@ def echo(self, module:int=0, echo_data = None) -> tuple[bytes, int]: logger.error("Unexpected error during echo process: %s", e) raise # Re-raise the exception for the caller to handle - def toggle_led(self, module:int=0) -> bool: + def toggle_led(self, module: int = 0) -> bool: """ Toggle the LED on the TX device. @@ -350,7 +313,7 @@ def toggle_led(self, module:int=0) -> bool: logger.error("TX Device not connected") return False - r = self.uart.send_packet(id=None, packetType=OW_CMD, command=OW_CMD_TOGGLE_LED, addr=module) + self.uart.send_packet(id=None, packetType=OW_CMD, command=OW_CMD_TOGGLE_LED, addr=module) self.uart.clear_buffer() # r.print_packet() return True @@ -363,7 +326,7 @@ def toggle_led(self, module:int=0) -> bool: logger.error("Unexpected error during process: %s", e) raise # Re-raise the exception for the caller to handle - def get_hardware_id(self, module:int=0) -> str: + def get_hardware_id(self, module: int = 0) -> str: """ Retrieve the hardware ID of the TX device. @@ -397,15 +360,15 @@ def get_hardware_id(self, module:int=0) -> str: logger.error("Unexpected error during process: %s", e) raise # Re-raise the exception for the caller to handle - def read_config(self, module:int=0) -> Optional[LifuUserConfig]: + def read_config(self, module: int = 0) -> LifuUserConfig | None: """ Read the user configuration from device flash. - + The configuration is stored as JSON with metadata (magic, version, sequence, CRC). - + Returns: LifuUserConfig: Parsed configuration object, or None on error - + Raises: ValueError: If the UART is not connected Exception: If an error occurs during communication @@ -425,7 +388,7 @@ def read_config(self, module:int=0) -> Optional[LifuUserConfig]: packetType=OW_CMD, addr=module, command=OW_CMD_USR_CFG, - reserved=0 # 0 = READ + reserved=0, # 0 = READ ) self.uart.clear_buffer() @@ -436,10 +399,10 @@ def read_config(self, module:int=0) -> Optional[LifuUserConfig]: # Parse wire format response try: config = LifuUserConfig.from_wire_bytes(r.data) - logger.debug(f"Read config: seq={config.header.seq}, json_len={config.header.json_len}") + logger.debug("Read config: seq=%s, json_len=%s", config.header.seq, config.header.json_len) return config except Exception as e: - logger.error(f"Failed to parse config response: {e}") + logger.error("Failed to parse config response: %s", e) return None except ValueError as v: @@ -450,20 +413,20 @@ def read_config(self, module:int=0) -> Optional[LifuUserConfig]: logger.error("Unexpected error reading config: %s", e) raise - def write_config(self, config: LifuUserConfig, module:int=0) -> Optional[LifuUserConfig]: + def write_config(self, config: LifuUserConfig, module: int = 0) -> LifuUserConfig | None: """ Write user configuration to device flash. - + Can pass either: - Full wire format (header + JSON) - Raw JSON bytes (device will parse as JSON) - + Args: config: LifuUserConfig object to write - + Returns: LifuUserConfig: Updated configuration from device (with new seq/crc), or None on error - + Raises: ValueError: If the UART is not connected Exception: If an error occurs during communication @@ -478,9 +441,9 @@ def write_config(self, config: LifuUserConfig, module:int=0) -> Optional[LifuUse # Convert config to wire format bytes wire_data = config.to_wire_bytes() - - logger.debug(f"Writing config to device: {len(wire_data)} bytes") - + + logger.debug("Writing config to device: %s bytes", len(wire_data)) + # Send write command (reserved=1 for WRITE) r = self.uart.send_packet( id=None, @@ -488,7 +451,7 @@ def write_config(self, config: LifuUserConfig, module:int=0) -> Optional[LifuUse command=OW_CMD_USR_CFG, addr=module, reserved=1, # 1 = WRITE - data=wire_data + data=wire_data, ) self.uart.clear_buffer() @@ -501,12 +464,13 @@ def write_config(self, config: LifuUserConfig, module:int=0) -> Optional[LifuUse # JSON data we just wrote (which is not echoed back by the firmware). try: from openlifu_sdk.io.LIFUUserConfig import LifuUserConfigHeader + updated_header = LifuUserConfigHeader.from_bytes(r.data[:16]) updated_config = LifuUserConfig(header=updated_header, json_data=config.json_data) - logger.debug(f"Config written successfully: new seq={updated_config.header.seq}") + logger.debug("Config written successfully: new seq=%s", updated_config.header.seq) return updated_config except Exception as e: - logger.error(f"Failed to parse write response: {e}") + logger.error("Failed to parse write response: %s", e) return None except ValueError as v: @@ -517,19 +481,19 @@ def write_config(self, config: LifuUserConfig, module:int=0) -> Optional[LifuUse logger.error("Unexpected error writing config: %s", e) raise - def write_config_json(self, json_str: str, module:int=0) -> Optional[LifuUserConfig]: + def write_config_json(self, json_str: str, module: int = 0) -> LifuUserConfig | None: """ Write user configuration from a JSON string. - + This is a convenience method that creates a LifuUserConfig from JSON and writes it to the device. - + Args: json_str: JSON string to write - + Returns: LifuUserConfig: Updated configuration from device, or None on error - + Raises: ValueError: If JSON is invalid or UART is not connected Exception: If an error occurs during communication @@ -539,10 +503,10 @@ def write_config_json(self, json_str: str, module:int=0) -> Optional[LifuUserCon config.set_json_str(json_str) return self.write_config(module=module, config=config) except json.JSONDecodeError as e: - logger.error(f"Invalid JSON: {e}") - raise ValueError(f"Invalid JSON: {e}") + logger.error("Invalid JSON: %s", e) + raise ValueError(f"Invalid JSON: {e}") from e - def get_temperature(self, module:int=1) -> float: + def get_temperature(self, module: int = 1) -> float: """ Retrieve the temperature reading from the TX device. @@ -569,7 +533,7 @@ def get_temperature(self, module:int=1) -> float: # Check if the data length matches a float (4 bytes) if r.data_len == TEMPERATURE_DATA_LENGTH: # Unpack the float value from the received data (assuming little-endian) - temperature = struct.unpack(' float: logger.error("Unexpected error during process: %s", e) raise # Re-raise the exception for the caller to handle - def get_ambient_temperature(self, module:int=0) -> float: + def get_ambient_temperature(self, module: int = 0) -> float: """ Retrieve the ambient temperature reading from the TX device. @@ -610,7 +574,7 @@ def get_ambient_temperature(self, module:int=0) -> float: # Check if the data length matches a float (4 bytes) if r.data_len == TEMPERATURE_DATA_LENGTH: # Unpack the float value from the received data (assuming little-endian) - temperature = struct.unpack(' float: raise # Re-raise the exception for the caller to handle return 0 - def set_trigger(self, - pulse_interval: float, - pulse_count: int = 1, - pulse_width: int = DEFAULT_PULSE_WIDTH_US, - pulse_train_interval: float = 0.0, - pulse_train_count: int = 1, - trigger_mode: TriggerModeOpts = "sequence", - profile_index: int = 0, - profile_increment: bool = True) -> dict: + def set_trigger( + self, + pulse_interval: float, + pulse_count: int = 1, + pulse_width: int = DEFAULT_PULSE_WIDTH_US, + pulse_train_interval: float = 0.0, + pulse_train_count: int = 1, + trigger_mode: TriggerModeOpts = "sequence", + profile_index: int = 0, + profile_increment: bool = True, + ) -> dict: """ Set the trigger configuration on the TX device. @@ -662,23 +628,26 @@ def set_trigger(self, if pulse_train_interval > 0 and (pulse_train_interval < pulse_interval * pulse_count): raise ValueError("Pulse train interval cannot be less than pulse interval * pulse count") - logger.info(f"Setting trigger with parameters: " - f"pulse_interval={pulse_interval}, " - f"pulse_count={pulse_count}, " - f"pulse_width={pulse_width}, " - f"pulse_train_interval={pulse_train_interval}, " - f"pulse_train_count={pulse_train_count}, " - f"trigger_mode={trigger_mode}") + logger.info( + "Setting trigger with parameters: pulse_interval=%s, pulse_count=%s, " + "pulse_width=%s, pulse_train_interval=%s, pulse_train_count=%s, trigger_mode=%s", + pulse_interval, + pulse_count, + pulse_width, + pulse_train_interval, + pulse_train_count, + trigger_mode, + ) trigger_json = { - "TriggerFrequencyHz": 1/pulse_interval, + "TriggerFrequencyHz": 1 / pulse_interval, "TriggerPulseCount": pulse_count, "TriggerPulseWidthUsec": pulse_width, "TriggerPulseTrainInterval": pulse_train_interval * 1000000, "TriggerPulseTrainCount": pulse_train_count, "TriggerMode": trigger_mode_int, "ProfileIndex": 0, - "ProfileIncrement": 0 + "ProfileIncrement": 0, } return self.set_trigger_json(data=trigger_json) @@ -711,21 +680,23 @@ def set_trigger_json(self, data=None) -> dict: try: json_string = json.dumps(data) except json.JSONDecodeError as e: - logger.error(f"Data must be valid JSON: {e}") + logger.error("Data must be valid JSON: %s", e) return None - payload = json_string.encode('utf-8') + payload = json_string.encode("utf-8") - r = self.uart.send_packet(id=None, packetType=OW_CONTROLLER, command=OW_CTRL_SET_SWTRIG, addr=0, data=payload) + r = self.uart.send_packet( + id=None, packetType=OW_CONTROLLER, command=OW_CTRL_SET_SWTRIG, addr=0, data=payload + ) self.uart.clear_buffer() if r.packet_type != OW_ERROR and r.data_len > 0: # Parse response as JSON, if possible try: - response_json = json.loads(r.data.decode('utf-8')) + response_json = json.loads(r.data.decode("utf-8")) return response_json except json.JSONDecodeError as e: - logger.error(f"Error decoding JSON: {e}") + logger.error("Error decoding JSON: %s", e) return None else: return None @@ -759,9 +730,9 @@ def get_trigger_json(self) -> dict: self.uart.clear_buffer() data_object = None try: - data_object = json.loads(r.data.decode('utf-8')) + data_object = json.loads(r.data.decode("utf-8")) except json.JSONDecodeError as e: - logger.error(f"Error decoding JSON: {e}") + logger.error("Error decoding JSON: %s", e) return data_object except ValueError as v: logger.error("ValueError: %s", v) @@ -800,7 +771,7 @@ def get_trigger(self): "pulse_train_count": trigger_json["TriggerPulseTrainCount"], "mode": mode, "profile_index": trigger_json["ProfileIndex"], - "profile_increment": bool(trigger_json["ProfileIncrement"]) + "profile_increment": bool(trigger_json["ProfileIncrement"]), } return trigger_dict @@ -822,7 +793,9 @@ def start_trigger(self) -> bool: if not self.uart.is_connected(): raise ValueError("TX Device not connected") - r = self.uart.send_packet(id=None, packetType=OW_CONTROLLER, command=OW_CTRL_START_SWTRIG, addr=0, data=None) + r = self.uart.send_packet( + id=None, packetType=OW_CONTROLLER, command=OW_CTRL_START_SWTRIG, addr=0, data=None + ) self.uart.clear_buffer() # r.print_packet() if r.packet_type == OW_ERROR: @@ -861,13 +834,7 @@ def stop_trigger(self) -> bool: raise ValueError("TX Device not connected") # Send the STOP_SWTRIG command to the device - r = self.uart.send_packet( - id=None, - packetType=OW_CONTROLLER, - command=OW_CTRL_STOP_SWTRIG, - addr=0, - data=None - ) + r = self.uart.send_packet(id=None, packetType=OW_CONTROLLER, command=OW_CTRL_STOP_SWTRIG, addr=0, data=None) # Clear the UART buffer to prepare for further communication self.uart.clear_buffer() @@ -889,7 +856,7 @@ def stop_trigger(self) -> bool: logger.error("Unexpected error during process: %s", e) raise # Re-raise the exception for the caller to handle - def soft_reset(self, module:int=0) -> bool: + def soft_reset(self, module: int = 0) -> bool: """ Perform a soft reset on the TX device. @@ -918,7 +885,7 @@ def soft_reset(self, module:int=0) -> bool: logger.error("Unexpected error during process: %s", e) raise # Re-raise the exception for the caller to handle - def enter_dfu(self, module:int=0) -> bool: + def enter_dfu(self, module: int = 0) -> bool: """ Perform a soft reset to enter DFU mode on TX device. @@ -977,9 +944,9 @@ def async_mode(self, enable: bool | None = None) -> bool: if enable is not None: if enable: - payload = struct.pack(' bool: logger.error("ValueError: %s", v) raise - def get_tx_module_count(self) -> int: """ Retrieve the number of detected Transmit modules. @@ -1033,8 +999,7 @@ def get_tx_module_count(self) -> int: logger.error("Unexpected error during process: %s", e) raise # Re-raise the exception for the caller to handle - def enum_tx7332_devices(self, - num_devices: int | None = None) -> int: + def enum_tx7332_devices(self, num_devices: int | None = None) -> int: """ Enumerate TX7332 devices connected to the TX device. @@ -1067,7 +1032,9 @@ def enum_tx7332_devices(self, logger.error("Error enumerating TX devices.") if num_devices is not None and num_detected_devices != num_devices: raise ValueError(f"Expected {num_devices} devices, but detected {num_detected_devices} devices") - self.tx_registers = TxDeviceRegisters(num_transmitters=num_detected_devices, module_invert=self.module_invert) + self.tx_registers = TxDeviceRegisters( + num_transmitters=num_detected_devices, module_invert=self.module_invert + ) logger.info("TX Device Count: %d", num_detected_devices) return num_detected_devices except ValueError as v: @@ -1078,7 +1045,7 @@ def enum_tx7332_devices(self, logger.error("Unexpected error during process: %s", e) raise # Re-raise the exception for the caller to handle - def set_module_invert(self, module_invert: bool | List[bool]) -> None: + def set_module_invert(self, module_invert: bool | list[bool]) -> None: """ Set the module invert configuration for the TX device. @@ -1089,7 +1056,7 @@ def set_module_invert(self, module_invert: bool | List[bool]) -> None: if self.tx_registers is not None: self.tx_registers.module_invert = module_invert - def demo_tx7332(self, identifier:int) -> bool: + def demo_tx7332(self, identifier: int) -> bool: """ Sets all TX7332 chip registers with a test waveform. @@ -1122,7 +1089,7 @@ def demo_tx7332(self, identifier:int) -> bool: logger.error("Unexpected error during process: %s", e) raise # Re-raise the exception for the caller to handle - def write_register(self, identifier:int, address: int, value: int) -> bool: + def write_register(self, identifier: int, address: int, value: int) -> bool: """ Write a value to a register in the TX device. @@ -1151,19 +1118,13 @@ def write_register(self, identifier:int, address: int, value: int) -> bool: # Pack the address and value into the required format try: - data = struct.pack(' bool: logger.error("Error writing TX register value") return False - logger.debug(f"Successfully wrote value 0x{value:08X} to register 0x{address:04X}") + logger.debug("Successfully wrote value 0x%08X to register 0x%04X", value, address) return True except ValueError as v: @@ -1184,7 +1145,7 @@ def write_register(self, identifier:int, address: int, value: int) -> bool: logger.error("Unexpected error during process: %s", e) raise # Re-raise the exception for the caller to handle - def read_register(self, identifier:int, address: int) -> int: + def read_register(self, identifier: int, address: int) -> int: """ Read a register value from the TX device. @@ -1211,19 +1172,13 @@ def read_register(self, identifier:int, address: int) -> int: # Pack the address into the required format try: - data = struct.pack(' int: # Verify data length and unpack the register value if r.data_len == 4: try: - value = struct.unpack(' int: logger.error("Unexpected error during process: %s", e) raise # Re-raise the exception for the caller to handle - def write_block(self, identifier: int, start_address: int, reg_values: List[int]) -> bool: + def write_block(self, identifier: int, start_address: int, reg_values: list[int]) -> bool: """ Write a block of register values to the TX device. @@ -1289,7 +1244,7 @@ def write_block(self, identifier: int, start_address: int, reg_values: List[int] # Configure chunking for large blocks max_regs_per_block = 62 # Maximum registers per block due to payload size num_chunks = (len(reg_values) + max_regs_per_block - 1) // max_regs_per_block - logger.debug(f"Write Block: Total chunks = {num_chunks}") + logger.debug("Write Block: Total chunks = %s", num_chunks) # Write each chunk for i in range(num_chunks): @@ -1299,19 +1254,17 @@ def write_block(self, identifier: int, start_address: int, reg_values: List[int] # Pack the chunk into the required data format try: - data_format = ' Optional[List[int]]: + def read_block(self, identifier: int, start_address: int, count: int) -> list[int] | None: """ Read a block of consecutive register values from the TX device. @@ -1364,14 +1317,10 @@ def read_block(self, identifier: int, start_address: int, count: int) -> Optiona raise ValueError(f"count must be 1-62, got {count}") # Request payload: uint16_t start_addr, uint8_t count, uint8_t reserved - data = struct.pack(' Optiona expected_len = count * 4 if r.data_len != expected_len: - logger.error(f"Unexpected data length: {r.data_len}, expected {expected_len}") + logger.error("Unexpected data length: %s, expected %s", r.data_len, expected_len) return None - values = list(struct.unpack(f'<{count}I', r.data)) - logger.debug(f"read_block: {count} regs from 0x{start_address:04X} on tx {identifier}") + values = list(struct.unpack(f"<{count}I", r.data)) + logger.debug("read_block: %s regs from 0x%04X on tx %s", count, start_address, identifier) return values except ValueError as v: @@ -1425,18 +1374,14 @@ def write_register_verify(self, address: int, value: int) -> bool: # Pack the address and value into the required format try: - data = struct.pack(' bool: logger.error("Error verifying writing TX register value") return False - logger.debug(f"Successfully wrote value 0x{value:08X} to register 0x{address:04X}") + logger.debug("Successfully wrote value 0x%08X to register 0x%04X", value, address) return True except ValueError as v: @@ -1458,7 +1403,7 @@ def write_register_verify(self, address: int, value: int) -> bool: logger.error("Unexpected error during process: %s", e) raise # Re-raise the exception for the caller to handle - def write_block_verify(self, start_address: int, reg_values: List[int]) -> bool: + def write_block_verify(self, start_address: int, reg_values: list[int]) -> bool: """ Write a block of register values to the TX device with verification. @@ -1493,7 +1438,7 @@ def write_block_verify(self, start_address: int, reg_values: List[int]) -> bool: # Configure chunking for large blocks max_regs_per_block = 62 # Maximum registers per block due to payload size num_chunks = (len(reg_values) + max_regs_per_block - 1) // max_regs_per_block - logger.debug(f"Write Block: Total chunks = {num_chunks}") + logger.debug("Write Block: Total chunks = %s", num_chunks) # Write each chunk for i in range(num_chunks): @@ -1503,19 +1448,17 @@ def write_block_verify(self, start_address: int, reg_values: List[int]) -> bool: # Pack the chunk into the required data format try: - data_format = ' bool: # Check for errors in the response if r.packet_type == OW_ERROR: - logger.error(f"Error verifying writing TX block at chunk {i}") + logger.error("Error verifying writing TX block at chunk %s", i) return False logger.debug("Block write successful") @@ -1537,14 +1480,16 @@ def write_block_verify(self, start_address: int, reg_values: List[int]) -> bool: logger.error("Unexpected error during process: %s", e) raise # Re-raise the exception for the caller to handle - def set_solution(self, - pulse: Dict, - delays: np.ndarray, - apodizations: np.ndarray, - sequence: Dict, - trigger_mode: TriggerModeOpts = "sequence", - profile_index: int = 1, - profile_increment: bool = True): + def set_solution( + self, + pulse: dict, + delays: np.ndarray, + apodizations: np.ndarray, + sequence: dict, + trigger_mode: TriggerModeOpts = "sequence", + profile_index: int = 1, + profile_increment: bool = True, + ): """ Set the solution parameters on the TX device. @@ -1568,6 +1513,7 @@ def set_solution(self, n_required_devices = int(n_elements / NUM_CHANNELS) n_detected_tx = self.enum_tx7332_devices(num_devices=n_required_devices) n_modules = n_detected_tx / TRANSMITTERS_PER_MODULE + logger.debug("Detected %s TX devices (%s modules)", n_detected_tx, n_modules) if n_required_devices != n_detected_tx: errmsg = f"Number of detected TX devices ({n_detected_tx}) does not match required ({n_required_devices})" logger.exception(errmsg) @@ -1578,18 +1524,16 @@ def set_solution(self, if n > 1: raise NotImplementedError("Multiple foci not supported yet") for profile in range(n): - duty_cycle=DEFAULT_PATTERN_DUTY_CYCLE * max(apodizations[profile,:]) * pulse["amplitude"] + duty_cycle = DEFAULT_PATTERN_DUTY_CYCLE * max(apodizations[profile, :]) * pulse["amplitude"] pulse_profile = Tx7332PulseProfile( - profile=profile+1, + profile=profile + 1, frequency=pulse["frequency"], cycles=int(pulse["duration"] * pulse["frequency"]), - duty_cycle=duty_cycle + duty_cycle=duty_cycle, ) self.tx_registers.add_pulse_profile(pulse_profile) delay_profile = Tx7332DelayProfile( - profile=profile+1, - delays=delays[profile,:], - apodizations=apodizations[profile, :] + profile=profile + 1, delays=delays[profile, :], apodizations=apodizations[profile, :] ) self.tx_registers.add_delay_profile(delay_profile) self.set_trigger( @@ -1599,15 +1543,22 @@ def set_solution(self, pulse_train_count=sequence["pulse_train_count"], trigger_mode=trigger_mode, profile_index=profile_index, - profile_increment=profile_increment + profile_increment=profile_increment, ) self.apply_all_registers() - # Buffer the pulse and delay profiles in the microcontroller(s), so that they can be used to switch profiles on trigger detection - delay_control_registers = {profile:self.tx_registers.get_delay_control_registers(profile) for profile in self.tx_registers.configured_delay_profiles()} - pulse_control_registers = {profile:self.tx_registers.get_pulse_control_registers(profile) for profile in self.tx_registers.configured_pulse_profiles()} - - + # Buffer the pulse and delay profiles in the microcontroller(s) so that they + # can be used to switch profiles on trigger detection. These dicts are computed + # here as a placeholder for future firmware commands; they are not yet sent. + _delay_ctrl = { + profile: self.tx_registers.get_delay_control_registers(profile) + for profile in self.tx_registers.configured_delay_profiles() + } + _pulse_ctrl = { + profile: self.tx_registers.get_pulse_control_registers(profile) + for profile in self.tx_registers.configured_pulse_profiles() + } + logger.debug("Buffered %s delay and %s pulse profiles", len(_delay_ctrl), len(_pulse_ctrl)) def apply_all_registers(self): """ @@ -1626,7 +1577,7 @@ def apply_all_registers(self): for txi, txregs in enumerate(registers): for addr, reg_values in txregs.items(): if not self.write_block(identifier=txi, start_address=addr, reg_values=reg_values): - logger.error(f"Error applying TX CHIP ID: {txi} registers") + logger.error("Error applying TX CHIP ID: %s registers", txi) return False return True @@ -1658,11 +1609,13 @@ def write_ti_config_to_tx_device(self, file_path: str, txchip_id: int) -> bool: # Write each register to the TX device for group, addr, value in parsed_registers: - logger.debug(f"Writing to {group:<20} | Address: 0x{addr:02X} | Value: 0x{value:08X}") + logger.debug("Writing to %-20s | Address: 0x%02X | Value: 0x%08X", group, addr, value) if not self.write_register(identifier=txchip_id, address=addr, value=value): logger.error( - f"Failed to write to TX CHIP ID: {txchip_id} | " - f"Register: 0x{addr:02X} | Value: 0x{value:08X}" + "Failed to write to TX CHIP ID: %s | Register: 0x%02X | Value: 0x%08X", + txchip_id, + addr, + value, ) return False @@ -1670,13 +1623,13 @@ def write_ti_config_to_tx_device(self, file_path: str, txchip_id: int) -> bool: return True except FileNotFoundError as e: - logger.error(f"TI configuration file not found: {file_path}. Error: {e}") + logger.error("TI configuration file not found: %s. Error: %s", file_path, e) raise except ValueError as e: - logger.error(f"Invalid input or device state: {e}") + logger.error("Invalid input or device state: %s", e) raise except Exception as e: - logger.error(f"Unexpected error while writing TI config to TX Device: {e}") + logger.error("Unexpected error while writing TI config to TX Device: %s", e) raise # ------------------------------------------------------------------ @@ -1698,10 +1651,7 @@ def get_module_count(self) -> int: logger.error("TX Device not connected") return 0 - r = self.uart.send_packet( - id=None, packetType=OW_CMD, - command=OW_CTRL_GET_MODULE_COUNT, addr=0 - ) + r = self.uart.send_packet(id=None, packetType=OW_CMD, command=OW_CTRL_GET_MODULE_COUNT, addr=0) self.uart.clear_buffer() if r.packet_type != OW_ERROR and r.data_len >= 1: @@ -1710,9 +1660,7 @@ def get_module_count(self) -> int: return count # Fallback: TX7332 chip count / 2 - logger.info( - "OW_CTRL_GET_MODULE_COUNT not supported; falling back to TX7332 count" - ) + logger.info("OW_CTRL_GET_MODULE_COUNT not supported; falling back to TX7332 count") module_count = self.get_tx_module_count() return module_count @@ -1720,13 +1668,18 @@ def get_module_count(self) -> int: logger.error("Error getting module count: %s", e) return 0 - def update_firmware(self, module: int, package_file: str, - vid: int = 0x0483, pid: int = 0xDF11, - libusb_dll: str | None = None, - i2c_addr: int = 0x72, - dfu_wait_s: float = 5.0, - device_type: str = "transmitter", - progress_callback=None) -> bool: + def update_firmware( + self, + module: int, + package_file: str, + vid: int = 0x0483, + pid: int = 0xDF11, + libusb_dll: str | None = None, + i2c_addr: int = 0x72, + dfu_wait_s: float = 5.0, + device_type: str = "transmitter", + progress_callback=None, + ) -> bool: """Update firmware on a single module. Module 0 (USB master): host → USB DFU. @@ -1777,733 +1730,6 @@ def print(self) -> None: Raises: None """ - print("TX Device Information") # noqa: T201 - print(" UART Port:") # noqa: T201 + print("TX Device Information") + print(" UART Port:") self.uart.print() - -def get_delay_location(channel:int, profile:int=1): - """ - Gets the address and least significant bit of a delay - - :param channel: Channel number - :param profile: Delay profile number - :returns: Register address and least significant bit of the delay location - """ - if channel not in DELAY_CHANNEL_MAP: - raise ValueError(f"Invalid channel {channel}.") - channel_map = DELAY_CHANNEL_MAP[channel] - if profile not in VALID_DELAY_PROFILES: - raise ValueError(f"Invalid Profile {profile}") - address = ADDRESSES_DELAY_DATA[0] + (profile-1) * DELAY_PROFILE_OFFSET + channel_map['row'] - lsb = channel_map['lsb'] - return address, lsb - -def set_register_value(reg_value:int, value:int, lsb:int=0, width: int | None=None): - """ - Sets the value of a parameter in a register integer - - :param reg_value: Register value - :param value: New value of the parameter - :param lsb: Least significant bit of the parameter - :param width: Width of the parameter (bits) - :returns: New register value - """ - if width is None: - width = REGISTER_WIDTH - lsb - mask = (1 << width) - 1 - if value < 0 or value > mask: - raise ValueError(f"Value {value} does not fit in {width} bits") - return (reg_value & ~(mask << lsb)) | ((int(value) & mask) << lsb) - -def get_register_value(reg_value:int, lsb:int=0, width: int | None=None): - """ - Extracts the value of a parameter from a register integer - - :param reg_value: Register value - :param lsb: Least significant bit of the parameter - :param width: Width of the parameter (bits) - :returns: Value of the parameter - """ - if width is None: - width = REGISTER_WIDTH - lsb - mask = (1 << width) - 1 - return (reg_value >> lsb) & mask - -def calc_pulse_pattern(frequency:float, duty_cycle:float=DEFAULT_PATTERN_DUTY_CYCLE, bf_clk:float=DEFAULT_CLK_FREQ): - """ - Calculates the pattern for a given frequency and duty cycle - - The pattern is calculated to represent a single cycle of a pulse with the specified frequency and duty cycle. - If the pattern requires more than 16 periods, the clock divider is increased to reduce the period length. - - :param frequency: Frequency of the pattern in Hz - :param duty_cycle: Duty cycle of the pattern - :param bf_clk: Clock frequency of the BF system in Hz - :returns: Tuple of lists of levels and lengths, and the clock divider setting - """ - clk_div_n = 0 - while clk_div_n < 6: - clk_n = bf_clk / (2**clk_div_n) - period_samples = int(clk_n / frequency) - first_half_period_samples = int(period_samples / 2) - second_half_period_samples = period_samples - first_half_period_samples - first_on_samples = int(first_half_period_samples * duty_cycle) - if first_on_samples < 2: - logging.warning("Duty cycle too short. Setting to minimum of 2 samples") - first_on_samples = 2 - first_off_samples = first_half_period_samples - first_on_samples - second_on_samples = max(2, int(second_half_period_samples * duty_cycle)) - if second_on_samples < 2: - logging.warning("Duty cycle too short. Setting to minimum of 2 samples") - second_on_samples = 2 - second_off_samples = second_half_period_samples - second_on_samples - if first_off_samples > 0 and first_off_samples < 2: - logging.warn - first_off_samples = 0 - first_on_samples = first_half_period_samples - if second_off_samples > 0 and first_off_samples < 2: - second_off_samples = 0 - second_on_samples = second_half_period_samples - levels = [1, 0, -1, 0] - per_lengths = [] - per_levels = [] - for i, samples in enumerate([first_on_samples, first_off_samples, second_on_samples, second_off_samples]): - while samples > 0: - if samples > MAX_PATTERN_PERIOD_LENGTH+2: - if samples == MAX_PATTERN_PERIOD_LENGTH+3: - per_lengths.append(MAX_PATTERN_PERIOD_LENGTH-1) - samples -= (MAX_PATTERN_PERIOD_LENGTH+1) - else: - per_lengths.append(MAX_PATTERN_PERIOD_LENGTH) - samples -= (MAX_PATTERN_PERIOD_LENGTH+2) - per_levels.append(levels[i]) - else: - per_lengths.append(samples-2) - per_levels.append(levels[i]) - samples = 0 - if len(per_levels) <= MAX_PATTERN_PERIODS: - t = (np.arange(np.sum(np.array(per_lengths)+2))*(1/clk_n)).tolist() - y = np.concatenate([[yi]*(ni+2) for yi,ni in zip(per_levels, per_lengths)]).tolist() - pattern = {'levels': per_levels, - 'lengths': per_lengths, - 'clk_div_n': clk_div_n, - 't': t, - 'y': y} - return pattern - else: - clk_div_n += 1 - raise ValueError(f"Pattern requires too many periods ({len(per_levels)} > {MAX_PATTERN_PERIODS})") - -def get_pattern_location(period:int, profile:int=1): - """ - Gets the address and least significant bit of a pattern period - - :param period: Pattern period number - :param profile: Pattern profile number - :returns: Register address and least significant bit of the pattern period location - """ - if period not in PATTERN_MAP: - raise ValueError(f"Invalid period {period}.") - if profile not in VALID_PATTERN_PROFILES: - raise ValueError(f"Invalid profile {profile}.") - address = ADDRESSES_PATTERN_DATA[0] + (profile-1) * PATTERN_PROFILE_OFFSET + PATTERN_MAP[period]['row'] - lsb_lvl = PATTERN_MAP[period]['lsb_lvl'] - lsb_period = PATTERN_MAP[period]['lsb_period'] - return address, lsb_lvl, lsb_period - -def print_regs(d): - for addr, val in sorted(d.items()): - if isinstance(val, list): - for i, v in enumerate(val): - print(f'0x{addr:X}[+{i:d}]:x{v:08X}') # noqa: T201 - else: - print(f'0x{addr:X}:x{val:08X}') # noqa: T201 - -def pack_registers(regs, pack_single:bool=False): - """ - Packs registers into contiguous blocks - - :param regs: Dictionary of registers - :param pack_single: Pack single registers into arrays. Default True. - :returns: Dictionary of packed registers. - """ - addresses = sorted(regs.keys()) - if len(addresses) == 0: - return {} - last_addr = -255 - burst_addr = -255 - packed = {} - for addr in addresses: - if addr == last_addr+1 and burst_addr in packed: - packed[burst_addr].append(regs[addr]) - else: - packed[addr] = [regs[addr]] - burst_addr = addr - last_addr = addr - if not pack_single: - for addr, val in packed.items(): - if len(val) == 1: - packed[addr] = val[0] - return packed - -def swap_byte_order(regs): - """ - Swaps the byte order of the registers - - :param regs: Dictionary of registers - :returns: Dictionary of registers with swapped byte order - """ - swapped = {} - for addr, val in regs.items(): - if isinstance(val, list): - swapped[addr] = [int.from_bytes(v.to_bytes(REGISTER_BYTES, 'big'), 'little') for v in val] - else: - swapped[addr] = int.from_bytes(val.to_bytes(REGISTER_BYTES, 'big'), 'little') - return swapped - -@dataclass -class Tx7332DelayProfile: - profile: Annotated[int, OpenLIFUFieldData("Profile Index (1-16)", "Index of the delay profile (1-16)")] - """Index of the delay profile (1-16). The Tx7332 support 16 unique delay profiles.""" - - delays: Annotated[List[float], OpenLIFUFieldData("Delay values", "Delay values for transducer elements")] - """Delay values for transducer elements""" - - apodizations: Annotated[List[int] | None, OpenLIFUFieldData("Apodizations", "Apodization values for transducer elements")] = None - """Apodization values for transducer elements""" - - units: Annotated[str, OpenLIFUFieldData("Units", "Time units used for delay values")] = 's' - """Time units used for delay values""" - - def __post_init__(self): - self.num_elements = len(self.delays) - if self.apodizations is None: - self.apodizations = [1]*self.num_elements - if len(self.apodizations) != self.num_elements: - raise ValueError(f"Apodizations list must have {self.num_elements} elements") - if self.profile not in VALID_DELAY_PROFILES: - raise ValueError(f"Invalid Profile {self.profile}") - -@dataclass -class Tx7332PulseProfile: - profile: Annotated[int, OpenLIFUFieldData("Profile index (1-32)", "Index of the pulse profile (1-32)")] - """Index of the pulse profile (1-32). The Tx7332 supports 32 unique pulse profiles.""" - - frequency: Annotated[float, OpenLIFUFieldData("Frequency (Hz)", "Center frequency of the pulse (Hz)")] - """Center frequency of the pulse (Hz)""" - - cycles: Annotated[int, OpenLIFUFieldData("Number of cycles", "Number of cycles in the pulse")] - """Number of cycles in the pulse""" - - duty_cycle: Annotated[float, OpenLIFUFieldData("Duty cycle (0-1)", "Pulse duty cycle for the generated square wave (0-1)")] = DEFAULT_PATTERN_DUTY_CYCLE - """Pulse duty cycle for the generated square wave (0-1). By default 0.66 is used to approximate a sinusoidal wave.""" - - tail_count: Annotated[int, OpenLIFUFieldData("Tail count (cycles)", "Clock cycles to actively drive the pulser to ground after the pulse ends")] = DEFAULT_TAIL_COUNT - """Clock cycles to actively drive the pulser to ground after the pulse ends. Default 29""" - - invert: Annotated[bool, OpenLIFUFieldData("Invert polarity?", "Flag indicating whether to invert the pulse amplitude")] = False - """Invert the pulse amplitude. Default False""" - - def __post_init__(self): - if self.profile not in VALID_PATTERN_PROFILES: - raise ValueError(f"Invalid profile {self.profile}.") - -@dataclass -class Tx7332Registers: - bf_clk: Annotated[float, OpenLIFUFieldData("Clock Frequency (Hz)", "The beamformer clock frequency in Hz.")] = DEFAULT_CLK_FREQ - """The beamformer clock frequency in Hz. This much match the hardware clock frequency in order for calculated register values to produce the correct pulse and delay timting. Default is 64 MHz.""" - - _delay_profiles_list: Annotated[List[Tx7332DelayProfile], OpenLIFUFieldData("Delay profiles list", "Internal list of available delay profiles")] = field(default_factory=list) - """Internal list of available delay profiles""" - - _pulse_profiles_list: Annotated[List[Tx7332PulseProfile], OpenLIFUFieldData("Pulse profiles list", "Internal list of available pulse profiles")] = field(default_factory=list) - """Internal list of available pulse profiles""" - - active_delay_profile: Annotated[int | None, OpenLIFUFieldData("Active delay profile", "Index of the currently active delay profile")] = None - """Index of the currently active delay profile""" - - active_pulse_profile: Annotated[int | None, OpenLIFUFieldData("Active pulse profile", "Index of the currently active pulse profile")] = None - """Index of the currently active pulse profile""" - - def __post_init__(self): - delay_profile_indices = self.configured_delay_profiles() - if len(delay_profile_indices) != len(set(delay_profile_indices)): - raise ValueError("Duplicate delay profiles found") - if self.active_delay_profile is not None and self.active_delay_profile not in delay_profile_indices: - raise ValueError(f"Delay profile {self.active_delay_profile} not found") - pulse_profile_indices = self.configured_pulse_profiles() - if len(pulse_profile_indices) != len(set(pulse_profile_indices)): - raise ValueError("Duplicate pulse profiles found") - if self.active_pulse_profile is not None and self.active_pulse_profile not in pulse_profile_indices: - raise ValueError(f"Pulse profile {self.active_pulse_profile} not found") - - def add_delay_profile(self, p: Tx7332DelayProfile, activate: bool | None=None): - if p.num_elements != NUM_CHANNELS: - raise ValueError(f"Delay profile must have {NUM_CHANNELS} elements") - profile_indices = self.configured_delay_profiles() - if p.profile in profile_indices: - i = profile_indices.index(p.profile) - self._delay_profiles_list[i] = p - else: - self._delay_profiles_list.append(p) - if activate is None: - activate = self.active_delay_profile is None - if activate: - self.active_delay_profile = p.profile - - def add_pulse_profile(self, p: Tx7332PulseProfile, activate: bool | None=None): - profile_indices = self.configured_pulse_profiles() - if p.profile in profile_indices: - i = profile_indices.index(p.profile) - self._pulse_profiles_list[i] = p - else: - self._pulse_profiles_list.append(p) - if activate is None: - activate = self.active_pulse_profile is None - if activate: - self.active_pulse_profile = p.profile - - def remove_delay_profile(self, profile:int): - profile_indices = self.configured_delay_profiles() - if profile not in profile_indices: - raise ValueError(f"Delay profile {profile} not found") - index = profile_indices.index(profile) - del self._delay_profiles_list[index] - if self.active_delay_profile == index: - self.active_delay_profile = None - - def remove_pulse_profile(self, profile:int): - profiles = self.configured_pulse_profiles() - if profile not in profiles: - raise ValueError(f"Pulse profile {profile} not found") - index = profiles.index(profile) - del self._pulse_profiles_list[index] - if self.active_pulse_profile == index: - self.active_pulse_profile = None - - def get_delay_profile(self, profile: int | None=None) -> Tx7332DelayProfile: - if profile is None: - profile = self.active_delay_profile - profiles = self.configured_delay_profiles() - if profile not in profiles: - raise ValueError(f"Delay profile {profile} not found") - index = profiles.index(profile) - return self._delay_profiles_list[index] - - def configured_delay_profiles(self) -> List[int]: - return [p.profile for p in self._delay_profiles_list] - - def get_pulse_profile(self, profile: int | None=None) -> Tx7332PulseProfile: - if profile is None: - profile = self.active_pulse_profile - profiles = self.configured_pulse_profiles() - if profile not in profiles: - raise ValueError(f"Pulse profile {profile} not found") - index = profiles.index(profile) - return self._pulse_profiles_list[index] - - def configured_pulse_profiles(self) -> List[int]: - return [p.profile for p in self._pulse_profiles_list] - - def activate_delay_profile(self, profile:int): - if profile not in self.configured_delay_profiles(): - raise ValueError(f"Delay profile {profile} not configured") - self.active_delay_profile = profile - - def activate_pulse_profile(self, profile:int): - if profile not in self.configured_pulse_profiles(): - raise ValueError(f"Pulse profile {profile} not configured") - self.active_pulse_profile = profile - - def get_delay_control_registers(self, profile: int | None=None) -> Dict[int,int]: - if profile is None: - profile = self.active_delay_profile - delay_profile = self.get_delay_profile(profile) - apod_register = 0 - for i, apod in enumerate(delay_profile.apodizations): - apod_register = set_register_value(apod_register, 1-apod, lsb=APODIZATION_CHANNEL_ORDER_REVERSED.index(i+1), width=1) - delay_sel_register = 0 - delay_sel_register = set_register_value(delay_sel_register, delay_profile.profile-1, lsb=12, width=4) - delay_sel_register = set_register_value(delay_sel_register, delay_profile.profile-1, lsb=28, width=4) - return {ADDRESS_DELAY_SEL: delay_sel_register, - ADDRESS_APODIZATION: apod_register} - - def get_pulse_control_registers(self, profile: int | None=None, pulse_invert: bool = False) -> Dict[int,int]: - if profile is None: - profile = self.active_pulse_profile - pulse_profile = self.get_pulse_profile(profile) - if pulse_profile.profile not in VALID_PATTERN_PROFILES: - raise ValueError(f"Invalid profile {pulse_profile.profile}.") - pattern = calc_pulse_pattern(pulse_profile.frequency, pulse_profile.duty_cycle, bf_clk=self.bf_clk) - clk_div_n = pattern['clk_div_n'] - clk_div = 2**clk_div_n - clk_n = self.bf_clk / clk_div - cycles = int(pulse_profile.cycles) - if cycles > (MAX_REPEAT+1): - # Use elastic repeat - pulse_duration_samples = self.bf_clk * ((cycles / pulse_profile.frequency) + ELASTIC_MODE_PULSE_LENGTH_ADJUST) - repeat = 0 - elastic_repeat = int(pulse_duration_samples / 16) - period_samples = int(clk_n / pulse_profile.frequency) - cycles = 16*elastic_repeat / period_samples - y = pattern['y']*int(cycles+1) - y = y[:(16*elastic_repeat)] - y = y + ([0]*pulse_profile.tail_count) - t = np.arange(len(y))*(1/clk_n) - elastic_mode = 1 - if elastic_repeat > MAX_ELASTIC_REPEAT: - raise ValueError("Pattern duration too long for elastic repeat") - else: - repeat = cycles-1 - elastic_repeat = 0 - elastic_mode = 0 - y = pattern['y']*(repeat+1) - y = np.array(y + [0]*pulse_profile.tail_count) - reg_mode = 0x02000003 - reg_mode = set_register_value(reg_mode, clk_div_n, lsb=3, width=3) - reg_mode = set_register_value(reg_mode, int(pulse_profile.invert^pulse_invert), lsb=6, width=1) - reg_repeat = 0 - reg_repeat = set_register_value(reg_repeat, repeat, lsb=1, width=5) - reg_repeat = set_register_value(reg_repeat, pulse_profile.tail_count, lsb=6, width=5) - reg_repeat = set_register_value(reg_repeat, elastic_mode, lsb=11, width=1) - reg_repeat = set_register_value(reg_repeat, elastic_repeat, lsb=12, width=16) - reg_pat_sel = 0 - reg_pat_sel = set_register_value(reg_pat_sel, pulse_profile.profile-1, lsb=0, width=6) - registers = {ADDRESS_PATTERN_MODE: reg_mode, - ADDRESS_PATTERN_REPEAT: reg_repeat, - ADDRESS_PATTERN_SEL_G1: reg_pat_sel, - ADDRESS_PATTERN_SEL_G2: reg_pat_sel} - return registers - - def get_delay_data_registers(self, profile: int | None=None, pack: bool=False, pack_single: bool=False) -> Dict[int,int]: - if profile is None: - profile = self.active_delay_profile - delay_profile = self.get_delay_profile(profile) - data_registers = {} - for channel in range(1, NUM_CHANNELS+1): - address, lsb = get_delay_location(channel, delay_profile.profile) - if address not in data_registers: - data_registers[address] = 0 - delay_value = int(delay_profile.delays[channel-1] * getunitconversion(delay_profile.units, 's') * self.bf_clk) - data_registers[address] = set_register_value(data_registers[address], delay_value, lsb=lsb, width=DELAY_WIDTH) - if pack: - data_registers = pack_registers(data_registers, pack_single=pack_single) - return data_registers - - def get_pulse_data_registers(self, profile: int | None=None, pack: bool=False, pack_single: bool=False) -> Dict[int,int]: - if profile is None: - profile = self.active_pulse_profile - profile_index = self.get_pulse_profile(profile) - data_registers = {} - pattern = calc_pulse_pattern(profile_index.frequency, profile_index.duty_cycle, bf_clk=self.bf_clk) - levels = pattern['levels'] - lengths = pattern['lengths'] - nperiods = len(levels) - level_lut = {-1: 0b01, 0: 0b11, 1: 0b10} # Map levels to register values 0b11 drive to ground 0b00 high impedance - for i, (level, length) in enumerate(zip(levels, lengths)): - address, lsb_lvl, lsb_length = get_pattern_location(i+1, profile_index.profile) - if address not in data_registers: - data_registers[address] = 0 - data_registers[address] = set_register_value(data_registers[address], level_lut[level], lsb=lsb_lvl, width=PATTERN_LEVEL_WIDTH) - data_registers[address] = set_register_value(data_registers[address], length, lsb=lsb_length, width=PATTERN_LENGTH_WIDTH) - if nperiods< MAX_PATTERN_PERIODS: - address, lsb_lvl, lsb_length = get_pattern_location(nperiods+1, profile_index.profile) - if address not in data_registers: - data_registers[address] = 0 - data_registers[address] = set_register_value(data_registers[address], 0b111, lsb=lsb_lvl, width=PATTERN_LEVEL_WIDTH) - data_registers[address] = set_register_value(data_registers[address], 0, lsb=lsb_length, width=PATTERN_LENGTH_WIDTH) - if pack: - data_registers = pack_registers(data_registers, pack_single=pack_single) - return data_registers - - def get_registers(self, profiles: ProfileOpts = "configured", pack: bool=False, pack_single: bool=False, pulse_invert: bool=False) -> Dict[int,int]: - if len(self._delay_profiles_list) == 0: - raise ValueError("No delay profiles have been configured") - if len(self._pulse_profiles_list) == 0: - raise ValueError("No pulse profiles have been configured") - if self.active_delay_profile is None: - raise ValueError("No delay profile activated") - if self.active_pulse_profile is None: - raise ValueError("No pulse profile activated") - registers = {addr:0x0 for addr in ADDRESSES_GLOBAL} - registers.update(self.get_delay_control_registers()) - registers.update(self.get_pulse_control_registers(pulse_invert=pulse_invert)) - if profiles == "active": - delay_data = self.get_delay_data_registers() - pulse_data = self.get_pulse_data_registers() - else: - if profiles == "all": - delay_data = {addr:0x0 for addr in ADDRESSES_DELAY_DATA} - pulse_data = {addr:0x0 for addr in ADDRESSES_PATTERN_DATA} - else: - delay_data = {} - pulse_data = {} - for delay_profile in self._delay_profiles_list: - delay_data.update(self.get_delay_data_registers(profile=delay_profile.profile)) - for profile_index in self._pulse_profiles_list: - pulse_data.update(self.get_pulse_data_registers(profile=profile_index.profile)) - registers.update(delay_data) - registers.update(pulse_data) - if pack: - registers = pack_registers(registers, pack_single=pack_single) - return registers - -@dataclass -class TxDeviceRegisters: - bf_clk: Annotated[int, OpenLIFUFieldData("Clock Frequency (Hz)", "The beamformer clock frequency in Hz.")] = DEFAULT_CLK_FREQ - """The beamformer clock frequency in Hz. This much match the hardware clock frequency in order for calculated register values to produce the correct pulse and delay timting. Default is 64 MHz.""" - - _delay_profiles_list: Annotated[List[Tx7332DelayProfile], OpenLIFUFieldData("Delay profiles list", "Internal list of available delay profiles")] = field(default_factory=list) - """Internal list of available delay profiles""" - - _profiles_list: Annotated[List[Tx7332PulseProfile], OpenLIFUFieldData("Pulse profiles list", "Internal list of available pulse profiles")] = field(default_factory=list) - """Internal list of available pulse profiles""" - - active_delay_profile: Annotated[int | None, OpenLIFUFieldData("Active delay profile", "Index of the currently active delay profile")] = None - """Index of the currently active delay profile""" - - active_profile: Annotated[int | None, OpenLIFUFieldData("Active pulse profile", "Index of the currently active pulse profile")] = None - """Index of the currently active pulse profile""" - - num_transmitters: Annotated[int, OpenLIFUFieldData("Number of transmitters", "The number of transmitters available on the device")] = DEFAULT_NUM_TRANSMITTERS - """The number of transmitters available on the device""" - - module_invert: Annotated[List[bool] | bool, OpenLIFUFieldData("Module Invert", "List of flags indicating whether to invert the pulse amplitude for each module or a single flag for all modules")] = False - """List of flags indicating whether to invert the pulse amplitude for each module or a single flag for all modules""" - - def __post_init__(self): - self.transmitters = tuple([Tx7332Registers(bf_clk=self.bf_clk) for _ in range(self.num_transmitters)]) - - def add_pulse_profile(self, pulse_profile: Tx7332PulseProfile, activate: bool | None=None): - """ - Add a pulse profile - - :param p: Pulse profile - :param activate: Activate the pulse profile - """ - profiles = self.configured_pulse_profiles() - if pulse_profile.profile in profiles: - i = profiles.index(pulse_profile.profile) - self._profiles_list[i] = pulse_profile - else: - self._profiles_list.append(pulse_profile) - if activate is None: - activate = self.active_profile is None - if activate: - self.active_profile = pulse_profile.profile - for tx in self.transmitters: - tx.add_pulse_profile(pulse_profile, activate = activate) - - def add_delay_profile(self, delay_profile: Tx7332DelayProfile, activate: bool | None=None): - """ - Add a delay profile - - :param p: Delay profile - :param activate: Activate the delay profile - """ - if delay_profile.num_elements != NUM_CHANNELS*self.num_transmitters: - raise ValueError(f"Delay profile must have {NUM_CHANNELS*self.num_transmitters} elements") - profiles = self.configured_delay_profiles() - if delay_profile.profile in profiles: - i = profiles.index(delay_profile.profile) - self._delay_profiles_list[i] = delay_profile - else: - self._delay_profiles_list.append(delay_profile) - if activate is None: - activate = self.active_delay_profile is None - if activate: - self.active_delay_profile = delay_profile.profile - for i, tx in enumerate(self.transmitters): - module = i // 2 - chip = (i+1) % 2 - start_channel = module*NUM_CHANNELS*2 + chip*NUM_CHANNELS - profiles = np.arange(start_channel, start_channel+NUM_CHANNELS, dtype=int) - tx_delays = np.array(delay_profile.delays)[profiles].tolist() - tx_apodizations = np.array(delay_profile.apodizations)[profiles].tolist() - txp = Tx7332DelayProfile(delay_profile.profile, tx_delays, tx_apodizations, delay_profile.units) - tx.add_delay_profile(txp, activate = activate) - - def remove_delay_profile(self, profile:int): - """ - Remove a delay profile - - :param profile: Delay profile number - """ - profiles = self.configured_delay_profiles() - if profile not in profiles: - raise ValueError(f"Delay profile {profile} not found") - i = profiles.index(profile) - del self._delay_profiles_list[i] - if self.active_delay_profile == profile: - self.active_delay_profile = None - for tx in self.transmitters: - tx.remove_delay_profile(profile) - - def remove_pulse_profile(self, profile:int): - """ - Remove a pulse profile - - :param profile: Pulse profile number - """ - profiles = self.configured_pulse_profiles() - if profile not in profiles: - raise ValueError(f"Pulse profile {profile} not found") - i = profiles.index(profile) - del self._profiles_list[i] - if self.active_profile == profile: - self.active_profile = None - for tx in self.transmitters: - tx.remove_pulse_profile(profile) - - def get_delay_profile(self, profile:int | None=None) -> Tx7332DelayProfile: - """ - Retrieve a delay profile - - :param profile: Delay profile number - :return: Delay profile - """ - if profile is None: - profile = self.active_delay_profile - profiles = self.configured_delay_profiles() - if profile not in profiles: - raise ValueError(f"Delay profile {profile} not found") - i = profiles.index(profile) - return self._delay_profiles_list[i] - - def configured_delay_profiles(self) -> List[int]: - """ - Get the configured delay profiles - - :return: List of delay profiles - """ - return [p.profile for p in self._delay_profiles_list] - - def get_pulse_profile(self, profile:int | None=None) -> Tx7332PulseProfile: - """ - Retrieve a pulse profile - - :param profile: Pulse profile number - :return: Pulse profile - """ - if profile is None: - profile = self.active_profile - profiles = self.configured_pulse_profiles() - if profile not in profiles: - raise ValueError(f"Pulse profile {profile} not found") - i = profiles.index(profile) - return self._profiles_list[i] - - def configured_pulse_profiles(self) -> List[int]: - """ - Get the configured pulse profiles - - :return: List of pulse profiles - """ - return [p.profile for p in self._profiles_list] - - def activate_delay_profile(self, profile:int=1): - """ - Activates a delay profile - - :param profile: Delay profile number - """ - for tx in self.transmitters: - tx.activate_delay_profile(profile) - self.active_delay_profile = profile - - def activate_pulse_profile(self, profile:int=1): - """ - Activates a pulse profile - - :param profile: Pulse profile number - """ - for tx in self.transmitters: - tx.activate_pulse_profile(profile) - self.active_profile = profile - - def recompute_delay_profiles(self): - """ - Recompute the delay profiles - """ - for tx in self.transmitters: - profiles = tx.configured_delay_profiles() - for profile in profiles: - tx.remove_delay_profile(profile) - for dp in self._delay_profiles_list: - self.add_delay_profile(dp, activate = dp.profile == self.active_delay_profile) - - def recompute_pulse_profiles(self): - """ - Recompute the pulse profiles - """ - for tx in self.transmitters: - profiles = tx.configured_pulse_profiles() - for profile in profiles: - tx.remove_pulse_profile(profile) - for pp in self._profiles_list: - tx.add_pulse_profile(pp, activate = pp.profile == self.active_profile) - - def get_registers(self, profiles: ProfileOpts = "configured", recompute: bool = False, pack: bool=False, pack_single:bool=False) -> List[Dict[int,int]]: - """ - Get the registers for all transmitters - - :param profiles: Profile options - :param recompute: Recompute the registers - :return: List of registers for each transmitter - """ - if recompute: - self.recompute_delay_profiles() - self.recompute_pulse_profiles() - if isinstance(self.module_invert, bool): - tx_invert = [self.module_invert] * self.num_transmitters - else: - tx_invert = list(np.array(self.module_invert*TRANSMITTERS_PER_MODULE, dtype=bool).reshape(TRANSMITTERS_PER_MODULE, -1).T.reshape(-1)) - return [tx.get_registers(profiles, pack=pack, pack_single=pack_single, pulse_invert=inv) for tx, inv in zip(self.transmitters, tx_invert)] - - def get_delay_control_registers(self, profile:int | None=None) -> List[Dict[int,int]]: - """ - Get the delay control registers for all transmitters - - :param profile: Delay profile number - :return: List of delay control registers for each transmitter - """ - if profile is None: - profile = self.active_delay_profile - return [tx.get_delay_control_registers(profile) for tx in self.transmitters] - - def get_pulse_control_registers(self, profile:int | None=None) -> List[Dict[int,int]]: - """ - Get the pulse control registers for all transmitters - - :param profile: Pulse profile number - :return: List of pulse control registers for each transmitter - """ - if profile is None: - profile = self.active_profile - if isinstance(self.module_invert, bool): - tx_invert = [self.module_invert] * self.num_transmitters - else: - tx_invert = list(np.array(self.module_invert*TRANSMITTERS_PER_MODULE, dtype=bool).T.reshape(-1)) - return [tx.get_pulse_control_registers(profile, pulse_invert=inv) for tx, inv in zip(self.transmitters, tx_invert)] - - def get_delay_data_registers(self, profile:int | None=None, pack: bool=False, pack_single: bool=False) -> List[Dict[int,int]]: - """ - Get the delay data registers for all transmitters - - :param profile: Delay profile number - :return: List of delay data registers for each transmitter - """ - if profile is None: - profile = self.active_delay_profile - return [tx.get_delay_data_registers(profile, pack=pack, pack_single=pack_single) for tx in self.transmitters] - - def get_pulse_data_registers(self, profile:int | None=None, pack: bool=False, pack_single: bool=False) -> List[Dict[int,int]]: - """ - Get the pulse data registers for all transmitters - - :param profile: Pulse profile number - :return: List of pulse data registers for each transmitter - """ - if profile is None: - profile = self.active_profile - return [tx.get_pulse_data_registers(profile, pack=pack, pack_single=pack_single) for tx in self.transmitters] diff --git a/src/openlifu_sdk/io/LIFUUart.py b/src/openlifu_sdk/io/LIFUUart.py index 714c843..f0707ad 100644 --- a/src/openlifu_sdk/io/LIFUUart.py +++ b/src/openlifu_sdk/io/LIFUUart.py @@ -18,53 +18,269 @@ OW_END_BYTE = 0xDD ID_COUNTER = 0 # Initializing the ID counter -# Set up logging -log = logging.getLogger("UART") -log.setLevel(logging.ERROR) -log.propagate = False - -handler = logging.StreamHandler() -handler.setLevel(logging.ERROR) -formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") # Format output with timestamp -handler.setFormatter(formatter) -log.addHandler(handler) +log = logging.getLogger(__name__) # CRC16-ccitt lookup table crc16_tab = [ - 0x0000, 0x1021, 0x2042, 0x3063, 0x4084, 0x50a5, 0x60c6, 0x70e7, - 0x8108, 0x9129, 0xa14a, 0xb16b, 0xc18c, 0xd1ad, 0xe1ce, 0xf1ef, - 0x1231, 0x0210, 0x3273, 0x2252, 0x52b5, 0x4294, 0x72f7, 0x62d6, - 0x9339, 0x8318, 0xb37b, 0xa35a, 0xd3bd, 0xc39c, 0xf3ff, 0xe3de, - 0x2462, 0x3443, 0x0420, 0x1401, 0x64e6, 0x74c7, 0x44a4, 0x5485, - 0xa56a, 0xb54b, 0x8528, 0x9509, 0xe5ee, 0xf5cf, 0xc5ac, 0xd58d, - 0x3653, 0x2672, 0x1611, 0x0630, 0x76d7, 0x66f6, 0x5695, 0x46b4, - 0xb75b, 0xa77a, 0x9719, 0x8738, 0xf7df, 0xe7fe, 0xd79d, 0xc7bc, - 0x48c4, 0x58e5, 0x6886, 0x78a7, 0x0840, 0x1861, 0x2802, 0x3823, - 0xc9cc, 0xd9ed, 0xe98e, 0xf9af, 0x8948, 0x9969, 0xa90a, 0xb92b, - 0x5af5, 0x4ad4, 0x7ab7, 0x6a96, 0x1a71, 0x0a50, 0x3a33, 0x2a12, - 0xdbfd, 0xcbdc, 0xfbbf, 0xeb9e, 0x9b79, 0x8b58, 0xbb3b, 0xab1a, - 0x6ca6, 0x7c87, 0x4ce4, 0x5cc5, 0x2c22, 0x3c03, 0x0c60, 0x1c41, - 0xedae, 0xfd8f, 0xcdec, 0xddcd, 0xad2a, 0xbd0b, 0x8d68, 0x9d49, - 0x7e97, 0x6eb6, 0x5ed5, 0x4ef4, 0x3e13, 0x2e32, 0x1e51, 0x0e70, - 0xff9f, 0xefbe, 0xdfdd, 0xcffc, 0xbf1b, 0xaf3a, 0x9f59, 0x8f78, - 0x9188, 0x81a9, 0xb1ca, 0xa1eb, 0xd10c, 0xc12d, 0xf14e, 0xe16f, - 0x1080, 0x00a1, 0x30c2, 0x20e3, 0x5004, 0x4025, 0x7046, 0x6067, - 0x83b9, 0x9398, 0xa3fb, 0xb3da, 0xc33d, 0xd31c, 0xe37f, 0xf35e, - 0x02b1, 0x1290, 0x22f3, 0x32d2, 0x4235, 0x5214, 0x6277, 0x7256, - 0xb5ea, 0xa5cb, 0x95a8, 0x8589, 0xf56e, 0xe54f, 0xd52c, 0xc50d, - 0x34e2, 0x24c3, 0x14a0, 0x0481, 0x7466, 0x6447, 0x5424, 0x4405, - 0xa7db, 0xb7fa, 0x8799, 0x97b8, 0xe75f, 0xf77e, 0xc71d, 0xd73c, - 0x26d3, 0x36f2, 0x0691, 0x16b0, 0x6657, 0x7676, 0x4615, 0x5634, - 0xd94c, 0xc96d, 0xf90e, 0xe92f, 0x99c8, 0x89e9, 0xb98a, 0xa9ab, - 0x5844, 0x4865, 0x7806, 0x6827, 0x18c0, 0x08e1, 0x3882, 0x28a3, - 0xcb7d, 0xdb5c, 0xeb3f, 0xfb1e, 0x8bf9, 0x9bd8, 0xabbb, 0xbb9a, - 0x4a75, 0x5a54, 0x6a37, 0x7a16, 0x0af1, 0x1ad0, 0x2ab3, 0x3a92, - 0xfd2e, 0xed0f, 0xdd6c, 0xcd4d, 0xbdaa, 0xad8b, 0x9de8, 0x8dc9, - 0x7c26, 0x6c07, 0x5c64, 0x4c45, 0x3ca2, 0x2c83, 0x1ce0, 0x0cc1, - 0xef1f, 0xff3e, 0xcf5d, 0xdf7c, 0xaf9b, 0xbfba, 0x8fd9, 0x9ff8, - 0x6e17, 0x7e36, 0x4e55, 0x5e74, 0x2e93, 0x3eb2, 0x0ed1, 0x1ef0, + 0x0000, + 0x1021, + 0x2042, + 0x3063, + 0x4084, + 0x50A5, + 0x60C6, + 0x70E7, + 0x8108, + 0x9129, + 0xA14A, + 0xB16B, + 0xC18C, + 0xD1AD, + 0xE1CE, + 0xF1EF, + 0x1231, + 0x0210, + 0x3273, + 0x2252, + 0x52B5, + 0x4294, + 0x72F7, + 0x62D6, + 0x9339, + 0x8318, + 0xB37B, + 0xA35A, + 0xD3BD, + 0xC39C, + 0xF3FF, + 0xE3DE, + 0x2462, + 0x3443, + 0x0420, + 0x1401, + 0x64E6, + 0x74C7, + 0x44A4, + 0x5485, + 0xA56A, + 0xB54B, + 0x8528, + 0x9509, + 0xE5EE, + 0xF5CF, + 0xC5AC, + 0xD58D, + 0x3653, + 0x2672, + 0x1611, + 0x0630, + 0x76D7, + 0x66F6, + 0x5695, + 0x46B4, + 0xB75B, + 0xA77A, + 0x9719, + 0x8738, + 0xF7DF, + 0xE7FE, + 0xD79D, + 0xC7BC, + 0x48C4, + 0x58E5, + 0x6886, + 0x78A7, + 0x0840, + 0x1861, + 0x2802, + 0x3823, + 0xC9CC, + 0xD9ED, + 0xE98E, + 0xF9AF, + 0x8948, + 0x9969, + 0xA90A, + 0xB92B, + 0x5AF5, + 0x4AD4, + 0x7AB7, + 0x6A96, + 0x1A71, + 0x0A50, + 0x3A33, + 0x2A12, + 0xDBFD, + 0xCBDC, + 0xFBBF, + 0xEB9E, + 0x9B79, + 0x8B58, + 0xBB3B, + 0xAB1A, + 0x6CA6, + 0x7C87, + 0x4CE4, + 0x5CC5, + 0x2C22, + 0x3C03, + 0x0C60, + 0x1C41, + 0xEDAE, + 0xFD8F, + 0xCDEC, + 0xDDCD, + 0xAD2A, + 0xBD0B, + 0x8D68, + 0x9D49, + 0x7E97, + 0x6EB6, + 0x5ED5, + 0x4EF4, + 0x3E13, + 0x2E32, + 0x1E51, + 0x0E70, + 0xFF9F, + 0xEFBE, + 0xDFDD, + 0xCFFC, + 0xBF1B, + 0xAF3A, + 0x9F59, + 0x8F78, + 0x9188, + 0x81A9, + 0xB1CA, + 0xA1EB, + 0xD10C, + 0xC12D, + 0xF14E, + 0xE16F, + 0x1080, + 0x00A1, + 0x30C2, + 0x20E3, + 0x5004, + 0x4025, + 0x7046, + 0x6067, + 0x83B9, + 0x9398, + 0xA3FB, + 0xB3DA, + 0xC33D, + 0xD31C, + 0xE37F, + 0xF35E, + 0x02B1, + 0x1290, + 0x22F3, + 0x32D2, + 0x4235, + 0x5214, + 0x6277, + 0x7256, + 0xB5EA, + 0xA5CB, + 0x95A8, + 0x8589, + 0xF56E, + 0xE54F, + 0xD52C, + 0xC50D, + 0x34E2, + 0x24C3, + 0x14A0, + 0x0481, + 0x7466, + 0x6447, + 0x5424, + 0x4405, + 0xA7DB, + 0xB7FA, + 0x8799, + 0x97B8, + 0xE75F, + 0xF77E, + 0xC71D, + 0xD73C, + 0x26D3, + 0x36F2, + 0x0691, + 0x16B0, + 0x6657, + 0x7676, + 0x4615, + 0x5634, + 0xD94C, + 0xC96D, + 0xF90E, + 0xE92F, + 0x99C8, + 0x89E9, + 0xB98A, + 0xA9AB, + 0x5844, + 0x4865, + 0x7806, + 0x6827, + 0x18C0, + 0x08E1, + 0x3882, + 0x28A3, + 0xCB7D, + 0xDB5C, + 0xEB3F, + 0xFB1E, + 0x8BF9, + 0x9BD8, + 0xABBB, + 0xBB9A, + 0x4A75, + 0x5A54, + 0x6A37, + 0x7A16, + 0x0AF1, + 0x1AD0, + 0x2AB3, + 0x3A92, + 0xFD2E, + 0xED0F, + 0xDD6C, + 0xCD4D, + 0xBDAA, + 0xAD8B, + 0x9DE8, + 0x8DC9, + 0x7C26, + 0x6C07, + 0x5C64, + 0x4C45, + 0x3CA2, + 0x2C83, + 0x1CE0, + 0x0CC1, + 0xEF1F, + 0xFF3E, + 0xCF5D, + 0xDF7C, + 0xAF9B, + 0xBFBA, + 0x8FD9, + 0x9FF8, + 0x6E17, + 0x7E36, + 0x4E55, + 0x5E74, + 0x2E93, + 0x3EB2, + 0x0ED1, + 0x1EF0, ] + def util_crc16(buf): crc = 0xFFFF @@ -73,6 +289,7 @@ def util_crc16(buf): return crc + class UartPacket: def __init__(self, id=None, packet_type=None, command=None, addr=None, reserved=None, data=None, buffer=None): if buffer: @@ -91,12 +308,12 @@ def calculate_crc(self) -> int: crc_value = 0xFFFF packet = bytearray() packet.append(OW_START_BYTE) - packet.extend(self.id.to_bytes(2, 'big')) + packet.extend(self.id.to_bytes(2, "big")) packet.append(self.packet_type) packet.append(self.command) packet.append(self.addr) packet.append(self.reserved) - packet.extend(self.data_len.to_bytes(2, 'big')) + packet.extend(self.data_len.to_bytes(2, "big")) if self.data_len > 0: packet.extend(self.data) crc_value = util_crc16(packet[1:]) @@ -105,16 +322,16 @@ def calculate_crc(self) -> int: def to_bytes(self) -> bytes: buffer = bytearray() buffer.append(OW_START_BYTE) - buffer.extend(self.id.to_bytes(2, 'big')) + buffer.extend(self.id.to_bytes(2, "big")) buffer.append(self.packet_type) buffer.append(self.command) buffer.append(self.addr) buffer.append(self.reserved) - buffer.extend(self.data_len.to_bytes(2, 'big')) + buffer.extend(self.data_len.to_bytes(2, "big")) if self.data_len > 0: buffer.extend(self.data) crc_value = util_crc16(buffer[1:]) - buffer.extend(crc_value.to_bytes(2, 'big')) + buffer.extend(crc_value.to_bytes(2, "big")) buffer.append(OW_END_BYTE) return bytes(buffer) @@ -122,31 +339,32 @@ def from_buffer(self, buffer: bytes): if buffer[0] != OW_START_BYTE or buffer[-1] != OW_END_BYTE: raise ValueError("Invalid buffer format") - self.id = int.from_bytes(buffer[1:3], 'big') + self.id = int.from_bytes(buffer[1:3], "big") self.packet_type = buffer[3] self.command = buffer[4] self.addr = buffer[5] self.reserved = buffer[6] - self.data_len = int.from_bytes(buffer[7:9], 'big') - self.data = bytearray(buffer[9:9+self.data_len]) - crc_value = util_crc16(buffer[1:9+self.data_len]) - self.crc = int.from_bytes(buffer[9+self.data_len:11+self.data_len], 'big') + self.data_len = int.from_bytes(buffer[7:9], "big") + self.data = bytearray(buffer[9 : 9 + self.data_len]) + crc_value = util_crc16(buffer[1 : 9 + self.data_len]) + self.crc = int.from_bytes(buffer[9 + self.data_len : 11 + self.data_len], "big") if self.crc != crc_value: raise ValueError("CRC mismatch") def print_packet(self): log.info("UartPacket:") - log.info(f" Packet ID: {self.id}") - log.info(f" Packet Type: {hex(self.packet_type)}") - log.info(f" Command: {hex(self.command)}") - log.info(f" Address: {hex(self.addr)}") - log.info(f" Reserved: {hex(self.reserved)}") - log.info(f" Data Length: {self.data_len}") + log.info(" Packet ID: %s", self.id) + log.info(" Packet Type: %s", hex(self.packet_type)) + log.info(" Command: %s", hex(self.command)) + log.info(" Address: %s", hex(self.addr)) + log.info(" Reserved: %s", hex(self.reserved)) + log.info(" Data Length: %s", self.data_len) if self.data_len > 0: - log.info(f" Data: {self.data.hex()}") + log.info(" Data: %s", self.data.hex()) else: log.info(" Data: None") - log.info(f" CRC: {hex(self.crc)}") + log.info(" CRC: %s", hex(self.crc)) + class LIFUUart: def __init__(self, vid, pid, baudrate=921600, timeout=10, align=0, desc="VCP", demo_mode=False, async_mode=False): @@ -198,11 +416,7 @@ def connect(self): self.signal_connect.emit(self.descriptor, "demo_mode") return try: - self.serial = serial.Serial( - port=self.port, - baudrate=self.baudrate, - timeout=self.timeout - ) + self.serial = serial.Serial(port=self.port, baudrate=self.baudrate, timeout=self.timeout) log.info("Connected to UART on port %s.", self.port) self.signal_connect.emit(self.descriptor, self.port) @@ -290,7 +504,7 @@ def list_vcp_with_vid_pid(self): """Find the USB device by VID and PID.""" ports = serial.tools.list_ports.comports() for port in ports: - if hasattr(port, 'vid') and hasattr(port, 'pid') and port.vid == self.vid and port.pid == self.pid: + if hasattr(port, "vid") and hasattr(port, "pid") and port.vid == self.vid and port.pid == self.pid: return port.device return None @@ -355,7 +569,7 @@ def _read_data(self, timeout=20): if packet.id in self.response_queues: self.response_queues[packet.id].put(packet) elif packet.packet_type == OW_DATA and packet.id == 0: - text = packet.data.decode('utf-8', errors='replace') + text = packet.data.decode("utf-8", errors="replace") self.signal_data_received.emit(self.descriptor, text) else: log.warning("Received an unsolicited packet with ID %d", packet.id) @@ -422,14 +636,7 @@ def read_packet(self, timeout=20) -> UartPacket: packet = UartPacket(buffer=raw_data) except Exception as e: log.error("Error parsing packet: %s", e) - packet = UartPacket( - id=0, - packet_type=OW_ERROR, - command=0, - addr=0, - reserved=0, - data=[] - ) + packet = UartPacket(id=0, packet_type=OW_ERROR, command=0, addr=0, reserved=0, data=[]) raise e return packet @@ -454,21 +661,21 @@ def send_packet(self, id=None, packetType=OW_ACK, command=OW_CMD_NOP, addr=0, re payload_length = len(payload) else: payload_length = 0 - payload = b'' + payload = b"" packet = bytearray() packet.append(OW_START_BYTE) - packet.extend(id.to_bytes(2, 'big')) + packet.extend(id.to_bytes(2, "big")) packet.append(packetType) packet.append(command) packet.append(addr) packet.append(reserved) - packet.extend(payload_length.to_bytes(2, 'big')) + packet.extend(payload_length.to_bytes(2, "big")) if payload_length > 0: packet.extend(payload) crc_value = util_crc16(packet[1:]) # Exclude start byte - packet.extend(crc_value.to_bytes(2, 'big')) + packet.extend(crc_value.to_bytes(2, "big")) packet.append(OW_END_BYTE) self._tx(packet) diff --git a/src/openlifu_sdk/io/LIFUUserConfig.py b/src/openlifu_sdk/io/LIFUUserConfig.py index 457eed2..354bd8d 100644 --- a/src/openlifu_sdk/io/LIFUUserConfig.py +++ b/src/openlifu_sdk/io/LIFUUserConfig.py @@ -1,18 +1,20 @@ -import struct import json -from dataclasses import dataclass -from typing import Optional, Any, Dict import logging +import struct +from dataclasses import dataclass +from typing import Any logger = logging.getLogger(__name__) # Constants from C code LIFU_MAGIC = 0x4C494655 # 'LIFU' -LIFU_VER = 0x00010002 # v1.0.0 +LIFU_VER = 0x00010002 # v1.0.0 + @dataclass class LifuUserConfigHeader: """Wire format header for Lifu config""" + magic: int version: int seq: int @@ -20,31 +22,19 @@ class LifuUserConfigHeader: json_len: int @classmethod - def from_bytes(cls, data: bytes) -> 'LifuUserConfigHeader': + def from_bytes(cls, data: bytes) -> "LifuUserConfigHeader": """Parse header from wire format bytes (little-endian)""" if len(data) < 16: raise ValueError(f"Header data too short: {len(data)} bytes, need 16") - + # Parse: uint32 magic, uint32 version, uint32 seq, uint16 crc, uint16 json_len - magic, version, seq, crc, json_len = struct.unpack(' bytes: """Convert header to wire format bytes (little-endian)""" - return struct.pack(' bool: """Check if magic and version are valid""" @@ -54,89 +44,89 @@ def is_valid(self) -> bool: class LifuUserConfig: """ Encapsulates the Lifu configuration stored in device flash. - + The configuration is stored as a JSON blob with metadata including: - magic number for validation - version for compatibility - sequence number (monotonically increasing) - CRC for integrity """ - - def __init__(self, header: Optional[LifuUserConfigHeader] = None, json_data: Optional[Dict[str, Any]] = None): + + def __init__(self, header: LifuUserConfigHeader | None = None, json_data: dict[str, Any] | None = None): """ Initialize LifuUserConfig - + Args: header: Configuration header metadata json_data: Configuration JSON data as a dictionary """ - self.header = header if header else LifuUserConfigHeader( - magic=LIFU_MAGIC, - version=LIFU_VER, - seq=0, - crc=0, - json_len=0 + self.header = ( + header if header else LifuUserConfigHeader(magic=LIFU_MAGIC, version=LIFU_VER, seq=0, crc=0, json_len=0) ) self.json_data = json_data if json_data is not None else {} @classmethod - def from_wire_bytes(cls, data: bytes) -> 'LifuUserConfig': + def from_wire_bytes(cls, data: bytes) -> "LifuUserConfig": """ Parse configuration from wire format bytes - + Wire format: [header: 16 bytes][json: json_len bytes] - + Args: data: Raw bytes from device - + Returns: LifuUserConfig instance - + Raises: ValueError: If data is invalid or malformed """ if len(data) < 16: raise ValueError(f"Wire data too short: {len(data)} bytes") - + header = LifuUserConfigHeader.from_bytes(data[:16]) - + if not header.is_valid(): raise ValueError(f"Invalid magic (0x{header.magic:08X}) or version (0x{header.version:08X})") - + # Extract JSON bytes json_bytes_end = 16 + header.json_len if len(data) < json_bytes_end: - logger.warning(f"JSON data truncated: expected {header.json_len} bytes, got {len(data) - 16}") + logger.warning( + "JSON data truncated: expected %s bytes, got %s", + header.json_len, + len(data) - 16, + ) json_bytes = data[16:] else: json_bytes = data[16:json_bytes_end] - + # Parse JSON (handle null terminator if present) - json_str = json_bytes.rstrip(b'\x00').decode('utf-8', errors='ignore') - + json_str = json_bytes.rstrip(b"\x00").decode("utf-8", errors="ignore") + try: json_data = json.loads(json_str) if json_str else {} except json.JSONDecodeError as e: - logger.warning(f"Failed to parse JSON: {e}. Using empty config.") + logger.warning("Failed to parse JSON: %s. Using empty config.", e) json_data = {} - + return cls(header=header, json_data=json_data) def to_wire_bytes(self) -> bytes: """ Convert configuration to wire format for sending to device - + Returns: bytes: Wire format [header][json_bytes] """ # Convert JSON to bytes - json_str = json.dumps(self.json_data, separators=(',', ':')) - json_bytes = json_str.encode('utf-8') - + json_str = json.dumps(self.json_data, separators=(",", ":")) + json_bytes = json_str.encode("utf-8") + # Update header with JSON length self.header.json_len = len(json_bytes) - + # Build wire format return self.header.to_bytes() + json_bytes @@ -147,10 +137,10 @@ def get_json_str(self) -> str: def set_json_str(self, json_str: str): """ Set configuration from JSON string - + Args: json_str: JSON string to parse - + Raises: json.JSONDecodeError: If JSON is invalid """ @@ -164,14 +154,16 @@ def set(self, key: str, value: Any): """Set a configuration value by key""" self.json_data[key] = value - def update(self, updates: Dict[str, Any]): + def update(self, updates: dict[str, Any]): """Update multiple configuration values""" self.json_data.update(updates) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Get the configuration as a dictionary""" return self.json_data.copy() def __repr__(self) -> str: - return (f"LifuUserConfig(seq={self.header.seq}, crc=0x{self.header.crc:04X}, " - f"json_len={self.header.json_len}, data={self.json_data})") + return ( + f"LifuUserConfig(seq={self.header.seq}, crc=0x{self.header.crc:04X}, " + f"json_len={self.header.json_len}, data={self.json_data})" + ) diff --git a/src/openlifu_sdk/transport.py b/src/openlifu_sdk/transport.py new file mode 100644 index 0000000..7151112 --- /dev/null +++ b/src/openlifu_sdk/transport.py @@ -0,0 +1,77 @@ +"""Transport protocol definition for openlifu-sdk. + +:class:`Transport` is a :class:`typing.Protocol` that formalises the interface +expected by :class:`~openlifu_sdk.io.LIFUTXDevice.TxDevice` and +:class:`~openlifu_sdk.io.LIFUHVController.HVController`. + +Any object that implements these attributes and methods — including the concrete +:class:`~openlifu_sdk.io.LIFUUart.LIFUUart` and any test double — can be passed +as the ``uart`` argument without inheriting from a base class. + +Example — creating a minimal test double:: + + from openlifu_sdk.transport import Transport + + class FakeTransport: + demo_mode = True + asyncMode = False + + def is_connected(self) -> bool: + return True + + def check_usb_status(self) -> None: + pass + + def send_packet(self, **kwargs): + ... # return a mock packet + + def clear_buffer(self) -> None: + pass + + def disconnect(self) -> None: + pass + + assert isinstance(FakeTransport(), Transport) # runtime check works +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +if TYPE_CHECKING: + pass + + +@runtime_checkable +class Transport(Protocol): + """Structural interface fulfilled by :class:`~openlifu_sdk.io.LIFUUart.LIFUUart`. + + Any class that exposes these attributes and methods is a valid transport. + No explicit inheritance is required (duck typing / structural subtyping). + """ + + #: When ``True`` the transport simulates hardware responses without a real device. + demo_mode: bool + + #: When ``True`` the transport operates in asynchronous (non-blocking) mode. + asyncMode: bool + + def is_connected(self) -> bool: + """Return ``True`` if the physical device port is open and connected.""" + ... + + def check_usb_status(self) -> None: + """Probe USB ports and update the connection state.""" + ... + + def send_packet(self, **kwargs: Any) -> Any: + """Send a packet to the device and return the response packet.""" + ... + + def clear_buffer(self) -> None: + """Flush any pending data from the receive buffer.""" + ... + + def disconnect(self) -> None: + """Close the underlying serial port.""" + ... diff --git a/src/openlifu_sdk/util/units.py b/src/openlifu_sdk/util/units.py index 6f5ad3e..4f268c9 100644 --- a/src/openlifu_sdk/util/units.py +++ b/src/openlifu_sdk/util/units.py @@ -6,42 +6,43 @@ def getunittype(unit): unit = unit.lower() - if unit in ['micron', 'microns']: - return 'distance' - elif unit in ['minute', 'minutes', 'min', 'mins', 'hour', 'hours', 'hr', 'hrs', 'day', 'days', 'd']: - return 'time' - elif unit in ['rad', 'deg', 'radian', 'radians', 'degree', 'degrees', '°']: - return 'angle' - elif 'sec' in unit: - return 'time' - elif 'meter' in unit or 'micron' in unit: - return 'distance' - elif unit.endswith('s'): - return 'time' - elif unit.endswith('m'): - return 'distance' - elif unit.endswith(('m2', 'm^2')): - return 'area' - elif unit.endswith(('m3', 'm^3')): - return 'volume' - elif unit.endswith('hz'): - return 'frequency' - elif unit.endswith('pa'): - return 'pressure' - elif unit.endswith('w'): - return 'watt' + if unit in ["micron", "microns"]: + return "distance" + elif unit in ["minute", "minutes", "min", "mins", "hour", "hours", "hr", "hrs", "day", "days", "d"]: + return "time" + elif unit in ["rad", "deg", "radian", "radians", "degree", "degrees", "°"]: + return "angle" + elif "sec" in unit: + return "time" + elif "meter" in unit or "micron" in unit: + return "distance" + elif unit.endswith("s"): + return "time" + elif unit.endswith("m"): + return "distance" + elif unit.endswith(("m2", "m^2")): + return "area" + elif unit.endswith(("m3", "m^3")): + return "volume" + elif unit.endswith("hz"): + return "frequency" + elif unit.endswith("pa"): + return "pressure" + elif unit.endswith("w"): + return "watt" else: - return 'other' + return "other" + def getunitconversion(from_unit, to_unit, unitratio=None, constant=None): if not from_unit: return 1.0 if unitratio is not None and constant is not None: - if '/' not in unitratio: - raise ValueError('Conversion unit ratio must have a \'/\' symbol') + if "/" not in unitratio: + raise ValueError("Conversion unit ratio must have a '/' symbol") - unitn, unitd = unitratio.split('/') + unitn, unitd = unitratio.split("/") type0 = getunittype(from_unit) type1 = getunittype(to_unit) typen = getunittype(unitn) @@ -54,27 +55,27 @@ def getunitconversion(from_unit, to_unit, unitratio=None, constant=None): elif type0 == type1: scl = getunitconversion(from_unit, to_unit) else: - raise ValueError(f'Unit type mismatch {type0} -> ({typen}/{typed}) -> {type1}') + raise ValueError(f"Unit type mismatch {type0} -> ({typen}/{typed}) -> {type1}") else: - slash0 = from_unit.find('/') - slash1 = to_unit.find('/') + slash0 = from_unit.find("/") + slash1 = to_unit.find("/") if slash0 != -1 and slash1 != -1: num0 = from_unit[:slash0] - denom0 = from_unit[slash0+1:] + denom0 = from_unit[slash0 + 1 :] num1 = to_unit[:slash1] - denom1 = to_unit[slash1+1:] + denom1 = to_unit[slash1 + 1 :] scl = getunitconversion(num0, num1) / getunitconversion(denom0, denom1) elif slash0 == -1 and slash1 == -1: type0 = getunittype(from_unit) type1 = getunittype(to_unit) if type0 != type1: - raise ValueError(f'Unit type mismatch ({type0}) vs ({type1})') + raise ValueError(f"Unit type mismatch ({type0}) vs ({type1})") - if type0 == 'other': + if type0 == "other": if from_unit[-1] != to_unit[-1]: - raise ValueError(f'Cannot convert {from_unit} to {to_unit}') + raise ValueError(f"Cannot convert {from_unit} to {to_unit}") i = 0 while i < min(len(from_unit), len(to_unit)) and from_unit[-i:] == to_unit[-i:]: @@ -89,40 +90,41 @@ def getunitconversion(from_unit, to_unit, unitratio=None, constant=None): scl1 = getsiscale(to_unit, type0) scl = scl0 / scl1 else: - raise ValueError(f'Unit ratio mismatch ({from_unit} vs {to_unit})') + raise ValueError(f"Unit ratio mismatch ({from_unit} vs {to_unit})") return scl + def getsiscale(unit, type): type = type.lower() - if type in ['distance', 'area', 'volume']: - idx = unit.find('meters') + if type in ["distance", "area", "volume"]: + idx = unit.find("meters") if idx == -1: - idx = unit.find('meter') + idx = unit.find("meter") if idx == -1: - if unit.lower() == 'micron': + if unit.lower() == "micron": idx = 6 else: - idx = unit.rfind('m') + idx = unit.rfind("m") if idx == -1: idx = len(unit) - elif type == 'time': - idx = unit.find('seconds') + elif type == "time": + idx = unit.find("seconds") if idx == -1: - idx = unit.find('second') + idx = unit.find("second") if idx == -1: - idx = unit.find('sec') + idx = unit.find("sec") if idx == -1: - idx = unit.rfind('s') + idx = unit.rfind("s") if idx == -1: idx = len(unit) - elif type == 'angle': + elif type == "angle": idx = len(unit) - elif type == 'frequency' or type == "pressure": + elif type == "frequency" or type == "pressure": idx = len(unit) - 2 elif type == "watt": @@ -138,43 +140,43 @@ def getsiscale(unit, type): else: scl = 1.0 - if prefix == 'pico' or prefix == 'p': + if prefix == "pico" or prefix == "p": scl = 1.0e-12 - elif prefix == 'nano' or prefix == 'n': + elif prefix == "nano" or prefix == "n": scl = 1.0e-9 - elif prefix == 'micro' or prefix == 'u' or prefix == '\u00b5' or prefix == '\u03bc': + elif prefix == "micro" or prefix == "u" or prefix == "\u00b5" or prefix == "\u03bc": scl = 1.0e-6 - elif prefix == 'milli' or prefix == 'm': + elif prefix == "milli" or prefix == "m": scl = 1.0e-3 - elif prefix == 'centi' or prefix == 'c': + elif prefix == "centi" or prefix == "c": scl = 1.0e-2 - elif prefix == '': + elif prefix == "": scl = 1.0 - elif prefix == 'kilo' or prefix == 'k': + elif prefix == "kilo" or prefix == "k": scl = 1.0e3 - elif prefix == 'mega' or prefix == 'M': + elif prefix == "mega" or prefix == "M": scl = 1.0e6 - elif prefix == 'giga' or prefix == 'G': + elif prefix == "giga" or prefix == "G": scl = 1.0e9 - elif prefix == 'tera' or prefix == 'T': + elif prefix == "tera" or prefix == "T": scl = 1.0e12 - elif prefix == 'min' or prefix == 'minute': + elif prefix == "min" or prefix == "minute": scl = 60.0 - elif prefix == 'hour' or prefix == 'hr': + elif prefix == "hour" or prefix == "hr": scl = 60.0 * 60.0 - elif prefix == 'day' or prefix == 'd': + elif prefix == "day" or prefix == "d": scl = 60.0 * 60.0 * 24.0 - elif prefix == 'rad' or prefix == 'radian' or prefix == 'radians': + elif prefix == "rad" or prefix == "radian" or prefix == "radians": scl = 1.0 - elif prefix == 'deg' or prefix == 'degree' or prefix == 'degrees' or prefix == '\u00b0': + elif prefix == "deg" or prefix == "degree" or prefix == "degrees" or prefix == "\u00b0": scl = 2 * 3.14159265358979323846 / 360 elif prefix: - raise ValueError(f'Unknown prefix {prefix}') + raise ValueError(f"Unknown prefix {prefix}") - if type == 'area': - scl = scl ** 2.0 - elif type == 'volume': - scl = scl ** 3.0 + if type == "area": + scl = scl**2.0 + elif type == "volume": + scl = scl**3.0 return scl @@ -191,9 +193,9 @@ def rescale_data_arr(data_arr: Dataset, units: str) -> Dataset: rescaled: The rescaled xarray to new units. """ rescaled = data_arr.copy(deep=True) - scale = getunitconversion(data_arr.attrs['units'], units) + scale = getunitconversion(data_arr.attrs["units"], units) rescaled.data *= scale - rescaled.attrs['units'] = units + rescaled.attrs["units"] = units return rescaled @@ -212,12 +214,12 @@ def rescale_coords(data_arr: Dataset, units: str) -> Dataset: rescaled = data_arr.copy(deep=True) for coord_key in data_arr.coords: curr_coord_attrs = rescaled[coord_key].attrs - if 'units' in curr_coord_attrs: - curr_coord_units = curr_coord_attrs['units'] + if "units" in curr_coord_attrs: + curr_coord_units = curr_coord_attrs["units"] scale = getunitconversion(curr_coord_units, units) - curr_coord_rescaled = scale*rescaled[coord_key].data + curr_coord_rescaled = scale * rescaled[coord_key].data rescaled = rescaled.assign_coords({coord_key: (coord_key, curr_coord_rescaled, curr_coord_attrs)}) - rescaled[coord_key].attrs['units'] = units + rescaled[coord_key].attrs["units"] = units return rescaled @@ -237,7 +239,7 @@ def get_ndgrid_from_arr(data_arr: Dataset) -> np.ndarray: ordered_key = data_arr[first_data_key].dims all_coord = [] for coord_key in ordered_key: - if 'units' in data_arr[coord_key].attrs: + if "units" in data_arr[coord_key].attrs: all_coord += [data_arr.coords[coord_key].data] ndgrid = np.stack(np.meshgrid(*all_coord, indexing="ij"), axis=-1)