diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f09b280a..6a8e0420 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,7 +17,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.12"] steps: - uses: actions/checkout@v3 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b313588a..722e29bc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: rev: v4.4.0 hooks: - id: check-yaml - - repo: https://github.com/psf/black - rev: 22.3.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.0 hooks: - - id: black + - id: ruff-format \ No newline at end of file diff --git a/examples/jig_driver.py b/examples/jig_driver.py index 6d56cdf6..ff100b3f 100644 --- a/examples/jig_driver.py +++ b/examples/jig_driver.py @@ -2,6 +2,7 @@ This file is just a test playground that shows how the update jig classes will fit together. """ + from __future__ import annotations from dataclasses import dataclass, field from fixate import ( @@ -10,6 +11,8 @@ MuxGroup, PinValueAddressHandler, VirtualSwitch, + RelayMatrixMux, + Signal, ) @@ -58,3 +61,84 @@ class JigMuxGroup(MuxGroup): jig.mux.mux_two("sig5") jig.mux.mux_three("On") jig.mux.mux_three(False) + + +# VirtualMuxes can be created with type annotations to provide the signal map +from typing import Literal, Annotated, Union + +# a signal is a typing Annotation +# the first Literal is the signal name, the rest are the pin names +# the signal name MUST be a Literal +# multiple signals can be combined with a Union +# assigning annotations to variables is possible +# fmt: off +MuxOneSigDef = Union[ + Annotated[Literal["sig_a1"], "a0", "a2"], + Annotated[Literal["sig_a2"], "a1"], +] + +# alternative syntax without the Union keyword +MuxTwoSigDef = (Annotated[Literal["sig_b1"], "b0", "b2"] | + Annotated[Literal["sig_b2"], "b1"]) + +# if defining only a single signal, the Union is omitted in the definition +SingleSignalDef = Annotated[Literal["sig_c1"], "c0", "c1"] +# fmt: on + + +# VirtualMuxes can now be created with type annotations to provide the signal map +# this only works when subclassing +class MyMux(VirtualMux[MuxOneSigDef]): + """A helpful description for my mux that is used in this jig driver""" + + +muxa = MyMux() + +muxa("sig_a1") +muxa("sig_a2") + +# using the wrong signal name will be caught at runtime and by the type checker +try: + muxa("unknown signal name") +except ValueError as e: + print(f"An Exception would have occurred: {e}") +else: + raise ValueError("Should have raised an exception") + + +class MultiPinSwitch(VirtualMux[SingleSignalDef]): + """This acts like a switch, but has to coordinate two pins""" + + +ls = MultiPinSwitch() +ls("sig_c1") +ls("") + +# further generic types can be created by subclassing from VirtualMux using a TypeVar +# compared to the above way of subclassing, this way lets you reuse the class + + +class MyGenericMux[S: Signal](VirtualMux[S]): + ... + + def extra_method(self) -> None: + print("Extra method") + + +class MyConcreteMux(MyGenericMux[MuxTwoSigDef]): + pass + + +generic_mux = MyConcreteMux() +generic_mux("sig_b1") +generic_mux("sig_b2") + + +# RelayMatrixMux is an example of a reusable generic class +class MyRelayMatrixMux(RelayMatrixMux[MuxOneSigDef]): + pass + + +rmm = MyRelayMatrixMux() +rmm("sig_a1") +rmm("sig_a2") diff --git a/mypy.ini b/mypy.ini index 757f7204..4beb01b4 100644 --- a/mypy.ini +++ b/mypy.ini @@ -94,4 +94,4 @@ follow_imports = silent [mypy-fixate.core.checks,fixate.core.common,fixate.core.config_util,fixate.core.jig_mapping] follow_imports = silent [mypy-fixate.drivers,fixate.drivers.dso.helper,fixate.drivers.funcgen.helper,fixate.drivers.funcgen.rigol_dg1022,fixate.drivers.pps,fixate.drivers.pps.helper,fixate.drivers.ftdi] -follow_imports = silent +follow_imports = silent \ No newline at end of file diff --git a/src/fixate/_switching.py b/src/fixate/_switching.py index f9209d7e..4eb2b8ad 100644 --- a/src/fixate/_switching.py +++ b/src/fixate/_switching.py @@ -32,28 +32,33 @@ import itertools import time from typing import ( - Generic, Optional, Callable, Sequence, - TypeVar, Generator, Union, Collection, Dict, FrozenSet, Iterable, + Literal, + Annotated, + get_origin, + get_args, ) from dataclasses import dataclass from functools import reduce from operator import or_ + Signal = str +EmptySignal = Literal[""] Pin = str PinList = Sequence[Pin] PinSet = FrozenSet[Pin] +MapList = Sequence[Sequence[Union[Signal, Pin]]] SignalMap = Dict[Signal, PinSet] -TreeDef = Sequence[Union[Signal, "TreeDef"]] +TreeDef = Sequence[Union[Optional[Signal], "TreeDef"]] @dataclass(frozen=True) @@ -88,14 +93,21 @@ def __or__(self, other: PinUpdate) -> PinUpdate: PinUpdateCallback = Callable[[PinUpdate, bool], None] -class VirtualMux: +from types import get_original_bases, resolve_bases + + +class VirtualMux[S: Signal]: + map_tree: Optional[TreeDef] = None + map_list: Optional[Sequence[Sequence[str]]] = None pin_list: PinList = () clearing_time: float = 0.0 ########################################################################### # These methods are the public API for the class - def __init__(self, update_pins: Optional[PinUpdateCallback] = None): + # digest all the typing information if there is any to set pin_list and map_list + self._digest_type_hints() + self._last_update_time = time.monotonic() self._update_pins: PinUpdateCallback @@ -129,7 +141,9 @@ def __init__(self, update_pins: Optional[PinUpdateCallback] = None): if hasattr(self, "default_signal"): raise ValueError("'default_signal' should not be set on a VirtualMux") - def __call__(self, signal: Signal, trigger_update: bool = True) -> None: + def __call__( + self, signal: Union[S, EmptySignal], trigger_update: bool = True + ) -> None: """ Convenience to avoid having to type jig.mux..multiplex. @@ -138,7 +152,9 @@ def __call__(self, signal: Signal, trigger_update: bool = True) -> None: """ self.multiplex(signal, trigger_update) - def multiplex(self, signal: Signal, trigger_update: bool = True) -> None: + def multiplex( + self, signal: Union[S, EmptySignal], trigger_update: bool = True + ) -> None: """ Update the multiplexer state to signal. @@ -230,13 +246,13 @@ def _map_signals(self) -> SignalMap: Avoid subclassing. Consider creating helper functions to build map_tree or map_list. """ - if hasattr(self, "map_tree"): + if self.map_tree is not None: return self._map_tree(self.map_tree, self.pin_list, fixed_pins=frozenset()) - elif hasattr(self, "map_list"): + elif self.map_list is not None: return {sig: frozenset(pins) for sig, *pins in self.map_list} else: raise ValueError( - "VirtualMux subclass must define either map_tree or map_list" + "VirtualMux subclass must define either map_tree or map_list or provide a type to VirtualMux" ) def _map_tree(self, tree: TreeDef, pins: PinList, fixed_pins: PinSet) -> SignalMap: @@ -441,6 +457,56 @@ def _default_update_pins( """ print(pin_updates, trigger_update) + def _digest_type_hints(self) -> None: + # digest all the typing information if there is any + + # original bases are effectively the "as written" class + # they are types, not classes + bases = get_original_bases(self.__class__) + # resolved bases are what actually exist at runtime + resolved_bases = resolve_bases(bases) + first_resolved_base = resolved_bases[0] + # now check that we are trying to get typing information out of the correct class + assert issubclass( + first_resolved_base, VirtualMux + ), f"First parent class of {self.__class__} should be VirtualMux subclass, not {first_resolved_base}" + + args = get_args(bases[0]) + # if we found typing annotations, use them to define the pins and signals + if args: + plist, mlist = self._unpack_muxdef(args[0]) + self.pin_list = plist + self.map_list = mlist + + @staticmethod + def _unpack_muxdef(muxdef: type) -> tuple[PinList, MapList]: + # muxdef is the signal definition + if get_origin(muxdef) == Union: + signals = get_args(muxdef) + elif get_origin(muxdef) == Annotated: + # Union FORCES you to have two or more types, so this handles the case of only one pin + signals = (muxdef,) + else: + raise TypeError("Signal definition must be Union or Annotated") + + map_list: list[tuple[str]] = [] + pin_list: list[str] = [] + for s in signals: + assert get_origin(s) == Annotated, "Signal definition must be Annotated" + # get_args gives Literal + sigdef, *pins = get_args(s) + assert ( + get_origin(sigdef) == Literal + ), "Signal definition must be string literal" + # get_args gives members of Literal + (signame,) = get_args(sigdef) + assert isinstance(signame, Signal), "Signal name must be signal type" + assert all(isinstance(p, Pin) for p in pins), "Pins must be pin type" + pin_list.extend(pins) + map_list.append((signame, *pins)) + + return pin_list, map_list + class VirtualSwitch(VirtualMux): """ @@ -482,7 +548,7 @@ def __init__( super().__init__(update_pins) -class RelayMatrixMux(VirtualMux): +class RelayMatrixMux[S: Signal](VirtualMux[S]): clearing_time = 0.01 def _calculate_pins( @@ -684,10 +750,7 @@ def active_signals(self) -> list[str]: return [str(mux) for mux in self.get_multiplexers()] -JigSpecificMuxGroup = TypeVar("JigSpecificMuxGroup", bound=MuxGroup) - - -class JigDriver(Generic[JigSpecificMuxGroup]): +class JigDriver[JigSpecificMuxGroup: MuxGroup]: """ Combine multiple VirtualMux's and multiple AddressHandler's. @@ -771,10 +834,7 @@ def _validate(self) -> None: ) -_T = TypeVar("_T") - - -def _generate_bit_sets(bits: Sequence[_T]) -> Generator[set[_T], None, None]: +def _generate_bit_sets[_T](bits: Sequence[_T]) -> Generator[set[_T], None, None]: """ Create subsets of bits, representing bits of a list of integers diff --git a/test/test_switching.py b/test/test_switching.py index 527f1d2c..0c1e2c77 100644 --- a/test/test_switching.py +++ b/test/test_switching.py @@ -15,6 +15,8 @@ JigDriver, ) +from typing import Literal, Union, TypeVar, get_args, get_origin, Annotated + import pytest ################################################################ @@ -592,3 +594,157 @@ def test_pin_update_or(): 2.0, ) assert expected == a | b + + +# fmt: off +MuxASigDef = Union[ + Annotated[Literal["sig_a1"], "a0", "a1"], + Annotated[Literal["sig_a2"], "a1"] +] +# fmt: on + + +def test_typed_mux_using_subclass(): + class SubMux(VirtualMux[MuxASigDef]): + pass + + sm = SubMux(update_pins=print) + assert sm._signal_map == MuxA()._signal_map + assert sm._pin_set == MuxA()._pin_set + + +def test_typed_relaymux_using_subclass(): + class SubRelayMux(RelayMatrixMux[MuxASigDef]): + pass + + srm = SubRelayMux() + assert srm._signal_map == MuxA()._signal_map + assert srm._pin_set == MuxA()._pin_set + + +def test_typed_mux_generic_subclass(): + T = TypeVar("T", bound=str) + + class GenericSubMux(VirtualMux[T]): + pass + + class ConcreteMux(GenericSubMux[MuxASigDef]): + pass + + gsm = ConcreteMux() + assert gsm._signal_map == MuxA()._signal_map + assert gsm._pin_set == MuxA()._pin_set + + +def test_typed_mux_one_signal(): + + muxbdef = Annotated[Literal["sig1"], "a0"] + + class MuxB(VirtualMux[muxbdef]): + ... + + clear = PinSetState(off=frozenset({"a0"})) + a1 = PinSetState(on=frozenset({"a0"})) + + updates = [] + muxb = MuxB(update_pins=lambda x, y: updates.append((x, y))) + muxb("sig1") + assert updates.pop() == (PinUpdate(PinSetState(), a1), True) + + muxb("") + assert updates.pop() == (PinUpdate(PinSetState(), clear), True) + + +def test_annotated_preserve_pin_defs(): + annotated = Annotated[Literal["sig_a1"], "a0", "a1"] + sigdef, *pins = get_args(annotated) + + +def test_annotated_raises_on_missing_pin_def(): + with pytest.raises(TypeError): + annotated = Annotated[Literal["sig_a1"]] + + +def test_annotation_bad_pindefs(): + BadMuxDef = Union[ + Annotated[Literal["sig_a1"], "a0", "a1"], + Annotated[Literal["sig_a1"], "a0", 1], + ] + + class BadMux(VirtualMux[BadMuxDef]): + pass + + with pytest.raises(AssertionError): + mux = BadMux() + + +def test_annotation_bad_brackets(): + """ + We put the brackets in the wrong spot and accidentally defined + one of the signals as one of the pins of the previous signal + """ + BadMuxDef = Union[ + Annotated[Literal["sig_a1"], "a0", "a1", Annotated[Literal["sig_a2"], "a1"]], + Annotated[Literal["sig_a1"], "a0", "a1"], + ] + + class BadMux(VirtualMux[BadMuxDef]): + pass + + with pytest.raises(AssertionError): + mux = BadMux() + + +def test_annotated_get_origin(): + # Annotated behaviour is different between python versions + # fails 3.8, passes >=3.9 + assert get_origin(Annotated[Literal["sig_a1"], "a0", "a1"]) == Annotated + + +def test_annotated_get_args(): + assert get_args(Annotated[Literal["sig_a1"], "a0", "a1"]) == ( + Literal["sig_a1"], + "a0", + "a1", + ) + + +@pytest.mark.skip( + reason="Revisit this idea once we have a way to stop Generic breaking getattr" +) +def test_typed_mux_class_getitem(): + clear = PinSetState(off=frozenset({"a0", "a1"})) + a1 = PinSetState(on=frozenset({"a0", "a1"})) + a2 = PinSetState(on=frozenset({"a1"}), off=frozenset({"a0"})) + + updates_class_mux = [] + updates_mux_a = [] + + class_mux = VirtualMux[MuxASigDef](lambda x, y: updates_class_mux.append((x, y))) + mux_a = MuxA(lambda x, y: updates_mux_a.append((x, y))) + assert mux_a._signal_map == class_mux._signal_map + assert mux_a._pin_set == class_mux._pin_set + + class_mux("sig_a1") + mux_a("sig_a1") + assert ( + updates_class_mux.pop() + == updates_mux_a.pop() + == (PinUpdate(PinSetState(), a1), True) + ) + + class_mux.multiplex("sig_a2", trigger_update=False) + mux_a.multiplex("sig_a2", trigger_update=False) + assert ( + updates_class_mux.pop() + == updates_mux_a.pop() + == (PinUpdate(PinSetState(), a2), False) + ) + + class_mux("") + mux_a("") + assert ( + updates_class_mux.pop() + == updates_mux_a.pop() + == (PinUpdate(PinSetState(), clear), True) + )