diff --git a/qupulse/program/__init__.py b/qupulse/program/__init__.py index b5040676..a1b52b7c 100644 --- a/qupulse/program/__init__.py +++ b/qupulse/program/__init__.py @@ -1,16 +1,17 @@ import contextlib from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, Union, Sequence, ContextManager, Mapping, Tuple, Generic, TypeVar, Iterable -from numbers import Real +from typing import Optional, Union, Sequence, ContextManager, Mapping, Tuple, Generic, TypeVar, Iterable, Dict +from numbers import Real, Number import numpy as np from qupulse._program.waveforms import Waveform -from qupulse.utils.types import MeasurementWindow, TimeType +from qupulse.utils.types import MeasurementWindow, TimeType, FrozenMapping from qupulse._program.volatile import VolatileRepetitionCount from qupulse.parameter_scope import Scope -from qupulse.expressions import sympy as sym_expr +from qupulse.expressions import sympy as sym_expr, Expression +from qupulse.utils.sympy import _lambdify_modules from typing import Protocol, runtime_checkable @@ -30,7 +31,7 @@ class SimpleExpression(Generic[NumVal]): """ base: NumVal - offsets: Sequence[Tuple[str, NumVal]] + offsets: Dict[str, NumVal] def value(self, scope: Mapping[str, NumVal]) -> NumVal: value = self.base @@ -43,7 +44,10 @@ def __add__(self, other): return SimpleExpression(self.base + other, self.offsets) if type(other) == type(self): - return SimpleExpression(self.base + other.base, self.offsets + other.offsets) + offsets = self.offsets.copy() + for name, value in other.offsets.items(): + offsets[name] = value + offsets.get(name, 0) + return SimpleExpression(self.base + other.base, offsets) return NotImplemented @@ -57,22 +61,40 @@ def __rsub__(self, other): (-self).__add__(other) def __neg__(self): - return SimpleExpression(-self.base, tuple((name, -value) for name, value in self.offsets)) + return SimpleExpression(-self.base, {name: -value for name, value in self.offsets.items()}) def __mul__(self, other: NumVal): - if isinstance(other, SimpleExpression): - return NotImplemented - return SimpleExpression(self.base * other, tuple((name, value * other) for name, value in self.offsets)) + if isinstance(other, (float, int, TimeType)): + return SimpleExpression(self.base * other, {name: other * value for name, value in self.offsets.items()}) + + return NotImplemented def __rmul__(self, other): return self.__mul__(other) - def evaluate_in_scope(self, *args, **kwargs): + def __truediv__(self, other): + inv = 1 / other + return self.__mul__(inv) + + @property + def free_symbols(self): + return () + + def _sympy_(self): + return self + + def replace(self, r, s): + return self + + def evaluate_in_scope_(self, *args, **kwargs): # TODO: remove. It is currently required to avoid nesting this class in an expression for the MappedScope # We can maybe replace is with a HardwareScope or something along those lines return self +_lambdify_modules.append({'SimpleExpression': SimpleExpression}) + + RepetitionCount = Union[int, VolatileRepetitionCount, SimpleExpression[int]] HardwareTime = Union[TimeType, SimpleExpression[TimeType]] HardwareVoltage = Union[float, SimpleExpression[float]] diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index 7dc3ddc6..d224648b 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -124,7 +124,7 @@ def inner_scope(self, scope: Scope) -> Scope: process.""" if self._ranges: name, _ = self._ranges[-1] - return MappedScope(scope, FrozenDict({name: SimpleExpression(base=0, offsets=[(name, 1)])})) + return scope.overwrite({name: SimpleExpression(base=0, offsets=[(name, 1)])}) else: return scope