All checks were successful
Build AppImage / Explore-Gitea-Actions (push) Successful in 1m33s
closes #209 Co-authored-by: Dominik Demuth <dominik.demuth@physik.tu-darmstadt.de> Reviewed-on: #222
226 lines
6.6 KiB
Python
226 lines
6.6 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Callable, Any
|
|
import operator
|
|
|
|
from inspect import signature, Parameter
|
|
|
|
|
|
class ModelFactory:
|
|
|
|
@staticmethod
|
|
def create_from_list(
|
|
funcs: list,
|
|
left=None,
|
|
func_order: list[int] = None,
|
|
param_len: list[int] = None,
|
|
left_cnt: int = 0,
|
|
):
|
|
if func_order is None:
|
|
func_order = []
|
|
|
|
if param_len is None:
|
|
param_len = []
|
|
|
|
for func in funcs:
|
|
if not func['active']:
|
|
continue
|
|
|
|
if func['children']:
|
|
f = func.copy()
|
|
f['children'] = []
|
|
right, _, _, right_cnt = ModelFactory.create_from_list(
|
|
[f] + func['children'],
|
|
left_cnt=func['pos'],
|
|
func_order=func_order,
|
|
param_len=param_len,
|
|
)
|
|
else:
|
|
right = func['func']
|
|
right_cnt = func['cnt']
|
|
|
|
func_order.append(func['cnt'])
|
|
param_len.append(len(func['func'].params))
|
|
|
|
if left is None:
|
|
left = right
|
|
left_cnt = right_cnt
|
|
else:
|
|
left = MultiModel(left, right, func['op'], left_idx=left_cnt, right_idx=right_cnt)
|
|
|
|
return left, func_order, param_len, left_cnt
|
|
|
|
|
|
class MultiModel:
|
|
op_repr = {
|
|
operator.add: ' + ',
|
|
operator.mul: ' * ',
|
|
operator.sub: ' - ',
|
|
operator.truediv: ' / ',
|
|
}
|
|
str_op = {
|
|
'+': operator.add,
|
|
'*': operator.mul,
|
|
'-': operator.sub,
|
|
'/': operator.truediv,
|
|
}
|
|
int_op = {
|
|
0: operator.add,
|
|
1: operator.mul,
|
|
2: operator.sub,
|
|
3: operator.truediv,
|
|
}
|
|
|
|
def __init__(self,
|
|
left: Any,
|
|
right: Any,
|
|
op: str | Callable | int = '+',
|
|
left_idx: int | None = 0,
|
|
right_idx: int | None = 1,
|
|
):
|
|
self._left = left
|
|
self._right = right
|
|
|
|
self._op = None
|
|
|
|
if isinstance(op, str):
|
|
self._op = MultiModel.str_op.get(op, None)
|
|
elif isinstance(op, int):
|
|
self._op = MultiModel.int_op.get(op, None)
|
|
elif isinstance(op, Callable):
|
|
self._op = op
|
|
|
|
if self._op is None:
|
|
raise ValueError('Invalid binary operator.')
|
|
|
|
if right_idx is None:
|
|
right_idx = left_idx + 1
|
|
|
|
self.name = '('
|
|
self.params = []
|
|
self.bounds = []
|
|
self._kwargs_right = {}
|
|
self._kwargs_left = {}
|
|
self.fun_kwargs = {}
|
|
self.idx = (left_idx, right_idx)
|
|
|
|
# mapping kwargs to kwargs of underlying functions
|
|
self._ext_int_kw = {}
|
|
|
|
self._get_parameter(left, 'l', left_idx)
|
|
self._param_left = len(left.params)
|
|
|
|
try:
|
|
self.name += MultiModel.op_repr[self._op]
|
|
except KeyError:
|
|
self.name += str(op)
|
|
|
|
self._get_parameter(right, 'r', right_idx)
|
|
self.name += ')'
|
|
|
|
self._param_len = len(self.params)
|
|
|
|
def __str__(self):
|
|
return self.name
|
|
|
|
def _get_parameter(self, func, pos, idx):
|
|
kw_dict = {'l': self._kwargs_left, 'r': self._kwargs_right}[pos]
|
|
|
|
if isinstance(func, MultiModel):
|
|
strcnt = ''
|
|
kw_dict.update(func.fun_kwargs)
|
|
self.fun_kwargs.update({k: v for k, v in kw_dict.items()})
|
|
self._ext_int_kw.update({k: k for k in kw_dict.keys()})
|
|
|
|
else:
|
|
temp_dic = {k: v.default for k, v in signature(func.func).parameters.items()
|
|
if v.default is not Parameter.empty}
|
|
|
|
for k, v in temp_dic.items():
|
|
key_ = k
|
|
if k != 'complex_mode':
|
|
key_ = f'{k}_{idx}'
|
|
kw_dict[key_] = v
|
|
self.fun_kwargs[key_] = v
|
|
self._ext_int_kw[key_] = k
|
|
|
|
strcnt = f'({idx})'
|
|
|
|
self.params += [pp+strcnt for pp in func.params]
|
|
self.name += func.name + strcnt
|
|
|
|
try:
|
|
self.bounds.extend(func.bounds)
|
|
except AttributeError:
|
|
self.bounds.extend([(None, None)]*len(func.params))
|
|
|
|
def _left_arguments(self, *args, **kwargs):
|
|
kw_left = {}
|
|
for k_ext, k_int in self._ext_int_kw.items():
|
|
if k_ext in self._kwargs_left:
|
|
if not k_ext.startswith('complex_mode'):
|
|
kw_left[k_int] = kwargs[k_ext]
|
|
else:
|
|
kw_left['complex_mode'] = kwargs['complex_mode']
|
|
|
|
pl = args[:self._param_left]
|
|
|
|
return pl, kw_left
|
|
|
|
def _right_arguments(self, *args, **kwargs):
|
|
kw_right = {}
|
|
for k_ext, k_int in self._ext_int_kw.items():
|
|
if k_ext in self._kwargs_right:
|
|
if not k_ext.startswith('complex_mode'):
|
|
kw_right[k_int] = kwargs[k_ext]
|
|
else:
|
|
kw_right['complex_mode'] = kwargs['complex_mode']
|
|
|
|
pr = args[self._param_left:self._param_len]
|
|
|
|
return pr, kw_right
|
|
|
|
def func(self, x, *args, **kwargs):
|
|
pl, kw_left = self._left_arguments(*args, **kwargs)
|
|
l_func = self._left.func(x, *pl, **kw_left)
|
|
|
|
pr, kw_right = self._right_arguments(*args, **kwargs)
|
|
r_func = self._right.func(x, *pr, **kw_right)
|
|
|
|
return self._op(l_func, r_func)
|
|
|
|
def left_func(self, x, *args, **kwargs):
|
|
return self._left.func(x, *args, **kwargs)
|
|
|
|
def right_func(self, x, *args, **kwargs):
|
|
return self._right.func(x, *args, **kwargs)
|
|
|
|
def subs(self, x, *args, **kwargs):
|
|
""" Iterator over all sub-functions (depth-first and left-to-right) """
|
|
pl, kw_left = self._left_arguments(*args, **kwargs)
|
|
if isinstance(self._left, MultiModel):
|
|
yield from self._left.subs(x, *pl, **kw_left)
|
|
else:
|
|
yield self._left.func(x, *pl, **kw_left)
|
|
|
|
pr, kw_right = self._right_arguments(*args, **kwargs)
|
|
if isinstance(self._right, MultiModel):
|
|
yield from self._right.subs(x, *pr, **kw_right)
|
|
else:
|
|
yield self._right.func(x, *pr, **kw_right)
|
|
|
|
def sub_name(self):
|
|
if isinstance(self._left, MultiModel):
|
|
yield from self._left.sub_name()
|
|
elif hasattr(self._left, 'name'):
|
|
yield f'{self._left.name}({self.idx[0]})'
|
|
else:
|
|
yield self.name + '(lhs)'
|
|
|
|
if isinstance(self._right, MultiModel):
|
|
yield from self._right.sub_name()
|
|
elif hasattr(self._right, 'name'):
|
|
yield f'{self._right.name}({self.idx[1]})'
|
|
else:
|
|
yield self.name + '(rhs)'
|