Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ classifiers = [

dependencies = [
"aio-pika ~= 9.4, < 9.5",
"omotes-sdk-protocol ~= 1.1",
"omotes-sdk-protocol ~= 1.2",
"pyesdl ~= 24.2",
"pamqp ~= 3.3",
"celery ~= 5.3",
Expand Down Expand Up @@ -75,7 +75,7 @@ enabled = true
starting_version = "0.0.1"

[tool.pytest.ini_options]
addopts = "--cov=omotes_sdk --cov-report html --cov-report term-missing --cov-fail-under 62"
addopts = "--cov=omotes_sdk --cov-report html --cov-report term-missing --cov-fail-under 60"

[tool.coverage.run]
source = ["src"]
Expand Down
181 changes: 168 additions & 13 deletions src/omotes_sdk/workflow_type.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import logging
import pprint
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import List, Optional, Dict, Union, Any, Type, TypeVar, cast
from typing import List, Optional, Dict, Union, Any, Type, TypeVar, cast, Literal
from typing_extensions import Self, override

from omotes_sdk_protocol.workflow_pb2 import (
Expand Down Expand Up @@ -49,6 +50,10 @@ class WorkflowParameter(ABC):
"""Optional description (displayed below the input field)."""
type_name: str = ""
"""Parameter type name, set in child class."""
constraints: List[WorkflowParameterPb.Constraint] = field(
default_factory=list, hash=False, compare=False
)
"""Optional list of non-ESDL workflow parameters."""

@staticmethod
@abstractmethod
Expand Down Expand Up @@ -130,6 +135,54 @@ def to_pb_value(value: ParamsDictValues) -> PBStructCompatibleTypes:
"""
... # pragma: no cover

def check_parameter_constraint(
self,
value1: ParamsDictValues,
value2: ParamsDictValues,
check: WorkflowParameterPb.Constraint,
) -> Literal[True]:
"""Check if the values adhere to the parameter constraint.

:param value1: The left-hand value to be checked.
:param value2: The right-hand value to the checked.
:param check: The parameter constraint to check between `value1` and `value2`
:return: Always true if the function returns noting the parameter constraint is adhered to.
:raises RuntimeError: In case the parameter constraint is not adhered to.
"""
supported_types = (float, int, datetime, timedelta)
if not isinstance(value1, supported_types) or not isinstance(value2, supported_types):
raise RuntimeError(
f"Values {value1}, {value2} are of a type that are not supported "
f"by parameter constraint {check}"
)

same_type_required = (datetime, timedelta)
if (
isinstance(value1, same_type_required) or isinstance(value2, same_type_required)
) and type(value1) is not type(value2):
raise RuntimeError(
f"Values {value1}, {value2} are required to be of the same type to be"
f"supported by parameter constraint {check}"
)

if check.relation == WorkflowParameterPb.Constraint.RelationType.GREATER:
result = value1 > value2 # type: ignore[operator]
elif check.relation == WorkflowParameterPb.Constraint.RelationType.GREATER_OR_EQ:
result = value1 >= value2 # type: ignore[operator]
elif check.relation == WorkflowParameterPb.Constraint.RelationType.SMALLER:
result = value1 < value2 # type: ignore[operator]
elif check.relation == WorkflowParameterPb.Constraint.RelationType.SMALLER_OR_EQ:
result = value1 <= value2 # type: ignore[operator]
else:
raise RuntimeError("Unknown parameter constraint. Please implement.")

if not result:
raise RuntimeError(
f"Check failed for constraint {check.relation} with "
f"{self.key_name}: {value1} and {check.other_key_name}: {value2}"
)
return result


@dataclass(eq=True, frozen=True)
class StringEnumOption:
Expand Down Expand Up @@ -196,6 +249,7 @@ def from_pb_message(
description=parameter_pb.description,
default=parameter_type_pb.default,
enum_options=[],
constraints=list(parameter_pb.constraints),
)
for enum_option_pb in parameter_type_pb.enum_options:
if parameter_type_pb.enum_options and parameter.enum_options is not None:
Expand All @@ -221,6 +275,16 @@ def from_json_config(cls, json_config: Dict) -> Self:
if "enum_options" in json_config and not isinstance(json_config["enum_options"], List):
raise TypeError("'enum_options' for StringParameter must be a 'list'")

if "constraints" in json_config:
if not isinstance(json_config["constraints"], list):
raise TypeError("'constraints' for StringParameter must be a 'list'")

parsed_constraints = [
convert_json_to_parameter_constraint(constraint)
for constraint in json_config["constraints"]
]
json_config["constraints"] = parsed_constraints

if "enum_options" in json_config:
enum_options = []
for enum_option in json_config["enum_options"]:
Expand Down Expand Up @@ -316,6 +380,7 @@ def from_pb_message(
title=parameter_pb.title,
description=parameter_pb.description,
default=parameter_type_pb.default,
constraints=list(parameter_pb.constraints),
)

@classmethod
Expand All @@ -331,6 +396,17 @@ def from_json_config(cls, json_config: Dict) -> Self:
f"'default' for BooleanParameter must be in 'bool' format:"
f" '{json_config['default']}'"
)

if "constraints" in json_config:
if not isinstance(json_config["constraints"], list):
raise TypeError("'constraints' for BooleanParameter must be a 'list'")

parsed_constraints = [
convert_json_to_parameter_constraint(constraint)
for constraint in json_config["constraints"]
]
json_config["constraints"] = parsed_constraints

return cls(**json_config)

@staticmethod
Expand Down Expand Up @@ -415,6 +491,7 @@ def from_pb_message(
maximum=(
parameter_type_pb.maximum if parameter_type_pb.HasField("maximum") else None
), # protobuf has '0' default value for int instead of None
constraints=list(parameter_pb.constraints),
)

@classmethod
Expand All @@ -432,6 +509,17 @@ def from_json_config(cls, json_config: Dict) -> Self:
f"'{int_param}' for IntegerParameter must be in 'int' format:"
f" '{json_config[int_param]}'"
)

if "constraints" in json_config:
if not isinstance(json_config["constraints"], list):
raise TypeError("'constraints' for IntegerParameter must be a 'list'")

parsed_constraints = [
convert_json_to_parameter_constraint(constraint)
for constraint in json_config["constraints"]
]
json_config["constraints"] = parsed_constraints

return cls(**json_config)

@staticmethod
Expand Down Expand Up @@ -526,6 +614,7 @@ def from_pb_message(
maximum=(
parameter_type_pb.maximum if parameter_type_pb.HasField("maximum") else None
), # protobuf has '0' default value for int instead of None
constraints=list(parameter_pb.constraints),
)

@classmethod
Expand All @@ -548,6 +637,16 @@ def from_json_config(cls, json_config: Dict) -> Self:
f" '{json_config[float_param]}'"
)

if "constraints" in json_config:
if not isinstance(json_config["constraints"], list):
raise TypeError("'constraints' for FloatParameter must be a 'list'")

parsed_constraints = [
convert_json_to_parameter_constraint(constraint)
for constraint in json_config["constraints"]
]
json_config["constraints"] = parsed_constraints

return cls(**json_config)

@staticmethod
Expand Down Expand Up @@ -636,6 +735,7 @@ def from_pb_message(
title=parameter_pb.title,
description=parameter_pb.description,
default=default,
constraints=list(parameter_pb.constraints),
)

@classmethod
Expand All @@ -656,6 +756,16 @@ def from_json_config(cls, json_config: Dict) -> Self:
)
json_config["default"] = default

if "constraints" in json_config:
if not isinstance(json_config["constraints"], list):
raise TypeError("'constraints' for DateTimeParameter must be a 'list'")

parsed_constraints = [
convert_json_to_parameter_constraint(constraint)
for constraint in json_config["constraints"]
]
json_config["constraints"] = parsed_constraints

return cls(**json_config)

@staticmethod
Expand Down Expand Up @@ -752,6 +862,7 @@ def from_pb_message(
if parameter_type_pb.HasField("maximum")
else None
),
constraints=list(parameter_pb.constraints),
)

@classmethod
Expand Down Expand Up @@ -779,6 +890,16 @@ def from_json_config(cls, json_config: Dict) -> Self:
elif duration_param in json_config:
args[duration_param] = timedelta(seconds=json_config[duration_param])

if "constraints" in json_config:
if not isinstance(json_config["constraints"], list):
raise TypeError("'constraints' for StringParameter must be a 'list'")

parsed_constraints = [
convert_json_to_parameter_constraint(constraint)
for constraint in json_config["constraints"]
]
args["constraints"] = parsed_constraints

return cls(**args)

@staticmethod
Expand Down Expand Up @@ -843,6 +964,33 @@ def to_pb_value(value: ParamsDictValues) -> float:
}


def convert_str_to_parameter_relation(
parameter_constraint_name: str,
) -> WorkflowParameterPb.Constraint.RelationType.ValueType:
"""Translate the name of a parameter constraint to the relevant enum.

:param parameter_constraint_name: String name of the parameter constraint.
:return: The parameter constraint as an enum value of `Constraint.RelationType`
:raises RuntimeError: In case the parameter constraint name is unknown.
"""
return WorkflowParameterPb.Constraint.RelationType.Value(parameter_constraint_name.upper())


def convert_json_to_parameter_constraint(
parameter_constraint_json: dict,
) -> WorkflowParameterPb.Constraint:
"""Convert a json document containing a parameter constraint definition to a `Constraint`.

:param parameter_constraint_json: The json document which contains the parameter constraint
definition.
:return: The converted parameter constraint definition.
"""
return WorkflowParameterPb.Constraint(
other_key_name=parameter_constraint_json["other_key_name"],
relation=convert_str_to_parameter_relation(parameter_constraint_json["relation"]),
)


@dataclass(eq=True, frozen=True)
class WorkflowType:
"""Define a type of workflow this SDK supports."""
Expand All @@ -854,7 +1002,6 @@ class WorkflowType:
workflow_parameters: Optional[List[WorkflowParameter]] = field(
default=None, hash=False, compare=False
)
"""Optional list of non-ESDL workflow parameters."""


class WorkflowTypeManager:
Expand Down Expand Up @@ -910,6 +1057,7 @@ def to_pb_message(self) -> AvailableWorkflows:
key_name=_parameter.key_name,
title=_parameter.title,
description=_parameter.description,
constraints=_parameter.constraints,
)
parameter_type_to_pb_type_oneof = {
StringParameter: parameter_pb.string_parameter,
Expand Down Expand Up @@ -938,6 +1086,7 @@ def from_pb_message(cls, available_workflows_pb: AvailableWorkflows) -> Self:
:return: WorkflowTypeManager instance.
"""
workflow_types = []
workflow_pb: Workflow
for workflow_pb in available_workflows_pb.workflows:
workflow_parameters: List[WorkflowParameter] = []
for parameter_pb in workflow_pb.parameters:
Expand All @@ -956,6 +1105,7 @@ def from_pb_message(cls, available_workflows_pb: AvailableWorkflows) -> Self:
workflow_parameters.append(parameter)
else:
raise RuntimeError(f"Unknown PB class {type(one_of_parameter_type_pb)}")

workflow_types.append(
WorkflowType(
workflow_type_name=workflow_pb.type_name,
Expand All @@ -974,20 +1124,20 @@ def from_json_config_file(cls, json_config_file_path: str) -> Self:
"""
with open(json_config_file_path, "r") as f:
json_config_dict = json.load(f)
logger.debug("Loading workflow config: %s", pprint.pformat(json_config_dict))
workflow_types = []
for _workflow in json_config_dict:
workflow_parameters = []
if "workflow_parameters" in _workflow:
for parameter_config in _workflow["workflow_parameters"]:
parameter_type_name = parameter_config["parameter_type"]
parameter_config.pop("parameter_type")

for parameter_type_class in PARAMETER_CLASS_TO_PB_CLASS:
if parameter_type_class.type_name == parameter_type_name:
workflow_parameters.append(
parameter_type_class.from_json_config(parameter_config)
)
break
for parameter_config in _workflow.get("workflow_parameters", []):
parameter_type_name = parameter_config["parameter_type"]
parameter_config.pop("parameter_type")

for parameter_type_class in PARAMETER_CLASS_TO_PB_CLASS:
if parameter_type_class.type_name == parameter_type_name:
workflow_parameters.append(
parameter_type_class.from_json_config(parameter_config)
)
break

workflow_types.append(
WorkflowType(
Expand Down Expand Up @@ -1023,6 +1173,11 @@ def convert_params_dict_to_struct(workflow: WorkflowType, params_dict: ParamsDic

normalized_dict[parameter.key_name] = parameter.to_pb_value(param_value)

for constraint in parameter.constraints:
other_value = params_dict[constraint.other_key_name]

parameter.check_parameter_constraint(param_value, other_value, constraint)

params_dict_struct = Struct()
params_dict_struct.update(normalized_dict)

Expand Down
Loading