From 0ca727c253329958a4a8169895692cb3ba335e90 Mon Sep 17 00:00:00 2001 From: Eyal-Danieli Date: Thu, 1 Jan 2026 11:02:24 +0200 Subject: [PATCH] update function yamls --- cli/cli.py | 8 +- cli/common/generate_item_yaml.py | 19 +- cli/common/test_suite.py | 142 ++--- cli/common/update_readme.py | 27 +- cli/functions/function_to_item.py | 4 +- cli/functions/item_to_function.py | 59 +- cli/marketplace/build.py | 71 ++- cli/utils/helpers.py | 52 +- cli/utils/path_iterator.py | 6 +- functions/src/aggregate/aggregate.py | 142 +++-- functions/src/aggregate/function.yaml | 48 +- functions/src/aggregate/test_aggregate.py | 57 +- .../src/arc_to_parquet/arc_to_parquet.py | 76 +-- functions/src/arc_to_parquet/function.yaml | 35 +- .../src/arc_to_parquet/test_arc_to_parquet.py | 46 +- functions/src/auto_trainer/auto_trainer.py | 29 +- functions/src/auto_trainer/function.yaml | 48 +- .../src/auto_trainer/test_auto_trainer.py | 19 +- functions/src/azureml_serving/function.yaml | 60 +- functions/src/azureml_utils/azureml_utils.py | 50 +- functions/src/azureml_utils/function.yaml | 127 ++-- .../src/azureml_utils/test_azureml_utils.py | 10 +- .../src/batch_inference/batch_inference.py | 39 +- functions/src/batch_inference/function.yaml | 46 +- .../batch_inference/test_batch_inference.py | 3 +- .../batch_inference_v2/batch_inference_v2.py | 171 +++--- .../src/batch_inference_v2/function.yaml | 56 +- functions/src/batch_inference_v2/item.yaml | 2 +- .../test_batch_inference_v2.py | 128 ++-- functions/src/describe/describe.py | 17 +- functions/src/describe/function.yaml | 34 +- functions/src/describe/test_describe.py | 6 +- functions/src/describe_dask/describe_dask.py | 22 +- functions/src/describe_dask/function.yaml | 35 +- .../src/describe_dask/test_describe_dask.py | 40 +- .../src/describe_spark/describe_spark.py | 551 +++++++++++------- functions/src/describe_spark/function.yaml | 320 +++++----- .../feature_selection/feature_selection.py | 7 +- functions/src/feature_selection/function.yaml | 43 +- functions/src/feature_selection/item.yaml | 2 +- functions/src/gen_class_data/function.yaml | 33 +- .../src/gen_class_data/gen_class_data.py | 30 +- .../src/gen_class_data/test_gen_class_data.py | 14 +- functions/src/github_utils/function.yaml | 54 +- functions/src/github_utils/github_utils.py | 12 +- .../src/hugging_face_serving/function.yaml | 27 +- .../hugging_face_serving.py | 6 +- .../test_hugging_face_serving.py | 6 +- functions/src/load_dataset/function.yaml | 64 +- functions/src/mlflow_utils/function.yaml | 37 +- functions/src/mlflow_utils/mlflow_utils.py | 7 +- .../src/mlflow_utils/test_mlflow_utils.py | 17 +- functions/src/model_server/function.yaml | 33 +- functions/src/model_server/model_server.py | 11 +- .../src/model_server/test_model_server.py | 38 +- .../src/model_server_tester/function.yaml | 40 +- .../model_server_tester.py | 10 +- functions/src/noise_reduction/function.yaml | 135 ++--- .../src/noise_reduction/noise_reduction.py | 28 +- functions/src/onnx_utils/function.yaml | 69 +-- functions/src/onnx_utils/onnx_utils.py | 16 +- functions/src/open_archive/function.yaml | 31 +- functions/src/open_archive/item.yaml | 2 +- functions/src/open_archive/open_archive.py | 100 ++-- .../src/open_archive/test_open_archive.py | 54 +- functions/src/pii_recognizer/function.yaml | 73 ++- .../src/pii_recognizer/pii_recognizer.py | 48 +- .../src/pii_recognizer/test_pii_recognizer.py | 10 +- functions/src/pyannote_audio/function.yaml | 64 +- .../src/pyannote_audio/pyannote_audio.py | 18 +- .../src/question_answering/function.yaml | 142 +++-- .../question_answering/question_answering.py | 69 +-- .../test_question_answering.py | 16 +- functions/src/send_email/function.yaml | 49 +- functions/src/send_email/send_email.py | 10 +- functions/src/silero_vad/function.yaml | 147 +++-- functions/src/silero_vad/silero_vad.py | 49 +- .../src/sklearn_classifier/function.yaml | 43 +- .../sklearn_classifier/sklearn_classifier.py | 11 +- .../test_sklearn_classifier.py | 61 +- .../src/sklearn_classifier_dask/function.yaml | 45 +- .../sklearn_classifier_dask.py | 19 +- .../structured_data_generator/function.yaml | 41 +- .../structured_data_generator.py | 6 +- .../test_structured_data_generator.py | 11 +- functions/src/test_classifier/function.yaml | 60 +- .../src/test_classifier/test_classifier.py | 9 +- .../src/text_to_audio_generator/function.yaml | 58 +- .../test_text_to_audio_generator.py | 2 +- .../text_to_audio_generator.py | 43 +- functions/src/tf2_serving/function.yaml | 66 +-- functions/src/tf2_serving/tf2_serving.py | 19 +- functions/src/transcribe/test_transcribe.py | 8 +- functions/src/transcribe/transcribe.py | 170 +++--- functions/src/translate/function.yaml | 67 ++- functions/src/translate/item.yaml | 2 +- functions/src/translate/test_translate.py | 5 +- functions/src/translate/translate.py | 16 +- functions/src/v2_model_server/function.yaml | 96 +-- .../src/v2_model_server/v2_model_server.py | 11 +- functions/src/v2_model_tester/function.yaml | 40 +- .../src/v2_model_tester/v2_model_tester.py | 9 +- modules/src/agent_deployer/agent_deployer.py | 17 +- .../src/agent_deployer/test_agent_deployer.py | 26 +- modules/src/count_events/count_events.py | 13 +- modules/src/count_events/item.yaml | 2 +- modules/src/count_events/test_count_events.py | 15 +- modules/src/evidently_iris/evidently_iris.py | 25 +- modules/src/evidently_iris/item.yaml | 2 +- .../src/evidently_iris/test_evidently_iris.py | 4 +- .../histogram_data_drift.py | 23 +- modules/src/histogram_data_drift/item.yaml | 2 +- .../test_histogram_data_drift.py | 47 +- .../src/openai_proxy_app/openai_proxy_app.py | 31 +- .../openai_proxy_app/test_openai_proxy_app.py | 10 +- modules/src/vllm_module/test_vllm_module.py | 7 +- modules/src/vllm_module/vllm_module.py | 63 +- pyproject.toml | 8 +- steps/src/verify_schema/test_verify_schema.py | 23 +- steps/src/verify_schema/verify_schema.py | 9 +- 120 files changed, 2755 insertions(+), 2716 deletions(-) diff --git a/cli/cli.py b/cli/cli.py index e8e6922fe..8d31ad38f 100644 --- a/cli/cli.py +++ b/cli/cli.py @@ -14,17 +14,19 @@ # import click +from cli.common.generate_item_yaml import generate_item_yaml +from cli.common.test_suite import test_suite +from cli.common.update_readme import update_readme from cli.functions.function_to_item import function_to_item_cli from cli.functions.item_to_function import item_to_function_cli from cli.marketplace.build import build_marketplace_cli -from cli.common.test_suite import test_suite -from cli.common.update_readme import update_readme -from cli.common.generate_item_yaml import generate_item_yaml + @click.group() def cli(): pass + cli.add_command(generate_item_yaml, name="generate-item-yaml") cli.add_command(item_to_function_cli, name="item-to-function") cli.add_command(function_to_item_cli, name="function-to-item") diff --git a/cli/common/generate_item_yaml.py b/cli/common/generate_item_yaml.py index e97089ad3..093d19fac 100644 --- a/cli/common/generate_item_yaml.py +++ b/cli/common/generate_item_yaml.py @@ -1,6 +1,7 @@ import sys -from pathlib import Path from datetime import datetime, timezone +from pathlib import Path + import click from jinja2 import Environment, FileSystemLoader @@ -14,14 +15,18 @@ @click.command() @click.argument("type", type=click.Choice(list(TEMPLATES.keys()))) @click.argument("name") -@click.option("--overwrite", is_flag=True, help="Replace existing file instead of raising an error.") +@click.option( + "--overwrite", + is_flag=True, + help="Replace existing file instead of raising an error.", +) def generate_item_yaml(type: str, name: str, overwrite: bool = False): """ - Generate an item.yaml file from a template. + Generate an item.yaml file from a template. -type: one of the supported types (currently only `function` or `module`) -name: the function/module name (also used as the directory name) -overwrite: whether to overwrite existing item.yaml file + type: one of the supported types (currently only `function` or `module`) + name: the function/module name (also used as the directory name) + overwrite: whether to overwrite existing item.yaml file """ # Construct the target path path = Path(f"{type}s/src/{name}").resolve() @@ -53,4 +58,4 @@ def generate_item_yaml(type: str, name: str, overwrite: bool = False): if __name__ == "__main__": - generate_item_yaml() \ No newline at end of file + generate_item_yaml() diff --git a/cli/common/test_suite.py b/cli/common/test_suite.py index 52dc1c5ae..9e1e7b983 100644 --- a/cli/common/test_suite.py +++ b/cli/common/test_suite.py @@ -12,23 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import re import subprocess +import sys from abc import ABC, abstractmethod from dataclasses import dataclass, field from pathlib import Path from subprocess import CompletedProcess -from typing import List, Union, Optional -import sys + import click import yaml -import re from cli.utils.helpers import ( - is_item_dir, + get_item_yaml_values, install_pipenv, install_python, install_requirements, - get_item_yaml_values, + is_item_dir, ) from cli.utils.path_iterator import PathIterator @@ -45,11 +45,13 @@ default=False, help="When true, test suite will stop running after the first test ran", ) -def test_suite(root_directory: str, - suite: str, - stop_on_failure: bool, - multi_processing: bool = False, - function_name: str = None): +def test_suite( + root_directory: str, + suite: str, + stop_on_failure: bool, + multi_processing: bool = False, + function_name: str = None, +): if not suite: click.echo("-s/--suite is required") exit(1) @@ -101,25 +103,19 @@ def test_example(root_dir="."): @dataclass class TestResult: status: str - status_code: Optional[int] + status_code: int | None meta_data: dict = field(default_factory=dict) @classmethod - def passed( - cls, status_code: Optional[int] = None, meta_data: Optional[dict] = None - ): + def passed(cls, status_code: int | None = None, meta_data: dict | None = None): return cls(status="Passed", status_code=status_code, meta_data=meta_data) @classmethod - def failed( - cls, status_code: Optional[int] = None, meta_data: Optional[dict] = None - ): + def failed(cls, status_code: int | None = None, meta_data: dict | None = None): return cls(status="Failed", status_code=status_code, meta_data=meta_data) @classmethod - def ignored( - cls, status_code: Optional[int] = None, meta_data: Optional[dict] = None - ): + def ignored(cls, status_code: int | None = None, meta_data: dict | None = None): return cls(status="Ignored", status_code=status_code, meta_data=meta_data) @@ -129,11 +125,11 @@ def __init__(self, stop_on_failure: bool = True): self.test_results = [] @abstractmethod - def discover(self, path: Union[str, Path]) -> List[str]: + def discover(self, path: str | Path) -> list[str]: pass @abstractmethod - def run(self, path: Union[str, Path]) -> TestResult: + def run(self, path: str | Path) -> TestResult: pass @abstractmethod @@ -145,27 +141,32 @@ def after_run(self): pass @abstractmethod - def before_each(self, path: Union[str, Path]): + def before_each(self, path: str | Path): pass @abstractmethod - def after_each(self, path: Union[str, Path], test_result: TestResult): + def after_each(self, path: str | Path, test_result: TestResult): pass - def _run(self, path: Union[str, Path], multiprocess, function_name): + def _run(self, path: str | Path, multiprocess, function_name): import multiprocessing as mp + process_count = 1 if multiprocess: process_count = mp.cpu_count() - 1 - print("running tests with {} process".format(process_count)) + print(f"running tests with {process_count} process") discovered_functions = self.discover(path) if function_name is not None: - click.echo("running test with name {}".format(function_name)) - discovered_functions = [fn for fn in discovered_functions if Path(function_name).stem == Path(fn).stem] + click.echo(f"running test with name {function_name}") + discovered_functions = [ + fn + for fn in discovered_functions + if Path(function_name).stem == Path(fn).stem + ] for path in discovered_functions: if re.match(".+/test_*", path): discovered_functions.remove(path) - print("a function name cannot start with test, please rename {} ".format(path)) + print(f"a function name cannot start with test, please rename {path} ") self.before_run() @@ -191,7 +192,7 @@ def __init__(self, stop_on_failure: bool = True, clean_env_artifacts: bool = Tru self.clean_env_artifacts = clean_env_artifacts self.results = [] - def discover(self, path: Union[str, Path]) -> List[str]: + def discover(self, path: str | Path) -> list[str]: path = Path(path) testable = [] item_yaml_path = path / "item.yaml" @@ -228,15 +229,21 @@ def discover(self, path: Union[str, Path]) -> List[str]: def before_run(self): install_pipenv() - def before_each(self, path: Union[str, Path]): + def before_each(self, path: str | Path): pass - def run(self, path: Union[str, Path]): - print("PY run path {}".format(path)) + def run(self, path: str | Path): + print(f"PY run path {path}") install_python(path) - item_requirements = list(get_item_yaml_values(path, 'requirements')['requirements']) - mlrun_version = list(get_item_yaml_values(path, "mlrunVersion")["mlrunVersion"])[0] - install_requirements(path, ["pytest", f"mlrun=={mlrun_version}"] + item_requirements) + item_requirements = list( + get_item_yaml_values(path, "requirements")["requirements"] + ) + mlrun_version = list( + get_item_yaml_values(path, "mlrunVersion")["mlrunVersion"] + )[0] + install_requirements( + path, ["pytest", f"mlrun=={mlrun_version}"] + item_requirements + ) click.echo(f"Running tests for {path}...") completed_process: CompletedProcess = subprocess.run( f"cd {path} ; pipenv run python -m pytest", @@ -256,7 +263,7 @@ def run(self, path: Union[str, Path]): meta_data=meta_data, ) - def after_each(self, path: Union[str, Path], test_result: TestResult): + def after_each(self, path: str | Path, test_result: TestResult): if self.clean_env_artifacts: clean_pipenv(path) @@ -314,11 +321,11 @@ def after_run(self): sys.exit(1) @staticmethod - def is_test_py(path: Union[str, Path]) -> bool: + def is_test_py(path: str | Path) -> bool: return ( - path.is_file() - and path.name.startswith("test_") - and path.name.endswith(".py") + path.is_file() + and path.name.startswith("test_") + and path.name.endswith(".py") ) @@ -328,7 +335,7 @@ def __init__(self, stop_on_failure: bool = True, clean_env_artifacts: bool = Tru self.clean_env_artifacts = clean_env_artifacts self.results = [] - def discover(self, path: Union[str, Path]) -> List[str]: + def discover(self, path: str | Path) -> list[str]: path = Path(path) testables = [] @@ -357,34 +364,34 @@ def discover(self, path: Union[str, Path]) -> List[str]: ) exit(0) testables.sort() - click.echo( - "tests list " + str(testables) - ) + click.echo("tests list " + str(testables)) return testables def before_run(self): install_pipenv() - def before_each(self, path: Union[str, Path]): + def before_each(self, path: str | Path): pass # def run(self, path: Union[str, Path]) -> TestResult: - def run(self, path: Union[str, Path]) -> TestResult: - print("IPYNB run path {}".format(path)) + def run(self, path: str | Path) -> TestResult: + print(f"IPYNB run path {path}") install_python(path) - item_requirements = list(get_item_yaml_values(path, 'requirements')['requirements']) + item_requirements = list( + get_item_yaml_values(path, "requirements")["requirements"] + ) install_requirements(path, ["papermill"] + item_requirements) click.echo(f"Running tests for {path}...") running_ipynb = Path(path).name + ".ipynb" click.echo(f"Running notebook {running_ipynb}") - command = f'pipenv run papermill {running_ipynb} out.ipynb --log-output' + command = f"pipenv run papermill {running_ipynb} out.ipynb --log-output" completed_process: CompletedProcess = subprocess.run( f"cd {path} ;echo {command} ; {command}", stdout=sys.stdout, stderr=subprocess.PIPE, cwd=path, - shell=True + shell=True, ) meta_data = {"completed_process": completed_process, "test_path": path} @@ -438,7 +445,7 @@ def after_run(self): if failed_tests: exit(1) - def after_each(self, path: Union[str, Path], test_result: TestResult): + def after_each(self, path: str | Path, test_result: TestResult): if self.clean_env_artifacts: clean_pipenv(path) @@ -454,22 +461,19 @@ def after_each(self, path: Union[str, Path], test_result: TestResult): click.echo(complete_subprocess.stderr.decode("utf-8")) exit(test_result.status_code) - def _run(self, path: Union[str, Path], multi_processing, function_name): + def _run(self, path: str | Path, multi_processing, function_name): super()._run(path, multi_processing, function_name) @staticmethod def is_test_ipynb(path: Path): - return ( - path.is_file() - and path.name.endswith(".ipynb") - ) + return path.is_file() and path.name.endswith(".ipynb") class TestItemYamls(TestSuite): def __init__(self, stop_on_failure: bool = True): super().__init__(stop_on_failure) - def discover(self, path: Union[str, Path]) -> List[str]: + def discover(self, path: str | Path) -> list[str]: path = Path(path) testables = [] @@ -493,9 +497,9 @@ def discover(self, path: Union[str, Path]) -> List[str]: return testables - def run(self, path: Union[str, Path]) -> TestResult: + def run(self, path: str | Path) -> TestResult: path = Path(path) - item = yaml.full_load(open(path, "r")) + item = yaml.full_load(open(path)) directory = path.parent if item.get("spec")["filename"]: @@ -572,10 +576,10 @@ def after_run(self): if failed_tests: exit(1) - def before_each(self, path: Union[str, Path]): + def before_each(self, path: str | Path): pass - def after_each(self, path: Union[str, Path], test_result: TestResult): + def after_each(self, path: str | Path, test_result: TestResult): if self.stop_on_failure: if test_result.status == "Failed": message = test_result.meta_data["message"] @@ -583,7 +587,7 @@ def after_each(self, path: Union[str, Path], test_result: TestResult): click.echo(f"Error: {message}") exit(1) - def _run(self, path: Union[str, Path]): + def _run(self, path: str | Path): super()._run(path) @@ -599,20 +603,24 @@ def clean_pipenv(directory: str): # load item yaml def load_item(path): - with open(path, 'r') as stream: + with open(path) as stream: data = yaml.load(stream=stream, Loader=yaml.FullLoader) return data def is_test_valid_by_item(item_posix_path): - full_path = str(item_posix_path.absolute())+'/item.yaml' + full_path = str(item_posix_path.absolute()) + "/item.yaml" data = load_item(full_path) if data.get("test_valid") is not None: test_valid = data.get("test_valid") test_name = data.get("name") if not test_valid: - click.echo("==================== Test {} Not valid ====================".format(test_name)) - click.echo("==================== enable test_valid in item.yaml ====================") + click.echo( + f"==================== Test {test_name} Not valid ====================" + ) + click.echo( + "==================== enable test_valid in item.yaml ====================" + ) return test_valid else: return True diff --git a/cli/common/update_readme.py b/cli/common/update_readme.py index f6e582bb6..7816ebaa5 100644 --- a/cli/common/update_readme.py +++ b/cli/common/update_readme.py @@ -14,8 +14,8 @@ import sys +from collections.abc import Iterable from pathlib import Path -from typing import Iterable, List, Tuple import click import yaml @@ -28,6 +28,7 @@ "steps": ("Name", "Description", "Class Name", "Categories"), } + @click.command("update-readme") @click.option("-c", "--channel", default="master", help="Name of build channel") @click.option( @@ -35,12 +36,14 @@ multiple=True, required=True, help="Asset types to process (e.g: functions). " - "Pass multiple: --asset functions --asset modules", + "Pass multiple: --asset functions --asset modules", +) +@click.option( + "--check", + is_flag=True, + help="Do not write; exit non‑zero if README(s) would change.", ) -@click.option("--check", is_flag=True, - help="Do not write; exit non‑zero if README(s) would change.") -def update_readme(channel: str, asset: Iterable[str], - check: bool) -> None: +def update_readme(channel: str, asset: Iterable[str], check: bool) -> None: """ Regenerate the README tables for asset types from their item.yaml files. """ @@ -102,7 +105,11 @@ def _rows_for_asset_type(channel: str, asset_dir: Path, columns) -> list: kind = (data.get("spec", {}).get("kind", "")).strip() class_name = (data.get("className", "")).strip() cats = data.get("categories") or [] - cats_str = ", ".join(c.strip() for c in cats) if isinstance(cats, list) else str(cats).strip() + cats_str = ( + ", ".join(c.strip() for c in cats) + if isinstance(cats, list) + else str(cats).strip() + ) # Link the name to its source directory # Construct the relative path from the repo root for the asset rel_path = asset_dir.relative_to(Path(".").resolve()) @@ -135,7 +142,11 @@ def _build_table_md(rows, columns) -> str: "| " + " | ".join("---" for _ in columns) + " |", ] for r in rows: - lines.append("| " + " | ".join((cell or "").replace("\n", " ").strip() for cell in r) + " |") + lines.append( + "| " + + " | ".join((cell or "").replace("\n", " ").strip() for cell in r) + + " |" + ) return "\n".join(lines) diff --git a/cli/functions/function_to_item.py b/cli/functions/function_to_item.py index c3c870d75..e31364961 100644 --- a/cli/functions/function_to_item.py +++ b/cli/functions/function_to_item.py @@ -14,7 +14,6 @@ # from datetime import datetime from pathlib import Path -from typing import Union import click import yaml @@ -70,8 +69,7 @@ def function_to_item(path: str): exit(0) -def function_yaml_to_item(function_path: Union[str, Path]) -> dict: - +def function_yaml_to_item(function_path: str | Path) -> dict: function_path = Path(function_path) function_yaml = yaml.full_load(open(function_path)) diff --git a/cli/functions/item_to_function.py b/cli/functions/item_to_function.py index be84c0dce..80d95cd00 100644 --- a/cli/functions/item_to_function.py +++ b/cli/functions/item_to_function.py @@ -13,12 +13,11 @@ # limitations under the License. # from pathlib import Path -from typing import Optional, Union import click import semver import yaml -from black import format_str, FileMode +from black import FileMode, format_str from mlrun import code_to_function from yaml import full_load @@ -55,17 +54,21 @@ help="If -b/--bump_version is enabled, increase the minor version in the item.yaml file", ) def item_to_function_cli( - item_path: str, output_path: Optional[str], code_output: bool, format_code: bool, bump_version: bool + item_path: str, + output_path: str | None, + code_output: bool, + format_code: bool, + bump_version: bool, ): item_to_function(item_path, output_path, code_output, format_code, bump_version) def item_to_function( - item_path: str, - output_path: Optional[str] = None, - code_output: bool = False, - format_code: bool = True, - bump_version: bool = False, + item_path: str, + output_path: str | None = None, + code_output: bool = False, + format_code: bool = True, + bump_version: bool = False, ): item_path = Path(item_path) if item_path.is_dir(): @@ -74,17 +77,21 @@ def item_to_function( # That means we are in a specific item directory if item_path.exists(): _output_path = output_path or item_path.parent / "function.yaml" - create_function_yaml(item_path, _output_path, code_output, format_code, bump_version) + create_function_yaml( + item_path, _output_path, code_output, format_code, bump_version + ) # That means we need to search for items inside this direcotry else: for inner_dir in PathIterator( - root=item_path.parent, - rule=is_item_dir, - as_path=True, + root=item_path.parent, + rule=is_item_dir, + as_path=True, ): try: _output_path = output_path or (inner_dir / "function.yaml") - create_function_yaml(inner_dir, _output_path, code_output, format_code, bump_version) + create_function_yaml( + inner_dir, _output_path, code_output, format_code, bump_version + ) except Exception as e: print(e) click.echo(f"{inner_dir.name}: Failed to generate function.yaml") @@ -114,16 +121,16 @@ def _get_item_yaml(item_path: Path) -> dict: elif not item_path.exists(): raise FileNotFoundError(f"{item_path} not found") - item_yaml = full_load(open(item_path, "r")) + item_yaml = full_load(open(item_path)) return item_path, item_yaml def create_function_yaml( - item_path: Union[str, Path], - output_path: Optional[str] = None, - code_output: bool = False, - format_code: bool = True, - bump_version: bool = False, + item_path: str | Path, + output_path: str | None = None, + code_output: bool = False, + format_code: bool = True, + bump_version: bool = False, ): item_path = Path(item_path) if bump_version: @@ -157,11 +164,15 @@ def create_function_yaml( labels=item_yaml.get("labels", {}), with_doc=True, ) + + # Store only the file name in the function spec for portability. + function_object.spec.filename = Path(filename).name + function_object.metadata.project = "" # remove build info from object - function_object.spec.build.code_origin = '' - function_object.spec.build.origin_filename = '' - if 'state_thresholds' not in spec: + function_object.spec.build.code_origin = "" + function_object.spec.build.origin_filename = "" + if "state_thresholds" not in spec: function_object.spec.state_thresholds = None custom_fields = spec.get("customFields", {}) @@ -194,7 +205,7 @@ def create_function_yaml( function_object.export(target=str(output_path.resolve())) if code_output and format_code: - with open(_code_output, "r") as file: + with open(_code_output) as file: code = file.read() code = format_str(code, mode=FileMode()) with open(_code_output, "w") as file: @@ -206,5 +217,5 @@ def bump_function_yaml_version(item_path: Path): item_ver = item_yaml.get("version", "0.0.0") new_ver = semver.Version.parse(item_ver).bump_minor() item_yaml["version"] = str(new_ver) - with open(item_path, 'w') as file: + with open(item_path, "w") as file: yaml.safe_dump(item_yaml, file, default_flow_style=False) diff --git a/cli/marketplace/build.py b/cli/marketplace/build.py index 206886631..0d65dacce 100644 --- a/cli/marketplace/build.py +++ b/cli/marketplace/build.py @@ -18,7 +18,6 @@ import subprocess import uuid from pathlib import Path -from typing import Dict, List, Optional, Set, Union import click import yaml @@ -26,9 +25,14 @@ from sphinx.cmd.build import main as sphinx_build_cmd from sphinx.ext.apidoc import main as sphinx_apidoc_cmd -from cli.utils.helpers import (PROJECT_ROOT, get_item_yaml_values, - get_mock_requirements, is_item_dir, render_jinja) from cli.marketplace.changelog import ChangeLog +from cli.utils.helpers import ( + PROJECT_ROOT, + get_item_yaml_values, + get_mock_requirements, + is_item_dir, + render_jinja, +) from cli.utils.path_iterator import PathIterator _verbose = False @@ -192,7 +196,7 @@ def build_marketplace( write_change_log(marketplace_root / "README.md", change_log) -def print_file_tree(title: str, path: Union[str, Path]): +def print_file_tree(title: str, path: str | Path): click.echo(f"\n\n -- {title}:") path = Path(path) lines = ["---------------------------------", f"\t{path.resolve()}"] @@ -210,7 +214,7 @@ def print_file_tree(title: str, path: Union[str, Path]): def write_change_log(readme_path: Path, change_log: ChangeLog): readme_path.touch(exist_ok=True) - content = open(readme_path, "r").read() + content = open(readme_path).read() if change_log.changes_available: with open(readme_path, "w") as f: compiled_change_log = change_log.compile() @@ -218,7 +222,7 @@ def write_change_log(readme_path: Path, change_log: ChangeLog): f.write(content) -def write_index_html(marketplace_root: Union[str, Path]): +def write_index_html(marketplace_root: str | Path): marketplace_root = Path(marketplace_root) index_path = marketplace_root / "index.html" template_path = PROJECT_ROOT / "cli" / "marketplace" / "index.html" @@ -238,7 +242,12 @@ def copy_resources(marketplace_dir, temp_docs): def update_or_create_items( - source_dir, source_name, marketplace_dir, temp_docs, change_log, force_update: bool = False + source_dir, + source_name, + marketplace_dir, + temp_docs, + change_log, + force_update: bool = False, ): click.echo("Creating items...") for item_dir in PathIterator(root=source_dir, rule=is_item_dir, as_path=True): @@ -248,9 +257,9 @@ def update_or_create_items( def build_catalog_json( - marketplace_dir: Union[str, Path], - source_directory: Union[str, Path], - catalog_path: Union[str, Path], + marketplace_dir: str | Path, + source_directory: str | Path, + catalog_path: str | Path, change_log: ChangeLog, in_channel_directory: bool = True, with_assets: bool = False, @@ -275,7 +284,7 @@ def build_catalog_json( channel = marketplace_dir.name source = marketplace_dir.parent.name - catalog = json.load(open(catalog_path, "r")) if catalog_path.exists() else {} + catalog = json.load(open(catalog_path)) if catalog_path.exists() else {} funcs = catalog if in_channel_directory: @@ -325,7 +334,7 @@ def update_item_in_catalog(directory: Path, with_assets: bool) -> dict: """ source_yaml_path = directory / "src" / "item.yaml" - item_yaml = yaml.full_load(open(source_yaml_path, "r")) + item_yaml = yaml.full_load(open(source_yaml_path)) item_yaml["generationDate"] = str(item_yaml["generationDate"]) if with_assets: add_assets(item_yaml) @@ -360,7 +369,7 @@ def update_or_create_item( force_update: bool = False, ): # Copy source directories to target directories, if target already has the directory, archive previous version - item_yaml = yaml.full_load(open(item_dir / "item.yaml", "r")) + item_yaml = yaml.full_load(open(item_dir / "item.yaml")) source_version = item_yaml["version"] relative_path = "../../../" @@ -369,9 +378,7 @@ def update_or_create_item( target_version = marketplace_item / source_version if target_version.exists() and not force_update: - latest_item_yaml = yaml.full_load( - open(target_latest / "src" / "item.yaml", "r") - ) + latest_item_yaml = yaml.full_load(open(target_latest / "src" / "item.yaml")) if item_yaml["hidden"] == latest_item_yaml.get("hidden"): click.echo("Source version already exists in target directory!") return @@ -432,8 +439,7 @@ def update_or_create_item( source_py_name = item_yaml.get("spec", {}).get("filename", "") if source_py_name.endswith(".py") and (item_dir / source_py_name).exists(): - - with open((item_dir / source_py_name), "r") as f: + with open(item_dir / source_py_name) as f: source_code = f.read() render_jinja( @@ -447,7 +453,7 @@ def update_or_create_item( {"source_code": source_code}, ) - with open((item_dir / "item.yaml"), "r") as f: + with open(item_dir / "item.yaml") as f: source_code = f.read() render_jinja( @@ -466,7 +472,7 @@ def update_or_create_item( asset_yaml_path = item_dir / f"{asset_name}.yaml" if asset_yaml_path.exists(): - with open(asset_yaml_path, "r") as f: + with open(asset_yaml_path) as f: source_code = f.read() render_jinja( templates / "yaml.html", @@ -490,7 +496,7 @@ def update_html_resource_paths( item_name: str = None, ): if html_path.exists(): - with open(html_path, "r", encoding="utf8") as html: + with open(html_path, encoding="utf8") as html: parsed = BeautifulSoup(html.read(), features="html.parser") # Update back to docs link (from source page) @@ -516,9 +522,9 @@ def update_html_resource_paths( nodes = parsed.find_all(lambda node: "_sources" in node.get("href", "")) for node in nodes: # fix path and remove example from name: - node[ - "href" - ] = f'../{node["href"].replace("_sources", "src").replace("_example", "")}' + node["href"] = ( + f"../{node['href'].replace('_sources', 'src').replace('_example', '')}" + ) else: # Removing download option from documentation: nodes = parsed.find_all( @@ -551,7 +557,7 @@ def patch_temp_docs(source_dir, temp_docs): for directory in PathIterator(root=source_dir, rule=is_item_dir): directory = Path(directory) - with open(directory / "item.yaml", "r") as f: + with open(directory / "item.yaml") as f: item = yaml.full_load(f) example_file = directory / item["example"] @@ -576,7 +582,7 @@ def build_temp_project(source_dir, temp_root): item_count += 1 click.echo(f"[Temporary project] Now processing: {directory / 'item.yaml'}") - with open(directory / "item.yaml", "r") as f: + with open(directory / "item.yaml") as f: item = yaml.full_load(f) filename = item.get("spec")["filename"] @@ -594,8 +600,8 @@ def build_temp_project(source_dir, temp_root): def collect_values_from_items( - source_dir: Union[Path, str], tags_set: Set[str] -) -> Dict[str, List[str]]: + source_dir: Path | str, tags_set: set[str] +) -> dict[str, list[str]]: """ Collecting all tags values from item.yaml files. If the `with_requirements` flag is on than also collecting requirements from ite.yaml and requirements.txt files. @@ -626,9 +632,7 @@ def collect_values_from_items( return tags -def sphinx_quickstart( - temp_root: Union[str, Path], requirements: Optional[List[str]] = None -): +def sphinx_quickstart(temp_root: str | Path, requirements: list[str] | None = None): """ Generate required files for a Sphinx project. sphinx-quickstart is an interactive tool that asks some questions about your project and then @@ -694,5 +698,8 @@ def build_temp_docs(temp_root, temp_docs, source_dir): sphinx_apidoc_cmd(cmd.split(" ")) - shutil.copytree(PROJECT_ROOT / "cli" / "marketplace" / "_static" / "css", temp_docs / '_static/css') + shutil.copytree( + PROJECT_ROOT / "cli" / "marketplace" / "_static" / "css", + temp_docs / "_static/css", + ) click.echo("[Sphinx] Done autodoc") diff --git a/cli/utils/helpers.py b/cli/utils/helpers.py index fabccbf7a..df67c8c3e 100644 --- a/cli/utils/helpers.py +++ b/cli/utils/helpers.py @@ -15,10 +15,10 @@ import os import pathlib import subprocess -from pathlib import Path -from typing import Union, List, Set, Dict import sys from glob import iglob +from pathlib import Path + import yaml from jinja2 import Template @@ -35,13 +35,11 @@ def is_function_dir(path: Path) -> bool: # dir_name = path.name # ipynb_found = any((f.name.endswith(".ipynb") for f in path.iterdir())) # py_found = any((f.name.endswith(".py") for f in path.iterdir())) - return any((f.name == "function.yaml" for f in path.iterdir())) + return any(f.name == "function.yaml" for f in path.iterdir()) -def render_jinja( - template_path: Union[str, Path], output_path: Union[str, Path], data: dict -): - with open(template_path, "r") as t: +def render_jinja(template_path: str | Path, output_path: str | Path, data: dict): + with open(template_path) as t: template_text = t.read() template = Template(template_text) @@ -54,7 +52,7 @@ def render_jinja( def install_pipenv(): print("Installing pipenv...") pipenv_install: subprocess.CompletedProcess = subprocess.run( - f"export PIP_NO_INPUT=1;pip install pipenv==2023.10.24", + "export PIP_NO_INPUT=1;pip install pipenv==2023.10.24", stdout=sys.stdout, stderr=subprocess.PIPE, shell=True, @@ -62,12 +60,16 @@ def install_pipenv(): exit_on_non_zero_return(pipenv_install) -def install_python(directory: Union[str, Path]): +def install_python(directory: str | Path): print(f"Installing python for {directory} ...") - install_command = f"pipenv --rm;pipenv --python 3.10.17" - if (os.environ.get('CONDA_DEFAULT_ENV') is not None) and (os.environ.get('CONDA_PREFIX') is not None): + install_command = "pipenv --rm;pipenv --python 3.10.17" + if (os.environ.get("CONDA_DEFAULT_ENV") is not None) and ( + os.environ.get("CONDA_PREFIX") is not None + ): print("conda env detected using conda to get pipenv python version") - install_command = f"pipenv --rm;pipenv --python=$(conda run which python) --site-packages" + install_command = ( + "pipenv --rm;pipenv --python=$(conda run which python) --site-packages" + ) python_install: subprocess.CompletedProcess = subprocess.run( install_command, stdout=sys.stdout, @@ -81,7 +83,7 @@ def install_python(directory: Union[str, Path]): stderr = python_install.stderr.decode("utf8") stderr = stderr.split("\n") python_location = [l for l in stderr if "Virtualenv location: " in l] - if python_location.count(python_location)>0: + if python_location.count(python_location) > 0: python_location = ( python_location[0].split("Virtualenv location: ")[-1] + "bin/python" ) @@ -90,7 +92,7 @@ def install_python(directory: Union[str, Path]): return python_location -def _run_subprocess(cmd: List[str], directory): +def _run_subprocess(cmd: list[str], directory): completed_process: subprocess.CompletedProcess = subprocess.run( cmd, stdout=sys.stdout, @@ -103,14 +105,14 @@ def _run_subprocess(cmd: List[str], directory): def install_requirements( directory: str, - requirements: Union[List[str], Set[str]], + requirements: list[str] | set[str], ): """ Installing requirements from a requirements list/set and from a requirements.txt file if found in directory :param directory: The relevant directory were the requirements are installed and collected :param requirements: Requirement list/set with or without bounds """ - requirements_file = Path(directory) / 'requirements.txt' + requirements_file = Path(directory) / "requirements.txt" if not requirements and not requirements_file.exists(): print(f"No requirements found for {directory}...") @@ -120,7 +122,7 @@ def install_requirements( print(f"Installing requirements from {requirements_file}...") cmd = ["pipenv", "install", "--skip-lock", "-r", str(requirements_file)] _run_subprocess(cmd, directory) - with open(requirements_file, "r") as f: + with open(requirements_file) as f: mlrun_version = [l.replace("\n", "") for l in f.readlines() if "mlrun" in l] # remove mlrun from requirements if installed with version limits: if mlrun_version and any([c in mlrun_version[0] for c in "<>=~"]): @@ -133,8 +135,8 @@ def install_requirements( def get_item_yaml_values( - item_path: pathlib.Path, keys: Union[str, Set[str]] -) -> Dict[str, Set[str]]: + item_path: pathlib.Path, keys: str | set[str] +) -> dict[str, set[str]]: """ Getting value from item.yaml requested field. @@ -153,7 +155,7 @@ def get_item_yaml_values( item_path = Path(item_path) if item_path.is_dir(): item_path = item_path / "item.yaml" - with open(item_path, "r") as f: + with open(item_path) as f: item = yaml.full_load(f) if key in item: values = item.get(key, "") @@ -174,7 +176,7 @@ def get_item_yaml_values( return values_dict -def get_mock_requirements(source_dir: Union[str, Path]) -> List[str]: +def get_mock_requirements(source_dir: str | Path) -> list[str]: """ Getting all requirements from .py files inside all the subdirectories of the given source dir. Only the files with the same name as their parent directory are taken in consideration. @@ -197,13 +199,13 @@ def get_mock_requirements(source_dir: Union[str, Path]) -> List[str]: # Skipping test files continue # Getting all packages: - with open(filename, 'r') as f: + with open(filename) as f: lines = list(filter(None, f.read().split("\n"))) for line in lines: - words = line.split(' ') + words = line.split(" ") words = [w for w in words if w] - if words and (words[0] == 'from' or words[0] == 'import'): - mock_reqs.add(words[1].split('.')[0]) + if words and (words[0] == "from" or words[0] == "import"): + mock_reqs.add(words[1].split(".")[0]) return sorted(mock_reqs) diff --git a/cli/utils/path_iterator.py b/cli/utils/path_iterator.py index 0aaccc2b7..2ea62588d 100644 --- a/cli/utils/path_iterator.py +++ b/cli/utils/path_iterator.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from collections.abc import Callable from pathlib import Path -from typing import Optional, Callable, Union class PathIterator: @@ -27,8 +27,8 @@ class PathIterator: def __init__( self, - root: Union[str, Path], - rule: Optional[Callable[[Path], bool]] = None, + root: str | Path, + rule: Callable[[Path], bool] | None = None, recursive: bool = False, absolute: bool = True, as_path: bool = False, diff --git a/functions/src/aggregate/aggregate.py b/functions/src/aggregate/aggregate.py index 1e9d8502d..f3f555569 100644 --- a/functions/src/aggregate/aggregate.py +++ b/functions/src/aggregate/aggregate.py @@ -15,51 +15,52 @@ # Generated by nuclio.export.NuclioExporter import os + import pandas as pd from mlrun.datastore import DataItem -from typing import Union - - -def aggregate(context, - df_artifact: Union[DataItem, pd.core.frame.DataFrame], - save_to: str = 'aggregated-df.pq', - keys: list = None, - metrics: list = None, - labels: list = None, - metric_aggregations: list = ['mean'], - label_aggregations: list = ['max'], - suffix: str = '', - window: int = 3, - center: bool = False, - inplace: bool = False, - drop_na: bool = True, - files_to_select: int = 1): + +def aggregate( + context, + df_artifact: DataItem | pd.core.frame.DataFrame, + save_to: str = "aggregated-df.pq", + keys: list = None, + metrics: list = None, + labels: list = None, + metric_aggregations: list = ["mean"], + label_aggregations: list = ["max"], + suffix: str = "", + window: int = 3, + center: bool = False, + inplace: bool = False, + drop_na: bool = True, + files_to_select: int = 1, +): """Time-series aggregation function - + Will perform a rolling aggregation on {df_artifact}, over {window} by the selected {keys} applying {metric_aggregations} on {metrics} and {label_aggregations} on {labels}. adding {suffix} to the feature names. - + if not {inplace}, will return the original {df_artifact}, joined by the aggregated result. :param context: After running a job, you need to be able to track it. To gain the maximum value, MLRun uses the job context object inside the code. This provides access to job metadata, parameters, inputs, secrets, and API for logging and monitoring the results, as well as log text, files, artifacts, and labels. - - :param df_artifact: MLRun input pointing to pandas dataframe (csv/parquet file path) or a + + :param df_artifact: MLRun input pointing to pandas dataframe (csv/parquet file path) or a directory containing parquet files. * When given a directory the latest {files_to_select} will be selected :param save_to: Where to save the result dataframe. * If relative will add to the {artifact_path} :param keys: Subset of indexes from the source dataframe to aggregate by (default=all) - :param metrics: Array containing a list of metrics to run the aggregations on. (default=None) - :param labels: Array containing a list of labels to run the aggregations on. (default=None) + :param metrics: Array containing a list of metrics to run the aggregations on. (default=None) + :param labels: Array containing a list of labels to run the aggregations on. (default=None) :param metric_aggregations: Array containing a list of aggregation function names to run on {metrics}. (Ex: 'mean', 'std') (default='mean') :param label_aggregations: Array containing a list of aggregation function names to run on {metrics}. - (Ex: 'max', 'min') (default='max') + (Ex: 'max', 'min') (default='max') :param suffix: Suffix to add to the feature name, E.g: __ (Ex: 'last_60_minutes') (default='') :param window: Window size to perform the rolling aggregate on. (default=3) @@ -70,70 +71,99 @@ def aggregate(context, :param drop_na: Will drop na lines due to the Rolling. :param files_to_select: Specifies the number of *latest* files to select (and concat) for aggregation. """ - + from_model = type(df_artifact) == pd.DataFrame if from_model: - context.logger.info('Aggregating from Buffer') + context.logger.info("Aggregating from Buffer") input_df = df_artifact else: - if df_artifact.url.endswith('/'): # is a directory? - mpath = [os.path.join(df_artifact.url, file) for file in df_artifact.listdir() if file.endswith(('parquet', 'pq'))] + if df_artifact.url.endswith("/"): # is a directory? + mpath = [ + os.path.join(df_artifact.url, file) + for file in df_artifact.listdir() + if file.endswith(("parquet", "pq")) + ] files_by_updated = sorted(mpath, key=os.path.getmtime, reverse=True) context.logger.info(files_by_updated) latest = files_by_updated[:files_to_select] - context.logger.info(f'Aggregating {latest}') + context.logger.info(f"Aggregating {latest}") input_df = pd.concat([context.get_dataitem(df).as_df() for df in latest]) else: # A regular artifact - context.logger.info(f'Aggregating {df_artifact.url}') + context.logger.info(f"Aggregating {df_artifact.url}") input_df = df_artifact.as_df() - + if not (metrics or labels): - raise ValueError('please specify metrics or labels param') - + raise ValueError("please specify metrics or labels param") + if keys: current_index = input_df.index.names indexes_to_drop = [col for col in input_df.index.names if col not in keys] df = input_df.reset_index(level=indexes_to_drop) else: df = input_df - + if metrics: - metrics_df = df.loc[:, metrics].rolling(window=window, center=center).aggregate(metric_aggregations) - metrics_df.columns = ['_'.join(col).strip() for col in metrics_df.columns.values] - + metrics_df = ( + df.loc[:, metrics] + .rolling(window=window, center=center) + .aggregate(metric_aggregations) + ) + metrics_df.columns = [ + "_".join(col).strip() for col in metrics_df.columns.values + ] + if suffix: - metrics_df.columns = [f'{metric}_{suffix}' for metric in metrics_df.columns] - + metrics_df.columns = [f"{metric}_{suffix}" for metric in metrics_df.columns] + if not inplace: - final_df = pd.merge(input_df, metrics_df, suffixes=('', suffix), left_index=True, right_index=True) + final_df = pd.merge( + input_df, + metrics_df, + suffixes=("", suffix), + left_index=True, + right_index=True, + ) else: final_df = metrics_df if labels: - labels_df = df.loc[:, labels].rolling(window=window, - center=center).aggregate(label_aggregations) - labels_df.columns = ['_'.join(col).strip() for col in labels_df.columns.values] - + labels_df = ( + df.loc[:, labels] + .rolling(window=window, center=center) + .aggregate(label_aggregations) + ) + labels_df.columns = ["_".join(col).strip() for col in labels_df.columns.values] + if suffix: - labels_df.columns = [f'{label}_{suffix}' for label in labels_df.columns] - + labels_df.columns = [f"{label}_{suffix}" for label in labels_df.columns] + if metrics: - final_df = pd.merge(final_df, labels_df, suffixes=('', suffix), left_index=True, right_index=True) + final_df = pd.merge( + final_df, + labels_df, + suffixes=("", suffix), + left_index=True, + right_index=True, + ) else: if not inplace: - final_df = pd.merge(input_df, labels_df, suffixes=('', suffix), left_index=True, right_index=True) + final_df = pd.merge( + input_df, + labels_df, + suffixes=("", suffix), + left_index=True, + right_index=True, + ) else: final_df = labels_df - + if drop_na: final_df = final_df.dropna() - - context.logger.info('Logging artifact') + + context.logger.info("Logging artifact") if not from_model: - context.log_dataset(key='aggregate', - df=final_df, - format='parquet', - local_path=save_to) + context.log_dataset( + key="aggregate", df=final_df, format="parquet", local_path=save_to + ) else: return final_df - diff --git a/functions/src/aggregate/function.yaml b/functions/src/aggregate/function.yaml index 4782ee7ea..ba8b0656a 100644 --- a/functions/src/aggregate/function.yaml +++ b/functions/src/aggregate/function.yaml @@ -1,4 +1,18 @@ +metadata: + tag: '' + name: aggregate + categories: + - data-preparation +verbose: false +kind: job spec: + image: mlrun/mlrun + disable_auto_mount: false + build: + origin_filename: '' + functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Generated by nuclio.export.NuclioExporter

import os

import pandas as pd
from mlrun.datastore import DataItem


def aggregate(
    context,
    df_artifact: DataItem | pd.core.frame.DataFrame,
    save_to: str = "aggregated-df.pq",
    keys: list = None,
    metrics: list = None,
    labels: list = None,
    metric_aggregations: list = ["mean"],
    label_aggregations: list = ["max"],
    suffix: str = "",
    window: int = 3,
    center: bool = False,
    inplace: bool = False,
    drop_na: bool = True,
    files_to_select: int = 1,
):
    """Time-series aggregation function

    Will perform a rolling aggregation on {df_artifact}, over {window} by the selected {keys}
    applying {metric_aggregations} on {metrics} and {label_aggregations} on {labels}. adding {suffix} to the
    feature names.

    if not {inplace}, will return the original {df_artifact}, joined by the aggregated result.

    :param context: After running a job, you need to be able to track it. To gain the maximum value, MLRun uses the
                    job context object inside the code. This provides access to job metadata, parameters,
                    inputs, secrets, and API for logging and monitoring the results, as well as log text, files,
                    artifacts, and labels.

    :param df_artifact: MLRun input pointing to pandas dataframe (csv/parquet file path) or a
                        directory containing parquet files.
                        * When given a directory the latest {files_to_select} will be selected
    :param save_to:     Where to save the result dataframe.
                        * If relative will add to the {artifact_path}
    :param keys:        Subset of indexes from the source dataframe to aggregate by (default=all)
    :param metrics:     Array containing a list of metrics to run the aggregations on. (default=None)
    :param labels:      Array containing a list of labels to run the aggregations on. (default=None)
    :param metric_aggregations: Array containing a list of aggregation function names to run on {metrics}.
                        (Ex: 'mean', 'std') (default='mean')
    :param label_aggregations:  Array containing a list of aggregation function names to run on {metrics}.
                        (Ex: 'max', 'min') (default='max')
    :param suffix:      Suffix to add to the feature name, E.g: <Feature_Name>_<Agg_Function>_<Suffix>
                        (Ex: 'last_60_minutes') (default='')
    :param window:      Window size to perform the rolling aggregate on. (default=3)
    :param center:      If True, Sets the value for the central sample in the window,
                        If False, will set the value to the last sample. (default=False)
    :param inplace:     If True, will return only the aggregated results.
                        If False, will join the aggregated results with the original dataframe
    :param drop_na:     Will drop na lines due to the Rolling.
    :param files_to_select: Specifies the number of *latest* files to select (and concat) for aggregation.
    """

    from_model = type(df_artifact) == pd.DataFrame
    if from_model:
        context.logger.info("Aggregating from Buffer")
        input_df = df_artifact
    else:
        if df_artifact.url.endswith("/"):  # is a directory?
            mpath = [
                os.path.join(df_artifact.url, file)
                for file in df_artifact.listdir()
                if file.endswith(("parquet", "pq"))
            ]
            files_by_updated = sorted(mpath, key=os.path.getmtime, reverse=True)
            context.logger.info(files_by_updated)
            latest = files_by_updated[:files_to_select]
            context.logger.info(f"Aggregating {latest}")
            input_df = pd.concat([context.get_dataitem(df).as_df() for df in latest])
        else:  # A regular artifact
            context.logger.info(f"Aggregating {df_artifact.url}")
            input_df = df_artifact.as_df()

    if not (metrics or labels):
        raise ValueError("please specify metrics or labels param")

    if keys:
        current_index = input_df.index.names
        indexes_to_drop = [col for col in input_df.index.names if col not in keys]
        df = input_df.reset_index(level=indexes_to_drop)
    else:
        df = input_df

    if metrics:
        metrics_df = (
            df.loc[:, metrics]
            .rolling(window=window, center=center)
            .aggregate(metric_aggregations)
        )
        metrics_df.columns = [
            "_".join(col).strip() for col in metrics_df.columns.values
        ]

        if suffix:
            metrics_df.columns = [f"{metric}_{suffix}" for metric in metrics_df.columns]

        if not inplace:
            final_df = pd.merge(
                input_df,
                metrics_df,
                suffixes=("", suffix),
                left_index=True,
                right_index=True,
            )
        else:
            final_df = metrics_df

    if labels:
        labels_df = (
            df.loc[:, labels]
            .rolling(window=window, center=center)
            .aggregate(label_aggregations)
        )
        labels_df.columns = ["_".join(col).strip() for col in labels_df.columns.values]

        if suffix:
            labels_df.columns = [f"{label}_{suffix}" for label in labels_df.columns]

        if metrics:
            final_df = pd.merge(
                final_df,
                labels_df,
                suffixes=("", suffix),
                left_index=True,
                right_index=True,
            )
        else:
            if not inplace:
                final_df = pd.merge(
                    input_df,
                    labels_df,
                    suffixes=("", suffix),
                    left_index=True,
                    right_index=True,
                )
            else:
                final_df = labels_df

    if drop_na:
        final_df = final_df.dropna()

    context.logger.info("Logging artifact")
    if not from_model:
        context.log_dataset(
            key="aggregate", df=final_df, format="parquet", local_path=save_to
        )
    else:
        return final_df
 + code_origin: '' + filename: aggregate.py entry_points: aggregate: parameters: @@ -8,10 +22,9 @@ spec: access to job metadata, parameters, inputs, secrets, and API for logging and monitoring the results, as well as log text, files, artifacts, and labels. - name: df_artifact - type: Union[DataItem, pd.core.frame.DataFrame] - doc: MLRun input pointing to pandas dataframe (csv/parquet file path) or a directory - containing parquet files. * When given a directory the latest {files_to_select} - will be selected + doc: MLRun input pointing to pandas dataframe (csv/parquet file path) or a + directory containing parquet files. * When given a directory the latest + {files_to_select} will be selected - name: save_to type: str doc: Where to save the result dataframe. * If relative will add to the {artifact_path} @@ -22,11 +35,11 @@ spec: default: null - name: metrics type: list - doc: 'Array containing a list of metrics to run the aggregations on. (default=None) ' + doc: Array containing a list of metrics to run the aggregations on. (default=None) default: null - name: labels type: list - doc: 'Array containing a list of labels to run the aggregations on. (default=None) ' + doc: Array containing a list of labels to run the aggregations on. (default=None) default: null - name: metric_aggregations type: list @@ -37,7 +50,7 @@ spec: - name: label_aggregations type: list doc: 'Array containing a list of aggregation function names to run on {metrics}. - (Ex: ''max'', ''min'') (default=''max'') ' + (Ex: ''max'', ''min'') (default=''max'')' default: - max - name: suffix @@ -67,6 +80,7 @@ spec: type: int doc: Specifies the number of *latest* files to select (and concat) for aggregation. default: 1 + name: aggregate doc: 'Time-series aggregation function @@ -81,23 +95,9 @@ spec: if not {inplace}, will return the original {df_artifact}, joined by the aggregated result.' - has_varargs: false - name: aggregate has_kwargs: false - lineno: 24 - disable_auto_mount: false + has_varargs: false + lineno: 23 + command: '' description: Rolling aggregation over Metrics and Lables according to specifications default_handler: aggregate - image: mlrun/mlrun - command: '' - build: - functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Generated by nuclio.export.NuclioExporter

import os
import pandas as pd
from mlrun.datastore import DataItem

from typing import Union


def aggregate(context,
              df_artifact: Union[DataItem, pd.core.frame.DataFrame],
              save_to: str = 'aggregated-df.pq',
              keys: list = None,
              metrics: list = None,
              labels: list = None,
              metric_aggregations: list = ['mean'],
              label_aggregations: list = ['max'],
              suffix: str = '',
              window: int = 3,
              center: bool = False,
              inplace: bool = False,
              drop_na: bool = True,
              files_to_select: int = 1):
    """Time-series aggregation function
    
    Will perform a rolling aggregation on {df_artifact}, over {window} by the selected {keys}
    applying {metric_aggregations} on {metrics} and {label_aggregations} on {labels}. adding {suffix} to the
    feature names.
    
    if not {inplace}, will return the original {df_artifact}, joined by the aggregated result.

    :param context: After running a job, you need to be able to track it. To gain the maximum value, MLRun uses the
                    job context object inside the code. This provides access to job metadata, parameters,
                    inputs, secrets, and API for logging and monitoring the results, as well as log text, files,
                    artifacts, and labels.
    
    :param df_artifact: MLRun input pointing to pandas dataframe (csv/parquet file path) or a 
                        directory containing parquet files.
                        * When given a directory the latest {files_to_select} will be selected
    :param save_to:     Where to save the result dataframe.
                        * If relative will add to the {artifact_path}
    :param keys:        Subset of indexes from the source dataframe to aggregate by (default=all)
    :param metrics:     Array containing a list of metrics to run the aggregations on. (default=None) 
    :param labels:      Array containing a list of labels to run the aggregations on. (default=None) 
    :param metric_aggregations: Array containing a list of aggregation function names to run on {metrics}.
                        (Ex: 'mean', 'std') (default='mean')
    :param label_aggregations:  Array containing a list of aggregation function names to run on {metrics}.
                        (Ex: 'max', 'min') (default='max') 
    :param suffix:      Suffix to add to the feature name, E.g: <Feature_Name>_<Agg_Function>_<Suffix>
                        (Ex: 'last_60_minutes') (default='')
    :param window:      Window size to perform the rolling aggregate on. (default=3)
    :param center:      If True, Sets the value for the central sample in the window,
                        If False, will set the value to the last sample. (default=False)
    :param inplace:     If True, will return only the aggregated results.
                        If False, will join the aggregated results with the original dataframe
    :param drop_na:     Will drop na lines due to the Rolling.
    :param files_to_select: Specifies the number of *latest* files to select (and concat) for aggregation.
    """
    
    from_model = type(df_artifact) == pd.DataFrame
    if from_model:
        context.logger.info('Aggregating from Buffer')
        input_df = df_artifact
    else:
        if df_artifact.url.endswith('/'):   # is a directory?
            mpath = [os.path.join(df_artifact.url, file) for file in df_artifact.listdir() if file.endswith(('parquet', 'pq'))]
            files_by_updated = sorted(mpath, key=os.path.getmtime, reverse=True)
            context.logger.info(files_by_updated)
            latest = files_by_updated[:files_to_select]
            context.logger.info(f'Aggregating {latest}')
            input_df = pd.concat([context.get_dataitem(df).as_df() for df in latest])
        else:  # A regular artifact
            context.logger.info(f'Aggregating {df_artifact.url}')
            input_df = df_artifact.as_df()
    
    if not (metrics or labels):
        raise ValueError('please specify metrics or labels param')
    
    if keys:
        current_index = input_df.index.names
        indexes_to_drop = [col for col in input_df.index.names if col not in keys]
        df = input_df.reset_index(level=indexes_to_drop)
    else:
        df = input_df
        
    if metrics:
        metrics_df = df.loc[:, metrics].rolling(window=window, center=center).aggregate(metric_aggregations)
        metrics_df.columns = ['_'.join(col).strip() for col in metrics_df.columns.values]
        
        if suffix:
            metrics_df.columns = [f'{metric}_{suffix}' for metric in metrics_df.columns]
            
        if not inplace:
            final_df = pd.merge(input_df, metrics_df, suffixes=('', suffix), left_index=True, right_index=True)
        else:
            final_df = metrics_df

    if labels:
        labels_df = df.loc[:, labels].rolling(window=window,
                                              center=center).aggregate(label_aggregations)
        labels_df.columns = ['_'.join(col).strip() for col in labels_df.columns.values]
        
        if suffix:
            labels_df.columns = [f'{label}_{suffix}' for label in labels_df.columns]
            
        if metrics:
            final_df = pd.merge(final_df, labels_df, suffixes=('', suffix), left_index=True, right_index=True)   
        else:
            if not inplace:
                final_df = pd.merge(input_df, labels_df, suffixes=('', suffix), left_index=True, right_index=True)      
            else:
                final_df = labels_df
                
    if drop_na:
        final_df = final_df.dropna()
        
    context.logger.info('Logging artifact')
    if not from_model:
        context.log_dataset(key='aggregate', 
                            df=final_df, 
                            format='parquet',
                            local_path=save_to)
    else:
        return final_df

 - code_origin: '' - origin_filename: '' -verbose: false -metadata: - categories: - - data-preparation - name: aggregate - tag: '' -kind: job diff --git a/functions/src/aggregate/test_aggregate.py b/functions/src/aggregate/test_aggregate.py index 87248ac50..694da13ad 100644 --- a/functions/src/aggregate/test_aggregate.py +++ b/functions/src/aggregate/test_aggregate.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from pathlib import Path import os + from mlrun import code_to_function, import_function AGGREGATE_PATH = "artifacts/aggregate.pq" @@ -21,26 +21,27 @@ def test_run_local_aggregate(): - fn = code_to_function(name='code_to_function', - filename="aggregate.py", - handler="aggregate", - kind="local", - ) + fn = code_to_function( + name="code_to_function", + filename="aggregate.py", + handler="aggregate", + kind="local", + ) fn.run( params={ - 'metrics': ['cpu_utilization'], - 'labels': ['is_error'], - 'metric_aggs': ['mean', 'sum'], - 'label_aggs': ['max'], - 'suffix': 'daily', - 'inplace': False, - 'window': 5, - 'center': True, - 'save_to': AGGREGATE_PATH, - 'files_to_select': 2 + "metrics": ["cpu_utilization"], + "labels": ["is_error"], + "metric_aggs": ["mean", "sum"], + "label_aggs": ["max"], + "suffix": "daily", + "inplace": False, + "window": 5, + "center": True, + "save_to": AGGREGATE_PATH, + "files_to_select": 2, }, local=True, - inputs={'df_artifact': DATA} + inputs={"df_artifact": DATA}, ) assert os.path.exists("code-to-function-aggregate/0/aggregate.pq") == True @@ -49,18 +50,18 @@ def test_import_function_aggregate(): fn = import_function("function.yaml") fn.run( params={ - 'metrics': ['cpu_utilization'], - 'labels': ['is_error'], - 'metric_aggs': ['mean', 'sum'], - 'label_aggs': ['max'], - 'suffix': 'daily', - 'inplace': False, - 'window': 5, - 'center': True, - 'save_to': AGGREGATE_PATH, - 'files_to_select': 2, + "metrics": ["cpu_utilization"], + "labels": ["is_error"], + "metric_aggs": ["mean", "sum"], + "label_aggs": ["max"], + "suffix": "daily", + "inplace": False, + "window": 5, + "center": True, + "save_to": AGGREGATE_PATH, + "files_to_select": 2, }, local=True, - inputs={'df_artifact': DATA}, + inputs={"df_artifact": DATA}, ) assert os.path.exists("aggregate-aggregate/0/aggregate.pq") == True diff --git a/functions/src/arc_to_parquet/arc_to_parquet.py b/functions/src/arc_to_parquet/arc_to_parquet.py index d9275b7ca..ebae092a1 100644 --- a/functions/src/arc_to_parquet/arc_to_parquet.py +++ b/functions/src/arc_to_parquet/arc_to_parquet.py @@ -12,28 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os + +import numpy as np import pandas as pd -import pyarrow.parquet as pq import pyarrow as pa -import numpy as np - - -from mlrun.execution import MLClientCtx +import pyarrow.parquet as pq from mlrun.datastore import DataItem - -from typing import List -import os - +from mlrun.execution import MLClientCtx def _chunk_readwrite( - archive_url, - dest_path, - chunksize, - header, - encoding, - dtype, - dataset + archive_url, dest_path, chunksize, header, encoding, dtype, dataset ): """stream read and write archives @@ -46,9 +36,15 @@ def _chunk_readwrite( """ pqwriter = None header = [] - for i, df in enumerate(pd.read_csv(archive_url, chunksize=chunksize, - names=header, encoding=encoding, - dtype=dtype)): + for i, df in enumerate( + pd.read_csv( + archive_url, + chunksize=chunksize, + names=header, + encoding=encoding, + dtype=dtype, + ) + ): table = pa.Table.from_pandas(df) if i == 0: if dataset: @@ -56,7 +52,9 @@ def _chunk_readwrite( else: pqwriter = pq.ParquetWriter(dest_path, table.schema) if dataset: - pq.write_to_dataset(table, root_path=dest_path, partition_cols=partition_cols) + pq.write_to_dataset( + table, root_path=dest_path, partition_cols=partition_cols + ) else: pqwriter.write_table(table) if pqwriter: @@ -66,19 +64,19 @@ def _chunk_readwrite( def arc_to_parquet( - context: MLClientCtx, - archive_url: DataItem, - header: List[str] = [None], - chunksize: int = 0, - dtype=None, - encoding: str = "latin-1", - key: str = "data", - dataset: str = "None", - part_cols=[], - file_ext: str = "parquet", - index: bool = False, - refresh_data: bool = False, - stats: bool = False + context: MLClientCtx, + archive_url: DataItem, + header: list[str] = [None], + chunksize: int = 0, + dtype=None, + encoding: str = "latin-1", + key: str = "data", + dataset: str = "None", + part_cols=[], + file_ext: str = "parquet", + index: bool = False, + refresh_data: bool = False, + stats: bool = False, ) -> None: """Open a file/object archive and save as a parquet file or dataset @@ -123,12 +121,14 @@ def arc_to_parquet( if not exists: context.logger.info("destination file does not exist, downloading") if chunksize > 0: - header = _chunk_readwrite(archive_url, dest_path, chunksize, - encoding, dtype, dataset) - context.log_dataset(key=key, stats=stats, format='parquet', - target_path=dest_path) + header = _chunk_readwrite( + archive_url, dest_path, chunksize, encoding, dtype, dataset + ) + context.log_dataset( + key=key, stats=stats, format="parquet", target_path=dest_path + ) else: df = pd.read_csv(archive_url) context.log_dataset(key, df=df, format=file_ext, index=index) else: - context.logger.info("destination file already exists, nothing done") \ No newline at end of file + context.logger.info("destination file already exists, nothing done") diff --git a/functions/src/arc_to_parquet/function.yaml b/functions/src/arc_to_parquet/function.yaml index ca2c31921..7feaf4ece 100644 --- a/functions/src/arc_to_parquet/function.yaml +++ b/functions/src/arc_to_parquet/function.yaml @@ -1,17 +1,22 @@ +metadata: + tag: '' + name: arc-to-parquet + categories: + - utils +verbose: false kind: job spec: image: mlrun/mlrun + disable_auto_mount: false build: - functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKaW1wb3J0IHBhbmRhcyBhcyBwZAppbXBvcnQgcHlhcnJvdy5wYXJxdWV0IGFzIHBxCmltcG9ydCBweWFycm93IGFzIHBhCmltcG9ydCBudW1weSBhcyBucAoKCmZyb20gbWxydW4uZXhlY3V0aW9uIGltcG9ydCBNTENsaWVudEN0eApmcm9tIG1scnVuLmRhdGFzdG9yZSBpbXBvcnQgRGF0YUl0ZW0KCmZyb20gdHlwaW5nIGltcG9ydCBMaXN0CmltcG9ydCBvcwoKCgpkZWYgX2NodW5rX3JlYWR3cml0ZSgKICAgICAgICBhcmNoaXZlX3VybCwKICAgICAgICBkZXN0X3BhdGgsCiAgICAgICAgY2h1bmtzaXplLAogICAgICAgIGhlYWRlciwKICAgICAgICBlbmNvZGluZywKICAgICAgICBkdHlwZSwKICAgICAgICBkYXRhc2V0Cik6CiAgICAiIiJzdHJlYW0gcmVhZCBhbmQgd3JpdGUgYXJjaGl2ZXMKCiAgICBwYW5kYXMgcmVhZHMgYW5kIHBhcnF1ZXQgd3JpdGVzCgogICAgbm90ZXMKICAgIC0tLS0tCiAgICAqIGRlc3RfcGF0aCBjYW4gYmUgZWl0aGVyIGEgZmlsZS5wYXJxdWV0LCBvciBpbiBodGUgY2FzZSBvZiBwYXJ0aXRpb25lZCBwYXJxdWV0CiAgICAgIGl0IHdpbGwgYmUgb25seSB0aGUgZGVzdGluYXRpb24gZm9sZGVyIG9mIHRoZSBwYXJxdWV0IHBhcnRpdGlvbiBmaWxlcwogICAgIiIiCiAgICBwcXdyaXRlciA9IE5vbmUKICAgIGhlYWRlciA9IFtdCiAgICBmb3IgaSwgZGYgaW4gZW51bWVyYXRlKHBkLnJlYWRfY3N2KGFyY2hpdmVfdXJsLCBjaHVua3NpemU9Y2h1bmtzaXplLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBuYW1lcz1oZWFkZXIsIGVuY29kaW5nPWVuY29kaW5nLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBkdHlwZT1kdHlwZSkpOgogICAgICAgIHRhYmxlID0gcGEuVGFibGUuZnJvbV9wYW5kYXMoZGYpCiAgICAgICAgaWYgaSA9PSAwOgogICAgICAgICAgICBpZiBkYXRhc2V0OgogICAgICAgICAgICAgICAgaGVhZGVyID0gbnAuY29weSh0YWJsZS5zY2hlbWEpCiAgICAgICAgICAgIGVsc2U6CiAgICAgICAgICAgICAgICBwcXdyaXRlciA9IHBxLlBhcnF1ZXRXcml0ZXIoZGVzdF9wYXRoLCB0YWJsZS5zY2hlbWEpCiAgICAgICAgaWYgZGF0YXNldDoKICAgICAgICAgICAgcHEud3JpdGVfdG9fZGF0YXNldCh0YWJsZSwgcm9vdF9wYXRoPWRlc3RfcGF0aCwgcGFydGl0aW9uX2NvbHM9cGFydGl0aW9uX2NvbHMpCiAgICAgICAgZWxzZToKICAgICAgICAgICAgcHF3cml0ZXIud3JpdGVfdGFibGUodGFibGUpCiAgICBpZiBwcXdyaXRlcjoKICAgICAgICBwcXdyaXRlci5jbG9zZSgpCgogICAgcmV0dXJuIGhlYWRlcgoKCmRlZiBhcmNfdG9fcGFycXVldCgKICAgICAgICBjb250ZXh0OiBNTENsaWVudEN0eCwKICAgICAgICBhcmNoaXZlX3VybDogRGF0YUl0ZW0sCiAgICAgICAgaGVhZGVyOiBMaXN0W3N0cl0gPSBbTm9uZV0sCiAgICAgICAgY2h1bmtzaXplOiBpbnQgPSAwLAogICAgICAgIGR0eXBlPU5vbmUsCiAgICAgICAgZW5jb2Rpbmc6IHN0ciA9ICJsYXRpbi0xIiwKICAgICAgICBrZXk6IHN0ciA9ICJkYXRhIiwKICAgICAgICBkYXRhc2V0OiBzdHIgPSAiTm9uZSIsCiAgICAgICAgcGFydF9jb2xzPVtdLAogICAgICAgIGZpbGVfZXh0OiBzdHIgPSAicGFycXVldCIsCiAgICAgICAgaW5kZXg6IGJvb2wgPSBGYWxzZSwKICAgICAgICByZWZyZXNoX2RhdGE6IGJvb2wgPSBGYWxzZSwKICAgICAgICBzdGF0czogYm9vbCA9IEZhbHNlCikgLT4gTm9uZToKICAgICIiIk9wZW4gYSBmaWxlL29iamVjdCBhcmNoaXZlIGFuZCBzYXZlIGFzIGEgcGFycXVldCBmaWxlIG9yIGRhdGFzZXQKCiAgICBOb3RlcwogICAgLS0tLS0KICAgICogdGhpcyBmdW5jdGlvbiBpcyB0eXBpY2FsbHkgZm9yIGxhcmdlIGZpbGVzLCBwbGVhc2UgYmUgc3VyZSB0byBjaGVjayBhbGwgc2V0dGluZ3MKICAgICogcGFydGl0aW9uaW5nIHJlcXVpcmVzIHByZWNpc2Ugc3BlY2lmaWNhdGlvbiBvZiBjb2x1bW4gdHlwZXMuCiAgICAqIHRoZSBhcmNoaXZlX3VybCBjYW4gYmUgYW55IGZpbGUgcmVhZGFibGUgYnkgcGFuZGFzIHJlYWRfY3N2LCB3aGljaCBpbmNsdWRlcyB0YXIgZmlsZXMKICAgICogaWYgdGhlIGBkYXRhc2V0YCBwYXJhbWV0ZXIgaXMgbm90IGVtcHR5LCB0aGVuIGEgcGFydGl0aW9uZWQgZGF0YXNldCB3aWxsIGJlIGNyZWF0ZWQKICAgIGluc3RlYWQgb2YgYSBzaW5nbGUgZmlsZSBpbiB0aGUgZm9sZGVyIGBkYXRhc2V0YAogICAgKiBpZiBhIGtleSBleGlzdHMgYWxyZWFkeSB0aGVuIGl0IHdpbGwgbm90IGJlIHJlLWFjcXVpcmVkIHVubGVzcyB0aGUgYHJlZnJlc2hfZGF0YWAgcGFyYW0KICAgIGlzIHNldCB0byBgVHJ1ZWAuICBUaGlzIGlzIGluIGNhc2UgdGhlIG9yaWdpbmFsIGZpbGUgaXMgY29ycnVwdCwgb3IgYSByZWZyZXNoIGlzCiAgICByZXF1aXJlZC4KCiAgICA6cGFyYW0gY29udGV4dDogICAgICAgIHRoZSBmdW5jdGlvbiBjb250ZXh0CiAgICA6cGFyYW0gYXJjaGl2ZV91cmw6ICAgIE1MUnVuIGRhdGEgaW5wdXQgKERhdGFJdGVtIG9iamVjdCkKICAgIDpwYXJhbSBjaHVua3NpemU6ICAgICAgKDApIHdoZW4gPiAwLCByb3cgc2l6ZSAoY2h1bmspIHRvIHJldHJpZXZlCiAgICAgICAgICAgICAgICAgICAgICAgICAgIHBlciBpdGVyYXRpb24KICAgIDpwYXJhbSBkdHlwZSAgICAgICAgICAgZGVzdGluYXRpb24gZGF0YSB0eXBlIG9mIHNwZWNpZmllZCBjb2x1bW5zCiAgICA6cGFyYW0gZW5jb2RpbmcgICAgICAgICgibGF0aW4tOCIpIGZpbGUgZW5jb2RpbmcKICAgIDpwYXJhbSBrZXk6ICAgICAgICAgICAga2V5IGluIGFydGlmYWN0IHN0b3JlICh3aGVuIGxvZ19kYXRhPVRydWUpCiAgICA6cGFyYW0gZGF0YXNldDogICAgICAgIChOb25lKSBpZiBub3QgTm9uZSB0aGVuICJ0YXJnZXRfcGF0aC9kYXRhc2V0IgogICAgICAgICAgICAgICAgICAgICAgICAgICBpcyBmb2xkZXIgZm9yIHBhcnRpdGlvbmVkIGZpbGVzCiAgICA6cGFyYW0gcGFydF9jb2xzOiAgICAgIChbXSkgbGlzdCBvZiBwYXJ0aXRpb25pbmcgY29sdW1ucwogICAgOnBhcmFtIGZpbGVfZXh0OiAgICAgICAocGFycXVldCkgY3N2L3BhcnF1ZXQgZmlsZSBleHRlbnNpb24KICAgIDpwYXJhbSBpbmRleDogICAgICAgICAgKEZhbHNlKSBwYW5kYXMgc2F2ZSBpbmRleCBvcHRpb24KICAgIDpwYXJhbSByZWZyZXNoX2RhdGE6ICAgKEZhbHNlKSBvdmVyd3JpdGUgZXhpc3RpbmcgZGF0YSBhdCB0aGF0IGxvY2F0aW9uCiAgICA6cGFyYW0gc3RhdHM6ICAgICAgICAgIChOb25lKSBjYWxjdWxhdGUgdGFibGUgc3RhdHMgd2hlbiBsb2dnaW5nIGFydGlmYWN0CiAgICAiIiIKICAgIGJhc2VfcGF0aCA9IGNvbnRleHQuYXJ0aWZhY3RfcGF0aAogICAgb3MubWFrZWRpcnMoYmFzZV9wYXRoLCBleGlzdF9vaz1UcnVlKQoKICAgIGFyY2hpdmVfdXJsID0gYXJjaGl2ZV91cmwubG9jYWwoKQoKICAgIGlmIGRhdGFzZXQgaXMgbm90IE5vbmU6CiAgICAgICAgZGVzdF9wYXRoID0gb3MucGF0aC5qb2luKGJhc2VfcGF0aCwgZGF0YXNldCkKICAgICAgICBleGlzdHMgPSBvcy5wYXRoLmlzZGlyKGRlc3RfcGF0aCkKICAgIGVsc2U6CiAgICAgICAgZGVzdF9wYXRoID0gb3MucGF0aC5qb2luKGJhc2VfcGF0aCwga2V5ICsgZiIue2ZpbGVfZXh0fSIpCiAgICAgICAgZXhpc3RzID0gb3MucGF0aC5pc2ZpbGUoZGVzdF9wYXRoKQoKICAgIGlmIG5vdCBleGlzdHM6CiAgICAgICAgY29udGV4dC5sb2dnZXIuaW5mbygiZGVzdGluYXRpb24gZmlsZSBkb2VzIG5vdCBleGlzdCwgZG93bmxvYWRpbmciKQogICAgICAgIGlmIGNodW5rc2l6ZSA+IDA6CiAgICAgICAgICAgIGhlYWRlciA9IF9jaHVua19yZWFkd3JpdGUoYXJjaGl2ZV91cmwsIGRlc3RfcGF0aCwgY2h1bmtzaXplLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGVuY29kaW5nLCBkdHlwZSwgZGF0YXNldCkKICAgICAgICAgICAgY29udGV4dC5sb2dfZGF0YXNldChrZXk9a2V5LCBzdGF0cz1zdGF0cywgZm9ybWF0PSdwYXJxdWV0JywKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB0YXJnZXRfcGF0aD1kZXN0X3BhdGgpCiAgICAgICAgZWxzZToKICAgICAgICAgICAgZGYgPSBwZC5yZWFkX2NzdihhcmNoaXZlX3VybCkKICAgICAgICAgICAgY29udGV4dC5sb2dfZGF0YXNldChrZXksIGRmPWRmLCBmb3JtYXQ9ZmlsZV9leHQsIGluZGV4PWluZGV4KQogICAgZWxzZToKICAgICAgICBjb250ZXh0LmxvZ2dlci5pbmZvKCJkZXN0aW5hdGlvbiBmaWxlIGFscmVhZHkgZXhpc3RzLCBub3RoaW5nIGRvbmUiKQ== origin_filename: '' + functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKaW1wb3J0IG9zCgppbXBvcnQgbnVtcHkgYXMgbnAKaW1wb3J0IHBhbmRhcyBhcyBwZAppbXBvcnQgcHlhcnJvdyBhcyBwYQppbXBvcnQgcHlhcnJvdy5wYXJxdWV0IGFzIHBxCmZyb20gbWxydW4uZGF0YXN0b3JlIGltcG9ydCBEYXRhSXRlbQpmcm9tIG1scnVuLmV4ZWN1dGlvbiBpbXBvcnQgTUxDbGllbnRDdHgKCgpkZWYgX2NodW5rX3JlYWR3cml0ZSgKICAgIGFyY2hpdmVfdXJsLCBkZXN0X3BhdGgsIGNodW5rc2l6ZSwgaGVhZGVyLCBlbmNvZGluZywgZHR5cGUsIGRhdGFzZXQKKToKICAgICIiInN0cmVhbSByZWFkIGFuZCB3cml0ZSBhcmNoaXZlcwoKICAgIHBhbmRhcyByZWFkcyBhbmQgcGFycXVldCB3cml0ZXMKCiAgICBub3RlcwogICAgLS0tLS0KICAgICogZGVzdF9wYXRoIGNhbiBiZSBlaXRoZXIgYSBmaWxlLnBhcnF1ZXQsIG9yIGluIGh0ZSBjYXNlIG9mIHBhcnRpdGlvbmVkIHBhcnF1ZXQKICAgICAgaXQgd2lsbCBiZSBvbmx5IHRoZSBkZXN0aW5hdGlvbiBmb2xkZXIgb2YgdGhlIHBhcnF1ZXQgcGFydGl0aW9uIGZpbGVzCiAgICAiIiIKICAgIHBxd3JpdGVyID0gTm9uZQogICAgaGVhZGVyID0gW10KICAgIGZvciBpLCBkZiBpbiBlbnVtZXJhdGUoCiAgICAgICAgcGQucmVhZF9jc3YoCiAgICAgICAgICAgIGFyY2hpdmVfdXJsLAogICAgICAgICAgICBjaHVua3NpemU9Y2h1bmtzaXplLAogICAgICAgICAgICBuYW1lcz1oZWFkZXIsCiAgICAgICAgICAgIGVuY29kaW5nPWVuY29kaW5nLAogICAgICAgICAgICBkdHlwZT1kdHlwZSwKICAgICAgICApCiAgICApOgogICAgICAgIHRhYmxlID0gcGEuVGFibGUuZnJvbV9wYW5kYXMoZGYpCiAgICAgICAgaWYgaSA9PSAwOgogICAgICAgICAgICBpZiBkYXRhc2V0OgogICAgICAgICAgICAgICAgaGVhZGVyID0gbnAuY29weSh0YWJsZS5zY2hlbWEpCiAgICAgICAgICAgIGVsc2U6CiAgICAgICAgICAgICAgICBwcXdyaXRlciA9IHBxLlBhcnF1ZXRXcml0ZXIoZGVzdF9wYXRoLCB0YWJsZS5zY2hlbWEpCiAgICAgICAgaWYgZGF0YXNldDoKICAgICAgICAgICAgcHEud3JpdGVfdG9fZGF0YXNldCgKICAgICAgICAgICAgICAgIHRhYmxlLCByb290X3BhdGg9ZGVzdF9wYXRoLCBwYXJ0aXRpb25fY29scz1wYXJ0aXRpb25fY29scwogICAgICAgICAgICApCiAgICAgICAgZWxzZToKICAgICAgICAgICAgcHF3cml0ZXIud3JpdGVfdGFibGUodGFibGUpCiAgICBpZiBwcXdyaXRlcjoKICAgICAgICBwcXdyaXRlci5jbG9zZSgpCgogICAgcmV0dXJuIGhlYWRlcgoKCmRlZiBhcmNfdG9fcGFycXVldCgKICAgIGNvbnRleHQ6IE1MQ2xpZW50Q3R4LAogICAgYXJjaGl2ZV91cmw6IERhdGFJdGVtLAogICAgaGVhZGVyOiBsaXN0W3N0cl0gPSBbTm9uZV0sCiAgICBjaHVua3NpemU6IGludCA9IDAsCiAgICBkdHlwZT1Ob25lLAogICAgZW5jb2Rpbmc6IHN0ciA9ICJsYXRpbi0xIiwKICAgIGtleTogc3RyID0gImRhdGEiLAogICAgZGF0YXNldDogc3RyID0gIk5vbmUiLAogICAgcGFydF9jb2xzPVtdLAogICAgZmlsZV9leHQ6IHN0ciA9ICJwYXJxdWV0IiwKICAgIGluZGV4OiBib29sID0gRmFsc2UsCiAgICByZWZyZXNoX2RhdGE6IGJvb2wgPSBGYWxzZSwKICAgIHN0YXRzOiBib29sID0gRmFsc2UsCikgLT4gTm9uZToKICAgICIiIk9wZW4gYSBmaWxlL29iamVjdCBhcmNoaXZlIGFuZCBzYXZlIGFzIGEgcGFycXVldCBmaWxlIG9yIGRhdGFzZXQKCiAgICBOb3RlcwogICAgLS0tLS0KICAgICogdGhpcyBmdW5jdGlvbiBpcyB0eXBpY2FsbHkgZm9yIGxhcmdlIGZpbGVzLCBwbGVhc2UgYmUgc3VyZSB0byBjaGVjayBhbGwgc2V0dGluZ3MKICAgICogcGFydGl0aW9uaW5nIHJlcXVpcmVzIHByZWNpc2Ugc3BlY2lmaWNhdGlvbiBvZiBjb2x1bW4gdHlwZXMuCiAgICAqIHRoZSBhcmNoaXZlX3VybCBjYW4gYmUgYW55IGZpbGUgcmVhZGFibGUgYnkgcGFuZGFzIHJlYWRfY3N2LCB3aGljaCBpbmNsdWRlcyB0YXIgZmlsZXMKICAgICogaWYgdGhlIGBkYXRhc2V0YCBwYXJhbWV0ZXIgaXMgbm90IGVtcHR5LCB0aGVuIGEgcGFydGl0aW9uZWQgZGF0YXNldCB3aWxsIGJlIGNyZWF0ZWQKICAgIGluc3RlYWQgb2YgYSBzaW5nbGUgZmlsZSBpbiB0aGUgZm9sZGVyIGBkYXRhc2V0YAogICAgKiBpZiBhIGtleSBleGlzdHMgYWxyZWFkeSB0aGVuIGl0IHdpbGwgbm90IGJlIHJlLWFjcXVpcmVkIHVubGVzcyB0aGUgYHJlZnJlc2hfZGF0YWAgcGFyYW0KICAgIGlzIHNldCB0byBgVHJ1ZWAuICBUaGlzIGlzIGluIGNhc2UgdGhlIG9yaWdpbmFsIGZpbGUgaXMgY29ycnVwdCwgb3IgYSByZWZyZXNoIGlzCiAgICByZXF1aXJlZC4KCiAgICA6cGFyYW0gY29udGV4dDogICAgICAgIHRoZSBmdW5jdGlvbiBjb250ZXh0CiAgICA6cGFyYW0gYXJjaGl2ZV91cmw6ICAgIE1MUnVuIGRhdGEgaW5wdXQgKERhdGFJdGVtIG9iamVjdCkKICAgIDpwYXJhbSBjaHVua3NpemU6ICAgICAgKDApIHdoZW4gPiAwLCByb3cgc2l6ZSAoY2h1bmspIHRvIHJldHJpZXZlCiAgICAgICAgICAgICAgICAgICAgICAgICAgIHBlciBpdGVyYXRpb24KICAgIDpwYXJhbSBkdHlwZSAgICAgICAgICAgZGVzdGluYXRpb24gZGF0YSB0eXBlIG9mIHNwZWNpZmllZCBjb2x1bW5zCiAgICA6cGFyYW0gZW5jb2RpbmcgICAgICAgICgibGF0aW4tOCIpIGZpbGUgZW5jb2RpbmcKICAgIDpwYXJhbSBrZXk6ICAgICAgICAgICAga2V5IGluIGFydGlmYWN0IHN0b3JlICh3aGVuIGxvZ19kYXRhPVRydWUpCiAgICA6cGFyYW0gZGF0YXNldDogICAgICAgIChOb25lKSBpZiBub3QgTm9uZSB0aGVuICJ0YXJnZXRfcGF0aC9kYXRhc2V0IgogICAgICAgICAgICAgICAgICAgICAgICAgICBpcyBmb2xkZXIgZm9yIHBhcnRpdGlvbmVkIGZpbGVzCiAgICA6cGFyYW0gcGFydF9jb2xzOiAgICAgIChbXSkgbGlzdCBvZiBwYXJ0aXRpb25pbmcgY29sdW1ucwogICAgOnBhcmFtIGZpbGVfZXh0OiAgICAgICAocGFycXVldCkgY3N2L3BhcnF1ZXQgZmlsZSBleHRlbnNpb24KICAgIDpwYXJhbSBpbmRleDogICAgICAgICAgKEZhbHNlKSBwYW5kYXMgc2F2ZSBpbmRleCBvcHRpb24KICAgIDpwYXJhbSByZWZyZXNoX2RhdGE6ICAgKEZhbHNlKSBvdmVyd3JpdGUgZXhpc3RpbmcgZGF0YSBhdCB0aGF0IGxvY2F0aW9uCiAgICA6cGFyYW0gc3RhdHM6ICAgICAgICAgIChOb25lKSBjYWxjdWxhdGUgdGFibGUgc3RhdHMgd2hlbiBsb2dnaW5nIGFydGlmYWN0CiAgICAiIiIKICAgIGJhc2VfcGF0aCA9IGNvbnRleHQuYXJ0aWZhY3RfcGF0aAogICAgb3MubWFrZWRpcnMoYmFzZV9wYXRoLCBleGlzdF9vaz1UcnVlKQoKICAgIGFyY2hpdmVfdXJsID0gYXJjaGl2ZV91cmwubG9jYWwoKQoKICAgIGlmIGRhdGFzZXQgaXMgbm90IE5vbmU6CiAgICAgICAgZGVzdF9wYXRoID0gb3MucGF0aC5qb2luKGJhc2VfcGF0aCwgZGF0YXNldCkKICAgICAgICBleGlzdHMgPSBvcy5wYXRoLmlzZGlyKGRlc3RfcGF0aCkKICAgIGVsc2U6CiAgICAgICAgZGVzdF9wYXRoID0gb3MucGF0aC5qb2luKGJhc2VfcGF0aCwga2V5ICsgZiIue2ZpbGVfZXh0fSIpCiAgICAgICAgZXhpc3RzID0gb3MucGF0aC5pc2ZpbGUoZGVzdF9wYXRoKQoKICAgIGlmIG5vdCBleGlzdHM6CiAgICAgICAgY29udGV4dC5sb2dnZXIuaW5mbygiZGVzdGluYXRpb24gZmlsZSBkb2VzIG5vdCBleGlzdCwgZG93bmxvYWRpbmciKQogICAgICAgIGlmIGNodW5rc2l6ZSA+IDA6CiAgICAgICAgICAgIGhlYWRlciA9IF9jaHVua19yZWFkd3JpdGUoCiAgICAgICAgICAgICAgICBhcmNoaXZlX3VybCwgZGVzdF9wYXRoLCBjaHVua3NpemUsIGVuY29kaW5nLCBkdHlwZSwgZGF0YXNldAogICAgICAgICAgICApCiAgICAgICAgICAgIGNvbnRleHQubG9nX2RhdGFzZXQoCiAgICAgICAgICAgICAgICBrZXk9a2V5LCBzdGF0cz1zdGF0cywgZm9ybWF0PSJwYXJxdWV0IiwgdGFyZ2V0X3BhdGg9ZGVzdF9wYXRoCiAgICAgICAgICAgICkKICAgICAgICBlbHNlOgogICAgICAgICAgICBkZiA9IHBkLnJlYWRfY3N2KGFyY2hpdmVfdXJsKQogICAgICAgICAgICBjb250ZXh0LmxvZ19kYXRhc2V0KGtleSwgZGY9ZGYsIGZvcm1hdD1maWxlX2V4dCwgaW5kZXg9aW5kZXgpCiAgICBlbHNlOgogICAgICAgIGNvbnRleHQubG9nZ2VyLmluZm8oImRlc3RpbmF0aW9uIGZpbGUgYWxyZWFkeSBleGlzdHMsIG5vdGhpbmcgZG9uZSIpCg== code_origin: '' - command: '' - default_handler: arc_to_parquet - description: retrieve remote archive, open and save as parquet + filename: arc_to_parquet.py entry_points: arc_to_parquet: - name: arc_to_parquet - has_varargs: false + outputs: + - type: None parameters: - name: context type: MLClientCtx @@ -20,7 +25,7 @@ spec: type: DataItem doc: MLRun data input (DataItem object) - name: header - type: List[str] + type: list[str] default: - null - name: chunksize @@ -60,6 +65,7 @@ spec: type: bool doc: (None) calculate table stats when logging artifact default: false + name: arc_to_parquet doc: 'Open a file/object archive and save as a parquet file or dataset @@ -88,13 +94,8 @@ spec: required.' has_kwargs: false - outputs: - - type: None - lineno: 68 - disable_auto_mount: false -metadata: - categories: - - utils - name: arc-to-parquet - tag: '' -verbose: false + has_varargs: false + lineno: 66 + command: '' + description: retrieve remote archive, open and save as parquet + default_handler: arc_to_parquet diff --git a/functions/src/arc_to_parquet/test_arc_to_parquet.py b/functions/src/arc_to_parquet/test_arc_to_parquet.py index f0299f57c..ec990b66a 100644 --- a/functions/src/arc_to_parquet/test_arc_to_parquet.py +++ b/functions/src/arc_to_parquet/test_arc_to_parquet.py @@ -16,28 +16,36 @@ DATA_URL = "https://s3.wasabisys.com/iguazio/data/market-palce/arc_to_parquet/higgs-sample.csv.gz" + def test_run_arc_to_parquet(): - fn = code_to_function(name='test_arc_to_parquet', - filename="arc_to_parquet.py", - handler="arc_to_parquet", - kind="local", - ) - run = fn.run(params={"key": "higgs-sample"}, - handler="arc_to_parquet", - inputs={"archive_url": DATA_URL}, - artifact_path='artifacts', - local=False) - - assert(run.outputs['higgs-sample']) + fn = code_to_function( + name="test_arc_to_parquet", + filename="arc_to_parquet.py", + handler="arc_to_parquet", + kind="local", + ) + run = fn.run( + params={"key": "higgs-sample"}, + handler="arc_to_parquet", + inputs={"archive_url": DATA_URL}, + artifact_path="artifacts", + local=False, + ) + + assert run.outputs["higgs-sample"] + def test_run_local_arc_to_parquet(): import os + os.getcwd() fn = import_function("function.yaml") - run = fn.run(params={"key": "higgs-sample"}, - handler="arc_to_parquet", - inputs={"archive_url": DATA_URL}, - artifact_path=os.getcwd()+'/artifacts', - local=True) - - assert(run.outputs['higgs-sample']) \ No newline at end of file + run = fn.run( + params={"key": "higgs-sample"}, + handler="arc_to_parquet", + inputs={"archive_url": DATA_URL}, + artifact_path=os.getcwd() + "/artifacts", + local=True, + ) + + assert run.outputs["higgs-sample"] diff --git a/functions/src/auto_trainer/auto_trainer.py b/functions/src/auto_trainer/auto_trainer.py index 7b4764700..d9ad2c8e8 100755 --- a/functions/src/auto_trainer/auto_trainer.py +++ b/functions/src/auto_trainer/auto_trainer.py @@ -13,7 +13,7 @@ # limitations under the License. # from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Union import mlrun import mlrun.datastore @@ -23,7 +23,7 @@ from mlrun.datastore import DataItem from mlrun.execution import MLClientCtx from mlrun.frameworks.auto_mlrun import AutoMLRun -from mlrun.utils.helpers import create_class, create_function +from mlrun.utils.helpers import create_class from sklearn.model_selection import train_test_split PathType = Union[str, Path] @@ -35,7 +35,7 @@ class KWArgsPrefixes: TRAIN = "TRAIN_" -def _get_sub_dict_by_prefix(src: Dict, prefix_key: str) -> Dict[str, Any]: +def _get_sub_dict_by_prefix(src: dict, prefix_key: str) -> dict[str, Any]: """ Collect all the keys from the given dict that starts with the given prefix and creates a new dictionary with these keys. @@ -54,9 +54,9 @@ def _get_sub_dict_by_prefix(src: Dict, prefix_key: str) -> Dict[str, Any]: def _get_dataframe( context: MLClientCtx, dataset: DataItem, - label_columns: Optional[Union[str, List[str]]] = None, - drop_columns: Union[str, List[str], int, List[int]] = None, -) -> Tuple[pd.DataFrame, Optional[Union[str, List[str]]]]: + label_columns: str | list[str] | None = None, + drop_columns: str | list[str] | int | list[int] = None, +) -> tuple[pd.DataFrame, str | list[str] | None]: """ Getting the DataFrame of the dataset and drop the columns accordingly. @@ -122,8 +122,8 @@ def train( context: MLClientCtx, dataset: DataItem, model_class: str, - label_columns: Optional[Union[str, List[str]]] = None, - drop_columns: List[str] = None, + label_columns: str | list[str] | None = None, + drop_columns: list[str] = None, model_name: str = "model", tag: str = "", sample_set: DataItem = None, @@ -139,6 +139,7 @@ def train( example:: import mlrun + project = mlrun.get_or_create_project("my-project") project.set_function("hub://auto_trainer", "train") trainer_run = project.run( @@ -210,7 +211,7 @@ def train( # Getting the sample set: if sample_set is None: context.logger.info( - f"Sample set not given, using the whole training set as the sample set" + "Sample set not given, using the whole training set as the sample set" ) sample_set = dataset else: @@ -274,8 +275,8 @@ def evaluate( context: MLClientCtx, model: str, dataset: mlrun.DataItem, - drop_columns: List[str] = None, - label_columns: Optional[Union[str, List[str]]] = None, + drop_columns: list[str] = None, + label_columns: str | list[str] | None = None, **kwargs, ): """ @@ -328,9 +329,9 @@ def predict( context: MLClientCtx, model: str, dataset: mlrun.DataItem, - drop_columns: Union[str, List[str], int, List[int]] = None, - label_columns: Optional[Union[str, List[str]]] = None, - result_set: Optional[str] = None, + drop_columns: str | list[str] | int | list[int] = None, + label_columns: str | list[str] | None = None, + result_set: str | None = None, **kwargs, ): """ diff --git a/functions/src/auto_trainer/function.yaml b/functions/src/auto_trainer/function.yaml index 0920b1033..bb0f13ce8 100644 --- a/functions/src/auto_trainer/function.yaml +++ b/functions/src/auto_trainer/function.yaml @@ -1,22 +1,21 @@ metadata: + tag: '' + name: auto-trainer categories: - machine-learning - model-training - tag: '' - name: auto-trainer +verbose: false +kind: job spec: image: mlrun/mlrun + disable_auto_mount: false build: origin_filename: '' - functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import mlrun
import mlrun.datastore
import mlrun.utils
import pandas as pd
from mlrun import feature_store as fs
from mlrun.datastore import DataItem
from mlrun.execution import MLClientCtx
from mlrun.frameworks.auto_mlrun import AutoMLRun
from mlrun.utils.helpers import create_class, create_function
from sklearn.model_selection import train_test_split

PathType = Union[str, Path]


class KWArgsPrefixes:
    MODEL_CLASS = "CLASS_"
    FIT = "FIT_"
    TRAIN = "TRAIN_"


def _get_sub_dict_by_prefix(src: Dict, prefix_key: str) -> Dict[str, Any]:
    """
    Collect all the keys from the given dict that starts with the given prefix and creates a new dictionary with these
    keys.

    :param src:         The source dict to extract the values from.
    :param prefix_key:  Only keys with this prefix will be returned. The keys in the result dict will be without this
                        prefix.
    """
    return {
        key.replace(prefix_key, ""): val
        for key, val in src.items()
        if key.startswith(prefix_key)
    }


def _get_dataframe(
    context: MLClientCtx,
    dataset: DataItem,
    label_columns: Optional[Union[str, List[str]]] = None,
    drop_columns: Union[str, List[str], int, List[int]] = None,
) -> Tuple[pd.DataFrame, Optional[Union[str, List[str]]]]:
    """
    Getting the DataFrame of the dataset and drop the columns accordingly.

    :param context:         MLRun context.
    :param dataset:         The dataset to train the model on.
                            Can be either a list of lists, dict, URI or a FeatureVector.
    :param label_columns:   The target label(s) of the column(s) in the dataset. for Regression or
                            Classification tasks.
    :param drop_columns:    str/int or a list of strings/ints that represent the column names/indices to drop.
    """
    store_uri_prefix, _ = mlrun.datastore.parse_store_uri(dataset.artifact_url)

    # Getting the dataset:
    if mlrun.utils.StorePrefix.FeatureVector == store_uri_prefix:
        label_columns = label_columns or dataset.meta.status.label_column
        context.logger.info(f"label columns: {label_columns}")
        # FeatureVector case:
        try:
            fv = mlrun.datastore.get_store_resource(dataset.artifact_url)
            dataset = fv.get_offline_features(drop_columns=drop_columns).to_dataframe()
        except AttributeError:
            # Leave here for backwards compatibility
            dataset = fs.get_offline_features(
                dataset.meta.uri, drop_columns=drop_columns
            ).to_dataframe()

    elif not label_columns:
        context.logger.info(
            "label_columns not provided, mandatory when dataset is not a FeatureVector"
        )
        raise ValueError

    elif isinstance(dataset, (list, dict)):
        # list/dict case:
        dataset = pd.DataFrame(dataset)
        # Checking if drop_columns provided by integer type:
        if drop_columns:
            if isinstance(drop_columns, str) or (
                isinstance(drop_columns, list)
                and any(isinstance(col, str) for col in drop_columns)
            ):
                context.logger.error(
                    "drop_columns must be an integer/list of integers if not provided with a URI/FeatureVector dataset"
                )
                raise ValueError
            dataset.drop(drop_columns, axis=1, inplace=True)

    else:
        # simple URL case:
        dataset = dataset.as_df()
        if drop_columns:
            if all(col in dataset for col in drop_columns):
                dataset = dataset.drop(drop_columns, axis=1)
            else:
                context.logger.info(
                    "not all of the columns to drop in the dataset, drop columns process skipped"
                )

    return dataset, label_columns


def train(
    context: MLClientCtx,
    dataset: DataItem,
    model_class: str,
    label_columns: Optional[Union[str, List[str]]] = None,
    drop_columns: List[str] = None,
    model_name: str = "model",
    tag: str = "",
    sample_set: DataItem = None,
    test_set: DataItem = None,
    train_test_split_size: float = None,
    random_state: int = None,
    labels: dict = None,
    **kwargs,
):
    """
    Training a model with the given dataset.

    example::

        import mlrun
        project = mlrun.get_or_create_project("my-project")
        project.set_function("hub://auto_trainer", "train")
        trainer_run = project.run(
            name="train",
            handler="train",
            inputs={"dataset": "./path/to/dataset.csv"},
            params={
                "model_class": "sklearn.linear_model.LogisticRegression",
                "label_columns": "label",
                "drop_columns": "id",
                "model_name": "my-model",
                "tag": "v1.0.0",
                "sample_set": "./path/to/sample_set.csv",
                "test_set": "./path/to/test_set.csv",
                "CLASS_solver": "liblinear",
            },
        )

    :param context:                 MLRun context
    :param dataset:                 The dataset to train the model on. Can be either a URI or a FeatureVector
    :param model_class:             The class of the model, e.g. `sklearn.linear_model.LogisticRegression`
    :param label_columns:           The target label(s) of the column(s) in the dataset. for Regression or
                                    Classification tasks. Mandatory when dataset is not a FeatureVector.
    :param drop_columns:            str or a list of strings that represent the columns to drop
    :param model_name:              The model's name to use for storing the model artifact, default to 'model'
    :param tag:                     The model's tag to log with
    :param sample_set:              A sample set of inputs for the model for logging its stats along the model in favour
                                    of model monitoring. Can be either a URI or a FeatureVector
    :param test_set:                The test set to train the model with.
    :param train_test_split_size:   if test_set was provided then this argument is ignored.
                                    Should be between 0.0 and 1.0 and represent the proportion of the dataset to include
                                    in the test split. The size of the Training set is set to the complement of this
                                    value. Default = 0.2
    :param random_state:            Relevant only when using train_test_split_size.
                                    A random state seed to shuffle the data. For more information, see:
                                    https://scikit-learn.org/stable/glossary.html#term-random_state
                                    Notice that here we only pass integer values.
    :param labels:                  Labels to log with the model
    :param kwargs:                  Here you can pass keyword arguments with prefixes,
                                    that will be parsed and passed to the relevant function, by the following prefixes:
                                    - `CLASS_` - for the model class arguments
                                    - `FIT_` - for the `fit` function arguments
                                    - `TRAIN_` - for the `train` function (in xgb or lgbm train function - future)

    """
    # Validate inputs:
    # Check if exactly one of them is supplied:
    if test_set is None:
        if train_test_split_size is None:
            context.logger.info(
                "test_set or train_test_split_size are not provided, setting train_test_split_size to 0.2"
            )
            train_test_split_size = 0.2

    elif train_test_split_size:
        context.logger.info(
            "test_set provided, ignoring given train_test_split_size value"
        )
        train_test_split_size = None

    # Get DataFrame by URL or by FeatureVector:
    dataset, label_columns = _get_dataframe(
        context=context,
        dataset=dataset,
        label_columns=label_columns,
        drop_columns=drop_columns,
    )

    # Getting the sample set:
    if sample_set is None:
        context.logger.info(
            f"Sample set not given, using the whole training set as the sample set"
        )
        sample_set = dataset
    else:
        sample_set, _ = _get_dataframe(
            context=context,
            dataset=sample_set,
            label_columns=label_columns,
            drop_columns=drop_columns,
        )

    # Parsing kwargs:
    # TODO: Use in xgb or lgbm train function.
    train_kwargs = _get_sub_dict_by_prefix(src=kwargs, prefix_key=KWArgsPrefixes.TRAIN)
    fit_kwargs = _get_sub_dict_by_prefix(src=kwargs, prefix_key=KWArgsPrefixes.FIT)
    model_class_kwargs = _get_sub_dict_by_prefix(
        src=kwargs, prefix_key=KWArgsPrefixes.MODEL_CLASS
    )

    # Check if model or function:
    if hasattr(model_class, "train"):
        # TODO: Need to call: model(), afterwards to start the train function.
        # model = create_function(f"{model_class}.train")
        raise NotImplementedError
    else:
        # Creating model instance:
        model = create_class(model_class)(**model_class_kwargs)

    x = dataset.drop(label_columns, axis=1)
    y = dataset[label_columns]
    if train_test_split_size:
        x_train, x_test, y_train, y_test = train_test_split(
            x, y, test_size=train_test_split_size, random_state=random_state
        )
    else:
        x_train, y_train = x, y

        test_set = test_set.as_df()
        if drop_columns:
            test_set = dataset.drop(drop_columns, axis=1)

        x_test, y_test = test_set.drop(label_columns, axis=1), test_set[label_columns]

    AutoMLRun.apply_mlrun(
        model=model,
        model_name=model_name,
        context=context,
        tag=tag,
        sample_set=sample_set,
        y_columns=label_columns,
        test_set=test_set,
        x_test=x_test,
        y_test=y_test,
        artifacts=context.artifacts,
        labels=labels,
    )
    context.logger.info(f"training '{model_name}'")
    model.fit(x_train, y_train, **fit_kwargs)


def evaluate(
    context: MLClientCtx,
    model: str,
    dataset: mlrun.DataItem,
    drop_columns: List[str] = None,
    label_columns: Optional[Union[str, List[str]]] = None,
    **kwargs,
):
    """
    Evaluating a model. Artifacts generated by the MLHandler.

    :param context:                 MLRun context.
    :param model:                   The model Store path.
    :param dataset:                 The dataset to evaluate the model on. Can be either a URI or a FeatureVector.
    :param drop_columns:            str or a list of strings that represent the columns to drop.
    :param label_columns:           The target label(s) of the column(s) in the dataset. for Regression or
                                    Classification tasks. Mandatory when dataset is not a FeatureVector.
    :param kwargs:                  Here you can pass keyword arguments to the predict function
                                    (PREDICT_ prefix is not required).
    """
    # Get dataset by URL or by FeatureVector:
    dataset, label_columns = _get_dataframe(
        context=context,
        dataset=dataset,
        label_columns=label_columns,
        drop_columns=drop_columns,
    )

    # Parsing label_columns:
    parsed_label_columns = []
    if label_columns:
        label_columns = (
            label_columns if isinstance(label_columns, list) else [label_columns]
        )
        for lc in label_columns:
            if fs.common.feature_separator in lc:
                feature_set_name, label_name, alias = fs.common.parse_feature_string(lc)
                parsed_label_columns.append(alias or label_name)
        if parsed_label_columns:
            label_columns = parsed_label_columns

    x = dataset.drop(label_columns, axis=1)
    y = dataset[label_columns]

    # Loading the model and predicting:
    model_handler = AutoMLRun.load_model(
        model_path=model, context=context, model_name="model_LinearRegression"
    )
    AutoMLRun.apply_mlrun(model_handler.model, y_test=y, model_path=model)

    context.logger.info(f"evaluating '{model_handler.model_name}'")
    model_handler.model.predict(x, **kwargs)


def predict(
    context: MLClientCtx,
    model: str,
    dataset: mlrun.DataItem,
    drop_columns: Union[str, List[str], int, List[int]] = None,
    label_columns: Optional[Union[str, List[str]]] = None,
    result_set: Optional[str] = None,
    **kwargs,
):
    """
    Predicting dataset by a model.

    :param context:                 MLRun context.
    :param model:                   The model Store path.
    :param dataset:                 The dataset to predict the model on. Can be either a URI, a FeatureVector or a
                                    sample in a shape of a list/dict.
                                    When passing a sample, pass the dataset as a field in `params` instead of `inputs`.
    :param drop_columns:            str/int or a list of strings/ints that represent the column names/indices to drop.
                                    When the dataset is a list/dict this parameter should be represented by integers.
    :param label_columns:           The target label(s) of the column(s) in the dataset. for Regression or
                                    Classification tasks. Mandatory when dataset is not a FeatureVector.
    :param result_set:              The db key to set name of the prediction result and the filename.
                                    Default to 'prediction'.
    :param kwargs:                  Here you can pass keyword arguments to the predict function
                                    (PREDICT_ prefix is not required).
    """
    # Get dataset by URL or by FeatureVector:
    dataset, label_columns = _get_dataframe(
        context=context,
        dataset=dataset,
        label_columns=label_columns,
        drop_columns=drop_columns,
    )

    # loading the model, and getting the model handler:
    model_handler = AutoMLRun.load_model(model_path=model, context=context)

    # Dropping label columns if necessary:
    if not label_columns:
        label_columns = []
    elif isinstance(label_columns, str):
        label_columns = [label_columns]

    # Predicting:
    context.logger.info(f"making prediction by '{model_handler.model_name}'")
    y_pred = model_handler.model.predict(dataset, **kwargs)

    # Preparing and validating label columns for the dataframe of the prediction result:
    num_predicted = 1 if len(y_pred.shape) == 1 else y_pred.shape[1]

    if num_predicted > len(label_columns):
        if num_predicted == 1:
            label_columns = ["predicted labels"]
        else:
            label_columns.extend(
                [
                    f"predicted_label_{i + 1 + len(label_columns)}"
                    for i in range(num_predicted - len(label_columns))
                ]
            )
    elif num_predicted < len(label_columns):
        context.logger.error(
            f"number of predicted labels: {num_predicted} is smaller than number of label columns: {len(label_columns)}"
        )
        raise ValueError

    artifact_name = result_set or "prediction"
    labels_inside_df = set(label_columns) & set(dataset.columns.tolist())
    if labels_inside_df:
        context.logger.error(
            f"The labels: {labels_inside_df} are already existed in the dataframe"
        )
        raise ValueError
    pred_df = pd.concat([dataset, pd.DataFrame(y_pred, columns=label_columns)], axis=1)
    context.log_dataset(artifact_name, pred_df, db_key=result_set)
 + functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pathlib import Path
from typing import Any, Union

import mlrun
import mlrun.datastore
import mlrun.utils
import pandas as pd
from mlrun import feature_store as fs
from mlrun.datastore import DataItem
from mlrun.execution import MLClientCtx
from mlrun.frameworks.auto_mlrun import AutoMLRun
from mlrun.utils.helpers import create_class
from sklearn.model_selection import train_test_split

PathType = Union[str, Path]


class KWArgsPrefixes:
    MODEL_CLASS = "CLASS_"
    FIT = "FIT_"
    TRAIN = "TRAIN_"


def _get_sub_dict_by_prefix(src: dict, prefix_key: str) -> dict[str, Any]:
    """
    Collect all the keys from the given dict that starts with the given prefix and creates a new dictionary with these
    keys.

    :param src:         The source dict to extract the values from.
    :param prefix_key:  Only keys with this prefix will be returned. The keys in the result dict will be without this
                        prefix.
    """
    return {
        key.replace(prefix_key, ""): val
        for key, val in src.items()
        if key.startswith(prefix_key)
    }


def _get_dataframe(
    context: MLClientCtx,
    dataset: DataItem,
    label_columns: str | list[str] | None = None,
    drop_columns: str | list[str] | int | list[int] = None,
) -> tuple[pd.DataFrame, str | list[str] | None]:
    """
    Getting the DataFrame of the dataset and drop the columns accordingly.

    :param context:         MLRun context.
    :param dataset:         The dataset to train the model on.
                            Can be either a list of lists, dict, URI or a FeatureVector.
    :param label_columns:   The target label(s) of the column(s) in the dataset. for Regression or
                            Classification tasks.
    :param drop_columns:    str/int or a list of strings/ints that represent the column names/indices to drop.
    """
    store_uri_prefix, _ = mlrun.datastore.parse_store_uri(dataset.artifact_url)

    # Getting the dataset:
    if mlrun.utils.StorePrefix.FeatureVector == store_uri_prefix:
        label_columns = label_columns or dataset.meta.status.label_column
        context.logger.info(f"label columns: {label_columns}")
        # FeatureVector case:
        try:
            fv = mlrun.datastore.get_store_resource(dataset.artifact_url)
            dataset = fv.get_offline_features(drop_columns=drop_columns).to_dataframe()
        except AttributeError:
            # Leave here for backwards compatibility
            dataset = fs.get_offline_features(
                dataset.meta.uri, drop_columns=drop_columns
            ).to_dataframe()

    elif not label_columns:
        context.logger.info(
            "label_columns not provided, mandatory when dataset is not a FeatureVector"
        )
        raise ValueError

    elif isinstance(dataset, (list, dict)):
        # list/dict case:
        dataset = pd.DataFrame(dataset)
        # Checking if drop_columns provided by integer type:
        if drop_columns:
            if isinstance(drop_columns, str) or (
                isinstance(drop_columns, list)
                and any(isinstance(col, str) for col in drop_columns)
            ):
                context.logger.error(
                    "drop_columns must be an integer/list of integers if not provided with a URI/FeatureVector dataset"
                )
                raise ValueError
            dataset.drop(drop_columns, axis=1, inplace=True)

    else:
        # simple URL case:
        dataset = dataset.as_df()
        if drop_columns:
            if all(col in dataset for col in drop_columns):
                dataset = dataset.drop(drop_columns, axis=1)
            else:
                context.logger.info(
                    "not all of the columns to drop in the dataset, drop columns process skipped"
                )

    return dataset, label_columns


def train(
    context: MLClientCtx,
    dataset: DataItem,
    model_class: str,
    label_columns: str | list[str] | None = None,
    drop_columns: list[str] = None,
    model_name: str = "model",
    tag: str = "",
    sample_set: DataItem = None,
    test_set: DataItem = None,
    train_test_split_size: float = None,
    random_state: int = None,
    labels: dict = None,
    **kwargs,
):
    """
    Training a model with the given dataset.

    example::

        import mlrun

        project = mlrun.get_or_create_project("my-project")
        project.set_function("hub://auto_trainer", "train")
        trainer_run = project.run(
            name="train",
            handler="train",
            inputs={"dataset": "./path/to/dataset.csv"},
            params={
                "model_class": "sklearn.linear_model.LogisticRegression",
                "label_columns": "label",
                "drop_columns": "id",
                "model_name": "my-model",
                "tag": "v1.0.0",
                "sample_set": "./path/to/sample_set.csv",
                "test_set": "./path/to/test_set.csv",
                "CLASS_solver": "liblinear",
            },
        )

    :param context:                 MLRun context
    :param dataset:                 The dataset to train the model on. Can be either a URI or a FeatureVector
    :param model_class:             The class of the model, e.g. `sklearn.linear_model.LogisticRegression`
    :param label_columns:           The target label(s) of the column(s) in the dataset. for Regression or
                                    Classification tasks. Mandatory when dataset is not a FeatureVector.
    :param drop_columns:            str or a list of strings that represent the columns to drop
    :param model_name:              The model's name to use for storing the model artifact, default to 'model'
    :param tag:                     The model's tag to log with
    :param sample_set:              A sample set of inputs for the model for logging its stats along the model in favour
                                    of model monitoring. Can be either a URI or a FeatureVector
    :param test_set:                The test set to train the model with.
    :param train_test_split_size:   if test_set was provided then this argument is ignored.
                                    Should be between 0.0 and 1.0 and represent the proportion of the dataset to include
                                    in the test split. The size of the Training set is set to the complement of this
                                    value. Default = 0.2
    :param random_state:            Relevant only when using train_test_split_size.
                                    A random state seed to shuffle the data. For more information, see:
                                    https://scikit-learn.org/stable/glossary.html#term-random_state
                                    Notice that here we only pass integer values.
    :param labels:                  Labels to log with the model
    :param kwargs:                  Here you can pass keyword arguments with prefixes,
                                    that will be parsed and passed to the relevant function, by the following prefixes:
                                    - `CLASS_` - for the model class arguments
                                    - `FIT_` - for the `fit` function arguments
                                    - `TRAIN_` - for the `train` function (in xgb or lgbm train function - future)

    """
    # Validate inputs:
    # Check if exactly one of them is supplied:
    if test_set is None:
        if train_test_split_size is None:
            context.logger.info(
                "test_set or train_test_split_size are not provided, setting train_test_split_size to 0.2"
            )
            train_test_split_size = 0.2

    elif train_test_split_size:
        context.logger.info(
            "test_set provided, ignoring given train_test_split_size value"
        )
        train_test_split_size = None

    # Get DataFrame by URL or by FeatureVector:
    dataset, label_columns = _get_dataframe(
        context=context,
        dataset=dataset,
        label_columns=label_columns,
        drop_columns=drop_columns,
    )

    # Getting the sample set:
    if sample_set is None:
        context.logger.info(
            "Sample set not given, using the whole training set as the sample set"
        )
        sample_set = dataset
    else:
        sample_set, _ = _get_dataframe(
            context=context,
            dataset=sample_set,
            label_columns=label_columns,
            drop_columns=drop_columns,
        )

    # Parsing kwargs:
    # TODO: Use in xgb or lgbm train function.
    train_kwargs = _get_sub_dict_by_prefix(src=kwargs, prefix_key=KWArgsPrefixes.TRAIN)
    fit_kwargs = _get_sub_dict_by_prefix(src=kwargs, prefix_key=KWArgsPrefixes.FIT)
    model_class_kwargs = _get_sub_dict_by_prefix(
        src=kwargs, prefix_key=KWArgsPrefixes.MODEL_CLASS
    )

    # Check if model or function:
    if hasattr(model_class, "train"):
        # TODO: Need to call: model(), afterwards to start the train function.
        # model = create_function(f"{model_class}.train")
        raise NotImplementedError
    else:
        # Creating model instance:
        model = create_class(model_class)(**model_class_kwargs)

    x = dataset.drop(label_columns, axis=1)
    y = dataset[label_columns]
    if train_test_split_size:
        x_train, x_test, y_train, y_test = train_test_split(
            x, y, test_size=train_test_split_size, random_state=random_state
        )
    else:
        x_train, y_train = x, y

        test_set = test_set.as_df()
        if drop_columns:
            test_set = dataset.drop(drop_columns, axis=1)

        x_test, y_test = test_set.drop(label_columns, axis=1), test_set[label_columns]

    AutoMLRun.apply_mlrun(
        model=model,
        model_name=model_name,
        context=context,
        tag=tag,
        sample_set=sample_set,
        y_columns=label_columns,
        test_set=test_set,
        x_test=x_test,
        y_test=y_test,
        artifacts=context.artifacts,
        labels=labels,
    )
    context.logger.info(f"training '{model_name}'")
    model.fit(x_train, y_train, **fit_kwargs)


def evaluate(
    context: MLClientCtx,
    model: str,
    dataset: mlrun.DataItem,
    drop_columns: list[str] = None,
    label_columns: str | list[str] | None = None,
    **kwargs,
):
    """
    Evaluating a model. Artifacts generated by the MLHandler.

    :param context:                 MLRun context.
    :param model:                   The model Store path.
    :param dataset:                 The dataset to evaluate the model on. Can be either a URI or a FeatureVector.
    :param drop_columns:            str or a list of strings that represent the columns to drop.
    :param label_columns:           The target label(s) of the column(s) in the dataset. for Regression or
                                    Classification tasks. Mandatory when dataset is not a FeatureVector.
    :param kwargs:                  Here you can pass keyword arguments to the predict function
                                    (PREDICT_ prefix is not required).
    """
    # Get dataset by URL or by FeatureVector:
    dataset, label_columns = _get_dataframe(
        context=context,
        dataset=dataset,
        label_columns=label_columns,
        drop_columns=drop_columns,
    )

    # Parsing label_columns:
    parsed_label_columns = []
    if label_columns:
        label_columns = (
            label_columns if isinstance(label_columns, list) else [label_columns]
        )
        for lc in label_columns:
            if fs.common.feature_separator in lc:
                feature_set_name, label_name, alias = fs.common.parse_feature_string(lc)
                parsed_label_columns.append(alias or label_name)
        if parsed_label_columns:
            label_columns = parsed_label_columns

    x = dataset.drop(label_columns, axis=1)
    y = dataset[label_columns]

    # Loading the model and predicting:
    model_handler = AutoMLRun.load_model(
        model_path=model, context=context, model_name="model_LinearRegression"
    )
    AutoMLRun.apply_mlrun(model_handler.model, y_test=y, model_path=model)

    context.logger.info(f"evaluating '{model_handler.model_name}'")
    model_handler.model.predict(x, **kwargs)


def predict(
    context: MLClientCtx,
    model: str,
    dataset: mlrun.DataItem,
    drop_columns: str | list[str] | int | list[int] = None,
    label_columns: str | list[str] | None = None,
    result_set: str | None = None,
    **kwargs,
):
    """
    Predicting dataset by a model.

    :param context:                 MLRun context.
    :param model:                   The model Store path.
    :param dataset:                 The dataset to predict the model on. Can be either a URI, a FeatureVector or a
                                    sample in a shape of a list/dict.
                                    When passing a sample, pass the dataset as a field in `params` instead of `inputs`.
    :param drop_columns:            str/int or a list of strings/ints that represent the column names/indices to drop.
                                    When the dataset is a list/dict this parameter should be represented by integers.
    :param label_columns:           The target label(s) of the column(s) in the dataset. for Regression or
                                    Classification tasks. Mandatory when dataset is not a FeatureVector.
    :param result_set:              The db key to set name of the prediction result and the filename.
                                    Default to 'prediction'.
    :param kwargs:                  Here you can pass keyword arguments to the predict function
                                    (PREDICT_ prefix is not required).
    """
    # Get dataset by URL or by FeatureVector:
    dataset, label_columns = _get_dataframe(
        context=context,
        dataset=dataset,
        label_columns=label_columns,
        drop_columns=drop_columns,
    )

    # loading the model, and getting the model handler:
    model_handler = AutoMLRun.load_model(model_path=model, context=context)

    # Dropping label columns if necessary:
    if not label_columns:
        label_columns = []
    elif isinstance(label_columns, str):
        label_columns = [label_columns]

    # Predicting:
    context.logger.info(f"making prediction by '{model_handler.model_name}'")
    y_pred = model_handler.model.predict(dataset, **kwargs)

    # Preparing and validating label columns for the dataframe of the prediction result:
    num_predicted = 1 if len(y_pred.shape) == 1 else y_pred.shape[1]

    if num_predicted > len(label_columns):
        if num_predicted == 1:
            label_columns = ["predicted labels"]
        else:
            label_columns.extend(
                [
                    f"predicted_label_{i + 1 + len(label_columns)}"
                    for i in range(num_predicted - len(label_columns))
                ]
            )
    elif num_predicted < len(label_columns):
        context.logger.error(
            f"number of predicted labels: {num_predicted} is smaller than number of label columns: {len(label_columns)}"
        )
        raise ValueError

    artifact_name = result_set or "prediction"
    labels_inside_df = set(label_columns) & set(dataset.columns.tolist())
    if labels_inside_df:
        context.logger.error(
            f"The labels: {labels_inside_df} are already existed in the dataframe"
        )
        raise ValueError
    pred_df = pd.concat([dataset, pd.DataFrame(y_pred, columns=label_columns)], axis=1)
    context.log_dataset(artifact_name, pred_df, db_key=result_set)
 code_origin: '' - description: Automatic train, evaluate and predict functions for the ML frameworks - - Scikit-Learn, XGBoost and LightGBM. - disable_auto_mount: false - default_handler: train + filename: auto_trainer.py entry_points: train: - lineno: 121 parameters: - name: context type: MLClientCtx @@ -28,12 +27,11 @@ spec: type: str doc: The class of the model, e.g. `sklearn.linear_model.LogisticRegression` - name: label_columns - type: Optional[Union[str, List[str]]] doc: The target label(s) of the column(s) in the dataset. for Regression or Classification tasks. Mandatory when dataset is not a FeatureVector. default: null - name: drop_columns - type: List[str] + type: list[str] doc: str or a list of strings that represent the columns to drop default: null - name: model_name @@ -70,11 +68,9 @@ spec: type: dict doc: Labels to log with the model default: null - has_varargs: false name: train - has_kwargs: true doc: "Training a model with the given dataset.\n\nexample::\n\n import mlrun\n\ - \ project = mlrun.get_or_create_project(\"my-project\")\n project.set_function(\"\ + \n project = mlrun.get_or_create_project(\"my-project\")\n project.set_function(\"\ hub://auto_trainer\", \"train\")\n trainer_run = project.run(\n \ \ name=\"train\",\n handler=\"train\",\n inputs={\"dataset\"\ : \"./path/to/dataset.csv\"},\n params={\n \"model_class\"\ @@ -83,8 +79,10 @@ spec: : \"my-model\",\n \"tag\": \"v1.0.0\",\n \"sample_set\"\ : \"./path/to/sample_set.csv\",\n \"test_set\": \"./path/to/test_set.csv\"\ ,\n \"CLASS_solver\": \"liblinear\",\n },\n )" + has_kwargs: true + has_varargs: false + lineno: 121 evaluate: - lineno: 273 parameters: - name: context type: MLClientCtx @@ -96,20 +94,19 @@ spec: type: DataItem doc: The dataset to evaluate the model on. Can be either a URI or a FeatureVector. - name: drop_columns - type: List[str] + type: list[str] doc: str or a list of strings that represent the columns to drop. default: null - name: label_columns - type: Optional[Union[str, List[str]]] doc: The target label(s) of the column(s) in the dataset. for Regression or Classification tasks. Mandatory when dataset is not a FeatureVector. default: null - has_varargs: false name: evaluate - has_kwargs: true doc: Evaluating a model. Artifacts generated by the MLHandler. + has_kwargs: true + has_varargs: false + lineno: 274 predict: - lineno: 327 parameters: - name: context type: MLClientCtx @@ -123,25 +120,24 @@ spec: or a sample in a shape of a list/dict. When passing a sample, pass the dataset as a field in `params` instead of `inputs`. - name: drop_columns - type: Union[str, List[str], int, List[int]] doc: str/int or a list of strings/ints that represent the column names/indices to drop. When the dataset is a list/dict this parameter should be represented by integers. default: null - name: label_columns - type: Optional[Union[str, List[str]]] doc: The target label(s) of the column(s) in the dataset. for Regression or Classification tasks. Mandatory when dataset is not a FeatureVector. default: null - name: result_set - type: Optional[str] doc: The db key to set name of the prediction result and the filename. Default to 'prediction'. default: null - has_varargs: false name: predict - has_kwargs: true doc: Predicting dataset by a model. + has_kwargs: true + has_varargs: false + lineno: 328 command: '' -kind: job -verbose: false + description: Automatic train, evaluate and predict functions for the ML frameworks + - Scikit-Learn, XGBoost and LightGBM. + default_handler: train diff --git a/functions/src/auto_trainer/test_auto_trainer.py b/functions/src/auto_trainer/test_auto_trainer.py index 9a1ff554c..49eb4101b 100644 --- a/functions/src/auto_trainer/test_auto_trainer.py +++ b/functions/src/auto_trainer/test_auto_trainer.py @@ -14,7 +14,6 @@ # import os import tempfile -from typing import Tuple import mlrun import pandas as pd @@ -78,7 +77,7 @@ def _assert_train_handler(train_run): @pytest.mark.parametrize("model", MODELS) -def test_train(model: Tuple[str, str]): +def test_train(model: tuple[str, str]): dataset, label_columns = _get_dataset(model[1]) is_test_passed = True @@ -115,7 +114,7 @@ def test_train(model: Tuple[str, str]): condition=not _validate_environment_variables(), reason="Project's environment variables are not set", ) -def test_train_evaluate(model: Tuple[str, str]): +def test_train_evaluate(model: tuple[str, str]): dataset, label_columns = _get_dataset(model[1]) is_test_passed = True # Importing function: @@ -156,9 +155,9 @@ def test_train_evaluate(model: Tuple[str, str]): is_test_passed = False assert is_test_passed, "The test failed" - assert ( - evaluate_run and "evaluation-test_set" in evaluate_run.outputs - ), "Missing fields in evaluate_run" + assert evaluate_run and "evaluation-test_set" in evaluate_run.outputs, ( + "Missing fields in evaluate_run" + ) @pytest.mark.parametrize("model", MODELS) @@ -166,7 +165,7 @@ def test_train_evaluate(model: Tuple[str, str]): condition=not _validate_environment_variables(), reason="Project's environment variables are not set", ) -def test_train_predict(model: Tuple[str, str]): +def test_train_predict(model: tuple[str, str]): is_test_passed = True dataset, label_columns = _get_dataset(model[1]) df = pd.read_csv(dataset) @@ -210,6 +209,6 @@ def test_train_predict(model: Tuple[str, str]): is_test_passed = False assert is_test_passed, "The test failed" - assert ( - predict_run and "prediction" in predict_run.outputs - ), "Prediction field must be in the output" + assert predict_run and "prediction" in predict_run.outputs, ( + "Prediction field must be in the output" + ) diff --git a/functions/src/azureml_serving/function.yaml b/functions/src/azureml_serving/function.yaml index 978806878..fd996b356 100644 --- a/functions/src/azureml_serving/function.yaml +++ b/functions/src/azureml_serving/function.yaml @@ -1,51 +1,31 @@ -kind: serving metadata: - name: azureml-serving tag: '' - hash: c0f404820b8f0fe92d2d1cfe9dbcc068be1a13bf - project: '' - labels: - author: Iguazio + name: azureml-serving categories: - machine-learning - model-serving +verbose: false +kind: serving spec: - command: '' - args: [] image: mlrun/mlrun - build: - commands: - - python -m pip install azureml-automl-runtime~=1.38.1 - code_origin: "" - origin_filename: "" - description: AzureML serving function disable_auto_mount: false - env: [] - priority_class_name: '' - preemption_mode: prevent + build: + origin_filename: '' + functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKIyBEbyBub3QgZGVsZXRlIQoKZnJvbSBtbHJ1bi5ydW50aW1lcyBpbXBvcnQgbnVjbGlvX2luaXRfaG9vawpkZWYgaW5pdF9jb250ZXh0KGNvbnRleHQpOgogICAgbnVjbGlvX2luaXRfaG9vayhjb250ZXh0LCBnbG9iYWxzKCksICdzZXJ2aW5nX3YyJykKCmRlZiBoYW5kbGVyKGNvbnRleHQsIGV2ZW50KToKICAgIHJldHVybiBjb250ZXh0Lm1scnVuX2hhbmRsZXIoY29udGV4dCwgZXZlbnQpCg== + requirements: + - azureml-automl-runtime~=1.38.1 + code_origin: '' + filename: azureml_serving.py + default_class: mlrun.frameworks.sklearn.PickleModelServer min_replicas: 1 - max_replicas: 4 - base_spec: - apiVersion: nuclio.io/v1 - kind: Function - metadata: - name: azureml-serving - labels: {} - annotations: - nuclio.io/generated_by: function generated from /Users/yonatanshelach/yoni/projects/functions/azureml_serving/azureml_serving.py - spec: - runtime: python - handler: azureml_serving:handler - env: [] - volumes: [] - build: - commands: [] - noBaseImagesPull: true - functionSourceCode: IyBEbyBub3QgZGVsZXRlIQoKZnJvbSBtbHJ1bi5ydW50aW1lcyBpbXBvcnQgbnVjbGlvX2luaXRfaG9vawpkZWYgaW5pdF9jb250ZXh0KGNvbnRleHQpOgogICAgbnVjbGlvX2luaXRfaG9vayhjb250ZXh0LCBnbG9iYWxzKCksICdzZXJ2aW5nX3YyJykKCmRlZiBoYW5kbGVyKGNvbnRleHQsIGV2ZW50KToKICAgIHJldHVybiBjb250ZXh0Lm1scnVuX2hhbmRsZXIoY29udGV4dCwgZXZlbnQpCg== + command: '' + default_handler: '' source: '' + max_replicas: 4 + base_image_pull: false + description: AzureML serving function function_kind: serving_v2 - default_class: mlrun.frameworks.sklearn.PickleModelServer - secret_sources: [] - affinity: null - tolerations: null -verbose: false \ No newline at end of file + function_handler: azureml-serving-nuclio:handler + env: + - name: MLRUN_HTTPDB__NUCLIO__EXPLICIT_ACK + value: enabled diff --git a/functions/src/azureml_utils/azureml_utils.py b/functions/src/azureml_utils/azureml_utils.py index 041af2b87..a8ac6bd7f 100644 --- a/functions/src/azureml_utils/azureml_utils.py +++ b/functions/src/azureml_utils/azureml_utils.py @@ -12,28 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import os import json import logging -from typing import Tuple, List +import os -from mlrun import MLClientCtx, DataItem, get_dataitem -import mlrun.feature_store as f_store import mlrun.datastore +import mlrun.feature_store as f_store import mlrun.utils -from mlrun.datastore.targets import ParquetTarget - from azureml.core.authentication import ServicePrincipalAuthentication -from azureml.core.workspace import Workspace -from azureml.core.experiment import Experiment +from azureml.core.compute import AmlCompute, ComputeTarget +from azureml.core.compute_target import ComputeTargetException from azureml.core.dataset import Dataset +from azureml.core.experiment import Experiment from azureml.core.model import Model -from azureml.core.compute import ComputeTarget, AmlCompute -from azureml.core.compute_target import ComputeTargetException from azureml.core.script_run import ScriptRun - +from azureml.core.workspace import Workspace from azureml.train.automl import AutoMLConfig from azureml.train.automl.run import AutoMLRun +from mlrun import DataItem, MLClientCtx, get_dataitem +from mlrun.datastore.targets import ParquetTarget def _env_or_secret(context, key): @@ -77,7 +74,7 @@ def _load_workspace(context: MLClientCtx) -> Workspace: def _init_experiment( context: MLClientCtx, experiment_name: str -) -> Tuple[Workspace, Experiment]: +) -> tuple[Workspace, Experiment]: """ Initialize workspace and experiment in Azure ML. Uses Service Principal authentication via environment variables. @@ -156,9 +153,9 @@ def register_dataset( """ # test for Azure storage connection environment variable or secret: - assert _env_or_secret( - context, "AZURE_STORAGE_CONNECTION_STRING" - ), "AZURE_STORAGE_CONNECTION_STRING secret not set" + assert _env_or_secret(context, "AZURE_STORAGE_CONNECTION_STRING"), ( + "AZURE_STORAGE_CONNECTION_STRING secret not set" + ) # Connect to AzureML experiment and datastore: context.logger.info("Connecting to AzureML experiment default datastore") @@ -177,7 +174,9 @@ def register_dataset( context.logger.info( f"Retrieving feature vector and uploading to Azure blob storage: {blob_path}" ) - f_store.get_offline_features(data.meta.uri, target=ParquetTarget(path=blob_path)) + f_store.get_offline_features( + data.meta.uri, target=ParquetTarget(path=blob_path) + ) else: blob_path += data.suffix # DataItem case: @@ -195,7 +194,7 @@ def register_dataset( ) else: context.logger.info( - f"OpenSSL version must be 1.1. Overriding the OpenSSL version to 1.1" + "OpenSSL version must be 1.1. Overriding the OpenSSL version to 1.1" ) # OpenSSL version must be 1.1 os.environ["CLR_OPENSSL_VERSION_OVERRIDE"] = "1.1" @@ -265,7 +264,7 @@ def upload_model( def _get_top_n_runs( remote_run: AutoMLRun, n: int = 5, primary_metric: str = "accuracy" -) -> List[ScriptRun]: +) -> list[ScriptRun]: """ Get top N complete runs from experiment sorted by primary metric. @@ -317,9 +316,9 @@ def _get_model_hp( return {} hp_dicts = spec_dict["objects"] # after training there are two hyper-parameters dicts inside the run object: - assert ( - len(hp_dicts) == 2 - ), "after training there are two hyper-parameters dicts inside the run object" + assert len(hp_dicts) == 2, ( + "after training there are two hyper-parameters dicts inside the run object" + ) result_dict = {} dict_keys = [ ["data_trans_class_name", "data_trans_module", "data_trans_spec_class"], @@ -336,7 +335,6 @@ def _get_model_hp( kwargs_prefix = "param_kwargs" for d, name, keys in zip(hp_dicts, ["data_trans", "train"], dict_keys): for key in keys: - if kwargs_prefix in key: result_dict[key] = d[kwargs_prefix][ key.replace(f"{name}_{kwargs_prefix}_", "") @@ -357,7 +355,7 @@ def submit_training_job( registered_dataset_name: str, automl_settings: dict, training_set: DataItem, - label_column_name: str = '', + label_column_name: str = "", save_n_models: int = 3, show_output: bool = True, ) -> None: @@ -390,7 +388,7 @@ def submit_training_job( if mlrun.utils.StorePrefix.FeatureVector == store_uri_prefix: feature_vector = training_set.meta.uri label_column_name = label_column_name or training_set.meta.status.label_column - context.logger.info(f'label column name: {label_column_name}') + context.logger.info(f"label column name: {label_column_name}") training_set = f_store.get_offline_features(feature_vector).to_dataframe() else: training_set = training_set.as_df() @@ -445,9 +443,7 @@ def submit_training_job( with context.get_child_context(**model_hp_dict) as child: model_key = f"model_{i + 1}_{model_hp_dict['data_trans_class_name'].lower()}_{model_hp_dict['train_class_name'].lower()}" # Log model: - context.logger.info( - f"Logging {model_key} model to MLRun" - ) + context.logger.info(f"Logging {model_key} model to MLRun") child.log_results(metrics) child.log_model( "model", diff --git a/functions/src/azureml_utils/function.yaml b/functions/src/azureml_utils/function.yaml index f14a6313f..fcd31ef59 100644 --- a/functions/src/azureml_utils/function.yaml +++ b/functions/src/azureml_utils/function.yaml @@ -1,32 +1,35 @@ +metadata: + tag: '' + name: azureml-utils + categories: + - model-serving + - utils verbose: false +kind: job spec: - command: '' + image: '' + disable_auto_mount: false build: - auto_build: true - code_origin: '' + origin_filename: '' with_mlrun: true + functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import os

import mlrun.datastore
import mlrun.feature_store as f_store
import mlrun.utils
from azureml.core.authentication import ServicePrincipalAuthentication
from azureml.core.compute import AmlCompute, ComputeTarget
from azureml.core.compute_target import ComputeTargetException
from azureml.core.dataset import Dataset
from azureml.core.experiment import Experiment
from azureml.core.model import Model
from azureml.core.script_run import ScriptRun
from azureml.core.workspace import Workspace
from azureml.train.automl import AutoMLConfig
from azureml.train.automl.run import AutoMLRun
from mlrun import DataItem, MLClientCtx, get_dataitem
from mlrun.datastore.targets import ParquetTarget


def _env_or_secret(context, key):
    if key in os.environ:
        return os.environ[key]
    return context.get_secret(key)


def _load_workspace(context: MLClientCtx) -> Workspace:
    """
    Loading AzureML Workspace with Azure secrets.

    :param context: MLRun context.
    :returns:       AzureML Workspace
    """

    if hasattr(context, "_azure_workspace"):
        return context._azure_workspace

    context.logger.info("Loading AzureML Workspace")
    # Azure service authentication:
    service_authentication = ServicePrincipalAuthentication(
        tenant_id=_env_or_secret(context, "AZURE_TENANT_ID"),
        service_principal_id=_env_or_secret(context, "AZURE_SERVICE_PRINCIPAL_ID"),
        service_principal_password=_env_or_secret(
            context, "AZURE_SERVICE_PRINCIPAL_PASSWORD"
        ),
    )

    # Loading Azure workspace:
    workspace = Workspace(
        subscription_id=_env_or_secret(context, "AZURE_SUBSCRIPTION_ID"),
        resource_group=_env_or_secret(context, "AZURE_RESOURCE_GROUP"),
        workspace_name=_env_or_secret(context, "AZURE_WORKSPACE_NAME"),
        auth=service_authentication,
    )

    context._azure_workspace = workspace
    return workspace


def _init_experiment(
    context: MLClientCtx, experiment_name: str
) -> tuple[Workspace, Experiment]:
    """
    Initialize workspace and experiment in Azure ML. Uses Service
    Principal authentication via environment variables.

    :param context:         MLRun context.
    :param experiment_name: Name of experiment to create in Azure ML.
    :returns:               Azure ML Workspace and Experiment.
    """

    # Initialize experiment via Service Principal Authentication:
    # https://docs.microsoft.com/en-us/azure/machine-learning/how-to-setup-authentication#use-service-principal-authentication

    workspace = _load_workspace(context)

    context.logger.info(f"Initializing AzureML experiment {experiment_name}")
    # Creating experiment:
    experiment = Experiment(workspace, experiment_name)

    return workspace, experiment


def init_compute(
    context: MLClientCtx,
    cpu_cluster_name: str,
    vm_size: str = "STANDARD_D2_V2",
    max_nodes: int = 1,
) -> ComputeTarget:
    """
    Initialize Azure ML compute target to run experiment. Checks for
    existing compute target and creates new if does not exist.

    :param context:          MLRun context.
    :param cpu_cluster_name: Name of Azure ML compute target. Created if does not exist.
    :param vm_size:          Azure machine type for compute target.
    :param max_nodes:        Maximum number of concurrent compute targets.
    :returns:                Azure ML Compute Target.
    """

    workspace = _load_workspace(context)
    context.logger.info(f"Initializing AzureML compute target {cpu_cluster_name}")

    # Verify that cluster does not exist already:
    try:
        compute_target = ComputeTarget(workspace=workspace, name=cpu_cluster_name)
        context.logger.info("Found existing cluster, will use it.")
    except ComputeTargetException:
        compute_config = AmlCompute.provisioning_configuration(
            vm_size=vm_size, max_nodes=max_nodes
        )
        compute_target = ComputeTarget.create(
            workspace, cpu_cluster_name, compute_config
        )

    compute_target.wait_for_completion(show_output=True)
    return compute_target


def register_dataset(
    context: MLClientCtx,
    dataset_name: str,
    dataset_description: str,
    data: DataItem,
    create_new_version: bool = False,
):
    """
    Register dataset object (can be also an Iguazio FeatureVector) in Azure ML.
    Uploads parquet file to Azure blob storage and registers
    that file as a dataset in Azure ML.

    :param context:               MLRun context.
    :param dataset_name:          Name of Azure dataset to register.
    :param dataset_description:   Description of Azure dataset to register.
    :param data:                  MLRun FeatureVector or dataset object to upload.
    :param create_new_version:    Register Azure dataset as new version. Must be used when
                                  modifying dataset schema.
    """

    # test for Azure storage connection environment variable or secret:
    assert _env_or_secret(context, "AZURE_STORAGE_CONNECTION_STRING"), (
        "AZURE_STORAGE_CONNECTION_STRING secret not set"
    )

    # Connect to AzureML experiment and datastore:
    context.logger.info("Connecting to AzureML experiment default datastore")

    workspace = _load_workspace(context)
    datastore = workspace.get_default_datastore()

    # Azure blob path (default datastore for workspace):
    blob_path = f"az://{datastore.container_name}/{dataset_name}"

    store_uri_prefix, _ = mlrun.datastore.parse_store_uri(data.artifact_url)
    feature_vector_case = mlrun.utils.StorePrefix.FeatureVector == store_uri_prefix
    # Retrieve data source as dataframe:
    if feature_vector_case:
        # FeatureVector case:
        context.logger.info(
            f"Retrieving feature vector and uploading to Azure blob storage: {blob_path}"
        )
        f_store.get_offline_features(
            data.meta.uri, target=ParquetTarget(path=blob_path)
        )
    else:
        blob_path += data.suffix
        # DataItem case:
        context.logger.info(
            f"Retrieving feature vector and uploading to Azure blob storage: {blob_path}"
        )
        data_in_bytes = data.get()
        get_dataitem(blob_path).put(data_in_bytes)

    # Register dataset in AzureML:
    context.logger.info(f"Registering dataset {dataset_name} in Azure ML")
    if data.suffix == ".parquet" or feature_vector_case:
        dataset = Dataset.Tabular.from_parquet_files(
            path=(datastore, f"{dataset_name}.parquet"), validate=False
        )
    else:
        context.logger.info(
            "OpenSSL version must be 1.1. Overriding the OpenSSL version to 1.1"
        )
        # OpenSSL version must be 1.1
        os.environ["CLR_OPENSSL_VERSION_OVERRIDE"] = "1.1"
        dataset = Dataset.Tabular.from_delimited_files(
            path=(datastore, f"{dataset_name}{data.suffix}"), validate=False
        )

    dataset.register(
        workspace=workspace,
        name=dataset_name,
        description=dataset_description,
        create_new_version=create_new_version,
    )

    # Output registered dataset name in Azure:
    context.log_result("dataset_blob_path", blob_path)


def download_model(
    context: MLClientCtx,
    model_name: str,
    model_version: int,
    target_dir: str = ".",
) -> None:
    """
    Download trained model from Azure ML to local filesystem.

    :param context:       MLRun context.
    :param model_name:    Name of trained and registered model.
    :param model_version: Version of model to download.
    :param target_dir:    Target directory to download model.
    """
    # Loading workspace if not provided:
    workspace = _load_workspace(context)
    context.logger.info(f"Downloading model {model_name}:{model_version}")
    model = Model(workspace, model_name, version=model_version)
    model.download(target_dir=target_dir, exist_ok=True)


def upload_model(
    context: MLClientCtx,
    model_name: str,
    model_path: str,
    model_description: str = None,
    model_tags: dict = None,
) -> None:
    """
    Upload pre-trained model from local filesystem to Azure ML.
    :param context:           MLRun context.
    :param model_name:        Name of trained and registered model.
    :param model_path:        Path to file on local filesystem.
    :param model_description: Description of models.
    :param model_tags:        KV pairs of model tags.
    """
    # Loading workspace if not provided:
    workspace = _load_workspace(context)

    context.logger.info(f"Upload model {model_name} from {model_path}")
    Model.register(
        workspace=workspace,
        model_path=model_path,
        model_name=model_name,
        description=model_description,
        tags=model_tags,
    )


def _get_top_n_runs(
    remote_run: AutoMLRun, n: int = 5, primary_metric: str = "accuracy"
) -> list[ScriptRun]:
    """
    Get top N complete runs from experiment sorted by primary metric.

    :param remote_run:     Azure ML Run.
    :param n:              Number of top runs to return.
    :param primary_metric: Metric to sort by.

    :returns:              List of top N runs sorted by primary metric.
    """
    # Collect all models:
    complete_runs = [
        run
        for run in remote_run.get_children(status="Completed")
        if not any(s in run.id for s in ["setup", "worker"])
    ]

    # Checking that the required number of runs are done:
    if len(complete_runs) < n:
        raise ValueError(f"Expected {n} runs but only received {len(complete_runs)}")

    # Sorting by the primary metric:
    sorted_runs = sorted(
        complete_runs, key=lambda run: run.get_metrics()[primary_metric], reverse=True
    )
    return sorted_runs[:n]


def _get_model_hp(
    run: ScriptRun,
) -> dict:
    """
    Get hyper-parameters of trained AzureML model.
    Combine the hyper-parameters of the data transformation and training to a dictionary.
    The prefix of the dictionary keys corresponds to 'data transformation' and 'training'.

    :param run: Run object of AzureML trained model.

    :returns:    A dictionary as described in the docstring.
    """

    spec_field = "pipeline_spec"
    if spec_field not in run.properties:
        return {}
    spec_string = run.properties[spec_field]
    spec_dict = json.loads(spec_string)

    if "objects" not in spec_dict:
        # No hyper-params
        return {}
    hp_dicts = spec_dict["objects"]
    # after training there are two hyper-parameters dicts inside the run object:
    assert len(hp_dicts) == 2, (
        "after training there are two hyper-parameters dicts inside the run object"
    )
    result_dict = {}
    dict_keys = [
        ["data_trans_class_name", "data_trans_module", "data_trans_spec_class"],
        [
            "train_class_name",
            "train_module",
            "train_param_kwargs_C",
            "train_param_kwargs_class_weight",
            "train_spec_class",
        ],
    ]

    # creating hyper-params dict with key prefixes for each part:
    kwargs_prefix = "param_kwargs"
    for d, name, keys in zip(hp_dicts, ["data_trans", "train"], dict_keys):
        for key in keys:
            if kwargs_prefix in key:
                result_dict[key] = d[kwargs_prefix][
                    key.replace(f"{name}_{kwargs_prefix}_", "")
                ]
            else:
                result_dict[key] = d[key.replace(f"{name}_", "")]
            if not result_dict[key]:
                result_dict[key] = ""

    return result_dict


def submit_training_job(
    context: MLClientCtx,
    experiment: Experiment,
    compute_target: ComputeTarget,
    register_model_name: str,
    registered_dataset_name: str,
    automl_settings: dict,
    training_set: DataItem,
    label_column_name: str = "",
    save_n_models: int = 3,
    show_output: bool = True,
) -> None:
    """
    Submit training job to Azure AutoML and download trained model
    when completed. Uses previously registered dataset for training.

    :param context:                 MLRun context.
    :param experiment:              Azure experiment.
    :param compute_target:          Azure compute target.
    :param register_model_name:     Name of model to register in Azure.
    :param registered_dataset_name: Name of dataset registered in Azure ML.
    :param label_column_name:       Name of target column in dataset.
    :param automl_settings:         JSON string of all Azure AutoML settings.
    :param training_set:            Training set to log with model. For model
                                    monitoring integration.
    :param show_output:             Displaying Azure logs.
    :param save_n_models:           How many of the top performing models to log.
    """
    # Loading workspace if not provided:
    workspace = _load_workspace(context)

    # Setup experiment:
    context.logger.info("Setting up experiment parameters")
    dataset = Dataset.get_by_name(workspace, name=registered_dataset_name)

    # Get training set to log with model:
    feature_vector = None
    store_uri_prefix, _ = mlrun.datastore.parse_store_uri(training_set.artifact_url)
    if mlrun.utils.StorePrefix.FeatureVector == store_uri_prefix:
        feature_vector = training_set.meta.uri
        label_column_name = label_column_name or training_set.meta.status.label_column
        context.logger.info(f"label column name: {label_column_name}")
        training_set = f_store.get_offline_features(feature_vector).to_dataframe()
    else:
        training_set = training_set.as_df()

    automl_config = AutoMLConfig(
        compute_target=compute_target,
        training_data=dataset,
        verbosity=logging.INFO,
        label_column_name=label_column_name,
        **automl_settings,
    )

    # Run experiment on AzureML:
    context.logger.info("Submitting and running experiment")
    remote_run = experiment.submit(automl_config)
    remote_run.wait_for_completion(show_output=show_output)
    if show_output:
        # Azure log ending row:
        print(f"\n{'*' * 92}\n")
    # Get top N runs to log:
    top_runs = _get_top_n_runs(
        remote_run=remote_run,
        n=save_n_models,
        primary_metric=automl_settings["primary_metric"],
    )

    # Register, download, and log models:
    for i, run in enumerate(top_runs):
        # Register model:
        context.logger.info("Registering model")
        model = run.register_model(
            model_name=register_model_name, model_path="outputs/model.pkl"
        )
        context.logger.info(
            f"Registered model with name '{model.name}', id '{model.id}', version '{model.version}'"
        )

        # Download model locally:
        download_model(
            context=context,
            model_name=register_model_name,
            model_version=model.version,
            target_dir=f"./{model.version}",
        )

        metrics = {k.lower(): val for k, val in run.get_metrics().items()}
        del metrics["confusion_matrix"]
        del metrics["accuracy_table"]

        # Collect model hyper-parameters:
        model_hp_dict = _get_model_hp(run)
        with context.get_child_context(**model_hp_dict) as child:
            model_key = f"model_{i + 1}_{model_hp_dict['data_trans_class_name'].lower()}_{model_hp_dict['train_class_name'].lower()}"
            # Log model:
            context.logger.info(f"Logging {model_key} model to MLRun")
            child.log_results(metrics)
            child.log_model(
                "model",
                db_key=model_key,
                artifact_path=context.artifact_subpath("models"),
                metrics=metrics,
                model_file=f"{model.version}/model.pkl",
                training_set=training_set,
                label_column=label_column_name,
                feature_vector=feature_vector,
                framework="AzureML",
                algorithm=model_hp_dict.get("train_class_name"),
            )
            if i == 0:
                # This also logs the model:
                child.mark_as_best()


def train(
    # MlRun
    context: MLClientCtx,
    dataset: DataItem,
    # Init experiment and compute
    experiment_name: str = "",
    cpu_cluster_name: str = "",
    vm_size: str = "STANDARD_D2_V2",
    max_nodes: int = 1,
    # Register dataset
    dataset_name: str = "",
    dataset_description: str = "",
    create_new_version: bool = False,
    label_column_name: str = "",
    # Submit training job
    register_model_name: str = "",
    save_n_models: int = 1,
    log_azure: bool = True,
    automl_settings: str = None,
) -> None:
    """
    Whole training flow for Azure AutoML. Registers dataset/feature vector,
    submits training job to Azure AutoML, and downloads trained model
    when completed.

    :param context:             MLRun context.

    :param dataset:             MLRun FeatureVector or dataset URI to upload. Will drop
                                index before uploading when it is a FeatureVector.

    :param experiment_name:     Name of experiment to create in Azure ML.
    :param cpu_cluster_name:    Name of Azure ML compute target. Created if does not exist.
    :param vm_size:             Azure machine type for compute target.
    :param max_nodes:           Maximum number of concurrent compute targets.

    :param dataset_name:        Name of Azure dataset to register.
    :param dataset_description: Description of Azure dataset to register.

    :param create_new_version:  Register Azure dataset as new version. Must be used when
                                modifying dataset schema.
    :param label_column_name:   Target column in dataset.

    :param register_model_name: Name of model to register in Azure.
    :param save_n_models:       How many of the top performing models to log.
    :param log_azure:           Displaying Azure logs.
    :param automl_settings:     JSON string of all Azure AutoML settings.
    """
    if not automl_settings:
        automl_settings = {
            "task": "classification",
            "debug_log": "automl_errors.log",
            # "experiment_exit_score": 0.9,
            "enable_early_stopping": False,
            "allowed_models": ["LogisticRegression", "SGD", "SVM"],
            "iterations": 3,
            "iteration_timeout_minutes": 2,
            "max_concurrent_iterations": 2,
            "max_cores_per_iteration": -1,
            "n_cross_validations": 5,
            "primary_metric": "accuracy",
            "featurization": "off",
            "model_explainability": False,
            "enable_voting_ensemble": False,
            "enable_stack_ensemble": False,
        }

    # Init experiment and compute
    workspace, experiment = _init_experiment(
        context=context, experiment_name=experiment_name
    )

    compute_target = init_compute(
        context=context,
        cpu_cluster_name=cpu_cluster_name,
        vm_size=vm_size,
        max_nodes=max_nodes,
    )

    # Register dataset
    register_dataset(
        context=context,
        dataset_name=dataset_name,
        dataset_description=dataset_description,
        data=dataset,
        create_new_version=create_new_version,
    )

    # Submit training job
    submit_training_job(
        context,
        experiment=experiment,
        compute_target=compute_target,
        register_model_name=register_model_name,
        registered_dataset_name=dataset_name,
        label_column_name=label_column_name,
        automl_settings=automl_settings,
        training_set=dataset,
        show_output=log_azure,
        save_n_models=save_n_models,
    )
 requirements: - azureml-core==1.54.0.post1 - azureml-train-automl-client==1.54.0.post1 - plotly~=5.23 - functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import json
import logging
from typing import Tuple, List

from mlrun import MLClientCtx, DataItem, get_dataitem
import mlrun.feature_store as f_store
import mlrun.datastore
import mlrun.utils
from mlrun.datastore.targets import ParquetTarget

from azureml.core.authentication import ServicePrincipalAuthentication
from azureml.core.workspace import Workspace
from azureml.core.experiment import Experiment
from azureml.core.dataset import Dataset
from azureml.core.model import Model
from azureml.core.compute import ComputeTarget, AmlCompute
from azureml.core.compute_target import ComputeTargetException
from azureml.core.script_run import ScriptRun

from azureml.train.automl import AutoMLConfig
from azureml.train.automl.run import AutoMLRun


def _env_or_secret(context, key):
    if key in os.environ:
        return os.environ[key]
    return context.get_secret(key)


def _load_workspace(context: MLClientCtx) -> Workspace:
    """
    Loading AzureML Workspace with Azure secrets.

    :param context: MLRun context.
    :returns:       AzureML Workspace
    """

    if hasattr(context, "_azure_workspace"):
        return context._azure_workspace

    context.logger.info("Loading AzureML Workspace")
    # Azure service authentication:
    service_authentication = ServicePrincipalAuthentication(
        tenant_id=_env_or_secret(context, "AZURE_TENANT_ID"),
        service_principal_id=_env_or_secret(context, "AZURE_SERVICE_PRINCIPAL_ID"),
        service_principal_password=_env_or_secret(
            context, "AZURE_SERVICE_PRINCIPAL_PASSWORD"
        ),
    )

    # Loading Azure workspace:
    workspace = Workspace(
        subscription_id=_env_or_secret(context, "AZURE_SUBSCRIPTION_ID"),
        resource_group=_env_or_secret(context, "AZURE_RESOURCE_GROUP"),
        workspace_name=_env_or_secret(context, "AZURE_WORKSPACE_NAME"),
        auth=service_authentication,
    )

    context._azure_workspace = workspace
    return workspace


def _init_experiment(
    context: MLClientCtx, experiment_name: str
) -> Tuple[Workspace, Experiment]:
    """
    Initialize workspace and experiment in Azure ML. Uses Service
    Principal authentication via environment variables.

    :param context:         MLRun context.
    :param experiment_name: Name of experiment to create in Azure ML.
    :returns:               Azure ML Workspace and Experiment.
    """

    # Initialize experiment via Service Principal Authentication:
    # https://docs.microsoft.com/en-us/azure/machine-learning/how-to-setup-authentication#use-service-principal-authentication

    workspace = _load_workspace(context)

    context.logger.info(f"Initializing AzureML experiment {experiment_name}")
    # Creating experiment:
    experiment = Experiment(workspace, experiment_name)

    return workspace, experiment


def init_compute(
    context: MLClientCtx,
    cpu_cluster_name: str,
    vm_size: str = "STANDARD_D2_V2",
    max_nodes: int = 1,
) -> ComputeTarget:
    """
    Initialize Azure ML compute target to run experiment. Checks for
    existing compute target and creates new if does not exist.

    :param context:          MLRun context.
    :param cpu_cluster_name: Name of Azure ML compute target. Created if does not exist.
    :param vm_size:          Azure machine type for compute target.
    :param max_nodes:        Maximum number of concurrent compute targets.
    :returns:                Azure ML Compute Target.
    """

    workspace = _load_workspace(context)
    context.logger.info(f"Initializing AzureML compute target {cpu_cluster_name}")

    # Verify that cluster does not exist already:
    try:
        compute_target = ComputeTarget(workspace=workspace, name=cpu_cluster_name)
        context.logger.info("Found existing cluster, will use it.")
    except ComputeTargetException:
        compute_config = AmlCompute.provisioning_configuration(
            vm_size=vm_size, max_nodes=max_nodes
        )
        compute_target = ComputeTarget.create(
            workspace, cpu_cluster_name, compute_config
        )

    compute_target.wait_for_completion(show_output=True)
    return compute_target


def register_dataset(
    context: MLClientCtx,
    dataset_name: str,
    dataset_description: str,
    data: DataItem,
    create_new_version: bool = False,
):
    """
    Register dataset object (can be also an Iguazio FeatureVector) in Azure ML.
    Uploads parquet file to Azure blob storage and registers
    that file as a dataset in Azure ML.

    :param context:               MLRun context.
    :param dataset_name:          Name of Azure dataset to register.
    :param dataset_description:   Description of Azure dataset to register.
    :param data:                  MLRun FeatureVector or dataset object to upload.
    :param create_new_version:    Register Azure dataset as new version. Must be used when
                                  modifying dataset schema.
    """

    # test for Azure storage connection environment variable or secret:
    assert _env_or_secret(
        context, "AZURE_STORAGE_CONNECTION_STRING"
    ), "AZURE_STORAGE_CONNECTION_STRING secret not set"

    # Connect to AzureML experiment and datastore:
    context.logger.info("Connecting to AzureML experiment default datastore")

    workspace = _load_workspace(context)
    datastore = workspace.get_default_datastore()

    # Azure blob path (default datastore for workspace):
    blob_path = f"az://{datastore.container_name}/{dataset_name}"

    store_uri_prefix, _ = mlrun.datastore.parse_store_uri(data.artifact_url)
    feature_vector_case = mlrun.utils.StorePrefix.FeatureVector == store_uri_prefix
    # Retrieve data source as dataframe:
    if feature_vector_case:
        # FeatureVector case:
        context.logger.info(
            f"Retrieving feature vector and uploading to Azure blob storage: {blob_path}"
        )
        f_store.get_offline_features(data.meta.uri, target=ParquetTarget(path=blob_path))
    else:
        blob_path += data.suffix
        # DataItem case:
        context.logger.info(
            f"Retrieving feature vector and uploading to Azure blob storage: {blob_path}"
        )
        data_in_bytes = data.get()
        get_dataitem(blob_path).put(data_in_bytes)

    # Register dataset in AzureML:
    context.logger.info(f"Registering dataset {dataset_name} in Azure ML")
    if data.suffix == ".parquet" or feature_vector_case:
        dataset = Dataset.Tabular.from_parquet_files(
            path=(datastore, f"{dataset_name}.parquet"), validate=False
        )
    else:
        context.logger.info(
            f"OpenSSL version must be 1.1. Overriding the OpenSSL version to 1.1"
        )
        # OpenSSL version must be 1.1
        os.environ["CLR_OPENSSL_VERSION_OVERRIDE"] = "1.1"
        dataset = Dataset.Tabular.from_delimited_files(
            path=(datastore, f"{dataset_name}{data.suffix}"), validate=False
        )

    dataset.register(
        workspace=workspace,
        name=dataset_name,
        description=dataset_description,
        create_new_version=create_new_version,
    )

    # Output registered dataset name in Azure:
    context.log_result("dataset_blob_path", blob_path)


def download_model(
    context: MLClientCtx,
    model_name: str,
    model_version: int,
    target_dir: str = ".",
) -> None:
    """
    Download trained model from Azure ML to local filesystem.

    :param context:       MLRun context.
    :param model_name:    Name of trained and registered model.
    :param model_version: Version of model to download.
    :param target_dir:    Target directory to download model.
    """
    # Loading workspace if not provided:
    workspace = _load_workspace(context)
    context.logger.info(f"Downloading model {model_name}:{model_version}")
    model = Model(workspace, model_name, version=model_version)
    model.download(target_dir=target_dir, exist_ok=True)


def upload_model(
    context: MLClientCtx,
    model_name: str,
    model_path: str,
    model_description: str = None,
    model_tags: dict = None,
) -> None:
    """
    Upload pre-trained model from local filesystem to Azure ML.
    :param context:           MLRun context.
    :param model_name:        Name of trained and registered model.
    :param model_path:        Path to file on local filesystem.
    :param model_description: Description of models.
    :param model_tags:        KV pairs of model tags.
    """
    # Loading workspace if not provided:
    workspace = _load_workspace(context)

    context.logger.info(f"Upload model {model_name} from {model_path}")
    Model.register(
        workspace=workspace,
        model_path=model_path,
        model_name=model_name,
        description=model_description,
        tags=model_tags,
    )


def _get_top_n_runs(
    remote_run: AutoMLRun, n: int = 5, primary_metric: str = "accuracy"
) -> List[ScriptRun]:
    """
    Get top N complete runs from experiment sorted by primary metric.

    :param remote_run:     Azure ML Run.
    :param n:              Number of top runs to return.
    :param primary_metric: Metric to sort by.

    :returns:              List of top N runs sorted by primary metric.
    """
    # Collect all models:
    complete_runs = [
        run
        for run in remote_run.get_children(status="Completed")
        if not any(s in run.id for s in ["setup", "worker"])
    ]

    # Checking that the required number of runs are done:
    if len(complete_runs) < n:
        raise ValueError(f"Expected {n} runs but only received {len(complete_runs)}")

    # Sorting by the primary metric:
    sorted_runs = sorted(
        complete_runs, key=lambda run: run.get_metrics()[primary_metric], reverse=True
    )
    return sorted_runs[:n]


def _get_model_hp(
    run: ScriptRun,
) -> dict:
    """
    Get hyper-parameters of trained AzureML model.
    Combine the hyper-parameters of the data transformation and training to a dictionary.
    The prefix of the dictionary keys corresponds to 'data transformation' and 'training'.

    :param run: Run object of AzureML trained model.

    :returns:    A dictionary as described in the docstring.
    """

    spec_field = "pipeline_spec"
    if spec_field not in run.properties:
        return {}
    spec_string = run.properties[spec_field]
    spec_dict = json.loads(spec_string)

    if "objects" not in spec_dict:
        # No hyper-params
        return {}
    hp_dicts = spec_dict["objects"]
    # after training there are two hyper-parameters dicts inside the run object:
    assert (
        len(hp_dicts) == 2
    ), "after training there are two hyper-parameters dicts inside the run object"
    result_dict = {}
    dict_keys = [
        ["data_trans_class_name", "data_trans_module", "data_trans_spec_class"],
        [
            "train_class_name",
            "train_module",
            "train_param_kwargs_C",
            "train_param_kwargs_class_weight",
            "train_spec_class",
        ],
    ]

    # creating hyper-params dict with key prefixes for each part:
    kwargs_prefix = "param_kwargs"
    for d, name, keys in zip(hp_dicts, ["data_trans", "train"], dict_keys):
        for key in keys:

            if kwargs_prefix in key:
                result_dict[key] = d[kwargs_prefix][
                    key.replace(f"{name}_{kwargs_prefix}_", "")
                ]
            else:
                result_dict[key] = d[key.replace(f"{name}_", "")]
            if not result_dict[key]:
                result_dict[key] = ""

    return result_dict


def submit_training_job(
    context: MLClientCtx,
    experiment: Experiment,
    compute_target: ComputeTarget,
    register_model_name: str,
    registered_dataset_name: str,
    automl_settings: dict,
    training_set: DataItem,
    label_column_name: str = '',
    save_n_models: int = 3,
    show_output: bool = True,
) -> None:
    """
    Submit training job to Azure AutoML and download trained model
    when completed. Uses previously registered dataset for training.

    :param context:                 MLRun context.
    :param experiment:              Azure experiment.
    :param compute_target:          Azure compute target.
    :param register_model_name:     Name of model to register in Azure.
    :param registered_dataset_name: Name of dataset registered in Azure ML.
    :param label_column_name:       Name of target column in dataset.
    :param automl_settings:         JSON string of all Azure AutoML settings.
    :param training_set:            Training set to log with model. For model
                                    monitoring integration.
    :param show_output:             Displaying Azure logs.
    :param save_n_models:           How many of the top performing models to log.
    """
    # Loading workspace if not provided:
    workspace = _load_workspace(context)

    # Setup experiment:
    context.logger.info("Setting up experiment parameters")
    dataset = Dataset.get_by_name(workspace, name=registered_dataset_name)

    # Get training set to log with model:
    feature_vector = None
    store_uri_prefix, _ = mlrun.datastore.parse_store_uri(training_set.artifact_url)
    if mlrun.utils.StorePrefix.FeatureVector == store_uri_prefix:
        feature_vector = training_set.meta.uri
        label_column_name = label_column_name or training_set.meta.status.label_column
        context.logger.info(f'label column name: {label_column_name}')
        training_set = f_store.get_offline_features(feature_vector).to_dataframe()
    else:
        training_set = training_set.as_df()

    automl_config = AutoMLConfig(
        compute_target=compute_target,
        training_data=dataset,
        verbosity=logging.INFO,
        label_column_name=label_column_name,
        **automl_settings,
    )

    # Run experiment on AzureML:
    context.logger.info("Submitting and running experiment")
    remote_run = experiment.submit(automl_config)
    remote_run.wait_for_completion(show_output=show_output)
    if show_output:
        # Azure log ending row:
        print(f"\n{'*' * 92}\n")
    # Get top N runs to log:
    top_runs = _get_top_n_runs(
        remote_run=remote_run,
        n=save_n_models,
        primary_metric=automl_settings["primary_metric"],
    )

    # Register, download, and log models:
    for i, run in enumerate(top_runs):
        # Register model:
        context.logger.info("Registering model")
        model = run.register_model(
            model_name=register_model_name, model_path="outputs/model.pkl"
        )
        context.logger.info(
            f"Registered model with name '{model.name}', id '{model.id}', version '{model.version}'"
        )

        # Download model locally:
        download_model(
            context=context,
            model_name=register_model_name,
            model_version=model.version,
            target_dir=f"./{model.version}",
        )

        metrics = {k.lower(): val for k, val in run.get_metrics().items()}
        del metrics["confusion_matrix"]
        del metrics["accuracy_table"]

        # Collect model hyper-parameters:
        model_hp_dict = _get_model_hp(run)
        with context.get_child_context(**model_hp_dict) as child:
            model_key = f"model_{i + 1}_{model_hp_dict['data_trans_class_name'].lower()}_{model_hp_dict['train_class_name'].lower()}"
            # Log model:
            context.logger.info(
                f"Logging {model_key} model to MLRun"
            )
            child.log_results(metrics)
            child.log_model(
                "model",
                db_key=model_key,
                artifact_path=context.artifact_subpath("models"),
                metrics=metrics,
                model_file=f"{model.version}/model.pkl",
                training_set=training_set,
                label_column=label_column_name,
                feature_vector=feature_vector,
                framework="AzureML",
                algorithm=model_hp_dict.get("train_class_name"),
            )
            if i == 0:
                # This also logs the model:
                child.mark_as_best()


def train(
    # MlRun
    context: MLClientCtx,
    dataset: DataItem,
    # Init experiment and compute
    experiment_name: str = "",
    cpu_cluster_name: str = "",
    vm_size: str = "STANDARD_D2_V2",
    max_nodes: int = 1,
    # Register dataset
    dataset_name: str = "",
    dataset_description: str = "",
    create_new_version: bool = False,
    label_column_name: str = "",
    # Submit training job
    register_model_name: str = "",
    save_n_models: int = 1,
    log_azure: bool = True,
    automl_settings: str = None,
) -> None:
    """
    Whole training flow for Azure AutoML. Registers dataset/feature vector,
    submits training job to Azure AutoML, and downloads trained model
    when completed.

    :param context:             MLRun context.

    :param dataset:             MLRun FeatureVector or dataset URI to upload. Will drop
                                index before uploading when it is a FeatureVector.

    :param experiment_name:     Name of experiment to create in Azure ML.
    :param cpu_cluster_name:    Name of Azure ML compute target. Created if does not exist.
    :param vm_size:             Azure machine type for compute target.
    :param max_nodes:           Maximum number of concurrent compute targets.

    :param dataset_name:        Name of Azure dataset to register.
    :param dataset_description: Description of Azure dataset to register.

    :param create_new_version:  Register Azure dataset as new version. Must be used when
                                modifying dataset schema.
    :param label_column_name:   Target column in dataset.

    :param register_model_name: Name of model to register in Azure.
    :param save_n_models:       How many of the top performing models to log.
    :param log_azure:           Displaying Azure logs.
    :param automl_settings:     JSON string of all Azure AutoML settings.
    """
    if not automl_settings:
        automl_settings = {
            "task": "classification",
            "debug_log": "automl_errors.log",
            # "experiment_exit_score": 0.9,
            "enable_early_stopping": False,
            "allowed_models": ["LogisticRegression", "SGD", "SVM"],
            "iterations": 3,
            "iteration_timeout_minutes": 2,
            "max_concurrent_iterations": 2,
            "max_cores_per_iteration": -1,
            "n_cross_validations": 5,
            "primary_metric": "accuracy",
            "featurization": "off",
            "model_explainability": False,
            "enable_voting_ensemble": False,
            "enable_stack_ensemble": False,
        }

    # Init experiment and compute
    workspace, experiment = _init_experiment(
        context=context, experiment_name=experiment_name
    )

    compute_target = init_compute(
        context=context,
        cpu_cluster_name=cpu_cluster_name,
        vm_size=vm_size,
        max_nodes=max_nodes,
    )

    # Register dataset
    register_dataset(
        context=context,
        dataset_name=dataset_name,
        dataset_description=dataset_description,
        data=dataset,
        create_new_version=create_new_version,
    )

    # Submit training job
    submit_training_job(
        context,
        experiment=experiment,
        compute_target=compute_target,
        register_model_name=register_model_name,
        registered_dataset_name=dataset_name,
        label_column_name=label_column_name,
        automl_settings=automl_settings,
        training_set=dataset,
        show_output=log_azure,
        save_n_models=save_n_models,
    )
 + code_origin: '' commands: - apt-get update && apt-get install -y --no-install-recommends git - apt install -y liblttng-ust0 + auto_build: true base_image: python:3.9-bullseye - origin_filename: '' - default_handler: train allow_empty_resources: true - disable_auto_mount: false - image: '' + filename: azureml_utils.py entry_points: init_compute: - doc: 'Initialize Azure ML compute target to run experiment. Checks for - - existing compute target and creates new if does not exist.' - name: init_compute - lineno: 102 - has_kwargs: false + outputs: + - doc: Azure ML Compute Target. + type: ComputeTarget parameters: - name: context type: MLClientCtx @@ -42,20 +45,14 @@ spec: type: int doc: Maximum number of concurrent compute targets. default: 1 - outputs: - - doc: Azure ML Compute Target. - type: ComputeTarget - has_varargs: false - register_dataset: - doc: 'Register dataset object (can be also an Iguazio FeatureVector) in Azure - ML. - - Uploads parquet file to Azure blob storage and registers + name: init_compute + doc: 'Initialize Azure ML compute target to run experiment. Checks for - that file as a dataset in Azure ML.' - name: register_dataset - lineno: 138 + existing compute target and creates new if does not exist.' has_kwargs: false + has_varargs: false + lineno: 99 + register_dataset: parameters: - name: context type: MLClientCtx @@ -74,12 +71,19 @@ spec: doc: Register Azure dataset as new version. Must be used when modifying dataset schema. default: false + name: register_dataset + doc: 'Register dataset object (can be also an Iguazio FeatureVector) in Azure + ML. + + Uploads parquet file to Azure blob storage and registers + + that file as a dataset in Azure ML.' + has_kwargs: false has_varargs: false + lineno: 135 download_model: - doc: Download trained model from Azure ML to local filesystem. - name: download_model - lineno: 217 - has_kwargs: false + outputs: + - type: None parameters: - name: context type: MLClientCtx @@ -94,14 +98,14 @@ spec: type: str doc: Target directory to download model. default: . - outputs: - - type: None + name: download_model + doc: Download trained model from Azure ML to local filesystem. + has_kwargs: false has_varargs: false + lineno: 216 upload_model: - doc: Upload pre-trained model from local filesystem to Azure ML. - name: upload_model - lineno: 238 - has_kwargs: false + outputs: + - type: None parameters: - name: context type: MLClientCtx @@ -120,16 +124,14 @@ spec: type: dict doc: KV pairs of model tags. default: null - outputs: - - type: None + name: upload_model + doc: Upload pre-trained model from local filesystem to Azure ML. + has_kwargs: false has_varargs: false + lineno: 237 submit_training_job: - doc: 'Submit training job to Azure AutoML and download trained model - - when completed. Uses previously registered dataset for training.' - name: submit_training_job - lineno: 352 - has_kwargs: false + outputs: + - type: None parameters: - name: context type: MLClientCtx @@ -164,18 +166,16 @@ spec: type: bool doc: Displaying Azure logs. default: true - outputs: - - type: None - has_varargs: false - train: - doc: 'Whole training flow for Azure AutoML. Registers dataset/feature vector, - - submits training job to Azure AutoML, and downloads trained model + name: submit_training_job + doc: 'Submit training job to Azure AutoML and download trained model - when completed.' - name: train - lineno: 469 + when completed. Uses previously registered dataset for training.' has_kwargs: false + has_varargs: false + lineno: 350 + train: + outputs: + - type: None parameters: - name: context type: MLClientCtx @@ -233,15 +233,16 @@ spec: type: str doc: JSON string of all Azure AutoML settings. default: null - outputs: - - type: None + name: train + doc: 'Whole training flow for Azure AutoML. Registers dataset/feature vector, + + submits training job to Azure AutoML, and downloads trained model + + when completed.' + has_kwargs: false has_varargs: false + lineno: 465 + command: '' description: Azure AutoML integration in MLRun, including utils functions for training models on Azure AutoML platfrom. -kind: job -metadata: - categories: - - model-serving - - utils - tag: '' - name: azureml-utils + default_handler: train diff --git a/functions/src/azureml_utils/test_azureml_utils.py b/functions/src/azureml_utils/test_azureml_utils.py index d6ef80d12..752fc3fee 100644 --- a/functions/src/azureml_utils/test_azureml_utils.py +++ b/functions/src/azureml_utils/test_azureml_utils.py @@ -13,11 +13,11 @@ # limitations under the License. # import os -import tempfile import shutil -import pytest +import tempfile import mlrun +import pytest from mlrun import import_function EXPERIMENT_NAME = "azure-automl-test" @@ -117,7 +117,9 @@ def test_train(): local=True, ) # Get trained models: - num_saved_models = len(azureml_run.status.iterations) - 1 # The first one in the list is the 'columns' + num_saved_models = ( + len(azureml_run.status.iterations) - 1 + ) # The first one in the list is the 'columns' test_pass = num_saved_models == save_n_models except Exception as exception: @@ -125,4 +127,4 @@ def test_train(): _cleanup_environment(artifact_path) - assert test_pass, f'Created {len(model_paths)} models instead of {save_n_models}' + assert test_pass, f"Created {len(model_paths)} models instead of {save_n_models}" diff --git a/functions/src/batch_inference/batch_inference.py b/functions/src/batch_inference/batch_inference.py index 844fdf392..3070c6f72 100644 --- a/functions/src/batch_inference/batch_inference.py +++ b/functions/src/batch_inference/batch_inference.py @@ -15,14 +15,15 @@ import hashlib import json from datetime import datetime -from typing import Any, Dict, List, Tuple, Union -import semver +from typing import Any, Union import mlrun +import semver + if semver.compare(mlrun.__version__, "1.5.0") >= 0: raise mlrun.errors.MLRunNotFoundError( - f"When using `mlrun` version >=1.5.0, please use " - f"batch inference `v2` function ('hub://batch_inference_v2')." + "When using `mlrun` version >=1.5.0, please use " + "batch inference `v2` function ('hub://batch_inference_v2')." ) import mlrun.datastore @@ -45,10 +46,10 @@ def _read_dataset_as_dataframe( dataset: DatasetType, - feature_columns: Union[str, List[str]] = None, - label_columns: Union[str, List[str]] = None, - drop_columns: Union[str, List[str], int, List[int]] = None, -) -> Tuple[pd.DataFrame, List[str]]: + feature_columns: str | list[str] = None, + label_columns: str | list[str] = None, + drop_columns: str | list[str] | int | list[int] = None, +) -> tuple[pd.DataFrame, list[str]]: """ Parse the given dataset into a DataFrame and drop the columns accordingly. In addition, the label columns will be parsed and validated as well. @@ -120,7 +121,7 @@ def _read_dataset_as_dataframe( def _prepare_result_set( - x: pd.DataFrame, label_columns: List[str], y_pred: np.ndarray + x: pd.DataFrame, label_columns: list[str], y_pred: np.ndarray ) -> pd.DataFrame: """ Set default label column names and validate given names to prepare the result set - a concatenation of the inputs @@ -204,7 +205,7 @@ def _get_drift_result( tvd: float, hellinger: float, threshold: float, -) -> Tuple[bool, float]: +) -> tuple[bool, float]: """ Calculate the drift result by the following equation: (tvd + hellinger) / 2 @@ -228,7 +229,7 @@ def _perform_drift_analysis( drift_threshold: float, possible_drift_threshold: float, inf_capping: float, -) -> Tuple[Artifact, Artifact, dict]: +) -> tuple[Artifact, Artifact, dict]: """ Perform drift analysis, producing the drift table artifact for logging post prediction. @@ -318,9 +319,9 @@ def infer( context: mlrun.MLClientCtx, model: str, dataset: DatasetType, - drop_columns: Union[str, List[str], int, List[int]] = None, - label_columns: Union[str, List[str]] = None, - feature_columns: Union[str, List[str]] = None, + drop_columns: str | list[str] | int | list[int] = None, + label_columns: str | list[str] = None, + feature_columns: str | list[str] = None, log_result_set: bool = True, result_set_name: str = "prediction", batch_id: str = None, @@ -330,7 +331,7 @@ def infer( possible_drift_threshold: float = 0.5, inf_capping: float = 10.0, artifacts_tag: str = "", - **predict_kwargs: Dict[str, Any], + **predict_kwargs: dict[str, Any], ): """ Perform a prediction on a given dataset with the given model. Can perform drift analysis between the sample set @@ -368,7 +369,7 @@ def infer( :param artifacts_tag: Tag to use for all the artifacts resulted from the function. """ # Loading the model: - context.logger.info(f"Loading model...") + context.logger.info("Loading model...") model_handler = AutoMLRun.load_model(model_path=model, context=context) if label_columns is None: label_columns = [ @@ -381,7 +382,7 @@ def infer( ] # Get dataset by object, URL or by FeatureVector: - context.logger.info(f"Loading data...") + context.logger.info("Loading data...") x, label_columns = _read_dataset_as_dataframe( dataset=dataset, feature_columns=feature_columns, @@ -390,7 +391,7 @@ def infer( ) # Predict: - context.logger.info(f"Calculating prediction...") + context.logger.info("Calculating prediction...") y_pred = model_handler.model.predict(x, **predict_kwargs) # Prepare the result set: @@ -399,7 +400,7 @@ def infer( # Check for logging the result set: if log_result_set: # Log the result set: - context.logger.info(f"Logging result set (x | prediction)...") + context.logger.info("Logging result set (x | prediction)...") context.log_dataset( key=result_set_name, df=result_set, diff --git a/functions/src/batch_inference/function.yaml b/functions/src/batch_inference/function.yaml index 74b672d4a..0c0ada9cb 100644 --- a/functions/src/batch_inference/function.yaml +++ b/functions/src/batch_inference/function.yaml @@ -1,22 +1,23 @@ -kind: job -verbose: false metadata: - name: batch-inference tag: '' + name: batch-inference categories: - model-serving +verbose: false +kind: job spec: image: mlrun/ml-models + disable_auto_mount: false + build: + origin_filename: '' + with_mlrun: false + functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import hashlib
import json
from datetime import datetime
from typing import Any, Union

import mlrun
import semver

if semver.compare(mlrun.__version__, "1.5.0") >= 0:
    raise mlrun.errors.MLRunNotFoundError(
        "When using `mlrun` version >=1.5.0, please use "
        "batch inference `v2` function ('hub://batch_inference_v2')."
    )

import mlrun.datastore
import mlrun.utils
import numpy as np
import pandas as pd
from mlrun import feature_store as fs
from mlrun.artifacts import Artifact
from mlrun.data_types.infer import InferOptions, get_df_stats
from mlrun.frameworks.auto_mlrun import AutoMLRun
from mlrun.model_monitoring.features_drift_table import FeaturesDriftTablePlot
from mlrun.model_monitoring.model_monitoring_batch import (
    VirtualDrift,
    calculate_inputs_statistics,
)

# A union of all supported dataset types:
DatasetType = Union[mlrun.DataItem, list, dict, pd.DataFrame, pd.Series, np.ndarray]


def _read_dataset_as_dataframe(
    dataset: DatasetType,
    feature_columns: str | list[str] = None,
    label_columns: str | list[str] = None,
    drop_columns: str | list[str] | int | list[int] = None,
) -> tuple[pd.DataFrame, list[str]]:
    """
    Parse the given dataset into a DataFrame and drop the columns accordingly. In addition, the label columns will be
    parsed and validated as well.

    :param dataset:         A dataset that will be converted into a DataFrame.
                            Can be either a list of lists, dict, URI or a FeatureVector.
    :param feature_columns: List of feature columns that will be used to build the dataframe when dataset is from
                            type list or numpy array.
    :param label_columns:   The target label(s) of the column(s) in the dataset. for Regression or
                            Classification tasks.
    :param drop_columns:    ``str`` / ``int`` or a list of ``str`` / ``int`` that represent the column names / indices
                            to drop.

    :returns: A tuple of:
              [0] = The parsed dataset as a DataFrame
              [1] = Label columns.

    raises MLRunInvalidArgumentError: If the `drop_columns` are not matching the dataset or unsupported dataset type.
    """
    # Turn the `drop labels` into a list if given:
    if drop_columns is not None:
        if not isinstance(drop_columns, list):
            drop_columns = [drop_columns]

    # Check if the dataset is in fact a Feature Vector:
    if isinstance(dataset, fs.FeatureVector):
        # Try to get the label columns if not provided:
        if label_columns is None:
            label_columns = dataset.status.label_column
        # Get the features and parse to DataFrame:
        dataset = fs.get_offline_features(
            dataset.uri, drop_columns=drop_columns
        ).to_dataframe()

    elif isinstance(dataset, (list, np.ndarray)):
        if not feature_columns:
            raise mlrun.errors.MLRunInvalidArgumentError(
                "Feature columns list must be provided when dataset input as from type list or numpy array"
            )
        # Parse the list / numpy array into a DataFrame:
        dataset = pd.DataFrame(dataset, columns=feature_columns)
        # Validate the `drop_columns` is given as integers:
        if drop_columns and not all(isinstance(col, int) for col in drop_columns):
            raise mlrun.errors.MLRunInvalidArgumentError(
                "`drop_columns` must be an integer / list of integers if provided as a list."
            )
    elif isinstance(dataset, mlrun.DataItem):
        # Turn the DataITem to DataFrame:
        dataset = dataset.as_df()
    else:
        # Parse the object (should be a pd.DataFrame / pd.Series, dictionary) into a DataFrame:
        try:
            dataset = pd.DataFrame(dataset)
        except ValueError as e:
            raise mlrun.errors.MLRunInvalidArgumentError(
                f"Could not parse the given dataset of type {type(dataset)} into a pandas DataFrame. "
                f"Received the following error: {e}"
            )
    # Drop columns if needed:
    if drop_columns:
        dataset.drop(drop_columns, axis=1, inplace=True)

    # Turn the `label_columns` into a list by default:
    if label_columns is None:
        label_columns = []
    elif isinstance(label_columns, (str, int)):
        label_columns = [label_columns]
    return dataset, label_columns


def _prepare_result_set(
    x: pd.DataFrame, label_columns: list[str], y_pred: np.ndarray
) -> pd.DataFrame:
    """
    Set default label column names and validate given names to prepare the result set - a concatenation of the inputs
    (x) and the model predictions (y_pred).

    :param x:             The inputs.
    :param label_columns: A list of strings representing the target column names to add to the predictions. Default name
                          will be used in case the list is empty (predicted_label_{i}).
    :param y_pred:        The model predictions on the inputs.

    :returns: The result set.

    raises MLRunInvalidArgumentError: If the labels columns amount do not match the outputs or if one of the label
                                       column already exists in the dataset.
    """
    # Prepare default target columns names if not provided:
    prediction_columns_amount = 1 if len(y_pred.shape) == 1 else y_pred.shape[1]
    if len(label_columns) == 0:
        # Add default label column names:
        if prediction_columns_amount == 1:
            label_columns = ["predicted_label"]
        else:
            label_columns = [
                f"predicted_label_{i}" for i in range(prediction_columns_amount)
            ]

    # Validate the label columns:
    if prediction_columns_amount != len(label_columns):
        # No equality between provided label column names and outputs amount:
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"The number of predicted labels: {prediction_columns_amount} "
            f"is not equal to the given label columns: {len(label_columns)}"
        )
    common_labels = set(label_columns) & set(x.columns.tolist())
    if common_labels:
        # Label column exist in the original inputs:
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"The labels: {common_labels} are already existed in the given dataset."
        )

    return pd.concat(
        [x, pd.DataFrame(y_pred, columns=label_columns, index=x.index)], axis=1
    )


def _get_sample_set_statistics(
    sample_set: DatasetType = None, model_artifact_feature_stats: dict = None
) -> dict:
    """
    Get the sample set statistics either from the given sample set or the statistics logged with the model while
    favoring the given sample set.

    :param sample_set:                   A sample dataset to give to compare the inputs in the drift analysis.
    :param model_artifact_feature_stats: The `feature_stats` attribute in the spec of the model artifact, where the
                                         original sample set statistics of the model was used.

    :returns: The sample set statistics.

    raises MLRunInvalidArgumentError: If no sample set or statistics were given.
    """
    # Check if a sample set was provided:
    if sample_set is None:
        # Check if the model was logged with a sample set:
        if model_artifact_feature_stats is None:
            raise mlrun.errors.MLRunInvalidArgumentError(
                "Cannot perform drift analysis as there is no sample set to compare to. The model artifact was not "
                "logged with a sample set and `sample_set` was not provided to the function."
            )
        # Return the statistics logged with the model:
        return model_artifact_feature_stats

    # Turn the DataItem to DataFrame:
    if isinstance(sample_set, mlrun.DataItem):
        sample_set, _ = _read_dataset_as_dataframe(dataset=sample_set)

    # Return the sample set statistics:
    return get_df_stats(df=sample_set, options=InferOptions.Histogram)


def _get_drift_result(
    tvd: float,
    hellinger: float,
    threshold: float,
) -> tuple[bool, float]:
    """
    Calculate the drift result by the following equation: (tvd + hellinger) / 2

    :param tvd:       The feature's TVD value.
    :param hellinger: The feature's Hellinger value.
    :param threshold: The threshold from which the value is considered a drift.

    :returns: A tuple of:
              [0] = Boolean value as the drift status.
              [1] = The result.
    """
    result = (tvd + hellinger) / 2
    if result >= threshold:
        return True, result
    return False, result


def _perform_drift_analysis(
    sample_set_statistics: dict,
    inputs: pd.DataFrame,
    drift_threshold: float,
    possible_drift_threshold: float,
    inf_capping: float,
) -> tuple[Artifact, Artifact, dict]:
    """
    Perform drift analysis, producing the drift table artifact for logging post prediction.

    :param sample_set_statistics:    The statistics of the sample set logged along a model.
    :param inputs:                   Input dataset to perform the drift calculation on.
    :param drift_threshold:          The threshold of which to mark drifts.
    :param possible_drift_threshold: The threshold of which to mark possible drifts.
    :param inf_capping:              The value to set for when it reached infinity.

    :returns: A tuple of
              [0] = An MLRun artifact holding the HTML code of the drift table plot.
              [1] = An MLRun artifact holding the metric per feature dictionary.
              [2] = Results to log the final analysis outcome.
    """
    # Calculate the input's statistics:
    inputs_statistics = calculate_inputs_statistics(
        sample_set_statistics=sample_set_statistics,
        inputs=inputs,
    )

    # Calculate drift:
    virtual_drift = VirtualDrift(inf_capping=inf_capping)
    metrics = virtual_drift.compute_drift_from_histograms(
        feature_stats=sample_set_statistics,
        current_stats=inputs_statistics,
    )
    drift_results = virtual_drift.check_for_drift_per_feature(
        metrics_results_dictionary=metrics,
        possible_drift_threshold=possible_drift_threshold,
        drift_detected_threshold=drift_threshold,
    )

    # Validate all feature columns named the same between the inputs and sample sets:
    sample_features = set(
        [
            feature_name
            for feature_name, feature_statistics in sample_set_statistics.items()
            if isinstance(feature_statistics, dict)
        ]
    )
    input_features = set(inputs.columns)
    if len(sample_features & input_features) != len(input_features):
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"Not all feature names were matching between the inputs and the sample set provided: "
            f"{input_features - sample_features | sample_features - input_features}"
        )

    # Plot:
    html_plot = FeaturesDriftTablePlot().produce(
        features=list(input_features),
        sample_set_statistics=sample_set_statistics,
        inputs_statistics=inputs_statistics,
        metrics=metrics,
        drift_results=drift_results,
    )

    # Prepare metrics per feature dictionary:
    metrics_per_feature = {
        feature: _get_drift_result(
            tvd=metric_dictionary["tvd"],
            hellinger=metric_dictionary["hellinger"],
            threshold=drift_threshold,
        )[1]
        for feature, metric_dictionary in metrics.items()
        if isinstance(metric_dictionary, dict)
    }

    # Calculate the final analysis result:
    drift_status, drift_metric = _get_drift_result(
        tvd=metrics["tvd_mean"],
        hellinger=metrics["hellinger_mean"],
        threshold=drift_threshold,
    )

    return (
        Artifact(body=html_plot, format="html", key="drift_table_plot"),
        Artifact(
            body=json.dumps(metrics_per_feature),
            format="json",
            key="features_drift_results",
        ),
        {"drift_status": drift_status, "drift_metric": drift_metric},
    )


def infer(
    context: mlrun.MLClientCtx,
    model: str,
    dataset: DatasetType,
    drop_columns: str | list[str] | int | list[int] = None,
    label_columns: str | list[str] = None,
    feature_columns: str | list[str] = None,
    log_result_set: bool = True,
    result_set_name: str = "prediction",
    batch_id: str = None,
    perform_drift_analysis: bool = None,
    sample_set: DatasetType = None,
    drift_threshold: float = 0.7,
    possible_drift_threshold: float = 0.5,
    inf_capping: float = 10.0,
    artifacts_tag: str = "",
    **predict_kwargs: dict[str, Any],
):
    """
    Perform a prediction on a given dataset with the given model. Can perform drift analysis between the sample set
    statistics stored in the model to the current input data. The drift rule is the value per-feature mean of the TVD
    and Hellinger scores according to the thresholds configures here.

    :param context:                  MLRun context.
    :param model:                    The model Store path.
    :param dataset:                  The dataset to infer through the model. Can be passed in `inputs` as either a
                                     Dataset artifact / Feature vector URI. Or, in `parameters` as a list, dictionary or
                                     numpy array.
    :param drop_columns:             A string / integer or a list of strings / integers that represent the column names
                                     / indices to drop. When the dataset is a list or a numpy array this parameter must
                                     be represented by integers.
    :param label_columns:            The target label(s) of the column(s) in the dataset for Regression or
                                     Classification tasks. The label column can be accessed from the model object, or
                                     the feature vector provided if available.
    :param feature_columns:          List of feature columns that will be used to build the dataframe when dataset is
                                     from type list or numpy array.
    :param log_result_set:           Whether to log the result set - a DataFrame of the given inputs concatenated with
                                     the predictions. Defaulted to True.
    :param result_set_name:          The db key to set name of the prediction result and the filename. Defaulted to
                                     'prediction'.
    :param batch_id:                 The ID of the given batch (inference dataset). If `None`, it will be generated.
                                     Will be logged as a result of the run.
    :param perform_drift_analysis:   Whether to perform drift analysis between the sample set of the model object to the
                                     dataset given. By default, None, which means it will perform drift analysis if the
                                     model has a sample set statistics. Perform drift analysis will produce a data drift
                                     table artifact.
    :param sample_set:               A sample dataset to give to compare the inputs in the drift analysis. The default
                                     chosen sample set will always be the one who is set in the model artifact itself.
    :param drift_threshold:          The threshold of which to mark drifts. Defaulted to 0.7.
    :param possible_drift_threshold: The threshold of which to mark possible drifts. Defaulted to 0.5.
    :param inf_capping:              The value to set for when it reached infinity. Defaulted to 10.0.
    :param artifacts_tag:            Tag to use for all the artifacts resulted from the function.
    """
    # Loading the model:
    context.logger.info("Loading model...")
    model_handler = AutoMLRun.load_model(model_path=model, context=context)
    if label_columns is None:
        label_columns = [
            output.name for output in model_handler._model_artifact.spec.outputs
        ]

    if feature_columns is None:
        feature_columns = [
            input.name for input in model_handler._model_artifact.spec.inputs
        ]

    # Get dataset by object, URL or by FeatureVector:
    context.logger.info("Loading data...")
    x, label_columns = _read_dataset_as_dataframe(
        dataset=dataset,
        feature_columns=feature_columns,
        label_columns=label_columns,
        drop_columns=drop_columns,
    )

    # Predict:
    context.logger.info("Calculating prediction...")
    y_pred = model_handler.model.predict(x, **predict_kwargs)

    # Prepare the result set:
    result_set = _prepare_result_set(x=x, label_columns=label_columns, y_pred=y_pred)

    # Check for logging the result set:
    if log_result_set:
        # Log the result set:
        context.logger.info("Logging result set (x | prediction)...")
        context.log_dataset(
            key=result_set_name,
            df=result_set,
            db_key=result_set_name,
            tag=artifacts_tag,
        )
        # Log the batch ID:
        if batch_id is None:
            batch_id = hashlib.sha224(str(datetime.now()).encode()).hexdigest()
        context.log_result(
            key="batch_id",
            value=batch_id,
        )

    # Check for performing drift analysis:
    if (
        perform_drift_analysis is None
        and model_handler._model_artifact.spec.feature_stats is not None
    ):
        perform_drift_analysis = True
    if perform_drift_analysis:
        context.logger.info("Performing drift analysis...")
        # Get the sample set statistics (either from the sample set or from the statistics logged with the model):
        sample_set_statistics = _get_sample_set_statistics(
            sample_set=sample_set,
            model_artifact_feature_stats=model_handler._model_artifact.spec.feature_stats,
        )
        # Produce the artifact:
        (
            drift_table_plot,
            metric_per_feature_dict,
            analysis_results,
        ) = _perform_drift_analysis(
            sample_set_statistics=sample_set_statistics,
            inputs=result_set,
            drift_threshold=drift_threshold,
            possible_drift_threshold=possible_drift_threshold,
            inf_capping=inf_capping,
        )
        # Log the artifact and results:
        context.log_artifact(drift_table_plot, tag=artifacts_tag)
        context.log_artifact(metric_per_feature_dict, tag=artifacts_tag)
        context.log_results(results=analysis_results)
 + code_origin: '' + auto_build: false + allow_empty_resources: true + filename: batch_inference.py entry_points: infer: - name: infer - doc: 'Perform a prediction on a given dataset with the given model. Can perform - drift analysis between the sample set - - statistics stored in the model to the current input data. The drift rule is - the value per-feature mean of the TVD - - and Hellinger scores according to the thresholds configures here.' parameters: - name: context type: MLClientCtx @@ -30,19 +31,16 @@ spec: either a Dataset artifact / Feature vector URI. Or, in `parameters` as a list, dictionary or numpy array. - name: drop_columns - type: Union[str, List[str], int, List[int]] doc: A string / integer or a list of strings / integers that represent the column names / indices to drop. When the dataset is a list or a numpy array this parameter must be represented by integers. default: null - name: label_columns - type: Union[str, List[str]] doc: The target label(s) of the column(s) in the dataset for Regression or Classification tasks. The label column can be accessed from the model object, or the feature vector provided if available. default: null - name: feature_columns - type: Union[str, List[str]] doc: List of feature columns that will be used to build the dataframe when dataset is from type list or numpy array. default: null @@ -90,18 +88,18 @@ spec: type: str doc: Tag to use for all the artifacts resulted from the function. default: '' - lineno: 317 + name: infer + doc: 'Perform a prediction on a given dataset with the given model. Can perform + drift analysis between the sample set + + statistics stored in the model to the current input data. The drift rule is + the value per-feature mean of the TVD + + and Hellinger scores according to the thresholds configures here.' has_kwargs: true has_varargs: false - allow_empty_resources: true - default_handler: infer + lineno: 318 command: '' - build: - functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import hashlib
import json
from datetime import datetime
from typing import Any, Dict, List, Tuple, Union
import semver

import mlrun
if semver.compare(mlrun.__version__, "1.5.0") >= 0:
    raise mlrun.errors.MLRunNotFoundError(
        f"When using `mlrun` version >=1.5.0, please use "
        f"batch inference `v2` function ('hub://batch_inference_v2')."
    )

import mlrun.datastore
import mlrun.utils
import numpy as np
import pandas as pd
from mlrun import feature_store as fs
from mlrun.artifacts import Artifact
from mlrun.data_types.infer import InferOptions, get_df_stats
from mlrun.frameworks.auto_mlrun import AutoMLRun
from mlrun.model_monitoring.features_drift_table import FeaturesDriftTablePlot
from mlrun.model_monitoring.model_monitoring_batch import (
    VirtualDrift,
    calculate_inputs_statistics,
)

# A union of all supported dataset types:
DatasetType = Union[mlrun.DataItem, list, dict, pd.DataFrame, pd.Series, np.ndarray]


def _read_dataset_as_dataframe(
    dataset: DatasetType,
    feature_columns: Union[str, List[str]] = None,
    label_columns: Union[str, List[str]] = None,
    drop_columns: Union[str, List[str], int, List[int]] = None,
) -> Tuple[pd.DataFrame, List[str]]:
    """
    Parse the given dataset into a DataFrame and drop the columns accordingly. In addition, the label columns will be
    parsed and validated as well.

    :param dataset:         A dataset that will be converted into a DataFrame.
                            Can be either a list of lists, dict, URI or a FeatureVector.
    :param feature_columns: List of feature columns that will be used to build the dataframe when dataset is from
                            type list or numpy array.
    :param label_columns:   The target label(s) of the column(s) in the dataset. for Regression or
                            Classification tasks.
    :param drop_columns:    ``str`` / ``int`` or a list of ``str`` / ``int`` that represent the column names / indices
                            to drop.

    :returns: A tuple of:
              [0] = The parsed dataset as a DataFrame
              [1] = Label columns.

    raises MLRunInvalidArgumentError: If the `drop_columns` are not matching the dataset or unsupported dataset type.
    """
    # Turn the `drop labels` into a list if given:
    if drop_columns is not None:
        if not isinstance(drop_columns, list):
            drop_columns = [drop_columns]

    # Check if the dataset is in fact a Feature Vector:
    if isinstance(dataset, fs.FeatureVector):
        # Try to get the label columns if not provided:
        if label_columns is None:
            label_columns = dataset.status.label_column
        # Get the features and parse to DataFrame:
        dataset = fs.get_offline_features(
            dataset.uri, drop_columns=drop_columns
        ).to_dataframe()

    elif isinstance(dataset, (list, np.ndarray)):
        if not feature_columns:
            raise mlrun.errors.MLRunInvalidArgumentError(
                "Feature columns list must be provided when dataset input as from type list or numpy array"
            )
        # Parse the list / numpy array into a DataFrame:
        dataset = pd.DataFrame(dataset, columns=feature_columns)
        # Validate the `drop_columns` is given as integers:
        if drop_columns and not all(isinstance(col, int) for col in drop_columns):
            raise mlrun.errors.MLRunInvalidArgumentError(
                "`drop_columns` must be an integer / list of integers if provided as a list."
            )
    elif isinstance(dataset, mlrun.DataItem):
        # Turn the DataITem to DataFrame:
        dataset = dataset.as_df()
    else:
        # Parse the object (should be a pd.DataFrame / pd.Series, dictionary) into a DataFrame:
        try:
            dataset = pd.DataFrame(dataset)
        except ValueError as e:
            raise mlrun.errors.MLRunInvalidArgumentError(
                f"Could not parse the given dataset of type {type(dataset)} into a pandas DataFrame. "
                f"Received the following error: {e}"
            )
    # Drop columns if needed:
    if drop_columns:
        dataset.drop(drop_columns, axis=1, inplace=True)

    # Turn the `label_columns` into a list by default:
    if label_columns is None:
        label_columns = []
    elif isinstance(label_columns, (str, int)):
        label_columns = [label_columns]
    return dataset, label_columns


def _prepare_result_set(
    x: pd.DataFrame, label_columns: List[str], y_pred: np.ndarray
) -> pd.DataFrame:
    """
    Set default label column names and validate given names to prepare the result set - a concatenation of the inputs
    (x) and the model predictions (y_pred).

    :param x:             The inputs.
    :param label_columns: A list of strings representing the target column names to add to the predictions. Default name
                          will be used in case the list is empty (predicted_label_{i}).
    :param y_pred:        The model predictions on the inputs.

    :returns: The result set.

    raises MLRunInvalidArgumentError: If the labels columns amount do not match the outputs or if one of the label
                                       column already exists in the dataset.
    """
    # Prepare default target columns names if not provided:
    prediction_columns_amount = 1 if len(y_pred.shape) == 1 else y_pred.shape[1]
    if len(label_columns) == 0:
        # Add default label column names:
        if prediction_columns_amount == 1:
            label_columns = ["predicted_label"]
        else:
            label_columns = [
                f"predicted_label_{i}" for i in range(prediction_columns_amount)
            ]

    # Validate the label columns:
    if prediction_columns_amount != len(label_columns):
        # No equality between provided label column names and outputs amount:
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"The number of predicted labels: {prediction_columns_amount} "
            f"is not equal to the given label columns: {len(label_columns)}"
        )
    common_labels = set(label_columns) & set(x.columns.tolist())
    if common_labels:
        # Label column exist in the original inputs:
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"The labels: {common_labels} are already existed in the given dataset."
        )

    return pd.concat(
        [x, pd.DataFrame(y_pred, columns=label_columns, index=x.index)], axis=1
    )


def _get_sample_set_statistics(
    sample_set: DatasetType = None, model_artifact_feature_stats: dict = None
) -> dict:
    """
    Get the sample set statistics either from the given sample set or the statistics logged with the model while
    favoring the given sample set.

    :param sample_set:                   A sample dataset to give to compare the inputs in the drift analysis.
    :param model_artifact_feature_stats: The `feature_stats` attribute in the spec of the model artifact, where the
                                         original sample set statistics of the model was used.

    :returns: The sample set statistics.

    raises MLRunInvalidArgumentError: If no sample set or statistics were given.
    """
    # Check if a sample set was provided:
    if sample_set is None:
        # Check if the model was logged with a sample set:
        if model_artifact_feature_stats is None:
            raise mlrun.errors.MLRunInvalidArgumentError(
                "Cannot perform drift analysis as there is no sample set to compare to. The model artifact was not "
                "logged with a sample set and `sample_set` was not provided to the function."
            )
        # Return the statistics logged with the model:
        return model_artifact_feature_stats

    # Turn the DataItem to DataFrame:
    if isinstance(sample_set, mlrun.DataItem):
        sample_set, _ = _read_dataset_as_dataframe(dataset=sample_set)

    # Return the sample set statistics:
    return get_df_stats(df=sample_set, options=InferOptions.Histogram)


def _get_drift_result(
    tvd: float,
    hellinger: float,
    threshold: float,
) -> Tuple[bool, float]:
    """
    Calculate the drift result by the following equation: (tvd + hellinger) / 2

    :param tvd:       The feature's TVD value.
    :param hellinger: The feature's Hellinger value.
    :param threshold: The threshold from which the value is considered a drift.

    :returns: A tuple of:
              [0] = Boolean value as the drift status.
              [1] = The result.
    """
    result = (tvd + hellinger) / 2
    if result >= threshold:
        return True, result
    return False, result


def _perform_drift_analysis(
    sample_set_statistics: dict,
    inputs: pd.DataFrame,
    drift_threshold: float,
    possible_drift_threshold: float,
    inf_capping: float,
) -> Tuple[Artifact, Artifact, dict]:
    """
    Perform drift analysis, producing the drift table artifact for logging post prediction.

    :param sample_set_statistics:    The statistics of the sample set logged along a model.
    :param inputs:                   Input dataset to perform the drift calculation on.
    :param drift_threshold:          The threshold of which to mark drifts.
    :param possible_drift_threshold: The threshold of which to mark possible drifts.
    :param inf_capping:              The value to set for when it reached infinity.

    :returns: A tuple of
              [0] = An MLRun artifact holding the HTML code of the drift table plot.
              [1] = An MLRun artifact holding the metric per feature dictionary.
              [2] = Results to log the final analysis outcome.
    """
    # Calculate the input's statistics:
    inputs_statistics = calculate_inputs_statistics(
        sample_set_statistics=sample_set_statistics,
        inputs=inputs,
    )

    # Calculate drift:
    virtual_drift = VirtualDrift(inf_capping=inf_capping)
    metrics = virtual_drift.compute_drift_from_histograms(
        feature_stats=sample_set_statistics,
        current_stats=inputs_statistics,
    )
    drift_results = virtual_drift.check_for_drift_per_feature(
        metrics_results_dictionary=metrics,
        possible_drift_threshold=possible_drift_threshold,
        drift_detected_threshold=drift_threshold,
    )

    # Validate all feature columns named the same between the inputs and sample sets:
    sample_features = set(
        [
            feature_name
            for feature_name, feature_statistics in sample_set_statistics.items()
            if isinstance(feature_statistics, dict)
        ]
    )
    input_features = set(inputs.columns)
    if len(sample_features & input_features) != len(input_features):
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"Not all feature names were matching between the inputs and the sample set provided: "
            f"{input_features - sample_features | sample_features - input_features}"
        )

    # Plot:
    html_plot = FeaturesDriftTablePlot().produce(
        features=list(input_features),
        sample_set_statistics=sample_set_statistics,
        inputs_statistics=inputs_statistics,
        metrics=metrics,
        drift_results=drift_results,
    )

    # Prepare metrics per feature dictionary:
    metrics_per_feature = {
        feature: _get_drift_result(
            tvd=metric_dictionary["tvd"],
            hellinger=metric_dictionary["hellinger"],
            threshold=drift_threshold,
        )[1]
        for feature, metric_dictionary in metrics.items()
        if isinstance(metric_dictionary, dict)
    }

    # Calculate the final analysis result:
    drift_status, drift_metric = _get_drift_result(
        tvd=metrics["tvd_mean"],
        hellinger=metrics["hellinger_mean"],
        threshold=drift_threshold,
    )

    return (
        Artifact(body=html_plot, format="html", key="drift_table_plot"),
        Artifact(
            body=json.dumps(metrics_per_feature),
            format="json",
            key="features_drift_results",
        ),
        {"drift_status": drift_status, "drift_metric": drift_metric},
    )


def infer(
    context: mlrun.MLClientCtx,
    model: str,
    dataset: DatasetType,
    drop_columns: Union[str, List[str], int, List[int]] = None,
    label_columns: Union[str, List[str]] = None,
    feature_columns: Union[str, List[str]] = None,
    log_result_set: bool = True,
    result_set_name: str = "prediction",
    batch_id: str = None,
    perform_drift_analysis: bool = None,
    sample_set: DatasetType = None,
    drift_threshold: float = 0.7,
    possible_drift_threshold: float = 0.5,
    inf_capping: float = 10.0,
    artifacts_tag: str = "",
    **predict_kwargs: Dict[str, Any],
):
    """
    Perform a prediction on a given dataset with the given model. Can perform drift analysis between the sample set
    statistics stored in the model to the current input data. The drift rule is the value per-feature mean of the TVD
    and Hellinger scores according to the thresholds configures here.

    :param context:                  MLRun context.
    :param model:                    The model Store path.
    :param dataset:                  The dataset to infer through the model. Can be passed in `inputs` as either a
                                     Dataset artifact / Feature vector URI. Or, in `parameters` as a list, dictionary or
                                     numpy array.
    :param drop_columns:             A string / integer or a list of strings / integers that represent the column names
                                     / indices to drop. When the dataset is a list or a numpy array this parameter must
                                     be represented by integers.
    :param label_columns:            The target label(s) of the column(s) in the dataset for Regression or
                                     Classification tasks. The label column can be accessed from the model object, or
                                     the feature vector provided if available.
    :param feature_columns:          List of feature columns that will be used to build the dataframe when dataset is
                                     from type list or numpy array.
    :param log_result_set:           Whether to log the result set - a DataFrame of the given inputs concatenated with
                                     the predictions. Defaulted to True.
    :param result_set_name:          The db key to set name of the prediction result and the filename. Defaulted to
                                     'prediction'.
    :param batch_id:                 The ID of the given batch (inference dataset). If `None`, it will be generated.
                                     Will be logged as a result of the run.
    :param perform_drift_analysis:   Whether to perform drift analysis between the sample set of the model object to the
                                     dataset given. By default, None, which means it will perform drift analysis if the
                                     model has a sample set statistics. Perform drift analysis will produce a data drift
                                     table artifact.
    :param sample_set:               A sample dataset to give to compare the inputs in the drift analysis. The default
                                     chosen sample set will always be the one who is set in the model artifact itself.
    :param drift_threshold:          The threshold of which to mark drifts. Defaulted to 0.7.
    :param possible_drift_threshold: The threshold of which to mark possible drifts. Defaulted to 0.5.
    :param inf_capping:              The value to set for when it reached infinity. Defaulted to 10.0.
    :param artifacts_tag:            Tag to use for all the artifacts resulted from the function.
    """
    # Loading the model:
    context.logger.info(f"Loading model...")
    model_handler = AutoMLRun.load_model(model_path=model, context=context)
    if label_columns is None:
        label_columns = [
            output.name for output in model_handler._model_artifact.spec.outputs
        ]

    if feature_columns is None:
        feature_columns = [
            input.name for input in model_handler._model_artifact.spec.inputs
        ]

    # Get dataset by object, URL or by FeatureVector:
    context.logger.info(f"Loading data...")
    x, label_columns = _read_dataset_as_dataframe(
        dataset=dataset,
        feature_columns=feature_columns,
        label_columns=label_columns,
        drop_columns=drop_columns,
    )

    # Predict:
    context.logger.info(f"Calculating prediction...")
    y_pred = model_handler.model.predict(x, **predict_kwargs)

    # Prepare the result set:
    result_set = _prepare_result_set(x=x, label_columns=label_columns, y_pred=y_pred)

    # Check for logging the result set:
    if log_result_set:
        # Log the result set:
        context.logger.info(f"Logging result set (x | prediction)...")
        context.log_dataset(
            key=result_set_name,
            df=result_set,
            db_key=result_set_name,
            tag=artifacts_tag,
        )
        # Log the batch ID:
        if batch_id is None:
            batch_id = hashlib.sha224(str(datetime.now()).encode()).hexdigest()
        context.log_result(
            key="batch_id",
            value=batch_id,
        )

    # Check for performing drift analysis:
    if (
        perform_drift_analysis is None
        and model_handler._model_artifact.spec.feature_stats is not None
    ):
        perform_drift_analysis = True
    if perform_drift_analysis:
        context.logger.info("Performing drift analysis...")
        # Get the sample set statistics (either from the sample set or from the statistics logged with the model):
        sample_set_statistics = _get_sample_set_statistics(
            sample_set=sample_set,
            model_artifact_feature_stats=model_handler._model_artifact.spec.feature_stats,
        )
        # Produce the artifact:
        (
            drift_table_plot,
            metric_per_feature_dict,
            analysis_results,
        ) = _perform_drift_analysis(
            sample_set_statistics=sample_set_statistics,
            inputs=result_set,
            drift_threshold=drift_threshold,
            possible_drift_threshold=possible_drift_threshold,
            inf_capping=inf_capping,
        )
        # Log the artifact and results:
        context.log_artifact(drift_table_plot, tag=artifacts_tag)
        context.log_artifact(metric_per_feature_dict, tag=artifacts_tag)
        context.log_results(results=analysis_results)
 - origin_filename: '' - auto_build: false - code_origin: '' - with_mlrun: false - disable_auto_mount: false description: Batch inference (also knows as prediction) for the common ML frameworks (SciKit-Learn, XGBoost and LightGBM) while performing data drift analysis. + default_handler: infer diff --git a/functions/src/batch_inference/test_batch_inference.py b/functions/src/batch_inference/test_batch_inference.py index d18d27a9b..e37a7d000 100644 --- a/functions/src/batch_inference/test_batch_inference.py +++ b/functions/src/batch_inference/test_batch_inference.py @@ -86,7 +86,6 @@ def train(training_set: pd.DataFrame): reason="Project's environment variables are not set", ) def test_batch_predict(): - project = mlrun.get_or_create_project( "batch-infer-v9-test", context="./", user_project=True ) @@ -132,7 +131,7 @@ def test_batch_predict(): # Check the features drift results json: drift_results_file = batch_predict_run.artifact("features_drift_results").local() - with open(drift_results_file, "r") as json_file: + with open(drift_results_file) as json_file: drift_results = json.load(json_file) assert len(drift_results) == n_features + 1 diff --git a/functions/src/batch_inference_v2/batch_inference_v2.py b/functions/src/batch_inference_v2/batch_inference_v2.py index c12b04972..3c4ade07b 100644 --- a/functions/src/batch_inference_v2/batch_inference_v2.py +++ b/functions/src/batch_inference_v2/batch_inference_v2.py @@ -13,15 +13,16 @@ # limitations under the License. from inspect import signature -from typing import Any, Dict, List, Union, Optional +from typing import Any + import mlrun try: import mlrun.model_monitoring.api except ModuleNotFoundError: raise mlrun.errors.MLRunNotFoundError( - f"Please update your `mlrun` version to >=1.5.0 or use an " - f"older version of the batch inference function." + "Please update your `mlrun` version to >=1.5.0 or use an " + "older version of the batch inference function." ) import numpy as np @@ -29,7 +30,9 @@ from mlrun.frameworks.auto_mlrun import AutoMLRun -def _prepare_result_set(x: pd.DataFrame, label_columns: List[str], y_pred: np.ndarray) -> pd.DataFrame: +def _prepare_result_set( + x: pd.DataFrame, label_columns: list[str], y_pred: np.ndarray +) -> pd.DataFrame: """ Set default label column names and validate given names to prepare the result set - a concatenation of the inputs (x) and the model predictions (y_pred). @@ -74,63 +77,75 @@ def _prepare_result_set(x: pd.DataFrame, label_columns: List[str], y_pred: np.nd ) -def _get_sample_set_statistics_parameters(context: mlrun.MLClientCtx, - model_endpoint_sample_set: Union[ - mlrun.DataItem, list, dict, pd.DataFrame, pd.Series, np.ndarray], - model_artifact_feature_stats: dict, - feature_columns: Optional[List], - drop_columns: Optional[List], - label_columns: Optional[List]) -> Dict[str, Any]: - statics_input_full_dict = dict(sample_set=model_endpoint_sample_set, - model_artifact_feature_stats=model_artifact_feature_stats, - sample_set_columns=feature_columns, - sample_set_drop_columns=drop_columns, - sample_set_label_columns=label_columns) +def _get_sample_set_statistics_parameters( + context: mlrun.MLClientCtx, + model_endpoint_sample_set: mlrun.DataItem + | list + | dict + | pd.DataFrame + | pd.Series + | np.ndarray, + model_artifact_feature_stats: dict, + feature_columns: list | None, + drop_columns: list | None, + label_columns: list | None, +) -> dict[str, Any]: + statics_input_full_dict = dict( + sample_set=model_endpoint_sample_set, + model_artifact_feature_stats=model_artifact_feature_stats, + sample_set_columns=feature_columns, + sample_set_drop_columns=drop_columns, + sample_set_label_columns=label_columns, + ) get_sample_statics_function = mlrun.model_monitoring.api.get_sample_set_statistics statics_function_input_dict = signature(get_sample_statics_function).parameters # As a result of changes to input parameters in the mlrun-get_sample_set_statistics function, # we will now send only the parameters it expects. - statistics_input_filtered = {key: statics_input_full_dict[key] for key in statics_function_input_dict} + statistics_input_filtered = { + key: statics_input_full_dict[key] for key in statics_function_input_dict + } if len(statistics_input_filtered) != len(statics_function_input_dict): - context.logger.warning(f"get_sample_set_statistics is in an older version; " - "some parameters will not be sent to the function." - f" Expected input: {list(statics_function_input_dict.keys())}," - f" actual input: {list(statistics_input_filtered.keys())}") + context.logger.warning( + f"get_sample_set_statistics is in an older version; " + "some parameters will not be sent to the function." + f" Expected input: {list(statics_function_input_dict.keys())}," + f" actual input: {list(statistics_input_filtered.keys())}" + ) return statistics_input_filtered def infer( - context: mlrun.MLClientCtx, - dataset: Union[mlrun.DataItem, list, dict, pd.DataFrame, pd.Series, np.ndarray], - model_path: Union[str, mlrun.DataItem], - drop_columns: Union[str, List[str], int, List[int]] = None, - label_columns: Union[str, List[str]] = None, - feature_columns: Union[str, List[str]] = None, - log_result_set: bool = True, - result_set_name: str = "prediction", - batch_id: str = None, - artifacts_tag: str = "", - # Drift analysis parameters - perform_drift_analysis: bool = None, - endpoint_id: str = "", - # The following model endpoint parameters are relevant only if: - # perform drift analysis is not disabled - # a new model endpoint record is going to be generated - model_endpoint_name: str = "batch-infer", - model_endpoint_sample_set: Union[ - mlrun.DataItem, list, dict, pd.DataFrame, pd.Series, np.ndarray - ] = None, - - # the following parameters are deprecated and will be removed once the versioning mechanism is implemented - # TODO: Remove the following parameters once FHUB-13 is resolved - trigger_monitoring_job: Optional[bool] = None, - batch_image_job: Optional[str] = None, - model_endpoint_drift_threshold: Optional[float] = None, - model_endpoint_possible_drift_threshold: Optional[float] = None, - - # prediction kwargs to pass to the model predict function - **predict_kwargs: Dict[str, Any], - + context: mlrun.MLClientCtx, + dataset: mlrun.DataItem | list | dict | pd.DataFrame | pd.Series | np.ndarray, + model_path: str | mlrun.DataItem, + drop_columns: str | list[str] | int | list[int] = None, + label_columns: str | list[str] = None, + feature_columns: str | list[str] = None, + log_result_set: bool = True, + result_set_name: str = "prediction", + batch_id: str = None, + artifacts_tag: str = "", + # Drift analysis parameters + perform_drift_analysis: bool = None, + endpoint_id: str = "", + # The following model endpoint parameters are relevant only if: + # perform drift analysis is not disabled + # a new model endpoint record is going to be generated + model_endpoint_name: str = "batch-infer", + model_endpoint_sample_set: mlrun.DataItem + | list + | dict + | pd.DataFrame + | pd.Series + | np.ndarray = None, + # the following parameters are deprecated and will be removed once the versioning mechanism is implemented + # TODO: Remove the following parameters once FHUB-13 is resolved + trigger_monitoring_job: bool | None = None, + batch_image_job: str | None = None, + model_endpoint_drift_threshold: float | None = None, + model_endpoint_possible_drift_threshold: float | None = None, + # prediction kwargs to pass to the model predict function + **predict_kwargs: dict[str, Any], ): """ Perform a prediction on the provided dataset using the specified model. @@ -192,26 +207,33 @@ def infer( raises MLRunInvalidArgumentError: if both `model_path` and `endpoint_id` are not provided """ - if trigger_monitoring_job: - context.logger.warning("The `trigger_monitoring_job` parameter is deprecated and will be removed once the versioning mechanism is implemented. " - "if you are using mlrun<1.7.0, please import the previous version of this function, for example " - "'hub://batch_inference_v2:2.5.0'.") + context.logger.warning( + "The `trigger_monitoring_job` parameter is deprecated and will be removed once the versioning mechanism is implemented. " + "if you are using mlrun<1.7.0, please import the previous version of this function, for example " + "'hub://batch_inference_v2:2.5.0'." + ) if batch_image_job: - context.logger.warning("The `batch_image_job` parameter is deprecated and will be removed once the versioning mechanism is implemented. " - "if you are using mlrun<1.7.0, please import the previous version of this function, for example " - "'hub://batch_inference_v2:2.5.0'.") + context.logger.warning( + "The `batch_image_job` parameter is deprecated and will be removed once the versioning mechanism is implemented. " + "if you are using mlrun<1.7.0, please import the previous version of this function, for example " + "'hub://batch_inference_v2:2.5.0'." + ) if model_endpoint_drift_threshold: - context.logger.warning("The `model_endpoint_drift_threshold` parameter is deprecated and will be removed once the versioning mechanism is implemented. " - "if you are using mlrun<1.7.0, please import the previous version of this function, for example " - "'hub://batch_inference_v2:2.5.0'.") + context.logger.warning( + "The `model_endpoint_drift_threshold` parameter is deprecated and will be removed once the versioning mechanism is implemented. " + "if you are using mlrun<1.7.0, please import the previous version of this function, for example " + "'hub://batch_inference_v2:2.5.0'." + ) if model_endpoint_possible_drift_threshold: - context.logger.warning("The `model_endpoint_possible_drift_threshold` parameter is deprecated and will be removed once the versioning mechanism is implemented. " - "if you are using mlrun<1.7.0, please import the previous version of this function, for example " - "'hub://batch_inference_v2:2.5.0'.") + context.logger.warning( + "The `model_endpoint_possible_drift_threshold` parameter is deprecated and will be removed once the versioning mechanism is implemented. " + "if you are using mlrun<1.7.0, please import the previous version of this function, for example " + "'hub://batch_inference_v2:2.5.0'." + ) # Loading the model: - context.logger.info(f"Loading model...") + context.logger.info("Loading model...") if isinstance(model_path, mlrun.DataItem): model_path = model_path.artifact_url if not mlrun.datastore.is_store_uri(model_path): @@ -233,7 +255,7 @@ def infer( ] # Get dataset by object, URL or by FeatureVector: - context.logger.info(f"Loading data...") + context.logger.info("Loading data...") x, label_columns = mlrun.model_monitoring.api.read_dataset_as_dataframe( dataset=dataset, feature_columns=feature_columns, @@ -242,7 +264,7 @@ def infer( ) # Predict: - context.logger.info(f"Calculating prediction...") + context.logger.info("Calculating prediction...") y_pred = model_handler.model.predict(x, **predict_kwargs) # Prepare the result set: @@ -260,8 +282,8 @@ def infer( # Check for performing drift analysis if ( - perform_drift_analysis is None - and model_handler._model_artifact.spec.feature_stats is not None + perform_drift_analysis is None + and model_handler._model_artifact.spec.feature_stats is not None ): perform_drift_analysis = True if perform_drift_analysis: @@ -273,8 +295,11 @@ def infer( model_artifact_feature_stats=model_handler._model_artifact.spec.feature_stats, feature_columns=feature_columns, drop_columns=drop_columns, - label_columns=label_columns) - sample_set_statistics = mlrun.model_monitoring.api.get_sample_set_statistics(**statistics_input_filtered) + label_columns=label_columns, + ) + sample_set_statistics = mlrun.model_monitoring.api.get_sample_set_statistics( + **statistics_input_filtered + ) mlrun.model_monitoring.api.record_results( project=context.project, context=context, @@ -283,4 +308,4 @@ def infer( model_endpoint_name=model_endpoint_name, infer_results_df=result_set.copy(), sample_set_statistics=sample_set_statistics, - ) \ No newline at end of file + ) diff --git a/functions/src/batch_inference_v2/function.yaml b/functions/src/batch_inference_v2/function.yaml index 014cb2167..8c327e9d6 100644 --- a/functions/src/batch_inference_v2/function.yaml +++ b/functions/src/batch_inference_v2/function.yaml @@ -1,21 +1,32 @@ +metadata: + tag: '' + name: batch-inference-v2 + categories: + - model-serving verbose: false +kind: job spec: - default_handler: infer + image: mlrun/mlrun + disable_auto_mount: false + build: + origin_filename: '' + with_mlrun: false + functionSourceCode: # Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from inspect import signature
from typing import Any

import mlrun

try:
    import mlrun.model_monitoring.api
except ModuleNotFoundError:
    raise mlrun.errors.MLRunNotFoundError(
        "Please update your `mlrun` version to >=1.5.0 or use an "
        "older version of the batch inference function."
    )

import numpy as np
import pandas as pd
from mlrun.frameworks.auto_mlrun import AutoMLRun


def _prepare_result_set(
    x: pd.DataFrame, label_columns: list[str], y_pred: np.ndarray
) -> pd.DataFrame:
    """
    Set default label column names and validate given names to prepare the result set - a concatenation of the inputs
    (x) and the model predictions (y_pred).

    :param x:             The inputs.
    :param label_columns: A list of strings representing the target column names to add to the predictions. Default name
                          will be used in case the list is empty (predicted_label_{i}).
    :param y_pred:        The model predictions on the inputs.

    :returns: The result set.

    raises MLRunInvalidArgumentError: If the labels columns amount do not match the outputs or if one of the label
                                       column already exists in the dataset.
    """
    # Prepare default target columns names if not provided:
    prediction_columns_amount = 1 if len(y_pred.shape) == 1 else y_pred.shape[1]
    if len(label_columns) == 0:
        # Add default label column names:
        if prediction_columns_amount == 1:
            label_columns = ["predicted_label"]
        else:
            label_columns = [
                f"predicted_label_{i}" for i in range(prediction_columns_amount)
            ]

    # Validate the label columns:
    if prediction_columns_amount != len(label_columns):
        # No equality between provided label column names and outputs amount:
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"The number of predicted labels: {prediction_columns_amount} "
            f"is not equal to the given label columns: {len(label_columns)}"
        )
    common_labels = set(label_columns) & set(x.columns.tolist())
    if common_labels:
        # Label column exist in the original inputs:
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"The labels: {common_labels} are already existed in the given dataset."
        )

    return pd.concat(
        [x, pd.DataFrame(y_pred, columns=label_columns, index=x.index)], axis=1
    )


def _get_sample_set_statistics_parameters(
    context: mlrun.MLClientCtx,
    model_endpoint_sample_set: mlrun.DataItem
    | list
    | dict
    | pd.DataFrame
    | pd.Series
    | np.ndarray,
    model_artifact_feature_stats: dict,
    feature_columns: list | None,
    drop_columns: list | None,
    label_columns: list | None,
) -> dict[str, Any]:
    statics_input_full_dict = dict(
        sample_set=model_endpoint_sample_set,
        model_artifact_feature_stats=model_artifact_feature_stats,
        sample_set_columns=feature_columns,
        sample_set_drop_columns=drop_columns,
        sample_set_label_columns=label_columns,
    )
    get_sample_statics_function = mlrun.model_monitoring.api.get_sample_set_statistics
    statics_function_input_dict = signature(get_sample_statics_function).parameters
    #  As a result of changes to input parameters in the mlrun-get_sample_set_statistics function,
    #  we will now send only the parameters it expects.
    statistics_input_filtered = {
        key: statics_input_full_dict[key] for key in statics_function_input_dict
    }
    if len(statistics_input_filtered) != len(statics_function_input_dict):
        context.logger.warning(
            f"get_sample_set_statistics is in an older version; "
            "some parameters will not be sent to the function."
            f" Expected input: {list(statics_function_input_dict.keys())},"
            f" actual input: {list(statistics_input_filtered.keys())}"
        )
    return statistics_input_filtered


def infer(
    context: mlrun.MLClientCtx,
    dataset: mlrun.DataItem | list | dict | pd.DataFrame | pd.Series | np.ndarray,
    model_path: str | mlrun.DataItem,
    drop_columns: str | list[str] | int | list[int] = None,
    label_columns: str | list[str] = None,
    feature_columns: str | list[str] = None,
    log_result_set: bool = True,
    result_set_name: str = "prediction",
    batch_id: str = None,
    artifacts_tag: str = "",
    # Drift analysis parameters
    perform_drift_analysis: bool = None,
    endpoint_id: str = "",
    # The following model endpoint parameters are relevant only if:
    # perform drift analysis is not disabled
    # a new model endpoint record is going to be generated
    model_endpoint_name: str = "batch-infer",
    model_endpoint_sample_set: mlrun.DataItem
    | list
    | dict
    | pd.DataFrame
    | pd.Series
    | np.ndarray = None,
    # the following parameters are deprecated and will be removed once the versioning mechanism is implemented
    # TODO: Remove the following parameters once FHUB-13 is resolved
    trigger_monitoring_job: bool | None = None,
    batch_image_job: str | None = None,
    model_endpoint_drift_threshold: float | None = None,
    model_endpoint_possible_drift_threshold: float | None = None,
    # prediction kwargs to pass to the model predict function
    **predict_kwargs: dict[str, Any],
):
    """
    Perform a prediction on the provided dataset using the specified model.
    Ensure that the model has already been logged under the current project.

    If you wish to apply monitoring tools (e.g., drift analysis), set the perform_drift_analysis parameter to True.
    This will create a new model endpoint record under the specified model_endpoint_name.
    Additionally, ensure that model monitoring is enabled at the project level by calling the
    project.enable_model_monitoring() function. You can also apply monitoring to an existing model by providing its
    endpoint id or name, and the monitoring tools will be applied to that endpoint.

    At the moment, this function is supported for `mlrun>=1.5.0` versions.

    :param context:                                 MLRun context.
    :param dataset:                                 The dataset to infer through the model. Provided as an input (DataItem)
                                                    that represents Dataset artifact / Feature vector URI.
                                                    If using MLRun SDK, `dataset` can also be provided as a list, dictionary or
                                                    numpy array.
    :param model_path:                              Model store uri (should start with store://). Provided as an input (DataItem).
                                                    If using MLRun SDK, `model_path` can also be provided as a parameter (string).
                                                    To generate a valid model store URI, please log the model before running this function.
                                                    If `endpoint_id` of existing model endpoint is provided, make sure
                                                    that it has a similar model store path, otherwise the drift analysis
                                                    won't be triggered.
    :param drop_columns:                            A string / integer or a list of strings / integers that represent the column names
                                                    / indices to drop. When the dataset is a list or a numpy array this parameter must
                                                    be represented by integers.
    :param label_columns:                           The target label(s) of the column(s) in the dataset for Regression or
                                                    Classification tasks. The label column can be accessed from the model object, or
                                                    the feature vector provided if available.
    :param feature_columns:                         List of feature columns that will be used to build the dataframe when dataset is
                                                    from type list or numpy array.
    :param log_result_set:                          Whether to log the result set - a DataFrame of the given inputs concatenated with
                                                    the predictions. Defaulted to True.
    :param result_set_name:                         The db key to set name of the prediction result and the filename. Defaulted to
                                                    'prediction'.
    :param batch_id:                                The ID of the given batch (inference dataset). If `None`, it will be generated.
                                                    Will be logged as a result of the run.
    :param artifacts_tag:                           Tag to use for prediction set result artifact.
    :param perform_drift_analysis:                  Whether to perform drift analysis between the sample set of the model object to the
                                                    dataset given. By default, None, which means it will perform drift analysis if the
                                                    model already has feature stats that are considered as a reference sample set.
                                                    Performing drift analysis on a new endpoint id will generate a new model endpoint
                                                    record.
    :param endpoint_id:                             Model endpoint unique ID. If `perform_drift_analysis` was set, the endpoint_id
                                                    will be used to perform the analysis on existing model endpoint, or if it does not
                                                    exist a new model endpoint will be created with a newly generated ID.
    :param model_endpoint_name:                     If a new model endpoint is generated, the model name will be presented under this
                                                    endpoint.
    :param model_endpoint_sample_set:               A sample dataset to give to compare the inputs in the drift analysis.
                                                    Can be provided as an input (DataItem) or as a parameter (e.g. string, list, DataFrame).
                                                    The default chosen sample set will always be the one who is set in the model artifact itself.
    :param trigger_monitoring_job:                  Whether to trigger the batch drift analysis after the infer job.
    :param batch_image_job:                         The image that will be used to register the monitoring batch job if not exist.
                                                    By default, the image is mlrun/mlrun.
    :param model_endpoint_drift_threshold:          The threshold of which to mark drifts. Defaulted to 0.7.
    :param model_endpoint_possible_drift_threshold: The threshold of which to mark possible drifts. Defaulted to 0.5.

    raises MLRunInvalidArgumentError: if both `model_path` and `endpoint_id` are not provided
    """

    if trigger_monitoring_job:
        context.logger.warning(
            "The `trigger_monitoring_job` parameter is deprecated and will be removed once the versioning mechanism is implemented. "
            "if you are using mlrun<1.7.0, please import the previous version of this function, for example "
            "'hub://batch_inference_v2:2.5.0'."
        )
    if batch_image_job:
        context.logger.warning(
            "The `batch_image_job` parameter is deprecated and will be removed once the versioning mechanism is implemented. "
            "if you are using mlrun<1.7.0, please import the previous version of this function, for example "
            "'hub://batch_inference_v2:2.5.0'."
        )
    if model_endpoint_drift_threshold:
        context.logger.warning(
            "The `model_endpoint_drift_threshold` parameter is deprecated and will be removed once the versioning mechanism is implemented. "
            "if you are using mlrun<1.7.0, please import the previous version of this function, for example "
            "'hub://batch_inference_v2:2.5.0'."
        )
    if model_endpoint_possible_drift_threshold:
        context.logger.warning(
            "The `model_endpoint_possible_drift_threshold` parameter is deprecated and will be removed once the versioning mechanism is implemented. "
            "if you are using mlrun<1.7.0, please import the previous version of this function, for example "
            "'hub://batch_inference_v2:2.5.0'."
        )

    # Loading the model:
    context.logger.info("Loading model...")
    if isinstance(model_path, mlrun.DataItem):
        model_path = model_path.artifact_url
    if not mlrun.datastore.is_store_uri(model_path):
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"The provided model path ({model_path}) is invalid - should start with `store://`. "
            f"Please make sure that you have logged the model using `project.log_model()` "
            f"which generates a unique store uri for the logged model."
        )
    model_handler = AutoMLRun.load_model(model_path=model_path, context=context)

    if label_columns is None:
        label_columns = [
            output.name for output in model_handler._model_artifact.spec.outputs
        ]

    if feature_columns is None:
        feature_columns = [
            input.name for input in model_handler._model_artifact.spec.inputs
        ]

    # Get dataset by object, URL or by FeatureVector:
    context.logger.info("Loading data...")
    x, label_columns = mlrun.model_monitoring.api.read_dataset_as_dataframe(
        dataset=dataset,
        feature_columns=feature_columns,
        label_columns=label_columns,
        drop_columns=drop_columns,
    )

    # Predict:
    context.logger.info("Calculating prediction...")
    y_pred = model_handler.model.predict(x, **predict_kwargs)

    # Prepare the result set:
    result_set = _prepare_result_set(x=x, label_columns=label_columns, y_pred=y_pred)

    # Check for logging the result set:
    if log_result_set:
        mlrun.model_monitoring.api.log_result(
            context=context,
            result_set_name=result_set_name,
            result_set=result_set,
            artifacts_tag=artifacts_tag,
            batch_id=batch_id,
        )

    # Check for performing drift analysis
    if (
        perform_drift_analysis is None
        and model_handler._model_artifact.spec.feature_stats is not None
    ):
        perform_drift_analysis = True
    if perform_drift_analysis:
        context.logger.info("Performing drift analysis...")
        # Get the sample set statistics (either from the sample set or from the statistics logged with the model)
        statistics_input_filtered = _get_sample_set_statistics_parameters(
            context=context,
            model_endpoint_sample_set=model_endpoint_sample_set,
            model_artifact_feature_stats=model_handler._model_artifact.spec.feature_stats,
            feature_columns=feature_columns,
            drop_columns=drop_columns,
            label_columns=label_columns,
        )
        sample_set_statistics = mlrun.model_monitoring.api.get_sample_set_statistics(
            **statistics_input_filtered
        )
        mlrun.model_monitoring.api.record_results(
            project=context.project,
            context=context,
            endpoint_id=endpoint_id,
            model_path=model_path,
            model_endpoint_name=model_endpoint_name,
            infer_results_df=result_set.copy(),
            sample_set_statistics=sample_set_statistics,
        )
 + code_origin: '' + auto_build: false + allow_empty_resources: true + filename: batch_inference_v2.py entry_points: infer: - lineno: 102 - name: infer parameters: - name: context type: MLClientCtx doc: MLRun context. - name: dataset - type: Union[DataItem, list, dict, DataFrame, Series, ndarray] doc: The dataset to infer through the model. Provided as an input (DataItem) that represents Dataset artifact / Feature vector URI. If using MLRun SDK, `dataset` can also be provided as a list, dictionary or numpy array. - name: model_path - type: Union[str, DataItem] doc: Model store uri (should start with store://). Provided as an input (DataItem). If using MLRun SDK, `model_path` can also be provided as a parameter (string). To generate a valid model store URI, please log the model before running @@ -23,19 +34,16 @@ spec: make sure that it has a similar model store path, otherwise the drift analysis won't be triggered. - name: drop_columns - type: Union[str, List[str], int, List[int]] doc: A string / integer or a list of strings / integers that represent the column names / indices to drop. When the dataset is a list or a numpy array this parameter must be represented by integers. default: null - name: label_columns - type: Union[str, List[str]] doc: The target label(s) of the column(s) in the dataset for Regression or Classification tasks. The label column can be accessed from the model object, or the feature vector provided if available. default: null - name: feature_columns - type: Union[str, List[str]] doc: List of feature columns that will be used to build the dataframe when dataset is from type list or numpy array. default: null @@ -69,8 +77,9 @@ spec: - name: endpoint_id type: str doc: Model endpoint unique ID. If `perform_drift_analysis` was set, the endpoint_id - will be used either to perform the analysis on existing model endpoint or - to generate a new model endpoint record. + will be used to perform the analysis on existing model endpoint, or if it + does not exist a new model endpoint will be created with a newly generated + ID. default: '' - name: model_endpoint_name type: str @@ -78,31 +87,25 @@ spec: under this endpoint. default: batch-infer - name: model_endpoint_sample_set - type: Union[DataItem, list, dict, DataFrame, Series, ndarray] doc: A sample dataset to give to compare the inputs in the drift analysis. Can be provided as an input (DataItem) or as a parameter (e.g. string, list, DataFrame). The default chosen sample set will always be the one who is set in the model artifact itself. default: null - name: trigger_monitoring_job - type: Optional[bool] doc: Whether to trigger the batch drift analysis after the infer job. default: null - name: batch_image_job - type: Optional[str] doc: The image that will be used to register the monitoring batch job if not exist. By default, the image is mlrun/mlrun. default: null - name: model_endpoint_drift_threshold - type: Optional[float] doc: The threshold of which to mark drifts. Defaulted to 0.7. default: null - name: model_endpoint_possible_drift_threshold - type: Optional[float] doc: The threshold of which to mark possible drifts. Defaulted to 0.5. default: null - has_kwargs: true - has_varargs: false + name: infer doc: 'Perform a prediction on the provided dataset using the specified model. Ensure that the model has already been logged under the current project. @@ -123,21 +126,10 @@ spec: At the moment, this function is supported for `mlrun>=1.5.0` versions.' + has_kwargs: true + has_varargs: false + lineno: 117 command: '' - build: - with_mlrun: false - code_origin: '' - origin_filename: '' - auto_build: false - functionSourceCode: # Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from inspect import signature
from typing import Any, Dict, List, Union, Optional
import mlrun

try:
    import mlrun.model_monitoring.api
except ModuleNotFoundError:
    raise mlrun.errors.MLRunNotFoundError(
        f"Please update your `mlrun` version to >=1.5.0 or use an "
        f"older version of the batch inference function."
    )

import numpy as np
import pandas as pd
from mlrun.frameworks.auto_mlrun import AutoMLRun


def _prepare_result_set(x: pd.DataFrame, label_columns: List[str], y_pred: np.ndarray) -> pd.DataFrame:
    """
    Set default label column names and validate given names to prepare the result set - a concatenation of the inputs
    (x) and the model predictions (y_pred).

    :param x:             The inputs.
    :param label_columns: A list of strings representing the target column names to add to the predictions. Default name
                          will be used in case the list is empty (predicted_label_{i}).
    :param y_pred:        The model predictions on the inputs.

    :returns: The result set.

    raises MLRunInvalidArgumentError: If the labels columns amount do not match the outputs or if one of the label
                                       column already exists in the dataset.
    """
    # Prepare default target columns names if not provided:
    prediction_columns_amount = 1 if len(y_pred.shape) == 1 else y_pred.shape[1]
    if len(label_columns) == 0:
        # Add default label column names:
        if prediction_columns_amount == 1:
            label_columns = ["predicted_label"]
        else:
            label_columns = [
                f"predicted_label_{i}" for i in range(prediction_columns_amount)
            ]

    # Validate the label columns:
    if prediction_columns_amount != len(label_columns):
        # No equality between provided label column names and outputs amount:
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"The number of predicted labels: {prediction_columns_amount} "
            f"is not equal to the given label columns: {len(label_columns)}"
        )
    common_labels = set(label_columns) & set(x.columns.tolist())
    if common_labels:
        # Label column exist in the original inputs:
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"The labels: {common_labels} are already existed in the given dataset."
        )

    return pd.concat(
        [x, pd.DataFrame(y_pred, columns=label_columns, index=x.index)], axis=1
    )


def _get_sample_set_statistics_parameters(context: mlrun.MLClientCtx,
                                          model_endpoint_sample_set: Union[
                                              mlrun.DataItem, list, dict, pd.DataFrame, pd.Series, np.ndarray],
                                          model_artifact_feature_stats: dict,
                                          feature_columns: Optional[List],
                                          drop_columns: Optional[List],
                                          label_columns: Optional[List]) -> Dict[str, Any]:
    statics_input_full_dict = dict(sample_set=model_endpoint_sample_set,
                                   model_artifact_feature_stats=model_artifact_feature_stats,
                                   sample_set_columns=feature_columns,
                                   sample_set_drop_columns=drop_columns,
                                   sample_set_label_columns=label_columns)
    get_sample_statics_function = mlrun.model_monitoring.api.get_sample_set_statistics
    statics_function_input_dict = signature(get_sample_statics_function).parameters
    #  As a result of changes to input parameters in the mlrun-get_sample_set_statistics function,
    #  we will now send only the parameters it expects.
    statistics_input_filtered = {key: statics_input_full_dict[key] for key in statics_function_input_dict}
    if len(statistics_input_filtered) != len(statics_function_input_dict):
        context.logger.warning(f"get_sample_set_statistics is in an older version; "
                               "some parameters will not be sent to the function."
                               f" Expected input: {list(statics_function_input_dict.keys())},"
                               f" actual input: {list(statistics_input_filtered.keys())}")
    return statistics_input_filtered


def infer(
        context: mlrun.MLClientCtx,
        dataset: Union[mlrun.DataItem, list, dict, pd.DataFrame, pd.Series, np.ndarray],
        model_path: Union[str, mlrun.DataItem],
        drop_columns: Union[str, List[str], int, List[int]] = None,
        label_columns: Union[str, List[str]] = None,
        feature_columns: Union[str, List[str]] = None,
        log_result_set: bool = True,
        result_set_name: str = "prediction",
        batch_id: str = None,
        artifacts_tag: str = "",
        # Drift analysis parameters
        perform_drift_analysis: bool = None,
        endpoint_id: str = "",
        # The following model endpoint parameters are relevant only if:
        # perform drift analysis is not disabled
        # a new model endpoint record is going to be generated
        model_endpoint_name: str = "batch-infer",
        model_endpoint_sample_set: Union[
            mlrun.DataItem, list, dict, pd.DataFrame, pd.Series, np.ndarray
        ] = None,

        # the following parameters are deprecated and will be removed once the versioning mechanism is implemented
        # TODO: Remove the following parameters once FHUB-13 is resolved
        trigger_monitoring_job: Optional[bool] = None,
        batch_image_job: Optional[str] = None,
        model_endpoint_drift_threshold: Optional[float] = None,
        model_endpoint_possible_drift_threshold: Optional[float] = None,

        # prediction kwargs to pass to the model predict function
        **predict_kwargs: Dict[str, Any],

):
    """
    Perform a prediction on the provided dataset using the specified model.
    Ensure that the model has already been logged under the current project.

    If you wish to apply monitoring tools (e.g., drift analysis), set the perform_drift_analysis parameter to True.
    This will create a new model endpoint record under the specified model_endpoint_name.
    Additionally, ensure that model monitoring is enabled at the project level by calling the
    project.enable_model_monitoring() function. You can also apply monitoring to an existing model by providing its
    endpoint id or name, and the monitoring tools will be applied to that endpoint.

    At the moment, this function is supported for `mlrun>=1.5.0` versions.

    :param context:                                 MLRun context.
    :param dataset:                                 The dataset to infer through the model. Provided as an input (DataItem)
                                                    that represents Dataset artifact / Feature vector URI.
                                                    If using MLRun SDK, `dataset` can also be provided as a list, dictionary or
                                                    numpy array.
    :param model_path:                              Model store uri (should start with store://). Provided as an input (DataItem).
                                                    If using MLRun SDK, `model_path` can also be provided as a parameter (string).
                                                    To generate a valid model store URI, please log the model before running this function.
                                                    If `endpoint_id` of existing model endpoint is provided, make sure
                                                    that it has a similar model store path, otherwise the drift analysis
                                                    won't be triggered.
    :param drop_columns:                            A string / integer or a list of strings / integers that represent the column names
                                                    / indices to drop. When the dataset is a list or a numpy array this parameter must
                                                    be represented by integers.
    :param label_columns:                           The target label(s) of the column(s) in the dataset for Regression or
                                                    Classification tasks. The label column can be accessed from the model object, or
                                                    the feature vector provided if available.
    :param feature_columns:                         List of feature columns that will be used to build the dataframe when dataset is
                                                    from type list or numpy array.
    :param log_result_set:                          Whether to log the result set - a DataFrame of the given inputs concatenated with
                                                    the predictions. Defaulted to True.
    :param result_set_name:                         The db key to set name of the prediction result and the filename. Defaulted to
                                                    'prediction'.
    :param batch_id:                                The ID of the given batch (inference dataset). If `None`, it will be generated.
                                                    Will be logged as a result of the run.
    :param artifacts_tag:                           Tag to use for prediction set result artifact.
    :param perform_drift_analysis:                  Whether to perform drift analysis between the sample set of the model object to the
                                                    dataset given. By default, None, which means it will perform drift analysis if the
                                                    model already has feature stats that are considered as a reference sample set.
                                                    Performing drift analysis on a new endpoint id will generate a new model endpoint
                                                    record.
    :param endpoint_id:                             Model endpoint unique ID. If `perform_drift_analysis` was set, the endpoint_id
                                                    will be used either to perform the analysis on existing model endpoint or to
                                                    generate a new model endpoint record.
    :param model_endpoint_name:                     If a new model endpoint is generated, the model name will be presented under this
                                                    endpoint.
    :param model_endpoint_sample_set:               A sample dataset to give to compare the inputs in the drift analysis.
                                                    Can be provided as an input (DataItem) or as a parameter (e.g. string, list, DataFrame).
                                                    The default chosen sample set will always be the one who is set in the model artifact itself.
    :param trigger_monitoring_job:                  Whether to trigger the batch drift analysis after the infer job.
    :param batch_image_job:                         The image that will be used to register the monitoring batch job if not exist.
                                                    By default, the image is mlrun/mlrun.
    :param model_endpoint_drift_threshold:          The threshold of which to mark drifts. Defaulted to 0.7.
    :param model_endpoint_possible_drift_threshold: The threshold of which to mark possible drifts. Defaulted to 0.5.

    raises MLRunInvalidArgumentError: if both `model_path` and `endpoint_id` are not provided
    """


    if trigger_monitoring_job:
        context.logger.warning("The `trigger_monitoring_job` parameter is deprecated and will be removed once the versioning mechanism is implemented. "
                               "if you are using mlrun<1.7.0, please import the previous version of this function, for example "
                               "'hub://batch_inference_v2:2.5.0'.")
    if batch_image_job:
        context.logger.warning("The `batch_image_job` parameter is deprecated and will be removed once the versioning mechanism is implemented. "
                               "if you are using mlrun<1.7.0, please import the previous version of this function, for example "
                               "'hub://batch_inference_v2:2.5.0'.")
    if model_endpoint_drift_threshold:
        context.logger.warning("The `model_endpoint_drift_threshold` parameter is deprecated and will be removed once the versioning mechanism is implemented. "
                               "if you are using mlrun<1.7.0, please import the previous version of this function, for example "
                               "'hub://batch_inference_v2:2.5.0'.")
    if model_endpoint_possible_drift_threshold:
        context.logger.warning("The `model_endpoint_possible_drift_threshold` parameter is deprecated and will be removed once the versioning mechanism is implemented. "
                               "if you are using mlrun<1.7.0, please import the previous version of this function, for example "
                               "'hub://batch_inference_v2:2.5.0'.")

    # Loading the model:
    context.logger.info(f"Loading model...")
    if isinstance(model_path, mlrun.DataItem):
        model_path = model_path.artifact_url
    if not mlrun.datastore.is_store_uri(model_path):
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"The provided model path ({model_path}) is invalid - should start with `store://`. "
            f"Please make sure that you have logged the model using `project.log_model()` "
            f"which generates a unique store uri for the logged model."
        )
    model_handler = AutoMLRun.load_model(model_path=model_path, context=context)

    if label_columns is None:
        label_columns = [
            output.name for output in model_handler._model_artifact.spec.outputs
        ]

    if feature_columns is None:
        feature_columns = [
            input.name for input in model_handler._model_artifact.spec.inputs
        ]

    # Get dataset by object, URL or by FeatureVector:
    context.logger.info(f"Loading data...")
    x, label_columns = mlrun.model_monitoring.api.read_dataset_as_dataframe(
        dataset=dataset,
        feature_columns=feature_columns,
        label_columns=label_columns,
        drop_columns=drop_columns,
    )

    # Predict:
    context.logger.info(f"Calculating prediction...")
    y_pred = model_handler.model.predict(x, **predict_kwargs)

    # Prepare the result set:
    result_set = _prepare_result_set(x=x, label_columns=label_columns, y_pred=y_pred)

    # Check for logging the result set:
    if log_result_set:
        mlrun.model_monitoring.api.log_result(
            context=context,
            result_set_name=result_set_name,
            result_set=result_set,
            artifacts_tag=artifacts_tag,
            batch_id=batch_id,
        )

    # Check for performing drift analysis
    if (
            perform_drift_analysis is None
            and model_handler._model_artifact.spec.feature_stats is not None
    ):
        perform_drift_analysis = True
    if perform_drift_analysis:
        context.logger.info("Performing drift analysis...")
        # Get the sample set statistics (either from the sample set or from the statistics logged with the model)
        statistics_input_filtered = _get_sample_set_statistics_parameters(
            context=context,
            model_endpoint_sample_set=model_endpoint_sample_set,
            model_artifact_feature_stats=model_handler._model_artifact.spec.feature_stats,
            feature_columns=feature_columns,
            drop_columns=drop_columns,
            label_columns=label_columns)
        sample_set_statistics = mlrun.model_monitoring.api.get_sample_set_statistics(**statistics_input_filtered)
        mlrun.model_monitoring.api.record_results(
            project=context.project,
            context=context,
            endpoint_id=endpoint_id,
            model_path=model_path,
            model_endpoint_name=model_endpoint_name,
            infer_results_df=result_set.copy(),
            sample_set_statistics=sample_set_statistics,
        ) - allow_empty_resources: true - disable_auto_mount: false - image: mlrun/mlrun description: Batch inference (also knows as prediction) for the common ML frameworks (SciKit-Learn, XGBoost and LightGBM) while performing data drift analysis. -metadata: - tag: '' - categories: - - model-serving - name: batch-inference-v2 -kind: job + default_handler: infer diff --git a/functions/src/batch_inference_v2/item.yaml b/functions/src/batch_inference_v2/item.yaml index 8b8f01df0..62738b1ec 100644 --- a/functions/src/batch_inference_v2/item.yaml +++ b/functions/src/batch_inference_v2/item.yaml @@ -12,7 +12,7 @@ labels: author: Iguazio maintainers: [] marketplaceType: '' -mlrunVersion: 1.7.0-rc51 +mlrunVersion: 1.7.0 name: batch_inference_v2 platformVersion: 3.6.0 spec: diff --git a/functions/src/batch_inference_v2/test_batch_inference_v2.py b/functions/src/batch_inference_v2/test_batch_inference_v2.py index 6fa657a0d..e34433076 100644 --- a/functions/src/batch_inference_v2/test_batch_inference_v2.py +++ b/functions/src/batch_inference_v2/test_batch_inference_v2.py @@ -13,25 +13,27 @@ # limitations under the License. # +import datetime import os import pickle +import shutil import time import uuid + +import mlrun +import mlrun.common.schemas import numpy as np import pandas as pd import pytest +from batch_inference_v2 import infer +from mlrun.frameworks.sklearn import apply_mlrun +from mlrun.model_monitoring.api import get_or_create_model_endpoint +from mlrun.projects import get_or_create_project from sklearn.datasets import make_classification -from sklearn.tree import DecisionTreeClassifier -import datetime from sklearn.model_selection import train_test_split +from sklearn.tree import DecisionTreeClassifier from xgboost import XGBClassifier -from mlrun.frameworks.sklearn import apply_mlrun -from mlrun.projects import get_or_create_project -import mlrun -import mlrun.common.schemas -from batch_inference_v2 import infer -import shutil -from mlrun.model_monitoring.api import get_or_create_model_endpoint + REQUIRED_ENV_VARS = [ "MLRUN_DBPATH", "V3IO_USERNAME", @@ -39,6 +41,7 @@ "V3IO_ACCESS_KEY", ] + def _validate_environment_variables() -> bool: """ Checks that all required Environment variables are set. @@ -52,7 +55,7 @@ def generate_data(n_samples: int = 5000, n_features: int = 20): x, y = make_classification(n_samples=n_samples, n_features=n_features, n_classes=2) # Split the data into a training set and a prediction set: - x_train, x_prediction = x[: n_samples // 2], x[n_samples // 2:] + x_train, x_prediction = x[: n_samples // 2], x[n_samples // 2 :] y_train = y[: n_samples // 2] # Randomly drift some features: @@ -86,17 +89,27 @@ def train(training_set: pd.DataFrame): model.fit(training_set, labels) -def assert_batch_predict(n_features, batch_inference_run, with_monitoring=False, project_name="batch-infer-test"): +def assert_batch_predict( + n_features, + batch_inference_run, + with_monitoring=False, + project_name="batch-infer-test", +): # Check the logged results: assert "batch_id" in batch_inference_run.status.results assert len(batch_inference_run.status.artifacts) == 1 - assert len(batch_inference_run.artifact("prediction").as_df().columns) == n_features + 1 + assert ( + len(batch_inference_run.artifact("prediction").as_df().columns) + == n_features + 1 + ) if with_monitoring: # Check that the drift analysis was performed: time.sleep(60) # Retrieve the model endpoint project = get_or_create_project(project_name) - endpoint = get_or_create_model_endpoint(project=project.name, model_endpoint_name="my_cool_endpoint") + endpoint = get_or_create_model_endpoint( + project=project.name, model_endpoint_name="my_cool_endpoint" + ) # Validate that the artifacts were logged in the project artifacts = project.list_artifacts( @@ -119,9 +132,7 @@ def assert_batch_predict(n_features, batch_inference_run, with_monitoring=False, reason="Project's environment variables are not set", ) def test_batch_predict(): - project = get_or_create_project( - "batch-infer-test", context="./", user_project=True - ) + project = get_or_create_project("batch-infer-test", context="./", user_project=True) # Configure test: n_samples = 5000 n_features = 20 @@ -157,19 +168,23 @@ def test_batch_predict(): # Enable model monitoring project.set_model_monitoring_credentials( - endpoint_store_connection="v3io", - tsdb_connection="v3io", - stream_path="v3io") + endpoint_store_connection="v3io", tsdb_connection="v3io", stream_path="v3io" + ) # Deploy model monitoring infrastructure project.enable_model_monitoring(wait_for_deployment=True, base_period=1) # Wait until the monitoring application is triggered import time + time.sleep(60) # Check the logged results: - assert_batch_predict(n_features=n_features, batch_inference_run=batch_inference_run, with_monitoring=True) + assert_batch_predict( + n_features=n_features, + batch_inference_run=batch_inference_run, + with_monitoring=True, + ) # Clean resources _delete_project(project=project.metadata.name) @@ -190,7 +205,9 @@ def setup_method(self): current_datetime = datetime.datetime.now() datetime_str = current_datetime.strftime("%Y%m%d_%H%M%S") mlrun.runtimes.utils.global_context.set(None) - self.context = mlrun.get_or_create_ctx(datetime_str, project=self.project.metadata.name, upload_artifacts=True) + self.context = mlrun.get_or_create_ctx( + datetime_str, project=self.project.metadata.name, upload_artifacts=True + ) self.context.artifact_path = self.infer_artifact_path def teardown_method(self): @@ -209,43 +226,70 @@ def _get_model_endpoint_sample_set(self, sample_type, n_features: int = 20): elif sample_type == list: return data.values.tolist() elif sample_type == dict: - return data.to_dict(orient='list') + return data.to_dict(orient="list") elif sample_type == pd.DataFrame: return data elif sample_type == np.ndarray: return data.values - @pytest.mark.parametrize("sample_type", [mlrun.DataItem, list, dict, pd.DataFrame, np.ndarray]) + @pytest.mark.parametrize( + "sample_type", [mlrun.DataItem, list, dict, pd.DataFrame, np.ndarray] + ) def test_infer_sample_types(self, sample_type): n_features = 10 training_set, prediction_set = generate_data(n_features=n_features) - clf = XGBClassifier(n_estimators=2, max_depth=2, learning_rate=1, objective="binary:logistic") - x, y = prediction_set, training_set['target_label'] - x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=0.8, test_size=0.2, random_state=0) + clf = XGBClassifier( + n_estimators=2, max_depth=2, learning_rate=1, objective="binary:logistic" + ) + x, y = prediction_set, training_set["target_label"] + x_train, x_test, y_train, y_test = train_test_split( + x, y, train_size=0.8, test_size=0.2, random_state=0 + ) clf.fit(x_train, y_train) train_set_to_log = x_train.join(y_train) - model = self.project.log_model(f"model-{uuid.uuid4()}", body=pickle.dumps(clf), - model_file=f"model-{uuid.uuid4()}.pkl", framework="xgboost", - training_set=train_set_to_log, label_column="target_label") + model = self.project.log_model( + f"model-{uuid.uuid4()}", + body=pickle.dumps(clf), + model_file=f"model-{uuid.uuid4()}.pkl", + framework="xgboost", + training_set=train_set_to_log, + label_column="target_label", + ) dataset = self.project.log_dataset(f"dataset-{uuid.uuid4()}", df=x_test) z_test = train_set_to_log * 5 - model_endpoint_sample_set = self.project.log_dataset(f"model-endpoint-sample-set{uuid.uuid4()}", df=z_test) + model_endpoint_sample_set = self.project.log_dataset( + f"model-endpoint-sample-set{uuid.uuid4()}", df=z_test + ) sample = self._get_model_endpoint_sample_set( - sample_type=sample_type, n_features=n_features) - infer(context=self.context, - dataset=dataset.to_dataitem().as_df(), model_path=model.uri, - model_endpoint_sample_set=sample, - feature_columns=list(model_endpoint_sample_set.to_dataitem().as_df().columns), - label_columns="target_label", - model_endpoint_name=f"model-endpoint-name-{uuid.uuid4()}", - trigger_monitoring_job=True, - perform_drift_analysis=True) + sample_type=sample_type, n_features=n_features + ) + infer( + context=self.context, + dataset=dataset.to_dataitem().as_df(), + model_path=model.uri, + model_endpoint_sample_set=sample, + feature_columns=list( + model_endpoint_sample_set.to_dataitem().as_df().columns + ), + label_columns="target_label", + model_endpoint_name=f"model-endpoint-name-{uuid.uuid4()}", + trigger_monitoring_job=True, + perform_drift_analysis=True, + ) # a workaround until ML-4636 will be solved. - batch_inference_run = self.project.list_runs(name=self.context.name).to_objects()[0] - mlrun.get_run_db().update_run(updates={"status.state": "completed"}, uid=batch_inference_run.uid()) - assert_batch_predict(n_features=n_features, batch_inference_run=batch_inference_run, project_name=self.project_name) + batch_inference_run = self.project.list_runs( + name=self.context.name + ).to_objects()[0] + mlrun.get_run_db().update_run( + updates={"status.state": "completed"}, uid=batch_inference_run.uid() + ) + assert_batch_predict( + n_features=n_features, + batch_inference_run=batch_inference_run, + project_name=self.project_name, + ) def _delete_project(project: str): diff --git a/functions/src/describe/describe.py b/functions/src/describe/describe.py index 27d789f5b..ac8a744dc 100644 --- a/functions/src/describe/describe.py +++ b/functions/src/describe/describe.py @@ -15,7 +15,6 @@ # Generated by nuclio.export.NuclioExporter import warnings -from typing import Union import mlrun import numpy as np @@ -46,7 +45,7 @@ def analyze( context: MLClientCtx, name: str = "dataset", - table: Union[FeatureSet, DataItem] = None, + table: FeatureSet | DataItem = None, label_column: str = None, plots_dest: str = "plots", random_state: int = 1, @@ -129,7 +128,7 @@ def analyze( ) df = feature_set.to_dataframe() else: - context.logger.error(f"Wrong table type.") + context.logger.error("Wrong table type.") return if df.size > MAX_SIZE_OF_DF: @@ -320,8 +319,8 @@ def _create_features_histogram_artifacts( ) fig.update_layout(title_text=f"Histograms of {first_feature_name}") - extra_data[f"histograms"] = context.log_artifact( - PlotlyArtifact(key=f"histograms", figure=fig), + extra_data["histograms"] = context.log_artifact( + PlotlyArtifact(key="histograms", figure=fig), local_path=f"{plots_dest}/histograms.html", ) @@ -431,9 +430,9 @@ def _create_features_2d_scatter_artifacts( template="plotly_white", ) - fig.update_layout(title_text=f"Scatter-2d") - extra_data[f"scatter-2d"] = context.log_artifact( - PlotlyArtifact(key=f"scatter-2d", figure=fig), + fig.update_layout(title_text="Scatter-2d") + extra_data["scatter-2d"] = context.log_artifact( + PlotlyArtifact(key="scatter-2d", figure=fig), local_path=f"{plots_dest}/scatter-2d.html", ) @@ -540,7 +539,7 @@ def _create_corr_artifact( ) z = tblcorr.values.tolist() - z_text = [["{:.2f}".format(y) for y in x] for x in z] + z_text = [[f"{y:.2f}" for y in x] for x in z] fig = ff.create_annotated_heatmap( z, x=list(tblcorr.columns), diff --git a/functions/src/describe/function.yaml b/functions/src/describe/function.yaml index a11461774..679eea213 100644 --- a/functions/src/describe/function.yaml +++ b/functions/src/describe/function.yaml @@ -1,7 +1,20 @@ +metadata: + tag: '' + name: describe + categories: + - data-analysis +verbose: false +kind: job spec: + image: mlrun/mlrun + disable_auto_mount: false + build: + origin_filename: '' + functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Generated by nuclio.export.NuclioExporter

import warnings

import mlrun
import numpy as np

warnings.simplefilter(action="ignore", category=FutureWarning)

import mlrun.feature_store as fstore
import pandas as pd
import plotly.express as px
import plotly.figure_factory as ff
import plotly.graph_objects as go
from mlrun.artifacts import (
    Artifact,
    DatasetArtifact,
    PlotlyArtifact,
    TableArtifact,
    update_dataset_meta,
)
from mlrun.datastore import DataItem
from mlrun.execution import MLClientCtx
from mlrun.feature_store import FeatureSet
from plotly.subplots import make_subplots

pd.set_option("display.float_format", lambda x: "%.2f" % x)
MAX_SIZE_OF_DF = 500000


def analyze(
    context: MLClientCtx,
    name: str = "dataset",
    table: FeatureSet | DataItem = None,
    label_column: str = None,
    plots_dest: str = "plots",
    random_state: int = 1,
    problem_type: str = "classification",
    dask_key: str = "dask_key",
    dask_function: str = None,
    dask_client=None,
) -> None:
    """
    The function will output the following artifacts per
    column within the data frame (based on data types)
    If the data has more than 500,000 sample we
    sample randomly 500,000 samples:

    describe csv
    histograms
    scatter-2d
    violin chart
    correlation-matrix chart
    correlation-matrix csv
    imbalance pie chart
    imbalance-weights-vec csv

    :param context:                 The function context
    :param name:                    Key of dataset to database ("dataset" for default)
    :param table:                   MLRun input pointing to pandas dataframe (csv/parquet file path) or FeatureSet
                                    as param
    :param label_column:            Ground truth column label
    :param plots_dest:              Destination folder of summary plots (relative to artifact_path)
                                    ("plots" for default)
    :param random_state:            When the table has more than 500,000 samples, we sample randomly 500,000 samples
    :param problem_type             The type of the ML problem the data facing - regression, classification or None
                                    (classification for default)
    :param dask_key:                Key of dataframe in dask client "datasets" attribute
    :param dask_function:           Dask function url (db://..)
    :param dask_client:             Dask client object
    """
    data_item, featureset, creat, update = False, False, False, False
    get_from_table = True
    if dask_function or dask_client:
        data_item, creat = True, True
        if dask_function:
            client = mlrun.import_function(dask_function).client
        elif dask_client:
            client = dask_client
        else:
            raise ValueError("dask client was not provided")

        if dask_key in client.datasets:
            df = client.get_dataset(dask_key)
            data_item, creat, get_from_table = True, True, False
        elif table:
            get_from_table = True
        else:
            context.logger.info(
                f"only these datasets are available {client.datasets} in client {client}"
            )
            raise Exception("dataset not found on dask cluster")

    if get_from_table:
        if type(table) == DataItem:
            if table.meta is None:
                data_item, creat, update = True, True, False
            elif table.meta.kind == "dataset":
                data_item, creat, update = True, False, True
            elif table.meta.kind == "FeatureVector":
                data_item, creat, update = True, False, False
            elif table.meta.kind == "FeatureSet":
                featureset, creat, update = True, False, False

        if data_item:
            df = table.as_df()
        elif featureset:
            project_name, set_name = (
                table._path.split("/")[2],
                table._path.split("/")[4],
            )
            feature_set = fstore.get_feature_set(
                f"store://feature-sets/{project_name}/{set_name}"
            )
            df = feature_set.to_dataframe()
        else:
            context.logger.error("Wrong table type.")
            return

    if df.size > MAX_SIZE_OF_DF:
        df = df.sample(n=int(MAX_SIZE_OF_DF / df.shape[1]), random_state=random_state)
    extra_data = {}

    if label_column not in df.columns:
        label_column = None

    extra_data["describe csv"] = context.log_artifact(
        TableArtifact("describe-csv", df=df.describe()),
        local_path=f"{plots_dest}/describe.csv",
    )

    try:
        _create_histogram_mat_artifact(
            context, df, extra_data, label_column, plots_dest
        )
    except Exception as e:
        context.logger.warn(f"Failed to create histogram matrix artifact due to: {e}")
    try:
        _create_features_histogram_artifacts(
            context, df, extra_data, label_column, plots_dest, problem_type
        )
    except Exception as e:
        context.logger.warn(f"Failed to create pairplot histograms due to: {e}")
    try:
        _create_features_2d_scatter_artifacts(
            context, df, extra_data, label_column, plots_dest, problem_type
        )
    except Exception as e:
        context.logger.warn(f"Failed to create pairplot 2d_scatter due to: {e}")
    try:
        _create_violin_artifact(context, df, extra_data, plots_dest)
    except Exception as e:
        context.logger.warn(f"Failed to create violin distribution plots due to: {e}")
    try:
        _create_imbalance_artifact(
            context, df, extra_data, label_column, plots_dest, problem_type
        )
    except Exception as e:
        context.logger.warn(f"Failed to create class imbalance plot due to: {e}")
    try:
        _create_corr_artifact(context, df, extra_data, label_column, plots_dest)
    except Exception as e:
        context.logger.warn(f"Failed to create features correlation plot due to: {e}")

    if not data_item:
        return

    artifact = table.artifact_url
    if creat:  # dataset not stored
        artifact = DatasetArtifact(
            key="dataset", stats=True, df=df, extra_data=extra_data
        )
        artifact = context.log_artifact(artifact, db_key=name)
        context.logger.info(f"The data set is logged to the project under {name} name")

    if update:
        update_dataset_meta(artifact, extra_data=extra_data)
        context.logger.info(f"The data set named {name} is updated")

    # TODO : 3-D plot on on selected features.
    # TODO : Reintegration plot on on selected features.
    # TODO : PCA plot (with options)


def _create_histogram_mat_artifact(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
):
    """
    Create and log a histogram matrix artifact
    """
    context.log_artifact(
        item=Artifact(
            key="hist",
            body=b"<b> Deprecated, see the artifacts scatter-2d "
            b"and histograms instead<b>",
        ),
        local_path=f"{plots_dest}/hist.html",
    )


def _create_features_histogram_artifacts(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
    problem_type: str,
):
    """
    Create and log a histogram artifact for each feature
    """

    figs = dict()
    first_feature_name = ""
    if label_column is not None and problem_type == "classification":
        all_labels = df[label_column].unique()
    visible = True
    for column_name in df.columns:
        if column_name == label_column:
            continue

        if label_column is not None and problem_type == "classification":
            for label in all_labels:
                sub_fig = go.Histogram(
                    histfunc="count",
                    x=df.loc[df[label_column] == label][column_name],
                    name=str(label),
                    visible=visible,
                )
                figs[f"{column_name}@?@{label}"] = sub_fig
        else:
            sub_fig = go.Histogram(histfunc="count", x=df[column_name], visible=visible)
            figs[f"{column_name}@?@{1}"] = sub_fig
        if visible:
            first_feature_name = column_name
        visible = False

    fig = go.Figure()
    for k in figs.keys():
        fig.add_trace(figs[k])

    fig.update_layout(
        updatemenus=[
            {
                "buttons": [
                    {
                        "label": column_name,
                        "method": "update",
                        "args": [
                            {
                                "visible": [
                                    key.split("@?@")[0] == column_name
                                    for key in figs.keys()
                                ],
                                "xaxis": {
                                    "range": [
                                        min(df[column_name]),
                                        max(df[column_name]),
                                    ]
                                },
                            },
                            {"title": f"<i><b>Histogram of {column_name}</b></i>"},
                        ],
                    }
                    for column_name in df.columns
                    if column_name != label_column
                ],
                "direction": "down",
                "pad": {"r": 10, "t": 10},
                "showactive": True,
                "x": 0.25,
                "xanchor": "left",
                "y": 1.1,
                "yanchor": "top",
            }
        ],
        annotations=[
            dict(
                text="Select Feature Name ",
                showarrow=False,
                x=0,
                y=1.05,
                yref="paper",
                xref="paper",
                align="left",
                xanchor="left",
                yanchor="top",
                font={
                    "color": "blue",
                },
            )
        ],
    )

    fig.update_layout(
        width=600,
        height=400,
        autosize=False,
        margin=dict(t=100, b=0, l=0, r=0),
        template="plotly_white",
    )

    fig.update_layout(title_text=f"<i><b>Histograms of {first_feature_name}</b></i>")
    extra_data["histograms"] = context.log_artifact(
        PlotlyArtifact(key="histograms", figure=fig),
        local_path=f"{plots_dest}/histograms.html",
    )


def _create_features_2d_scatter_artifacts(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
    problem_type: str,
):
    """
    Create and log a scatter-2d artifact for each couple of features
    """
    features = [
        column_name for column_name in df.columns if column_name != label_column
    ]
    max_feature_len = float(max(len(elem) for elem in features))
    if label_column is not None:
        labels = sorted(df[label_column].unique())
    else:
        labels = [None]
    fig = go.Figure()
    if label_column is not None and problem_type == "classification":
        for l in labels:
            fig.add_trace(
                go.Scatter(
                    x=df.loc[df[label_column] == l][features[0]],
                    y=df.loc[df[label_column] == l][features[0]],
                    mode="markers",
                    visible=True,
                    showlegend=True,
                    name=str(l),
                )
            )
    elif label_column is None:
        fig.add_trace(
            go.Scatter(
                x=df[features[0]],
                y=df[features[0]],
                mode="markers",
                visible=True,
            )
        )
    elif problem_type == "regression":
        fig.add_trace(
            go.Scatter(
                x=df[features[0]],
                y=df[features[0]],
                mode="markers",
                marker=dict(
                    color=df[label_column], colorscale="Viridis", showscale=True
                ),
                visible=True,
            )
        )

    x_buttons = []
    y_buttons = []

    for ncol in features:
        if problem_type == "classification" and label_column is not None:
            x_buttons.append(
                dict(
                    method="update",
                    label=ncol,
                    args=[
                        {"x": [df.loc[df[label_column] == l][ncol] for l in labels]},
                        np.arange(len(labels)).tolist(),
                    ],
                )
            )

            y_buttons.append(
                dict(
                    method="update",
                    label=ncol,
                    args=[
                        {"y": [df.loc[df[label_column] == l][ncol] for l in labels]},
                        np.arange(len(labels)).tolist(),
                    ],
                )
            )
        else:
            x_buttons.append(
                dict(method="update", label=ncol, args=[{"x": [df[ncol]]}])
            )

            y_buttons.append(
                dict(method="update", label=ncol, args=[{"y": [df[ncol]]}])
            )

    # Pass buttons to the updatemenus argument
    fig.update_layout(
        updatemenus=[
            dict(buttons=x_buttons, direction="up", x=0.5, y=-0.1),
            dict(buttons=y_buttons, direction="down", x=-max_feature_len / 100, y=0.5),
        ]
    )

    fig.update_layout(
        width=600,
        height=400,
        autosize=False,
        margin=dict(t=100, b=0, l=0, r=0),
        template="plotly_white",
    )

    fig.update_layout(title_text="<i><b>Scatter-2d</b></i>")
    extra_data["scatter-2d"] = context.log_artifact(
        PlotlyArtifact(key="scatter-2d", figure=fig),
        local_path=f"{plots_dest}/scatter-2d.html",
    )


def _create_violin_artifact(
    context: MLClientCtx, df: pd.DataFrame, extra_data: dict, plots_dest: str
):
    """
    Create and log a violin artifact
    """
    cols = 5
    rows = (df.shape[1] // cols) + 1
    fig = make_subplots(rows=rows, cols=cols)

    plot_num = 0

    for column_name in df.columns:
        column_data = df[column_name]
        violin = go.Violin(
            x=[column_name] * column_data.shape[0],
            y=column_data,
            name=column_name,
        )

        fig.add_trace(
            violin,
            row=(plot_num // cols) + 1,
            col=(plot_num % cols) + 1,
        )

        plot_num += 1

    fig["layout"].update(
        height=(rows + 1) * 200,
        width=(cols + 1) * 200,
        title="<i><b>Violin Plots</b></i>",
    )

    fig.update_layout(showlegend=False)
    extra_data["violin"] = context.log_artifact(
        PlotlyArtifact(key="violin", figure=fig),
        local_path=f"{plots_dest}/violin.html",
    )


def _create_imbalance_artifact(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
    problem_type: str,
):
    """
    Create and log an imbalance class artifact (csv + plot)
    """
    if label_column:
        if problem_type == "classification":
            values_column = "count"
            labels_count = df[label_column].value_counts().sort_index()
            df_labels_count = pd.DataFrame(labels_count)
            df_labels_count[label_column] = labels_count.index
            df_labels_count.rename(columns={"": values_column}, inplace=True)
            df_labels_count[values_column] = df_labels_count[values_column] / sum(
                df_labels_count[values_column]
            )
            fig = px.pie(df_labels_count, names=label_column, values=values_column)
        else:
            fig = px.histogram(
                histfunc="count",
                x=df[label_column],
            )
            hist = np.histogram(df[label_column])
            df_labels_count = pd.DataFrame(
                {"min_val": hist[1], "count": hist[0].tolist() + [0]}
            )
        fig.update_layout(title_text="<i><b>Labels Imbalance</b></i>")
        extra_data["imbalance"] = context.log_artifact(
            PlotlyArtifact(key="imbalance", figure=fig),
            local_path=f"{plots_dest}/imbalance.html",
        )
        extra_data["imbalance-csv"] = context.log_artifact(
            TableArtifact("imbalance-weights-vec", df=df_labels_count),
            local_path=f"{plots_dest}/imbalance-weights-vec.csv",
        )


def _create_corr_artifact(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
):
    """
    Create and log an correlation-matrix artifact (csv + plot)
    """
    if label_column is not None:
        df = df.drop([label_column], axis=1)
    tblcorr = df.corr(numeric_only=True)
    extra_data["correlation-matrix-csv"] = context.log_artifact(
        TableArtifact("correlation-matrix-csv", df=tblcorr, visible=True),
        local_path=f"{plots_dest}/correlation-matrix.csv",
    )

    z = tblcorr.values.tolist()
    z_text = [[f"{y:.2f}" for y in x] for x in z]
    fig = ff.create_annotated_heatmap(
        z,
        x=list(tblcorr.columns),
        y=list(tblcorr.columns),
        annotation_text=z_text,
        colorscale="agsunset",
    )
    fig["layout"]["yaxis"]["autorange"] = "reversed"  # l -> r
    fig.update_layout(title_text="<i><b>Correlation matrix</b></i>")
    fig["data"][0]["showscale"] = True

    extra_data["correlation"] = context.log_artifact(
        PlotlyArtifact(key="correlation", figure=fig),
        local_path=f"{plots_dest}/correlation.html",
    )
 + code_origin: '' + filename: describe.py entry_points: analyze: - has_varargs: false outputs: - type: None parameters: @@ -13,7 +26,6 @@ spec: doc: Key of dataset to database ("dataset" for default) default: dataset - name: table - type: Union[FeatureSet, DataItem] doc: MLRun input pointing to pandas dataframe (csv/parquet file path) or FeatureSet as param default: null @@ -45,6 +57,7 @@ spec: - name: dask_client doc: Dask client object default: null + name: analyze doc: 'The function will output the following artifacts per column within the data frame (based on data types) @@ -70,21 +83,8 @@ spec: imbalance-weights-vec csv' has_kwargs: false - name: analyze - lineno: 46 - image: mlrun/mlrun + has_varargs: false + lineno: 45 command: '' - build: - functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Generated by nuclio.export.NuclioExporter

import warnings
from typing import Union

import mlrun
import numpy as np

warnings.simplefilter(action="ignore", category=FutureWarning)

import mlrun.feature_store as fstore
import pandas as pd
import plotly.express as px
import plotly.figure_factory as ff
import plotly.graph_objects as go
from mlrun.artifacts import (
    Artifact,
    DatasetArtifact,
    PlotlyArtifact,
    TableArtifact,
    update_dataset_meta,
)
from mlrun.datastore import DataItem
from mlrun.execution import MLClientCtx
from mlrun.feature_store import FeatureSet
from plotly.subplots import make_subplots

pd.set_option("display.float_format", lambda x: "%.2f" % x)
MAX_SIZE_OF_DF = 500000


def analyze(
    context: MLClientCtx,
    name: str = "dataset",
    table: Union[FeatureSet, DataItem] = None,
    label_column: str = None,
    plots_dest: str = "plots",
    random_state: int = 1,
    problem_type: str = "classification",
    dask_key: str = "dask_key",
    dask_function: str = None,
    dask_client=None,
) -> None:
    """
    The function will output the following artifacts per
    column within the data frame (based on data types)
    If the data has more than 500,000 sample we
    sample randomly 500,000 samples:

    describe csv
    histograms
    scatter-2d
    violin chart
    correlation-matrix chart
    correlation-matrix csv
    imbalance pie chart
    imbalance-weights-vec csv

    :param context:                 The function context
    :param name:                    Key of dataset to database ("dataset" for default)
    :param table:                   MLRun input pointing to pandas dataframe (csv/parquet file path) or FeatureSet
                                    as param
    :param label_column:            Ground truth column label
    :param plots_dest:              Destination folder of summary plots (relative to artifact_path)
                                    ("plots" for default)
    :param random_state:            When the table has more than 500,000 samples, we sample randomly 500,000 samples
    :param problem_type             The type of the ML problem the data facing - regression, classification or None
                                    (classification for default)
    :param dask_key:                Key of dataframe in dask client "datasets" attribute
    :param dask_function:           Dask function url (db://..)
    :param dask_client:             Dask client object
    """
    data_item, featureset, creat, update = False, False, False, False
    get_from_table = True
    if dask_function or dask_client:
        data_item, creat = True, True
        if dask_function:
            client = mlrun.import_function(dask_function).client
        elif dask_client:
            client = dask_client
        else:
            raise ValueError("dask client was not provided")

        if dask_key in client.datasets:
            df = client.get_dataset(dask_key)
            data_item, creat, get_from_table = True, True, False
        elif table:
            get_from_table = True
        else:
            context.logger.info(
                f"only these datasets are available {client.datasets} in client {client}"
            )
            raise Exception("dataset not found on dask cluster")

    if get_from_table:
        if type(table) == DataItem:
            if table.meta is None:
                data_item, creat, update = True, True, False
            elif table.meta.kind == "dataset":
                data_item, creat, update = True, False, True
            elif table.meta.kind == "FeatureVector":
                data_item, creat, update = True, False, False
            elif table.meta.kind == "FeatureSet":
                featureset, creat, update = True, False, False

        if data_item:
            df = table.as_df()
        elif featureset:
            project_name, set_name = (
                table._path.split("/")[2],
                table._path.split("/")[4],
            )
            feature_set = fstore.get_feature_set(
                f"store://feature-sets/{project_name}/{set_name}"
            )
            df = feature_set.to_dataframe()
        else:
            context.logger.error(f"Wrong table type.")
            return

    if df.size > MAX_SIZE_OF_DF:
        df = df.sample(n=int(MAX_SIZE_OF_DF / df.shape[1]), random_state=random_state)
    extra_data = {}

    if label_column not in df.columns:
        label_column = None

    extra_data["describe csv"] = context.log_artifact(
        TableArtifact("describe-csv", df=df.describe()),
        local_path=f"{plots_dest}/describe.csv",
    )

    try:
        _create_histogram_mat_artifact(
            context, df, extra_data, label_column, plots_dest
        )
    except Exception as e:
        context.logger.warn(f"Failed to create histogram matrix artifact due to: {e}")
    try:
        _create_features_histogram_artifacts(
            context, df, extra_data, label_column, plots_dest, problem_type
        )
    except Exception as e:
        context.logger.warn(f"Failed to create pairplot histograms due to: {e}")
    try:
        _create_features_2d_scatter_artifacts(
            context, df, extra_data, label_column, plots_dest, problem_type
        )
    except Exception as e:
        context.logger.warn(f"Failed to create pairplot 2d_scatter due to: {e}")
    try:
        _create_violin_artifact(context, df, extra_data, plots_dest)
    except Exception as e:
        context.logger.warn(f"Failed to create violin distribution plots due to: {e}")
    try:
        _create_imbalance_artifact(
            context, df, extra_data, label_column, plots_dest, problem_type
        )
    except Exception as e:
        context.logger.warn(f"Failed to create class imbalance plot due to: {e}")
    try:
        _create_corr_artifact(context, df, extra_data, label_column, plots_dest)
    except Exception as e:
        context.logger.warn(f"Failed to create features correlation plot due to: {e}")

    if not data_item:
        return

    artifact = table.artifact_url
    if creat:  # dataset not stored
        artifact = DatasetArtifact(
            key="dataset", stats=True, df=df, extra_data=extra_data
        )
        artifact = context.log_artifact(artifact, db_key=name)
        context.logger.info(f"The data set is logged to the project under {name} name")

    if update:
        update_dataset_meta(artifact, extra_data=extra_data)
        context.logger.info(f"The data set named {name} is updated")

    # TODO : 3-D plot on on selected features.
    # TODO : Reintegration plot on on selected features.
    # TODO : PCA plot (with options)


def _create_histogram_mat_artifact(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
):
    """
    Create and log a histogram matrix artifact
    """
    context.log_artifact(
        item=Artifact(
            key="hist",
            body=b"<b> Deprecated, see the artifacts scatter-2d "
            b"and histograms instead<b>",
        ),
        local_path=f"{plots_dest}/hist.html",
    )


def _create_features_histogram_artifacts(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
    problem_type: str,
):
    """
    Create and log a histogram artifact for each feature
    """

    figs = dict()
    first_feature_name = ""
    if label_column is not None and problem_type == "classification":
        all_labels = df[label_column].unique()
    visible = True
    for column_name in df.columns:
        if column_name == label_column:
            continue

        if label_column is not None and problem_type == "classification":
            for label in all_labels:
                sub_fig = go.Histogram(
                    histfunc="count",
                    x=df.loc[df[label_column] == label][column_name],
                    name=str(label),
                    visible=visible,
                )
                figs[f"{column_name}@?@{label}"] = sub_fig
        else:
            sub_fig = go.Histogram(histfunc="count", x=df[column_name], visible=visible)
            figs[f"{column_name}@?@{1}"] = sub_fig
        if visible:
            first_feature_name = column_name
        visible = False

    fig = go.Figure()
    for k in figs.keys():
        fig.add_trace(figs[k])

    fig.update_layout(
        updatemenus=[
            {
                "buttons": [
                    {
                        "label": column_name,
                        "method": "update",
                        "args": [
                            {
                                "visible": [
                                    key.split("@?@")[0] == column_name
                                    for key in figs.keys()
                                ],
                                "xaxis": {
                                    "range": [
                                        min(df[column_name]),
                                        max(df[column_name]),
                                    ]
                                },
                            },
                            {"title": f"<i><b>Histogram of {column_name}</b></i>"},
                        ],
                    }
                    for column_name in df.columns
                    if column_name != label_column
                ],
                "direction": "down",
                "pad": {"r": 10, "t": 10},
                "showactive": True,
                "x": 0.25,
                "xanchor": "left",
                "y": 1.1,
                "yanchor": "top",
            }
        ],
        annotations=[
            dict(
                text="Select Feature Name ",
                showarrow=False,
                x=0,
                y=1.05,
                yref="paper",
                xref="paper",
                align="left",
                xanchor="left",
                yanchor="top",
                font={
                    "color": "blue",
                },
            )
        ],
    )

    fig.update_layout(
        width=600,
        height=400,
        autosize=False,
        margin=dict(t=100, b=0, l=0, r=0),
        template="plotly_white",
    )

    fig.update_layout(title_text=f"<i><b>Histograms of {first_feature_name}</b></i>")
    extra_data[f"histograms"] = context.log_artifact(
        PlotlyArtifact(key=f"histograms", figure=fig),
        local_path=f"{plots_dest}/histograms.html",
    )


def _create_features_2d_scatter_artifacts(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
    problem_type: str,
):
    """
    Create and log a scatter-2d artifact for each couple of features
    """
    features = [
        column_name for column_name in df.columns if column_name != label_column
    ]
    max_feature_len = float(max(len(elem) for elem in features))
    if label_column is not None:
        labels = sorted(df[label_column].unique())
    else:
        labels = [None]
    fig = go.Figure()
    if label_column is not None and problem_type == "classification":
        for l in labels:
            fig.add_trace(
                go.Scatter(
                    x=df.loc[df[label_column] == l][features[0]],
                    y=df.loc[df[label_column] == l][features[0]],
                    mode="markers",
                    visible=True,
                    showlegend=True,
                    name=str(l),
                )
            )
    elif label_column is None:
        fig.add_trace(
            go.Scatter(
                x=df[features[0]],
                y=df[features[0]],
                mode="markers",
                visible=True,
            )
        )
    elif problem_type == "regression":
        fig.add_trace(
            go.Scatter(
                x=df[features[0]],
                y=df[features[0]],
                mode="markers",
                marker=dict(
                    color=df[label_column], colorscale="Viridis", showscale=True
                ),
                visible=True,
            )
        )

    x_buttons = []
    y_buttons = []

    for ncol in features:
        if problem_type == "classification" and label_column is not None:
            x_buttons.append(
                dict(
                    method="update",
                    label=ncol,
                    args=[
                        {"x": [df.loc[df[label_column] == l][ncol] for l in labels]},
                        np.arange(len(labels)).tolist(),
                    ],
                )
            )

            y_buttons.append(
                dict(
                    method="update",
                    label=ncol,
                    args=[
                        {"y": [df.loc[df[label_column] == l][ncol] for l in labels]},
                        np.arange(len(labels)).tolist(),
                    ],
                )
            )
        else:
            x_buttons.append(
                dict(method="update", label=ncol, args=[{"x": [df[ncol]]}])
            )

            y_buttons.append(
                dict(method="update", label=ncol, args=[{"y": [df[ncol]]}])
            )

    # Pass buttons to the updatemenus argument
    fig.update_layout(
        updatemenus=[
            dict(buttons=x_buttons, direction="up", x=0.5, y=-0.1),
            dict(buttons=y_buttons, direction="down", x=-max_feature_len / 100, y=0.5),
        ]
    )

    fig.update_layout(
        width=600,
        height=400,
        autosize=False,
        margin=dict(t=100, b=0, l=0, r=0),
        template="plotly_white",
    )

    fig.update_layout(title_text=f"<i><b>Scatter-2d</b></i>")
    extra_data[f"scatter-2d"] = context.log_artifact(
        PlotlyArtifact(key=f"scatter-2d", figure=fig),
        local_path=f"{plots_dest}/scatter-2d.html",
    )


def _create_violin_artifact(
    context: MLClientCtx, df: pd.DataFrame, extra_data: dict, plots_dest: str
):
    """
    Create and log a violin artifact
    """
    cols = 5
    rows = (df.shape[1] // cols) + 1
    fig = make_subplots(rows=rows, cols=cols)

    plot_num = 0

    for column_name in df.columns:
        column_data = df[column_name]
        violin = go.Violin(
            x=[column_name] * column_data.shape[0],
            y=column_data,
            name=column_name,
        )

        fig.add_trace(
            violin,
            row=(plot_num // cols) + 1,
            col=(plot_num % cols) + 1,
        )

        plot_num += 1

    fig["layout"].update(
        height=(rows + 1) * 200,
        width=(cols + 1) * 200,
        title="<i><b>Violin Plots</b></i>",
    )

    fig.update_layout(showlegend=False)
    extra_data["violin"] = context.log_artifact(
        PlotlyArtifact(key="violin", figure=fig),
        local_path=f"{plots_dest}/violin.html",
    )


def _create_imbalance_artifact(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
    problem_type: str,
):
    """
    Create and log an imbalance class artifact (csv + plot)
    """
    if label_column:
        if problem_type == "classification":
            values_column = "count"
            labels_count = df[label_column].value_counts().sort_index()
            df_labels_count = pd.DataFrame(labels_count)
            df_labels_count[label_column] = labels_count.index
            df_labels_count.rename(columns={"": values_column}, inplace=True)
            df_labels_count[values_column] = df_labels_count[values_column] / sum(
                df_labels_count[values_column]
            )
            fig = px.pie(df_labels_count, names=label_column, values=values_column)
        else:
            fig = px.histogram(
                histfunc="count",
                x=df[label_column],
            )
            hist = np.histogram(df[label_column])
            df_labels_count = pd.DataFrame(
                {"min_val": hist[1], "count": hist[0].tolist() + [0]}
            )
        fig.update_layout(title_text="<i><b>Labels Imbalance</b></i>")
        extra_data["imbalance"] = context.log_artifact(
            PlotlyArtifact(key="imbalance", figure=fig),
            local_path=f"{plots_dest}/imbalance.html",
        )
        extra_data["imbalance-csv"] = context.log_artifact(
            TableArtifact("imbalance-weights-vec", df=df_labels_count),
            local_path=f"{plots_dest}/imbalance-weights-vec.csv",
        )


def _create_corr_artifact(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
):
    """
    Create and log an correlation-matrix artifact (csv + plot)
    """
    if label_column is not None:
        df = df.drop([label_column], axis=1)
    tblcorr = df.corr(numeric_only=True)
    extra_data["correlation-matrix-csv"] = context.log_artifact(
        TableArtifact("correlation-matrix-csv", df=tblcorr, visible=True),
        local_path=f"{plots_dest}/correlation-matrix.csv",
    )

    z = tblcorr.values.tolist()
    z_text = [["{:.2f}".format(y) for y in x] for x in z]
    fig = ff.create_annotated_heatmap(
        z,
        x=list(tblcorr.columns),
        y=list(tblcorr.columns),
        annotation_text=z_text,
        colorscale="agsunset",
    )
    fig["layout"]["yaxis"]["autorange"] = "reversed"  # l -> r
    fig.update_layout(title_text="<i><b>Correlation matrix</b></i>")
    fig["data"][0]["showscale"] = True

    extra_data["correlation"] = context.log_artifact(
        PlotlyArtifact(key="correlation", figure=fig),
        local_path=f"{plots_dest}/correlation.html",
    )
 - code_origin: '' - origin_filename: '' description: describe and visualizes dataset stats - disable_auto_mount: false default_handler: analyze -verbose: false -metadata: - tag: '' - name: describe - categories: - - data-analysis -kind: job diff --git a/functions/src/describe/test_describe.py b/functions/src/describe/test_describe.py index 9ffe39abb..4ea56c979 100644 --- a/functions/src/describe/test_describe.py +++ b/functions/src/describe/test_describe.py @@ -15,12 +15,10 @@ import os import shutil from pathlib import Path -from typing import Set -import mlrun import pandas as pd import pytest -from mlrun import code_to_function, import_function, new_function +from mlrun import import_function from mlrun.execution import MLClientCtx from sklearn.datasets import make_classification, make_regression @@ -29,7 +27,7 @@ ARTIFACTS_PATH = os.path.abspath("./artifacts") -def _validate_paths(paths: Set): +def _validate_paths(paths: set): """ Check if all the expected plot are saved """ diff --git a/functions/src/describe_dask/describe_dask.py b/functions/src/describe_dask/describe_dask.py index 3dc382820..a34535a3c 100644 --- a/functions/src/describe_dask/describe_dask.py +++ b/functions/src/describe_dask/describe_dask.py @@ -12,19 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import mlrun import warnings + +import mlrun + warnings.simplefilter(action="ignore", category=FutureWarning) -import pandas as pd import matplotlib.pyplot as plt +import numpy as np +import pandas as pd import seaborn as sns from mlrun.artifacts import PlotArtifact, TableArtifact from mlrun.mlutils.plots import gcf_clear -import numpy as np - pd.set_option("display.float_format", lambda x: "%.2f" % x) + def summarize( context, dask_key: str = "dask_key", @@ -35,7 +37,7 @@ def summarize( dask_client=None, ) -> None: """Summarize a table - + Connects to dask client through the function context, or through an optional user-supplied scheduler. @@ -51,15 +53,17 @@ def summarize( elif dask_client: client = dask_client else: - raise ValueError('dask client was not provided') - + raise ValueError("dask client was not provided") + if dask_key in client.datasets: table = client.get_dataset(dask_key) elif dataset: - #table = dataset.as_df(df_module=dd) + # table = dataset.as_df(df_module=dd) table = dataset.as_df() else: - context.logger.info(f"only these datasets are available {client.datasets} in client {client}") + context.logger.info( + f"only these datasets are available {client.datasets} in client {client}" + ) raise Exception("dataset not found on dask cluster") df = table header = df.columns.values diff --git a/functions/src/describe_dask/function.yaml b/functions/src/describe_dask/function.yaml index baf3ced1d..eaf9b2177 100644 --- a/functions/src/describe_dask/function.yaml +++ b/functions/src/describe_dask/function.yaml @@ -1,17 +1,22 @@ +metadata: + tag: '' + name: describe-dask + categories: + - data-analysis verbose: false +kind: job spec: - disable_auto_mount: false image: mlrun/ml-models - command: '' - default_handler: summarize + disable_auto_mount: false + build: + origin_filename: '' + functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKaW1wb3J0IHdhcm5pbmdzCgppbXBvcnQgbWxydW4KCndhcm5pbmdzLnNpbXBsZWZpbHRlcihhY3Rpb249Imlnbm9yZSIsIGNhdGVnb3J5PUZ1dHVyZVdhcm5pbmcpCmltcG9ydCBtYXRwbG90bGliLnB5cGxvdCBhcyBwbHQKaW1wb3J0IG51bXB5IGFzIG5wCmltcG9ydCBwYW5kYXMgYXMgcGQKaW1wb3J0IHNlYWJvcm4gYXMgc25zCmZyb20gbWxydW4uYXJ0aWZhY3RzIGltcG9ydCBQbG90QXJ0aWZhY3QsIFRhYmxlQXJ0aWZhY3QKZnJvbSBtbHJ1bi5tbHV0aWxzLnBsb3RzIGltcG9ydCBnY2ZfY2xlYXIKCnBkLnNldF9vcHRpb24oImRpc3BsYXkuZmxvYXRfZm9ybWF0IiwgbGFtYmRhIHg6ICIlLjJmIiAlIHgpCgoKZGVmIHN1bW1hcml6ZSgKICAgIGNvbnRleHQsCiAgICBkYXNrX2tleTogc3RyID0gImRhc2tfa2V5IiwKICAgIGRhdGFzZXQ6IG1scnVuLkRhdGFJdGVtID0gTm9uZSwKICAgIGxhYmVsX2NvbHVtbjogc3RyID0gImxhYmVsIiwKICAgIHBsb3RzX2Rlc3Q6IHN0ciA9ICJwbG90cyIsCiAgICBkYXNrX2Z1bmN0aW9uOiBzdHIgPSBOb25lLAogICAgZGFza19jbGllbnQ9Tm9uZSwKKSAtPiBOb25lOgogICAgIiIiU3VtbWFyaXplIGEgdGFibGUKCiAgICBDb25uZWN0cyB0byBkYXNrIGNsaWVudCB0aHJvdWdoIHRoZSBmdW5jdGlvbiBjb250ZXh0LCBvciB0aHJvdWdoIGFuIG9wdGlvbmFsCiAgICB1c2VyLXN1cHBsaWVkIHNjaGVkdWxlci4KCiAgICA6cGFyYW0gY29udGV4dDogICAgICAgICB0aGUgZnVuY3Rpb24gY29udGV4dAogICAgOnBhcmFtIGRhc2tfa2V5OiAgICAgICAga2V5IG9mIGRhdGFmcmFtZSBpbiBkYXNrIGNsaWVudCAiZGF0YXNldHMiIGF0dHJpYnV0ZQogICAgOnBhcmFtIGxhYmVsX2NvbHVtbjogICAgZ3JvdW5kIHRydXRoIGNvbHVtbiBsYWJlbAogICAgOnBhcmFtIHBsb3RzX2Rlc3Q6ICAgICAgZGVzdGluYXRpb24gZm9sZGVyIG9mIHN1bW1hcnkgcGxvdHMgKHJlbGF0aXZlIHRvIGFydGlmYWN0X3BhdGgpCiAgICA6cGFyYW0gZGFza19mdW5jdGlvbjogICBkYXNrIGZ1bmN0aW9uIHVybCAoZGI6Ly8uLikKICAgIDpwYXJhbSBkYXNrX2NsaWVudDogICAgIGRhc2sgY2xpZW50IG9iamVjdAogICAgIiIiCiAgICBpZiBkYXNrX2Z1bmN0aW9uOgogICAgICAgIGNsaWVudCA9IG1scnVuLmltcG9ydF9mdW5jdGlvbihkYXNrX2Z1bmN0aW9uKS5jbGllbnQKICAgIGVsaWYgZGFza19jbGllbnQ6CiAgICAgICAgY2xpZW50ID0gZGFza19jbGllbnQKICAgIGVsc2U6CiAgICAgICAgcmFpc2UgVmFsdWVFcnJvcigiZGFzayBjbGllbnQgd2FzIG5vdCBwcm92aWRlZCIpCgogICAgaWYgZGFza19rZXkgaW4gY2xpZW50LmRhdGFzZXRzOgogICAgICAgIHRhYmxlID0gY2xpZW50LmdldF9kYXRhc2V0KGRhc2tfa2V5KQogICAgZWxpZiBkYXRhc2V0OgogICAgICAgICMgdGFibGUgPSBkYXRhc2V0LmFzX2RmKGRmX21vZHVsZT1kZCkKICAgICAgICB0YWJsZSA9IGRhdGFzZXQuYXNfZGYoKQogICAgZWxzZToKICAgICAgICBjb250ZXh0LmxvZ2dlci5pbmZvKAogICAgICAgICAgICBmIm9ubHkgdGhlc2UgZGF0YXNldHMgYXJlIGF2YWlsYWJsZSB7Y2xpZW50LmRhdGFzZXRzfSBpbiBjbGllbnQge2NsaWVudH0iCiAgICAgICAgKQogICAgICAgIHJhaXNlIEV4Y2VwdGlvbigiZGF0YXNldCBub3QgZm91bmQgb24gZGFzayBjbHVzdGVyIikKICAgIGRmID0gdGFibGUKICAgIGhlYWRlciA9IGRmLmNvbHVtbnMudmFsdWVzCiAgICBleHRyYV9kYXRhID0ge30KCiAgICB0cnk6CiAgICAgICAgZ2NmX2NsZWFyKHBsdCkKICAgICAgICBzbnNwbHQgPSBzbnMucGFpcnBsb3QoZGYsIGh1ZT1sYWJlbF9jb2x1bW4pICAjICwgZGlhZ19rd3M9eyJidyI6IDEuNX0pCiAgICAgICAgZXh0cmFfZGF0YVsiaGlzdG9ncmFtcyJdID0gY29udGV4dC5sb2dfYXJ0aWZhY3QoCiAgICAgICAgICAgIFBsb3RBcnRpZmFjdCgiaGlzdG9ncmFtcyIsIGJvZHk9cGx0LmdjZigpKSwKICAgICAgICAgICAgbG9jYWxfcGF0aD1mIntwbG90c19kZXN0fS9oaXN0Lmh0bWwiLAogICAgICAgICAgICBkYl9rZXk9RmFsc2UsCiAgICAgICAgKQogICAgZXhjZXB0IEV4Y2VwdGlvbiBhcyBlOgogICAgICAgIGNvbnRleHQubG9nZ2VyLmVycm9yKGYiRmFpbGVkIHRvIGNyZWF0ZSBwYWlycGxvdCBoaXN0b2dyYW1zIGR1ZSB0bzoge2V9IikKCiAgICB0cnk6CiAgICAgICAgZ2NmX2NsZWFyKHBsdCkKICAgICAgICBwbG90X2NvbHMgPSAzCiAgICAgICAgcGxvdF9yb3dzID0gaW50KChsZW4oaGVhZGVyKSAtIDEpIC8gcGxvdF9jb2xzKSArIDEKICAgICAgICBmaWcsIGF4ID0gcGx0LnN1YnBsb3RzKHBsb3Rfcm93cywgcGxvdF9jb2xzLCBmaWdzaXplPSgxNSwgNCkpCiAgICAgICAgZmlnLnRpZ2h0X2xheW91dChwYWQ9Mi4wKQogICAgICAgIGZvciBpIGluIHJhbmdlKHBsb3Rfcm93cyAqIHBsb3RfY29scyk6CiAgICAgICAgICAgIGlmIGkgPCBsZW4oaGVhZGVyKToKICAgICAgICAgICAgICAgIHNucy52aW9saW5wbG90KAogICAgICAgICAgICAgICAgICAgIHg9ZGZbaGVhZGVyW2ldXSwKICAgICAgICAgICAgICAgICAgICBheD1heFtpbnQoaSAvIHBsb3RfY29scyldW2kgJSBwbG90X2NvbHNdLAogICAgICAgICAgICAgICAgICAgIG9yaWVudD0iaCIsCiAgICAgICAgICAgICAgICAgICAgd2lkdGg9MC43LAogICAgICAgICAgICAgICAgICAgIGlubmVyPSJxdWFydGlsZSIsCiAgICAgICAgICAgICAgICApCiAgICAgICAgICAgIGVsc2U6CiAgICAgICAgICAgICAgICBmaWcuZGVsYXhlcyhheFtpbnQoaSAvIHBsb3RfY29scyldW2kgJSBwbG90X2NvbHNdKQogICAgICAgICAgICBpICs9IDEKICAgICAgICBleHRyYV9kYXRhWyJ2aW9saW4iXSA9IGNvbnRleHQubG9nX2FydGlmYWN0KAogICAgICAgICAgICBQbG90QXJ0aWZhY3QoInZpb2xpbiIsIGJvZHk9cGx0LmdjZigpLCB0aXRsZT0iVmlvbGluIFBsb3QiKSwKICAgICAgICAgICAgbG9jYWxfcGF0aD1mIntwbG90c19kZXN0fS92aW9saW4uaHRtbCIsCiAgICAgICAgICAgIGRiX2tleT1GYWxzZSwKICAgICAgICApCiAgICBleGNlcHQgRXhjZXB0aW9uIGFzIGU6CiAgICAgICAgY29udGV4dC5sb2dnZXIud2FybihmIkZhaWxlZCB0byBjcmVhdGUgdmlvbGluIGRpc3RyaWJ1dGlvbiBwbG90cyBkdWUgdG86IHtlfSIpCgogICAgaWYgbGFiZWxfY29sdW1uOgogICAgICAgIGxhYmVscyA9IGRmLnBvcChsYWJlbF9jb2x1bW4pCiAgICAgICAgaW1idGFibGUgPSBsYWJlbHMudmFsdWVfY291bnRzKG5vcm1hbGl6ZT1UcnVlKS5zb3J0X2luZGV4KCkKICAgICAgICB0cnk6CiAgICAgICAgICAgIGdjZl9jbGVhcihwbHQpCiAgICAgICAgICAgIGJhbGFuY2ViYXIgPSBpbWJ0YWJsZS5wbG90KGtpbmQ9ImJhciIsIHRpdGxlPSJjbGFzcyBpbWJhbGFuY2UgLSBsYWJlbHMiKQogICAgICAgICAgICBiYWxhbmNlYmFyLnNldF94bGFiZWwoImNsYXNzIikKICAgICAgICAgICAgYmFsYW5jZWJhci5zZXRfeWxhYmVsKCJwcm9wb3J0aW9uIG9mIHRvdGFsIikKICAgICAgICAgICAgZXh0cmFfZGF0YVsiaW1iYWxhbmNlIl0gPSBjb250ZXh0LmxvZ19hcnRpZmFjdCgKICAgICAgICAgICAgICAgIFBsb3RBcnRpZmFjdCgiaW1iYWxhbmNlIiwgYm9keT1wbHQuZ2NmKCkpLAogICAgICAgICAgICAgICAgbG9jYWxfcGF0aD1mIntwbG90c19kZXN0fS9pbWJhbGFuY2UuaHRtbCIsCiAgICAgICAgICAgICkKICAgICAgICBleGNlcHQgRXhjZXB0aW9uIGFzIGU6CiAgICAgICAgICAgIGNvbnRleHQubG9nZ2VyLndhcm4oZiJGYWlsZWQgdG8gY3JlYXRlIGNsYXNzIGltYmFsYW5jZSBwbG90IGR1ZSB0bzoge2V9IikKICAgICAgICBjb250ZXh0LmxvZ19hcnRpZmFjdCgKICAgICAgICAgICAgVGFibGVBcnRpZmFjdCgKICAgICAgICAgICAgICAgICJpbWJhbGFuY2Utd2VpZ2h0cy12ZWMiLCBkZj1wZC5EYXRhRnJhbWUoeyJ3ZWlnaHRzIjogaW1idGFibGV9KQogICAgICAgICAgICApLAogICAgICAgICAgICBsb2NhbF9wYXRoPWYie3Bsb3RzX2Rlc3R9L2ltYmFsYW5jZS13ZWlnaHRzLXZlYy5jc3YiLAogICAgICAgICAgICBkYl9rZXk9RmFsc2UsCiAgICAgICAgKQoKICAgIHRibGNvcnIgPSBkZi5jb3JyKCkKICAgIG1hc2sgPSBucC56ZXJvc19saWtlKHRibGNvcnIsIGR0eXBlPW5wLmJvb2wpCiAgICBtYXNrW25wLnRyaXVfaW5kaWNlc19mcm9tKG1hc2spXSA9IFRydWUKCiAgICBkZmNvcnIgPSBwZC5EYXRhRnJhbWUoZGF0YT10Ymxjb3JyLCBjb2x1bW5zPWhlYWRlciwgaW5kZXg9aGVhZGVyKQogICAgZGZjb3JyID0gZGZjb3JyW25wLmFyYW5nZShkZmNvcnIuc2hhcGVbMF0pWzosIE5vbmVdID4gbnAuYXJhbmdlKGRmY29yci5zaGFwZVsxXSldCiAgICBjb250ZXh0LmxvZ19hcnRpZmFjdCgKICAgICAgICBUYWJsZUFydGlmYWN0KCJjb3JyZWxhdGlvbi1tYXRyaXgiLCBkZj10Ymxjb3JyLCB2aXNpYmxlPVRydWUpLAogICAgICAgIGxvY2FsX3BhdGg9ZiJ7cGxvdHNfZGVzdH0vY29ycmVsYXRpb24tbWF0cml4LmNzdiIsCiAgICAgICAgZGJfa2V5PUZhbHNlLAogICAgKQoKICAgIHRyeToKICAgICAgICBnY2ZfY2xlYXIocGx0KQogICAgICAgIGF4ID0gcGx0LmF4ZXMoKQogICAgICAgIHNucy5oZWF0bWFwKHRibGNvcnIsIGF4PWF4LCBtYXNrPW1hc2ssIGFubm90PUZhbHNlLCBjbWFwPXBsdC5jbS5SZWRzKQogICAgICAgIGF4LnNldF90aXRsZSgiZmVhdHVyZXMgY29ycmVsYXRpb24iKQogICAgICAgIGV4dHJhX2RhdGFbImNvcnJlbGF0aW9uIl0gPSBjb250ZXh0LmxvZ19hcnRpZmFjdCgKICAgICAgICAgICAgUGxvdEFydGlmYWN0KCJjb3JyZWxhdGlvbiIsIGJvZHk9cGx0LmdjZigpLCB0aXRsZT0iQ29ycmVsYXRpb24gTWF0cml4IiksCiAgICAgICAgICAgIGxvY2FsX3BhdGg9ZiJ7cGxvdHNfZGVzdH0vY29yci5odG1sIiwKICAgICAgICAgICAgZGJfa2V5PUZhbHNlLAogICAgICAgICkKICAgIGV4Y2VwdCBFeGNlcHRpb24gYXMgZToKICAgICAgICBjb250ZXh0LmxvZ2dlci53YXJuKGYiRmFpbGVkIHRvIGNyZWF0ZSBmZWF0dXJlcyBjb3JyZWxhdGlvbiBwbG90IGR1ZSB0bzoge2V9IikKCiAgICBnY2ZfY2xlYXIocGx0KQo= + code_origin: '' + filename: describe_dask.py entry_points: summarize: outputs: - type: None - has_kwargs: false - name: summarize - has_varargs: false - lineno: 28 parameters: - name: context doc: the function context @@ -37,20 +42,16 @@ spec: - name: dask_client doc: dask client object default: null + name: summarize doc: 'Summarize a table Connects to dask client through the function context, or through an optional user-supplied scheduler.' - build: - code_origin: '' - origin_filename: '' - functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKaW1wb3J0IG1scnVuCmltcG9ydCB3YXJuaW5ncwp3YXJuaW5ncy5zaW1wbGVmaWx0ZXIoYWN0aW9uPSJpZ25vcmUiLCBjYXRlZ29yeT1GdXR1cmVXYXJuaW5nKQppbXBvcnQgcGFuZGFzIGFzIHBkCmltcG9ydCBtYXRwbG90bGliLnB5cGxvdCBhcyBwbHQKaW1wb3J0IHNlYWJvcm4gYXMgc25zCmZyb20gbWxydW4uYXJ0aWZhY3RzIGltcG9ydCBQbG90QXJ0aWZhY3QsIFRhYmxlQXJ0aWZhY3QKZnJvbSBtbHJ1bi5tbHV0aWxzLnBsb3RzIGltcG9ydCBnY2ZfY2xlYXIKaW1wb3J0IG51bXB5IGFzIG5wCgoKcGQuc2V0X29wdGlvbigiZGlzcGxheS5mbG9hdF9mb3JtYXQiLCBsYW1iZGEgeDogIiUuMmYiICUgeCkKCmRlZiBzdW1tYXJpemUoCiAgICBjb250ZXh0LAogICAgZGFza19rZXk6IHN0ciA9ICJkYXNrX2tleSIsCiAgICBkYXRhc2V0OiBtbHJ1bi5EYXRhSXRlbSA9IE5vbmUsCiAgICBsYWJlbF9jb2x1bW46IHN0ciA9ICJsYWJlbCIsCiAgICBwbG90c19kZXN0OiBzdHIgPSAicGxvdHMiLAogICAgZGFza19mdW5jdGlvbjogc3RyID0gTm9uZSwKICAgIGRhc2tfY2xpZW50PU5vbmUsCikgLT4gTm9uZToKICAgICIiIlN1bW1hcml6ZSBhIHRhYmxlCiAgICAKICAgIENvbm5lY3RzIHRvIGRhc2sgY2xpZW50IHRocm91Z2ggdGhlIGZ1bmN0aW9uIGNvbnRleHQsIG9yIHRocm91Z2ggYW4gb3B0aW9uYWwKICAgIHVzZXItc3VwcGxpZWQgc2NoZWR1bGVyLgoKICAgIDpwYXJhbSBjb250ZXh0OiAgICAgICAgIHRoZSBmdW5jdGlvbiBjb250ZXh0CiAgICA6cGFyYW0gZGFza19rZXk6ICAgICAgICBrZXkgb2YgZGF0YWZyYW1lIGluIGRhc2sgY2xpZW50ICJkYXRhc2V0cyIgYXR0cmlidXRlCiAgICA6cGFyYW0gbGFiZWxfY29sdW1uOiAgICBncm91bmQgdHJ1dGggY29sdW1uIGxhYmVsCiAgICA6cGFyYW0gcGxvdHNfZGVzdDogICAgICBkZXN0aW5hdGlvbiBmb2xkZXIgb2Ygc3VtbWFyeSBwbG90cyAocmVsYXRpdmUgdG8gYXJ0aWZhY3RfcGF0aCkKICAgIDpwYXJhbSBkYXNrX2Z1bmN0aW9uOiAgIGRhc2sgZnVuY3Rpb24gdXJsIChkYjovLy4uKQogICAgOnBhcmFtIGRhc2tfY2xpZW50OiAgICAgZGFzayBjbGllbnQgb2JqZWN0CiAgICAiIiIKICAgIGlmIGRhc2tfZnVuY3Rpb246CiAgICAgICAgY2xpZW50ID0gbWxydW4uaW1wb3J0X2Z1bmN0aW9uKGRhc2tfZnVuY3Rpb24pLmNsaWVudAogICAgZWxpZiBkYXNrX2NsaWVudDoKICAgICAgICBjbGllbnQgPSBkYXNrX2NsaWVudAogICAgZWxzZToKICAgICAgICByYWlzZSBWYWx1ZUVycm9yKCdkYXNrIGNsaWVudCB3YXMgbm90IHByb3ZpZGVkJykKICAgICAgICAKICAgIGlmIGRhc2tfa2V5IGluIGNsaWVudC5kYXRhc2V0czoKICAgICAgICB0YWJsZSA9IGNsaWVudC5nZXRfZGF0YXNldChkYXNrX2tleSkKICAgIGVsaWYgZGF0YXNldDoKICAgICAgICAjdGFibGUgPSBkYXRhc2V0LmFzX2RmKGRmX21vZHVsZT1kZCkKICAgICAgICB0YWJsZSA9IGRhdGFzZXQuYXNfZGYoKQogICAgZWxzZToKICAgICAgICBjb250ZXh0LmxvZ2dlci5pbmZvKGYib25seSB0aGVzZSBkYXRhc2V0cyBhcmUgYXZhaWxhYmxlIHtjbGllbnQuZGF0YXNldHN9IGluIGNsaWVudCB7Y2xpZW50fSIpCiAgICAgICAgcmFpc2UgRXhjZXB0aW9uKCJkYXRhc2V0IG5vdCBmb3VuZCBvbiBkYXNrIGNsdXN0ZXIiKQogICAgZGYgPSB0YWJsZQogICAgaGVhZGVyID0gZGYuY29sdW1ucy52YWx1ZXMKICAgIGV4dHJhX2RhdGEgPSB7fQoKICAgIHRyeToKICAgICAgICBnY2ZfY2xlYXIocGx0KQogICAgICAgIHNuc3BsdCA9IHNucy5wYWlycGxvdChkZiwgaHVlPWxhYmVsX2NvbHVtbikgICMgLCBkaWFnX2t3cz17ImJ3IjogMS41fSkKICAgICAgICBleHRyYV9kYXRhWyJoaXN0b2dyYW1zIl0gPSBjb250ZXh0LmxvZ19hcnRpZmFjdCgKICAgICAgICAgICAgUGxvdEFydGlmYWN0KCJoaXN0b2dyYW1zIiwgYm9keT1wbHQuZ2NmKCkpLAogICAgICAgICAgICBsb2NhbF9wYXRoPWYie3Bsb3RzX2Rlc3R9L2hpc3QuaHRtbCIsCiAgICAgICAgICAgIGRiX2tleT1GYWxzZSwKICAgICAgICApCiAgICBleGNlcHQgRXhjZXB0aW9uIGFzIGU6CiAgICAgICAgY29udGV4dC5sb2dnZXIuZXJyb3IoZiJGYWlsZWQgdG8gY3JlYXRlIHBhaXJwbG90IGhpc3RvZ3JhbXMgZHVlIHRvOiB7ZX0iKQoKICAgIHRyeToKICAgICAgICBnY2ZfY2xlYXIocGx0KQogICAgICAgIHBsb3RfY29scyA9IDMKICAgICAgICBwbG90X3Jvd3MgPSBpbnQoKGxlbihoZWFkZXIpIC0gMSkgLyBwbG90X2NvbHMpICsgMQogICAgICAgIGZpZywgYXggPSBwbHQuc3VicGxvdHMocGxvdF9yb3dzLCBwbG90X2NvbHMsIGZpZ3NpemU9KDE1LCA0KSkKICAgICAgICBmaWcudGlnaHRfbGF5b3V0KHBhZD0yLjApCiAgICAgICAgZm9yIGkgaW4gcmFuZ2UocGxvdF9yb3dzICogcGxvdF9jb2xzKToKICAgICAgICAgICAgaWYgaSA8IGxlbihoZWFkZXIpOgogICAgICAgICAgICAgICAgc25zLnZpb2xpbnBsb3QoCiAgICAgICAgICAgICAgICAgICAgeD1kZltoZWFkZXJbaV1dLAogICAgICAgICAgICAgICAgICAgIGF4PWF4W2ludChpIC8gcGxvdF9jb2xzKV1baSAlIHBsb3RfY29sc10sCiAgICAgICAgICAgICAgICAgICAgb3JpZW50PSJoIiwKICAgICAgICAgICAgICAgICAgICB3aWR0aD0wLjcsCiAgICAgICAgICAgICAgICAgICAgaW5uZXI9InF1YXJ0aWxlIiwKICAgICAgICAgICAgICAgICkKICAgICAgICAgICAgZWxzZToKICAgICAgICAgICAgICAgIGZpZy5kZWxheGVzKGF4W2ludChpIC8gcGxvdF9jb2xzKV1baSAlIHBsb3RfY29sc10pCiAgICAgICAgICAgIGkgKz0gMQogICAgICAgIGV4dHJhX2RhdGFbInZpb2xpbiJdID0gY29udGV4dC5sb2dfYXJ0aWZhY3QoCiAgICAgICAgICAgIFBsb3RBcnRpZmFjdCgidmlvbGluIiwgYm9keT1wbHQuZ2NmKCksIHRpdGxlPSJWaW9saW4gUGxvdCIpLAogICAgICAgICAgICBsb2NhbF9wYXRoPWYie3Bsb3RzX2Rlc3R9L3Zpb2xpbi5odG1sIiwKICAgICAgICAgICAgZGJfa2V5PUZhbHNlLAogICAgICAgICkKICAgIGV4Y2VwdCBFeGNlcHRpb24gYXMgZToKICAgICAgICBjb250ZXh0LmxvZ2dlci53YXJuKGYiRmFpbGVkIHRvIGNyZWF0ZSB2aW9saW4gZGlzdHJpYnV0aW9uIHBsb3RzIGR1ZSB0bzoge2V9IikKCiAgICBpZiBsYWJlbF9jb2x1bW46CiAgICAgICAgbGFiZWxzID0gZGYucG9wKGxhYmVsX2NvbHVtbikKICAgICAgICBpbWJ0YWJsZSA9IGxhYmVscy52YWx1ZV9jb3VudHMobm9ybWFsaXplPVRydWUpLnNvcnRfaW5kZXgoKQogICAgICAgIHRyeToKICAgICAgICAgICAgZ2NmX2NsZWFyKHBsdCkKICAgICAgICAgICAgYmFsYW5jZWJhciA9IGltYnRhYmxlLnBsb3Qoa2luZD0iYmFyIiwgdGl0bGU9ImNsYXNzIGltYmFsYW5jZSAtIGxhYmVscyIpCiAgICAgICAgICAgIGJhbGFuY2ViYXIuc2V0X3hsYWJlbCgiY2xhc3MiKQogICAgICAgICAgICBiYWxhbmNlYmFyLnNldF95bGFiZWwoInByb3BvcnRpb24gb2YgdG90YWwiKQogICAgICAgICAgICBleHRyYV9kYXRhWyJpbWJhbGFuY2UiXSA9IGNvbnRleHQubG9nX2FydGlmYWN0KAogICAgICAgICAgICAgICAgUGxvdEFydGlmYWN0KCJpbWJhbGFuY2UiLCBib2R5PXBsdC5nY2YoKSksCiAgICAgICAgICAgICAgICBsb2NhbF9wYXRoPWYie3Bsb3RzX2Rlc3R9L2ltYmFsYW5jZS5odG1sIiwKICAgICAgICAgICAgKQogICAgICAgIGV4Y2VwdCBFeGNlcHRpb24gYXMgZToKICAgICAgICAgICAgY29udGV4dC5sb2dnZXIud2FybihmIkZhaWxlZCB0byBjcmVhdGUgY2xhc3MgaW1iYWxhbmNlIHBsb3QgZHVlIHRvOiB7ZX0iKQogICAgICAgIGNvbnRleHQubG9nX2FydGlmYWN0KAogICAgICAgICAgICBUYWJsZUFydGlmYWN0KAogICAgICAgICAgICAgICAgImltYmFsYW5jZS13ZWlnaHRzLXZlYyIsIGRmPXBkLkRhdGFGcmFtZSh7IndlaWdodHMiOiBpbWJ0YWJsZX0pCiAgICAgICAgICAgICksCiAgICAgICAgICAgIGxvY2FsX3BhdGg9ZiJ7cGxvdHNfZGVzdH0vaW1iYWxhbmNlLXdlaWdodHMtdmVjLmNzdiIsCiAgICAgICAgICAgIGRiX2tleT1GYWxzZSwKICAgICAgICApCgogICAgdGJsY29yciA9IGRmLmNvcnIoKQogICAgbWFzayA9IG5wLnplcm9zX2xpa2UodGJsY29yciwgZHR5cGU9bnAuYm9vbCkKICAgIG1hc2tbbnAudHJpdV9pbmRpY2VzX2Zyb20obWFzayldID0gVHJ1ZQoKICAgIGRmY29yciA9IHBkLkRhdGFGcmFtZShkYXRhPXRibGNvcnIsIGNvbHVtbnM9aGVhZGVyLCBpbmRleD1oZWFkZXIpCiAgICBkZmNvcnIgPSBkZmNvcnJbbnAuYXJhbmdlKGRmY29yci5zaGFwZVswXSlbOiwgTm9uZV0gPiBucC5hcmFuZ2UoZGZjb3JyLnNoYXBlWzFdKV0KICAgIGNvbnRleHQubG9nX2FydGlmYWN0KAogICAgICAgIFRhYmxlQXJ0aWZhY3QoImNvcnJlbGF0aW9uLW1hdHJpeCIsIGRmPXRibGNvcnIsIHZpc2libGU9VHJ1ZSksCiAgICAgICAgbG9jYWxfcGF0aD1mIntwbG90c19kZXN0fS9jb3JyZWxhdGlvbi1tYXRyaXguY3N2IiwKICAgICAgICBkYl9rZXk9RmFsc2UsCiAgICApCgogICAgdHJ5OgogICAgICAgIGdjZl9jbGVhcihwbHQpCiAgICAgICAgYXggPSBwbHQuYXhlcygpCiAgICAgICAgc25zLmhlYXRtYXAodGJsY29yciwgYXg9YXgsIG1hc2s9bWFzaywgYW5ub3Q9RmFsc2UsIGNtYXA9cGx0LmNtLlJlZHMpCiAgICAgICAgYXguc2V0X3RpdGxlKCJmZWF0dXJlcyBjb3JyZWxhdGlvbiIpCiAgICAgICAgZXh0cmFfZGF0YVsiY29ycmVsYXRpb24iXSA9IGNvbnRleHQubG9nX2FydGlmYWN0KAogICAgICAgICAgICBQbG90QXJ0aWZhY3QoImNvcnJlbGF0aW9uIiwgYm9keT1wbHQuZ2NmKCksIHRpdGxlPSJDb3JyZWxhdGlvbiBNYXRyaXgiKSwKICAgICAgICAgICAgbG9jYWxfcGF0aD1mIntwbG90c19kZXN0fS9jb3JyLmh0bWwiLAogICAgICAgICAgICBkYl9rZXk9RmFsc2UsCiAgICAgICAgKQogICAgZXhjZXB0IEV4Y2VwdGlvbiBhcyBlOgogICAgICAgIGNvbnRleHQubG9nZ2VyLndhcm4oZiJGYWlsZWQgdG8gY3JlYXRlIGZlYXR1cmVzIGNvcnJlbGF0aW9uIHBsb3QgZHVlIHRvOiB7ZX0iKQoKICAgIGdjZl9jbGVhcihwbHQpCg== + has_kwargs: false + has_varargs: false + lineno: 30 + command: '' description: describe and visualizes dataset stats -metadata: - categories: - - data-analysis - tag: '' - name: describe-dask -kind: job + default_handler: summarize diff --git a/functions/src/describe_dask/test_describe_dask.py b/functions/src/describe_dask/test_describe_dask.py index d5c38b71c..c478ac2b7 100644 --- a/functions/src/describe_dask/test_describe_dask.py +++ b/functions/src/describe_dask/test_describe_dask.py @@ -12,21 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from mlrun import code_to_function, new_function, import_function -from pathlib import Path import os -DATA_URL = 'https://s3.wasabisys.com/iguazio/data/iris/iris_dataset.csv' -ARTIFACTS_PATH = 'artifacts' -PLOTS_PATH = ARTIFACTS_PATH + '/plots' +from mlrun import code_to_function, import_function, new_function + +DATA_URL = "https://s3.wasabisys.com/iguazio/data/iris/iris_dataset.csv" +ARTIFACTS_PATH = "artifacts" +PLOTS_PATH = ARTIFACTS_PATH + "/plots" GENERATED_ARTIFACTS = [ - 'correlation', 'correlation-matrix', 'histograms', 'imbalance', 'imbalance-weights-vec', 'violin' + "correlation", + "correlation-matrix", + "histograms", + "imbalance", + "imbalance-weights-vec", + "violin", ] def _create_dask_func(uri): dask_cluster_name = "dask-cluster" - dask_cluster = new_function(dask_cluster_name, kind='dask', image='mlrun/ml-models') + dask_cluster = new_function(dask_cluster_name, kind="dask", image="mlrun/ml-models") dask_cluster.spec.remote = False dask_uri = uri dask_cluster.export(dask_uri) @@ -35,15 +40,15 @@ def _create_dask_func(uri): def test_code_to_function_describe_dask(): dask_uri = "dask_func.yaml" _create_dask_func(dask_uri) - fn = code_to_function(filename="describe_dask.py", kind='local') + fn = code_to_function(filename="describe_dask.py", kind="local") fn.spec.command = "describe_dask.py" run = fn.run( inputs={"dataset": DATA_URL}, params={ - 'update_dataset': True, - 'label_column': 'label', - 'dask_function': dask_uri, + "update_dataset": True, + "label_column": "label", + "dask_function": dask_uri, }, handler="summarize", ) @@ -54,18 +59,17 @@ def test_code_to_function_describe_dask(): def test_import_function_describe_dask(): dask_uri = "dask_func.yaml" _create_dask_func(dask_uri) - fn = import_function('function.yaml') + fn = import_function("function.yaml") run = fn.run( - inputs={ - "dataset": DATA_URL}, + inputs={"dataset": DATA_URL}, params={ - 'update_dataset': True, - 'label_column': 'label', - 'dask_function': dask_uri, + "update_dataset": True, + "label_column": "label", + "dask_function": dask_uri, }, handler="summarize", - artifact_path=os.getcwd() + '/artifacts', + artifact_path=os.getcwd() + "/artifacts", local=True, ) diff --git a/functions/src/describe_spark/describe_spark.py b/functions/src/describe_spark/describe_spark.py index 856b2505c..5e5902781 100644 --- a/functions/src/describe_spark/describe_spark.py +++ b/functions/src/describe_spark/describe_spark.py @@ -14,39 +14,45 @@ # # Generated by nuclio.export.NuclioExporter -import mlrun -from mlrun.platforms.iguazio import mount_v3io, mount_v3iod -from mlrun.datastore import DataItem -from mlrun.execution import MLClientCtx - -import os +import warnings from subprocess import run -import pandas as pd -import numpy as np -from pyspark.sql.types import LongType +import numpy as np +import pandas as pd +from mlrun.datastore import DataItem +from mlrun.execution import MLClientCtx from pyspark.sql import SparkSession -import sys -import base64 as b64 -import warnings warnings.filterwarnings("ignore") +import json from itertools import product -import matplotlib -import numpy as np -import json -import pandas as pd -from matplotlib import pyplot as plt -from pkg_resources import resource_filename -import six +import matplotlib from pyspark.sql import DataFrame as SparkDataFrame -from pyspark.sql.functions import (abs as df_abs, col, count, countDistinct, - max as df_max, mean, min as df_min, - sum as df_sum, when - ) -from pyspark.sql.functions import variance, stddev, kurtosis, skewness +from pyspark.sql.functions import ( + abs as df_abs, +) +from pyspark.sql.functions import ( + col, + count, + countDistinct, + kurtosis, + mean, + skewness, + stddev, + variance, + when, +) +from pyspark.sql.functions import ( + max as df_max, +) +from pyspark.sql.functions import ( + min as df_min, +) +from pyspark.sql.functions import ( + sum as df_sum, +) def describe(df, bins, corr_reject, config, **kwargs): @@ -65,20 +71,20 @@ def describe(df, bins, corr_reject, config, **kwargs): def pretty_name(x): x *= 100 if x == int(x): - return '%.0f%%' % x + return "%.0f%%" % x else: - return '%.1f%%' % x + return "%.1f%%" % x def corr_matrix(df, columns=None): if columns is None: columns = df.columns - combinations = list(product(columns,columns)) + combinations = list(product(columns, columns)) def separate(l, n): for i in range(0, len(l), n): - yield l[i:i+n] + yield l[i : i + n] - grouped = list(separate(combinations,len(columns))) + grouped = list(separate(combinations, len(columns))) df_cleaned = df.select(*columns).na.drop(how="any") for i in grouped: @@ -88,11 +94,10 @@ def separate(l, n): df_pandas = pd.DataFrame(grouped).applymap(lambda x: x[2]) df_pandas.columns = columns df_pandas.index = columns - + return df_pandas def create_hist_data(df, column, minim, maxim, bins=10): - def create_all_conditions(current_col, column, left_edges, count=1): """ Recursive function that exploits the @@ -105,11 +110,14 @@ def create_all_conditions(current_col, column, left_edges, count=1): if len(left_edges) == 1: next_col = current_col.when(col(column) >= float(left_edges[0]), count) left_edges.pop(0) - return create_all_conditions(next_col, column, left_edges[:], count+1) - next_col = current_col.when((float(left_edges[0]) <= col(column)) - & (col(column) < float(left_edges[1])), count) + return create_all_conditions(next_col, column, left_edges[:], count + 1) + next_col = current_col.when( + (float(left_edges[0]) <= col(column)) + & (col(column) < float(left_edges[1])), + count, + ) left_edges.pop(0) - return create_all_conditions(next_col, column, left_edges[:], count+1) + return create_all_conditions(next_col, column, left_edges[:], count + 1) num_range = maxim - minim bin_width = num_range / float(bins) @@ -117,20 +125,25 @@ def create_all_conditions(current_col, column, left_edges, count=1): for _bin in range(bins): left_edges = left_edges + [left_edges[-1] + bin_width] left_edges.pop() - expression_col = when((float(left_edges[0]) <= col(column)) - & (col(column) < float(left_edges[1])), 0) + expression_col = when( + (float(left_edges[0]) <= col(column)) + & (col(column) < float(left_edges[1])), + 0, + ) left_edges_copy = left_edges[:] left_edges_copy.pop(0) - bin_data = (df.select(col(column)) - .na.drop() - .select(col(column), - create_all_conditions(expression_col, - column, - left_edges_copy - ).alias("bin_id") - ) - .groupBy("bin_id").count() - ).toPandas() + bin_data = ( + df.select(col(column)) + .na.drop() + .select( + col(column), + create_all_conditions(expression_col, column, left_edges_copy).alias( + "bin_id" + ), + ) + .groupBy("bin_id") + .count() + ).toPandas() bin_data.index = bin_data["bin_id"] new_index = list(range(bins)) @@ -140,85 +153,102 @@ def create_all_conditions(current_col, column, left_edges, count=1): bin_data["left_edge"] = left_edges bin_data["width"] = bin_width - return bin_data - def describe_integer_1d(df, column, current_result, nrows): - - stats_df = df.select(column).na.drop().agg(mean(col(column)).alias("mean"), - df_min(col(column)).alias("min"), - df_max(col(column)).alias("max"), - variance(col(column)).alias("variance"), - kurtosis(col(column)).alias("kurtosis"), - stddev(col(column)).alias("std"), - skewness(col(column)).alias("skewness"), - df_sum(col(column)).alias("sum") - ).toPandas() - + stats_df = ( + df.select(column) + .na.drop() + .agg( + mean(col(column)).alias("mean"), + df_min(col(column)).alias("min"), + df_max(col(column)).alias("max"), + variance(col(column)).alias("variance"), + kurtosis(col(column)).alias("kurtosis"), + stddev(col(column)).alias("std"), + skewness(col(column)).alias("skewness"), + df_sum(col(column)).alias("sum"), + ) + .toPandas() + ) for x in np.array([0.05, 0.25, 0.5, 0.75, 0.95]): - stats_df[pretty_name(x)] = (df.select(column) - .na.drop() - .selectExpr("percentile(`{col}`,CAST({n} AS DOUBLE))" - .format(col=column, n=x)).toPandas().iloc[:,0] - ) + stats_df[pretty_name(x)] = ( + df.select(column) + .na.drop() + .selectExpr(f"percentile(`{column}`,CAST({x} AS DOUBLE))") + .toPandas() + .iloc[:, 0] + ) stats = stats_df.iloc[0].copy() stats.name = column stats["range"] = stats["max"] - stats["min"] stats["iqr"] = stats[pretty_name(0.75)] - stats[pretty_name(0.25)] stats["cv"] = stats["std"] / float(stats["mean"]) - stats["mad"] = (df.select(column) - .na.drop() - .select(df_abs(col(column)-stats["mean"]).alias("delta")) - .agg(df_sum(col("delta"))).toPandas().iloc[0,0] / float(current_result["count"])) + stats["mad"] = df.select(column).na.drop().select( + df_abs(col(column) - stats["mean"]).alias("delta") + ).agg(df_sum(col("delta"))).toPandas().iloc[0, 0] / float( + current_result["count"] + ) stats["type"] = "NUM" - stats['n_zeros'] = df.select(column).where(col(column)==0.0).count() - stats['p_zeros'] = stats['n_zeros'] / float(nrows) + stats["n_zeros"] = df.select(column).where(col(column) == 0.0).count() + stats["p_zeros"] = stats["n_zeros"] / float(nrows) hist_data = create_hist_data(df, column, stats["min"], stats["max"], bins) return stats def describe_float_1d(df, column, current_result, nrows): - stats_df = df.select(column).na.drop().agg(mean(col(column)).alias("mean"), - df_min(col(column)).alias("min"), - df_max(col(column)).alias("max"), - variance(col(column)).alias("variance"), - kurtosis(col(column)).alias("kurtosis"), - stddev(col(column)).alias("std"), - skewness(col(column)).alias("skewness"), - df_sum(col(column)).alias("sum") - ).toPandas() + stats_df = ( + df.select(column) + .na.drop() + .agg( + mean(col(column)).alias("mean"), + df_min(col(column)).alias("min"), + df_max(col(column)).alias("max"), + variance(col(column)).alias("variance"), + kurtosis(col(column)).alias("kurtosis"), + stddev(col(column)).alias("std"), + skewness(col(column)).alias("skewness"), + df_sum(col(column)).alias("sum"), + ) + .toPandas() + ) for x in np.array([0.05, 0.25, 0.5, 0.75, 0.95]): - stats_df[pretty_name(x)] = (df.select(column) - .na.drop() - .selectExpr("percentile_approx(`{col}`,CAST({n} AS DOUBLE))" - .format(col=column, n=x)).toPandas().iloc[:,0] - ) + stats_df[pretty_name(x)] = ( + df.select(column) + .na.drop() + .selectExpr(f"percentile_approx(`{column}`,CAST({x} AS DOUBLE))") + .toPandas() + .iloc[:, 0] + ) stats = stats_df.iloc[0].copy() stats.name = column stats["range"] = stats["max"] - stats["min"] stats["iqr"] = stats[pretty_name(0.75)] - stats[pretty_name(0.25)] stats["cv"] = stats["std"] / float(stats["mean"]) - stats["mad"] = (df.select(column) - .na.drop() - .select(df_abs(col(column)-stats["mean"]).alias("delta")) - .agg(df_sum(col("delta"))).toPandas().iloc[0,0] / float(current_result["count"])) + stats["mad"] = df.select(column).na.drop().select( + df_abs(col(column) - stats["mean"]).alias("delta") + ).agg(df_sum(col("delta"))).toPandas().iloc[0, 0] / float( + current_result["count"] + ) stats["type"] = "NUM" - stats['n_zeros'] = df.select(column).where(col(column)==0.0).count() - stats['p_zeros'] = stats['n_zeros'] / float(nrows) + stats["n_zeros"] = df.select(column).where(col(column) == 0.0).count() + stats["p_zeros"] = stats["n_zeros"] / float(nrows) hist_data = create_hist_data(df, column, stats["min"], stats["max"], bins) return stats def describe_date_1d(df, column): - stats_df = df.select(column).na.drop().agg(df_min(col(column)).alias("min"), - df_max(col(column)).alias("max") - ).toPandas() + stats_df = ( + df.select(column) + .na.drop() + .agg(df_min(col(column)).alias("min"), df_max(col(column)).alias("max")) + .toPandas() + ) stats = stats_df.iloc[0].copy() stats.name = column @@ -241,66 +271,102 @@ def guess_json_type(string_value): return type(obj) def describe_categorical_1d(df, column): - value_counts = (df.select(column).na.drop() - .groupBy(column) - .agg(count(col(column))) - .orderBy("count({c})".format(c=column),ascending=False) - ).cache() - - stats = (value_counts - .limit(1) - .withColumnRenamed(column, "top") - .withColumnRenamed("count({c})".format(c=column), "freq") - ).toPandas().iloc[0] - - top_50 = value_counts.limit(50).toPandas().sort_values("count({c})".format(c=column), - ascending=False) + value_counts = ( + df.select(column) + .na.drop() + .groupBy(column) + .agg(count(col(column))) + .orderBy(f"count({column})", ascending=False) + ).cache() + + stats = ( + ( + value_counts.limit(1) + .withColumnRenamed(column, "top") + .withColumnRenamed(f"count({column})", "freq") + ) + .toPandas() + .iloc[0] + ) + + top_50 = ( + value_counts.limit(50) + .toPandas() + .sort_values(f"count({column})", ascending=False) + ) top_50_categories = top_50[column].values.tolist() - others_count = pd.Series([df.select(column).na.drop() - .where(~(col(column).isin(*top_50_categories))) - .count() - ], index=["***Other Values***"]) - others_distinct_count = pd.Series([value_counts - .where(~(col(column).isin(*top_50_categories))) - .count() - ], index=["***Other Values Distinct Count***"]) - - top = top_50.set_index(column)["count({c})".format(c=column)] + others_count = pd.Series( + [ + df.select(column) + .na.drop() + .where(~(col(column).isin(*top_50_categories))) + .count() + ], + index=["***Other Values***"], + ) + others_distinct_count = pd.Series( + [value_counts.where(~(col(column).isin(*top_50_categories))).count()], + index=["***Other Values Distinct Count***"], + ) + + top = top_50.set_index(column)[f"count({column})"] top = top.append(others_count) top = top.append(others_distinct_count) stats["value_counts"] = top stats["type"] = "CAT" value_counts.unpersist() - unparsed_valid_jsons = df.select(column).na.drop().rdd.map( - lambda x: guess_json_type(x[column])).filter( - lambda x: x).distinct().collect() + unparsed_valid_jsons = ( + df.select(column) + .na.drop() + .rdd.map(lambda x: guess_json_type(x[column])) + .filter(lambda x: x) + .distinct() + .collect() + ) stats["unparsed_json_types"] = unparsed_valid_jsons return stats def describe_constant_1d(df, column): - stats = pd.Series(['CONST'], index=['type'], name=column) - stats["value_counts"] = (df.select(column) - .na.drop() - .limit(1)).toPandas().iloc[:,0].value_counts() + stats = pd.Series(["CONST"], index=["type"], name=column) + stats["value_counts"] = ( + (df.select(column).na.drop().limit(1)).toPandas().iloc[:, 0].value_counts() + ) return stats def describe_unique_1d(df, column): - stats = pd.Series(['UNIQUE'], index=['type'], name=column) - stats["value_counts"] = (df.select(column) - .na.drop() - .limit(50)).toPandas().iloc[:,0].value_counts() + stats = pd.Series(["UNIQUE"], index=["type"], name=column) + stats["value_counts"] = ( + (df.select(column).na.drop().limit(50)).toPandas().iloc[:, 0].value_counts() + ) return stats def describe_1d(df, column, nrows, lookup_config=None): column_type = df.select(column).dtypes[0][1] - if ("array" in column_type) or ("stuct" in column_type) or ("map" in column_type): - raise NotImplementedError("Column {c} is of type {t} and cannot be analyzed".format(c=column, t=column_type)) - - distinct_count = df.select(column).agg(countDistinct(col(column)).alias("distinct_count")).toPandas() - non_nan_count = df.select(column).na.drop().select(count(col(column)).alias("count")).toPandas() - results_data = pd.concat([distinct_count, non_nan_count],axis=1) - results_data["p_unique"] = results_data["distinct_count"] / float(results_data["count"]) + if ( + ("array" in column_type) + or ("stuct" in column_type) + or ("map" in column_type) + ): + raise NotImplementedError( + f"Column {column} is of type {column_type} and cannot be analyzed" + ) + + distinct_count = ( + df.select(column) + .agg(countDistinct(col(column)).alias("distinct_count")) + .toPandas() + ) + non_nan_count = ( + df.select(column) + .na.drop() + .select(count(col(column)).alias("count")) + .toPandas() + ) + results_data = pd.concat([distinct_count, non_nan_count], axis=1) + results_data["p_unique"] = results_data["distinct_count"] / float( + results_data["count"] + ) results_data["is_unique"] = results_data["distinct_count"] == nrows results_data["n_missing"] = nrows - results_data["count"] results_data["p_missing"] = results_data["n_missing"] / float(nrows) @@ -325,7 +391,7 @@ def describe_1d(df, column, nrows, lookup_config=None): if result["n_missing"] > 0: result["distinct_count"] = result["distinct_count"] + 1 - if (result["count"] > result["distinct_count"] > 1): + if result["count"] > result["distinct_count"] > 1: try: result["mode"] = result["top"] except KeyError: @@ -339,25 +405,34 @@ def describe_1d(df, column, nrows, lookup_config=None): result["mode"] = "MISSING" if lookup_config: - lookup_object = lookup_config['object'] - col_name_in_db = lookup_config['col_name_in_db'] if 'col_name_in_db' in lookup_config else None + lookup_object = lookup_config["object"] + col_name_in_db = ( + lookup_config["col_name_in_db"] + if "col_name_in_db" in lookup_config + else None + ) try: - matched, unmatched = lookup_object.lookup(df.select(column), col_name_in_db) - result['lookedup_values'] = str(matched.count()) + "/" + str(df.select(column).count()) + matched, unmatched = lookup_object.lookup( + df.select(column), col_name_in_db + ) + result["lookedup_values"] = ( + str(matched.count()) + "/" + str(df.select(column).count()) + ) except: - result['lookedup_values'] = 'FAILED' + result["lookedup_values"] = "FAILED" else: - result['lookedup_values'] = '' + result["lookedup_values"] = "" return result - ldesc = {} for colum in df.columns: if colum in config: - if 'lookup' in config[colum]: - lookup_config = config[colum]['lookup'] - desc = describe_1d(df, colum, table_stats["n"], lookup_config=lookup_config) + if "lookup" in config[colum]: + lookup_config = config[colum]["lookup"] + desc = describe_1d( + df, colum, table_stats["n"], lookup_config=lookup_config + ) else: desc = describe_1d(df, colum, table_stats["n"]) else: @@ -377,19 +452,23 @@ def describe_1d(df, column, nrows, lookup_config=None): variable_stats = pd.DataFrame(ldesc) table_stats["nvar"] = len(df.columns) - table_stats["total_missing"] = float(variable_stats.loc["n_missing"].sum()) / (table_stats["n"] * table_stats["nvar"]) + table_stats["total_missing"] = float(variable_stats.loc["n_missing"].sum()) / ( + table_stats["n"] * table_stats["nvar"] + ) memsize = 0 - table_stats['memsize'] = fmt_bytesize(memsize) - table_stats['recordsize'] = fmt_bytesize(memsize / table_stats['n']) - table_stats.update({k: 0 for k in ("NUM", "DATE", "CONST", "CAT", "UNIQUE", "CORR")}) - table_stats.update(dict(variable_stats.loc['type'].value_counts())) - table_stats['REJECTED'] = table_stats['CONST'] + table_stats['CORR'] + table_stats["memsize"] = fmt_bytesize(memsize) + table_stats["recordsize"] = fmt_bytesize(memsize / table_stats["n"]) + table_stats.update( + {k: 0 for k in ("NUM", "DATE", "CONST", "CAT", "UNIQUE", "CORR")} + ) + table_stats.update(dict(variable_stats.loc["type"].value_counts())) + table_stats["REJECTED"] = table_stats["CONST"] + table_stats["CORR"] freq_dict = {} for var in variable_stats: if "value_counts" not in variable_stats[var]: pass - elif not(variable_stats[var]["value_counts"] is np.nan): + elif variable_stats[var]["value_counts"] is not np.nan: freq_dict[var] = variable_stats[var]["value_counts"] else: pass @@ -400,129 +479,155 @@ def describe_1d(df, column, nrows, lookup_config=None): return table_stats, variable_stats.T, freq_dict -import numpy as np -from pyspark.sql.functions import abs as absou SKEWNESS_CUTOFF = 20 -DEFAULT_FLOAT_FORMATTER = u'spark_df_profiling.__default_float_formatter' +DEFAULT_FLOAT_FORMATTER = "spark_df_profiling.__default_float_formatter" def gradient_format(value, limit1, limit2, c1, c2): - def LerpColour(c1,c2,t): - return (int(c1[0]+(c2[0]-c1[0])*t),int(c1[1]+(c2[1]-c1[1])*t),int(c1[2]+(c2[2]-c1[2])*t)) - c = LerpColour(c1, c2, (value-limit1)/(limit2-limit1)) - return fmt_color(value,"rgb{}".format(str(c))) + def LerpColour(c1, c2, t): + return ( + int(c1[0] + (c2[0] - c1[0]) * t), + int(c1[1] + (c2[1] - c1[1]) * t), + int(c1[2] + (c2[2] - c1[2]) * t), + ) + + c = LerpColour(c1, c2, (value - limit1) / (limit2 - limit1)) + return fmt_color(value, f"rgb{str(c)}") def fmt_color(text, color): - return(u'{text}'.format(color=color,text=str(text))) + return f'{str(text)}' def fmt_class(text, cls): - return(u'{text}'.format(cls=cls,text=str(text))) + return f'{str(text)}' -def fmt_bytesize(num, suffix='B'): - for unit in ['','Ki','Mi','Gi','Ti','Pi','Ei','Zi']: +def fmt_bytesize(num, suffix="B"): + for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: if num < 0: - num = num*-1 + num = num * -1 if num < 1024.0: return "%3.1f %s%s" % (num, unit, suffix) num /= 1024.0 - return "%.1f %s%s" % (num, 'Yi', suffix) + return "%.1f %s%s" % (num, "Yi", suffix) def fmt_percent(v): - return "{:2.1f}%".format(v*100) + return f"{v * 100:2.1f}%" + def fmt_varname(v): - return u'{0}'.format(v) - - -value_formatters={ - u'freq': (lambda v: gradient_format(v, 0, 62000, (30, 198, 244), (99, 200, 72))), - u'p_missing': fmt_percent, - u'p_infinite': fmt_percent, - u'p_unique': fmt_percent, - u'p_zeros': fmt_percent, - u'memorysize': fmt_bytesize, - u'total_missing': fmt_percent, - DEFAULT_FLOAT_FORMATTER: lambda v: str(float('{:.5g}'.format(v))).rstrip('0').rstrip('.'), - u'correlation_var': lambda v: fmt_varname(v), - u'unparsed_json_types': lambda v: ', '.join([s.__name__ for s in v]) - } + return f"{v}" + + +value_formatters = { + "freq": (lambda v: gradient_format(v, 0, 62000, (30, 198, 244), (99, 200, 72))), + "p_missing": fmt_percent, + "p_infinite": fmt_percent, + "p_unique": fmt_percent, + "p_zeros": fmt_percent, + "memorysize": fmt_bytesize, + "total_missing": fmt_percent, + DEFAULT_FLOAT_FORMATTER: lambda v: str(float(f"{v:.5g}")).rstrip("0").rstrip("."), + "correlation_var": lambda v: fmt_varname(v), + "unparsed_json_types": lambda v: ", ".join([s.__name__ for s in v]), +} + def fmt_row_severity(v): - if np.isnan(v) or v<= 0.01: + if np.isnan(v) or v <= 0.01: return "ignore" else: return "alert" + def fmt_skewness(v): - if not np.isnan(v) and (v<-SKEWNESS_CUTOFF or v> SKEWNESS_CUTOFF): + if not np.isnan(v) and (v < -SKEWNESS_CUTOFF or v > SKEWNESS_CUTOFF): return "alert" else: return "" -row_formatters={ - u'p_zeros': fmt_row_severity, - u'p_missing': fmt_row_severity, - u'p_infinite': fmt_row_severity, - u'n_duplicates': fmt_row_severity, - u'skewness': fmt_skewness, + +row_formatters = { + "p_zeros": fmt_row_severity, + "p_missing": fmt_row_severity, + "p_infinite": fmt_row_severity, + "n_duplicates": fmt_row_severity, + "skewness": fmt_skewness, } run(["/bin/bash", "/etc/config/v3io/v3io-spark-operator.sh"]) -def describe_spark(context: MLClientCtx, - dataset: DataItem, - artifact_path, - bins: int=30, - describe_extended: bool=True): - + +def describe_spark( + context: MLClientCtx, + dataset: DataItem, + artifact_path, + bins: int = 30, + describe_extended: bool = True, +): location = dataset.local() - + spark = SparkSession.builder.appName("Spark job").getOrCreate() - - df = spark.read.csv(location, header=True, inferSchema= True) + + df = spark.read.csv(location, header=True, inferSchema=True) kwargs = [] - - float_cols = [item[0] for item in df.dtypes if item[1].startswith('float') or item[1].startswith('double')] - + + float_cols = [ + item[0] + for item in df.dtypes + if item[1].startswith("float") or item[1].startswith("double") + ] + if describe_extended == True: - table, variables, freq = describe(df, bins, float_cols, kwargs) tbl_1 = variables.reset_index() if len(freq) != 0: - tbl_2 = pd.DataFrame.from_dict(freq, orient = "index").sort_index().stack().reset_index() - tbl_2.columns = ['col', 'key', 'val'] - tbl_2['Merged'] = [{key: val} for key, val in zip(tbl_2.key, tbl_2.val)] - tbl_2 = tbl_2.groupby('col', as_index=False).agg(lambda x: tuple(x))[['col','Merged']] - - summary = pd.merge(tbl_1, tbl_2, how='left', left_on='index', right_on='col') + tbl_2 = ( + pd.DataFrame.from_dict(freq, orient="index") + .sort_index() + .stack() + .reset_index() + ) + tbl_2.columns = ["col", "key", "val"] + tbl_2["Merged"] = [{key: val} for key, val in zip(tbl_2.key, tbl_2.val)] + tbl_2 = tbl_2.groupby("col", as_index=False).agg(lambda x: tuple(x))[ + ["col", "Merged"] + ] + + summary = pd.merge( + tbl_1, tbl_2, how="left", left_on="index", right_on="col" + ) else: summary = tbl_1 - context.log_dataset("summary_stats", - df=summary, - format="csv", index=False, - artifact_path=context.artifact_subpath('data')) + context.log_dataset( + "summary_stats", + df=summary, + format="csv", + index=False, + artifact_path=context.artifact_subpath("data"), + ) context.log_results(table) - + else: tbl_1 = df.describe().toPandas() - + summary = tbl_1.T - - context.log_dataset("summary_stats", - df=summary, - format="csv", index=False, - artifact_path=context.artifact_subpath('data')) - - spark.stop() + context.log_dataset( + "summary_stats", + df=summary, + format="csv", + index=False, + artifact_path=context.artifact_subpath("data"), + ) + + spark.stop() diff --git a/functions/src/describe_spark/function.yaml b/functions/src/describe_spark/function.yaml index 688f4260b..12223e77c 100644 --- a/functions/src/describe_spark/function.yaml +++ b/functions/src/describe_spark/function.yaml @@ -1,322 +1,264 @@ -kind: job metadata: - name: describe-spark tag: '' - hash: bd54bbf6350fb0dc392ff7f91b4aa6ea3c742e93 - project: '' + name: describe-spark categories: - data-analysis +verbose: false +kind: job spec: - command: '' - args: [] image: iguazio/shell:3.0_b5565_20201026062233_wsdf - env: [] - default_handler: describe_spark + disable_auto_mount: false + build: + origin_filename: '' + functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Generated by nuclio.export.NuclioExporter

import warnings
from subprocess import run

import numpy as np
import pandas as pd
from mlrun.datastore import DataItem
from mlrun.execution import MLClientCtx
from pyspark.sql import SparkSession

warnings.filterwarnings("ignore")

import json
from itertools import product

import matplotlib
from pyspark.sql import DataFrame as SparkDataFrame
from pyspark.sql.functions import (
    abs as df_abs,
)
from pyspark.sql.functions import (
    col,
    count,
    countDistinct,
    kurtosis,
    mean,
    skewness,
    stddev,
    variance,
    when,
)
from pyspark.sql.functions import (
    max as df_max,
)
from pyspark.sql.functions import (
    min as df_min,
)
from pyspark.sql.functions import (
    sum as df_sum,
)


def describe(df, bins, corr_reject, config, **kwargs):
    if not isinstance(df, SparkDataFrame):
        raise TypeError("df must be of type pyspark.sql.DataFrame")

    table_stats = {"n": df.count()}
    if table_stats["n"] == 0:
        raise ValueError("df cannot be empty")

    try:
        matplotlib.style.use("default")
    except:
        pass

    def pretty_name(x):
        x *= 100
        if x == int(x):
            return "%.0f%%" % x
        else:
            return "%.1f%%" % x

    def corr_matrix(df, columns=None):
        if columns is None:
            columns = df.columns
        combinations = list(product(columns, columns))

        def separate(l, n):
            for i in range(0, len(l), n):
                yield l[i : i + n]

        grouped = list(separate(combinations, len(columns)))
        df_cleaned = df.select(*columns).na.drop(how="any")

        for i in grouped:
            for j in enumerate(i):
                i[j[0]] = i[j[0]] + (df_cleaned.corr(str(j[1][0]), str(j[1][1])),)

        df_pandas = pd.DataFrame(grouped).applymap(lambda x: x[2])
        df_pandas.columns = columns
        df_pandas.index = columns

        return df_pandas

    def create_hist_data(df, column, minim, maxim, bins=10):
        def create_all_conditions(current_col, column, left_edges, count=1):
            """
            Recursive function that exploits the
            ability to call the Spark SQL Column method
            .when() in a recursive way.
            """
            left_edges = left_edges[:]
            if len(left_edges) == 0:
                return current_col
            if len(left_edges) == 1:
                next_col = current_col.when(col(column) >= float(left_edges[0]), count)
                left_edges.pop(0)
                return create_all_conditions(next_col, column, left_edges[:], count + 1)
            next_col = current_col.when(
                (float(left_edges[0]) <= col(column))
                & (col(column) < float(left_edges[1])),
                count,
            )
            left_edges.pop(0)
            return create_all_conditions(next_col, column, left_edges[:], count + 1)

        num_range = maxim - minim
        bin_width = num_range / float(bins)
        left_edges = [minim]
        for _bin in range(bins):
            left_edges = left_edges + [left_edges[-1] + bin_width]
        left_edges.pop()
        expression_col = when(
            (float(left_edges[0]) <= col(column))
            & (col(column) < float(left_edges[1])),
            0,
        )
        left_edges_copy = left_edges[:]
        left_edges_copy.pop(0)
        bin_data = (
            df.select(col(column))
            .na.drop()
            .select(
                col(column),
                create_all_conditions(expression_col, column, left_edges_copy).alias(
                    "bin_id"
                ),
            )
            .groupBy("bin_id")
            .count()
        ).toPandas()

        bin_data.index = bin_data["bin_id"]
        new_index = list(range(bins))
        bin_data = bin_data.reindex(new_index)
        bin_data["bin_id"] = bin_data.index
        bin_data = bin_data.fillna(0)

        bin_data["left_edge"] = left_edges
        bin_data["width"] = bin_width

        return bin_data

    def describe_integer_1d(df, column, current_result, nrows):
        stats_df = (
            df.select(column)
            .na.drop()
            .agg(
                mean(col(column)).alias("mean"),
                df_min(col(column)).alias("min"),
                df_max(col(column)).alias("max"),
                variance(col(column)).alias("variance"),
                kurtosis(col(column)).alias("kurtosis"),
                stddev(col(column)).alias("std"),
                skewness(col(column)).alias("skewness"),
                df_sum(col(column)).alias("sum"),
            )
            .toPandas()
        )

        for x in np.array([0.05, 0.25, 0.5, 0.75, 0.95]):
            stats_df[pretty_name(x)] = (
                df.select(column)
                .na.drop()
                .selectExpr(f"percentile(`{column}`,CAST({x} AS DOUBLE))")
                .toPandas()
                .iloc[:, 0]
            )
        stats = stats_df.iloc[0].copy()
        stats.name = column
        stats["range"] = stats["max"] - stats["min"]
        stats["iqr"] = stats[pretty_name(0.75)] - stats[pretty_name(0.25)]
        stats["cv"] = stats["std"] / float(stats["mean"])
        stats["mad"] = df.select(column).na.drop().select(
            df_abs(col(column) - stats["mean"]).alias("delta")
        ).agg(df_sum(col("delta"))).toPandas().iloc[0, 0] / float(
            current_result["count"]
        )
        stats["type"] = "NUM"
        stats["n_zeros"] = df.select(column).where(col(column) == 0.0).count()
        stats["p_zeros"] = stats["n_zeros"] / float(nrows)

        hist_data = create_hist_data(df, column, stats["min"], stats["max"], bins)

        return stats

    def describe_float_1d(df, column, current_result, nrows):
        stats_df = (
            df.select(column)
            .na.drop()
            .agg(
                mean(col(column)).alias("mean"),
                df_min(col(column)).alias("min"),
                df_max(col(column)).alias("max"),
                variance(col(column)).alias("variance"),
                kurtosis(col(column)).alias("kurtosis"),
                stddev(col(column)).alias("std"),
                skewness(col(column)).alias("skewness"),
                df_sum(col(column)).alias("sum"),
            )
            .toPandas()
        )

        for x in np.array([0.05, 0.25, 0.5, 0.75, 0.95]):
            stats_df[pretty_name(x)] = (
                df.select(column)
                .na.drop()
                .selectExpr(f"percentile_approx(`{column}`,CAST({x} AS DOUBLE))")
                .toPandas()
                .iloc[:, 0]
            )
        stats = stats_df.iloc[0].copy()
        stats.name = column
        stats["range"] = stats["max"] - stats["min"]
        stats["iqr"] = stats[pretty_name(0.75)] - stats[pretty_name(0.25)]
        stats["cv"] = stats["std"] / float(stats["mean"])
        stats["mad"] = df.select(column).na.drop().select(
            df_abs(col(column) - stats["mean"]).alias("delta")
        ).agg(df_sum(col("delta"))).toPandas().iloc[0, 0] / float(
            current_result["count"]
        )
        stats["type"] = "NUM"
        stats["n_zeros"] = df.select(column).where(col(column) == 0.0).count()
        stats["p_zeros"] = stats["n_zeros"] / float(nrows)

        hist_data = create_hist_data(df, column, stats["min"], stats["max"], bins)

        return stats

    def describe_date_1d(df, column):
        stats_df = (
            df.select(column)
            .na.drop()
            .agg(df_min(col(column)).alias("min"), df_max(col(column)).alias("max"))
            .toPandas()
        )
        stats = stats_df.iloc[0].copy()
        stats.name = column

        if isinstance(stats["max"], pd.Timestamp):
            stats = stats.astype(object)
            stats["max"] = str(stats["max"].to_pydatetime())
            stats["min"] = str(stats["min"].to_pydatetime())

        else:
            stats["range"] = stats["max"] - stats["min"]
        stats["type"] = "DATE"
        return stats

    def guess_json_type(string_value):
        try:
            obj = json.loads(string_value)
        except:
            return None

        return type(obj)

    def describe_categorical_1d(df, column):
        value_counts = (
            df.select(column)
            .na.drop()
            .groupBy(column)
            .agg(count(col(column)))
            .orderBy(f"count({column})", ascending=False)
        ).cache()

        stats = (
            (
                value_counts.limit(1)
                .withColumnRenamed(column, "top")
                .withColumnRenamed(f"count({column})", "freq")
            )
            .toPandas()
            .iloc[0]
        )

        top_50 = (
            value_counts.limit(50)
            .toPandas()
            .sort_values(f"count({column})", ascending=False)
        )
        top_50_categories = top_50[column].values.tolist()

        others_count = pd.Series(
            [
                df.select(column)
                .na.drop()
                .where(~(col(column).isin(*top_50_categories)))
                .count()
            ],
            index=["***Other Values***"],
        )
        others_distinct_count = pd.Series(
            [value_counts.where(~(col(column).isin(*top_50_categories))).count()],
            index=["***Other Values Distinct Count***"],
        )

        top = top_50.set_index(column)[f"count({column})"]
        top = top.append(others_count)
        top = top.append(others_distinct_count)
        stats["value_counts"] = top
        stats["type"] = "CAT"
        value_counts.unpersist()
        unparsed_valid_jsons = (
            df.select(column)
            .na.drop()
            .rdd.map(lambda x: guess_json_type(x[column]))
            .filter(lambda x: x)
            .distinct()
            .collect()
        )
        stats["unparsed_json_types"] = unparsed_valid_jsons
        return stats

    def describe_constant_1d(df, column):
        stats = pd.Series(["CONST"], index=["type"], name=column)
        stats["value_counts"] = (
            (df.select(column).na.drop().limit(1)).toPandas().iloc[:, 0].value_counts()
        )
        return stats

    def describe_unique_1d(df, column):
        stats = pd.Series(["UNIQUE"], index=["type"], name=column)
        stats["value_counts"] = (
            (df.select(column).na.drop().limit(50)).toPandas().iloc[:, 0].value_counts()
        )
        return stats

    def describe_1d(df, column, nrows, lookup_config=None):
        column_type = df.select(column).dtypes[0][1]
        if (
            ("array" in column_type)
            or ("stuct" in column_type)
            or ("map" in column_type)
        ):
            raise NotImplementedError(
                f"Column {column} is of type {column_type} and cannot be analyzed"
            )

        distinct_count = (
            df.select(column)
            .agg(countDistinct(col(column)).alias("distinct_count"))
            .toPandas()
        )
        non_nan_count = (
            df.select(column)
            .na.drop()
            .select(count(col(column)).alias("count"))
            .toPandas()
        )
        results_data = pd.concat([distinct_count, non_nan_count], axis=1)
        results_data["p_unique"] = results_data["distinct_count"] / float(
            results_data["count"]
        )
        results_data["is_unique"] = results_data["distinct_count"] == nrows
        results_data["n_missing"] = nrows - results_data["count"]
        results_data["p_missing"] = results_data["n_missing"] / float(nrows)
        results_data["p_infinite"] = 0
        results_data["n_infinite"] = 0
        result = results_data.iloc[0].copy()
        result["memorysize"] = 0
        result.name = column

        if result["distinct_count"] <= 1:
            result = result.append(describe_constant_1d(df, column))
        elif column_type in {"tinyint", "smallint", "int", "bigint"}:
            result = result.append(describe_integer_1d(df, column, result, nrows))
        elif column_type in {"float", "double", "decimal"}:
            result = result.append(describe_float_1d(df, column, result, nrows))
        elif column_type in {"date", "timestamp"}:
            result = result.append(describe_date_1d(df, column))
        elif result["is_unique"] == True:
            result = result.append(describe_unique_1d(df, column))
        else:
            result = result.append(describe_categorical_1d(df, column))
            if result["n_missing"] > 0:
                result["distinct_count"] = result["distinct_count"] + 1

        if result["count"] > result["distinct_count"] > 1:
            try:
                result["mode"] = result["top"]
            except KeyError:
                result["mode"] = 0
        else:
            try:
                result["mode"] = result["value_counts"].index[0]
            except KeyError:
                result["mode"] = 0
            except IndexError:
                result["mode"] = "MISSING"

        if lookup_config:
            lookup_object = lookup_config["object"]
            col_name_in_db = (
                lookup_config["col_name_in_db"]
                if "col_name_in_db" in lookup_config
                else None
            )
            try:
                matched, unmatched = lookup_object.lookup(
                    df.select(column), col_name_in_db
                )
                result["lookedup_values"] = (
                    str(matched.count()) + "/" + str(df.select(column).count())
                )
            except:
                result["lookedup_values"] = "FAILED"
        else:
            result["lookedup_values"] = ""

        return result

    ldesc = {}
    for colum in df.columns:
        if colum in config:
            if "lookup" in config[colum]:
                lookup_config = config[colum]["lookup"]
                desc = describe_1d(
                    df, colum, table_stats["n"], lookup_config=lookup_config
                )
            else:
                desc = describe_1d(df, colum, table_stats["n"])
        else:
            desc = describe_1d(df, colum, table_stats["n"])
        ldesc.update({colum: desc})

    if corr_reject is not None:
        computable_corrs = [colum for colum in ldesc if ldesc[colum]["type"] in {"NUM"}]

        if len(computable_corrs) > 0:
            corr = corr_matrix(df, columns=computable_corrs)
            for x, corr_x in corr.iterrows():
                for y, corr in corr_x.iteritems():
                    if x == y:
                        break

    variable_stats = pd.DataFrame(ldesc)

    table_stats["nvar"] = len(df.columns)
    table_stats["total_missing"] = float(variable_stats.loc["n_missing"].sum()) / (
        table_stats["n"] * table_stats["nvar"]
    )
    memsize = 0
    table_stats["memsize"] = fmt_bytesize(memsize)
    table_stats["recordsize"] = fmt_bytesize(memsize / table_stats["n"])
    table_stats.update(
        {k: 0 for k in ("NUM", "DATE", "CONST", "CAT", "UNIQUE", "CORR")}
    )
    table_stats.update(dict(variable_stats.loc["type"].value_counts()))
    table_stats["REJECTED"] = table_stats["CONST"] + table_stats["CORR"]

    freq_dict = {}
    for var in variable_stats:
        if "value_counts" not in variable_stats[var]:
            pass
        elif variable_stats[var]["value_counts"] is not np.nan:
            freq_dict[var] = variable_stats[var]["value_counts"]
        else:
            pass
    try:
        variable_stats = variable_stats.drop("value_counts")
    except (ValueError, KeyError):
        pass

    return table_stats, variable_stats.T, freq_dict


SKEWNESS_CUTOFF = 20
DEFAULT_FLOAT_FORMATTER = "spark_df_profiling.__default_float_formatter"


def gradient_format(value, limit1, limit2, c1, c2):
    def LerpColour(c1, c2, t):
        return (
            int(c1[0] + (c2[0] - c1[0]) * t),
            int(c1[1] + (c2[1] - c1[1]) * t),
            int(c1[2] + (c2[2] - c1[2]) * t),
        )

    c = LerpColour(c1, c2, (value - limit1) / (limit2 - limit1))
    return fmt_color(value, f"rgb{str(c)}")


def fmt_color(text, color):
    return f'<span style="color:{color}">{str(text)}</span>'


def fmt_class(text, cls):
    return f'<span class="{cls}">{str(text)}</span>'


def fmt_bytesize(num, suffix="B"):
    for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
        if num < 0:
            num = num * -1
            if num < 1024.0:
                return "%3.1f %s%s" % (num, unit, suffix)
            num /= 1024.0
    return "%.1f %s%s" % (num, "Yi", suffix)


def fmt_percent(v):
    return f"{v * 100:2.1f}%"


def fmt_varname(v):
    return f"<code>{v}</code>"


value_formatters = {
    "freq": (lambda v: gradient_format(v, 0, 62000, (30, 198, 244), (99, 200, 72))),
    "p_missing": fmt_percent,
    "p_infinite": fmt_percent,
    "p_unique": fmt_percent,
    "p_zeros": fmt_percent,
    "memorysize": fmt_bytesize,
    "total_missing": fmt_percent,
    DEFAULT_FLOAT_FORMATTER: lambda v: str(float(f"{v:.5g}")).rstrip("0").rstrip("."),
    "correlation_var": lambda v: fmt_varname(v),
    "unparsed_json_types": lambda v: ", ".join([s.__name__ for s in v]),
}


def fmt_row_severity(v):
    if np.isnan(v) or v <= 0.01:
        return "ignore"
    else:
        return "alert"


def fmt_skewness(v):
    if not np.isnan(v) and (v < -SKEWNESS_CUTOFF or v > SKEWNESS_CUTOFF):
        return "alert"
    else:
        return ""


row_formatters = {
    "p_zeros": fmt_row_severity,
    "p_missing": fmt_row_severity,
    "p_infinite": fmt_row_severity,
    "n_duplicates": fmt_row_severity,
    "skewness": fmt_skewness,
}

run(["/bin/bash", "/etc/config/v3io/v3io-spark-operator.sh"])


def describe_spark(
    context: MLClientCtx,
    dataset: DataItem,
    artifact_path,
    bins: int = 30,
    describe_extended: bool = True,
):
    location = dataset.local()

    spark = SparkSession.builder.appName("Spark job").getOrCreate()

    df = spark.read.csv(location, header=True, inferSchema=True)

    kwargs = []

    float_cols = [
        item[0]
        for item in df.dtypes
        if item[1].startswith("float") or item[1].startswith("double")
    ]

    if describe_extended == True:
        table, variables, freq = describe(df, bins, float_cols, kwargs)

        tbl_1 = variables.reset_index()

        if len(freq) != 0:
            tbl_2 = (
                pd.DataFrame.from_dict(freq, orient="index")
                .sort_index()
                .stack()
                .reset_index()
            )
            tbl_2.columns = ["col", "key", "val"]
            tbl_2["Merged"] = [{key: val} for key, val in zip(tbl_2.key, tbl_2.val)]
            tbl_2 = tbl_2.groupby("col", as_index=False).agg(lambda x: tuple(x))[
                ["col", "Merged"]
            ]

            summary = pd.merge(
                tbl_1, tbl_2, how="left", left_on="index", right_on="col"
            )

        else:
            summary = tbl_1

        context.log_dataset(
            "summary_stats",
            df=summary,
            format="csv",
            index=False,
            artifact_path=context.artifact_subpath("data"),
        )

        context.log_results(table)

    else:
        tbl_1 = df.describe().toPandas()

        summary = tbl_1.T

        context.log_dataset(
            "summary_stats",
            df=summary,
            format="csv",
            index=False,
            artifact_path=context.artifact_subpath("data"),
        )

    spark.stop()
 + code_origin: '' + filename: describe_spark.py entry_points: describe: - name: describe - doc: '' parameters: - name: df - default: '' - name: bins - default: '' - name: corr_reject - default: '' - name: config - default: '' - outputs: - - default: '' - lineno: 38 - pretty_name: - name: pretty_name + name: describe doc: '' + has_kwargs: true + has_varargs: false + lineno: 58 + pretty_name: parameters: - name: x - default: '' - outputs: - - default: '' - lineno: 51 - corr_matrix: - name: corr_matrix + name: pretty_name doc: '' + has_kwargs: false + has_varargs: false + lineno: 71 + corr_matrix: parameters: - name: df - default: '' - name: columns default: null - outputs: - - default: '' - lineno: 58 - separate: - name: separate + name: corr_matrix doc: '' + has_kwargs: false + has_varargs: false + lineno: 78 + separate: parameters: - name: l - default: '' - name: n - default: '' - outputs: - - default: '' - lineno: 63 - create_hist_data: - name: create_hist_data + name: separate doc: '' + has_kwargs: false + has_varargs: false + lineno: 83 + create_hist_data: parameters: - name: df - default: '' - name: column - default: '' - name: minim - default: '' - name: maxim - default: '' - name: bins default: 10 - outputs: - - default: '' - lineno: 80 + name: create_hist_data + doc: '' + has_kwargs: false + has_varargs: false + lineno: 100 create_all_conditions: - name: create_all_conditions - doc: 'Recursive function that exploits the - - ability to call the Spark SQL Column method - - .when() in a recursive way.' parameters: - name: current_col - default: '' - name: column - default: '' - name: left_edges - default: '' - name: count default: 1 - outputs: - - default: '' - lineno: 82 + name: create_all_conditions + doc: 'Recursive function that exploits the + + ability to call the Spark SQL Column method + + .when() in a recursive way.' + has_kwargs: false + has_varargs: false + lineno: 101 describe_integer_1d: - name: describe_integer_1d - doc: '' parameters: - name: df - default: '' - name: column - default: '' - name: current_result - default: '' - name: nrows - default: '' - outputs: - - default: '' - lineno: 134 - describe_float_1d: - name: describe_float_1d + name: describe_integer_1d doc: '' + has_kwargs: false + has_varargs: false + lineno: 159 + describe_float_1d: parameters: - name: df - default: '' - name: column - default: '' - name: current_result - default: '' - name: nrows - default: '' - outputs: - - default: '' - lineno: 170 - describe_date_1d: - name: describe_date_1d + name: describe_float_1d doc: '' + has_kwargs: false + has_varargs: false + lineno: 202 + describe_date_1d: parameters: - name: df - default: '' - name: column - default: '' - outputs: - - default: '' - lineno: 204 - guess_json_type: - name: guess_json_type + name: describe_date_1d doc: '' + has_kwargs: false + has_varargs: false + lineno: 245 + guess_json_type: parameters: - name: string_value - default: '' - outputs: - - default: '' - lineno: 221 + name: guess_json_type + doc: '' + has_kwargs: false + has_varargs: false + lineno: 265 describe_categorical_1d: + parameters: + - name: df + - name: column name: describe_categorical_1d doc: '' + has_kwargs: false + has_varargs: false + lineno: 273 + describe_constant_1d: parameters: - name: df - default: '' - name: column - default: '' - outputs: - - default: '' - lineno: 229 - describe_constant_1d: name: describe_constant_1d doc: '' + has_kwargs: false + has_varargs: false + lineno: 330 + describe_unique_1d: parameters: - name: df - default: '' - name: column - default: '' - outputs: - - default: '' - lineno: 267 - describe_unique_1d: name: describe_unique_1d doc: '' - parameters: - - name: df - default: '' - - name: column - default: '' - outputs: - - default: '' - lineno: 274 + has_kwargs: false + has_varargs: false + lineno: 337 describe_1d: - name: describe_1d - doc: '' parameters: - name: df - default: '' - name: column - default: '' - name: nrows - default: '' - name: lookup_config default: null - outputs: - - default: '' - lineno: 281 - gradient_format: - name: gradient_format + name: describe_1d doc: '' + has_kwargs: false + has_varargs: false + lineno: 344 + gradient_format: parameters: - name: value - default: '' - name: limit1 - default: '' - name: limit2 - default: '' - name: c1 - default: '' - name: c2 - default: '' - outputs: - - default: '' - lineno: 396 - LerpColour: - name: LerpColour + name: gradient_format doc: '' + has_kwargs: false + has_varargs: false + lineno: 487 + LerpColour: parameters: - name: c1 - default: '' - name: c2 - default: '' - name: t - default: '' - outputs: - - default: '' - lineno: 397 - fmt_color: - name: fmt_color + name: LerpColour doc: '' + has_kwargs: false + has_varargs: false + lineno: 488 + fmt_color: parameters: - name: text - default: '' - name: color - default: '' - outputs: - - default: '' - lineno: 403 - fmt_class: - name: fmt_class + name: fmt_color doc: '' + has_kwargs: false + has_varargs: false + lineno: 499 + fmt_class: parameters: - name: text - default: '' - name: cls - default: '' - outputs: - - default: '' - lineno: 407 - fmt_bytesize: - name: fmt_bytesize + name: fmt_class doc: '' + has_kwargs: false + has_varargs: false + lineno: 503 + fmt_bytesize: parameters: - name: num - default: '' - name: suffix default: B - outputs: - - default: '' - lineno: 411 + name: fmt_bytesize + doc: '' + has_kwargs: false + has_varargs: false + lineno: 507 fmt_percent: + parameters: + - name: v name: fmt_percent doc: '' + has_kwargs: false + has_varargs: false + lineno: 517 + fmt_varname: parameters: - name: v - default: '' - outputs: - - default: '' - lineno: 421 - fmt_varname: name: fmt_varname doc: '' + has_kwargs: false + has_varargs: false + lineno: 521 + fmt_row_severity: parameters: - name: v - default: '' - outputs: - - default: '' - lineno: 424 - fmt_row_severity: name: fmt_row_severity doc: '' + has_kwargs: false + has_varargs: false + lineno: 539 + fmt_skewness: parameters: - name: v - default: '' - outputs: - - default: '' - lineno: 441 - fmt_skewness: name: fmt_skewness doc: '' - parameters: - - name: v - default: '' - outputs: - - default: '' - lineno: 447 + has_kwargs: false + has_varargs: false + lineno: 546 describe_spark: - name: describe_spark - doc: '' parameters: - name: context type: MLClientCtx - default: '' - name: dataset type: DataItem - default: '' - name: artifact_path - default: '' - name: bins type: int default: 30 - name: describe_extended type: bool default: true - outputs: - - default: '' - lineno: 463 + name: describe_spark + doc: '' + has_kwargs: false + has_varargs: false + lineno: 564 + command: '' description: '' - build: - functionSourceCode: # Generated by nuclio.export.NuclioExporter

import mlrun
from mlrun.platforms.iguazio import mount_v3io, mount_v3iod
from mlrun.datastore import DataItem
from mlrun.execution import MLClientCtx

import os
from subprocess import run
import pandas as pd
import numpy as np

from pyspark.sql.types import LongType
from pyspark.sql import SparkSession

import sys
import base64 as b64
import warnings
warnings.filterwarnings("ignore")

from itertools import product
import matplotlib

import numpy as np
import json
import pandas as pd
from matplotlib import pyplot as plt
from pkg_resources import resource_filename
import six
from pyspark.sql import DataFrame as SparkDataFrame
from pyspark.sql.functions import (abs as df_abs, col, count, countDistinct,
                                   max as df_max, mean, min as df_min,
                                   sum as df_sum, when
                                   )
from pyspark.sql.functions import variance, stddev, kurtosis, skewness


def describe(df, bins, corr_reject, config, **kwargs):
    if not isinstance(df, SparkDataFrame):
        raise TypeError("df must be of type pyspark.sql.DataFrame")

    table_stats = {"n": df.count()}
    if table_stats["n"] == 0:
        raise ValueError("df cannot be empty")

    try:
        matplotlib.style.use("default")
    except:
        pass

    def pretty_name(x):
        x *= 100
        if x == int(x):
            return '%.0f%%' % x
        else:
            return '%.1f%%' % x

    def corr_matrix(df, columns=None):
        if columns is None:
            columns = df.columns
        combinations = list(product(columns,columns))

        def separate(l, n):
            for i in range(0, len(l), n):
                yield l[i:i+n]

        grouped = list(separate(combinations,len(columns)))
        df_cleaned = df.select(*columns).na.drop(how="any")

        for i in grouped:
            for j in enumerate(i):
                i[j[0]] = i[j[0]] + (df_cleaned.corr(str(j[1][0]), str(j[1][1])),)

        df_pandas = pd.DataFrame(grouped).applymap(lambda x: x[2])
        df_pandas.columns = columns
        df_pandas.index = columns
        
        return df_pandas

    def create_hist_data(df, column, minim, maxim, bins=10):

        def create_all_conditions(current_col, column, left_edges, count=1):
            """
            Recursive function that exploits the
            ability to call the Spark SQL Column method
            .when() in a recursive way.
            """
            left_edges = left_edges[:]
            if len(left_edges) == 0:
                return current_col
            if len(left_edges) == 1:
                next_col = current_col.when(col(column) >= float(left_edges[0]), count)
                left_edges.pop(0)
                return create_all_conditions(next_col, column, left_edges[:], count+1)
            next_col = current_col.when((float(left_edges[0]) <= col(column))
                                        & (col(column) < float(left_edges[1])), count)
            left_edges.pop(0)
            return create_all_conditions(next_col, column, left_edges[:], count+1)

        num_range = maxim - minim
        bin_width = num_range / float(bins)
        left_edges = [minim]
        for _bin in range(bins):
            left_edges = left_edges + [left_edges[-1] + bin_width]
        left_edges.pop()
        expression_col = when((float(left_edges[0]) <= col(column))
                              & (col(column) < float(left_edges[1])), 0)
        left_edges_copy = left_edges[:]
        left_edges_copy.pop(0)
        bin_data = (df.select(col(column))
                    .na.drop()
                    .select(col(column),
                            create_all_conditions(expression_col,
                                                  column,
                                                  left_edges_copy
                                                 ).alias("bin_id")
                           )
                    .groupBy("bin_id").count()
                   ).toPandas()

        bin_data.index = bin_data["bin_id"]
        new_index = list(range(bins))
        bin_data = bin_data.reindex(new_index)
        bin_data["bin_id"] = bin_data.index
        bin_data = bin_data.fillna(0)

        bin_data["left_edge"] = left_edges
        bin_data["width"] = bin_width
        

        return bin_data


    def describe_integer_1d(df, column, current_result, nrows):
        
        stats_df = df.select(column).na.drop().agg(mean(col(column)).alias("mean"),
                                                       df_min(col(column)).alias("min"),
                                                       df_max(col(column)).alias("max"),
                                                       variance(col(column)).alias("variance"),
                                                       kurtosis(col(column)).alias("kurtosis"),
                                                       stddev(col(column)).alias("std"),
                                                       skewness(col(column)).alias("skewness"),
                                                       df_sum(col(column)).alias("sum")
                                                       ).toPandas()


        for x in np.array([0.05, 0.25, 0.5, 0.75, 0.95]):
            stats_df[pretty_name(x)] = (df.select(column)
                                        .na.drop()
                                        .selectExpr("percentile(`{col}`,CAST({n} AS DOUBLE))"
                                                    .format(col=column, n=x)).toPandas().iloc[:,0]
                                        )
        stats = stats_df.iloc[0].copy()
        stats.name = column
        stats["range"] = stats["max"] - stats["min"]
        stats["iqr"] = stats[pretty_name(0.75)] - stats[pretty_name(0.25)]
        stats["cv"] = stats["std"] / float(stats["mean"])
        stats["mad"] = (df.select(column)
                        .na.drop()
                        .select(df_abs(col(column)-stats["mean"]).alias("delta"))
                        .agg(df_sum(col("delta"))).toPandas().iloc[0,0] / float(current_result["count"]))
        stats["type"] = "NUM"
        stats['n_zeros'] = df.select(column).where(col(column)==0.0).count()
        stats['p_zeros'] = stats['n_zeros'] / float(nrows)

        hist_data = create_hist_data(df, column, stats["min"], stats["max"], bins)

        return stats

    def describe_float_1d(df, column, current_result, nrows):
        stats_df = df.select(column).na.drop().agg(mean(col(column)).alias("mean"),
                                                       df_min(col(column)).alias("min"),
                                                       df_max(col(column)).alias("max"),
                                                       variance(col(column)).alias("variance"),
                                                       kurtosis(col(column)).alias("kurtosis"),
                                                       stddev(col(column)).alias("std"),
                                                       skewness(col(column)).alias("skewness"),
                                                       df_sum(col(column)).alias("sum")
                                                       ).toPandas()

        for x in np.array([0.05, 0.25, 0.5, 0.75, 0.95]):
            stats_df[pretty_name(x)] = (df.select(column)
                                        .na.drop()
                                        .selectExpr("percentile_approx(`{col}`,CAST({n} AS DOUBLE))"
                                                    .format(col=column, n=x)).toPandas().iloc[:,0]
                                        )
        stats = stats_df.iloc[0].copy()
        stats.name = column
        stats["range"] = stats["max"] - stats["min"]
        stats["iqr"] = stats[pretty_name(0.75)] - stats[pretty_name(0.25)]
        stats["cv"] = stats["std"] / float(stats["mean"])
        stats["mad"] = (df.select(column)
                        .na.drop()
                        .select(df_abs(col(column)-stats["mean"]).alias("delta"))
                        .agg(df_sum(col("delta"))).toPandas().iloc[0,0] / float(current_result["count"]))
        stats["type"] = "NUM"
        stats['n_zeros'] = df.select(column).where(col(column)==0.0).count()
        stats['p_zeros'] = stats['n_zeros'] / float(nrows)

        hist_data = create_hist_data(df, column, stats["min"], stats["max"], bins)

        return stats

    def describe_date_1d(df, column):
        stats_df = df.select(column).na.drop().agg(df_min(col(column)).alias("min"),
                                                   df_max(col(column)).alias("max")
                                                  ).toPandas()
        stats = stats_df.iloc[0].copy()
        stats.name = column

        if isinstance(stats["max"], pd.Timestamp):
            stats = stats.astype(object)
            stats["max"] = str(stats["max"].to_pydatetime())
            stats["min"] = str(stats["min"].to_pydatetime())

        else:
            stats["range"] = stats["max"] - stats["min"]
        stats["type"] = "DATE"
        return stats

    def guess_json_type(string_value):
        try:
            obj = json.loads(string_value)
        except:
            return None

        return type(obj)

    def describe_categorical_1d(df, column):
        value_counts = (df.select(column).na.drop()
                        .groupBy(column)
                        .agg(count(col(column)))
                        .orderBy("count({c})".format(c=column),ascending=False)
                       ).cache()

        stats = (value_counts
                 .limit(1)
                 .withColumnRenamed(column, "top")
                 .withColumnRenamed("count({c})".format(c=column), "freq")
                ).toPandas().iloc[0]

        top_50 = value_counts.limit(50).toPandas().sort_values("count({c})".format(c=column),
                                                               ascending=False)
        top_50_categories = top_50[column].values.tolist()

        others_count = pd.Series([df.select(column).na.drop()
                        .where(~(col(column).isin(*top_50_categories)))
                        .count()
                        ], index=["***Other Values***"])
        others_distinct_count = pd.Series([value_counts
                                .where(~(col(column).isin(*top_50_categories)))
                                .count()
                                ], index=["***Other Values Distinct Count***"])

        top = top_50.set_index(column)["count({c})".format(c=column)]
        top = top.append(others_count)
        top = top.append(others_distinct_count)
        stats["value_counts"] = top
        stats["type"] = "CAT"
        value_counts.unpersist()
        unparsed_valid_jsons = df.select(column).na.drop().rdd.map(
            lambda x: guess_json_type(x[column])).filter(
            lambda x: x).distinct().collect()
        stats["unparsed_json_types"] = unparsed_valid_jsons
        return stats

    def describe_constant_1d(df, column):
        stats = pd.Series(['CONST'], index=['type'], name=column)
        stats["value_counts"] = (df.select(column)
                                 .na.drop()
                                 .limit(1)).toPandas().iloc[:,0].value_counts()
        return stats

    def describe_unique_1d(df, column):
        stats = pd.Series(['UNIQUE'], index=['type'], name=column)
        stats["value_counts"] = (df.select(column)
                                 .na.drop()
                                 .limit(50)).toPandas().iloc[:,0].value_counts()
        return stats

    def describe_1d(df, column, nrows, lookup_config=None):
        column_type = df.select(column).dtypes[0][1]
        if ("array" in column_type) or ("stuct" in column_type) or ("map" in column_type):
            raise NotImplementedError("Column {c} is of type {t} and cannot be analyzed".format(c=column, t=column_type))

        distinct_count = df.select(column).agg(countDistinct(col(column)).alias("distinct_count")).toPandas()
        non_nan_count = df.select(column).na.drop().select(count(col(column)).alias("count")).toPandas()
        results_data = pd.concat([distinct_count, non_nan_count],axis=1)
        results_data["p_unique"] = results_data["distinct_count"] / float(results_data["count"])
        results_data["is_unique"] = results_data["distinct_count"] == nrows
        results_data["n_missing"] = nrows - results_data["count"]
        results_data["p_missing"] = results_data["n_missing"] / float(nrows)
        results_data["p_infinite"] = 0
        results_data["n_infinite"] = 0
        result = results_data.iloc[0].copy()
        result["memorysize"] = 0
        result.name = column

        if result["distinct_count"] <= 1:
            result = result.append(describe_constant_1d(df, column))
        elif column_type in {"tinyint", "smallint", "int", "bigint"}:
            result = result.append(describe_integer_1d(df, column, result, nrows))
        elif column_type in {"float", "double", "decimal"}:
            result = result.append(describe_float_1d(df, column, result, nrows))
        elif column_type in {"date", "timestamp"}:
            result = result.append(describe_date_1d(df, column))
        elif result["is_unique"] == True:
            result = result.append(describe_unique_1d(df, column))
        else:
            result = result.append(describe_categorical_1d(df, column))
            if result["n_missing"] > 0:
                result["distinct_count"] = result["distinct_count"] + 1

        if (result["count"] > result["distinct_count"] > 1):
            try:
                result["mode"] = result["top"]
            except KeyError:
                result["mode"] = 0
        else:
            try:
                result["mode"] = result["value_counts"].index[0]
            except KeyError:
                result["mode"] = 0
            except IndexError:
                result["mode"] = "MISSING"

        if lookup_config:
            lookup_object = lookup_config['object']
            col_name_in_db = lookup_config['col_name_in_db'] if 'col_name_in_db' in lookup_config else None
            try:
                matched, unmatched = lookup_object.lookup(df.select(column), col_name_in_db)
                result['lookedup_values'] = str(matched.count()) + "/" + str(df.select(column).count())
            except:
                result['lookedup_values'] = 'FAILED'
        else:
            result['lookedup_values'] = ''

        return result


    ldesc = {}
    for colum in df.columns:
        if colum in config:
            if 'lookup' in config[colum]:
                lookup_config = config[colum]['lookup']
                desc = describe_1d(df, colum, table_stats["n"], lookup_config=lookup_config)
            else:
                desc = describe_1d(df, colum, table_stats["n"])
        else:
            desc = describe_1d(df, colum, table_stats["n"])
        ldesc.update({colum: desc})

    if corr_reject is not None:
        computable_corrs = [colum for colum in ldesc if ldesc[colum]["type"] in {"NUM"}]

        if len(computable_corrs) > 0:
            corr = corr_matrix(df, columns=computable_corrs)
            for x, corr_x in corr.iterrows():
                for y, corr in corr_x.iteritems():
                    if x == y:
                        break

    variable_stats = pd.DataFrame(ldesc)

    table_stats["nvar"] = len(df.columns)
    table_stats["total_missing"] = float(variable_stats.loc["n_missing"].sum()) / (table_stats["n"] * table_stats["nvar"])
    memsize = 0
    table_stats['memsize'] = fmt_bytesize(memsize)
    table_stats['recordsize'] = fmt_bytesize(memsize / table_stats['n'])
    table_stats.update({k: 0 for k in ("NUM", "DATE", "CONST", "CAT", "UNIQUE", "CORR")})
    table_stats.update(dict(variable_stats.loc['type'].value_counts()))
    table_stats['REJECTED'] = table_stats['CONST'] + table_stats['CORR']

    freq_dict = {}
    for var in variable_stats:
        if "value_counts" not in variable_stats[var]:
            pass
        elif not(variable_stats[var]["value_counts"] is np.nan):
            freq_dict[var] = variable_stats[var]["value_counts"]
        else:
            pass
    try:
        variable_stats = variable_stats.drop("value_counts")
    except (ValueError, KeyError):
        pass

    return table_stats, variable_stats.T, freq_dict

import numpy as np
from pyspark.sql.functions import abs as absou

SKEWNESS_CUTOFF = 20
DEFAULT_FLOAT_FORMATTER = u'spark_df_profiling.__default_float_formatter'


def gradient_format(value, limit1, limit2, c1, c2):
    def LerpColour(c1,c2,t):
        return (int(c1[0]+(c2[0]-c1[0])*t),int(c1[1]+(c2[1]-c1[1])*t),int(c1[2]+(c2[2]-c1[2])*t))
    c = LerpColour(c1, c2, (value-limit1)/(limit2-limit1))
    return fmt_color(value,"rgb{}".format(str(c)))


def fmt_color(text, color):
    return(u'<span style="color:{color}">{text}</span>'.format(color=color,text=str(text)))


def fmt_class(text, cls):
    return(u'<span class="{cls}">{text}</span>'.format(cls=cls,text=str(text)))


def fmt_bytesize(num, suffix='B'):
    for unit in ['','Ki','Mi','Gi','Ti','Pi','Ei','Zi']:
        if num < 0:
            num = num*-1
            if num < 1024.0:
                return "%3.1f %s%s" % (num, unit, suffix)
            num /= 1024.0
    return "%.1f %s%s" % (num, 'Yi', suffix)


def fmt_percent(v):
    return  "{:2.1f}%".format(v*100)

def fmt_varname(v):
    return u'<code>{0}</code>'.format(v)


value_formatters={
        u'freq': (lambda v: gradient_format(v, 0, 62000, (30, 198, 244), (99, 200, 72))),
        u'p_missing': fmt_percent,
        u'p_infinite': fmt_percent,
        u'p_unique': fmt_percent,
        u'p_zeros': fmt_percent,
        u'memorysize': fmt_bytesize,
        u'total_missing': fmt_percent,
        DEFAULT_FLOAT_FORMATTER: lambda v: str(float('{:.5g}'.format(v))).rstrip('0').rstrip('.'),
        u'correlation_var': lambda v: fmt_varname(v),
        u'unparsed_json_types': lambda v: ', '.join([s.__name__ for s in v])
        }

def fmt_row_severity(v):
    if np.isnan(v) or v<= 0.01:
        return "ignore"
    else:
        return "alert"

def fmt_skewness(v):
    if not np.isnan(v) and (v<-SKEWNESS_CUTOFF or v> SKEWNESS_CUTOFF):
        return "alert"
    else:
        return ""

row_formatters={
    u'p_zeros': fmt_row_severity,
    u'p_missing': fmt_row_severity,
    u'p_infinite': fmt_row_severity,
    u'n_duplicates': fmt_row_severity,
    u'skewness': fmt_skewness,
}

run(["/bin/bash", "/etc/config/v3io/v3io-spark-operator.sh"])

def describe_spark(context: MLClientCtx, 
                   dataset: DataItem, 
                   artifact_path,
                   bins: int=30,
                   describe_extended: bool=True):
    
    location = dataset.local()
    
    spark = SparkSession.builder.appName("Spark job").getOrCreate()
    
    df = spark.read.csv(location, header=True, inferSchema= True)

    kwargs = []
    
    float_cols = [item[0] for item in df.dtypes if item[1].startswith('float') or item[1].startswith('double')]
    
    if describe_extended == True:
        
        table, variables, freq = describe(df, bins, float_cols, kwargs)

        tbl_1 = variables.reset_index()

        if len(freq) != 0:
            tbl_2 = pd.DataFrame.from_dict(freq, orient = "index").sort_index().stack().reset_index()
            tbl_2.columns = ['col', 'key', 'val']
            tbl_2['Merged'] = [{key: val} for key, val in zip(tbl_2.key, tbl_2.val)]
            tbl_2 = tbl_2.groupby('col', as_index=False).agg(lambda x: tuple(x))[['col','Merged']]

            summary = pd.merge(tbl_1, tbl_2, how='left', left_on='index', right_on='col')

        else:
            summary = tbl_1

        context.log_dataset("summary_stats", 
                            df=summary,
                            format="csv", index=False,
                            artifact_path=context.artifact_subpath('data'))

        context.log_results(table)
    
    else:
        tbl_1 = df.describe().toPandas()
        
        summary = tbl_1.T
        
        context.log_dataset("summary_stats", 
                            df=summary,
                            format="csv", index=False,
                            artifact_path=context.artifact_subpath('data'))
    
    spark.stop()

 - commands: [] - code_origin: https://github.com/daniels290813/functions.git#55a79c32be5d233cc11efcf40cd3edbe309bfdef:/home/kali/functions/describe_spark/describe_spark.py - affinity: null -verbose: false + default_handler: describe_spark diff --git a/functions/src/feature_selection/feature_selection.py b/functions/src/feature_selection/feature_selection.py index a046143da..af828ad7f 100644 --- a/functions/src/feature_selection/feature_selection.py +++ b/functions/src/feature_selection/feature_selection.py @@ -23,12 +23,16 @@ import plotly.express as px from mlrun.artifacts import PlotlyArtifact from mlrun.datastore.targets import ParquetTarget + # MLRun utils from mlrun.utils.helpers import create_class + # Feature selection strategies from sklearn.feature_selection import SelectFromModel, SelectKBest + # Scale feature scoresgit st from sklearn.preprocessing import MinMaxScaler + # SKLearn estimators list from sklearn.utils import all_estimators @@ -194,7 +198,7 @@ def feature_selection( selected_models = {} for model_name, model in model_filters.items(): if ".json" in model: - current_model = json.load(open(model, "r")) + current_model = json.load(open(model)) classifier_class = create_class(current_model["META"]["class"]) selected_models[model_name] = classifier_class(**current_model["CLASS"]) elif model in all_sklearn_estimators: @@ -211,7 +215,6 @@ def feature_selection( # Run model filters models_df = pd.DataFrame(index=X.columns) for model_name, model in selected_models.items(): - if model_name == "LogisticRegression": model.set_params(solver="liblinear") diff --git a/functions/src/feature_selection/function.yaml b/functions/src/feature_selection/function.yaml index 1724428d0..8cc5aaa19 100644 --- a/functions/src/feature_selection/function.yaml +++ b/functions/src/feature_selection/function.yaml @@ -1,6 +1,19 @@ +metadata: + tag: '' + name: feature-selection + categories: + - data-preparation + - machine-learning +verbose: false +kind: job spec: + image: mlrun/mlrun disable_auto_mount: false - command: '' + build: + origin_filename: '' + functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json

import mlrun
import mlrun.datastore
import mlrun.feature_store as fs
import mlrun.utils
import numpy as np
import pandas as pd
import plotly.express as px
from mlrun.artifacts import PlotlyArtifact
from mlrun.datastore.targets import ParquetTarget

# MLRun utils
from mlrun.utils.helpers import create_class

# Feature selection strategies
from sklearn.feature_selection import SelectFromModel, SelectKBest

# Scale feature scoresgit st
from sklearn.preprocessing import MinMaxScaler

# SKLearn estimators list
from sklearn.utils import all_estimators

DEFAULT_STAT_FILTERS = ["f_classif", "mutual_info_classif", "chi2", "f_regression"]
DEFAULT_MODEL_FILTERS = {
    "LinearSVC": "LinearSVC",
    "LogisticRegression": "LogisticRegression",
    "ExtraTreesClassifier": "ExtraTreesClassifier",
}


def show_values_on_bars(axs, h_v="v", space=0.4):
    def _show_on_single_plot(ax_):
        if h_v == "v":
            for p in ax_.patches:
                _x = p.get_x() + p.get_width() / 2
                _y = p.get_y() + p.get_height()
                value = int(p.get_height())
                ax_.text(_x, _y, value, ha="center")
        elif h_v == "h":
            for p in ax_.patches:
                _x = p.get_x() + p.get_width() + float(space)
                _y = p.get_y() + p.get_height()
                value = int(p.get_width())
                ax_.text(_x, _y, value, ha="left")

    if isinstance(axs, np.ndarray):
        for idx, ax in np.ndenumerate(axs):
            _show_on_single_plot(ax)
    else:
        _show_on_single_plot(axs)


def plot_stat(context, stat_name, stat_df):
    sorted_df = stat_df.sort_values(stat_name)
    fig = px.bar(
        data_frame=sorted_df,
        x=stat_name,
        y=sorted_df.index,
        title=f"{stat_name} feature scores",
        color=stat_name,
    )
    context.log_artifact(
        item=PlotlyArtifact(key=stat_name, figure=fig),
        local_path=f"{stat_name}.html",
    )


def feature_selection(
    context,
    df_artifact,
    k: int = 5,
    min_votes: float = 0.5,
    label_column: str = None,
    stat_filters: list = None,
    model_filters: dict = None,
    max_scaled_scores: bool = True,
    sample_ratio: float = None,
    output_vector_name: float = None,
    ignore_type_errors: bool = False,
):
    """
    Applies selected feature selection statistical functions or models on our 'df_artifact'.

    Each statistical function or model will vote for it's best K selected features.
    If a feature has >= 'min_votes' votes, it will be selected.

    :param context:             the function context.
    :param df_artifact:         dataframe to pass as input.
    :param k:                   number of top features to select from each statistical
                                function or model.
    :param min_votes:           minimal number of votes (from a model or by statistical
                                function) needed for a feature to be selected.
                                Can be specified by percentage of votes or absolute
                                number of votes.
    :param label_column:        ground-truth (y) labels.
    :param stat_filters:        statistical functions to apply to the features
                                (from sklearn.feature_selection).
    :param model_filters:       models to use for feature evaluation, can be specified by
                                model name (ex. LinearSVC), formalized json (contains 'CLASS',
                                'FIT', 'META') or a path to such json file.
    :param max_scaled_scores:   produce feature scores table scaled with max_scaler.
    :param sample_ratio:        percentage of the dataset the user wishes to compute the feature selection process on.
    :param output_vector_name:  creates a new feature vector containing only the identifies features.
    :param ignore_type_errors:  skips datatypes that are neither float nor int within the feature vector.
    """
    stat_filters = stat_filters or DEFAULT_STAT_FILTERS
    model_filters = model_filters or DEFAULT_MODEL_FILTERS
    # Check if df.meta is valid, if it is, look for a feature vector
    store_uri_prefix, _ = mlrun.datastore.parse_store_uri(df_artifact.artifact_url)
    is_feature_vector = mlrun.utils.StorePrefix.FeatureVector == store_uri_prefix

    # Look inside meta.spec.label_feature to identify the label_column if the user did not specify it
    if label_column is None:
        if is_feature_vector:
            label_column = df_artifact.meta.spec.label_feature.split(".")[1]
        else:
            raise ValueError("No label_column was given, please add a label_column.")

    # Use the feature vector as dataframe
    df = df_artifact.as_df()

    # Ensure k is not bigger than the total number of features
    if k > df.shape[1]:
        raise ValueError(
            f"K cannot be bigger than the total number of features ({df.shape[1]}). Please choose a smaller K."
        )
    elif k < 1:
        raise ValueError("K cannot be smaller than 1. Please choose a bigger K.")

    # Create a sample dataframe of the original feature vector
    if sample_ratio:
        df = (
            df.groupby(label_column)
            .apply(lambda x: x.sample(frac=sample_ratio))
            .reset_index(drop=True)
        )
        df = df.dropna()

    # Set feature vector and labels
    y = df.pop(label_column)
    X = df

    if np.object_ in list(X.dtypes) and ignore_type_errors is False:
        raise ValueError(
            f"{df.select_dtypes(include=['object']).columns.tolist()} are neither float or int."
        )

    # Create selected statistical estimators
    stat_functions_list = {
        stat_name: SelectKBest(
            score_func=create_class(f"sklearn.feature_selection.{stat_name}"), k=k
        )
        for stat_name in stat_filters
    }
    requires_abs = ["chi2"]

    # Run statistic filters
    selected_features_agg = {}
    stats_df = pd.DataFrame(index=X.columns).dropna()

    for stat_name, stat_func in stat_functions_list.items():
        try:
            params = (X, y) if stat_name in requires_abs else (abs(X), y)
            stat = stat_func.fit(*params)

            # Collect stat function results
            stat_df = pd.DataFrame(
                index=X.columns, columns=[stat_name], data=stat.scores_
            )
            plot_stat(context, stat_name, stat_df)
            stats_df = stats_df.join(stat_df)

            # Select K Best features
            selected_features = X.columns[stat_func.get_support()]
            selected_features_agg[stat_name] = selected_features

        except Exception as e:
            context.logger.info(f"Couldn't calculate {stat_name} because of: {e}")

    # Create models from class name / json file / json params
    all_sklearn_estimators = dict(all_estimators()) if len(model_filters) > 0 else {}
    selected_models = {}
    for model_name, model in model_filters.items():
        if ".json" in model:
            current_model = json.load(open(model))
            classifier_class = create_class(current_model["META"]["class"])
            selected_models[model_name] = classifier_class(**current_model["CLASS"])
        elif model in all_sklearn_estimators:
            selected_models[model_name] = all_sklearn_estimators[model_name]()

        else:
            try:
                current_model = json.loads(model)
                classifier_class = create_class(current_model["META"]["class"])
                selected_models[model_name] = classifier_class(**current_model["CLASS"])
            except Exception as e:
                context.logger.info(f"unable to load {model} because of: {e}")

    # Run model filters
    models_df = pd.DataFrame(index=X.columns)
    for model_name, model in selected_models.items():
        if model_name == "LogisticRegression":
            model.set_params(solver="liblinear")

        # Train model and get feature importance
        select_from_model = SelectFromModel(model).fit(X, y)
        feature_idx = select_from_model.get_support()
        feature_names = X.columns[feature_idx]
        selected_features_agg[model_name] = feature_names.tolist()

        # Collect model feature importance
        if hasattr(select_from_model.estimator_, "coef_"):
            stat_df = select_from_model.estimator_.coef_
        elif hasattr(select_from_model.estimator_, "feature_importances_"):
            stat_df = select_from_model.estimator_.feature_importances_

        stat_df = pd.DataFrame(index=X.columns, columns=[model_name], data=stat_df[0])
        models_df = models_df.join(stat_df)

        plot_stat(context, model_name, stat_df)

    # Create feature_scores DF with stat & model filters scores
    result_matrix_df = pd.concat([stats_df, models_df], axis=1, sort=False)
    context.log_dataset(
        key="feature_scores",
        df=result_matrix_df,
        local_path="feature_scores.parquet",
        format="parquet",
    )
    if max_scaled_scores:
        normalized_df = result_matrix_df.replace([np.inf, -np.inf], np.nan).values
        min_max_scaler = MinMaxScaler()
        normalized_df = min_max_scaler.fit_transform(normalized_df)
        normalized_df = pd.DataFrame(
            data=normalized_df,
            columns=result_matrix_df.columns,
            index=result_matrix_df.index,
        )
        context.log_dataset(
            key="max_scaled_scores_feature_scores",
            df=normalized_df,
            local_path="max_scaled_scores_feature_scores.parquet",
            format="parquet",
        )

    # Create feature count DataFrame
    for test_name in selected_features_agg:
        result_matrix_df[test_name] = [
            1 if x in selected_features_agg[test_name] else 0 for x in X.columns
        ]
    result_matrix_df.loc[:, "num_votes"] = result_matrix_df.sum(axis=1)
    context.log_dataset(
        key="selected_features_count",
        df=result_matrix_df,
        local_path="selected_features_count.parquet",
        format="parquet",
    )

    # How many votes are needed for a feature to be selected?
    if isinstance(min_votes, int):
        votes_needed = min_votes
    else:
        num_filters = len(stat_filters) + len(model_filters)
        votes_needed = int(np.floor(num_filters * max(min(min_votes, 1), 0)))
    context.logger.info(f"votes needed to be selected: {votes_needed}")

    # Create final feature dataframe
    selected_features = result_matrix_df[
        result_matrix_df.num_votes >= votes_needed
    ].index.tolist()
    good_feature_df = df.loc[:, selected_features]
    final_df = pd.concat([good_feature_df, y], axis=1)
    context.log_dataset(
        key="selected_features",
        df=final_df,
        local_path="selected_features.parquet",
        format="parquet",
    )

    # Creating a new feature vector containing only the identified top features
    if is_feature_vector and df_artifact.meta.spec.features and output_vector_name:
        # Selecting the top K features from our top feature dataframe
        selected_features = result_matrix_df.head(k).index

        # Match the selected feature names to the FS Feature annotations
        matched_selections = [
            feature
            for feature in list(df_artifact.meta.spec.features)
            for selected in list(selected_features)
            if feature.endswith(selected)
        ]

        # Defining our new feature vector
        top_features_fv = fs.FeatureVector(
            output_vector_name,
            matched_selections,
            label_feature="labels.label",
            description="feature vector composed strictly of our top features",
        )

        # Saving
        top_features_fv.save()
        top_features_fv.get_offline_features(target=ParquetTarget())

        # Logging our new feature vector URI
        context.log_result("top_features_vector", top_features_fv.uri)
 + code_origin: '' + filename: feature_selection.py entry_points: show_values_on_bars: parameters: @@ -10,20 +23,20 @@ spec: - name: space default: 0.4 name: show_values_on_bars - lineno: 43 + doc: '' has_kwargs: false has_varargs: false - doc: '' + lineno: 47 plot_stat: parameters: - name: context - name: stat_name - name: stat_df name: plot_stat - lineno: 65 + doc: '' has_kwargs: false has_varargs: false - doc: '' + lineno: 69 feature_selection: parameters: - name: context @@ -72,9 +85,6 @@ spec: doc: skips datatypes that are neither float nor int within the feature vector. default: false name: feature_selection - lineno: 80 - has_kwargs: false - has_varargs: false doc: 'Applies selected feature selection statistical functions or models on our ''df_artifact''. @@ -82,18 +92,9 @@ spec: Each statistical function or model will vote for it''s best K selected features. If a feature has >= ''min_votes'' votes, it will be selected.' - image: mlrun/mlrun - build: - origin_filename: '' - functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json

import mlrun
import mlrun.datastore
import mlrun.feature_store as fs
import mlrun.utils
import numpy as np
import pandas as pd
import plotly.express as px
from mlrun.artifacts import PlotlyArtifact
from mlrun.datastore.targets import ParquetTarget
# MLRun utils
from mlrun.utils.helpers import create_class
# Feature selection strategies
from sklearn.feature_selection import SelectFromModel, SelectKBest
# Scale feature scoresgit st
from sklearn.preprocessing import MinMaxScaler
# SKLearn estimators list
from sklearn.utils import all_estimators

DEFAULT_STAT_FILTERS = ["f_classif", "mutual_info_classif", "chi2", "f_regression"]
DEFAULT_MODEL_FILTERS = {
    "LinearSVC": "LinearSVC",
    "LogisticRegression": "LogisticRegression",
    "ExtraTreesClassifier": "ExtraTreesClassifier",
}


def show_values_on_bars(axs, h_v="v", space=0.4):
    def _show_on_single_plot(ax_):
        if h_v == "v":
            for p in ax_.patches:
                _x = p.get_x() + p.get_width() / 2
                _y = p.get_y() + p.get_height()
                value = int(p.get_height())
                ax_.text(_x, _y, value, ha="center")
        elif h_v == "h":
            for p in ax_.patches:
                _x = p.get_x() + p.get_width() + float(space)
                _y = p.get_y() + p.get_height()
                value = int(p.get_width())
                ax_.text(_x, _y, value, ha="left")

    if isinstance(axs, np.ndarray):
        for idx, ax in np.ndenumerate(axs):
            _show_on_single_plot(ax)
    else:
        _show_on_single_plot(axs)


def plot_stat(context, stat_name, stat_df):
    sorted_df = stat_df.sort_values(stat_name)
    fig = px.bar(
        data_frame=sorted_df,
        x=stat_name,
        y=sorted_df.index,
        title=f"{stat_name} feature scores",
        color=stat_name,
    )
    context.log_artifact(
        item=PlotlyArtifact(key=stat_name, figure=fig),
        local_path=f"{stat_name}.html",
    )


def feature_selection(
    context,
    df_artifact,
    k: int = 5,
    min_votes: float = 0.5,
    label_column: str = None,
    stat_filters: list = None,
    model_filters: dict = None,
    max_scaled_scores: bool = True,
    sample_ratio: float = None,
    output_vector_name: float = None,
    ignore_type_errors: bool = False,
):
    """
    Applies selected feature selection statistical functions or models on our 'df_artifact'.

    Each statistical function or model will vote for it's best K selected features.
    If a feature has >= 'min_votes' votes, it will be selected.

    :param context:             the function context.
    :param df_artifact:         dataframe to pass as input.
    :param k:                   number of top features to select from each statistical
                                function or model.
    :param min_votes:           minimal number of votes (from a model or by statistical
                                function) needed for a feature to be selected.
                                Can be specified by percentage of votes or absolute
                                number of votes.
    :param label_column:        ground-truth (y) labels.
    :param stat_filters:        statistical functions to apply to the features
                                (from sklearn.feature_selection).
    :param model_filters:       models to use for feature evaluation, can be specified by
                                model name (ex. LinearSVC), formalized json (contains 'CLASS',
                                'FIT', 'META') or a path to such json file.
    :param max_scaled_scores:   produce feature scores table scaled with max_scaler.
    :param sample_ratio:        percentage of the dataset the user wishes to compute the feature selection process on.
    :param output_vector_name:  creates a new feature vector containing only the identifies features.
    :param ignore_type_errors:  skips datatypes that are neither float nor int within the feature vector.
    """
    stat_filters = stat_filters or DEFAULT_STAT_FILTERS
    model_filters = model_filters or DEFAULT_MODEL_FILTERS
    # Check if df.meta is valid, if it is, look for a feature vector
    store_uri_prefix, _ = mlrun.datastore.parse_store_uri(df_artifact.artifact_url)
    is_feature_vector = mlrun.utils.StorePrefix.FeatureVector == store_uri_prefix

    # Look inside meta.spec.label_feature to identify the label_column if the user did not specify it
    if label_column is None:
        if is_feature_vector:
            label_column = df_artifact.meta.spec.label_feature.split(".")[1]
        else:
            raise ValueError("No label_column was given, please add a label_column.")

    # Use the feature vector as dataframe
    df = df_artifact.as_df()

    # Ensure k is not bigger than the total number of features
    if k > df.shape[1]:
        raise ValueError(
            f"K cannot be bigger than the total number of features ({df.shape[1]}). Please choose a smaller K."
        )
    elif k < 1:
        raise ValueError("K cannot be smaller than 1. Please choose a bigger K.")

    # Create a sample dataframe of the original feature vector
    if sample_ratio:
        df = (
            df.groupby(label_column)
            .apply(lambda x: x.sample(frac=sample_ratio))
            .reset_index(drop=True)
        )
        df = df.dropna()

    # Set feature vector and labels
    y = df.pop(label_column)
    X = df

    if np.object_ in list(X.dtypes) and ignore_type_errors is False:
        raise ValueError(
            f"{df.select_dtypes(include=['object']).columns.tolist()} are neither float or int."
        )

    # Create selected statistical estimators
    stat_functions_list = {
        stat_name: SelectKBest(
            score_func=create_class(f"sklearn.feature_selection.{stat_name}"), k=k
        )
        for stat_name in stat_filters
    }
    requires_abs = ["chi2"]

    # Run statistic filters
    selected_features_agg = {}
    stats_df = pd.DataFrame(index=X.columns).dropna()

    for stat_name, stat_func in stat_functions_list.items():
        try:
            params = (X, y) if stat_name in requires_abs else (abs(X), y)
            stat = stat_func.fit(*params)

            # Collect stat function results
            stat_df = pd.DataFrame(
                index=X.columns, columns=[stat_name], data=stat.scores_
            )
            plot_stat(context, stat_name, stat_df)
            stats_df = stats_df.join(stat_df)

            # Select K Best features
            selected_features = X.columns[stat_func.get_support()]
            selected_features_agg[stat_name] = selected_features

        except Exception as e:
            context.logger.info(f"Couldn't calculate {stat_name} because of: {e}")

    # Create models from class name / json file / json params
    all_sklearn_estimators = dict(all_estimators()) if len(model_filters) > 0 else {}
    selected_models = {}
    for model_name, model in model_filters.items():
        if ".json" in model:
            current_model = json.load(open(model, "r"))
            classifier_class = create_class(current_model["META"]["class"])
            selected_models[model_name] = classifier_class(**current_model["CLASS"])
        elif model in all_sklearn_estimators:
            selected_models[model_name] = all_sklearn_estimators[model_name]()

        else:
            try:
                current_model = json.loads(model)
                classifier_class = create_class(current_model["META"]["class"])
                selected_models[model_name] = classifier_class(**current_model["CLASS"])
            except Exception as e:
                context.logger.info(f"unable to load {model} because of: {e}")

    # Run model filters
    models_df = pd.DataFrame(index=X.columns)
    for model_name, model in selected_models.items():

        if model_name == "LogisticRegression":
            model.set_params(solver="liblinear")

        # Train model and get feature importance
        select_from_model = SelectFromModel(model).fit(X, y)
        feature_idx = select_from_model.get_support()
        feature_names = X.columns[feature_idx]
        selected_features_agg[model_name] = feature_names.tolist()

        # Collect model feature importance
        if hasattr(select_from_model.estimator_, "coef_"):
            stat_df = select_from_model.estimator_.coef_
        elif hasattr(select_from_model.estimator_, "feature_importances_"):
            stat_df = select_from_model.estimator_.feature_importances_

        stat_df = pd.DataFrame(index=X.columns, columns=[model_name], data=stat_df[0])
        models_df = models_df.join(stat_df)

        plot_stat(context, model_name, stat_df)

    # Create feature_scores DF with stat & model filters scores
    result_matrix_df = pd.concat([stats_df, models_df], axis=1, sort=False)
    context.log_dataset(
        key="feature_scores",
        df=result_matrix_df,
        local_path="feature_scores.parquet",
        format="parquet",
    )
    if max_scaled_scores:
        normalized_df = result_matrix_df.replace([np.inf, -np.inf], np.nan).values
        min_max_scaler = MinMaxScaler()
        normalized_df = min_max_scaler.fit_transform(normalized_df)
        normalized_df = pd.DataFrame(
            data=normalized_df,
            columns=result_matrix_df.columns,
            index=result_matrix_df.index,
        )
        context.log_dataset(
            key="max_scaled_scores_feature_scores",
            df=normalized_df,
            local_path="max_scaled_scores_feature_scores.parquet",
            format="parquet",
        )

    # Create feature count DataFrame
    for test_name in selected_features_agg:
        result_matrix_df[test_name] = [
            1 if x in selected_features_agg[test_name] else 0 for x in X.columns
        ]
    result_matrix_df.loc[:, "num_votes"] = result_matrix_df.sum(axis=1)
    context.log_dataset(
        key="selected_features_count",
        df=result_matrix_df,
        local_path="selected_features_count.parquet",
        format="parquet",
    )

    # How many votes are needed for a feature to be selected?
    if isinstance(min_votes, int):
        votes_needed = min_votes
    else:
        num_filters = len(stat_filters) + len(model_filters)
        votes_needed = int(np.floor(num_filters * max(min(min_votes, 1), 0)))
    context.logger.info(f"votes needed to be selected: {votes_needed}")

    # Create final feature dataframe
    selected_features = result_matrix_df[
        result_matrix_df.num_votes >= votes_needed
    ].index.tolist()
    good_feature_df = df.loc[:, selected_features]
    final_df = pd.concat([good_feature_df, y], axis=1)
    context.log_dataset(
        key="selected_features",
        df=final_df,
        local_path="selected_features.parquet",
        format="parquet",
    )

    # Creating a new feature vector containing only the identified top features
    if is_feature_vector and df_artifact.meta.spec.features and output_vector_name:
        # Selecting the top K features from our top feature dataframe
        selected_features = result_matrix_df.head(k).index

        # Match the selected feature names to the FS Feature annotations
        matched_selections = [
            feature
            for feature in list(df_artifact.meta.spec.features)
            for selected in list(selected_features)
            if feature.endswith(selected)
        ]

        # Defining our new feature vector
        top_features_fv = fs.FeatureVector(
            output_vector_name,
            matched_selections,
            label_feature="labels.label",
            description="feature vector composed strictly of our top features",
        )

        # Saving
        top_features_fv.save()
        top_features_fv.get_offline_features(target=ParquetTarget())

        # Logging our new feature vector URI
        context.log_result("top_features_vector", top_features_fv.uri)
 - code_origin: '' + has_kwargs: false + has_varargs: false + lineno: 84 + command: '' description: Select features through multiple Statistical and Model filters default_handler: feature_selection -kind: job -metadata: - categories: - - data-preparation - - machine-learning - name: feature-selection - tag: '' -verbose: false diff --git a/functions/src/feature_selection/item.yaml b/functions/src/feature_selection/item.yaml index 4f9a3a5dd..8e0911229 100644 --- a/functions/src/feature_selection/item.yaml +++ b/functions/src/feature_selection/item.yaml @@ -12,7 +12,7 @@ labels: author: Iguazio maintainers: [] marketplaceType: '' -mlrunVersion: 1.8.0-rc40 +mlrunVersion: 1.8.0 name: feature-selection platformVersion: 3.6.0 spec: diff --git a/functions/src/gen_class_data/function.yaml b/functions/src/gen_class_data/function.yaml index 1769bec07..b4d175d67 100644 --- a/functions/src/gen_class_data/function.yaml +++ b/functions/src/gen_class_data/function.yaml @@ -1,14 +1,20 @@ metadata: - categories: - - data-generation tag: '' name: gen-class-data + categories: + - data-generation +verbose: false +kind: job spec: - description: Create a binary classification sample dataset and save. - default_handler: gen_class_data + image: mlrun/mlrun + disable_auto_mount: false + build: + origin_filename: '' + functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKCmltcG9ydCBwYW5kYXMgYXMgcGQKZnJvbSBtbHJ1bi5leGVjdXRpb24gaW1wb3J0IE1MQ2xpZW50Q3R4CmZyb20gc2tsZWFybi5kYXRhc2V0cyBpbXBvcnQgbWFrZV9jbGFzc2lmaWNhdGlvbgoKCmRlZiBnZW5fY2xhc3NfZGF0YSgKICAgIGNvbnRleHQ6IE1MQ2xpZW50Q3R4LAogICAgbl9zYW1wbGVzOiBpbnQsCiAgICBtX2ZlYXR1cmVzOiBpbnQsCiAgICBrX2NsYXNzZXM6IGludCwKICAgIGhlYWRlcjogbGlzdFtzdHJdIHwgTm9uZSwKICAgIGxhYmVsX2NvbHVtbjogc3RyIHwgTm9uZSA9ICJsYWJlbHMiLAogICAgd2VpZ2h0OiBmbG9hdCA9IDAuNSwKICAgIHJhbmRvbV9zdGF0ZTogaW50ID0gMSwKICAgIGtleTogc3RyID0gImNsYXNzaWZpZXItZGF0YSIsCiAgICBmaWxlX2V4dDogc3RyID0gInBhcnF1ZXQiLAogICAgc2tfcGFyYW1zPXt9LAopOgogICAgIiIiQ3JlYXRlIGEgYmluYXJ5IGNsYXNzaWZpY2F0aW9uIHNhbXBsZSBkYXRhc2V0IGFuZCBzYXZlLgogICAgSWYgbm8gZmlsZW5hbWUgaXMgZ2l2ZW4gaXQgd2lsbCBkZWZhdWx0IHRvOgogICAgInNpbWRhdGEte25fc2FtcGxlc31Ye21fZmVhdHVyZXN9LnBhcnF1ZXQiLgoKICAgIEFkZGl0aW9uYWwgc2Npa2l0LWxlYXJuIHBhcmFtZXRlcnMgY2FuIGJlIHNldCB1c2luZyAqKnNrX3BhcmFtcywgcGxlYXNlIHNlZSBodHRwczovL3NjaWtpdC1sZWFybi5vcmcvc3RhYmxlL21vZHVsZXMvZ2VuZXJhdGVkL3NrbGVhcm4uZGF0YXNldHMubWFrZV9jbGFzc2lmaWNhdGlvbi5odG1sIGZvciBtb3JlIGRldGFpbHMuCgogICAgOnBhcmFtIGNvbnRleHQ6ICAgICAgIGZ1bmN0aW9uIGNvbnRleHQKICAgIDpwYXJhbSBuX3NhbXBsZXM6ICAgICBudW1iZXIgb2Ygcm93cy9zYW1wbGVzCiAgICA6cGFyYW0gbV9mZWF0dXJlczogICAgbnVtYmVyIG9mIGNvbHMvZmVhdHVyZXMKICAgIDpwYXJhbSBrX2NsYXNzZXM6ICAgICBudW1iZXIgb2YgY2xhc3NlcwogICAgOnBhcmFtIGhlYWRlcjogICAgICAgIGhlYWRlciBmb3IgZmVhdHVyZXMgYXJyYXkKICAgIDpwYXJhbSBsYWJlbF9jb2x1bW46ICBjb2x1bW4gbmFtZSBvZiBncm91bmQtdHJ1dGggc2VyaWVzCiAgICA6cGFyYW0gd2VpZ2h0OiAgICAgICAgZnJhY3Rpb24gb2Ygc2FtcGxlIG5lZ2F0aXZlIHZhbHVlIChncm91bmQtdHJ1dGg9MCkKICAgIDpwYXJhbSByYW5kb21fc3RhdGU6ICBybmcgc2VlZCAoc2VlIGh0dHBzOi8vc2Npa2l0LWxlYXJuLm9yZy9zdGFibGUvZ2xvc3NhcnkuaHRtbCN0ZXJtLXJhbmRvbS1zdGF0ZSkKICAgIDpwYXJhbSBrZXk6ICAgICAgICAgICBrZXkgb2YgZGF0YSBpbiBhcnRpZmFjdCBzdG9yZQogICAgOnBhcmFtIGZpbGVfZXh0OiAgICAgIChwcXQpIGV4dGVuc2lvbiBmb3IgcGFycXVldCBmaWxlCiAgICA6cGFyYW0gc2tfcGFyYW1zOiAgICAgYWRkaXRpb25hbCBwYXJhbWV0ZXJzIGZvciBgc2tsZWFybi5kYXRhc2V0cy5tYWtlX2NsYXNzaWZpY2F0aW9uYAogICAgIiIiCiAgICBmZWF0dXJlcywgbGFiZWxzID0gbWFrZV9jbGFzc2lmaWNhdGlvbigKICAgICAgICBuX3NhbXBsZXM9bl9zYW1wbGVzLAogICAgICAgIG5fZmVhdHVyZXM9bV9mZWF0dXJlcywKICAgICAgICB3ZWlnaHRzPXdlaWdodCwKICAgICAgICBuX2NsYXNzZXM9a19jbGFzc2VzLAogICAgICAgIHJhbmRvbV9zdGF0ZT1yYW5kb21fc3RhdGUsCiAgICAgICAgKipza19wYXJhbXMsCiAgICApCgogICAgIyBtYWtlIGRhdGFmcmFtZXMsIGFkZCBjb2x1bW4gbmFtZXMsIGNvbmNhdGVuYXRlIChYLCB5KQogICAgWCA9IHBkLkRhdGFGcmFtZShmZWF0dXJlcykKICAgIGlmIG5vdCBoZWFkZXI6CiAgICAgICAgWC5jb2x1bW5zID0gWyJmZWF0XyIgKyBzdHIoeCkgZm9yIHggaW4gcmFuZ2UobV9mZWF0dXJlcyldCiAgICBlbHNlOgogICAgICAgIFguY29sdW1ucyA9IGhlYWRlcgoKICAgIHkgPSBwZC5EYXRhRnJhbWUobGFiZWxzLCBjb2x1bW5zPVtsYWJlbF9jb2x1bW5dKQogICAgZGF0YSA9IHBkLmNvbmNhdChbWCwgeV0sIGF4aXM9MSkKCiAgICBjb250ZXh0LmxvZ19kYXRhc2V0KGtleSwgZGY9ZGF0YSwgZm9ybWF0PWZpbGVfZXh0LCBpbmRleD1GYWxzZSkK + code_origin: '' + filename: gen_class_data.py entry_points: gen_class_data: - has_kwargs: false parameters: - name: context type: MLClientCtx @@ -23,10 +29,8 @@ spec: type: int doc: number of classes - name: header - type: Optional[List[str]] doc: header for features array - name: label_column - type: Optional[str] doc: column name of ground-truth series default: labels - name: weight @@ -48,7 +52,7 @@ spec: - name: sk_params doc: additional parameters for `sklearn.datasets.make_classification` default: {} - lineno: 22 + name: gen_class_data doc: 'Create a binary classification sample dataset and save. If no filename is given it will default to: @@ -59,14 +63,9 @@ spec: Additional scikit-learn parameters can be set using **sk_params, please see https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html for more details.' + has_kwargs: false has_varargs: false - name: gen_class_data + lineno: 21 command: '' - disable_auto_mount: false - image: mlrun/mlrun - build: - origin_filename: '' - functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKaW1wb3J0IHBhbmRhcyBhcyBwZApmcm9tIHR5cGluZyBpbXBvcnQgT3B0aW9uYWwsIExpc3QKZnJvbSBza2xlYXJuLmRhdGFzZXRzIGltcG9ydCBtYWtlX2NsYXNzaWZpY2F0aW9uCgpmcm9tIG1scnVuLmV4ZWN1dGlvbiBpbXBvcnQgTUxDbGllbnRDdHgKCgpkZWYgZ2VuX2NsYXNzX2RhdGEoCiAgICAgICAgY29udGV4dDogTUxDbGllbnRDdHgsCiAgICAgICAgbl9zYW1wbGVzOiBpbnQsCiAgICAgICAgbV9mZWF0dXJlczogaW50LAogICAgICAgIGtfY2xhc3NlczogaW50LAogICAgICAgIGhlYWRlcjogT3B0aW9uYWxbTGlzdFtzdHJdXSwKICAgICAgICBsYWJlbF9jb2x1bW46IE9wdGlvbmFsW3N0cl0gPSAibGFiZWxzIiwKICAgICAgICB3ZWlnaHQ6IGZsb2F0ID0gMC41LAogICAgICAgIHJhbmRvbV9zdGF0ZTogaW50ID0gMSwKICAgICAgICBrZXk6IHN0ciA9ICJjbGFzc2lmaWVyLWRhdGEiLAogICAgICAgIGZpbGVfZXh0OiBzdHIgPSAicGFycXVldCIsCiAgICAgICAgc2tfcGFyYW1zPXt9Cik6CiAgICAiIiJDcmVhdGUgYSBiaW5hcnkgY2xhc3NpZmljYXRpb24gc2FtcGxlIGRhdGFzZXQgYW5kIHNhdmUuCiAgICBJZiBubyBmaWxlbmFtZSBpcyBnaXZlbiBpdCB3aWxsIGRlZmF1bHQgdG86CiAgICAic2ltZGF0YS17bl9zYW1wbGVzfVh7bV9mZWF0dXJlc30ucGFycXVldCIuCgogICAgQWRkaXRpb25hbCBzY2lraXQtbGVhcm4gcGFyYW1ldGVycyBjYW4gYmUgc2V0IHVzaW5nICoqc2tfcGFyYW1zLCBwbGVhc2Ugc2VlIGh0dHBzOi8vc2Npa2l0LWxlYXJuLm9yZy9zdGFibGUvbW9kdWxlcy9nZW5lcmF0ZWQvc2tsZWFybi5kYXRhc2V0cy5tYWtlX2NsYXNzaWZpY2F0aW9uLmh0bWwgZm9yIG1vcmUgZGV0YWlscy4KCiAgICA6cGFyYW0gY29udGV4dDogICAgICAgZnVuY3Rpb24gY29udGV4dAogICAgOnBhcmFtIG5fc2FtcGxlczogICAgIG51bWJlciBvZiByb3dzL3NhbXBsZXMKICAgIDpwYXJhbSBtX2ZlYXR1cmVzOiAgICBudW1iZXIgb2YgY29scy9mZWF0dXJlcwogICAgOnBhcmFtIGtfY2xhc3NlczogICAgIG51bWJlciBvZiBjbGFzc2VzCiAgICA6cGFyYW0gaGVhZGVyOiAgICAgICAgaGVhZGVyIGZvciBmZWF0dXJlcyBhcnJheQogICAgOnBhcmFtIGxhYmVsX2NvbHVtbjogIGNvbHVtbiBuYW1lIG9mIGdyb3VuZC10cnV0aCBzZXJpZXMKICAgIDpwYXJhbSB3ZWlnaHQ6ICAgICAgICBmcmFjdGlvbiBvZiBzYW1wbGUgbmVnYXRpdmUgdmFsdWUgKGdyb3VuZC10cnV0aD0wKQogICAgOnBhcmFtIHJhbmRvbV9zdGF0ZTogIHJuZyBzZWVkIChzZWUgaHR0cHM6Ly9zY2lraXQtbGVhcm4ub3JnL3N0YWJsZS9nbG9zc2FyeS5odG1sI3Rlcm0tcmFuZG9tLXN0YXRlKQogICAgOnBhcmFtIGtleTogICAgICAgICAgIGtleSBvZiBkYXRhIGluIGFydGlmYWN0IHN0b3JlCiAgICA6cGFyYW0gZmlsZV9leHQ6ICAgICAgKHBxdCkgZXh0ZW5zaW9uIGZvciBwYXJxdWV0IGZpbGUKICAgIDpwYXJhbSBza19wYXJhbXM6ICAgICBhZGRpdGlvbmFsIHBhcmFtZXRlcnMgZm9yIGBza2xlYXJuLmRhdGFzZXRzLm1ha2VfY2xhc3NpZmljYXRpb25gCiAgICAiIiIKICAgIGZlYXR1cmVzLCBsYWJlbHMgPSBtYWtlX2NsYXNzaWZpY2F0aW9uKAogICAgICAgIG5fc2FtcGxlcz1uX3NhbXBsZXMsCiAgICAgICAgbl9mZWF0dXJlcz1tX2ZlYXR1cmVzLAogICAgICAgIHdlaWdodHM9d2VpZ2h0LAogICAgICAgIG5fY2xhc3Nlcz1rX2NsYXNzZXMsCiAgICAgICAgcmFuZG9tX3N0YXRlPXJhbmRvbV9zdGF0ZSwKICAgICAgICAqKnNrX3BhcmFtcykKCiAgICAjIG1ha2UgZGF0YWZyYW1lcywgYWRkIGNvbHVtbiBuYW1lcywgY29uY2F0ZW5hdGUgKFgsIHkpCiAgICBYID0gcGQuRGF0YUZyYW1lKGZlYXR1cmVzKQogICAgaWYgbm90IGhlYWRlcjoKICAgICAgICBYLmNvbHVtbnMgPSBbImZlYXRfIiArIHN0cih4KSBmb3IgeCBpbiByYW5nZShtX2ZlYXR1cmVzKV0KICAgIGVsc2U6CiAgICAgICAgWC5jb2x1bW5zID0gaGVhZGVyCgogICAgeSA9IHBkLkRhdGFGcmFtZShsYWJlbHMsIGNvbHVtbnM9W2xhYmVsX2NvbHVtbl0pCiAgICBkYXRhID0gcGQuY29uY2F0KFtYLCB5XSwgYXhpcz0xKQoKICAgIGNvbnRleHQubG9nX2RhdGFzZXQoa2V5LCBkZj1kYXRhLCBmb3JtYXQ9ZmlsZV9leHQsIGluZGV4PUZhbHNlKQo= - code_origin: '' -kind: job -verbose: false + description: Create a binary classification sample dataset and save. + default_handler: gen_class_data diff --git a/functions/src/gen_class_data/gen_class_data.py b/functions/src/gen_class_data/gen_class_data.py index 2e5ab1073..8e8774f00 100644 --- a/functions/src/gen_class_data/gen_class_data.py +++ b/functions/src/gen_class_data/gen_class_data.py @@ -12,25 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import pandas as pd -from typing import Optional, List -from sklearn.datasets import make_classification +import pandas as pd from mlrun.execution import MLClientCtx +from sklearn.datasets import make_classification def gen_class_data( - context: MLClientCtx, - n_samples: int, - m_features: int, - k_classes: int, - header: Optional[List[str]], - label_column: Optional[str] = "labels", - weight: float = 0.5, - random_state: int = 1, - key: str = "classifier-data", - file_ext: str = "parquet", - sk_params={} + context: MLClientCtx, + n_samples: int, + m_features: int, + k_classes: int, + header: list[str] | None, + label_column: str | None = "labels", + weight: float = 0.5, + random_state: int = 1, + key: str = "classifier-data", + file_ext: str = "parquet", + sk_params={}, ): """Create a binary classification sample dataset and save. If no filename is given it will default to: @@ -56,7 +55,8 @@ def gen_class_data( weights=weight, n_classes=k_classes, random_state=random_state, - **sk_params) + **sk_params, + ) # make dataframes, add column names, concatenate (X, y) X = pd.DataFrame(features) diff --git a/functions/src/gen_class_data/test_gen_class_data.py b/functions/src/gen_class_data/test_gen_class_data.py index e06eeb16b..deb354dc0 100644 --- a/functions/src/gen_class_data/test_gen_class_data.py +++ b/functions/src/gen_class_data/test_gen_class_data.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from mlrun import code_to_function import os +from mlrun import code_to_function + def test_gen_class_data(): fn = code_to_function( - name='test_gen_class_data', + name="test_gen_class_data", filename="gen_class_data.py", handler="gen_class_data", kind="job", @@ -32,8 +33,11 @@ def test_gen_class_data(): "header": None, "weight": [0.5, 0.5], "sk_params": {"n_informative": 2}, - "file_ext": "csv"}, + "file_ext": "csv", + }, local=True, artifact_path="./artifacts", - ) - assert os.path.isfile(run.status.artifacts[0]['spec']['target_path']), 'dataset is not available' + ) + assert os.path.isfile(run.status.artifacts[0]["spec"]["target_path"]), ( + "dataset is not available" + ) diff --git a/functions/src/github_utils/function.yaml b/functions/src/github_utils/function.yaml index 2d5d93aab..68b5afd8f 100644 --- a/functions/src/github_utils/function.yaml +++ b/functions/src/github_utils/function.yaml @@ -1,64 +1,52 @@ -kind: job metadata: - name: github-utils tag: '' - hash: d8e639af306794ce6f59eb246f0b845c016c9da4 - project: '' - labels: - author: Iguazio + name: github-utils categories: - utils +verbose: false +kind: job spec: - command: '' - args: [] image: mlrun/mlrun - env: [] - default_handler: run_summary_comment + disable_auto_mount: false + build: + origin_filename: '' + functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKIyBHZW5lcmF0ZWQgYnkgbnVjbGlvLmV4cG9ydC5OdWNsaW9FeHBvcnRlcgoKaW1wb3J0IG9zCgppbXBvcnQgcmVxdWVzdHMKZnJvbSBtbHJ1biBpbXBvcnQgRGF0YUl0ZW0sIGdldF9ydW5fZGIKCgpkZWYgcHJfY29tbWVudCgKICAgIGNvbnRleHQsIHJlcG86IHN0ciwgaXNzdWU6IGludCwgbWVzc2FnZTogc3RyID0gIiIsIG1lc3NhZ2VfZmlsZTogRGF0YUl0ZW0gPSBOb25lCik6CiAgICB0b2tlbiA9IGNvbnRleHQuZ2V0X3NlY3JldCgiR0lUSFVCX1RPS0VOIikgb3Igb3MuZW52aXJvbi5nZXQoIkdJVEhVQl9UT0tFTiIpCiAgICBpZiBtZXNzYWdlX2ZpbGUgYW5kIG5vdCBtZXNzYWdlOgogICAgICAgIG1lc3NhZ2UgPSBtZXNzYWdlX2ZpbGUuZ2V0KCkKICAgIGVsaWYgbm90IG1lc3NhZ2UgYW5kIG5vdCBtZXNzYWdlX2ZpbGU6CiAgICAgICAgcmFpc2UgVmFsdWVFcnJvcigicHIgbWVzc2FnZSBvciBtZXNzYWdlIGZpbGUgbXVzdCBiZSBwcm92aWRlZCIpCgogICAgaGVhZGVycyA9IHsKICAgICAgICAiQWNjZXB0IjogImFwcGxpY2F0aW9uL3ZuZC5naXRodWIudjMranNvbiIsCiAgICAgICAgIkF1dGhvcml6YXRpb24iOiBmInRva2VuIHt0b2tlbn0iLAogICAgfQogICAgdXJsID0gZiJodHRwczovL2FwaS5naXRodWIuY29tL3JlcG9zL3tyZXBvfS9pc3N1ZXMve2lzc3VlfS9jb21tZW50cyIKCiAgICByZXNwID0gcmVxdWVzdHMucG9zdCh1cmw9dXJsLCBqc29uPXsiYm9keSI6IHN0cihtZXNzYWdlKX0sIGhlYWRlcnM9aGVhZGVycykKICAgIGlmIG5vdCByZXNwLm9rOgogICAgICAgIGVycm1zZyA9IGYiYmFkIHByIGNvbW1lbnQgcmVzcCEhXG57cmVzcC50ZXh0fSIKICAgICAgICBjb250ZXh0LmxvZ2dlci5lcnJvcihlcnJtc2cpCiAgICAgICAgcmFpc2UgT1NFcnJvcihlcnJtc2cpCgoKZGVmIHJ1bl9zdW1tYXJ5X2NvbW1lbnQoY29udGV4dCwgd29ya2Zsb3dfaWQsIHJlcG86IHN0ciwgaXNzdWU6IGludCwgcHJvamVjdD0iIik6CiAgICBkYiA9IGdldF9ydW5fZGIoKS5jb25uZWN0KCkKICAgIHByb2plY3QgPSBwcm9qZWN0IG9yIGNvbnRleHQucHJvamVjdAogICAgcnVucyA9IGRiLmxpc3RfcnVucyhwcm9qZWN0PXByb2plY3QsIGxhYmVscz1mIndvcmtmbG93PXt3b3JrZmxvd19pZH0iKQoKICAgIGhhZF9lcnJvcnMgPSBpID0gMAogICAgZm9yIHIgaW4gcnVuczoKICAgICAgICBuYW1lID0gclsibWV0YWRhdGEiXVsibmFtZSJdCiAgICAgICAgaWYgclsic3RhdHVzIl0uZ2V0KCJzdGF0ZSIsICIiKSA9PSAiZXJyb3IiOgogICAgICAgICAgICBoYWRfZXJyb3JzICs9IDEKICAgICAgICBpZiBuYW1lID09IGNvbnRleHQubmFtZToKICAgICAgICAgICAgZGVsIHJ1bnNbaV0KICAgICAgICBpICs9IDEKCiAgICBwcmludCgiZXJyb3JzOiIsIGhhZF9lcnJvcnMpCgogICAgaHRtbCA9IGYiIyMjIFJ1biBSZXN1bHRzXG5Xb3JrZmxvdyB7d29ya2Zsb3dfaWR9IGZpbmlzaGVkIHdpdGgge2hhZF9lcnJvcnN9IGVycm9ycyIKICAgIGh0bWwgKz0gIjxicj5jbGljayB0aGUgaHlwZXIgbGlua3MgYmVsb3cgdG8gc2VlIGRldGFpbGVkIHJlc3VsdHM8YnI+IgogICAgaHRtbCArPSBydW5zLnNob3coZGlzcGxheT1GYWxzZSwgc2hvcnQ9VHJ1ZSkKICAgIGlmIHJlcG86CiAgICAgICAgcHJfY29tbWVudChjb250ZXh0LCByZXBvLCBpc3N1ZSwgaHRtbCkKICAgIGVsc2U6CiAgICAgICAgcHJpbnQoInJlcG8gbm90IGRlZmluZWQiKQogICAgICAgIHByaW50KGh0bWwpCg== + code_origin: '' + filename: github_utils.py entry_points: pr_comment: - name: pr_comment - doc: '' parameters: - name: context - default: '' - name: repo type: str - default: '' - name: issue type: int - default: '' - name: message type: str default: '' - name: message_file type: DataItem default: null - outputs: - - default: '' - lineno: 8 - run_summary_comment: - name: run_summary_comment + name: pr_comment doc: '' + has_kwargs: false + has_varargs: false + lineno: 23 + run_summary_comment: parameters: - name: context - default: '' - name: workflow_id - default: '' - name: repo type: str - default: '' - name: issue type: int - default: '' - name: project default: '' - outputs: - - default: '' - lineno: 31 + name: run_summary_comment + doc: '' + has_kwargs: false + has_varargs: false + lineno: 45 + command: '' description: add comments to github pull request - build: - functionSourceCode: IyBHZW5lcmF0ZWQgYnkgbnVjbGlvLmV4cG9ydC5OdWNsaW9FeHBvcnRlcgoKaW1wb3J0IHJlcXVlc3RzCmltcG9ydCBvcwpmcm9tIG1scnVuIGltcG9ydCBEYXRhSXRlbSwgZ2V0X3J1bl9kYiwgbWxjb25mCgoKZGVmIHByX2NvbW1lbnQoCiAgICBjb250ZXh0LCByZXBvOiBzdHIsIGlzc3VlOiBpbnQsIG1lc3NhZ2U6IHN0ciA9ICIiLCBtZXNzYWdlX2ZpbGU6IERhdGFJdGVtID0gTm9uZQopOgoKICAgIHRva2VuID0gY29udGV4dC5nZXRfc2VjcmV0KCJHSVRIVUJfVE9LRU4iKSBvciBvcy5lbnZpcm9uLmdldCgiR0lUSFVCX1RPS0VOIikKICAgIGlmIG1lc3NhZ2VfZmlsZSBhbmQgbm90IG1lc3NhZ2U6CiAgICAgICAgbWVzc2FnZSA9IG1lc3NhZ2VfZmlsZS5nZXQoKQogICAgZWxpZiBub3QgbWVzc2FnZSBhbmQgbm90IG1lc3NhZ2VfZmlsZToKICAgICAgICByYWlzZSBWYWx1ZUVycm9yKCJwciBtZXNzYWdlIG9yIG1lc3NhZ2UgZmlsZSBtdXN0IGJlIHByb3ZpZGVkIikKCiAgICBoZWFkZXJzID0gewogICAgICAgICJBY2NlcHQiOiAiYXBwbGljYXRpb24vdm5kLmdpdGh1Yi52Mytqc29uIiwKICAgICAgICAiQXV0aG9yaXphdGlvbiI6IGYidG9rZW4ge3Rva2VufSIsCiAgICB9CiAgICB1cmwgPSBmImh0dHBzOi8vYXBpLmdpdGh1Yi5jb20vcmVwb3Mve3JlcG99L2lzc3Vlcy97aXNzdWV9L2NvbW1lbnRzIgoKICAgIHJlc3AgPSByZXF1ZXN0cy5wb3N0KHVybD11cmwsIGpzb249eyJib2R5Ijogc3RyKG1lc3NhZ2UpfSwgaGVhZGVycz1oZWFkZXJzKQogICAgaWYgbm90IHJlc3Aub2s6CiAgICAgICAgZXJybXNnID0gZiJiYWQgcHIgY29tbWVudCByZXNwISFcbntyZXNwLnRleHR9IgogICAgICAgIGNvbnRleHQubG9nZ2VyLmVycm9yKGVycm1zZykKICAgICAgICByYWlzZSBJT0Vycm9yKGVycm1zZykKCgpkZWYgcnVuX3N1bW1hcnlfY29tbWVudChjb250ZXh0LCB3b3JrZmxvd19pZCwgcmVwbzogc3RyLCBpc3N1ZTogaW50LCBwcm9qZWN0PSIiKToKICAgIGRiID0gZ2V0X3J1bl9kYigpLmNvbm5lY3QoKQogICAgcHJvamVjdCA9IHByb2plY3Qgb3IgY29udGV4dC5wcm9qZWN0CiAgICBydW5zID0gZGIubGlzdF9ydW5zKHByb2plY3Q9cHJvamVjdCwgbGFiZWxzPWYid29ya2Zsb3c9e3dvcmtmbG93X2lkfSIpCgogICAgaGFkX2Vycm9ycyA9IGkgPSAwCiAgICBmb3IgciBpbiBydW5zOgogICAgICAgIG5hbWUgPSByWyJtZXRhZGF0YSJdWyJuYW1lIl0KICAgICAgICBpZiByWyJzdGF0dXMiXS5nZXQoInN0YXRlIiwgIiIpID09ICJlcnJvciI6CiAgICAgICAgICAgIGhhZF9lcnJvcnMgKz0gMQogICAgICAgIGlmIG5hbWUgPT0gY29udGV4dC5uYW1lOgogICAgICAgICAgICBkZWwgcnVuc1tpXQogICAgICAgIGkgKz0gMQoKICAgIHByaW50KCJlcnJvcnM6IiwgaGFkX2Vycm9ycykKCiAgICBodG1sID0gIiMjIyBSdW4gUmVzdWx0c1xuV29ya2Zsb3cge30gZmluaXNoZWQgd2l0aCB7fSBlcnJvcnMiLmZvcm1hdCgKICAgICAgICB3b3JrZmxvd19pZCwgaGFkX2Vycm9ycwogICAgKQogICAgaHRtbCArPSAiPGJyPmNsaWNrIHRoZSBoeXBlciBsaW5rcyBiZWxvdyB0byBzZWUgZGV0YWlsZWQgcmVzdWx0czxicj4iCiAgICBodG1sICs9IHJ1bnMuc2hvdyhkaXNwbGF5PUZhbHNlLCBzaG9ydD1UcnVlKQogICAgaWYgcmVwbzoKICAgICAgICBwcl9jb21tZW50KGNvbnRleHQsIHJlcG8sIGlzc3VlLCBodG1sKQogICAgZWxzZToKICAgICAgICBwcmludCgicmVwbyBub3QgZGVmaW5lZCIpCiAgICAgICAgcHJpbnQoaHRtbCkK - commands: [] - code_origin: https://github.com/daniels290813/functions.git#55a79c32be5d233cc11efcf40cd3edbe309bfdef:/home/kali/functions/github_utils/github_utils.py - affinity: null -verbose: false + default_handler: run_summary_comment diff --git a/functions/src/github_utils/github_utils.py b/functions/src/github_utils/github_utils.py index dc70456a9..09ed6a7bb 100644 --- a/functions/src/github_utils/github_utils.py +++ b/functions/src/github_utils/github_utils.py @@ -14,15 +14,15 @@ # # Generated by nuclio.export.NuclioExporter -import requests import os -from mlrun import DataItem, get_run_db, mlconf + +import requests +from mlrun import DataItem, get_run_db def pr_comment( context, repo: str, issue: int, message: str = "", message_file: DataItem = None ): - token = context.get_secret("GITHUB_TOKEN") or os.environ.get("GITHUB_TOKEN") if message_file and not message: message = message_file.get() @@ -39,7 +39,7 @@ def pr_comment( if not resp.ok: errmsg = f"bad pr comment resp!!\n{resp.text}" context.logger.error(errmsg) - raise IOError(errmsg) + raise OSError(errmsg) def run_summary_comment(context, workflow_id, repo: str, issue: int, project=""): @@ -58,9 +58,7 @@ def run_summary_comment(context, workflow_id, repo: str, issue: int, project="") print("errors:", had_errors) - html = "### Run Results\nWorkflow {} finished with {} errors".format( - workflow_id, had_errors - ) + html = f"### Run Results\nWorkflow {workflow_id} finished with {had_errors} errors" html += "
click the hyper links below to see detailed results
" html += runs.show(display=False, short=True) if repo: diff --git a/functions/src/hugging_face_serving/function.yaml b/functions/src/hugging_face_serving/function.yaml index a628d7ab7..3da9128a9 100644 --- a/functions/src/hugging_face_serving/function.yaml +++ b/functions/src/hugging_face_serving/function.yaml @@ -1,31 +1,32 @@ metadata: + tag: '' name: hugging-face-serving categories: - genai - model-serving - tag: '' +verbose: false +kind: serving spec: - default_handler: '' - min_replicas: 1 - source: '' image: mlrun/ml-models + disable_auto_mount: false build: - functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKCmZyb20gYWJjIGltcG9ydCBBQkMKZnJvbSBpbXBvcnRsaWIgaW1wb3J0IGltcG9ydF9tb2R1bGUKZnJvbSB0eXBpbmcgaW1wb3J0IExpc3QKCmZyb20gdHJhbnNmb3JtZXJzIGltcG9ydCBwaXBlbGluZQoKaW1wb3J0IG1scnVuLnNlcnZpbmcKClBBQ0tBR0VfTU9EVUxFID0gInRyYW5zZm9ybWVycyIKU0VSSUFMSVpBQkxFX1RZUEVTID0gW2RpY3QsIGxpc3QsIHR1cGxlLCBzdHIsIGludCwgZmxvYXRdCgoKY2xhc3MgSHVnZ2luZ0ZhY2VNb2RlbFNlcnZlcihtbHJ1bi5zZXJ2aW5nLlYyTW9kZWxTZXJ2ZXIsIEFCQyk6CiAgICAiIiIKICAgIEh1Z2dpbmcgRmFjZSBNb2RlbCBzZXJ2aW5nIGNsYXNzLCBpbmhlcml0aW5nIHRoZSBWMk1vZGVsU2VydmVyIGNsYXNzIGZvciBiZWluZyBpbml0aWFsaXplZCBhdXRvbWF0aWNhbGx5IGJ5IHRoZQogICAgbW9kZWwgc2VydmVyIGFuZCBiZSBhYmxlIHRvIHJ1biBsb2NhbGx5IGFzIHBhcnQgb2YgYSBudWNsaW8gc2VydmVybGVzcyBmdW5jdGlvbiwgb3IgYXMgcGFydCBvZiBhIHJlYWwtdGltZSBwaXBlbGluZS4KICAgICIiIgoKICAgIGRlZiBfX2luaXRfXygKICAgICAgICBzZWxmLAogICAgICAgIGNvbnRleHQ6IG1scnVuLk1MQ2xpZW50Q3R4LAogICAgICAgIG5hbWU6IHN0ciwKICAgICAgICB0YXNrOiBzdHIsCiAgICAgICAgbW9kZWxfcGF0aDogc3RyID0gTm9uZSwKICAgICAgICBtb2RlbF9uYW1lOiBzdHIgPSBOb25lLAogICAgICAgIG1vZGVsX2NsYXNzOiBzdHIgPSBOb25lLAogICAgICAgIHRva2VuaXplcl9uYW1lOiBzdHIgPSBOb25lLAogICAgICAgIHRva2VuaXplcl9jbGFzczogc3RyID0gTm9uZSwKICAgICAgICBmcmFtZXdvcms6IHN0ciA9IE5vbmUsCiAgICAgICAgKipjbGFzc19hcmdzLAogICAgKToKICAgICAgICAiIiIKICAgICAgICBJbml0aWFsaXplIGEgc2VydmluZyBjbGFzcyBmb3IgYSBIdWdnaW5nIGZhY2UgbW9kZWwuCgogICAgICAgIDpwYXJhbSBjb250ZXh0OiAgICAgICAgIFRoZSBtbHJ1biBjb250ZXh0IHRvIHdvcmsgd2l0aAogICAgICAgIDpwYXJhbSBuYW1lOiAgICAgICAgICAgIFRoZSBuYW1lIG9mIHRoaXMgc2VydmVyIHRvIGJlIGluaXRpYWxpemVkCiAgICAgICAgOnBhcmFtIG1vZGVsX3BhdGg6ICAgICAgTm90IGluIHVzZS4gV2hlbiBhZGRpbmcgYSBtb2RlbCBwYXNzIGFueSBzdHJpbmcgdmFsdWUKICAgICAgICA6cGFyYW0gbW9kZWxfbmFtZTogICAgICBUaGUgbW9kZWwncyBuYW1lIGluIHRoZSBIdWdnaW5nIEZhY2UgaHViCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZS5nLiwgYG5scHRvd24vYmVydC1iYXNlLW11bHRpbGluZ3VhbC11bmNhc2VkLXNlbnRpbWVudGAKICAgICAgICA6cGFyYW0gbW9kZWxfY2xhc3M6ICAgICBUaGUgbW9kZWwncyBjbGFzcyB0eXBlIG9iamVjdCB3aGljaCBjYW4gYmUgcGFzc2VkIGFzIHRoZSBjbGFzcydzIG5hbWUgKHN0cmluZykuCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgTXVzdCBiZSBwcm92aWRlZCBhbmQgdG8gYmUgbWF0Y2hlZCB3aXRoIGBtb2RlbF9uYW1lYC4KICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBlLmcuLCBgQXV0b01vZGVsRm9yU2VxdWVuY2VDbGFzc2lmaWNhdGlvbmAKICAgICAgICA6cGFyYW0gdG9rZW5pemVyX25hbWU6ICBUaGUgdG9rZW5pemVyJ3MgbmFtZSBpbiB0aGUgSHVnZ2luZyBGYWNlIGh1YgogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGUuZy4sIGBubHB0b3duL2JlcnQtYmFzZS1tdWx0aWxpbmd1YWwtdW5jYXNlZC1zZW50aW1lbnRgCiAgICAgICAgOnBhcmFtIHRva2VuaXplcl9jbGFzczogVGhlIG1vZGVsJ3MgY2xhc3MgdHlwZSBvYmplY3Qgd2hpY2ggY2FuIGJlIHBhc3NlZCBhcyB0aGUgY2xhc3MncyBuYW1lIChzdHJpbmcpLgogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIE11c3QgYmUgcHJvdmlkZWQgYW5kIHRvIGJlIG1hdGNoZWQgd2l0aCBgbW9kZWxfbmFtZWAuCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZS5nLiwgYEF1dG9Ub2tlbml6ZXJgCiAgICAgICAgOnBhcmFtIGZyYW1ld29yazogICAgICAgVGhlIGZyYW1ld29yayB0byB1c2UsIGVpdGhlciBgInB0ImAgZm9yIFB5VG9yY2ggb3IgYCJ0ZiJgIGZvciBUZW5zb3JGbG93LiBUaGUgc3BlY2lmaWVkCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZnJhbWV3b3JrIG11c3QgYmUgaW5zdGFsbGVkLgogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIElmIG5vIGZyYW1ld29yayBpcyBzcGVjaWZpZWQsIHdpbGwgZGVmYXVsdCB0byB0aGUgb25lIGN1cnJlbnRseSBpbnN0YWxsZWQuCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgSWYgbm8gZnJhbWV3b3JrIGlzIHNwZWNpZmllZCBhbmQgYm90aCBmcmFtZXdvcmtzIGFyZSBpbnN0YWxsZWQsIHdpbGwgZGVmYXVsdCB0byB0aGUKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBmcmFtZXdvcmsgb2YgdGhlIGBtb2RlbGAsIG9yIHRvIFB5VG9yY2ggaWYgbm8gbW9kZWwgaXMgcHJvdmlkZWQuCiAgICAgICAgOnBhcmFtIGNsYXNzX2FyZ3M6ICAgICAgLQogICAgICAgICIiIgogICAgICAgIHN1cGVyKEh1Z2dpbmdGYWNlTW9kZWxTZXJ2ZXIsIHNlbGYpLl9faW5pdF9fKAogICAgICAgICAgICBjb250ZXh0PWNvbnRleHQsCiAgICAgICAgICAgIG5hbWU9bmFtZSwKICAgICAgICAgICAgbW9kZWxfcGF0aD1tb2RlbF9wYXRoLAogICAgICAgICAgICAqKmNsYXNzX2FyZ3MsCiAgICAgICAgKQogICAgICAgIHNlbGYudGFzayA9IHRhc2sKICAgICAgICBzZWxmLm1vZGVsID0gTm9uZQogICAgICAgIHNlbGYudG9rZW5pemVyID0gTm9uZQogICAgICAgIHNlbGYubW9kZWxfbmFtZSA9IG1vZGVsX25hbWUKICAgICAgICBzZWxmLnRva2VuaXplcl9uYW1lID0gdG9rZW5pemVyX25hbWUKICAgICAgICBzZWxmLm1vZGVsX2NsYXNzID0gbW9kZWxfY2xhc3MKICAgICAgICBzZWxmLnRva2VuaXplcl9jbGFzcyA9IHRva2VuaXplcl9jbGFzcwogICAgICAgIHNlbGYuZnJhbWV3b3JrID0gZnJhbWV3b3JrCiAgICAgICAgc2VsZi5waXBlID0gTm9uZQoKICAgIGRlZiBsb2FkKHNlbGYpOgogICAgICAgICIiImxvYWQgYW5kIGluaXRpYWxpemUgdGhlIG1vZGVsIGFuZC9vciBvdGhlciBlbGVtZW50cyIiIgogICAgICAgIGlmIHNlbGYubW9kZWxfY2xhc3M6CiAgICAgICAgICAgIG1vZGVsX29iamVjdCA9IGdldGF0dHIoaW1wb3J0X21vZHVsZShQQUNLQUdFX01PRFVMRSksIHNlbGYubW9kZWxfY2xhc3MpCiAgICAgICAgICAgIHNlbGYubW9kZWwgPSBtb2RlbF9vYmplY3QuZnJvbV9wcmV0cmFpbmVkKHNlbGYubW9kZWxfbmFtZSkKICAgICAgICBpZiBzZWxmLnRva2VuaXplcl9jbGFzczoKICAgICAgICAgICAgdG9rZW5pemVyX29iamVjdCA9IGdldGF0dHIoCiAgICAgICAgICAgICAgICBpbXBvcnRfbW9kdWxlKFBBQ0tBR0VfTU9EVUxFKSwgc2VsZi50b2tlbml6ZXJfY2xhc3MKICAgICAgICAgICAgKQogICAgICAgICAgICBzZWxmLnRva2VuaXplciA9IHRva2VuaXplcl9vYmplY3QuZnJvbV9wcmV0cmFpbmVkKHNlbGYudG9rZW5pemVyX25hbWUpCiAgICAgICAgc2VsZi5waXBlID0gcGlwZWxpbmUoCiAgICAgICAgICAgIHRhc2s9c2VsZi50YXNrLAogICAgICAgICAgICBtb2RlbD1zZWxmLm1vZGVsIG9yIHNlbGYubW9kZWxfbmFtZSwKICAgICAgICAgICAgdG9rZW5pemVyPXNlbGYudG9rZW5pemVyLAogICAgICAgICAgICBmcmFtZXdvcms9c2VsZi5mcmFtZXdvcmssCiAgICAgICAgKQoKICAgIGRlZiBwcmVkaWN0KHNlbGYsIGJvZHk6IGRpY3QpIC0+IExpc3Q6CiAgICAgICAgIiIiR2VuZXJhdGUgbW9kZWwgcHJlZGljdGlvbnMgZnJvbSBzYW1wbGUuIiIiCiAgICAgICAgaWYgc2VsZi5waXBlIGlzIE5vbmU6CiAgICAgICAgICAgIHJhaXNlIFZhbHVlRXJyb3IoIlBsZWFzZSB1c2UgYC5sb2FkKClgIikKICAgICAgICB0cnk6CiAgICAgICAgICAgIGlmIGlzaW5zdGFuY2UoYm9keVsiaW5wdXRzIl1bMF0sIGRpY3QpOgogICAgICAgICAgICAgICAgcmVzdWx0ID0gW3NlbGYucGlwZSgqKl9pbnB1dCkgZm9yIF9pbnB1dCBpbiBib2R5WyJpbnB1dHMiXV0KICAgICAgICAgICAgZWxzZToKICAgICAgICAgICAgICAgIHJlc3VsdCA9IHNlbGYucGlwZShib2R5WyJpbnB1dHMiXSkKICAgICAgICAgICAgIyByZXBsYWNlIGxpc3Qgb2YgbGlzdHMgb2YgZGljdHMgaW50byBhIGxpc3Qgb2YgZGljdHM6CiAgICAgICAgICAgIGlmIGFsbChpc2luc3RhbmNlKHJlcywgbGlzdCkgZm9yIHJlcyBpbiByZXN1bHQpOgogICAgICAgICAgICAgICAgbmV3X3Jlc3VsdCA9IFtyZXNbMF0gZm9yIHJlcyBpbiByZXN1bHRdCiAgICAgICAgICAgICAgICByZXN1bHQgPSBuZXdfcmVzdWx0CgogICAgICAgICAgICBub25fc2VyaWFsaXphYmxlX3R5cGVzID0gW10KICAgICAgICAgICAgZm9yIHJlcyBpbiByZXN1bHQ6CiAgICAgICAgICAgICAgICBmb3Iga2V5LCB2YWwgaW4gcmVzLml0ZW1zKCk6CiAgICAgICAgICAgICAgICAgICAgaWYgdHlwZSh2YWwpIG5vdCBpbiBTRVJJQUxJWkFCTEVfVFlQRVM6CiAgICAgICAgICAgICAgICAgICAgICAgIG5vbl9zZXJpYWxpemFibGVfdHlwZXMuYXBwZW5kKHN0cih0eXBlKHZhbCkpKQogICAgICAgICAgICAgICAgICAgICAgICByZXNba2V5XSA9IHN0cih2YWwpCiAgICAgICAgICAgIGlmIG5vbl9zZXJpYWxpemFibGVfdHlwZXM6CiAgICAgICAgICAgICAgICBzZWxmLmNvbnRleHQubG9nZ2VyLmluZm8oCiAgICAgICAgICAgICAgICAgICAgZiJOb24tc2VyaWFsaXphYmxlIHR5cGVzOiB7bm9uX3NlcmlhbGl6YWJsZV90eXBlc30gd2VyZSBjYXN0ZWQgdG8gc3RyaW5ncyIKICAgICAgICAgICAgICAgICkKICAgICAgICBleGNlcHQgRXhjZXB0aW9uIGFzIGU6CiAgICAgICAgICAgIHJhaXNlIEV4Y2VwdGlvbigiRmFpbGVkIHRvIHByZWRpY3QgJXMiICUgZSkKICAgICAgICByZXR1cm4gcmVzdWx0Cgpmcm9tIG1scnVuLnJ1bnRpbWVzIGltcG9ydCBudWNsaW9faW5pdF9ob29rCmRlZiBpbml0X2NvbnRleHQoY29udGV4dCk6CiAgICBudWNsaW9faW5pdF9ob29rKGNvbnRleHQsIGdsb2JhbHMoKSwgJ3NlcnZpbmdfdjInKQoKZGVmIGhhbmRsZXIoY29udGV4dCwgZXZlbnQpOgogICAgcmV0dXJuIGNvbnRleHQubWxydW5faGFuZGxlcihjb250ZXh0LCBldmVudCkK - code_origin: '' origin_filename: '' + functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKCmZyb20gYWJjIGltcG9ydCBBQkMKZnJvbSBpbXBvcnRsaWIgaW1wb3J0IGltcG9ydF9tb2R1bGUKCmltcG9ydCBtbHJ1bi5zZXJ2aW5nCmZyb20gdHJhbnNmb3JtZXJzIGltcG9ydCBwaXBlbGluZQoKUEFDS0FHRV9NT0RVTEUgPSAidHJhbnNmb3JtZXJzIgpTRVJJQUxJWkFCTEVfVFlQRVMgPSBbZGljdCwgbGlzdCwgdHVwbGUsIHN0ciwgaW50LCBmbG9hdF0KCgpjbGFzcyBIdWdnaW5nRmFjZU1vZGVsU2VydmVyKG1scnVuLnNlcnZpbmcuVjJNb2RlbFNlcnZlciwgQUJDKToKICAgICIiIgogICAgSHVnZ2luZyBGYWNlIE1vZGVsIHNlcnZpbmcgY2xhc3MsIGluaGVyaXRpbmcgdGhlIFYyTW9kZWxTZXJ2ZXIgY2xhc3MgZm9yIGJlaW5nIGluaXRpYWxpemVkIGF1dG9tYXRpY2FsbHkgYnkgdGhlCiAgICBtb2RlbCBzZXJ2ZXIgYW5kIGJlIGFibGUgdG8gcnVuIGxvY2FsbHkgYXMgcGFydCBvZiBhIG51Y2xpbyBzZXJ2ZXJsZXNzIGZ1bmN0aW9uLCBvciBhcyBwYXJ0IG9mIGEgcmVhbC10aW1lIHBpcGVsaW5lLgogICAgIiIiCgogICAgZGVmIF9faW5pdF9fKAogICAgICAgIHNlbGYsCiAgICAgICAgY29udGV4dDogbWxydW4uTUxDbGllbnRDdHgsCiAgICAgICAgbmFtZTogc3RyLAogICAgICAgIHRhc2s6IHN0ciwKICAgICAgICBtb2RlbF9wYXRoOiBzdHIgPSBOb25lLAogICAgICAgIG1vZGVsX25hbWU6IHN0ciA9IE5vbmUsCiAgICAgICAgbW9kZWxfY2xhc3M6IHN0ciA9IE5vbmUsCiAgICAgICAgdG9rZW5pemVyX25hbWU6IHN0ciA9IE5vbmUsCiAgICAgICAgdG9rZW5pemVyX2NsYXNzOiBzdHIgPSBOb25lLAogICAgICAgIGZyYW1ld29yazogc3RyID0gTm9uZSwKICAgICAgICAqKmNsYXNzX2FyZ3MsCiAgICApOgogICAgICAgICIiIgogICAgICAgIEluaXRpYWxpemUgYSBzZXJ2aW5nIGNsYXNzIGZvciBhIEh1Z2dpbmcgZmFjZSBtb2RlbC4KCiAgICAgICAgOnBhcmFtIGNvbnRleHQ6ICAgICAgICAgVGhlIG1scnVuIGNvbnRleHQgdG8gd29yayB3aXRoCiAgICAgICAgOnBhcmFtIG5hbWU6ICAgICAgICAgICAgVGhlIG5hbWUgb2YgdGhpcyBzZXJ2ZXIgdG8gYmUgaW5pdGlhbGl6ZWQKICAgICAgICA6cGFyYW0gbW9kZWxfcGF0aDogICAgICBOb3QgaW4gdXNlLiBXaGVuIGFkZGluZyBhIG1vZGVsIHBhc3MgYW55IHN0cmluZyB2YWx1ZQogICAgICAgIDpwYXJhbSBtb2RlbF9uYW1lOiAgICAgIFRoZSBtb2RlbCdzIG5hbWUgaW4gdGhlIEh1Z2dpbmcgRmFjZSBodWIKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBlLmcuLCBgbmxwdG93bi9iZXJ0LWJhc2UtbXVsdGlsaW5ndWFsLXVuY2FzZWQtc2VudGltZW50YAogICAgICAgIDpwYXJhbSBtb2RlbF9jbGFzczogICAgIFRoZSBtb2RlbCdzIGNsYXNzIHR5cGUgb2JqZWN0IHdoaWNoIGNhbiBiZSBwYXNzZWQgYXMgdGhlIGNsYXNzJ3MgbmFtZSAoc3RyaW5nKS4KICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBNdXN0IGJlIHByb3ZpZGVkIGFuZCB0byBiZSBtYXRjaGVkIHdpdGggYG1vZGVsX25hbWVgLgogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGUuZy4sIGBBdXRvTW9kZWxGb3JTZXF1ZW5jZUNsYXNzaWZpY2F0aW9uYAogICAgICAgIDpwYXJhbSB0b2tlbml6ZXJfbmFtZTogIFRoZSB0b2tlbml6ZXIncyBuYW1lIGluIHRoZSBIdWdnaW5nIEZhY2UgaHViCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZS5nLiwgYG5scHRvd24vYmVydC1iYXNlLW11bHRpbGluZ3VhbC11bmNhc2VkLXNlbnRpbWVudGAKICAgICAgICA6cGFyYW0gdG9rZW5pemVyX2NsYXNzOiBUaGUgbW9kZWwncyBjbGFzcyB0eXBlIG9iamVjdCB3aGljaCBjYW4gYmUgcGFzc2VkIGFzIHRoZSBjbGFzcydzIG5hbWUgKHN0cmluZykuCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgTXVzdCBiZSBwcm92aWRlZCBhbmQgdG8gYmUgbWF0Y2hlZCB3aXRoIGBtb2RlbF9uYW1lYC4KICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBlLmcuLCBgQXV0b1Rva2VuaXplcmAKICAgICAgICA6cGFyYW0gZnJhbWV3b3JrOiAgICAgICBUaGUgZnJhbWV3b3JrIHRvIHVzZSwgZWl0aGVyIGAicHQiYCBmb3IgUHlUb3JjaCBvciBgInRmImAgZm9yIFRlbnNvckZsb3cuIFRoZSBzcGVjaWZpZWQKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBmcmFtZXdvcmsgbXVzdCBiZSBpbnN0YWxsZWQuCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgSWYgbm8gZnJhbWV3b3JrIGlzIHNwZWNpZmllZCwgd2lsbCBkZWZhdWx0IHRvIHRoZSBvbmUgY3VycmVudGx5IGluc3RhbGxlZC4KICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBJZiBubyBmcmFtZXdvcmsgaXMgc3BlY2lmaWVkIGFuZCBib3RoIGZyYW1ld29ya3MgYXJlIGluc3RhbGxlZCwgd2lsbCBkZWZhdWx0IHRvIHRoZQogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGZyYW1ld29yayBvZiB0aGUgYG1vZGVsYCwgb3IgdG8gUHlUb3JjaCBpZiBubyBtb2RlbCBpcyBwcm92aWRlZC4KICAgICAgICA6cGFyYW0gY2xhc3NfYXJnczogICAgICAtCiAgICAgICAgIiIiCiAgICAgICAgc3VwZXIoSHVnZ2luZ0ZhY2VNb2RlbFNlcnZlciwgc2VsZikuX19pbml0X18oCiAgICAgICAgICAgIGNvbnRleHQ9Y29udGV4dCwKICAgICAgICAgICAgbmFtZT1uYW1lLAogICAgICAgICAgICBtb2RlbF9wYXRoPW1vZGVsX3BhdGgsCiAgICAgICAgICAgICoqY2xhc3NfYXJncywKICAgICAgICApCiAgICAgICAgc2VsZi50YXNrID0gdGFzawogICAgICAgIHNlbGYubW9kZWwgPSBOb25lCiAgICAgICAgc2VsZi50b2tlbml6ZXIgPSBOb25lCiAgICAgICAgc2VsZi5tb2RlbF9uYW1lID0gbW9kZWxfbmFtZQogICAgICAgIHNlbGYudG9rZW5pemVyX25hbWUgPSB0b2tlbml6ZXJfbmFtZQogICAgICAgIHNlbGYubW9kZWxfY2xhc3MgPSBtb2RlbF9jbGFzcwogICAgICAgIHNlbGYudG9rZW5pemVyX2NsYXNzID0gdG9rZW5pemVyX2NsYXNzCiAgICAgICAgc2VsZi5mcmFtZXdvcmsgPSBmcmFtZXdvcmsKICAgICAgICBzZWxmLnBpcGUgPSBOb25lCgogICAgZGVmIGxvYWQoc2VsZik6CiAgICAgICAgIiIibG9hZCBhbmQgaW5pdGlhbGl6ZSB0aGUgbW9kZWwgYW5kL29yIG90aGVyIGVsZW1lbnRzIiIiCiAgICAgICAgaWYgc2VsZi5tb2RlbF9jbGFzczoKICAgICAgICAgICAgbW9kZWxfb2JqZWN0ID0gZ2V0YXR0cihpbXBvcnRfbW9kdWxlKFBBQ0tBR0VfTU9EVUxFKSwgc2VsZi5tb2RlbF9jbGFzcykKICAgICAgICAgICAgc2VsZi5tb2RlbCA9IG1vZGVsX29iamVjdC5mcm9tX3ByZXRyYWluZWQoc2VsZi5tb2RlbF9uYW1lKQogICAgICAgIGlmIHNlbGYudG9rZW5pemVyX2NsYXNzOgogICAgICAgICAgICB0b2tlbml6ZXJfb2JqZWN0ID0gZ2V0YXR0cigKICAgICAgICAgICAgICAgIGltcG9ydF9tb2R1bGUoUEFDS0FHRV9NT0RVTEUpLCBzZWxmLnRva2VuaXplcl9jbGFzcwogICAgICAgICAgICApCiAgICAgICAgICAgIHNlbGYudG9rZW5pemVyID0gdG9rZW5pemVyX29iamVjdC5mcm9tX3ByZXRyYWluZWQoc2VsZi50b2tlbml6ZXJfbmFtZSkKICAgICAgICBzZWxmLnBpcGUgPSBwaXBlbGluZSgKICAgICAgICAgICAgdGFzaz1zZWxmLnRhc2ssCiAgICAgICAgICAgIG1vZGVsPXNlbGYubW9kZWwgb3Igc2VsZi5tb2RlbF9uYW1lLAogICAgICAgICAgICB0b2tlbml6ZXI9c2VsZi50b2tlbml6ZXIsCiAgICAgICAgICAgIGZyYW1ld29yaz1zZWxmLmZyYW1ld29yaywKICAgICAgICApCgogICAgZGVmIHByZWRpY3Qoc2VsZiwgYm9keTogZGljdCkgLT4gbGlzdDoKICAgICAgICAiIiJHZW5lcmF0ZSBtb2RlbCBwcmVkaWN0aW9ucyBmcm9tIHNhbXBsZS4iIiIKICAgICAgICBpZiBzZWxmLnBpcGUgaXMgTm9uZToKICAgICAgICAgICAgcmFpc2UgVmFsdWVFcnJvcigiUGxlYXNlIHVzZSBgLmxvYWQoKWAiKQogICAgICAgIHRyeToKICAgICAgICAgICAgaWYgaXNpbnN0YW5jZShib2R5WyJpbnB1dHMiXVswXSwgZGljdCk6CiAgICAgICAgICAgICAgICByZXN1bHQgPSBbc2VsZi5waXBlKCoqX2lucHV0KSBmb3IgX2lucHV0IGluIGJvZHlbImlucHV0cyJdXQogICAgICAgICAgICBlbHNlOgogICAgICAgICAgICAgICAgcmVzdWx0ID0gc2VsZi5waXBlKGJvZHlbImlucHV0cyJdKQogICAgICAgICAgICAjIHJlcGxhY2UgbGlzdCBvZiBsaXN0cyBvZiBkaWN0cyBpbnRvIGEgbGlzdCBvZiBkaWN0czoKICAgICAgICAgICAgaWYgYWxsKGlzaW5zdGFuY2UocmVzLCBsaXN0KSBmb3IgcmVzIGluIHJlc3VsdCk6CiAgICAgICAgICAgICAgICBuZXdfcmVzdWx0ID0gW3Jlc1swXSBmb3IgcmVzIGluIHJlc3VsdF0KICAgICAgICAgICAgICAgIHJlc3VsdCA9IG5ld19yZXN1bHQKCiAgICAgICAgICAgIG5vbl9zZXJpYWxpemFibGVfdHlwZXMgPSBbXQogICAgICAgICAgICBmb3IgcmVzIGluIHJlc3VsdDoKICAgICAgICAgICAgICAgIGZvciBrZXksIHZhbCBpbiByZXMuaXRlbXMoKToKICAgICAgICAgICAgICAgICAgICBpZiB0eXBlKHZhbCkgbm90IGluIFNFUklBTElaQUJMRV9UWVBFUzoKICAgICAgICAgICAgICAgICAgICAgICAgbm9uX3NlcmlhbGl6YWJsZV90eXBlcy5hcHBlbmQoc3RyKHR5cGUodmFsKSkpCiAgICAgICAgICAgICAgICAgICAgICAgIHJlc1trZXldID0gc3RyKHZhbCkKICAgICAgICAgICAgaWYgbm9uX3NlcmlhbGl6YWJsZV90eXBlczoKICAgICAgICAgICAgICAgIHNlbGYuY29udGV4dC5sb2dnZXIuaW5mbygKICAgICAgICAgICAgICAgICAgICBmIk5vbi1zZXJpYWxpemFibGUgdHlwZXM6IHtub25fc2VyaWFsaXphYmxlX3R5cGVzfSB3ZXJlIGNhc3RlZCB0byBzdHJpbmdzIgogICAgICAgICAgICAgICAgKQogICAgICAgIGV4Y2VwdCBFeGNlcHRpb24gYXMgZToKICAgICAgICAgICAgcmFpc2UgRXhjZXB0aW9uKCJGYWlsZWQgdG8gcHJlZGljdCAlcyIgJSBlKQogICAgICAgIHJldHVybiByZXN1bHQKCmZyb20gbWxydW4ucnVudGltZXMgaW1wb3J0IG51Y2xpb19pbml0X2hvb2sKZGVmIGluaXRfY29udGV4dChjb250ZXh0KToKICAgIG51Y2xpb19pbml0X2hvb2soY29udGV4dCwgZ2xvYmFscygpLCAnc2VydmluZ192MicpCgpkZWYgaGFuZGxlcihjb250ZXh0LCBldmVudCk6CiAgICByZXR1cm4gY29udGV4dC5tbHJ1bl9oYW5kbGVyKGNvbnRleHQsIGV2ZW50KQo= requirements: - transformers==4.21.3 - tensorflow==2.9.2 - function_kind: serving_v2 + code_origin: '' + filename: hugging_face_serving.py default_class: HuggingFaceModelServer - base_image_pull: false - max_replicas: 4 + min_replicas: 1 command: '' - disable_auto_mount: false - function_handler: hugging-face-serving-nuclio:handler + default_handler: '' + source: '' + max_replicas: 4 + base_image_pull: false description: Generic Hugging Face model server. + function_kind: serving_v2 + function_handler: hugging-face-serving-nuclio:handler env: - name: MLRUN_HTTPDB__NUCLIO__EXPLICIT_ACK value: enabled -verbose: false -kind: serving diff --git a/functions/src/hugging_face_serving/hugging_face_serving.py b/functions/src/hugging_face_serving/hugging_face_serving.py index 06dc4207f..31ef144d1 100644 --- a/functions/src/hugging_face_serving/hugging_face_serving.py +++ b/functions/src/hugging_face_serving/hugging_face_serving.py @@ -15,11 +15,9 @@ from abc import ABC from importlib import import_module -from typing import List - -from transformers import pipeline import mlrun.serving +from transformers import pipeline PACKAGE_MODULE = "transformers" SERIALIZABLE_TYPES = [dict, list, tuple, str, int, float] @@ -100,7 +98,7 @@ def load(self): framework=self.framework, ) - def predict(self, body: dict) -> List: + def predict(self, body: dict) -> list: """Generate model predictions from sample.""" if self.pipe is None: raise ValueError("Please use `.load()`") diff --git a/functions/src/hugging_face_serving/test_hugging_face_serving.py b/functions/src/hugging_face_serving/test_hugging_face_serving.py index 6fdc02dd3..da1c68ec9 100644 --- a/functions/src/hugging_face_serving/test_hugging_face_serving.py +++ b/functions/src/hugging_face_serving/test_hugging_face_serving.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import mlrun import numpy as np import pytest -import mlrun - CLASS_NAME = "HuggingFaceModelServer" PIPELINES = [ @@ -81,7 +80,7 @@ def test_default_models(pipeline): ) server = serving_function.to_mock_server() result = server.test( - f'/v2/models/{pipeline["task"]}', body={"inputs": [pipeline["example"]]} + f"/v2/models/{pipeline['task']}", body={"inputs": [pipeline["example"]]} ) prediction = result["outputs"][0] assert all( @@ -90,7 +89,6 @@ def test_default_models(pipeline): def test_local_model_serving(): - serving_function = mlrun.import_function("function.yaml") # Adding model: diff --git a/functions/src/load_dataset/function.yaml b/functions/src/load_dataset/function.yaml index 91775a802..5fb3ca19f 100644 --- a/functions/src/load_dataset/function.yaml +++ b/functions/src/load_dataset/function.yaml @@ -1,40 +1,22 @@ -kind: job metadata: - name: load-dataset tag: '' - hash: d05aa41d618533335eeaeab38aa434a14e3e3980 - project: '' - labels: - author: Iguazio - framework: sklearn + name: load-dataset categories: - data-preparation +verbose: false +kind: job spec: - command: '' - args: [] image: mlrun/mlrun + disable_auto_mount: false build: + origin_filename: '' functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKaW1wb3J0IG51bXB5IGFzIG5wCmltcG9ydCBwYW5kYXMgYXMgcGQKZnJvbSBtbHJ1bi5leGVjdXRpb24gaW1wb3J0IE1MQ2xpZW50Q3R4CgoKZGVmIGxvYWRfZGF0YXNldCgKICAgIGNvbnRleHQ6IE1MQ2xpZW50Q3R4LAogICAgZGF0YXNldDogc3RyLAogICAgbmFtZTogc3RyID0gIiIsCiAgICBmaWxlX2V4dDogc3RyID0gInBhcnF1ZXQiLAogICAgcGFyYW1zOiBkaWN0ID0ge30sCikgLT4gTm9uZToKICAgICIiIkxvYWRzIGEgc2Npa2l0LWxlYXJuIHRveSBkYXRhc2V0IGZvciBjbGFzc2lmaWNhdGlvbiBvciByZWdyZXNzaW9uCgogICAgVGhlIGZvbGxvd2luZyBkYXRhc2V0cyBhcmUgYXZhaWxhYmxlICgnbmFtZScgOiBkZXNyaXB0aW9uKToKCiAgICAgICAgJ2Jvc3RvbicgICAgICAgICAgOiBib3N0b24gaG91c2UtcHJpY2VzIGRhdGFzZXQgKHJlZ3Jlc3Npb24pCiAgICAgICAgJ2lyaXMnICAgICAgICAgICAgOiBpcmlzIGRhdGFzZXQgKGNsYXNzaWZpY2F0aW9uKQogICAgICAgICdkaWFiZXRlcycgICAgICAgIDogZGlhYmV0ZXMgZGF0YXNldCAocmVncmVzc2lvbikKICAgICAgICAnZGlnaXRzJyAgICAgICAgICA6IGRpZ2l0cyBkYXRhc2V0IChjbGFzc2lmaWNhdGlvbikKICAgICAgICAnbGlubmVydWQnICAgICAgICA6IGxpbm5lcnVkIGRhdGFzZXQgKG11bHRpdmFyaWF0ZSByZWdyZXNzaW9uKQogICAgICAgICd3aW5lJyAgICAgICAgICAgIDogd2luZSBkYXRhc2V0IChjbGFzc2lmaWNhdGlvbikKICAgICAgICAnYnJlYXN0X2NhbmNlcicgICA6IGJyZWFzdCBjYW5jZXIgd2lzY29uc2luIGRhdGFzZXQgKGNsYXNzaWZpY2F0aW9uKQoKICAgIFRoZSBzY2lraXQtbGVhcm4gZnVuY3Rpb25zIHJldHVybiBhIGRhdGEgYnVuY2ggaW5jbHVkaW5nIHRoZSBmb2xsb3dpbmcgaXRlbXM6CiAgICAtIGRhdGEgICAgICAgICAgICAgIHRoZSBmZWF0dXJlcyBtYXRyaXgKICAgIC0gdGFyZ2V0ICAgICAgICAgICAgdGhlIGdyb3VuZCB0cnV0aCBsYWJlbHMKICAgIC0gREVTQ1IgICAgICAgICAgICAgYSBkZXNjcmlwdGlvbiBvZiB0aGUgZGF0YXNldAogICAgLSBmZWF0dXJlX25hbWVzICAgICBoZWFkZXIgZm9yIGRhdGEKCiAgICBUaGUgZmVhdHVyZXMgKGFuZCB0aGVpciBuYW1lcykgYXJlIHN0b3JlZCB3aXRoIHRoZSB0YXJnZXQgbGFiZWxzIGluIGEgRGF0YUZyYW1lLgoKICAgIEZvciBmdXJ0aGVyIGRldGFpbHMgc2VlIGh0dHBzOi8vc2Npa2l0LWxlYXJuLm9yZy9zdGFibGUvZGF0YXNldHMvaW5kZXguaHRtbCN0b3ktZGF0YXNldHMKCiAgICA6cGFyYW0gY29udGV4dDogICAgZnVuY3Rpb24gZXhlY3V0aW9uIGNvbnRleHQKICAgIDpwYXJhbSBkYXRhc2V0OiAgICBuYW1lIG9mIHRoZSBkYXRhc2V0IHRvIGxvYWQKICAgIDpwYXJhbSBuYW1lOiAgICAgICBhcnRpZmFjdCBuYW1lIChkZWZhdWx0cyB0byBkYXRhc2V0KQogICAgOnBhcmFtIGZpbGVfZXh0OiAgIG91dHB1dCBmaWxlX2V4dDogcGFycXVldCBvciBjc3YKICAgIDpwYXJhbSBwYXJhbXM6ICAgICBwYXJhbXMgb2YgdGhlIHNrbGVhcm4gbG9hZF9kYXRhIG1ldGhvZAogICAgIiIiCiAgICBkYXRhc2V0ID0gc3RyKGRhdGFzZXQpCiAgICBwa2dfbW9kdWxlID0gInNrbGVhcm4uZGF0YXNldHMiCiAgICBmbmFtZSA9IGYibG9hZF97ZGF0YXNldH0iCgogICAgcGtnX21vZHVsZSA9IF9faW1wb3J0X18ocGtnX21vZHVsZSwgZnJvbWxpc3Q9W2ZuYW1lXSkKICAgIGxvYWRfZGF0YV9mbiA9IGdldGF0dHIocGtnX21vZHVsZSwgZm5hbWUpCgogICAgZGF0YSA9IGxvYWRfZGF0YV9mbigqKnBhcmFtcykKICAgIGZlYXR1cmVfbmFtZXMgPSBkYXRhWyJmZWF0dXJlX25hbWVzIl0KCiAgICB4eSA9IG5wLmNvbmNhdGVuYXRlKFtkYXRhWyJkYXRhIl0sIGRhdGFbInRhcmdldCJdLnJlc2hhcGUoLTEsIDEpXSwgYXhpcz0xKQogICAgaWYgaGFzYXR0cihmZWF0dXJlX25hbWVzLCAiYXBwZW5kIik6CiAgICAgICAgZmVhdHVyZV9uYW1lcy5hcHBlbmQoImxhYmVscyIpCiAgICBlbHNlOgogICAgICAgIGZlYXR1cmVfbmFtZXMgPSBucC5hcHBlbmQoZmVhdHVyZV9uYW1lcywgImxhYmVscyIpCiAgICBkZiA9IHBkLkRhdGFGcmFtZShkYXRhPXh5LCBjb2x1bW5zPWZlYXR1cmVfbmFtZXMpCgogICAgY29udGV4dC5sb2dfZGF0YXNldChuYW1lIG9yIGRhdGFzZXQsIGRmPWRmLCBmb3JtYXQ9ZmlsZV9leHQsIGluZGV4PUZhbHNlKQo= - commands: [] code_origin: '' - origin_filename: '' - requirements: [] + filename: load_dataset.py entry_points: load_dataset: - name: load_dataset - doc: "Loads a scikit-learn toy dataset for classification or regression\n\n\ - The following datasets are available ('name' : desription):\n\n 'boston'\ - \ : boston house-prices dataset (regression)\n 'iris' \ - \ : iris dataset (classification)\n 'diabetes' : diabetes dataset\ - \ (regression)\n 'digits' : digits dataset (classification)\n\ - \ 'linnerud' : linnerud dataset (multivariate regression)\n 'wine'\ - \ : wine dataset (classification)\n 'breast_cancer' : breast\ - \ cancer wisconsin dataset (classification)\n\nThe scikit-learn functions\ - \ return a data bunch including the following items:\n- data \ - \ the features matrix\n- target the ground truth labels\n- DESCR\ - \ a description of the dataset\n- feature_names header for\ - \ data\n\nThe features (and their names) are stored with the target labels\ - \ in a DataFrame.\n\nFor further details see https://scikit-learn.org/stable/datasets/index.html#toy-datasets" + outputs: + - type: None parameters: - name: context type: MLClientCtx @@ -54,19 +36,23 @@ spec: type: dict doc: params of the sklearn load_data method default: {} - outputs: - - type: None - lineno: 20 - has_varargs: false + name: load_dataset + doc: "Loads a scikit-learn toy dataset for classification or regression\n\n\ + The following datasets are available ('name' : desription):\n\n 'boston'\ + \ : boston house-prices dataset (regression)\n 'iris' \ + \ : iris dataset (classification)\n 'diabetes' : diabetes dataset\ + \ (regression)\n 'digits' : digits dataset (classification)\n\ + \ 'linnerud' : linnerud dataset (multivariate regression)\n 'wine'\ + \ : wine dataset (classification)\n 'breast_cancer' : breast\ + \ cancer wisconsin dataset (classification)\n\nThe scikit-learn functions\ + \ return a data bunch including the following items:\n- data \ + \ the features matrix\n- target the ground truth labels\n- DESCR\ + \ a description of the dataset\n- feature_names header for\ + \ data\n\nThe features (and their names) are stored with the target labels\ + \ in a DataFrame.\n\nFor further details see https://scikit-learn.org/stable/datasets/index.html#toy-datasets" has_kwargs: false + has_varargs: false + lineno: 20 + command: '' description: load a toy dataset from scikit-learn default_handler: load_dataset - disable_auto_mount: false - clone_target_dir: '' - env: [] - priority_class_name: '' - preemption_mode: prevent - affinity: null - tolerations: null - security_context: {} -verbose: false diff --git a/functions/src/mlflow_utils/function.yaml b/functions/src/mlflow_utils/function.yaml index 623f054fb..96d04602d 100644 --- a/functions/src/mlflow_utils/function.yaml +++ b/functions/src/mlflow_utils/function.yaml @@ -1,32 +1,33 @@ +metadata: + tag: '' + name: mlflow-utils + categories: + - model-serving + - utils verbose: false +kind: serving spec: - command: '' - source: '' - default_class: MLFlowModelServer - function_kind: serving_v2 + image: mlrun/mlrun + disable_auto_mount: false build: - functionSourceCode: aW1wb3J0IHppcGZpbGUKZnJvbSB0eXBpbmcgaW1wb3J0IEFueSwgRGljdAppbXBvcnQgbWxmbG93CmZyb20gbWxydW4uc2VydmluZy52Ml9zZXJ2aW5nIGltcG9ydCBWMk1vZGVsU2VydmVyCmltcG9ydCBwYW5kYXMgYXMgcGQKCgpjbGFzcyBNTEZsb3dNb2RlbFNlcnZlcihWMk1vZGVsU2VydmVyKToKICAgICIiIgogICAgTUxGbG93IHRyYWNrZXIgTW9kZWwgc2VydmluZyBjbGFzcywgaW5oZXJpdGluZyB0aGUgVjJNb2RlbFNlcnZlciBjbGFzcyBmb3IgYmVpbmcgaW5pdGlhbGl6ZWQgYXV0b21hdGljYWxseSBieSB0aGUgbW9kZWwKICAgIHNlcnZlciBhbmQgYmUgYWJsZSB0byBydW4gbG9jYWxseSBhcyBwYXJ0IG9mIGEgbnVjbGlvIHNlcnZlcmxlc3MgZnVuY3Rpb24sIG9yIGFzIHBhcnQgb2YgYSByZWFsLXRpbWUgcGlwZWxpbmUuCiAgICAiIiIKCiAgICBkZWYgbG9hZChzZWxmKToKICAgICAgICAiIiIKICAgICAgICBsb2FkcyBhIG1vZGVsIHRoYXQgd2FzIGxvZ2dlZCBieSB0aGUgTUxGbG93IHRyYWNrZXIgbW9kZWwKICAgICAgICAiIiIKICAgICAgICAjIFVuemlwIHRoZSBtb2RlbCBkaXIgYW5kIHRoZW4gdXNlIG1sZmxvdydzIGxvYWQgZnVuY3Rpb24KICAgICAgICBtb2RlbF9maWxlLCBfID0gc2VsZi5nZXRfbW9kZWwoIi56aXAiKQogICAgICAgIG1vZGVsX3BhdGhfdW56aXAgPSBtb2RlbF9maWxlLnJlcGxhY2UoIi56aXAiLCAiIikKCiAgICAgICAgd2l0aCB6aXBmaWxlLlppcEZpbGUobW9kZWxfZmlsZSwgInIiKSBhcyB6aXBfcmVmOgogICAgICAgICAgICB6aXBfcmVmLmV4dHJhY3RhbGwobW9kZWxfcGF0aF91bnppcCkKCiAgICAgICAgc2VsZi5tb2RlbCA9IG1sZmxvdy5weWZ1bmMubG9hZF9tb2RlbChtb2RlbF9wYXRoX3VuemlwKQoKICAgIGRlZiBwcmVkaWN0KHNlbGYsIHJlcXVlc3Q6IERpY3Rbc3RyLCBBbnldKSAtPiBsaXN0OgogICAgICAgICIiIgogICAgICAgIEluZmVyIHRoZSBpbnB1dHMgdGhyb3VnaCB0aGUgbW9kZWwuIFRoZSBpbmZlcnJlZCBkYXRhIHdpbGwKICAgICAgICBiZSByZWFkIGZyb20gdGhlICJpbnB1dHMiIGtleSBvZiB0aGUgcmVxdWVzdC4KCiAgICAgICAgOnBhcmFtIHJlcXVlc3Q6IFRoZSByZXF1ZXN0IHRvIHRoZSBtb2RlbCB1c2luZyB4Z2Jvb3N0J3MgcHJlZGljdC4KICAgICAgICAgICAgICAgIFRoZSBpbnB1dCB0byB0aGUgbW9kZWwgd2lsbCBiZSByZWFkIGZyb20gdGhlICJpbnB1dHMiIGtleS4KCiAgICAgICAgOnJldHVybjogVGhlIG1vZGVsJ3MgcHJlZGljdGlvbiBvbiB0aGUgZ2l2ZW4gaW5wdXQuCiAgICAgICAgIiIiCgogICAgICAgICMgR2V0IHRoZSBpbnB1dHMgYW5kIHNldCB0byBhY2NlcHRlZCB0eXBlOgogICAgICAgIGlucHV0cyA9IHBkLkRhdGFGcmFtZShyZXF1ZXN0WyJpbnB1dHMiXSkKCiAgICAgICAgIyBQcmVkaWN0IHVzaW5nIHRoZSBtb2RlbCdzIHByZWRpY3QgZnVuY3Rpb246CiAgICAgICAgcHJlZGljdGlvbnMgPSBzZWxmLm1vZGVsLnByZWRpY3QoaW5wdXRzKQoKICAgICAgICAjIFJldHVybiBhcyBsaXN0OgogICAgICAgIHJldHVybiBwcmVkaWN0aW9ucy50b2xpc3QoKQoKZnJvbSBtbHJ1bi5ydW50aW1lcyBpbXBvcnQgbnVjbGlvX2luaXRfaG9vawpkZWYgaW5pdF9jb250ZXh0KGNvbnRleHQpOgogICAgbnVjbGlvX2luaXRfaG9vayhjb250ZXh0LCBnbG9iYWxzKCksICdzZXJ2aW5nX3YyJykKCmRlZiBoYW5kbGVyKGNvbnRleHQsIGV2ZW50KToKICAgIHJldHVybiBjb250ZXh0Lm1scnVuX2hhbmRsZXIoY29udGV4dCwgZXZlbnQpCg== + origin_filename: '' + functionSourceCode: aW1wb3J0IHppcGZpbGUKZnJvbSB0eXBpbmcgaW1wb3J0IEFueQoKaW1wb3J0IG1sZmxvdwppbXBvcnQgcGFuZGFzIGFzIHBkCmZyb20gbWxydW4uc2VydmluZy52Ml9zZXJ2aW5nIGltcG9ydCBWMk1vZGVsU2VydmVyCgoKY2xhc3MgTUxGbG93TW9kZWxTZXJ2ZXIoVjJNb2RlbFNlcnZlcik6CiAgICAiIiIKICAgIE1MRmxvdyB0cmFja2VyIE1vZGVsIHNlcnZpbmcgY2xhc3MsIGluaGVyaXRpbmcgdGhlIFYyTW9kZWxTZXJ2ZXIgY2xhc3MgZm9yIGJlaW5nIGluaXRpYWxpemVkIGF1dG9tYXRpY2FsbHkgYnkgdGhlIG1vZGVsCiAgICBzZXJ2ZXIgYW5kIGJlIGFibGUgdG8gcnVuIGxvY2FsbHkgYXMgcGFydCBvZiBhIG51Y2xpbyBzZXJ2ZXJsZXNzIGZ1bmN0aW9uLCBvciBhcyBwYXJ0IG9mIGEgcmVhbC10aW1lIHBpcGVsaW5lLgogICAgIiIiCgogICAgZGVmIGxvYWQoc2VsZik6CiAgICAgICAgIiIiCiAgICAgICAgbG9hZHMgYSBtb2RlbCB0aGF0IHdhcyBsb2dnZWQgYnkgdGhlIE1MRmxvdyB0cmFja2VyIG1vZGVsCiAgICAgICAgIiIiCiAgICAgICAgIyBVbnppcCB0aGUgbW9kZWwgZGlyIGFuZCB0aGVuIHVzZSBtbGZsb3cncyBsb2FkIGZ1bmN0aW9uCiAgICAgICAgbW9kZWxfZmlsZSwgXyA9IHNlbGYuZ2V0X21vZGVsKCIuemlwIikKICAgICAgICBtb2RlbF9wYXRoX3VuemlwID0gbW9kZWxfZmlsZS5yZXBsYWNlKCIuemlwIiwgIiIpCgogICAgICAgIHdpdGggemlwZmlsZS5aaXBGaWxlKG1vZGVsX2ZpbGUsICJyIikgYXMgemlwX3JlZjoKICAgICAgICAgICAgemlwX3JlZi5leHRyYWN0YWxsKG1vZGVsX3BhdGhfdW56aXApCgogICAgICAgIHNlbGYubW9kZWwgPSBtbGZsb3cucHlmdW5jLmxvYWRfbW9kZWwobW9kZWxfcGF0aF91bnppcCkKCiAgICBkZWYgcHJlZGljdChzZWxmLCByZXF1ZXN0OiBkaWN0W3N0ciwgQW55XSkgLT4gbGlzdDoKICAgICAgICAiIiIKICAgICAgICBJbmZlciB0aGUgaW5wdXRzIHRocm91Z2ggdGhlIG1vZGVsLiBUaGUgaW5mZXJyZWQgZGF0YSB3aWxsCiAgICAgICAgYmUgcmVhZCBmcm9tIHRoZSAiaW5wdXRzIiBrZXkgb2YgdGhlIHJlcXVlc3QuCgogICAgICAgIDpwYXJhbSByZXF1ZXN0OiBUaGUgcmVxdWVzdCB0byB0aGUgbW9kZWwgdXNpbmcgeGdib29zdCdzIHByZWRpY3QuCiAgICAgICAgICAgICAgICBUaGUgaW5wdXQgdG8gdGhlIG1vZGVsIHdpbGwgYmUgcmVhZCBmcm9tIHRoZSAiaW5wdXRzIiBrZXkuCgogICAgICAgIDpyZXR1cm46IFRoZSBtb2RlbCdzIHByZWRpY3Rpb24gb24gdGhlIGdpdmVuIGlucHV0LgogICAgICAgICIiIgoKICAgICAgICAjIEdldCB0aGUgaW5wdXRzIGFuZCBzZXQgdG8gYWNjZXB0ZWQgdHlwZToKICAgICAgICBpbnB1dHMgPSBwZC5EYXRhRnJhbWUocmVxdWVzdFsiaW5wdXRzIl0pCgogICAgICAgICMgUHJlZGljdCB1c2luZyB0aGUgbW9kZWwncyBwcmVkaWN0IGZ1bmN0aW9uOgogICAgICAgIHByZWRpY3Rpb25zID0gc2VsZi5tb2RlbC5wcmVkaWN0KGlucHV0cykKCiAgICAgICAgIyBSZXR1cm4gYXMgbGlzdDoKICAgICAgICByZXR1cm4gcHJlZGljdGlvbnMudG9saXN0KCkKCmZyb20gbWxydW4ucnVudGltZXMgaW1wb3J0IG51Y2xpb19pbml0X2hvb2sKZGVmIGluaXRfY29udGV4dChjb250ZXh0KToKICAgIG51Y2xpb19pbml0X2hvb2soY29udGV4dCwgZ2xvYmFscygpLCAnc2VydmluZ192MicpCgpkZWYgaGFuZGxlcihjb250ZXh0LCBldmVudCk6CiAgICByZXR1cm4gY29udGV4dC5tbHJ1bl9oYW5kbGVyKGNvbnRleHQsIGV2ZW50KQo= requirements: - - mlflow==2.12.2 + - mlflow~=2.22 - lightgbm - xgboost code_origin: '' - origin_filename: '' - image: mlrun/mlrun - base_image_pull: false + filename: mlflow_utils.py + default_class: MLFlowModelServer + min_replicas: 1 + command: '' default_handler: '' + source: '' max_replicas: 4 - disable_auto_mount: false - min_replicas: 1 + base_image_pull: false description: Mlflow model server, and additional utils. + function_kind: serving_v2 function_handler: mlflow-utils-nuclio:handler env: - name: MLRUN_HTTPDB__NUCLIO__EXPLICIT_ACK value: enabled -metadata: - categories: - - model-serving - - utils - name: mlflow-utils - tag: '' -kind: serving diff --git a/functions/src/mlflow_utils/mlflow_utils.py b/functions/src/mlflow_utils/mlflow_utils.py index fb6124bef..cbcc78381 100644 --- a/functions/src/mlflow_utils/mlflow_utils.py +++ b/functions/src/mlflow_utils/mlflow_utils.py @@ -1,8 +1,9 @@ import zipfile -from typing import Any, Dict +from typing import Any + import mlflow -from mlrun.serving.v2_serving import V2ModelServer import pandas as pd +from mlrun.serving.v2_serving import V2ModelServer class MLFlowModelServer(V2ModelServer): @@ -24,7 +25,7 @@ def load(self): self.model = mlflow.pyfunc.load_model(model_path_unzip) - def predict(self, request: Dict[str, Any]) -> list: + def predict(self, request: dict[str, Any]) -> list: """ Infer the inputs through the model. The inferred data will be read from the "inputs" key of the request. diff --git a/functions/src/mlflow_utils/test_mlflow_utils.py b/functions/src/mlflow_utils/test_mlflow_utils.py index 70d6ce03f..74dcefdbc 100644 --- a/functions/src/mlflow_utils/test_mlflow_utils.py +++ b/functions/src/mlflow_utils/test_mlflow_utils.py @@ -12,23 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os import tempfile import lightgbm as lgb import mlflow import mlflow.environment_variables import mlflow.xgboost + +# os.environ["MLRUN_IGNORE_ENV_FILE"] = "True" #TODO remove before push +import mlrun +import mlrun.launcher.local import pytest import xgboost as xgb from sklearn import datasets from sklearn.metrics import accuracy_score, log_loss from sklearn.model_selection import train_test_split -import os -# os.environ["MLRUN_IGNORE_ENV_FILE"] = "True" #TODO remove before push - -import mlrun -import mlrun.launcher.local # Important: # unlike mlconf which resets back to default after each test run, the mlflow configurations # and env vars don't, so at the end of each test we need to redo anything we set in that test. @@ -36,6 +36,7 @@ # name (last two using mlconf), failing run mid-way, and a run with no handler. # we also test here importing of runs, artifacts and models from a previous run. + # simple mlflow example of lgb logging def lgb_run(): # prepare train and test data @@ -170,10 +171,10 @@ def test_track_run_with_experiment_name(handler): server = serving_func.to_mock_server() # An example taken randomly - result = server.test(f"/v2/models/{model_name}/predict", {"inputs": [[5.1, 3.5, 1.4, 0.2]]}) + result = server.test( + f"/v2/models/{model_name}/predict", {"inputs": [[5.1, 3.5, 1.4, 0.2]]} + ) print(result) assert result # unset mlflow experiment name to default mlflow.environment_variables.MLFLOW_EXPERIMENT_NAME.unset() - - diff --git a/functions/src/model_server/function.yaml b/functions/src/model_server/function.yaml index 83e80823d..20f85bf67 100644 --- a/functions/src/model_server/function.yaml +++ b/functions/src/model_server/function.yaml @@ -1,27 +1,28 @@ -kind: remote +metadata: + tag: '' + name: model-server + categories: + - model-serving + - machine-learning verbose: false +kind: remote spec: + image: mlrun/mlrun disable_auto_mount: false + build: + origin_filename: '' + functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKIyBHZW5lcmF0ZWQgYnkgbnVjbGlvLmV4cG9ydC5OdWNsaW9FeHBvcnRlcgoKaW1wb3J0IHdhcm5pbmdzCgppbXBvcnQgbWxydW4KZnJvbSBjbG91ZHBpY2tsZSBpbXBvcnQgbG9hZAoKd2FybmluZ3MuZmlsdGVyd2FybmluZ3MoImlnbm9yZSIpCgppbXBvcnQgbnVtcHkgYXMgbnAKCgpjbGFzcyBDbGFzc2lmaWVyTW9kZWwobWxydW4ucnVudGltZXMuTUxNb2RlbFNlcnZlcik6CiAgICBkZWYgbG9hZChzZWxmKToKICAgICAgICAiIiJMb2FkIG1vZGVsIGZyb20gc3RvcmFnZS4iIiIKICAgICAgICBtb2RlbF9maWxlLCBleHRyYV9kYXRhID0gc2VsZi5nZXRfbW9kZWwoIi5wa2wiKQogICAgICAgIHNlbGYubW9kZWwgPSBsb2FkKG9wZW4obW9kZWxfZmlsZSwgInJiIikpCgogICAgZGVmIHByZWRpY3Qoc2VsZiwgYm9keTogZGljdCkgLT4gbGlzdDoKICAgICAgICAiIiJHZW5lcmF0ZSBtb2RlbCBwcmVkaWN0aW9ucyBmcm9tIHNhbXBsZS4KCiAgICAgICAgOnBhcmFtIGJvZHkgOiBBIGRpY3Qgb2Ygb2JzZXJ2YXRpb25zLCBlYWNoIG9mIHdoaWNoIGlzIGFuIDEtZGltZW5zaW9uYWwgZmVhdHVyZSB2ZWN0b3IuCgogICAgICAgIFJldHVybnMgbW9kZWwgcHJlZGljdGlvbnMgYXMgYSBgTGlzdGAsIG9uZSBmb3IgZWFjaCByb3cgaW4gdGhlIGBib2R5YCBpbnB1dCBgTGlzdGAuCiAgICAgICAgIiIiCiAgICAgICAgdHJ5OgogICAgICAgICAgICBmZWF0cyA9IG5wLmFzYXJyYXkoYm9keVsiaW5zdGFuY2VzIl0pCiAgICAgICAgICAgIHJlc3VsdDogbnAubmRhcnJheSA9IHNlbGYubW9kZWwucHJlZGljdChmZWF0cykKICAgICAgICAgICAgcmVzcCA9IHJlc3VsdC50b2xpc3QoKQogICAgICAgIGV4Y2VwdCBFeGNlcHRpb24gYXMgZToKICAgICAgICAgICAgcmFpc2UgRXhjZXB0aW9uKGYiRmFpbGVkIHRvIHByZWRpY3Qge2V9IikKCiAgICAgICAgcmV0dXJuIHJlc3AKCmZyb20gbWxydW4ucnVudGltZXMgaW1wb3J0IG51Y2xpb19pbml0X2hvb2sKZGVmIGluaXRfY29udGV4dChjb250ZXh0KToKICAgIG51Y2xpb19pbml0X2hvb2soY29udGV4dCwgZ2xvYmFscygpLCAnc2VydmluZycpCgpkZWYgaGFuZGxlcihjb250ZXh0LCBldmVudCk6CiAgICByZXR1cm4gY29udGV4dC5tbHJ1bl9oYW5kbGVyKGNvbnRleHQsIGV2ZW50KQo= + code_origin: '' + filename: model_server.py min_replicas: 1 - source: '' - description: generic sklearn model server + command: '' default_handler: '' + source: '' max_replicas: 4 - image: mlrun/mlrun + base_image_pull: false + description: generic sklearn model server function_kind: serving function_handler: model-server-nuclio:handler - build: - origin_filename: '' - functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKIyBHZW5lcmF0ZWQgYnkgbnVjbGlvLmV4cG9ydC5OdWNsaW9FeHBvcnRlcgoKaW1wb3J0IG1scnVuCgpmcm9tIGNsb3VkcGlja2xlIGltcG9ydCBsb2FkCmZyb20gdHlwaW5nIGltcG9ydCBMaXN0CmZyb20gZGF0ZXRpbWUgaW1wb3J0IGRhdGV0aW1lCmZyb20gc2tsZWFybi5kYXRhc2V0cyBpbXBvcnQgbG9hZF9pcmlzCgppbXBvcnQgd2FybmluZ3MKCndhcm5pbmdzLmZpbHRlcndhcm5pbmdzKCJpZ25vcmUiKQoKaW1wb3J0IG9zCmltcG9ydCBudW1weSBhcyBucAoKCmNsYXNzIENsYXNzaWZpZXJNb2RlbChtbHJ1bi5ydW50aW1lcy5NTE1vZGVsU2VydmVyKToKICAgIGRlZiBsb2FkKHNlbGYpOgogICAgICAgICIiIkxvYWQgbW9kZWwgZnJvbSBzdG9yYWdlLiIiIgogICAgICAgIG1vZGVsX2ZpbGUsIGV4dHJhX2RhdGEgPSBzZWxmLmdldF9tb2RlbCgiLnBrbCIpCiAgICAgICAgc2VsZi5tb2RlbCA9IGxvYWQob3Blbihtb2RlbF9maWxlLCAicmIiKSkKCiAgICBkZWYgcHJlZGljdChzZWxmLCBib2R5OiBkaWN0KSAtPiBMaXN0OgogICAgICAgICIiIkdlbmVyYXRlIG1vZGVsIHByZWRpY3Rpb25zIGZyb20gc2FtcGxlLgoKICAgICAgICA6cGFyYW0gYm9keSA6IEEgZGljdCBvZiBvYnNlcnZhdGlvbnMsIGVhY2ggb2Ygd2hpY2ggaXMgYW4gMS1kaW1lbnNpb25hbCBmZWF0dXJlIHZlY3Rvci4KCiAgICAgICAgUmV0dXJucyBtb2RlbCBwcmVkaWN0aW9ucyBhcyBhIGBMaXN0YCwgb25lIGZvciBlYWNoIHJvdyBpbiB0aGUgYGJvZHlgIGlucHV0IGBMaXN0YC4KICAgICAgICAiIiIKICAgICAgICB0cnk6CiAgICAgICAgICAgIGZlYXRzID0gbnAuYXNhcnJheShib2R5WyJpbnN0YW5jZXMiXSkKICAgICAgICAgICAgcmVzdWx0OiBucC5uZGFycmF5ID0gc2VsZi5tb2RlbC5wcmVkaWN0KGZlYXRzKQogICAgICAgICAgICByZXNwID0gcmVzdWx0LnRvbGlzdCgpCiAgICAgICAgZXhjZXB0IEV4Y2VwdGlvbiBhcyBlOgogICAgICAgICAgICByYWlzZSBFeGNlcHRpb24oZiJGYWlsZWQgdG8gcHJlZGljdCB7ZX0iKQoKICAgICAgICByZXR1cm4gcmVzcAoKZnJvbSBtbHJ1bi5ydW50aW1lcyBpbXBvcnQgbnVjbGlvX2luaXRfaG9vawpkZWYgaW5pdF9jb250ZXh0KGNvbnRleHQpOgogICAgbnVjbGlvX2luaXRfaG9vayhjb250ZXh0LCBnbG9iYWxzKCksICdzZXJ2aW5nJykKCmRlZiBoYW5kbGVyKGNvbnRleHQsIGV2ZW50KToKICAgIHJldHVybiBjb250ZXh0Lm1scnVuX2hhbmRsZXIoY29udGV4dCwgZXZlbnQpCg== - code_origin: '' - base_image_pull: false - command: '' env: - name: MLRUN_HTTPDB__NUCLIO__EXPLICIT_ACK value: enabled -metadata: - categories: - - model-serving - - machine-learning - name: model-server - tag: '' diff --git a/functions/src/model_server/model_server.py b/functions/src/model_server/model_server.py index cefdff235..3227a289c 100644 --- a/functions/src/model_server/model_server.py +++ b/functions/src/model_server/model_server.py @@ -14,18 +14,13 @@ # # Generated by nuclio.export.NuclioExporter -import mlrun +import warnings +import mlrun from cloudpickle import load -from typing import List -from datetime import datetime -from sklearn.datasets import load_iris - -import warnings warnings.filterwarnings("ignore") -import os import numpy as np @@ -35,7 +30,7 @@ def load(self): model_file, extra_data = self.get_model(".pkl") self.model = load(open(model_file, "rb")) - def predict(self, body: dict) -> List: + def predict(self, body: dict) -> list: """Generate model predictions from sample. :param body : A dict of observations, each of which is an 1-dimensional feature vector. diff --git a/functions/src/model_server/test_model_server.py b/functions/src/model_server/test_model_server.py index a11726bc7..a930ab736 100644 --- a/functions/src/model_server/test_model_server.py +++ b/functions/src/model_server/test_model_server.py @@ -12,38 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os +import pickle + +from model_server import ClassifierModel from sklearn.datasets import load_iris -from sklearn.model_selection import train_test_split -from sklearn.preprocessing import StandardScaler from sklearn.linear_model import LogisticRegression from sklearn.metrics import accuracy_score -from model_server import ClassifierModel -import pickle -import mlrun -import os -import requests -import json +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler + def gen_model(): # Getting the data - X,y = load_iris(return_X_y=True) - X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2,random_state=123) + X, y = load_iris(return_X_y=True) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=123 + ) # transforming the data sc = StandardScaler() X_train = sc.fit_transform(X_train) X_test = sc.transform(X_test) # Getting the model and training it - classifier = LogisticRegression(random_state = 0, solver='lbfgs', multi_class='auto') + classifier = LogisticRegression(random_state=0, solver="lbfgs", multi_class="auto") classifier.fit(X_train, y_train) # saving the model - filename = os.getcwd()+'/model.pkl' - pickle.dump(classifier, open(filename, 'wb')) - return X_test,y_test + filename = os.getcwd() + "/model.pkl" + pickle.dump(classifier, open(filename, "wb")) + return X_test, y_test + def test_remote_model_server(): - x,y = gen_model() - my_class = ClassifierModel('iris',model_dir=os.getcwd()) + x, y = gen_model() + my_class = ClassifierModel("iris", model_dir=os.getcwd()) my_class.load() - my_dict = {'instances':x.tolist()} + my_dict = {"instances": x.tolist()} preds = my_class.predict(my_dict) - assert(accuracy_score(y,preds) > 0.8) + assert accuracy_score(y, preds) > 0.8 diff --git a/functions/src/model_server_tester/function.yaml b/functions/src/model_server_tester/function.yaml index 45934c444..ae176c1e6 100644 --- a/functions/src/model_server_tester/function.yaml +++ b/functions/src/model_server_tester/function.yaml @@ -1,35 +1,29 @@ -kind: job metadata: - name: model-server-tester tag: '' - hash: 3b203a2799e44992539eafd32a4b8979bbcc8001 - project: '' - labels: - author: Iguazio + name: model-server-tester categories: - monitoring - model-serving +verbose: false +kind: job spec: - command: '' - args: [] image: mlrun/mlrun - env: [] - default_handler: model_server_tester + disable_auto_mount: false + build: + origin_filename: '' + functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKIyBHZW5lcmF0ZWQgYnkgbnVjbGlvLmV4cG9ydC5OdWNsaW9FeHBvcnRlcgoKaW1wb3J0IGpzb24KZnJvbSBkYXRldGltZSBpbXBvcnQgZGF0ZXRpbWUKCmltcG9ydCBudW1weSBhcyBucAppbXBvcnQgcmVxdWVzdHMKZnJvbSBtbHJ1bi5hcnRpZmFjdHMgaW1wb3J0IENoYXJ0QXJ0aWZhY3QKZnJvbSBtbHJ1bi5kYXRhc3RvcmUgaW1wb3J0IERhdGFJdGVtCgoKZGVmIG1vZGVsX3NlcnZlcl90ZXN0ZXIoCiAgICBjb250ZXh0LAogICAgdGFibGU6IERhdGFJdGVtLAogICAgYWRkcjogc3RyLAogICAgbGFiZWxfY29sdW1uOiBzdHIgPSAibGFiZWwiLAogICAgbW9kZWw6IHN0ciA9ICIiLAogICAgbWF0Y2hfZXJyOiBib29sID0gRmFsc2UsCiAgICByb3dzOiBpbnQgPSAyMCwKKToKICAgICIiIlRlc3QgYSBtb2RlbCBzZXJ2ZXIKCiAgICA6cGFyYW0gdGFibGU6ICAgICAgICAgY3N2L3BhcnF1ZXQgdGFibGUgd2l0aCB0ZXN0IGRhdGEKICAgIDpwYXJhbSBhZGRyOiAgICAgICAgICBmdW5jdGlvbiBhZGRyZXNzL3VybAogICAgOnBhcmFtIGxhYmVsX2NvbHVtbjogIG5hbWUgb2YgdGhlIGxhYmVsIGNvbHVtbiBpbiB0YWJsZQogICAgOnBhcmFtIG1vZGVsOiAgICAgICAgIHRlc3RlZCBtb2RlbCBuYW1lCiAgICA6cGFyYW0gbWF0Y2hfZXJyOiAgICAgcmFpc2UgZXJyb3Igb24gdmFsaWRhdGlvbiAocmVxdWlyZSBwcm9wZXIgdGVzdCBzZXQpCiAgICA6cGFyYW0gcm93czogICAgICAgICAgbnVtYmVyIG9mIHJvd3MgdG8gdXNlIGZyb20gdGVzdCBzZXQKICAgICIiIgoKICAgIHRhYmxlID0gdGFibGUuYXNfZGYoKQoKICAgIHlfbGlzdCA9IHRhYmxlLnBvcChsYWJlbF9jb2x1bW4pLnZhbHVlcy50b2xpc3QoKQogICAgY29udGV4dC5sb2dnZXIuaW5mbyhmInRlc3Rpbmcgd2l0aCBkYXRhc2V0IGFnYWluc3Qge2FkZHJ9LCBtb2RlbDoge21vZGVsfSIpCiAgICBpZiByb3dzIGFuZCByb3dzIDwgdGFibGUuc2hhcGVbMF06CiAgICAgICAgdGFibGUgPSB0YWJsZS5zYW1wbGUocm93cykKCiAgICBjb3VudCA9IGVycl9jb3VudCA9IG1hdGNoID0gMAogICAgdGltZXMgPSBbXQogICAgZm9yIHgsIHkgaW4gemlwKHRhYmxlLnZhbHVlcywgeV9saXN0KToKICAgICAgICBjb3VudCArPSAxCiAgICAgICAgZXZlbnRfZGF0YSA9IGpzb24uZHVtcHMoeyJpbnN0YW5jZXMiOiBbeC50b2xpc3QoKV19KQogICAgICAgIGhhZF9lcnIgPSBGYWxzZQogICAgICAgIHRyeToKICAgICAgICAgICAgc3RhcnQgPSBkYXRldGltZS5ub3coKQogICAgICAgICAgICByZXNwID0gcmVxdWVzdHMucHV0KGYie2FkZHJ9L3ttb2RlbH0vcHJlZGljdCIsIGpzb249ZXZlbnRfZGF0YSkKICAgICAgICAgICAgaWYgbm90IHJlc3Aub2s6CiAgICAgICAgICAgICAgICBjb250ZXh0LmxvZ2dlci5lcnJvcihmImJhZCBmdW5jdGlvbiByZXNwISFcbntyZXNwLnRleHR9IikKICAgICAgICAgICAgICAgIGVycl9jb3VudCArPSAxCiAgICAgICAgICAgICAgICBjb250aW51ZQogICAgICAgICAgICB0aW1lcy5hcHBlbmQoKGRhdGV0aW1lLm5vdygpIC0gc3RhcnQpLm1pY3Jvc2Vjb25kcykKCiAgICAgICAgZXhjZXB0IE9TRXJyb3IgYXMgZXJyOgogICAgICAgICAgICBjb250ZXh0LmxvZ2dlci5lcnJvcihmImVycm9yIGluIHJlcXVlc3QsIGRhdGE6e2V2ZW50X2RhdGF9LCBlcnJvcjoge2Vycn0iKQogICAgICAgICAgICBlcnJfY291bnQgKz0gMQogICAgICAgICAgICBjb250aW51ZQoKICAgICAgICB5X3Jlc3AgPSByZXNwLmpzb24oKVswXQogICAgICAgIGlmIHkgPT0geV9yZXNwOgogICAgICAgICAgICBtYXRjaCArPSAxCgogICAgY29udGV4dC5sb2dfcmVzdWx0KCJ0b3RhbF90ZXN0cyIsIGNvdW50KQogICAgY29udGV4dC5sb2dfcmVzdWx0KCJlcnJvcnMiLCBlcnJfY291bnQpCiAgICBjb250ZXh0LmxvZ19yZXN1bHQoIm1hdGNoIiwgbWF0Y2gpCiAgICBpZiBjb3VudCAtIGVycl9jb3VudCA+IDA6CiAgICAgICAgdGltZXNfYXJyID0gbnAuYXJyYXkodGltZXMpCiAgICAgICAgY29udGV4dC5sb2dfcmVzdWx0KCJhdmdfbGF0ZW5jeSIsIGludChucC5tZWFuKHRpbWVzX2FycikpKQogICAgICAgIGNvbnRleHQubG9nX3Jlc3VsdCgibWluX2xhdGVuY3kiLCBpbnQobnAuYW1pbih0aW1lc19hcnIpKSkKICAgICAgICBjb250ZXh0LmxvZ19yZXN1bHQoIm1heF9sYXRlbmN5IiwgaW50KG5wLmFtYXgodGltZXNfYXJyKSkpCgogICAgICAgIGNoYXJ0ID0gQ2hhcnRBcnRpZmFjdCgibGF0ZW5jeSIsIGhlYWRlcj1bIlRlc3QiLCAiTGF0ZW5jeSAobWljcm9zZWMpIl0pCiAgICAgICAgZm9yIGkgaW4gcmFuZ2UobGVuKHRpbWVzKSk6CiAgICAgICAgICAgIGNoYXJ0LmFkZF9yb3coW2kgKyAxLCBpbnQodGltZXNbaV0pXSkKICAgICAgICBjb250ZXh0LmxvZ19hcnRpZmFjdChjaGFydCkKCiAgICBjb250ZXh0LmxvZ2dlci5pbmZvKAogICAgICAgIGYicnVuIHtjb3VudH0gdGVzdHMsIHtlcnJfY291bnR9IGVycm9ycyBhbmQge21hdGNofSBtYXRjaCBleHBlY3RlZCB2YWx1ZSIKICAgICkKCiAgICBpZiBlcnJfY291bnQ6CiAgICAgICAgcmFpc2UgVmFsdWVFcnJvcihmImZhaWxlZCBvbiB7ZXJyX2NvdW50fSB0ZXN0cyBvZiB7Y291bnR9IikKCiAgICBpZiBtYXRjaF9lcnIgYW5kIG1hdGNoICE9IGNvdW50OgogICAgICAgIHJhaXNlIFZhbHVlRXJyb3IoZiJvbmx5IHttYXRjaH0gcmVzdWx0cyBtYXRjaCBvdXQgb2Yge2NvdW50fSIpCg== + code_origin: '' + filename: model_server_tester.py entry_points: model_server_tester: - name: model_server_tester - doc: Test a model server parameters: - name: context - default: '' - name: table type: DataItem doc: csv/parquet table with test data - default: '' - name: addr type: str doc: function address/url - default: '' - name: label_column type: str doc: name of the label column in table @@ -46,13 +40,11 @@ spec: type: int doc: number of rows to use from test set default: 20 - outputs: - - default: '' - lineno: 14 + name: model_server_tester + doc: Test a model server + has_kwargs: false + has_varargs: false + lineno: 26 + command: '' description: test model servers - build: - functionSourceCode: IyBHZW5lcmF0ZWQgYnkgbnVjbGlvLmV4cG9ydC5OdWNsaW9FeHBvcnRlcgoKaW1wb3J0IG9zCmltcG9ydCBwYW5kYXMgYXMgcGQKaW1wb3J0IHJlcXVlc3RzCmltcG9ydCBqc29uCmltcG9ydCBudW1weSBhcyBucAoKZnJvbSBkYXRldGltZSBpbXBvcnQgZGF0ZXRpbWUKZnJvbSBtbHJ1bi5kYXRhc3RvcmUgaW1wb3J0IERhdGFJdGVtCmZyb20gbWxydW4uYXJ0aWZhY3RzIGltcG9ydCBnZXRfbW9kZWwsIENoYXJ0QXJ0aWZhY3QKCgpkZWYgbW9kZWxfc2VydmVyX3Rlc3RlcigKICAgIGNvbnRleHQsCiAgICB0YWJsZTogRGF0YUl0ZW0sCiAgICBhZGRyOiBzdHIsCiAgICBsYWJlbF9jb2x1bW46IHN0ciA9ICJsYWJlbCIsCiAgICBtb2RlbDogc3RyID0gIiIsCiAgICBtYXRjaF9lcnI6IGJvb2wgPSBGYWxzZSwKICAgIHJvd3M6IGludCA9IDIwLAopOgogICAgIiIiVGVzdCBhIG1vZGVsIHNlcnZlcgoKICAgIDpwYXJhbSB0YWJsZTogICAgICAgICBjc3YvcGFycXVldCB0YWJsZSB3aXRoIHRlc3QgZGF0YQogICAgOnBhcmFtIGFkZHI6ICAgICAgICAgIGZ1bmN0aW9uIGFkZHJlc3MvdXJsCiAgICA6cGFyYW0gbGFiZWxfY29sdW1uOiAgbmFtZSBvZiB0aGUgbGFiZWwgY29sdW1uIGluIHRhYmxlCiAgICA6cGFyYW0gbW9kZWw6ICAgICAgICAgdGVzdGVkIG1vZGVsIG5hbWUKICAgIDpwYXJhbSBtYXRjaF9lcnI6ICAgICByYWlzZSBlcnJvciBvbiB2YWxpZGF0aW9uIChyZXF1aXJlIHByb3BlciB0ZXN0IHNldCkKICAgIDpwYXJhbSByb3dzOiAgICAgICAgICBudW1iZXIgb2Ygcm93cyB0byB1c2UgZnJvbSB0ZXN0IHNldAogICAgIiIiCgogICAgdGFibGUgPSB0YWJsZS5hc19kZigpCgogICAgeV9saXN0ID0gdGFibGUucG9wKGxhYmVsX2NvbHVtbikudmFsdWVzLnRvbGlzdCgpCiAgICBjb250ZXh0LmxvZ2dlci5pbmZvKGYidGVzdGluZyB3aXRoIGRhdGFzZXQgYWdhaW5zdCB7YWRkcn0sIG1vZGVsOiB7bW9kZWx9IikKICAgIGlmIHJvd3MgYW5kIHJvd3MgPCB0YWJsZS5zaGFwZVswXToKICAgICAgICB0YWJsZSA9IHRhYmxlLnNhbXBsZShyb3dzKQoKICAgIGNvdW50ID0gZXJyX2NvdW50ID0gbWF0Y2ggPSAwCiAgICB0aW1lcyA9IFtdCiAgICBmb3IgeCwgeSBpbiB6aXAodGFibGUudmFsdWVzLCB5X2xpc3QpOgogICAgICAgIGNvdW50ICs9IDEKICAgICAgICBldmVudF9kYXRhID0ganNvbi5kdW1wcyh7Imluc3RhbmNlcyI6IFt4LnRvbGlzdCgpXX0pCiAgICAgICAgaGFkX2VyciA9IEZhbHNlCiAgICAgICAgdHJ5OgogICAgICAgICAgICBzdGFydCA9IGRhdGV0aW1lLm5vdygpCiAgICAgICAgICAgIHJlc3AgPSByZXF1ZXN0cy5wdXQoZiJ7YWRkcn0ve21vZGVsfS9wcmVkaWN0IiwganNvbj1ldmVudF9kYXRhKQogICAgICAgICAgICBpZiBub3QgcmVzcC5vazoKICAgICAgICAgICAgICAgIGNvbnRleHQubG9nZ2VyLmVycm9yKGYiYmFkIGZ1bmN0aW9uIHJlc3AhIVxue3Jlc3AudGV4dH0iKQogICAgICAgICAgICAgICAgZXJyX2NvdW50ICs9IDEKICAgICAgICAgICAgICAgIGNvbnRpbnVlCiAgICAgICAgICAgIHRpbWVzLmFwcGVuZCgoZGF0ZXRpbWUubm93KCkgLSBzdGFydCkubWljcm9zZWNvbmRzKQoKICAgICAgICBleGNlcHQgT1NFcnJvciBhcyBlcnI6CiAgICAgICAgICAgIGNvbnRleHQubG9nZ2VyLmVycm9yKGYiZXJyb3IgaW4gcmVxdWVzdCwgZGF0YTp7ZXZlbnRfZGF0YX0sIGVycm9yOiB7ZXJyfSIpCiAgICAgICAgICAgIGVycl9jb3VudCArPSAxCiAgICAgICAgICAgIGNvbnRpbnVlCgogICAgICAgIHlfcmVzcCA9IHJlc3AuanNvbigpWzBdCiAgICAgICAgaWYgeSA9PSB5X3Jlc3A6CiAgICAgICAgICAgIG1hdGNoICs9IDEKCiAgICBjb250ZXh0LmxvZ19yZXN1bHQoInRvdGFsX3Rlc3RzIiwgY291bnQpCiAgICBjb250ZXh0LmxvZ19yZXN1bHQoImVycm9ycyIsIGVycl9jb3VudCkKICAgIGNvbnRleHQubG9nX3Jlc3VsdCgibWF0Y2giLCBtYXRjaCkKICAgIGlmIGNvdW50IC0gZXJyX2NvdW50ID4gMDoKICAgICAgICB0aW1lc19hcnIgPSBucC5hcnJheSh0aW1lcykKICAgICAgICBjb250ZXh0LmxvZ19yZXN1bHQoImF2Z19sYXRlbmN5IiwgaW50KG5wLm1lYW4odGltZXNfYXJyKSkpCiAgICAgICAgY29udGV4dC5sb2dfcmVzdWx0KCJtaW5fbGF0ZW5jeSIsIGludChucC5hbWluKHRpbWVzX2FycikpKQogICAgICAgIGNvbnRleHQubG9nX3Jlc3VsdCgibWF4X2xhdGVuY3kiLCBpbnQobnAuYW1heCh0aW1lc19hcnIpKSkKCiAgICAgICAgY2hhcnQgPSBDaGFydEFydGlmYWN0KCJsYXRlbmN5IiwgaGVhZGVyPVsiVGVzdCIsICJMYXRlbmN5IChtaWNyb3NlYykiXSkKICAgICAgICBmb3IgaSBpbiByYW5nZShsZW4odGltZXMpKToKICAgICAgICAgICAgY2hhcnQuYWRkX3JvdyhbaSArIDEsIGludCh0aW1lc1tpXSldKQogICAgICAgIGNvbnRleHQubG9nX2FydGlmYWN0KGNoYXJ0KQoKICAgIGNvbnRleHQubG9nZ2VyLmluZm8oCiAgICAgICAgZiJydW4ge2NvdW50fSB0ZXN0cywge2Vycl9jb3VudH0gZXJyb3JzIGFuZCB7bWF0Y2h9IG1hdGNoIGV4cGVjdGVkIHZhbHVlIgogICAgKQoKICAgIGlmIGVycl9jb3VudDoKICAgICAgICByYWlzZSBWYWx1ZUVycm9yKGYiZmFpbGVkIG9uIHtlcnJfY291bnR9IHRlc3RzIG9mIHtjb3VudH0iKQoKICAgIGlmIG1hdGNoX2VyciBhbmQgbWF0Y2ggIT0gY291bnQ6CiAgICAgICAgcmFpc2UgVmFsdWVFcnJvcihmIm9ubHkge21hdGNofSByZXN1bHRzIG1hdGNoIG91dCBvZiB7Y291bnR9IikK - commands: [] - code_origin: https://github.com/daniels290813/functions.git#55a79c32be5d233cc11efcf40cd3edbe309bfdef:/home/kali/functions/model_server_tester/model_server_tester.py - affinity: null -verbose: false + default_handler: model_server_tester diff --git a/functions/src/model_server_tester/model_server_tester.py b/functions/src/model_server_tester/model_server_tester.py index 7d83b148d..922030d11 100644 --- a/functions/src/model_server_tester/model_server_tester.py +++ b/functions/src/model_server_tester/model_server_tester.py @@ -14,15 +14,13 @@ # # Generated by nuclio.export.NuclioExporter -import os -import pandas as pd -import requests import json -import numpy as np - from datetime import datetime + +import numpy as np +import requests +from mlrun.artifacts import ChartArtifact from mlrun.datastore import DataItem -from mlrun.artifacts import get_model, ChartArtifact def model_server_tester( diff --git a/functions/src/noise_reduction/function.yaml b/functions/src/noise_reduction/function.yaml index d6d33b8da..e9d494506 100644 --- a/functions/src/noise_reduction/function.yaml +++ b/functions/src/noise_reduction/function.yaml @@ -1,21 +1,27 @@ +metadata: + tag: '' + name: noise-reduction + categories: + - data-preparation + - audio +verbose: false +kind: job spec: + image: '' + disable_auto_mount: false + build: + origin_filename: '' + functionSourceCode: import logging
from abc import ABCMeta, abstractmethod
from multiprocessing import Process, Queue
from pathlib import Path

import librosa
import numpy as np
import torch
from scipy.io import wavfile
from tqdm import tqdm

#: The value to send into multiprocessing queues to stop the process:
_MULTIPROCESSING_STOP_MARK = "STOP"

# Get the global logger:
try:
    import mlrun

    _LOGGER = mlrun.get_or_create_ctx("noise_reduce").logger
except ModuleNotFoundError:
    _LOGGER = logging.getLogger()


class ReduceNoiseBase(metaclass=ABCMeta):
    """
    Base class for noise reduction.
    This class is aimed to be inherited by specific noise reduction algorithms.
    You must implement the following methods:
    - clean_audio:  The method to clean the audio, where the noise reduction algorithm is implemented.
    - save_audio:   The method to save the audio to a file.
    - load_audio:   The method to load the audio from a file.

    After implementing the above methods, you can use the reduce_noise method to reduce noise from audio files.
    """

    def __init__(
        self,
        target_directory: Path,
        verbose: bool = True,
        silence_threshold: float = None,
    ):
        self.target_directory = Path(target_directory)
        self.verbose = verbose
        self.silence_threshold = silence_threshold

    def reduce_noise(self, audio_file: Path) -> tuple[bool, tuple[str, str]]:
        """
        Reduce noise from the given audio file.

        :param audio_file:  The audio file to reduce noise from.

        :returns: A tuple of:
         - a boolean indicating whether an error occurred
         - a tuple of:
            - audio file name
            - target path in case of success / error message in case of failure.
        """
        try:
            if self.verbose:
                _LOGGER.info(f"Reducing noise from {audio_file.name}.")

            # Load audio data:
            audio = self.load_audio(file=str(audio_file))

            # Perform noise reduction:
            reduced_noise = self.clean_audio(data=audio)

            # Remove silence from the audio if necessary:
            reduced_noise = self.remove_silence(audio=reduced_noise)

            # Prepare target path:
            target_path = self.update_to_wav_suffix(audio_file=audio_file)

            # Save file:
            self.save_audio(
                audio=reduced_noise,
                target_path=target_path,
            )

            if self.verbose:
                _LOGGER.info(f"Saved cleaned audio file to {target_path}.")

            return False, (audio_file.name, str(target_path))
        except Exception as exception:
            if self.verbose:
                _LOGGER.error(f"Failed to reduce noise from {audio_file.name}.")
                _LOGGER.error(f"Error: {exception}")
            # Collect the error:
            return True, (audio_file.name, str(exception))

    @abstractmethod
    def clean_audio(self, data) -> np.ndarray | torch.Tensor:
        """
        Clean the audio from noise. Here you should implement the noise reduction algorithm.

        :param data:    The audio data to clean.

        :returns: The cleaned audio.
        """
        pass

    @abstractmethod
    def save_audio(self, audio: np.ndarray, target_path: Path):
        """
        Save the audio to a file.

        :param audio:       The audio to save.
        :param target_path: The target path to save the audio to.
        """
        pass

    @abstractmethod
    def load_audio(self, file: str) -> tuple[np.ndarray | torch.Tensor, int]:
        """
        Load the audio from a file.

        :param file:    The file to load the audio from.

        :returns: A tuple of:
            - the audio data
            - the sample rate
        """
        pass

    def update_to_wav_suffix(self, audio_file: Path):
        target_path = self.target_directory / audio_file.name
        if target_path.suffix != ".wav":
            old_suffix = target_path.suffix[1:]
            target_path = target_path.with_stem(target_path.stem + f"_{old_suffix}")
            return target_path.with_suffix(".wav")
        else:
            return target_path

    def remove_silence(
        self,
        audio: np.ndarray,
    ):
        """
        Remove silence sections from the audio.

        :param audio:   The audio to remove silence from.

        :returns: The audio without silence.
        """
        if self.silence_threshold is None:
            return audio

        # Get the indices of the non-silent frames:
        non_silent_indices = librosa.effects.split(
            y=audio,
            top_db=self.silence_threshold,
            frame_length=2048,
            hop_length=256,
        )

        # Get the non-silent audio:
        non_silent_audio = np.concatenate(
            [audio[:, start:end] for start, end in non_silent_indices], axis=1
        )

        return non_silent_audio


class ReduceNoise(ReduceNoiseBase):
    def __init__(
        self,
        target_directory: Path,
        verbose: bool = True,
        silence_threshold: float = None,
        sample_rate: int = 16000,
        duration: int = None,
        channel: int = None,
    ):
        super().__init__(target_directory, verbose, silence_threshold)
        self.sample_rate = sample_rate
        self.duration = duration
        self.channel = channel

    def save_audio(self, audio: np.ndarray, target_path: Path):
        # If the audio has more than one channel, transpose it in order to save it:
        if len(audio) > 1:
            audio = audio.T

        wavfile.write(
            filename=target_path,
            rate=self.sample_rate,
            data=audio,
        )

    def load_audio(self, file: str) -> np.ndarray:
        data, sr = librosa.load(
            path=file,
            sr=self.sample_rate,
            mono=False,  # keep channels separate
            duration=self.duration,
        )
        # set sample rate:
        self.sample_rate = int(sr)

        # convert to int with scaling for 16-bit integer
        data *= 32767 / np.max(np.abs(data))  # re-scaling
        data = data.astype(np.int16)  # change data type

        # select channel
        data_to_reduce = data[self.channel] if self.channel is not None else data
        return data_to_reduce

    def clean_audio(self, data: np.ndarray) -> np.ndarray:
        try:
            import noisereduce
        except ImportError as e:
            raise ImportError("Please install noisereduce package") from e

        reduced_noise = noisereduce.reduce_noise(y=data, sr=self.sample_rate)

        # add channel back after noise reduction
        if self.channel is not None:
            # putting the channel back in the data
            data[self.channel] = reduced_noise
            # updating the data to save
            reduced_noise = data

        return reduced_noise


class DFN(ReduceNoiseBase):
    def __init__(
        self,
        target_directory: Path,
        verbose: bool = True,
        silence_threshold: float = None,
        pad: bool = True,
        atten_lim_db: int = None,
        **kwargs,
    ):
        super().__init__(target_directory, verbose, silence_threshold)
        self.pad = pad
        self.atten_lim_db = atten_lim_db
        self.kwargs = kwargs

        # import required packages
        try:
            from df.enhance import init_df
        except ImportError as e:
            raise ImportError("Please install deepfilternet packages") from e

        if self.verbose:
            _LOGGER.info("Loading DeepFilterNet2 model.")

        # Load the model:
        model, df_state, _ = init_df()
        self.model = model
        self.df_state = df_state
        self.sample_rate = self.df_state.sr()

    def save_audio(self, audio: np.ndarray, target_path: Path):
        try:
            from df.enhance import save_audio
        except ImportError as e:
            raise ImportError("Please install deepfilternet package") from e
        save_audio(
            file=target_path.name,
            audio=audio,
            sr=self.sample_rate,
            output_dir=str(self.target_directory),
        )

    def load_audio(self, file: str) -> torch.Tensor:
        try:
            from df.enhance import load_audio
        except ImportError as e:
            raise ImportError("Please install deepfilternet package") from e
        audio, _ = load_audio(file=file, sr=self.sample_rate, **self.kwargs)
        return audio

    def clean_audio(self, data: torch.Tensor) -> torch.Tensor:
        try:
            from df.enhance import enhance
        except ImportError as e:
            raise ImportError("Please install deepfilternet package") from e
        return enhance(
            model=self.model,
            df_state=self.df_state,
            audio=data,
            pad=self.pad,
            atten_lim_db=self.atten_lim_db,
        )


def _multiprocessing_complete_tasks(
    noise_reduce_type: type[ReduceNoiseBase],
    noise_reduce_arguments: dict,
    tasks_queue: Queue,
    results_queue: Queue,
):
    """
    Complete the tasks in the given queue and put the results in the given results queue. The function will stop when
    the given tasks queue will receive the stop mark. It is aimed to be used with multiprocessing as a process.

    :param noise_reduce_type:       The noise reduce type to use.
    :param noise_reduce_arguments:  The noisereduce initialization kwargs.
    :param tasks_queue:             A queue to get the tasks from.
    :param results_queue:           A queue to put the results in.
    """
    # Initialize the reduce noise object
    noise_reducer = noise_reduce_type(**noise_reduce_arguments)

    # Start listening to the tasks queue:
    while True:
        # Get the audio_file:
        audio_file = tasks_queue.get()
        if audio_file == _MULTIPROCESSING_STOP_MARK:
            break
        audio_file = Path(audio_file)
        # Apply noise reduction and collect the result:
        results_queue.put(noise_reducer.reduce_noise(audio_file=audio_file))

    # Mark the end of the tasks:
    results_queue.put(_MULTIPROCESSING_STOP_MARK)


def reduce_noise_dfn(
    audio_source: str,
    target_directory: str,
    pad: bool = True,
    atten_lim_db: int = None,
    silence_threshold: float = None,
    use_multiprocessing: int = 0,
    verbose: bool = True,
    **kwargs,
):
    """
    Reduce noise from audio files using DeepFilterNet.
    For more information about the noise reduction algorithm see:
    https://github.com/Rikorose/DeepFilterNet
    Notice that the saved files are in wav format, even if the original files are in other format.

    :param audio_source:        path to audio file or directory of audio files
    :param target_directory:    path to target directory to save cleaned audio files
    :param pad:                 whether to pad the audio file with zeros before cleaning
    :param atten_lim_db:        maximum attenuation in dB
    :param silence_threshold:   the threshold to remove silence from the audio, in dB. If None, no silence removal is
                                performed.
    :param use_multiprocessing: Number of processes to use for cleaning the audio files.
                                If 0, no multiprocessing is used.
    :param verbose:             verbosity level. If True, display progress bar and logs.
    :param kwargs:              additional arguments to pass to torchaudio.load(). For more information see:
                                https://pytorch.org/audio/stable/generated/torchaudio.load.html
    """
    if verbose:
        _LOGGER.info("Reducing noise from audio files.")

    # create target directory:
    target_directory = _create_target_directory(target_directory)

    # get audio files:
    audio_files = _get_audio_files(audio_source)

    noise_reduce_arguments = {
        "target_directory": target_directory,
        "pad": pad,
        "atten_lim_db": atten_lim_db,
        "silence_threshold": silence_threshold,
        **kwargs,
    }

    if use_multiprocessing:
        results = _parallel_run(
            noise_reduce_type=DFN,
            noise_reduce_arguments=noise_reduce_arguments,
            n_workers=use_multiprocessing,
            audio_files=audio_files,
            description="Noise-reduction",
            verbose=verbose,
        )
    else:
        results = _run(
            noise_reduce_type=DFN,
            noise_reduce_arguments=noise_reduce_arguments,
            audio_files=audio_files,
            description="Noise-reduction",
            verbose=verbose,
        )

    return _process_results(results, verbose)


def reduce_noise(
    audio_source: str,
    target_directory: str,
    sample_rate: int = 16000,
    duration: int = None,
    channel: int = None,
    silence_threshold: float = None,
    use_multiprocessing: int = 0,
    verbose: bool = True,
):
    """
    Reduce noise from audio file or directory containing audio files.
    The audio files must be in .wav format.
    The cleaned audio files will be saved in the target_directory.
    For information about the noise reduction algorithm see:
    https://github.com/timsainb/noisereduce
    Notice that the saved files are in wav format, even if the original files are in other format.

    :param audio_source:        path to audio file or directory containing audio files
    :param target_directory:    path to directory to save the cleaned audio files.
    :param sample_rate:         Number of samples in one second in the audio file.
                                Pass `None` to keep the original sample rate.
    :param duration:            Duration of the audio file to clean in seconds.
                                Pass `None` to keep the original duration.
    :param channel:             Channel to clean. Pass the number of the channel to clean.
                                To clean all channels pass None.
    :param silence_threshold:   The threshold to remove silence from the audio, in dB.
                                If None, no silence removal is performed.
    :param use_multiprocessing: Number of processes to use for cleaning the audio files.
                                If 0, no multiprocessing is used.
    :param verbose:             Verbosity level. If True, display progress bar.
    """
    if verbose:
        _LOGGER.info("Reducing noise from audio files.")

    # create target directory:
    target_directory = _create_target_directory(target_directory)

    # get audio files:
    audio_files = _get_audio_files(audio_source)

    # Create the reduce noise object:
    noise_reduce_arguments = {
        "target_directory": target_directory,
        "sample_rate": sample_rate,
        "duration": duration,
        "channel": channel,
        "silence_threshold": silence_threshold,
    }

    if use_multiprocessing:
        results = _parallel_run(
            noise_reduce_type=ReduceNoise,
            noise_reduce_arguments=noise_reduce_arguments,
            n_workers=use_multiprocessing,
            audio_files=audio_files,
            description="Noise-reduction",
            verbose=verbose,
        )
    else:
        results = _run(
            noise_reduce_type=ReduceNoise,
            noise_reduce_arguments=noise_reduce_arguments,
            audio_files=audio_files,
            description="Noise-reduction",
            verbose=verbose,
        )

    return _process_results(results, verbose)


def _create_target_directory(target_directory: str) -> str:
    target_directory = Path(target_directory)
    if not target_directory.exists():
        target_directory.mkdir(parents=True, exist_ok=True)
    return str(target_directory)


def _get_audio_files(audio_source: str):
    audio_source = Path(audio_source)
    audio_files = []
    if audio_source.is_dir():
        audio_files = list(audio_source.glob("*.*"))
    elif audio_source.is_file():
        audio_files.append(audio_source)
    else:
        raise ValueError(
            f"audio_source must be a file or a directory, got {audio_source}"
        )
    return audio_files


def _parallel_run(
    noise_reduce_type: type[ReduceNoiseBase],
    noise_reduce_arguments: dict,
    n_workers: int,
    audio_files: list[Path],
    description: str,
    verbose: bool,
) -> list[tuple[bool, tuple[str, str]]]:
    """
    Run multiple noise reduce workers with multiprocessing to complete the tasks that will be created on the provided
    files using the given task creator.

    :param noise_reduce_type:   The noise reduce type to use.
    :param n_workers:           The number of workers to use.
    :param audio_files:         The audio files to use.
    :param description:         The description to use for the progress bar.
    :param verbose:             Verbosity.

    :returns: The collected results.
    """
    # Check the number of workers:
    if n_workers > len(audio_files):
        _LOGGER.warning(
            f"The number of workers ({n_workers}) is larger than the number of audio files ({len(audio_files)}). "
            f"Setting the number of workers to {len(audio_files)}."
        )
        n_workers = len(audio_files)

    # Initialize the multiprocessing queues:
    tasks_queue = Queue()
    results_queue = Queue()

    # Initialize the multiprocessing processes:
    task_completion_processes = [
        Process(
            target=_multiprocessing_complete_tasks,
            kwargs={
                "noise_reduce_type": noise_reduce_type,
                "noise_reduce_arguments": noise_reduce_arguments,
                "tasks_queue": tasks_queue,
                "results_queue": results_queue,
            },
        )
        for _ in range(n_workers)
    ]

    # Start the multiprocessing processes:
    for p in task_completion_processes:
        p.start()

    # Put the tasks in the queue:
    for audio_file in audio_files:
        # tasks_queue.put(task_creator.create_task(audio_file=audio_file).to_tuple())
        tasks_queue.put(audio_file)

    # Put the stop marks in the queue:
    for _ in range(n_workers):
        tasks_queue.put(_MULTIPROCESSING_STOP_MARK)

    # Collect the results:
    results = []
    stop_marks_counter = 0
    with tqdm(
        desc=description,
        unit="file",
        total=len(audio_files),
        disable=not verbose,
    ) as progressbar:
        while True:
            # Get a result from the queue:
            result: tuple[bool, tuple[str, str]] = results_queue.get()
            if result == _MULTIPROCESSING_STOP_MARK:
                stop_marks_counter += 1
                if stop_marks_counter == n_workers:
                    break
            else:
                # Collect the result:
                results.append(result)
                progressbar.update(1)

    # Wait for the processes to finish:
    for p in task_completion_processes:
        p.join()

    return results


def _run(
    noise_reduce_type: type[ReduceNoiseBase],
    noise_reduce_arguments: dict,
    audio_files: list[Path],
    description: str,
    verbose: bool,
) -> list[tuple[bool, tuple[str, str]]]:
    """
    Run the noise reduce algorithm on the given audio files and collect the results.

    :param noise_reduce_type:       The noise reduce type to use.
    :param noise_reduce_arguments:  The noisereduce initialization kwargs.
    :param audio_files:             The audio files to use.
    :param description:             The description to use for the progress bar.
    :param verbose:                 Verbosity.

    :returns: The collected results.
    """
    # Create the reduce noise object:
    noise_reducer = noise_reduce_type(**noise_reduce_arguments)

    # Run the noise reduce algorithm on the audio files and collect the results:
    results = []
    for audio_file in tqdm(
        audio_files,
        desc=description,
        unit="file",
        total=len(audio_files),
        disable=not verbose,
    ):
        results.append(noise_reducer.reduce_noise(audio_file=audio_file))

    return results


def _process_results(
    results: list[tuple[bool, tuple[str, str]]], verbose: bool
) -> tuple[dict, dict]:
    """
    Process the results of the tasks.

    :param results: The results to process.
    :param verbose: Verbosity.

    :returns: The processed results as a tuple of successes and errors.
    """
    if verbose:
        _LOGGER.info("Summarizing the results.")
    successes = {}
    errors = {}
    for is_error, result in results:
        if is_error:
            errors[result[0]] = result[1]
        else:
            successes[result[0]] = result[1]
    if verbose:
        _LOGGER.info(f"Done ({len(successes)}/{len(successes) + len(errors)})\n")

    return successes, errors
 + requirements: + - librosa + - noisereduce + - deepfilternet + - torchaudio>=2.1.2 + code_origin: '' + base_image: mlrun/mlrun + filename: noise_reduction.py entry_points: reduce_noise: - has_kwargs: false - name: reduce_noise - has_varargs: false - doc: 'Reduce noise from audio file or directory containing audio files. - - The audio files must be in .wav format. - - The cleaned audio files will be saved in the target_directory. - - For information about the noise reduction algorithm see: - - https://github.com/timsainb/noisereduce - - Notice that the saved files are in wav format, even if the original files - are in other format.' parameters: - name: audio_source type: str @@ -52,78 +58,82 @@ spec: type: bool doc: Verbosity level. If True, display progress bar. default: true - lineno: 388 - clean_audio: + name: reduce_noise + doc: 'Reduce noise from audio file or directory containing audio files. + + The audio files must be in .wav format. + + The cleaned audio files will be saved in the target_directory. + + For information about the noise reduction algorithm see: + + https://github.com/timsainb/noisereduce + + Notice that the saved files are in wav format, even if the original files + are in other format.' has_kwargs: false - name: clean_audio has_varargs: false + lineno: 388 + clean_audio: outputs: - type: torch.Tensor - doc: '' parameters: - name: self - name: data type: Tensor - lineno: 276 - save_audio: + name: clean_audio + doc: '' has_kwargs: false - name: save_audio has_varargs: false - doc: '' + lineno: 276 + save_audio: parameters: - name: self - name: audio type: ndarray - name: target_path type: Path - lineno: 256 - load_audio: + name: save_audio + doc: '' has_kwargs: false - name: load_audio has_varargs: false + lineno: 256 + load_audio: outputs: - type: torch.Tensor - doc: '' parameters: - name: self - name: file type: str - lineno: 268 - update_to_wav_suffix: + name: load_audio + doc: '' has_kwargs: false - name: update_to_wav_suffix has_varargs: false - doc: '' + lineno: 268 + update_to_wav_suffix: parameters: - name: self - name: audio_file type: Path - lineno: 125 - remove_silence: + name: update_to_wav_suffix + doc: '' has_kwargs: false - name: remove_silence has_varargs: false + lineno: 125 + remove_silence: outputs: - doc: The audio without silence. - doc: Remove silence sections from the audio. parameters: - name: self - name: audio type: ndarray doc: The audio to remove silence from. + name: remove_silence + doc: Remove silence sections from the audio. + has_kwargs: false + has_varargs: false lineno: 134 reduce_noise_dfn: - has_kwargs: true - name: reduce_noise_dfn - has_varargs: false - doc: 'Reduce noise from audio files using DeepFilterNet. - - For more information about the noise reduction algorithm see: - - https://github.com/Rikorose/DeepFilterNet - - Notice that the saved files are in wav format, even if the original files - are in other format.' parameters: - name: audio_source type: str @@ -153,27 +163,18 @@ spec: type: bool doc: verbosity level. If True, display progress bar and logs. default: true + name: reduce_noise_dfn + doc: 'Reduce noise from audio files using DeepFilterNet. + + For more information about the noise reduction algorithm see: + + https://github.com/Rikorose/DeepFilterNet + + Notice that the saved files are in wav format, even if the original files + are in other format.' + has_kwargs: true + has_varargs: false lineno: 322 - build: - code_origin: '' - base_image: mlrun/mlrun - requirements: - - librosa - - noisereduce - - deepfilternet - - torchaudio>=2.1.2 - functionSourceCode: import logging
from abc import ABCMeta, abstractmethod
from multiprocessing import Process, Queue
from pathlib import Path
from typing import List, Tuple, Type, Union

import librosa
import numpy as np
import torch
from scipy.io import wavfile
from tqdm import tqdm

#: The value to send into multiprocessing queues to stop the process:
_MULTIPROCESSING_STOP_MARK = "STOP"

# Get the global logger:
try:
    import mlrun

    _LOGGER = mlrun.get_or_create_ctx("noise_reduce").logger
except ModuleNotFoundError:
    _LOGGER = logging.getLogger()


class ReduceNoiseBase(metaclass=ABCMeta):
    """
    Base class for noise reduction.
    This class is aimed to be inherited by specific noise reduction algorithms.
    You must implement the following methods:
    - clean_audio:  The method to clean the audio, where the noise reduction algorithm is implemented.
    - save_audio:   The method to save the audio to a file.
    - load_audio:   The method to load the audio from a file.

    After implementing the above methods, you can use the reduce_noise method to reduce noise from audio files.
    """
    def __init__(
        self,
        target_directory: Path,
        verbose: bool = True,
        silence_threshold: float = None,
    ):
        self.target_directory = Path(target_directory)
        self.verbose = verbose
        self.silence_threshold = silence_threshold

    def reduce_noise(self, audio_file: Path) -> Tuple[bool, Tuple[str, str]]:
        """
        Reduce noise from the given audio file.

        :param audio_file:  The audio file to reduce noise from.

        :returns: A tuple of:
         - a boolean indicating whether an error occurred
         - a tuple of:
            - audio file name
            - target path in case of success / error message in case of failure.
        """
        try:
            if self.verbose:
                _LOGGER.info(f"Reducing noise from {audio_file.name}.")

            # Load audio data:
            audio = self.load_audio(file=str(audio_file))

            # Perform noise reduction:
            reduced_noise = self.clean_audio(data=audio)

            # Remove silence from the audio if necessary:
            reduced_noise = self.remove_silence(audio=reduced_noise)

            # Prepare target path:
            target_path = self.update_to_wav_suffix(audio_file=audio_file)

            # Save file:
            self.save_audio(
                audio=reduced_noise,
                target_path=target_path,
            )

            if self.verbose:
                _LOGGER.info(f"Saved cleaned audio file to {target_path}.")

            return False, (audio_file.name, str(target_path))
        except Exception as exception:
            if self.verbose:
                _LOGGER.error(f"Failed to reduce noise from {audio_file.name}.")
                _LOGGER.error(f"Error: {exception}")
            # Collect the error:
            return True, (audio_file.name, str(exception))

    @abstractmethod
    def clean_audio(self, data) -> Union[np.ndarray, torch.Tensor]:
        """
        Clean the audio from noise. Here you should implement the noise reduction algorithm.

        :param data:    The audio data to clean.

        :returns: The cleaned audio.
        """
        pass

    @abstractmethod
    def save_audio(self, audio: np.ndarray, target_path: Path):
        """
        Save the audio to a file.

        :param audio:       The audio to save.
        :param target_path: The target path to save the audio to.
        """
        pass

    @abstractmethod
    def load_audio(self, file: str) -> Tuple[Union[np.ndarray, torch.Tensor], int]:
        """
        Load the audio from a file.

        :param file:    The file to load the audio from.

        :returns: A tuple of:
            - the audio data
            - the sample rate
        """
        pass

    def update_to_wav_suffix(self, audio_file: Path):
        target_path = self.target_directory / audio_file.name
        if target_path.suffix != ".wav":
            old_suffix = target_path.suffix[1:]
            target_path = target_path.with_stem(target_path.stem + f"_{old_suffix}")
            return target_path.with_suffix(".wav")
        else:
            return target_path

    def remove_silence(
        self,
        audio: np.ndarray,
    ):
        """
        Remove silence sections from the audio.

        :param audio:   The audio to remove silence from.

        :returns: The audio without silence.
        """
        if self.silence_threshold is None:
            return audio

        # Get the indices of the non-silent frames:
        non_silent_indices = librosa.effects.split(
            y=audio,
            top_db=self.silence_threshold,
            frame_length=2048,
            hop_length=256,
        )

        # Get the non-silent audio:
        non_silent_audio = np.concatenate(
            [audio[:, start:end] for start, end in non_silent_indices], axis=1
        )

        return non_silent_audio


class ReduceNoise(ReduceNoiseBase):
    def __init__(
        self,
        target_directory: Path,
        verbose: bool = True,
        silence_threshold: float = None,
        sample_rate: int = 16000,
        duration: int = None,
        channel: int = None,
    ):
        super().__init__(target_directory, verbose, silence_threshold)
        self.sample_rate = sample_rate
        self.duration = duration
        self.channel = channel

    def save_audio(self, audio: np.ndarray, target_path: Path):
        # If the audio has more than one channel, transpose it in order to save it:
        if len(audio) > 1:
            audio = audio.T

        wavfile.write(
            filename=target_path,
            rate=self.sample_rate,
            data=audio,
        )

    def load_audio(self, file: str) -> np.ndarray:
        data, sr = librosa.load(
            path=file,
            sr=self.sample_rate,
            mono=False,  # keep channels separate
            duration=self.duration,
        )
        # set sample rate:
        self.sample_rate = int(sr)

        # convert to int with scaling for 16-bit integer
        data *= 32767 / np.max(np.abs(data))  # re-scaling
        data = data.astype(np.int16)  # change data type

        # select channel
        data_to_reduce = data[self.channel] if self.channel is not None else data
        return data_to_reduce

    def clean_audio(self, data: np.ndarray) -> np.ndarray:
        try:
            import noisereduce
        except ImportError as e:
            raise ImportError("Please install noisereduce package") from e

        reduced_noise = noisereduce.reduce_noise(y=data, sr=self.sample_rate)

        # add channel back after noise reduction
        if self.channel is not None:
            # putting the channel back in the data
            data[self.channel] = reduced_noise
            # updating the data to save
            reduced_noise = data

        return reduced_noise


class DFN(ReduceNoiseBase):
    def __init__(
        self,
        target_directory: Path,
        verbose: bool = True,
        silence_threshold: float = None,
        pad: bool = True,
        atten_lim_db: int = None,
        **kwargs,
    ):
        super().__init__(target_directory, verbose, silence_threshold)
        self.pad = pad
        self.atten_lim_db = atten_lim_db
        self.kwargs = kwargs

        # import required packages
        try:
            from df.enhance import init_df
        except ImportError as e:
            raise ImportError("Please install deepfilternet packages") from e

        if self.verbose:
            _LOGGER.info("Loading DeepFilterNet2 model.")

        # Load the model:
        model, df_state, _ = init_df()
        self.model = model
        self.df_state = df_state
        self.sample_rate = self.df_state.sr()

    def save_audio(self, audio: np.ndarray, target_path: Path):
        try:
            from df.enhance import save_audio
        except ImportError as e:
            raise ImportError("Please install deepfilternet package") from e
        save_audio(
            file=target_path.name,
            audio=audio,
            sr=self.sample_rate,
            output_dir=str(self.target_directory),
        )

    def load_audio(self, file: str) -> torch.Tensor:
        try:
            from df.enhance import load_audio
        except ImportError as e:
            raise ImportError("Please install deepfilternet package") from e
        audio, _ = load_audio(file=file, sr=self.sample_rate, **self.kwargs)
        return audio

    def clean_audio(self, data: torch.Tensor) -> torch.Tensor:
        try:
            from df.enhance import enhance
        except ImportError as e:
            raise ImportError("Please install deepfilternet package") from e
        return enhance(
            model=self.model,
            df_state=self.df_state,
            audio=data,
            pad=self.pad,
            atten_lim_db=self.atten_lim_db,
        )


def _multiprocessing_complete_tasks(
    noise_reduce_type: Type[ReduceNoiseBase],
    noise_reduce_arguments: dict,
    tasks_queue: Queue,
    results_queue: Queue,
):
    """
    Complete the tasks in the given queue and put the results in the given results queue. The function will stop when
    the given tasks queue will receive the stop mark. It is aimed to be used with multiprocessing as a process.

    :param noise_reduce_type:       The noise reduce type to use.
    :param noise_reduce_arguments:  The noisereduce initialization kwargs.
    :param tasks_queue:             A queue to get the tasks from.
    :param results_queue:           A queue to put the results in.
    """
    # Initialize the reduce noise object
    noise_reducer = noise_reduce_type(**noise_reduce_arguments)

    # Start listening to the tasks queue:
    while True:
        # Get the audio_file:
        audio_file = tasks_queue.get()
        if audio_file == _MULTIPROCESSING_STOP_MARK:
            break
        audio_file = Path(audio_file)
        # Apply noise reduction and collect the result:
        results_queue.put(noise_reducer.reduce_noise(audio_file=audio_file))

    # Mark the end of the tasks:
    results_queue.put(_MULTIPROCESSING_STOP_MARK)


def reduce_noise_dfn(
    audio_source: str,
    target_directory: str,
    pad: bool = True,
    atten_lim_db: int = None,
    silence_threshold: float = None,
    use_multiprocessing: int = 0,
    verbose: bool = True,
    **kwargs,
):
    """
    Reduce noise from audio files using DeepFilterNet.
    For more information about the noise reduction algorithm see:
    https://github.com/Rikorose/DeepFilterNet
    Notice that the saved files are in wav format, even if the original files are in other format.

    :param audio_source:        path to audio file or directory of audio files
    :param target_directory:    path to target directory to save cleaned audio files
    :param pad:                 whether to pad the audio file with zeros before cleaning
    :param atten_lim_db:        maximum attenuation in dB
    :param silence_threshold:   the threshold to remove silence from the audio, in dB. If None, no silence removal is
                                performed.
    :param use_multiprocessing: Number of processes to use for cleaning the audio files.
                                If 0, no multiprocessing is used.
    :param verbose:             verbosity level. If True, display progress bar and logs.
    :param kwargs:              additional arguments to pass to torchaudio.load(). For more information see:
                                https://pytorch.org/audio/stable/generated/torchaudio.load.html
    """
    if verbose:
        _LOGGER.info("Reducing noise from audio files.")

    # create target directory:
    target_directory = _create_target_directory(target_directory)

    # get audio files:
    audio_files = _get_audio_files(audio_source)

    noise_reduce_arguments = {
        "target_directory": target_directory,
        "pad": pad,
        "atten_lim_db": atten_lim_db,
        "silence_threshold": silence_threshold,
        **kwargs,
    }

    if use_multiprocessing:
        results = _parallel_run(
            noise_reduce_type=DFN,
            noise_reduce_arguments=noise_reduce_arguments,
            n_workers=use_multiprocessing,
            audio_files=audio_files,
            description="Noise-reduction",
            verbose=verbose,
        )
    else:
        results = _run(
            noise_reduce_type=DFN,
            noise_reduce_arguments=noise_reduce_arguments,
            audio_files=audio_files,
            description="Noise-reduction",
            verbose=verbose,
        )

    return _process_results(results, verbose)


def reduce_noise(
    audio_source: str,
    target_directory: str,
    sample_rate: int = 16000,
    duration: int = None,
    channel: int = None,
    silence_threshold: float = None,
    use_multiprocessing: int = 0,
    verbose: bool = True,
):
    """
    Reduce noise from audio file or directory containing audio files.
    The audio files must be in .wav format.
    The cleaned audio files will be saved in the target_directory.
    For information about the noise reduction algorithm see:
    https://github.com/timsainb/noisereduce
    Notice that the saved files are in wav format, even if the original files are in other format.

    :param audio_source:        path to audio file or directory containing audio files
    :param target_directory:    path to directory to save the cleaned audio files.
    :param sample_rate:         Number of samples in one second in the audio file.
                                Pass `None` to keep the original sample rate.
    :param duration:            Duration of the audio file to clean in seconds.
                                Pass `None` to keep the original duration.
    :param channel:             Channel to clean. Pass the number of the channel to clean.
                                To clean all channels pass None.
    :param silence_threshold:   The threshold to remove silence from the audio, in dB.
                                If None, no silence removal is performed.
    :param use_multiprocessing: Number of processes to use for cleaning the audio files.
                                If 0, no multiprocessing is used.
    :param verbose:             Verbosity level. If True, display progress bar.
    """
    if verbose:
        _LOGGER.info("Reducing noise from audio files.")

    # create target directory:
    target_directory = _create_target_directory(target_directory)

    # get audio files:
    audio_files = _get_audio_files(audio_source)

    # Create the reduce noise object:
    noise_reduce_arguments = {
        "target_directory": target_directory,
        "sample_rate": sample_rate,
        "duration": duration,
        "channel": channel,
        "silence_threshold": silence_threshold,
    }

    if use_multiprocessing:
        results = _parallel_run(
            noise_reduce_type=ReduceNoise,
            noise_reduce_arguments=noise_reduce_arguments,
            n_workers=use_multiprocessing,
            audio_files=audio_files,
            description="Noise-reduction",
            verbose=verbose,
        )
    else:
        results = _run(
            noise_reduce_type=ReduceNoise,
            noise_reduce_arguments=noise_reduce_arguments,
            audio_files=audio_files,
            description="Noise-reduction",
            verbose=verbose,
        )

    return _process_results(results, verbose)


def _create_target_directory(target_directory: str) -> str:
    target_directory = Path(target_directory)
    if not target_directory.exists():
        target_directory.mkdir(parents=True, exist_ok=True)
    return str(target_directory)


def _get_audio_files(audio_source: str):
    audio_source = Path(audio_source)
    audio_files = []
    if audio_source.is_dir():
        audio_files = list(audio_source.glob("*.*"))
    elif audio_source.is_file():
        audio_files.append(audio_source)
    else:
        raise ValueError(
            f"audio_source must be a file or a directory, got {audio_source}"
        )
    return audio_files


def _parallel_run(
    noise_reduce_type: Type[ReduceNoiseBase],
    noise_reduce_arguments: dict,
    n_workers: int,
    audio_files: List[Path],
    description: str,
    verbose: bool,
) -> List[Tuple[bool, Tuple[str, str]]]:
    """
    Run multiple noise reduce workers with multiprocessing to complete the tasks that will be created on the provided
    files using the given task creator.

    :param noise_reduce_type:   The noise reduce type to use.
    :param n_workers:           The number of workers to use.
    :param audio_files:         The audio files to use.
    :param description:         The description to use for the progress bar.
    :param verbose:             Verbosity.

    :returns: The collected results.
    """
    # Check the number of workers:
    if n_workers > len(audio_files):
        _LOGGER.warning(
            f"The number of workers ({n_workers}) is larger than the number of audio files ({len(audio_files)}). "
            f"Setting the number of workers to {len(audio_files)}."
        )
        n_workers = len(audio_files)

    # Initialize the multiprocessing queues:
    tasks_queue = Queue()
    results_queue = Queue()

    # Initialize the multiprocessing processes:
    task_completion_processes = [
        Process(
            target=_multiprocessing_complete_tasks,
            kwargs={
                "noise_reduce_type": noise_reduce_type,
                "noise_reduce_arguments": noise_reduce_arguments,
                "tasks_queue": tasks_queue,
                "results_queue": results_queue,
            },
        )
        for _ in range(n_workers)
    ]

    # Start the multiprocessing processes:
    for p in task_completion_processes:
        p.start()

    # Put the tasks in the queue:
    for audio_file in audio_files:
        # tasks_queue.put(task_creator.create_task(audio_file=audio_file).to_tuple())
        tasks_queue.put(audio_file)

    # Put the stop marks in the queue:
    for _ in range(n_workers):
        tasks_queue.put(_MULTIPROCESSING_STOP_MARK)

    # Collect the results:
    results = []
    stop_marks_counter = 0
    with tqdm(
        desc=description,
        unit="file",
        total=len(audio_files),
        disable=not verbose,
    ) as progressbar:
        while True:
            # Get a result from the queue:
            result: Tuple[bool, Tuple[str, str]] = results_queue.get()
            if result == _MULTIPROCESSING_STOP_MARK:
                stop_marks_counter += 1
                if stop_marks_counter == n_workers:
                    break
            else:
                # Collect the result:
                results.append(result)
                progressbar.update(1)

    # Wait for the processes to finish:
    for p in task_completion_processes:
        p.join()

    return results


def _run(
    noise_reduce_type: Type[ReduceNoiseBase],
    noise_reduce_arguments: dict,
    audio_files: List[Path],
    description: str,
    verbose: bool,
) -> List[Tuple[bool, Tuple[str, str]]]:
    """
    Run the noise reduce algorithm on the given audio files and collect the results.

    :param noise_reduce_type:       The noise reduce type to use.
    :param noise_reduce_arguments:  The noisereduce initialization kwargs.
    :param audio_files:             The audio files to use.
    :param description:             The description to use for the progress bar.
    :param verbose:                 Verbosity.

    :returns: The collected results.
    """
    # Create the reduce noise object:
    noise_reducer = noise_reduce_type(**noise_reduce_arguments)

    # Run the noise reduce algorithm on the audio files and collect the results:
    results = []
    for audio_file in tqdm(
        audio_files,
        desc=description,
        unit="file",
        total=len(audio_files),
        disable=not verbose,
    ):
        results.append(noise_reducer.reduce_noise(audio_file=audio_file))

    return results


def _process_results(
    results: List[Tuple[bool, Tuple[str, str]]], verbose: bool
) -> Tuple[dict, dict]:
    """
    Process the results of the tasks.

    :param results: The results to process.
    :param verbose: Verbosity.

    :returns: The processed results as a tuple of successes and errors.
    """
    if verbose:
        _LOGGER.info("Summarizing the results.")
    successes = {}
    errors = {}
    for is_error, result in results:
        if is_error:
            errors[result[0]] = result[1]
        else:
            successes[result[0]] = result[1]
    if verbose:
        _LOGGER.info(f"Done ({len(successes)}/{len(successes) + len(errors)})\n")

    return successes, errors
 - origin_filename: '' - description: Reduce noise from audio files command: '' - image: '' + description: Reduce noise from audio files default_handler: reduce_noise - disable_auto_mount: false -metadata: - name: noise-reduction - tag: '' - categories: - - data-preparation - - audio -kind: job -verbose: false diff --git a/functions/src/noise_reduction/noise_reduction.py b/functions/src/noise_reduction/noise_reduction.py index f0fff5504..c9184922d 100644 --- a/functions/src/noise_reduction/noise_reduction.py +++ b/functions/src/noise_reduction/noise_reduction.py @@ -2,7 +2,6 @@ from abc import ABCMeta, abstractmethod from multiprocessing import Process, Queue from pathlib import Path -from typing import List, Tuple, Type, Union import librosa import numpy as np @@ -33,6 +32,7 @@ class ReduceNoiseBase(metaclass=ABCMeta): After implementing the above methods, you can use the reduce_noise method to reduce noise from audio files. """ + def __init__( self, target_directory: Path, @@ -43,7 +43,7 @@ def __init__( self.verbose = verbose self.silence_threshold = silence_threshold - def reduce_noise(self, audio_file: Path) -> Tuple[bool, Tuple[str, str]]: + def reduce_noise(self, audio_file: Path) -> tuple[bool, tuple[str, str]]: """ Reduce noise from the given audio file. @@ -89,7 +89,7 @@ def reduce_noise(self, audio_file: Path) -> Tuple[bool, Tuple[str, str]]: return True, (audio_file.name, str(exception)) @abstractmethod - def clean_audio(self, data) -> Union[np.ndarray, torch.Tensor]: + def clean_audio(self, data) -> np.ndarray | torch.Tensor: """ Clean the audio from noise. Here you should implement the noise reduction algorithm. @@ -110,7 +110,7 @@ def save_audio(self, audio: np.ndarray, target_path: Path): pass @abstractmethod - def load_audio(self, file: str) -> Tuple[Union[np.ndarray, torch.Tensor], int]: + def load_audio(self, file: str) -> tuple[np.ndarray | torch.Tensor, int]: """ Load the audio from a file. @@ -288,7 +288,7 @@ def clean_audio(self, data: torch.Tensor) -> torch.Tensor: def _multiprocessing_complete_tasks( - noise_reduce_type: Type[ReduceNoiseBase], + noise_reduce_type: type[ReduceNoiseBase], noise_reduce_arguments: dict, tasks_queue: Queue, results_queue: Queue, @@ -478,13 +478,13 @@ def _get_audio_files(audio_source: str): def _parallel_run( - noise_reduce_type: Type[ReduceNoiseBase], + noise_reduce_type: type[ReduceNoiseBase], noise_reduce_arguments: dict, n_workers: int, - audio_files: List[Path], + audio_files: list[Path], description: str, verbose: bool, -) -> List[Tuple[bool, Tuple[str, str]]]: +) -> list[tuple[bool, tuple[str, str]]]: """ Run multiple noise reduce workers with multiprocessing to complete the tasks that will be created on the provided files using the given task creator. @@ -547,7 +547,7 @@ def _parallel_run( ) as progressbar: while True: # Get a result from the queue: - result: Tuple[bool, Tuple[str, str]] = results_queue.get() + result: tuple[bool, tuple[str, str]] = results_queue.get() if result == _MULTIPROCESSING_STOP_MARK: stop_marks_counter += 1 if stop_marks_counter == n_workers: @@ -565,12 +565,12 @@ def _parallel_run( def _run( - noise_reduce_type: Type[ReduceNoiseBase], + noise_reduce_type: type[ReduceNoiseBase], noise_reduce_arguments: dict, - audio_files: List[Path], + audio_files: list[Path], description: str, verbose: bool, -) -> List[Tuple[bool, Tuple[str, str]]]: +) -> list[tuple[bool, tuple[str, str]]]: """ Run the noise reduce algorithm on the given audio files and collect the results. @@ -600,8 +600,8 @@ def _run( def _process_results( - results: List[Tuple[bool, Tuple[str, str]]], verbose: bool -) -> Tuple[dict, dict]: + results: list[tuple[bool, tuple[str, str]]], verbose: bool +) -> tuple[dict, dict]: """ Process the results of the tasks. diff --git a/functions/src/onnx_utils/function.yaml b/functions/src/onnx_utils/function.yaml index 023c034d3..c163f0e5a 100644 --- a/functions/src/onnx_utils/function.yaml +++ b/functions/src/onnx_utils/function.yaml @@ -1,17 +1,18 @@ -kind: job metadata: + tag: '' + name: onnx-utils categories: - utils - deep-learning - name: onnx-utils - tag: '' verbose: false +kind: job spec: + image: '' + disable_auto_mount: false build: - code_origin: '' - base_image: mlrun/mlrun origin_filename: '' - functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Callable, Dict, List, Tuple

import mlrun


class _ToONNXConversions:
    """
    An ONNX conversion functions library class.
    """

    @staticmethod
    def tf_keras_to_onnx(
        model_handler,
        onnx_model_name: str = None,
        optimize_model: bool = True,
        input_signature: List[Tuple[Tuple[int], str]] = None,
    ):
        """
        Convert a TF.Keras model to an ONNX model and log it back to MLRun as a new model object.

        :param model_handler:   An initialized TFKerasModelHandler with a loaded model to convert to ONNX.
        :param onnx_model_name: The name to use to log the converted ONNX model. If not given, the given `model_name`
                                will be used with an additional suffix `_onnx`. Defaulted to None.
        :param optimize_model:  Whether or not to optimize the ONNX model using 'onnxoptimizer' before saving the model.
                                Defaulted to True.
        :param input_signature: A list of the input layers shape and data type properties. Expected to receive a list
                                where each element is an input layer tuple. An input layer tuple is a tuple of:
                                [0] = Layer's shape, a tuple of integers.
                                [1] = Layer's data type, a mlrun.data_types.ValueType string.
                                If None, the input signature will be tried to be read from the model artifact. Defaulted
                                to None.
        """
        # Import the framework and handler:
        import tensorflow as tf
        from mlrun.frameworks.tf_keras import TFKerasUtils

        # Check the given 'input_signature' parameter:
        if input_signature is None:
            # Read the inputs from the model:
            try:
                model_handler.read_inputs_from_model()
            except Exception as error:
                raise mlrun.errors.MLRunRuntimeError(
                    f"Please provide the 'input_signature' parameter. The function tried reading the input layers "
                    f"information automatically but failed with the following error: {error}"
                )
        else:
            # Parse the 'input_signature' parameter:
            input_signature = [
                tf.TensorSpec(
                    shape=shape,
                    dtype=TFKerasUtils.convert_value_type_to_tf_dtype(
                        value_type=value_type
                    ),
                )
                for (shape, value_type) in input_signature
            ]

        # Convert to ONNX:
        model_handler.to_onnx(
            model_name=onnx_model_name,
            input_signature=input_signature,
            optimize=optimize_model,
        )

    @staticmethod
    def pytorch_to_onnx(
        model_handler,
        onnx_model_name: str = None,
        optimize_model: bool = True,
        input_signature: List[Tuple[Tuple[int, ...], str]] = None,
        input_layers_names: List[str] = None,
        output_layers_names: List[str] = None,
        dynamic_axes: Dict[str, Dict[int, str]] = None,
        is_batched: bool = True,
    ):
        """
        Convert a PyTorch model to an ONNX model and log it back to MLRun as a new model object.

        :param model_handler:       An initialized PyTorchModelHandler with a loaded model to convert to ONNX.
        :param onnx_model_name:     The name to use to log the converted ONNX model. If not given, the given
                                    `model_name` will be used with an additional suffix `_onnx`. Defaulted to None.
        :param optimize_model:      Whether or not to optimize the ONNX model using 'onnxoptimizer' before saving the
                                    model. Defaulted to True.
        :param input_signature:     A list of the input layers shape and data type properties. Expected to receive a
                                    list where each element is an input layer tuple. An input layer tuple is a tuple of:
                                    [0] = Layer's shape, a tuple of integers.
                                    [1] = Layer's data type, a mlrun.data_types.ValueType string.
                                    If None, the input signature will be tried to be read from the model artifact.
                                    Defaulted to None.
        :param input_layers_names:  List of names to assign to the input nodes of the graph in order. All of the other
                                    parameters (inner layers) can be set as well by passing additional names in the
                                    list. The order is by the order of the parameters in the model. If None, the inputs
                                    will be read from the handler's inputs. If its also None, it is defaulted to:
                                    "input_0", "input_1", ...
        :param output_layers_names: List of names to assign to the output nodes of the graph in order. If None, the
                                    outputs will be read from the handler's outputs. If its also None, it is defaulted
                                    to: "output_0" (for multiple outputs, this parameter must be provided).
        :param dynamic_axes:        If part of the input / output shape is dynamic, like (batch_size, 3, 32, 32) you can
                                    specify it by giving a dynamic axis to the input / output layer by its name as
                                    follows: {
                                        "input layer name": {0: "batch_size"},
                                        "output layer name": {0: "batch_size"},
                                    }
                                    If provided, the 'is_batched' flag will be ignored. Defaulted to None.
        :param is_batched:          Whether to include a batch size as the first axis in every input and output layer.
                                    Defaulted to True. Will be ignored if 'dynamic_axes' is provided.
        """
        # Import the framework and handler:
        import torch
        from mlrun.frameworks.pytorch import PyTorchUtils

        # Parse the 'input_signature' parameter:
        if input_signature is not None:
            input_signature = tuple(
                [
                    torch.zeros(
                        size=shape,
                        dtype=PyTorchUtils.convert_value_type_to_torch_dtype(
                            value_type=value_type
                        ),
                    )
                    for (shape, value_type) in input_signature
                ]
            )

        # Convert to ONNX:
        model_handler.to_onnx(
            model_name=onnx_model_name,
            input_sample=input_signature,
            optimize=optimize_model,
            input_layers_names=input_layers_names,
            output_layers_names=output_layers_names,
            dynamic_axes=dynamic_axes,
            is_batched=is_batched,
        )


# Map for getting the conversion function according to the provided framework:
_CONVERSION_MAP = {
    "tensorflow.keras": _ToONNXConversions.tf_keras_to_onnx,
    "torch": _ToONNXConversions.pytorch_to_onnx,
}  # type: Dict[str, Callable]


def to_onnx(
    context: mlrun.MLClientCtx,
    model_path: str,
    load_model_kwargs: dict = None,
    onnx_model_name: str = None,
    optimize_model: bool = True,
    framework_kwargs: Dict[str, Any] = None,
):
    """
    Convert the given model to an ONNX model.

    :param context:           The MLRun function execution context
    :param model_path:        The model path store object.
    :param load_model_kwargs: Keyword arguments to pass to the `AutoMLRun.load_model` method.
    :param onnx_model_name:   The name to use to log the converted ONNX model. If not given, the given `model_name` will
                              be used with an additional suffix `_onnx`. Defaulted to None.
    :param optimize_model:    Whether to optimize the ONNX model using 'onnxoptimizer' before saving the model.
                              Defaulted to True.
    :param framework_kwargs:  Additional arguments each framework may require to convert to ONNX. To get the doc string
                              of the desired framework onnx conversion function, pass "help".
    """
    from mlrun.frameworks.auto_mlrun.auto_mlrun import AutoMLRun

    # Get a model handler of the required framework:
    load_model_kwargs = load_model_kwargs or {}
    model_handler = AutoMLRun.load_model(
        model_path=model_path, context=context, **load_model_kwargs
    )

    # Get the model's framework:
    framework = model_handler.FRAMEWORK_NAME

    # Use the conversion map to get the specific framework to onnx conversion:
    if framework not in _CONVERSION_MAP:
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"The following framework: '{framework}', has no ONNX conversion."
        )
    conversion_function = _CONVERSION_MAP[framework]

    # Check if needed to print the function's doc string ("help" is passed):
    if framework_kwargs == "help":
        print(conversion_function.__doc__)
        return

    # Set the default empty framework kwargs if needed:
    if framework_kwargs is None:
        framework_kwargs = {}

    # Run the conversion:
    try:
        conversion_function(
            model_handler=model_handler,
            onnx_model_name=onnx_model_name,
            optimize_model=optimize_model,
            **framework_kwargs,
        )
    except TypeError as exception:
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"ERROR: A TypeError exception was raised during the conversion:\n{exception}. "
            f"Please read the {framework} framework conversion function doc string by passing 'help' in the "
            f"'framework_kwargs' dictionary parameter."
        )


def optimize(
    context: mlrun.MLClientCtx,
    model_path: str,
    handler_init_kwargs: dict = None,
    optimizations: List[str] = None,
    fixed_point: bool = False,
    optimized_model_name: str = None,
):
    """
    Optimize the given ONNX model.

    :param context:              The MLRun function execution context.
    :param model_path:           Path to the ONNX model object.
    :param handler_init_kwargs:  Keyword arguments to pass to the `ONNXModelHandler` init method preloading.
    :param optimizations:        List of possible optimizations. To see what optimizations are available, pass "help".
                                 If None, all the optimizations will be used. Defaulted to None.
    :param fixed_point:          Optimize the weights using fixed point. Defaulted to False.
    :param optimized_model_name: The name of the optimized model. If None, the original model will be overridden.
                                 Defaulted to None.
    """
    # Import the model handler:
    import onnxoptimizer
    from mlrun.frameworks.onnx import ONNXModelHandler

    # Check if needed to print the available optimizations ("help" is passed):
    if optimizations == "help":
        available_passes = "\n* ".join(onnxoptimizer.get_available_passes())
        print(f"The available optimizations are:\n* {available_passes}")
        return

    # Create the model handler:
    handler_init_kwargs = handler_init_kwargs or {}
    model_handler = ONNXModelHandler(
        model_path=model_path, context=context, **handler_init_kwargs
    )

    # Load the ONNX model:
    model_handler.load()

    # Optimize the model using the given configurations:
    model_handler.optimize(optimizations=optimizations, fixed_point=fixed_point)

    # Rename if needed:
    if optimized_model_name is not None:
        model_handler.set_model_name(model_name=optimized_model_name)

    # Log the optimized model:
    model_handler.log()
 + with_mlrun: false + functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any

import mlrun


class _ToONNXConversions:
    """
    An ONNX conversion functions library class.
    """

    @staticmethod
    def tf_keras_to_onnx(
        model_handler,
        onnx_model_name: str = None,
        optimize_model: bool = True,
        input_signature: list[tuple[tuple[int], str]] = None,
    ):
        """
        Convert a TF.Keras model to an ONNX model and log it back to MLRun as a new model object.

        :param model_handler:   An initialized TFKerasModelHandler with a loaded model to convert to ONNX.
        :param onnx_model_name: The name to use to log the converted ONNX model. If not given, the given `model_name`
                                will be used with an additional suffix `_onnx`. Defaulted to None.
        :param optimize_model:  Whether or not to optimize the ONNX model using 'onnxoptimizer' before saving the model.
                                Defaulted to True.
        :param input_signature: A list of the input layers shape and data type properties. Expected to receive a list
                                where each element is an input layer tuple. An input layer tuple is a tuple of:
                                [0] = Layer's shape, a tuple of integers.
                                [1] = Layer's data type, a mlrun.data_types.ValueType string.
                                If None, the input signature will be tried to be read from the model artifact. Defaulted
                                to None.
        """
        # Import the framework and handler:
        import tensorflow as tf
        from mlrun.frameworks.tf_keras import TFKerasUtils

        # Check the given 'input_signature' parameter:
        if input_signature is None:
            # Read the inputs from the model:
            try:
                model_handler.read_inputs_from_model()
            except Exception as error:
                raise mlrun.errors.MLRunRuntimeError(
                    f"Please provide the 'input_signature' parameter. The function tried reading the input layers "
                    f"information automatically but failed with the following error: {error}"
                )
        else:
            # Parse the 'input_signature' parameter:
            input_signature = [
                tf.TensorSpec(
                    shape=shape,
                    dtype=TFKerasUtils.convert_value_type_to_tf_dtype(
                        value_type=value_type
                    ),
                )
                for (shape, value_type) in input_signature
            ]

        # Convert to ONNX:
        model_handler.to_onnx(
            model_name=onnx_model_name,
            input_signature=input_signature,
            optimize=optimize_model,
        )

    @staticmethod
    def pytorch_to_onnx(
        model_handler,
        onnx_model_name: str = None,
        optimize_model: bool = True,
        input_signature: list[tuple[tuple[int, ...], str]] = None,
        input_layers_names: list[str] = None,
        output_layers_names: list[str] = None,
        dynamic_axes: dict[str, dict[int, str]] = None,
        is_batched: bool = True,
    ):
        """
        Convert a PyTorch model to an ONNX model and log it back to MLRun as a new model object.

        :param model_handler:       An initialized PyTorchModelHandler with a loaded model to convert to ONNX.
        :param onnx_model_name:     The name to use to log the converted ONNX model. If not given, the given
                                    `model_name` will be used with an additional suffix `_onnx`. Defaulted to None.
        :param optimize_model:      Whether or not to optimize the ONNX model using 'onnxoptimizer' before saving the
                                    model. Defaulted to True.
        :param input_signature:     A list of the input layers shape and data type properties. Expected to receive a
                                    list where each element is an input layer tuple. An input layer tuple is a tuple of:
                                    [0] = Layer's shape, a tuple of integers.
                                    [1] = Layer's data type, a mlrun.data_types.ValueType string.
                                    If None, the input signature will be tried to be read from the model artifact.
                                    Defaulted to None.
        :param input_layers_names:  List of names to assign to the input nodes of the graph in order. All of the other
                                    parameters (inner layers) can be set as well by passing additional names in the
                                    list. The order is by the order of the parameters in the model. If None, the inputs
                                    will be read from the handler's inputs. If its also None, it is defaulted to:
                                    "input_0", "input_1", ...
        :param output_layers_names: List of names to assign to the output nodes of the graph in order. If None, the
                                    outputs will be read from the handler's outputs. If its also None, it is defaulted
                                    to: "output_0" (for multiple outputs, this parameter must be provided).
        :param dynamic_axes:        If part of the input / output shape is dynamic, like (batch_size, 3, 32, 32) you can
                                    specify it by giving a dynamic axis to the input / output layer by its name as
                                    follows: {
                                        "input layer name": {0: "batch_size"},
                                        "output layer name": {0: "batch_size"},
                                    }
                                    If provided, the 'is_batched' flag will be ignored. Defaulted to None.
        :param is_batched:          Whether to include a batch size as the first axis in every input and output layer.
                                    Defaulted to True. Will be ignored if 'dynamic_axes' is provided.
        """
        # Import the framework and handler:
        import torch
        from mlrun.frameworks.pytorch import PyTorchUtils

        # Parse the 'input_signature' parameter:
        if input_signature is not None:
            input_signature = tuple(
                [
                    torch.zeros(
                        size=shape,
                        dtype=PyTorchUtils.convert_value_type_to_torch_dtype(
                            value_type=value_type
                        ),
                    )
                    for (shape, value_type) in input_signature
                ]
            )

        # Convert to ONNX:
        model_handler.to_onnx(
            model_name=onnx_model_name,
            input_sample=input_signature,
            optimize=optimize_model,
            input_layers_names=input_layers_names,
            output_layers_names=output_layers_names,
            dynamic_axes=dynamic_axes,
            is_batched=is_batched,
        )


# Map for getting the conversion function according to the provided framework:
_CONVERSION_MAP = {
    "tensorflow.keras": _ToONNXConversions.tf_keras_to_onnx,
    "torch": _ToONNXConversions.pytorch_to_onnx,
}  # type: Dict[str, Callable]


def to_onnx(
    context: mlrun.MLClientCtx,
    model_path: str,
    load_model_kwargs: dict = None,
    onnx_model_name: str = None,
    optimize_model: bool = True,
    framework_kwargs: dict[str, Any] = None,
):
    """
    Convert the given model to an ONNX model.

    :param context:           The MLRun function execution context
    :param model_path:        The model path store object.
    :param load_model_kwargs: Keyword arguments to pass to the `AutoMLRun.load_model` method.
    :param onnx_model_name:   The name to use to log the converted ONNX model. If not given, the given `model_name` will
                              be used with an additional suffix `_onnx`. Defaulted to None.
    :param optimize_model:    Whether to optimize the ONNX model using 'onnxoptimizer' before saving the model.
                              Defaulted to True.
    :param framework_kwargs:  Additional arguments each framework may require to convert to ONNX. To get the doc string
                              of the desired framework onnx conversion function, pass "help".
    """
    from mlrun.frameworks.auto_mlrun.auto_mlrun import AutoMLRun

    # Get a model handler of the required framework:
    load_model_kwargs = load_model_kwargs or {}
    model_handler = AutoMLRun.load_model(
        model_path=model_path, context=context, **load_model_kwargs
    )

    # Get the model's framework:
    framework = model_handler.FRAMEWORK_NAME

    # Use the conversion map to get the specific framework to onnx conversion:
    if framework not in _CONVERSION_MAP:
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"The following framework: '{framework}', has no ONNX conversion."
        )
    conversion_function = _CONVERSION_MAP[framework]

    # Check if needed to print the function's doc string ("help" is passed):
    if framework_kwargs == "help":
        print(conversion_function.__doc__)
        return

    # Set the default empty framework kwargs if needed:
    if framework_kwargs is None:
        framework_kwargs = {}

    # Run the conversion:
    try:
        conversion_function(
            model_handler=model_handler,
            onnx_model_name=onnx_model_name,
            optimize_model=optimize_model,
            **framework_kwargs,
        )
    except TypeError as exception:
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"ERROR: A TypeError exception was raised during the conversion:\n{exception}. "
            f"Please read the {framework} framework conversion function doc string by passing 'help' in the "
            f"'framework_kwargs' dictionary parameter."
        )


def optimize(
    context: mlrun.MLClientCtx,
    model_path: str,
    handler_init_kwargs: dict = None,
    optimizations: list[str] = None,
    fixed_point: bool = False,
    optimized_model_name: str = None,
):
    """
    Optimize the given ONNX model.

    :param context:              The MLRun function execution context.
    :param model_path:           Path to the ONNX model object.
    :param handler_init_kwargs:  Keyword arguments to pass to the `ONNXModelHandler` init method preloading.
    :param optimizations:        List of possible optimizations. To see what optimizations are available, pass "help".
                                 If None, all the optimizations will be used. Defaulted to None.
    :param fixed_point:          Optimize the weights using fixed point. Defaulted to False.
    :param optimized_model_name: The name of the optimized model. If None, the original model will be overridden.
                                 Defaulted to None.
    """
    # Import the model handler:
    import onnxoptimizer
    from mlrun.frameworks.onnx import ONNXModelHandler

    # Check if needed to print the available optimizations ("help" is passed):
    if optimizations == "help":
        available_passes = "\n* ".join(onnxoptimizer.get_available_passes())
        print(f"The available optimizations are:\n* {available_passes}")
        return

    # Create the model handler:
    handler_init_kwargs = handler_init_kwargs or {}
    model_handler = ONNXModelHandler(
        model_path=model_path, context=context, **handler_init_kwargs
    )

    # Load the ONNX model:
    model_handler.load()

    # Optimize the model using the given configurations:
    model_handler.optimize(optimizations=optimizations, fixed_point=fixed_point)

    # Rename if needed:
    if optimized_model_name is not None:
        model_handler.set_model_name(model_name=optimized_model_name)

    # Log the optimized model:
    model_handler.log()
 requirements: - tqdm~=4.67.1 - tensorflow~=2.19.0 @@ -24,17 +25,13 @@ spec: - onnxmltools~=1.13.0 - tf2onnx~=1.16.1 - plotly~=5.23 - with_mlrun: false + code_origin: '' auto_build: true - disable_auto_mount: false - description: ONNX intigration in MLRun, some utils functions for the ONNX framework, - optimizing and converting models from different framework to ONNX using MLRun. - image: '' + base_image: mlrun/mlrun + allow_empty_resources: true + filename: onnx_utils.py entry_points: tf_keras_to_onnx: - doc: Convert a TF.Keras model to an ONNX model and log it back to MLRun as a - new model object. - name: tf_keras_to_onnx parameters: - name: model_handler doc: An initialized TFKerasModelHandler with a loaded model to convert to @@ -51,20 +48,20 @@ spec: saving the model. Defaulted to True. default: true - name: input_signature - type: List[Tuple[Tuple[int], str]] + type: list[tuple[tuple[int], str]] doc: 'A list of the input layers shape and data type properties. Expected to receive a list where each element is an input layer tuple. An input layer tuple is a tuple of: [0] = Layer''s shape, a tuple of integers. [1] = Layer''s data type, a mlrun.data_types.ValueType string. If None, the input signature will be tried to be read from the model artifact. Defaulted to None.' default: null - has_varargs: false + name: tf_keras_to_onnx + doc: Convert a TF.Keras model to an ONNX model and log it back to MLRun as a + new model object. has_kwargs: false + has_varargs: false lineno: 26 pytorch_to_onnx: - doc: Convert a PyTorch model to an ONNX model and log it back to MLRun as a - new model object. - name: pytorch_to_onnx parameters: - name: model_handler doc: An initialized PyTorchModelHandler with a loaded model to convert to @@ -81,7 +78,7 @@ spec: saving the model. Defaulted to True. default: true - name: input_signature - type: List[Tuple[Tuple[int, ], str]] + type: list[tuple[tuple[int, ], str]] doc: 'A list of the input layers shape and data type properties. Expected to receive a list where each element is an input layer tuple. An input layer tuple is a tuple of: [0] = Layer''s shape, a tuple of integers. [1] = Layer''s @@ -89,7 +86,7 @@ spec: will be tried to be read from the model artifact. Defaulted to None.' default: null - name: input_layers_names - type: List[str] + type: list[str] doc: 'List of names to assign to the input nodes of the graph in order. All of the other parameters (inner layers) can be set as well by passing additional names in the list. The order is by the order of the parameters in the model. @@ -97,14 +94,14 @@ spec: None, it is defaulted to: "input_0", "input_1", ...' default: null - name: output_layers_names - type: List[str] + type: list[str] doc: 'List of names to assign to the output nodes of the graph in order. If None, the outputs will be read from the handler''s outputs. If its also None, it is defaulted to: "output_0" (for multiple outputs, this parameter must be provided).' default: null - name: dynamic_axes - type: Dict[str, Dict[int, str]] + type: dict[str, dict[int, str]] doc: 'If part of the input / output shape is dynamic, like (batch_size, 3, 32, 32) you can specify it by giving a dynamic axis to the input / output layer by its name as follows: { "input layer name": {0: "batch_size"}, "output @@ -116,12 +113,13 @@ spec: doc: Whether to include a batch size as the first axis in every input and output layer. Defaulted to True. Will be ignored if 'dynamic_axes' is provided. default: true - has_varargs: false + name: pytorch_to_onnx + doc: Convert a PyTorch model to an ONNX model and log it back to MLRun as a + new model object. has_kwargs: false + has_varargs: false lineno: 81 to_onnx: - doc: Convert the given model to an ONNX model. - name: to_onnx parameters: - name: context type: MLClientCtx @@ -145,17 +143,17 @@ spec: the model. Defaulted to True. default: true - name: framework_kwargs - type: Dict[str, Any] + type: dict[str, Any] doc: Additional arguments each framework may require to convert to ONNX. To get the doc string of the desired framework onnx conversion function, pass "help". default: null - has_varargs: false + name: to_onnx + doc: Convert the given model to an ONNX model. has_kwargs: false + has_varargs: false lineno: 160 optimize: - doc: Optimize the given ONNX model. - name: optimize parameters: - name: context type: MLClientCtx @@ -168,7 +166,7 @@ spec: doc: Keyword arguments to pass to the `ONNXModelHandler` init method preloading. default: null - name: optimizations - type: List[str] + type: list[str] doc: List of possible optimizations. To see what optimizations are available, pass "help". If None, all the optimizations will be used. Defaulted to None. default: null @@ -181,9 +179,12 @@ spec: doc: The name of the optimized model. If None, the original model will be overridden. Defaulted to None. default: null - has_varargs: false + name: optimize + doc: Optimize the given ONNX model. has_kwargs: false + has_varargs: false lineno: 224 - default_handler: to_onnx - allow_empty_resources: true command: '' + description: ONNX intigration in MLRun, some utils functions for the ONNX framework, + optimizing and converting models from different framework to ONNX using MLRun. + default_handler: to_onnx diff --git a/functions/src/onnx_utils/onnx_utils.py b/functions/src/onnx_utils/onnx_utils.py index c26e011be..ed6890b55 100644 --- a/functions/src/onnx_utils/onnx_utils.py +++ b/functions/src/onnx_utils/onnx_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any, Callable, Dict, List, Tuple +from typing import Any import mlrun @@ -27,7 +27,7 @@ def tf_keras_to_onnx( model_handler, onnx_model_name: str = None, optimize_model: bool = True, - input_signature: List[Tuple[Tuple[int], str]] = None, + input_signature: list[tuple[tuple[int], str]] = None, ): """ Convert a TF.Keras model to an ONNX model and log it back to MLRun as a new model object. @@ -82,10 +82,10 @@ def pytorch_to_onnx( model_handler, onnx_model_name: str = None, optimize_model: bool = True, - input_signature: List[Tuple[Tuple[int, ...], str]] = None, - input_layers_names: List[str] = None, - output_layers_names: List[str] = None, - dynamic_axes: Dict[str, Dict[int, str]] = None, + input_signature: list[tuple[tuple[int, ...], str]] = None, + input_layers_names: list[str] = None, + output_layers_names: list[str] = None, + dynamic_axes: dict[str, dict[int, str]] = None, is_batched: bool = True, ): """ @@ -163,7 +163,7 @@ def to_onnx( load_model_kwargs: dict = None, onnx_model_name: str = None, optimize_model: bool = True, - framework_kwargs: Dict[str, Any] = None, + framework_kwargs: dict[str, Any] = None, ): """ Convert the given model to an ONNX model. @@ -225,7 +225,7 @@ def optimize( context: mlrun.MLClientCtx, model_path: str, handler_init_kwargs: dict = None, - optimizations: List[str] = None, + optimizations: list[str] = None, fixed_point: bool = False, optimized_model_name: str = None, ): diff --git a/functions/src/open_archive/function.yaml b/functions/src/open_archive/function.yaml index bf78b5fcd..451279f43 100644 --- a/functions/src/open_archive/function.yaml +++ b/functions/src/open_archive/function.yaml @@ -1,20 +1,20 @@ -kind: job +metadata: + tag: '' + name: open-archive + categories: + - utils verbose: false +kind: job spec: - command: '' + image: mlrun/mlrun disable_auto_mount: false - default_handler: open_archive build: - functionSourceCode: IyBDb3B5cmlnaHQgMjAyNSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKCmltcG9ydCBvcwppbXBvcnQgemlwZmlsZQppbXBvcnQgdGFyZmlsZQoKZnJvbSBtbHJ1bi5leGVjdXRpb24gaW1wb3J0IE1MQ2xpZW50Q3R4CmZyb20gbWxydW4uZGF0YXN0b3JlIGltcG9ydCBEYXRhSXRlbQpmcm9tIG1scnVuLmFydGlmYWN0cy5iYXNlIGltcG9ydCBEaXJBcnRpZmFjdAoKZnJvbSB1cmxsaWIucGFyc2UgaW1wb3J0IHVybHBhcnNlCgoKZGVmIG9wZW5fYXJjaGl2ZSgKICAgICAgICBjb250ZXh0OiBNTENsaWVudEN0eCwKICAgICAgICBhcmNoaXZlX3VybDogRGF0YUl0ZW0sCiAgICAgICAgc3ViZGlyOiBzdHIgPSAiY29udGVudC8iLAogICAgICAgIGtleTogc3RyID0gImNvbnRlbnQiLAogICAgICAgIHRhcmdldF9wYXRoOiBzdHIgPSBOb25lLAopOgogICAgIiIiT3BlbiBhIGZpbGUvb2JqZWN0IGFyY2hpdmUgaW50byBhIHRhcmdldCBkaXJlY3RvcnkuIEN1cnJlbnRseSwgc3VwcG9ydHMgemlwIGFuZCB0YXIuZ3ouCgogICAgOnBhcmFtIGNvbnRleHQ6ICAgICAgZnVuY3Rpb24gZXhlY3V0aW9uIGNvbnRleHQKICAgIDpwYXJhbSBhcmNoaXZlX3VybDogIHVybCBvZiBhcmNoaXZlIGZpbGUKICAgIDpwYXJhbSBzdWJkaXI6ICAgICAgIHBhdGggd2l0aGluIGFydGlmYWN0IHN0b3JlIHdoZXJlIGV4dHJhY3RlZCBmaWxlcyBhcmUgc3RvcmVkLCBkZWZhdWx0IGlzICIvY29udGVudCIKICAgIDpwYXJhbSBrZXk6ICAgICAgICAgIGtleSBvZiBhcmNoaXZlIGNvbnRlbnRzIGluIGFydGlmYWN0IHN0b3JlCiAgICA6cGFyYW0gdGFyZ2V0X3BhdGg6ICBmaWxlIHN5c3RlbSBwYXRoIHRvIHN0b3JlIGV4dHJhY3RlZCBmaWxlcwogICAgIiIiCgogICAgIyBSZXNvbHZlcyB0aGUgYXJjaGl2ZSBsb2NhbGx5CiAgICBhcmNoaXZlX3VybCA9IGFyY2hpdmVfdXJsLmxvY2FsKCkKICAgIHYzaW9fc3ViZGlyID0gTm9uZQogICAgIyBXaGVuIGN1c3RvbSBhcnRpZmFjdCBwYXRoIGlzIGRlZmluZWQKICAgIGlmIG5vdCB0YXJnZXRfcGF0aCBhbmQgY29udGV4dC5hcnRpZmFjdF9wYXRoOgogICAgICAgIHBhcnNlZF9zdWJkaXIgPSB1cmxwYXJzZShjb250ZXh0LmFydGlmYWN0X3BhdGgpCiAgICAgICAgaWYgcGFyc2VkX3N1YmRpci5zY2hlbWUgPT0gJ3MzJzoKICAgICAgICAgICAgc3ViZGlyID0gb3MucGF0aC5qb2luKGNvbnRleHQuYXJ0aWZhY3RfcGF0aCwgc3ViZGlyKQogICAgICAgIGVsaWYgcGFyc2VkX3N1YmRpci5zY2hlbWUgPT0gJ3YzaW8nOgogICAgICAgICAgICB2M2lvX3N1YmRpciA9IG9zLnBhdGguam9pbihjb250ZXh0LmFydGlmYWN0X3BhdGgsIHN1YmRpcikgICMgVXNpbmcgdjNpb19zdWJkaXIgZm9yIGxvZ2dpbmcKICAgICAgICAgICAgc3ViZGlyID0gJy92M2lvJyArIHBhcnNlZF9zdWJkaXIucGF0aCArICcvJyArIHN1YmRpcgogICAgICAgICAgICBjb250ZXh0LmxvZ2dlci5pbmZvKGYnVXNpbmcgdjNpbyBzY2hlbWUsIGV4dHJhY3RpbmcgdG8ge3N1YmRpcn0nKQogICAgICAgIGVsc2U6CiAgICAgICAgICAgIGNvbnRleHQubG9nZ2VyLmluZm8oZidVbnJlY29nbml6YWJsZSBzY2hlbWUsIGV4dHJhY3RpbmcgdG8ge3N1YmRpcn0nKQoKICAgICMgV2hlbiB3b3JraW5nIG9uIENFLCB0YXJnZXQgcGF0aCBtaWdodCBiZSBvbiBzMwogICAgaWYgJ3MzJyBpbiAodGFyZ2V0X3BhdGggb3Igc3ViZGlyKToKICAgICAgICBjb250ZXh0LmxvZ2dlci5pbmZvKGYnVXNpbmcgczMgc2NoZW1lLCBleHRyYWN0aW5nIHRvIHt0YXJnZXRfcGF0aCBvciBzdWJkaXJ9JykKCiAgICAgICAgaWYgYXJjaGl2ZV91cmwuZW5kc3dpdGgoImd6Iik6CiAgICAgICAgICAgIF9leHRyYWN0X2d6X2ZpbGUoYXJjaGl2ZV91cmw9YXJjaGl2ZV91cmwsIHN1YmRpcj1zdWJkaXIsIHRhcmdldF9wYXRoPXRhcmdldF9wYXRoLCBpbl9zMz1UcnVlKQoKICAgICAgICBlbGlmIGFyY2hpdmVfdXJsLmVuZHN3aXRoKCJ6aXAiKToKICAgICAgICAgICAgX2V4dHJhY3RfemlwX2ZpbGUoYXJjaGl2ZV91cmw9YXJjaGl2ZV91cmwsIHN1YmRpcj1zdWJkaXIsIHRhcmdldF9wYXRoPXRhcmdldF9wYXRoLCBpbl9zMz1UcnVlKQogICAgICAgIGVsc2U6CiAgICAgICAgICAgIHJhaXNlIFZhbHVlRXJyb3IoZiJ1bnN1cHBvcnRlZCBhcmNoaXZlIHR5cGUgaW4ge2FyY2hpdmVfdXJsfSIpCiAgICBlbHNlOgogICAgICAgIGlmIGFyY2hpdmVfdXJsLmVuZHN3aXRoKCJneiIpOgogICAgICAgICAgICBfZXh0cmFjdF9nel9maWxlKGFyY2hpdmVfdXJsPWFyY2hpdmVfdXJsLCBzdWJkaXI9c3ViZGlyLCB0YXJnZXRfcGF0aD10YXJnZXRfcGF0aCkKICAgICAgICBlbGlmIGFyY2hpdmVfdXJsLmVuZHN3aXRoKCJ6aXAiKToKICAgICAgICAgICAgX2V4dHJhY3RfemlwX2ZpbGUoYXJjaGl2ZV91cmw9YXJjaGl2ZV91cmwsIHN1YmRpcj1zdWJkaXIsIHRhcmdldF9wYXRoPXRhcmdldF9wYXRoKQogICAgICAgIGVsc2U6CiAgICAgICAgICAgIHJhaXNlIFZhbHVlRXJyb3IoZiJ1bnN1cHBvcnRlZCBhcmNoaXZlIHR5cGUgaW4ge2FyY2hpdmVfdXJsfSIpCgogICAgaWYgdjNpb19zdWJkaXI6CiAgICAgICAgc3ViZGlyID0gdjNpb19zdWJkaXIKCiAgICBjb250ZXh0LmxvZ2dlci5pbmZvKGYnTG9nZ2luZyBhcnRpZmFjdCB0byB7KHRhcmdldF9wYXRoIG9yIHN1YmRpcil9JykKICAgIGNvbnRleHQubG9nX2FydGlmYWN0KERpckFydGlmYWN0KGtleT1rZXksIHRhcmdldF9wYXRoPSh0YXJnZXRfcGF0aCBvciBzdWJkaXIpKSkKCgpkZWYgX2V4dHJhY3RfZ3pfZmlsZShhcmNoaXZlX3VybDogc3RyLCB0YXJnZXRfcGF0aDogc3RyID0gTm9uZSwgc3ViZGlyOiBzdHIgPSAiY29udGVudC8iLCBpbl9zMzogYm9vbCA9IEZhbHNlKToKICAgIGlmIGluX3MzOgogICAgICAgIGNsaWVudCA9IF9pbml0X2JvdG8zX2NsaWVudCgpCiAgICAgICAgd2l0aCB0YXJmaWxlLm9wZW4oYXJjaGl2ZV91cmwsIG1vZGU9InJ8Z3oiKSBhcyByZWY6CiAgICAgICAgICAgIGZvciBtZW1iZXIgaW4gcmVmLmdldG1lbWJlcnMoKToKICAgICAgICAgICAgICAgIGRhdGEgPSByZWYuZXh0cmFjdGZpbGUobWVtYmVyPW1lbWJlcikucmVhZCgpCiAgICAgICAgICAgICAgICBjbGllbnQucHV0X29iamVjdChCb2R5PWRhdGEsIEJ1Y2tldD11cmxwYXJzZSh0YXJnZXRfcGF0aCBvciBzdWJkaXIpLm5ldGxvYywKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIEtleT1mJ3t1cmxwYXJzZSh0YXJnZXRfcGF0aCBvciBzdWJkaXIpLnBhdGhbMTpdfXttZW1iZXIubmFtZX0nKQogICAgZWxzZToKICAgICAgICBvcy5tYWtlZGlycyh0YXJnZXRfcGF0aCBvciBzdWJkaXIsIGV4aXN0X29rPVRydWUpCiAgICAgICAgd2l0aCB0YXJmaWxlLm9wZW4oYXJjaGl2ZV91cmwsIG1vZGU9InI6Z3oiKSBhcyByZWY6CiAgICAgICAgICAgIGZvciBlbnRyeSBpbiByZWY6CiAgICAgICAgICAgICAgICAjIFZhbGlkYXRlIHRoYXQgdGhlcmUgaXMgbm8gcGF0aCB0cmF2ZXJzYWwgaW4gdGhlIGFyY2hpdmUKICAgICAgICAgICAgICAgIGlmIG9zLnBhdGguaXNhYnMoZW50cnkubmFtZSkgb3IgIi4uIiBpbiBlbnRyeS5uYW1lOgogICAgICAgICAgICAgICAgICAgIHJhaXNlIFZhbHVlRXJyb3IoZiJJbGxlZ2FsIHRhciBhcmNoaXZlIGVudHJ5OiB7ZW50cnkubmFtZX0iKQoKICAgICAgICAgICAgICAgIHJlZi5leHRyYWN0KGVudHJ5LCB0YXJnZXRfcGF0aCBvciBzdWJkaXIpCgoKZGVmIF9leHRyYWN0X3ppcF9maWxlKGFyY2hpdmVfdXJsLCB0YXJnZXRfcGF0aDogc3RyID0gTm9uZSwgc3ViZGlyOiBzdHIgPSAiY29udGVudC8iLCBpbl9zMzogYm9vbCA9IEZhbHNlKToKICAgIGlmIGluX3MzOgogICAgICAgIGNsaWVudCA9IF9pbml0X2JvdG8zX2NsaWVudCgpCiAgICAgICAgd2l0aCB6aXBmaWxlLlppcEZpbGUoYXJjaGl2ZV91cmwsICJyIikgYXMgcmVmOgogICAgICAgICAgICBmb3IgZmlsZW5hbWUgaW4gcmVmLm5hbWVsaXN0KCk6CiAgICAgICAgICAgICAgICBkYXRhID0gcmVmLnJlYWQoZmlsZW5hbWUpCiAgICAgICAgICAgICAgICBjbGllbnQucHV0X29iamVjdChCb2R5PWRhdGEsIEJ1Y2tldD11cmxwYXJzZSh0YXJnZXRfcGF0aCBvciBzdWJkaXIpLm5ldGxvYywKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIEtleT1mJ3t1cmxwYXJzZSh0YXJnZXRfcGF0aCBvciBzdWJkaXIpLnBhdGhbMTpdfXtmaWxlbmFtZX0nKQogICAgZWxzZToKICAgICAgICB3aXRoIHppcGZpbGUuWmlwRmlsZShhcmNoaXZlX3VybCwgInIiKSBhcyByZWY6CiAgICAgICAgICAgICMgVmFsaWRhdGUgdGhhdCB0aGVyZSBpcyBubyBwYXRoIHRyYXZlcnNhbCBpbiB0aGUgYXJjaGl2ZQogICAgICAgICAgICBmb3IgZW50cnkgaW4gcmVmLm5hbWVsaXN0KCk6CiAgICAgICAgICAgICAgICBpZiBvcy5wYXRoLmlzYWJzKGVudHJ5KSBvciAiLi4iIGluIGVudHJ5OgogICAgICAgICAgICAgICAgICAgIHJhaXNlIFZhbHVlRXJyb3IoZiJJbGxlZ2FsIHppcCBhcmNoaXZlIGVudHJ5OiB7ZW50cnl9IikKICAgICAgICAgICAgb3MubWFrZWRpcnModGFyZ2V0X3BhdGggb3Igc3ViZGlyLCBleGlzdF9vaz1UcnVlKQogICAgICAgICAgICByZWYuZXh0cmFjdGFsbCh0YXJnZXRfcGF0aCBvciBzdWJkaXIpCgoKZGVmIF9pbml0X2JvdG8zX2NsaWVudCgpOgogICAgaW1wb3J0IGJvdG8zCiAgICBpZiBvcy5lbnZpcm9uLmdldCgnUzNfRU5EUE9JTlRfVVJMJyk6CiAgICAgICAgY2xpZW50ID0gYm90bzMuY2xpZW50KCdzMycsIGVuZHBvaW50X3VybD1vcy5lbnZpcm9uLmdldCgnUzNfRU5EUE9JTlRfVVJMJykpCiAgICBlbHNlOgogICAgICAgIGNsaWVudCA9IGJvdG8zLmNsaWVudCgnczMnKQogICAgcmV0dXJuIGNsaWVudA== - code_origin: '' origin_filename: '' - description: Open a file/object archive into a target directory - image: mlrun/mlrun + functionSourceCode: IyBDb3B5cmlnaHQgMjAyNSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKCmltcG9ydCBvcwppbXBvcnQgdGFyZmlsZQppbXBvcnQgemlwZmlsZQpmcm9tIHVybGxpYi5wYXJzZSBpbXBvcnQgdXJscGFyc2UKCmZyb20gbWxydW4uYXJ0aWZhY3RzLmJhc2UgaW1wb3J0IERpckFydGlmYWN0CmZyb20gbWxydW4uZGF0YXN0b3JlIGltcG9ydCBEYXRhSXRlbQpmcm9tIG1scnVuLmV4ZWN1dGlvbiBpbXBvcnQgTUxDbGllbnRDdHgKCgpkZWYgb3Blbl9hcmNoaXZlKAogICAgY29udGV4dDogTUxDbGllbnRDdHgsCiAgICBhcmNoaXZlX3VybDogRGF0YUl0ZW0sCiAgICBzdWJkaXI6IHN0ciA9ICJjb250ZW50LyIsCiAgICBrZXk6IHN0ciA9ICJjb250ZW50IiwKICAgIHRhcmdldF9wYXRoOiBzdHIgPSBOb25lLAopOgogICAgIiIiT3BlbiBhIGZpbGUvb2JqZWN0IGFyY2hpdmUgaW50byBhIHRhcmdldCBkaXJlY3RvcnkuIEN1cnJlbnRseSwgc3VwcG9ydHMgemlwIGFuZCB0YXIuZ3ouCgogICAgOnBhcmFtIGNvbnRleHQ6ICAgICAgZnVuY3Rpb24gZXhlY3V0aW9uIGNvbnRleHQKICAgIDpwYXJhbSBhcmNoaXZlX3VybDogIHVybCBvZiBhcmNoaXZlIGZpbGUKICAgIDpwYXJhbSBzdWJkaXI6ICAgICAgIHBhdGggd2l0aGluIGFydGlmYWN0IHN0b3JlIHdoZXJlIGV4dHJhY3RlZCBmaWxlcyBhcmUgc3RvcmVkLCBkZWZhdWx0IGlzICIvY29udGVudCIKICAgIDpwYXJhbSBrZXk6ICAgICAgICAgIGtleSBvZiBhcmNoaXZlIGNvbnRlbnRzIGluIGFydGlmYWN0IHN0b3JlCiAgICA6cGFyYW0gdGFyZ2V0X3BhdGg6ICBmaWxlIHN5c3RlbSBwYXRoIHRvIHN0b3JlIGV4dHJhY3RlZCBmaWxlcwogICAgIiIiCgogICAgIyBSZXNvbHZlcyB0aGUgYXJjaGl2ZSBsb2NhbGx5CiAgICBhcmNoaXZlX3VybCA9IGFyY2hpdmVfdXJsLmxvY2FsKCkKICAgIHYzaW9fc3ViZGlyID0gTm9uZQogICAgIyBXaGVuIGN1c3RvbSBhcnRpZmFjdCBwYXRoIGlzIGRlZmluZWQKICAgIGlmIG5vdCB0YXJnZXRfcGF0aCBhbmQgY29udGV4dC5hcnRpZmFjdF9wYXRoOgogICAgICAgIHBhcnNlZF9zdWJkaXIgPSB1cmxwYXJzZShjb250ZXh0LmFydGlmYWN0X3BhdGgpCiAgICAgICAgaWYgcGFyc2VkX3N1YmRpci5zY2hlbWUgPT0gInMzIjoKICAgICAgICAgICAgc3ViZGlyID0gb3MucGF0aC5qb2luKGNvbnRleHQuYXJ0aWZhY3RfcGF0aCwgc3ViZGlyKQogICAgICAgIGVsaWYgcGFyc2VkX3N1YmRpci5zY2hlbWUgPT0gInYzaW8iOgogICAgICAgICAgICB2M2lvX3N1YmRpciA9IG9zLnBhdGguam9pbigKICAgICAgICAgICAgICAgIGNvbnRleHQuYXJ0aWZhY3RfcGF0aCwgc3ViZGlyCiAgICAgICAgICAgICkgICMgVXNpbmcgdjNpb19zdWJkaXIgZm9yIGxvZ2dpbmcKICAgICAgICAgICAgc3ViZGlyID0gIi92M2lvIiArIHBhcnNlZF9zdWJkaXIucGF0aCArICIvIiArIHN1YmRpcgogICAgICAgICAgICBjb250ZXh0LmxvZ2dlci5pbmZvKGYiVXNpbmcgdjNpbyBzY2hlbWUsIGV4dHJhY3RpbmcgdG8ge3N1YmRpcn0iKQogICAgICAgIGVsc2U6CiAgICAgICAgICAgIGNvbnRleHQubG9nZ2VyLmluZm8oZiJVbnJlY29nbml6YWJsZSBzY2hlbWUsIGV4dHJhY3RpbmcgdG8ge3N1YmRpcn0iKQoKICAgICMgV2hlbiB3b3JraW5nIG9uIENFLCB0YXJnZXQgcGF0aCBtaWdodCBiZSBvbiBzMwogICAgaWYgInMzIiBpbiAodGFyZ2V0X3BhdGggb3Igc3ViZGlyKToKICAgICAgICBjb250ZXh0LmxvZ2dlci5pbmZvKGYiVXNpbmcgczMgc2NoZW1lLCBleHRyYWN0aW5nIHRvIHt0YXJnZXRfcGF0aCBvciBzdWJkaXJ9IikKCiAgICAgICAgaWYgYXJjaGl2ZV91cmwuZW5kc3dpdGgoImd6Iik6CiAgICAgICAgICAgIF9leHRyYWN0X2d6X2ZpbGUoCiAgICAgICAgICAgICAgICBhcmNoaXZlX3VybD1hcmNoaXZlX3VybCwKICAgICAgICAgICAgICAgIHN1YmRpcj1zdWJkaXIsCiAgICAgICAgICAgICAgICB0YXJnZXRfcGF0aD10YXJnZXRfcGF0aCwKICAgICAgICAgICAgICAgIGluX3MzPVRydWUsCiAgICAgICAgICAgICkKCiAgICAgICAgZWxpZiBhcmNoaXZlX3VybC5lbmRzd2l0aCgiemlwIik6CiAgICAgICAgICAgIF9leHRyYWN0X3ppcF9maWxlKAogICAgICAgICAgICAgICAgYXJjaGl2ZV91cmw9YXJjaGl2ZV91cmwsCiAgICAgICAgICAgICAgICBzdWJkaXI9c3ViZGlyLAogICAgICAgICAgICAgICAgdGFyZ2V0X3BhdGg9dGFyZ2V0X3BhdGgsCiAgICAgICAgICAgICAgICBpbl9zMz1UcnVlLAogICAgICAgICAgICApCiAgICAgICAgZWxzZToKICAgICAgICAgICAgcmFpc2UgVmFsdWVFcnJvcihmInVuc3VwcG9ydGVkIGFyY2hpdmUgdHlwZSBpbiB7YXJjaGl2ZV91cmx9IikKICAgIGVsc2U6CiAgICAgICAgaWYgYXJjaGl2ZV91cmwuZW5kc3dpdGgoImd6Iik6CiAgICAgICAgICAgIF9leHRyYWN0X2d6X2ZpbGUoCiAgICAgICAgICAgICAgICBhcmNoaXZlX3VybD1hcmNoaXZlX3VybCwgc3ViZGlyPXN1YmRpciwgdGFyZ2V0X3BhdGg9dGFyZ2V0X3BhdGgKICAgICAgICAgICAgKQogICAgICAgIGVsaWYgYXJjaGl2ZV91cmwuZW5kc3dpdGgoInppcCIpOgogICAgICAgICAgICBfZXh0cmFjdF96aXBfZmlsZSgKICAgICAgICAgICAgICAgIGFyY2hpdmVfdXJsPWFyY2hpdmVfdXJsLCBzdWJkaXI9c3ViZGlyLCB0YXJnZXRfcGF0aD10YXJnZXRfcGF0aAogICAgICAgICAgICApCiAgICAgICAgZWxzZToKICAgICAgICAgICAgcmFpc2UgVmFsdWVFcnJvcihmInVuc3VwcG9ydGVkIGFyY2hpdmUgdHlwZSBpbiB7YXJjaGl2ZV91cmx9IikKCiAgICBpZiB2M2lvX3N1YmRpcjoKICAgICAgICBzdWJkaXIgPSB2M2lvX3N1YmRpcgoKICAgIGNvbnRleHQubG9nZ2VyLmluZm8oZiJMb2dnaW5nIGFydGlmYWN0IHRvIHsodGFyZ2V0X3BhdGggb3Igc3ViZGlyKX0iKQogICAgY29udGV4dC5sb2dfYXJ0aWZhY3QoRGlyQXJ0aWZhY3Qoa2V5PWtleSwgdGFyZ2V0X3BhdGg9KHRhcmdldF9wYXRoIG9yIHN1YmRpcikpKQoKCmRlZiBfZXh0cmFjdF9nel9maWxlKAogICAgYXJjaGl2ZV91cmw6IHN0ciwKICAgIHRhcmdldF9wYXRoOiBzdHIgPSBOb25lLAogICAgc3ViZGlyOiBzdHIgPSAiY29udGVudC8iLAogICAgaW5fczM6IGJvb2wgPSBGYWxzZSwKKToKICAgIGlmIGluX3MzOgogICAgICAgIGNsaWVudCA9IF9pbml0X2JvdG8zX2NsaWVudCgpCiAgICAgICAgd2l0aCB0YXJmaWxlLm9wZW4oYXJjaGl2ZV91cmwsIG1vZGU9InJ8Z3oiKSBhcyByZWY6CiAgICAgICAgICAgIGZvciBtZW1iZXIgaW4gcmVmLmdldG1lbWJlcnMoKToKICAgICAgICAgICAgICAgIGRhdGEgPSByZWYuZXh0cmFjdGZpbGUobWVtYmVyPW1lbWJlcikucmVhZCgpCiAgICAgICAgICAgICAgICBjbGllbnQucHV0X29iamVjdCgKICAgICAgICAgICAgICAgICAgICBCb2R5PWRhdGEsCiAgICAgICAgICAgICAgICAgICAgQnVja2V0PXVybHBhcnNlKHRhcmdldF9wYXRoIG9yIHN1YmRpcikubmV0bG9jLAogICAgICAgICAgICAgICAgICAgIEtleT1mInt1cmxwYXJzZSh0YXJnZXRfcGF0aCBvciBzdWJkaXIpLnBhdGhbMTpdfXttZW1iZXIubmFtZX0iLAogICAgICAgICAgICAgICAgKQogICAgZWxzZToKICAgICAgICBvcy5tYWtlZGlycyh0YXJnZXRfcGF0aCBvciBzdWJkaXIsIGV4aXN0X29rPVRydWUpCiAgICAgICAgd2l0aCB0YXJmaWxlLm9wZW4oYXJjaGl2ZV91cmwsIG1vZGU9InI6Z3oiKSBhcyByZWY6CiAgICAgICAgICAgIGZvciBlbnRyeSBpbiByZWY6CiAgICAgICAgICAgICAgICAjIFZhbGlkYXRlIHRoYXQgdGhlcmUgaXMgbm8gcGF0aCB0cmF2ZXJzYWwgaW4gdGhlIGFyY2hpdmUKICAgICAgICAgICAgICAgIGlmIG9zLnBhdGguaXNhYnMoZW50cnkubmFtZSkgb3IgIi4uIiBpbiBlbnRyeS5uYW1lOgogICAgICAgICAgICAgICAgICAgIHJhaXNlIFZhbHVlRXJyb3IoZiJJbGxlZ2FsIHRhciBhcmNoaXZlIGVudHJ5OiB7ZW50cnkubmFtZX0iKQoKICAgICAgICAgICAgICAgIHJlZi5leHRyYWN0KGVudHJ5LCB0YXJnZXRfcGF0aCBvciBzdWJkaXIpCgoKZGVmIF9leHRyYWN0X3ppcF9maWxlKAogICAgYXJjaGl2ZV91cmwsIHRhcmdldF9wYXRoOiBzdHIgPSBOb25lLCBzdWJkaXI6IHN0ciA9ICJjb250ZW50LyIsIGluX3MzOiBib29sID0gRmFsc2UKKToKICAgIGlmIGluX3MzOgogICAgICAgIGNsaWVudCA9IF9pbml0X2JvdG8zX2NsaWVudCgpCiAgICAgICAgd2l0aCB6aXBmaWxlLlppcEZpbGUoYXJjaGl2ZV91cmwsICJyIikgYXMgcmVmOgogICAgICAgICAgICBmb3IgZmlsZW5hbWUgaW4gcmVmLm5hbWVsaXN0KCk6CiAgICAgICAgICAgICAgICBkYXRhID0gcmVmLnJlYWQoZmlsZW5hbWUpCiAgICAgICAgICAgICAgICBjbGllbnQucHV0X29iamVjdCgKICAgICAgICAgICAgICAgICAgICBCb2R5PWRhdGEsCiAgICAgICAgICAgICAgICAgICAgQnVja2V0PXVybHBhcnNlKHRhcmdldF9wYXRoIG9yIHN1YmRpcikubmV0bG9jLAogICAgICAgICAgICAgICAgICAgIEtleT1mInt1cmxwYXJzZSh0YXJnZXRfcGF0aCBvciBzdWJkaXIpLnBhdGhbMTpdfXtmaWxlbmFtZX0iLAogICAgICAgICAgICAgICAgKQogICAgZWxzZToKICAgICAgICB3aXRoIHppcGZpbGUuWmlwRmlsZShhcmNoaXZlX3VybCwgInIiKSBhcyByZWY6CiAgICAgICAgICAgICMgVmFsaWRhdGUgdGhhdCB0aGVyZSBpcyBubyBwYXRoIHRyYXZlcnNhbCBpbiB0aGUgYXJjaGl2ZQogICAgICAgICAgICBmb3IgZW50cnkgaW4gcmVmLm5hbWVsaXN0KCk6CiAgICAgICAgICAgICAgICBpZiBvcy5wYXRoLmlzYWJzKGVudHJ5KSBvciAiLi4iIGluIGVudHJ5OgogICAgICAgICAgICAgICAgICAgIHJhaXNlIFZhbHVlRXJyb3IoZiJJbGxlZ2FsIHppcCBhcmNoaXZlIGVudHJ5OiB7ZW50cnl9IikKICAgICAgICAgICAgb3MubWFrZWRpcnModGFyZ2V0X3BhdGggb3Igc3ViZGlyLCBleGlzdF9vaz1UcnVlKQogICAgICAgICAgICByZWYuZXh0cmFjdGFsbCh0YXJnZXRfcGF0aCBvciBzdWJkaXIpCgoKZGVmIF9pbml0X2JvdG8zX2NsaWVudCgpOgogICAgaW1wb3J0IGJvdG8zCgogICAgIyBCYWNrd2FyZCBjb21wYXRpYmlsaXR5OiBTdXBwb3J0IGJvdGggUzNfRU5EUE9JTlRfVVJMIChkZXByZWNhdGVkKSBhbmQgQVdTX0VORFBPSU5UX1VSTF9TMwogICAgIyBUT0RPOiBSZW1vdmUgdGhpcyBpbiAxLjEyLjAKICAgIGVuZHBvaW50X3VybCA9IG9zLmVudmlyb24uZ2V0KCJBV1NfRU5EUE9JTlRfVVJMX1MzIikgb3Igb3MuZW52aXJvbi5nZXQoCiAgICAgICAgIlMzX0VORFBPSU5UX1VSTCIKICAgICkKCiAgICBpZiBlbmRwb2ludF91cmw6CiAgICAgICAgY2xpZW50ID0gYm90bzMuY2xpZW50KCJzMyIsIGVuZHBvaW50X3VybD1lbmRwb2ludF91cmwpCiAgICBlbHNlOgogICAgICAgIGNsaWVudCA9IGJvdG8zLmNsaWVudCgiczMiKQogICAgcmV0dXJuIGNsaWVudAo= + code_origin: '' + filename: open_archive.py entry_points: open_archive: - has_kwargs: false - lineno: 27 - name: open_archive parameters: - name: context type: MLClientCtx @@ -35,11 +35,12 @@ spec: type: str doc: file system path to store extracted files default: null + name: open_archive doc: Open a file/object archive into a target directory. Currently, supports zip and tar.gz. + has_kwargs: false has_varargs: false -metadata: - name: open-archive - categories: - - utils - tag: '' + lineno: 26 + command: '' + description: Open a file/object archive into a target directory + default_handler: open_archive diff --git a/functions/src/open_archive/item.yaml b/functions/src/open_archive/item.yaml index c40a62e4a..adcc4c69e 100644 --- a/functions/src/open_archive/item.yaml +++ b/functions/src/open_archive/item.yaml @@ -11,7 +11,7 @@ labels: author: Iguazio maintainers: [] marketplaceType: '' -mlrunVersion: 1.8.0-rc50 +mlrunVersion: 1.8.0 name: open-archive platformVersion: 3.5.0 spec: diff --git a/functions/src/open_archive/open_archive.py b/functions/src/open_archive/open_archive.py index 19d3c757b..225edb224 100644 --- a/functions/src/open_archive/open_archive.py +++ b/functions/src/open_archive/open_archive.py @@ -14,22 +14,21 @@ # import os -import zipfile import tarfile +import zipfile +from urllib.parse import urlparse -from mlrun.execution import MLClientCtx -from mlrun.datastore import DataItem from mlrun.artifacts.base import DirArtifact - -from urllib.parse import urlparse +from mlrun.datastore import DataItem +from mlrun.execution import MLClientCtx def open_archive( - context: MLClientCtx, - archive_url: DataItem, - subdir: str = "content/", - key: str = "content", - target_path: str = None, + context: MLClientCtx, + archive_url: DataItem, + subdir: str = "content/", + key: str = "content", + target_path: str = None, ): """Open a file/object archive into a target directory. Currently, supports zip and tar.gz. @@ -46,49 +45,73 @@ def open_archive( # When custom artifact path is defined if not target_path and context.artifact_path: parsed_subdir = urlparse(context.artifact_path) - if parsed_subdir.scheme == 's3': + if parsed_subdir.scheme == "s3": subdir = os.path.join(context.artifact_path, subdir) - elif parsed_subdir.scheme == 'v3io': - v3io_subdir = os.path.join(context.artifact_path, subdir) # Using v3io_subdir for logging - subdir = '/v3io' + parsed_subdir.path + '/' + subdir - context.logger.info(f'Using v3io scheme, extracting to {subdir}') + elif parsed_subdir.scheme == "v3io": + v3io_subdir = os.path.join( + context.artifact_path, subdir + ) # Using v3io_subdir for logging + subdir = "/v3io" + parsed_subdir.path + "/" + subdir + context.logger.info(f"Using v3io scheme, extracting to {subdir}") else: - context.logger.info(f'Unrecognizable scheme, extracting to {subdir}') + context.logger.info(f"Unrecognizable scheme, extracting to {subdir}") # When working on CE, target path might be on s3 - if 's3' in (target_path or subdir): - context.logger.info(f'Using s3 scheme, extracting to {target_path or subdir}') + if "s3" in (target_path or subdir): + context.logger.info(f"Using s3 scheme, extracting to {target_path or subdir}") if archive_url.endswith("gz"): - _extract_gz_file(archive_url=archive_url, subdir=subdir, target_path=target_path, in_s3=True) + _extract_gz_file( + archive_url=archive_url, + subdir=subdir, + target_path=target_path, + in_s3=True, + ) elif archive_url.endswith("zip"): - _extract_zip_file(archive_url=archive_url, subdir=subdir, target_path=target_path, in_s3=True) + _extract_zip_file( + archive_url=archive_url, + subdir=subdir, + target_path=target_path, + in_s3=True, + ) else: raise ValueError(f"unsupported archive type in {archive_url}") else: if archive_url.endswith("gz"): - _extract_gz_file(archive_url=archive_url, subdir=subdir, target_path=target_path) + _extract_gz_file( + archive_url=archive_url, subdir=subdir, target_path=target_path + ) elif archive_url.endswith("zip"): - _extract_zip_file(archive_url=archive_url, subdir=subdir, target_path=target_path) + _extract_zip_file( + archive_url=archive_url, subdir=subdir, target_path=target_path + ) else: raise ValueError(f"unsupported archive type in {archive_url}") if v3io_subdir: subdir = v3io_subdir - context.logger.info(f'Logging artifact to {(target_path or subdir)}') + context.logger.info(f"Logging artifact to {(target_path or subdir)}") context.log_artifact(DirArtifact(key=key, target_path=(target_path or subdir))) -def _extract_gz_file(archive_url: str, target_path: str = None, subdir: str = "content/", in_s3: bool = False): +def _extract_gz_file( + archive_url: str, + target_path: str = None, + subdir: str = "content/", + in_s3: bool = False, +): if in_s3: client = _init_boto3_client() with tarfile.open(archive_url, mode="r|gz") as ref: for member in ref.getmembers(): data = ref.extractfile(member=member).read() - client.put_object(Body=data, Bucket=urlparse(target_path or subdir).netloc, - Key=f'{urlparse(target_path or subdir).path[1:]}{member.name}') + client.put_object( + Body=data, + Bucket=urlparse(target_path or subdir).netloc, + Key=f"{urlparse(target_path or subdir).path[1:]}{member.name}", + ) else: os.makedirs(target_path or subdir, exist_ok=True) with tarfile.open(archive_url, mode="r:gz") as ref: @@ -100,14 +123,19 @@ def _extract_gz_file(archive_url: str, target_path: str = None, subdir: str = "c ref.extract(entry, target_path or subdir) -def _extract_zip_file(archive_url, target_path: str = None, subdir: str = "content/", in_s3: bool = False): +def _extract_zip_file( + archive_url, target_path: str = None, subdir: str = "content/", in_s3: bool = False +): if in_s3: client = _init_boto3_client() with zipfile.ZipFile(archive_url, "r") as ref: for filename in ref.namelist(): data = ref.read(filename) - client.put_object(Body=data, Bucket=urlparse(target_path or subdir).netloc, - Key=f'{urlparse(target_path or subdir).path[1:]}{filename}') + client.put_object( + Body=data, + Bucket=urlparse(target_path or subdir).netloc, + Key=f"{urlparse(target_path or subdir).path[1:]}{filename}", + ) else: with zipfile.ZipFile(archive_url, "r") as ref: # Validate that there is no path traversal in the archive @@ -120,13 +148,15 @@ def _extract_zip_file(archive_url, target_path: str = None, subdir: str = "conte def _init_boto3_client(): import boto3 - + # Backward compatibility: Support both S3_ENDPOINT_URL (deprecated) and AWS_ENDPOINT_URL_S3 # TODO: Remove this in 1.12.0 - endpoint_url = os.environ.get('AWS_ENDPOINT_URL_S3') or os.environ.get('S3_ENDPOINT_URL') - + endpoint_url = os.environ.get("AWS_ENDPOINT_URL_S3") or os.environ.get( + "S3_ENDPOINT_URL" + ) + if endpoint_url: - client = boto3.client('s3', endpoint_url=endpoint_url) + client = boto3.client("s3", endpoint_url=endpoint_url) else: - client = boto3.client('s3') - return client \ No newline at end of file + client = boto3.client("s3") + return client diff --git a/functions/src/open_archive/test_open_archive.py b/functions/src/open_archive/test_open_archive.py index 507c7ecbc..29fcafc99 100644 --- a/functions/src/open_archive/test_open_archive.py +++ b/functions/src/open_archive/test_open_archive.py @@ -12,17 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from pathlib import Path -import shutil import os +import shutil import tarfile -from mlrun import code_to_function, import_function +from pathlib import Path + import open_archive import pytest +from mlrun import code_to_function, import_function -ARTIFACTS_PATH = 'artifacts' -CONTENT_PATH = 'content/data/images' -ARCHIVE_URL = "https://s3.wasabisys.com/iguazio/data/cats-vs-dogs/cats-vs-dogs-labeling-demo.zip" +ARTIFACTS_PATH = "artifacts" +CONTENT_PATH = "content/data/images" +ARCHIVE_URL = ( + "https://s3.wasabisys.com/iguazio/data/cats-vs-dogs/cats-vs-dogs-labeling-demo.zip" +) def _delete_outputs(paths): @@ -32,27 +35,32 @@ def _delete_outputs(paths): def test_open_archive(): - fn = code_to_function(name='test_open_archive', - filename="open_archive.py", - handler="open_archive", - kind="local", - ) + fn = code_to_function( + name="test_open_archive", + filename="open_archive.py", + handler="open_archive", + kind="local", + ) fn.spec.command = "open_archive.py" - fn.run(inputs={'archive_url': ARCHIVE_URL}, - params={'key': 'test_archive', 'target_path': os.getcwd() + '/content/'}, - local=True) + fn.run( + inputs={"archive_url": ARCHIVE_URL}, + params={"key": "test_archive", "target_path": os.getcwd() + "/content/"}, + local=True, + ) assert Path(CONTENT_PATH).is_dir() - _delete_outputs({'artifacts', 'runs', 'schedules', 'content'}) + _delete_outputs({"artifacts", "runs", "schedules", "content"}) def test_open_archive_import_function(): fn = import_function("function.yaml") - run = fn.run(inputs={'archive_url': ARCHIVE_URL}, - params={'key': 'test_archive', 'target_path': os.getcwd() + '/content/'}, - local=True) - assert (run.status.artifact_uris["test_archive"]) - _delete_outputs({'artifacts', 'runs', 'schedules', 'content'}) + run = fn.run( + inputs={"archive_url": ARCHIVE_URL}, + params={"key": "test_archive", "target_path": os.getcwd() + "/content/"}, + local=True, + ) + assert run.status.artifact_uris["test_archive"] + _delete_outputs({"artifacts", "runs", "schedules", "content"}) def test_traversal_entry(): @@ -65,6 +73,8 @@ def test_traversal_entry(): tar.add("malicious.txt", arcname="../malicious.txt") with pytest.raises(ValueError): - open_archive._extract_gz_file("malicious.tar.gz", target_path=os.getcwd() + '/content/') + open_archive._extract_gz_file( + "malicious.tar.gz", target_path=os.getcwd() + "/content/" + ) os.remove("malicious.txt") - os.remove("malicious.tar.gz") \ No newline at end of file + os.remove("malicious.tar.gz") diff --git a/functions/src/pii_recognizer/function.yaml b/functions/src/pii_recognizer/function.yaml index e7d6c1241..d3bc1516e 100644 --- a/functions/src/pii_recognizer/function.yaml +++ b/functions/src/pii_recognizer/function.yaml @@ -1,42 +1,61 @@ +metadata: + tag: '' + name: pii-recognizer + categories: + - data-preparation + - NLP verbose: false +kind: job spec: - default_handler: recognize_pii + image: '' + disable_auto_mount: false + build: + origin_filename: '' + functionSourceCode: # Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import pathlib
import tempfile
import warnings
from typing import List

import annotated_text.util as at_util
import mlrun
import nltk
import pandas as pd
import presidio_analyzer as pa
import presidio_anonymizer as pre_anoymizer
from presidio_anonymizer.entities import OperatorConfig
from tqdm import tqdm

try:
    import flair as fl
except ModuleNotFoundError:
    print("Flair is not installed")

# There is a conflict between Rust-based tokenizers' parallel processing
# and Python's fork operations during multiprocessing. To avoid this, we need
# the following two lines

os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore")

logger = logging.getLogger("pii-recognizer")


# Add the constant classes of Models and Entities to govern the whole package
class Models:
    WHOLE = "whole"
    PATTERN = "pattern"
    SPACY = "spacy"
    FLAIR = "flair"


class Entities:
    CREDIT_CARD = "CREDIT_CARD"
    SSN = "SSN"
    PHONE = "PHONE"
    EMAIL = "EMAIL"
    LOCATION = "LOCATION"
    PERSON = "PERSON"
    NRP = "NRP"
    ORGANIZATION = "ORGANIZATION"
    DATE_TIME = "DATE_TIME"
    GPE = ("GPE",)
    MAC_ADDRESS = "MAC_ADDRESS"
    US_BANK_NUMBER = "US_BANK_NUMBER"
    IMEI = "IMEI"
    TITLE = "TITLE"
    LICENSE_PLATE = "LICENSE_PLATE"
    US_PASSPORT = "US_PASSPORT"
    CURRENCY = "CURRENCY"
    ROUTING_NUMBER = "ROUTING_NUMBER"
    US_ITIN = "US_ITIN"
    US_BANK_NUMBER = "US_BANK_NUMBER"
    US_DRIVER_LICENSE = "US_DRIVER_LICENSE"
    AGE = "AGE"
    PASSWORD = "PASSWORD"
    SWIFT_CODE = "SWIFT_CODE"


class PatternRecognizerFactory:
    """
    Factory for creating pattern recognizers, it can be extended in the future to
    add more regex pattern for different entities. For the pattern recognizer to work,
    we need construct a list of regex patterns for each entity.
    """

    RECOGNIZABLE_ENTITIES = {
        "CREDIT_CARD": [pa.Pattern("CREDIT_CARD", r"\b(?:\d[ -]*?){13,16}\b", 0.5)],
        "SSN": [pa.Pattern("SSN", r"\b\d{3}-?\d{2}-?\d{4}\b", 0.5)],
        "PHONE": [pa.Pattern("PHONE", r"\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}", 0.5)],
        "EMAIL": [pa.Pattern("EMAIL", r"\S+@\S+", 0.5)],
    }

    # create a list of pattern recognizers
    @classmethod
    def _create_pattern_recognizer(cls):
        """
        For each entity, create a list of patterns to recognize it

        :param cls: PatternRecognizerFactory class

        :returns: List of pattern recognizers
        """

        # Entities to recognize and their regex patterns

        return [
            pa.PatternRecognizer(supported_entity=entity, patterns=pattern)
            for entity, pattern in cls.RECOGNIZABLE_ENTITIES.items()
        ]


class CustomSpacyRecognizer(pa.LocalRecognizer):
    """
    Custom Spacy Recognizer from Presidio Analyzer trained on Privy data.
    The privy data is generated using this https://github.com/pixie-io/pixie/tree/main/src/datagen/pii/privy
    It can be used to recognize custom entities, Since we want to use Presidio's Registries to generate AnalyzerEngine,
    it inherits from Presidio Analyzer's LocalRecognizer class.
    """

    # Entities to recognize

    RECOGNIZABLE_ENTITIES = {
        "LOCATION",
        "PERSON",
        "NRP",
        "ORGANIZATION",
        "DATE_TIME",
    }

    # Default explanation for this recognizer

    _DEFAULT_EXPLANATION = (
        "Identified as {} by Spacy's Named Entity Recognition (Privy-trained)"
    )

    # Label groups to check

    _DEFAULT_CHECK_LABEL_GROUPS = [
        ({"LOCATION"}, {"LOC", "LOCATION", "STREET_ADDRESS", "COORDINATE"}),
        ({"PERSON"}, {"PER", "PERSON"}),
        ({"NRP"}, {"NORP", "NRP"}),
        ({"ORGANIZATION"}, {"ORG"}),
        ({"DATE_TIME"}, {"DATE_TIME"}),
    ]

    # pretrained model for this recognizer

    _DEFAULT_MODEL_LANGUAGES = {
        "en": "beki/en_spacy_pii_distilbert",
    }

    _DEFAULT_PRESIDIO_EQUIVALENCES = {
        "PER": "PERSON",
        "LOC": "LOCATION",
        "ORG": "ORGANIZATION",
        "NROP": "NRP",
        "DATE_TIME": "DATE_TIME",
    }

    def __init__(
        self,
        supported_language: str = "en",
        supported_entities: list[str] = None,
        check_label_groups: tuple[set, set] = None,
        context: list[str] = None,
        ner_strength: float = 1,
    ):
        """
        Initialize Spacy Recognizer.

        :param supported_language: Language to use, default is English
        :param supported_entities: Entities to use for recognition
        :param check_label_groups: Label groups to check for the entities
        :param context:            Context to use if any
        :param ner_strength:       Default confidence for NER prediction

        :returns: SpacyRecognizer object
        """

        # Default confidence for NER prediction
        self.ner_strength = ner_strength

        self.check_label_groups = check_label_groups or self._DEFAULT_CHECK_LABEL_GROUPS
        supported_entities = supported_entities or self.RECOGNIZABLE_ENTITIES
        super().__init__(
            supported_entities=supported_entities,
            supported_language=supported_language,
        )

    # get the presidio explanation for the result

    def _build_spacy_explanation(
        self, original_score: float, explanation: str
    ) -> pa.AnalysisExplanation:
        """
        Create explanation for why this result was detected.

        :param original_score: Score given by this recognizer
        :param explanation:    Explanation string

        :returns: Presidio AnalysisExplanation object
        """
        explanation = pa.AnalysisExplanation(
            recognizer=self.__class__.__name__,
            original_score=original_score,
            textual_explanation=explanation,
        )
        return explanation

    # main method for the recognizer
    def analyze(self, text: str, entities: List[str], nlp_artifacts=None):  # noqa D102
        """
        Analyze text using Spacy.

        :param text:          Text to analyze
        :param entities:      Entities to analyze
        :param nlp_artifacts: NLP artifacts to use

        :returns: List of Presidio RecognizerResult objects
        """
        results = []
        if not nlp_artifacts:
            logger.warning("Skipping SpaCy, nlp artifacts not provided...")
            return results

        ner_entities = nlp_artifacts.entities

        # recognize the supported entities
        for entity in entities:
            if entity not in self.supported_entities:
                continue
            for ent in ner_entities:
                if not self.__check_label(entity, ent.label_, self.check_label_groups):
                    continue

                # string of the explanation saying the entity is recognized by spacy
                textual_explanation = self._DEFAULT_EXPLANATION.format(ent.label_)
                explanation = self._build_spacy_explanation(
                    self.ner_strength, textual_explanation
                )

                # create the standard result with the entity, start, end, score, and explanation
                spacy_result = pa.RecognizerResult(
                    entity_type=entity,
                    start=ent.start_char,
                    end=ent.end_char,
                    score=self.ner_strength,
                    analysis_explanation=explanation,
                    recognition_metadata={
                        pa.RecognizerResult.RECOGNIZER_NAME_KEY: self.name
                    },
                )
                results.append(spacy_result)

        return results

    @staticmethod
    def __check_label(
        entity: str, label: str, check_label_groups: tuple[set, set]
    ) -> bool:
        """
        Check if the label is in the label group.

        :param entity:             Entity to check
        :param label:              Label to check
        :param check_label_groups: Label groups to check

        :returns: True if the label is in the label group, False otherwise
        """
        return any(
            entity in egrp and label in lgrp for egrp, lgrp in check_label_groups
        )


# Class to use Flair with Presidio as an external recognizer.
class FlairRecognizer(pa.EntityRecognizer):
    """
    Wrapper for a flair model, if needed to be used within Presidio Analyzer.
    This is to make sure the recognizer can be registered with Presidio registry.
    """

    RECOGNIZABLE_ENTITIES = {
        "LOCATION",
        "PERSON",
        "NRP",
        "GPE",
        "ORGANIZATION",
        "MAC_ADDRESS",
        "US_BANK_NUMBER",
        "IMEI",
        "TITLE",
        "LICENSE_PLATE",
        "US_PASSPORT",
        "CURRENCY",
        "ROUTING_NUMBER",
        "US_ITIN",
        "US_BANK_NUMBER",
        "US_DRIVER_LICENSE",
        "AGE",
        "PASSWORD",
        "SWIFT_CODE",
    }

    # This is used to construct the explanation for the result

    _DEFAULT_EXPLANATION = "Identified as {} by Flair's Named Entity Recognition"

    _DEFAULT_CHECK_LABEL_GROUPS = [
        ({"LOCATION"}, {"LOC", "LOCATION", "STREET_ADDRESS", "COORDINATE"}),
        ({"PERSON"}, {"PER", "PERSON"}),
        ({"NRP"}, {"NORP", "NRP"}),
        ({"GPE"}, {"GPE"}),
        ({"ORGANIZATION"}, {"ORG"}),
        ({"MAC_ADDRESS"}, {"MAC_ADDRESS"}),
        ({"US_BANK_NUMBER"}, {"US_BANK_NUMBER"}),
        ({"IMEI"}, {"IMEI"}),
        ({"TITLE"}, {"TITLE"}),
        ({"LICENSE_PLATE"}, {"LICENSE_PLATE"}),
        ({"US_PASSPORT"}, {"US_PASSPORT"}),
        ({"CURRENCY"}, {"CURRENCY"}),
        ({"ROUTING_NUMBER"}, {"ROUTING_NUMBER"}),
        ({"AGE"}, {"AGE"}),
        ({"CURRENCY"}, {"CURRENCY"}),
        ({"SWIFT_CODE"}, {"SWIFT_CODE"}),
        ({"US_ITIN"}, {"US_ITIN"}),
        ({"US_BANK_NUMBER"}, {"US_BANK_NUMBER"}),
        ({"US_DRIVER_LICENSE"}, {"US_DRIVER_LICENSE"}),
    ]

    _DEFAULT_MODEL_LANGUAGES = {
        "en": "beki/flair-pii-distilbert",
    }

    _DEFAULT_PRESIDIO_EQUIVALENCES = {
        "PER": "PERSON",
        "LOC": "LOCATION",
        "ORG": "ORGANIZATION",
        "NROP": "NRP",
        "URL": "URL",
        "US_ITIN": "US_ITIN",
        "US_PASSPORT": "US_PASSPORT",
        "IBAN_CODE": "IBAN_CODE",
        "IP_ADDRESS": "IP_ADDRESS",
        "EMAIL_ADDRESS": "EMAIL",
        "US_DRIVER_LICENSE": "US_DRIVER_LICENSE",
        "US_BANK_NUMBER": "US_BANK_NUMBER",
    }

    def __init__(
        self,
        supported_language: str = "en",
        supported_entities: list[str] = None,
        check_label_groups: tuple[set, set] = None,
    ):
        """
        Initialize the FlairRecognizer.

        :param supported_language: Language to use
        :param supported_entities: Entities to use
        :param check_label_groups: Label groups to check

        :returns: FlairRecognizer object

        """
        self.check_label_groups = check_label_groups or self._DEFAULT_CHECK_LABEL_GROUPS

        supported_entities = supported_entities or self.RECOGNIZABLE_ENTITIES
        self.model = fl.models.SequenceTagger.load(
            self._DEFAULT_MODEL_LANGUAGES.get(supported_language)
        )

        super().__init__(
            supported_entities=supported_entities,
            supported_language=supported_language,
            name="Flair Analytics",
        )

    # main method for the recognizer
    def analyze(
        self,
        text: str,
        entities: list[str],
        nlp_artifacts: pa.nlp_engine.NlpArtifacts = None,
    ) -> list[pa.RecognizerResult]:
        """
        Analyze text and return the results.

        :param text:          The text for analysis.
        :param entities:      The list of entities to recognize.
        :param nlp_artifacts: Not used by this recognizer but needed for the interface.

        :returns: The list of Presidio RecognizerResult constructed from the recognized Flair detections.
        """

        results = []

        sentences = fl.data.Sentence(text)
        self.model.predict(sentences)

        # If there are no specific list of entities, we will look for all of it.
        if not entities:
            entities = self.supported_entities

        # Go over the entities and check if they are in the supported entities list.
        for entity in entities:
            if entity not in self.supported_entities:
                continue

            # Go over the sentences and check if the entity is in the sentence.
            for ent in sentences.get_spans("ner"):
                if not self.__check_label(
                    entity, ent.labels[0].value, self.check_label_groups
                ):
                    continue

                # If the entity is in the sentence, we will add it to the results.
                textual_explanation = self._DEFAULT_EXPLANATION.format(
                    ent.labels[0].value
                )

                # Build the explanation for the result
                explanation = self._build_flair_explanation(
                    round(ent.score, 2), textual_explanation
                )

                flair_result = self._convert_to_recognizer_result(ent, explanation)

                results.append(flair_result)

        return results

    def _convert_to_recognizer_result(
        self, entity: fl.data.Span, explanation: str
    ) -> pa.RecognizerResult:
        """
        Convert Flair result to Presidio RecognizerResult.

        :param entity:      Flair entity of Span
        :param explanation: Presidio AnalysisExplanation

        :returns: Presidio RecognizerResult
        """

        # Convert the entity type to Presidio entity type
        entity_type = self._DEFAULT_PRESIDIO_EQUIVALENCES.get(entity.tag, entity.tag)

        # Convert the score to Presidio score
        flair_score = round(entity.score, 2)

        # Create the Presidio RecognizerResult from the Flair entity
        flair_results = pa.RecognizerResult(
            entity_type=entity_type,
            start=entity.start_position,
            end=entity.end_position,
            score=flair_score,
            analysis_explanation=explanation,
        )

        return flair_results

    def _build_flair_explanation(
        self, original_score: float, explanation: str
    ) -> pa.AnalysisExplanation:
        """
        Create explanation for why this result was detected.

        :param original_score: Score given by this recognizer
        :param explanation:    Explanation string

        :returns: Presidio AnalysisExplanation
        """

        # Create the Presidio AnalysisExplanation for the result
        explanation = pa.AnalysisExplanation(
            recognizer=self.__class__.__name__,
            original_score=original_score,
            textual_explanation=explanation,
        )
        return explanation

    # sanity check of the entity and label before recognition
    @staticmethod
    def __check_label(
        entity: str, label: str, check_label_groups: tuple[set, set]
    ) -> bool:
        return any(
            entity in egrp and label in lgrp for egrp, lgrp in check_label_groups
        )


# get the analyzer engine based on the model
def _get_analyzer_engine(
    model: str = None, entities: list[str] = None
) -> pa.AnalyzerEngine:
    """
    Return pa.AnalyzerEngine.

    :param model: The model to use. Can be "spacy", "flair", "pattern" or "whole".
    :param entities: The list of entities to use.

    :returns: pa.AnalyzerEngine
    """
    # recognizer registry that can store multiple recognizers
    registry = pa.RecognizerRegistry()
    if model == Models.SPACY:
        # custom spacy recognizer
        spacy_recognizer = CustomSpacyRecognizer()
        # add the custom build spacy recognizer
        registry.add_recognizer(spacy_recognizer)
    elif model == Models.FLAIR:
        # pre-trained flair recognizer
        flair_recognizer = FlairRecognizer()
        # add the custom build flair recognizer
        registry.add_recognizer(flair_recognizer)
    elif model == Models.PATTERN:
        # add the pattern recognizer
        pattern_recognizer_factory = PatternRecognizerFactory()
        for recognizer in pattern_recognizer_factory._create_pattern_recognizer():
            registry.add_recognizer(recognizer)
    elif model == Models.WHOLE:
        spacy_recognizer = CustomSpacyRecognizer()
        flair_recognizer = FlairRecognizer()
        registry.add_recognizer(spacy_recognizer)
        registry.add_recognizer(flair_recognizer)
        # add the pattern recognizer
        pattern_recognizer_factory = PatternRecognizerFactory()
        for recognizer in pattern_recognizer_factory._create_pattern_recognizer():
            registry.add_recognizer(recognizer)
    elif not model and entities:
        if set(entities) & CustomSpacyRecognizer.RECOGNIZABLE_ENTITIES:
            spacy_recognizer = CustomSpacyRecognizer()
            registry.add_recognizer(spacy_recognizer)
        if set(entities) & FlairRecognizer.RECOGNIZABLE_ENTITIES:
            flair_recognizer = FlairRecognizer()
            registry.add_recognizer(flair_recognizer)
        # add the pattern recognizer
        if set(entities) & (set(PatternRecognizerFactory.RECOGNIZABLE_ENTITIES.keys())):
            pattern_recognizer_factory = PatternRecognizerFactory()
            for recognizer in pattern_recognizer_factory._create_pattern_recognizer():
                registry.add_recognizer(recognizer)
    else:
        raise ValueError(
            "argument of model and entities can not be None at the same time"
        )
    analyzer = pa.AnalyzerEngine(
        registry=registry,
        supported_languages=["en"],
    )

    supported_entities = analyzer.get_supported_entities()

    if entities and not all(item in supported_entities for item in entities):
        not_supported_entities = [
            item for item in entities if item not in supported_entities
        ]
        raise ValueError(
            f"The current model {model} doesn't support the following entities: {not_supported_entities}. "
            f"Supported entities are: {supported_entities}"
        )
    return analyzer


def _get_anonymizer_engine() -> pre_anoymizer.AnonymizerEngine:
    """
    Return AnonymizerEngine.

    :returns: The AnonymizerEngine.
    """
    return pre_anoymizer.AnonymizerEngine()


def _anonymize(
    text: str,
    analyze_results: list[pa.RecognizerResult],
    entity_operator_map: dict = None,
    is_full_text: bool = True,
) -> str:
    """
    Anonymize identified input using Presidio Abonymizer.

    :param text:                The text for analysis.
    :param analyze_results:     The list of Presidio RecognizerResult constructed from
    :param entity_operator_map: The entity_operator_map is a dictionary that maps entity to operator name and operator params.
    :param is_full_text:        Whether the text is full text or not.

    :returns: The anonymized text.
    """
    if not text:
        return ""

    anonymizer_engine = _get_anonymizer_engine()
    if not entity_operator_map:
        operators = None
    else:
        # Create OperatorConfig based on the entity_operator_map
        operators = {
            entity: OperatorConfig(operator_name, operator_params)
            for entity, (operator_name, operator_params) in entity_operator_map.items()
        }

    if is_full_text:
        # Anonymize the entire text
        return anonymizer_engine.anonymize(
            text=text, analyzer_results=analyze_results, operators=operators
        ).text
    # Tokenize the text to sentences
    sentences = nltk.sent_tokenize(text)
    anonymized_sentences = []
    current_idx = 0

    # Find the sentence that has pii entity
    for sentence in sentences:
        start_idx = current_idx
        end_idx = start_idx + len(sentence)

        # Get the entities that are in the sentence, update hte start_idx and end_idx
        sentence_results = [
            pa.RecognizerResult(
                result.entity_type,
                start=result.start - start_idx,
                end=result.end - start_idx,
                score=result.score,
            )
            for result in analyze_results
            if result.start >= start_idx and result.end <= end_idx
        ]

        # If PII is detected
        if sentence_results:
            anonymized_sentence = anonymizer_engine.anonymize(
                text=sentence, analyzer_results=sentence_results, operators=operators
            ).text
            anonymized_sentences.append(anonymized_sentence)

        current_idx = end_idx

    return " ".join(anonymized_sentences)


def _get_tokens(
    text: str, analyze_results: list[pa.RecognizerResult], is_full: bool = True
) -> list[str]:
    """
    Get the full tokens or only contains the entities that can form a sentence.

    :param text:            The text for analysis.
    :param analyze_results: The list of Presidio RecognizerResult constructed from
    :param is_full:         Whether return full tokens or just the tokens that only contains the entities that can form a sentence.

    :returns: The tokens.
    """

    tokens = []
    # sort by start index
    results = sorted(analyze_results, key=lambda x: x.start)
    for i, res in enumerate(results):
        if i == 0:
            tokens.append(text[: res.start])

        # append entity text and entity type
        tokens.append((text[res.start : res.end], res.entity_type))

        # if another entity coming i.e. we're not at the last results element,
        # add text up to next entity
        if i != len(results) - 1:
            tokens.append(text[res.end : results[i + 1].start])
        # if no more entities coming, add all remaining text
        else:
            tokens.append(text[res.end :])

    # get the tokens that only contains the entities that can form a sentence
    part_annontated_tokens = []
    if not is_full:
        last_end_sentence = 0
        for i, token in enumerate(tokens):
            if any(item in token for item in [".", "!", "?"]) and any(
                type(item) is tuple for item in tokens[last_end_sentence:i]
            ):
                part_annontated_tokens.append(tokens[last_end_sentence:i])
                last_end_sentence = i
        return part_annontated_tokens
    return tokens


def _annotate(
    text: str, st_analyze_results: list[pa.RecognizerResult], is_full_html: bool = True
) -> list[str]:
    """
    Annotate identified input using Presidio Anonymizer.

    :param text:               The text for analysis.
    :param st_analyze_results: The list of Presidio RecognizerResult constructed from analysis.
    :param is_full_html:       Whether generate full html or not.

    :returns: The list of tokens with the identified entities.

    """
    return _get_tokens(text, st_analyze_results, is_full_html)


def _process(
    text: str,
    model: pa.AnalyzerEngine,
    score_threshold: float,
    entities: list[str] = None,
    entities_operator_map: dict = None,
    is_full_text: bool = True,
) -> tuple[str, list]:
    """
    Process the text of str using the model.

    :param text:                  Text to process
    :param model:                 Model to use for processing
    :param entities:              Entities to recognize
    :param entities_operator_map: The entity_operator_map is a dictionary that maps entity to operator name and operator params.
    :param score_threshold:       The score threshold to use for recognition
    :param is_full_text:          Whether to return the full text or just the annotated text

    :returns: A tuple of:

              * the anonymized text
              * the list of Presidio RecognizerResult constructed from analysis
    """

    # get the analyzer engine
    analyzer = model

    # analyze the text that can be used for anonymization
    results = analyzer.analyze(
        text=text,
        language="en",
        entities=entities,
        score_threshold=score_threshold,
        return_decision_process=True,
    )

    # anonymize the text, replace the pii entities with the labels
    anonymized_text = _anonymize(text, results, entities_operator_map, is_full_text)

    return anonymized_text, results


def _get_single_html(
    text: str, results: list[pa.RecognizerResult], is_full_html: bool = True
):
    """
    Generate the html for a single txt file.

    :param text:         The text for analysis.
    :param results:      The list of Presidio RecognizerResult constructed from analysis.
    :param is_full_html: Whether generate full html or not.

    :returns: The html string for a single txt file.
    """
    # convert the results to tokens to generate the html
    tokens = _annotate(text, results, is_full_html)
    html = at_util.get_annotated_html(*tokens)

    # avoid the error during rendering of the \n in the html
    backslash_char = "\\"

    html_str = f"<p>{html.replace('{backslash_char}n', '<br>')}</p>"

    return html_str


def _get_single_json(results: list[pa.RecognizerResult], is_full_report: bool = True):
    """
    Generate the json for a single txt file.

    :param results:        The list of Presidio RecognizerResult constructed from analysis.
    :param is_full_report: Whether generate full json or not.

    :returns: The json string for a single txt file.
    """
    # generate the stats report if needed
    if not is_full_report:
        stats = []
        # add the simplify stats logic here
        for item in results:
            item.analysis_explanation = None
            stats.append(item)
    else:
        stats = results

    return stats


def _get_all_html(
    txt_content: dict,
    res_dict: dict,
    is_full_html: bool = True,
):
    """
    Generate the html for all txt files.

    :param txt_content:  The dictionary of txt file name and content.
    :param res_dict:     The dictionary of txt file name and the list of Presidio RecognizerResult constructed from analysis.
    :param is_full_html: Whether generate full html or not.

    :returns: The html string for all txt files.

    """
    # These are placeholder for the html string
    html_index = "<html><head><title>Highlighted Pii Entities</title></head><body><h1>Highlighted Pii Entities</h1><ul>"
    html_content = ""
    for txt_file, results in res_dict.items():
        txt = txt_content[txt_file]
        html_index += f"<li><a href='#{txt_file}'>{txt_file}</a></li>"
        html_content += f"<li><h2>{txt_file}</h2><p>{_get_single_html(txt, results, is_full_html)}</p></li>"
    html_index += "</ul>"
    html_res = f"{html_index}{html_content}</body></html>"

    return html_res


def _get_all_rpt(res_dict: dict, is_full_report: bool = True):
    """
    Generate the stats report for all txt files.

    :param res_dict:       The dictionary of txt file name and the list of Presidio RecognizerResult constructed from analysis.
    :param is_full_report: Whether generate full report or not.

    :returns: The stats report for all txt files.
    """
    # These are placeholder for the json report
    stats_dict = {}
    for txt_file, results in res_dict.items():
        new_stats = []
        for item in _get_single_json(results, is_full_report):
            if is_full_report:
                item.analysis_explanation = item.analysis_explanation.to_dict()
                new_stats.append(item.to_dict())
            else:
                tmp_dict = item.to_dict()
                tmp_dict.pop("analysis_explanation")
                tmp_dict.pop("recognition_metadata")
                new_stats.append(tmp_dict)
        stats_dict[txt_file] = new_stats
    return stats_dict


def recognize_pii(
    context: mlrun.MLClientCtx,
    input_path: str | pathlib.Path,
    html_key: str,
    score_threshold: float,
    output_directory: str = None,
    entities: list[
        str
    ] = None,  # List of entities to recognize, default is recognizing all
    entity_operator_map: dict = None,
    model: str = None,
    generate_json: bool = True,
    generate_html: bool = True,
    is_full_text: bool = True,
    is_full_html: bool = True,
    is_full_report: bool = True,
) -> tuple[str, pd.DataFrame, dict, dict] | tuple[str, pd.DataFrame, dict]:
    """
    Walk through the input path, recognize PII in text and store the anonymized text in the output path.
    Generate the html with different colors for each entity, json report of the explanation.

    :param context:              The MLRun context. this is needed for log the artifacts.
    :param input_path:           The input path of the text files needs to be analyzed.
    :param html_key:             The html key for the artifact.
    :param score_threshold:      The score threshold to mark the recognition as trusted.
    :param output_directory:     The output directory path to store the anonymized text.
    :param entities:             The list of entities to recognize.
    :param entity_operator_map:  The map of entity to operator (mask, redact, replace, keep, hash, and its params)
    :param model:                The model to use. Can be "spacy", "flair", "pattern" or "whole".
    :param generate_json:        Whether to generate the json report of the explanation.
    :param generate_html:        Whether to generate the html report of the explanation.
    :param is_full_text:         Whether to return the full text or only the masked text.
    :param is_full_html:         Whether to return the full html or just the annotated text
    :param is_full_report:       Whether to return the full report or just the score and start, end index

    :returns: A tuple of:

              * Path to the output directory
              * The json report of the explanation (if generate_json is True)
              * A dictionary of errors files that were not processed

    """

    # Set output directory
    if output_directory is None:
        output_directory = tempfile.mkdtemp()

    # Create the output directory:
    output_directory = pathlib.Path(output_directory)
    if not output_directory.exists():
        output_directory.mkdir(parents=True, exist_ok=True)

    txt_files_directory = pathlib.Path(input_path)
    successes = []
    errors = {}

    res_dict = {}
    txt_content = {}
    # Load the model:
    analyzer = _get_analyzer_engine(model, entities)
    logger.info("Model loaded")
    # Go over the text files in the input path, analyze and anonymize them:
    for txt_file in tqdm(
        list(txt_files_directory.glob("*.txt")),
        desc="Processing files",
        unit="file",
    ):
        try:
            # Load the str from the text file
            text = txt_file.read_text()
            txt_content[str(txt_file)] = text
            # Process the text to recoginze the pii entities in it
            anonymized_text, results = _process(
                text=text,
                model=analyzer,
                entities=entities,
                entities_operator_map=entity_operator_map,
                score_threshold=score_threshold,
                is_full_text=is_full_text,
            )
            res_dict[str(txt_file)] = results
            # Store the anonymized text in the output path
            output_file = output_directory / f"{txt_file.stem}.txt"
            output_file.parent.mkdir(parents=True, exist_ok=True)
            with open(output_file, "w") as f:
                f.write(anonymized_text)
            successes.append([txt_file.name, output_file.name])
        except Exception as e:
            errors[str(txt_file)] = str(e)
            logger.error(f"Error processing {txt_file}: {e}")

    successes = pd.DataFrame(
        successes,
        columns=["original_file", "anonymized_file"],
    )

    if generate_html:
        # Generate the html report
        html_res = _get_all_html(txt_content, res_dict, is_full_html)
        # Store the html report in the context
        arti_html = mlrun.artifacts.Artifact(body=html_res, format="html", key=html_key)
        context.log_artifact(arti_html)
    if generate_json:
        # Generate the json report
        json_res = _get_all_rpt(res_dict, is_full_report)
        return str(output_directory), successes, errors, json_res
    return str(output_directory), successes, errors
 + requirements: + - nltk + - pandas + - presidio-anonymizer + - presidio-analyzer + - torch + - flair@git+https://github.com/flairNLP/flair.git@d4ed67bf663e4066517f00397412510d90043653 + - st-annotated-text + - https://huggingface.co/beki/en_spacy_pii_distilbert/resolve/main/en_spacy_pii_distilbert-any-py3-none-any.whl + code_origin: '' + base_image: mlrun/mlrun + filename: pii_recognizer.py entry_points: analyze: - name: analyze outputs: - doc: The list of Presidio RecognizerResult constructed from the recognized Flair detections. - type: List[pa.RecognizerResult] - has_kwargs: false + type: list[pa.RecognizerResult] parameters: - name: self - name: text type: str doc: The text for analysis. - name: entities - type: List[str] + type: list[str] doc: The list of entities to recognize. - name: nlp_artifacts type: pa.nlp_engine.NlpArtifacts doc: Not used by this recognizer but needed for the interface. default: null - lineno: 381 + name: analyze doc: Analyze text and return the results. + has_kwargs: false has_varargs: false + lineno: 381 recognize_pii: - name: recognize_pii outputs: - doc: 'A tuple of:' - type: Union[Tuple[str, pd.DataFrame, dict, dict], Tuple[str, pd.DataFrame, - dict]] - has_kwargs: false + type: tuple[str, pd.DataFrame, dict, dict] | tuple[str, pd.DataFrame, dict] parameters: - name: context type: MLClientCtx doc: The MLRun context. this is needed for log the artifacts. - name: input_path - type: Union[str, Path] doc: The input path of the text files needs to be analyzed. - name: html_key type: str @@ -49,7 +68,7 @@ spec: doc: The output directory path to store the anonymized text. default: null - name: entities - type: List[str] + type: list[str] doc: The list of entities to recognize. default: null - name: entity_operator_map @@ -81,35 +100,15 @@ spec: type: bool doc: Whether to return the full report or just the score and start, end index default: true - lineno: 845 + name: recognize_pii doc: 'Walk through the input path, recognize PII in text and store the anonymized text in the output path. Generate the html with different colors for each entity, json report of the explanation.' + has_kwargs: false has_varargs: false - build: - base_image: mlrun/mlrun - requirements: - - nltk - - pandas - - presidio-anonymizer - - presidio-analyzer - - torch - - flair@git+https://github.com/flairNLP/flair.git@d4ed67bf663e4066517f00397412510d90043653 - - st-annotated-text - - https://huggingface.co/beki/en_spacy_pii_distilbert/resolve/main/en_spacy_pii_distilbert-any-py3-none-any.whl - functionSourceCode: # Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import pathlib
import tempfile
import warnings
from typing import List, Set, Tuple, Union

import annotated_text.util as at_util
import mlrun
import nltk
import pandas as pd
import presidio_analyzer as pa
import presidio_anonymizer as pre_anoymizer
from presidio_anonymizer.entities import OperatorConfig
from tqdm import tqdm

try:
    import flair as fl
except ModuleNotFoundError:
    print("Flair is not installed")

# There is a conflict between Rust-based tokenizers' parallel processing
# and Python's fork operations during multiprocessing. To avoid this, we need
# the following two lines

os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore")

logger = logging.getLogger("pii-recognizer")


# Add the constant classes of Models and Entities to govern the whole package
class Models:
    WHOLE = "whole"
    PATTERN = "pattern"
    SPACY = "spacy"
    FLAIR = "flair"


class Entities:
    CREDIT_CARD = "CREDIT_CARD"
    SSN = "SSN"
    PHONE = "PHONE"
    EMAIL = "EMAIL"
    LOCATION = "LOCATION"
    PERSON = "PERSON"
    NRP = "NRP"
    ORGANIZATION = "ORGANIZATION"
    DATE_TIME = "DATE_TIME"
    GPE = ("GPE",)
    MAC_ADDRESS = "MAC_ADDRESS"
    US_BANK_NUMBER = "US_BANK_NUMBER"
    IMEI = "IMEI"
    TITLE = "TITLE"
    LICENSE_PLATE = "LICENSE_PLATE"
    US_PASSPORT = "US_PASSPORT"
    CURRENCY = "CURRENCY"
    ROUTING_NUMBER = "ROUTING_NUMBER"
    US_ITIN = "US_ITIN"
    US_BANK_NUMBER = "US_BANK_NUMBER"
    US_DRIVER_LICENSE = "US_DRIVER_LICENSE"
    AGE = "AGE"
    PASSWORD = "PASSWORD"
    SWIFT_CODE = "SWIFT_CODE"


class PatternRecognizerFactory:
    """
    Factory for creating pattern recognizers, it can be extended in the future to
    add more regex pattern for different entities. For the pattern recognizer to work,
    we need construct a list of regex patterns for each entity.
    """

    RECOGNIZABLE_ENTITIES = {
        "CREDIT_CARD": [pa.Pattern("CREDIT_CARD", r"\b(?:\d[ -]*?){13,16}\b", 0.5)],
        "SSN": [pa.Pattern("SSN", r"\b\d{3}-?\d{2}-?\d{4}\b", 0.5)],
        "PHONE": [pa.Pattern("PHONE", r"\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}", 0.5)],
        "EMAIL": [pa.Pattern("EMAIL", r"\S+@\S+", 0.5)],
    }

    # create a list of pattern recognizers
    @classmethod
    def _create_pattern_recognizer(cls):
        """
        For each entity, create a list of patterns to recognize it

        :param cls: PatternRecognizerFactory class

        :returns: List of pattern recognizers
        """

        # Entities to recognize and their regex patterns

        return [
            pa.PatternRecognizer(supported_entity=entity, patterns=pattern)
            for entity, pattern in cls.RECOGNIZABLE_ENTITIES.items()
        ]


class CustomSpacyRecognizer(pa.LocalRecognizer):
    """
    Custom Spacy Recognizer from Presidio Analyzer trained on Privy data.
    The privy data is generated using this https://github.com/pixie-io/pixie/tree/main/src/datagen/pii/privy
    It can be used to recognize custom entities, Since we want to use Presidio's Registries to generate AnalyzerEngine,
    it inherits from Presidio Analyzer's LocalRecognizer class.
    """

    # Entities to recognize

    RECOGNIZABLE_ENTITIES = {
        "LOCATION",
        "PERSON",
        "NRP",
        "ORGANIZATION",
        "DATE_TIME",
    }

    # Default explanation for this recognizer

    _DEFAULT_EXPLANATION = (
        "Identified as {} by Spacy's Named Entity Recognition (Privy-trained)"
    )

    # Label groups to check

    _DEFAULT_CHECK_LABEL_GROUPS = [
        ({"LOCATION"}, {"LOC", "LOCATION", "STREET_ADDRESS", "COORDINATE"}),
        ({"PERSON"}, {"PER", "PERSON"}),
        ({"NRP"}, {"NORP", "NRP"}),
        ({"ORGANIZATION"}, {"ORG"}),
        ({"DATE_TIME"}, {"DATE_TIME"}),
    ]

    # pretrained model for this recognizer

    _DEFAULT_MODEL_LANGUAGES = {
        "en": "beki/en_spacy_pii_distilbert",
    }

    _DEFAULT_PRESIDIO_EQUIVALENCES = {
        "PER": "PERSON",
        "LOC": "LOCATION",
        "ORG": "ORGANIZATION",
        "NROP": "NRP",
        "DATE_TIME": "DATE_TIME",
    }

    def __init__(
        self,
        supported_language: str = "en",
        supported_entities: List[str] = None,
        check_label_groups: Tuple[Set, Set] = None,
        context: List[str] = None,
        ner_strength: float = 1,
    ):
        """
        Initialize Spacy Recognizer.

        :param supported_language: Language to use, default is English
        :param supported_entities: Entities to use for recognition
        :param check_label_groups: Label groups to check for the entities
        :param context:            Context to use if any
        :param ner_strength:       Default confidence for NER prediction

        :returns: SpacyRecognizer object
        """

        # Default confidence for NER prediction
        self.ner_strength = ner_strength

        self.check_label_groups = check_label_groups or self._DEFAULT_CHECK_LABEL_GROUPS
        supported_entities = supported_entities or self.RECOGNIZABLE_ENTITIES
        super().__init__(
            supported_entities=supported_entities,
            supported_language=supported_language,
        )

    # get the presidio explanation for the result

    def _build_spacy_explanation(
        self, original_score: float, explanation: str
    ) -> pa.AnalysisExplanation:
        """
        Create explanation for why this result was detected.

        :param original_score: Score given by this recognizer
        :param explanation:    Explanation string

        :returns: Presidio AnalysisExplanation object
        """
        explanation = pa.AnalysisExplanation(
            recognizer=self.__class__.__name__,
            original_score=original_score,
            textual_explanation=explanation,
        )
        return explanation

    # main method for the recognizer
    def analyze(self, text: str, entities: List[str], nlp_artifacts=None):  # noqa D102
        """
        Analyze text using Spacy.

        :param text:          Text to analyze
        :param entities:      Entities to analyze
        :param nlp_artifacts: NLP artifacts to use

        :returns: List of Presidio RecognizerResult objects
        """
        results = []
        if not nlp_artifacts:
            logger.warning("Skipping SpaCy, nlp artifacts not provided...")
            return results

        ner_entities = nlp_artifacts.entities

        # recognize the supported entities
        for entity in entities:
            if entity not in self.supported_entities:
                continue
            for ent in ner_entities:
                if not self.__check_label(entity, ent.label_, self.check_label_groups):
                    continue

                # string of the explanation saying the entity is recognized by spacy
                textual_explanation = self._DEFAULT_EXPLANATION.format(ent.label_)
                explanation = self._build_spacy_explanation(
                    self.ner_strength, textual_explanation
                )

                # create the standard result with the entity, start, end, score, and explanation
                spacy_result = pa.RecognizerResult(
                    entity_type=entity,
                    start=ent.start_char,
                    end=ent.end_char,
                    score=self.ner_strength,
                    analysis_explanation=explanation,
                    recognition_metadata={
                        pa.RecognizerResult.RECOGNIZER_NAME_KEY: self.name
                    },
                )
                results.append(spacy_result)

        return results

    @staticmethod
    def __check_label(
        entity: str, label: str, check_label_groups: Tuple[Set, Set]
    ) -> bool:
        """
        Check if the label is in the label group.

        :param entity:             Entity to check
        :param label:              Label to check
        :param check_label_groups: Label groups to check

        :returns: True if the label is in the label group, False otherwise
        """
        return any(
            entity in egrp and label in lgrp for egrp, lgrp in check_label_groups
        )


# Class to use Flair with Presidio as an external recognizer.
class FlairRecognizer(pa.EntityRecognizer):
    """
    Wrapper for a flair model, if needed to be used within Presidio Analyzer.
    This is to make sure the recognizer can be registered with Presidio registry.
    """

    RECOGNIZABLE_ENTITIES = {
        "LOCATION",
        "PERSON",
        "NRP",
        "GPE",
        "ORGANIZATION",
        "MAC_ADDRESS",
        "US_BANK_NUMBER",
        "IMEI",
        "TITLE",
        "LICENSE_PLATE",
        "US_PASSPORT",
        "CURRENCY",
        "ROUTING_NUMBER",
        "US_ITIN",
        "US_BANK_NUMBER",
        "US_DRIVER_LICENSE",
        "AGE",
        "PASSWORD",
        "SWIFT_CODE",
    }

    # This is used to construct the explanation for the result

    _DEFAULT_EXPLANATION = "Identified as {} by Flair's Named Entity Recognition"

    _DEFAULT_CHECK_LABEL_GROUPS = [
        ({"LOCATION"}, {"LOC", "LOCATION", "STREET_ADDRESS", "COORDINATE"}),
        ({"PERSON"}, {"PER", "PERSON"}),
        ({"NRP"}, {"NORP", "NRP"}),
        ({"GPE"}, {"GPE"}),
        ({"ORGANIZATION"}, {"ORG"}),
        ({"MAC_ADDRESS"}, {"MAC_ADDRESS"}),
        ({"US_BANK_NUMBER"}, {"US_BANK_NUMBER"}),
        ({"IMEI"}, {"IMEI"}),
        ({"TITLE"}, {"TITLE"}),
        ({"LICENSE_PLATE"}, {"LICENSE_PLATE"}),
        ({"US_PASSPORT"}, {"US_PASSPORT"}),
        ({"CURRENCY"}, {"CURRENCY"}),
        ({"ROUTING_NUMBER"}, {"ROUTING_NUMBER"}),
        ({"AGE"}, {"AGE"}),
        ({"CURRENCY"}, {"CURRENCY"}),
        ({"SWIFT_CODE"}, {"SWIFT_CODE"}),
        ({"US_ITIN"}, {"US_ITIN"}),
        ({"US_BANK_NUMBER"}, {"US_BANK_NUMBER"}),
        ({"US_DRIVER_LICENSE"}, {"US_DRIVER_LICENSE"}),
    ]

    _DEFAULT_MODEL_LANGUAGES = {
        "en": "beki/flair-pii-distilbert",
    }

    _DEFAULT_PRESIDIO_EQUIVALENCES = {
        "PER": "PERSON",
        "LOC": "LOCATION",
        "ORG": "ORGANIZATION",
        "NROP": "NRP",
        "URL": "URL",
        "US_ITIN": "US_ITIN",
        "US_PASSPORT": "US_PASSPORT",
        "IBAN_CODE": "IBAN_CODE",
        "IP_ADDRESS": "IP_ADDRESS",
        "EMAIL_ADDRESS": "EMAIL",
        "US_DRIVER_LICENSE": "US_DRIVER_LICENSE",
        "US_BANK_NUMBER": "US_BANK_NUMBER",
    }

    def __init__(
        self,
        supported_language: str = "en",
        supported_entities: List[str] = None,
        check_label_groups: Tuple[Set, Set] = None,
    ):
        """
        Initialize the FlairRecognizer.

        :param supported_language: Language to use
        :param supported_entities: Entities to use
        :param check_label_groups: Label groups to check

        :returns: FlairRecognizer object

        """
        self.check_label_groups = check_label_groups or self._DEFAULT_CHECK_LABEL_GROUPS

        supported_entities = supported_entities or self.RECOGNIZABLE_ENTITIES
        self.model = fl.models.SequenceTagger.load(
            self._DEFAULT_MODEL_LANGUAGES.get(supported_language)
        )

        super().__init__(
            supported_entities=supported_entities,
            supported_language=supported_language,
            name="Flair Analytics",
        )

    # main method for the recognizer
    def analyze(
        self,
        text: str,
        entities: List[str],
        nlp_artifacts: pa.nlp_engine.NlpArtifacts = None,
    ) -> List[pa.RecognizerResult]:
        """
        Analyze text and return the results.

        :param text:          The text for analysis.
        :param entities:      The list of entities to recognize.
        :param nlp_artifacts: Not used by this recognizer but needed for the interface.

        :returns: The list of Presidio RecognizerResult constructed from the recognized Flair detections.
        """

        results = []

        sentences = fl.data.Sentence(text)
        self.model.predict(sentences)

        # If there are no specific list of entities, we will look for all of it.
        if not entities:
            entities = self.supported_entities

        # Go over the entities and check if they are in the supported entities list.
        for entity in entities:
            if entity not in self.supported_entities:
                continue

            # Go over the sentences and check if the entity is in the sentence.
            for ent in sentences.get_spans("ner"):
                if not self.__check_label(
                    entity, ent.labels[0].value, self.check_label_groups
                ):
                    continue

                # If the entity is in the sentence, we will add it to the results.
                textual_explanation = self._DEFAULT_EXPLANATION.format(
                    ent.labels[0].value
                )

                # Build the explanation for the result
                explanation = self._build_flair_explanation(
                    round(ent.score, 2), textual_explanation
                )

                flair_result = self._convert_to_recognizer_result(ent, explanation)

                results.append(flair_result)

        return results

    def _convert_to_recognizer_result(
        self, entity: fl.data.Span, explanation: str
    ) -> pa.RecognizerResult:
        """
        Convert Flair result to Presidio RecognizerResult.

        :param entity:      Flair entity of Span
        :param explanation: Presidio AnalysisExplanation

        :returns: Presidio RecognizerResult
        """

        # Convert the entity type to Presidio entity type
        entity_type = self._DEFAULT_PRESIDIO_EQUIVALENCES.get(entity.tag, entity.tag)

        # Convert the score to Presidio score
        flair_score = round(entity.score, 2)

        # Create the Presidio RecognizerResult from the Flair entity
        flair_results = pa.RecognizerResult(
            entity_type=entity_type,
            start=entity.start_position,
            end=entity.end_position,
            score=flair_score,
            analysis_explanation=explanation,
        )

        return flair_results

    def _build_flair_explanation(
        self, original_score: float, explanation: str
    ) -> pa.AnalysisExplanation:
        """
        Create explanation for why this result was detected.

        :param original_score: Score given by this recognizer
        :param explanation:    Explanation string

        :returns: Presidio AnalysisExplanation
        """

        # Create the Presidio AnalysisExplanation for the result
        explanation = pa.AnalysisExplanation(
            recognizer=self.__class__.__name__,
            original_score=original_score,
            textual_explanation=explanation,
        )
        return explanation

    # sanity check of the entity and label before recognition
    @staticmethod
    def __check_label(
        entity: str, label: str, check_label_groups: Tuple[Set, Set]
    ) -> bool:
        return any(
            entity in egrp and label in lgrp for egrp, lgrp in check_label_groups
        )


# get the analyzer engine based on the model
def _get_analyzer_engine(
    model: str = None, entities: List[str] = None
) -> pa.AnalyzerEngine:
    """
    Return pa.AnalyzerEngine.

    :param model: The model to use. Can be "spacy", "flair", "pattern" or "whole".
    :param entities: The list of entities to use.

    :returns: pa.AnalyzerEngine
    """
    # recognizer registry that can store multiple recognizers
    registry = pa.RecognizerRegistry()
    if model == Models.SPACY:
        # custom spacy recognizer
        spacy_recognizer = CustomSpacyRecognizer()
        # add the custom build spacy recognizer
        registry.add_recognizer(spacy_recognizer)
    elif model == Models.FLAIR:
        # pre-trained flair recognizer
        flair_recognizer = FlairRecognizer()
        # add the custom build flair recognizer
        registry.add_recognizer(flair_recognizer)
    elif model == Models.PATTERN:
        # add the pattern recognizer
        pattern_recognizer_factory = PatternRecognizerFactory()
        for recognizer in pattern_recognizer_factory._create_pattern_recognizer():
            registry.add_recognizer(recognizer)
    elif model == Models.WHOLE:
        spacy_recognizer = CustomSpacyRecognizer()
        flair_recognizer = FlairRecognizer()
        registry.add_recognizer(spacy_recognizer)
        registry.add_recognizer(flair_recognizer)
        # add the pattern recognizer
        pattern_recognizer_factory = PatternRecognizerFactory()
        for recognizer in pattern_recognizer_factory._create_pattern_recognizer():
            registry.add_recognizer(recognizer)
    elif not model and entities:
        if set(entities) & CustomSpacyRecognizer.RECOGNIZABLE_ENTITIES:
            spacy_recognizer = CustomSpacyRecognizer()
            registry.add_recognizer(spacy_recognizer)
        if set(entities) & FlairRecognizer.RECOGNIZABLE_ENTITIES:
            flair_recognizer = FlairRecognizer()
            registry.add_recognizer(flair_recognizer)
        # add the pattern recognizer
        if set(entities) & (set(PatternRecognizerFactory.RECOGNIZABLE_ENTITIES.keys())):
            pattern_recognizer_factory = PatternRecognizerFactory()
            for recognizer in pattern_recognizer_factory._create_pattern_recognizer():
                registry.add_recognizer(recognizer)
    else:
        raise ValueError(
            f"argument of model and entities can not be None at the same time"
        )
    analyzer = pa.AnalyzerEngine(
        registry=registry,
        supported_languages=["en"],
    )

    supported_entities = analyzer.get_supported_entities()

    if entities and not all(item in supported_entities for item in entities):
        not_supported_entities = [
            item for item in entities if item not in supported_entities
        ]
        raise ValueError(
            f"The current model {model} doesn't support the following entities: {not_supported_entities}. "
            f"Supported entities are: {supported_entities}"
        )
    return analyzer


def _get_anonymizer_engine() -> pre_anoymizer.AnonymizerEngine:
    """
    Return AnonymizerEngine.

    :returns: The AnonymizerEngine.
    """
    return pre_anoymizer.AnonymizerEngine()


def _anonymize(
    text: str,
    analyze_results: List[pa.RecognizerResult],
    entity_operator_map: dict = None,
    is_full_text: bool = True,
) -> str:
    """
    Anonymize identified input using Presidio Abonymizer.

    :param text:                The text for analysis.
    :param analyze_results:     The list of Presidio RecognizerResult constructed from
    :param entity_operator_map: The entity_operator_map is a dictionary that maps entity to operator name and operator params.
    :param is_full_text:        Whether the text is full text or not.

    :returns: The anonymized text.
    """
    if not text:
        return ""

    anonymizer_engine = _get_anonymizer_engine()
    if not entity_operator_map:
        operators = None
    else:
        # Create OperatorConfig based on the entity_operator_map
        operators = {
            entity: OperatorConfig(operator_name, operator_params)
            for entity, (operator_name, operator_params) in entity_operator_map.items()
        }

    if is_full_text:
        # Anonymize the entire text
        return anonymizer_engine.anonymize(
            text=text, analyzer_results=analyze_results, operators=operators
        ).text
    # Tokenize the text to sentences
    sentences = nltk.sent_tokenize(text)
    anonymized_sentences = []
    current_idx = 0

    # Find the sentence that has pii entity
    for sentence in sentences:
        start_idx = current_idx
        end_idx = start_idx + len(sentence)

        # Get the entities that are in the sentence, update hte start_idx and end_idx
        sentence_results = [
            pa.RecognizerResult(
                result.entity_type,
                start=result.start - start_idx,
                end=result.end - start_idx,
                score=result.score,
            )
            for result in analyze_results
            if result.start >= start_idx and result.end <= end_idx
        ]

        # If PII is detected
        if sentence_results:
            anonymized_sentence = anonymizer_engine.anonymize(
                text=sentence, analyzer_results=sentence_results, operators=operators
            ).text
            anonymized_sentences.append(anonymized_sentence)

        current_idx = end_idx

    return " ".join(anonymized_sentences)


def _get_tokens(
    text: str, analyze_results: List[pa.RecognizerResult], is_full: bool = True
) -> List[str]:
    """
    Get the full tokens or only contains the entities that can form a sentence.

    :param text:            The text for analysis.
    :param analyze_results: The list of Presidio RecognizerResult constructed from
    :param is_full:         Whether return full tokens or just the tokens that only contains the entities that can form a sentence.

    :returns: The tokens.
    """

    tokens = []
    # sort by start index
    results = sorted(analyze_results, key=lambda x: x.start)
    for i, res in enumerate(results):
        if i == 0:
            tokens.append(text[: res.start])

        # append entity text and entity type
        tokens.append((text[res.start : res.end], res.entity_type))

        # if another entity coming i.e. we're not at the last results element,
        # add text up to next entity
        if i != len(results) - 1:
            tokens.append(text[res.end : results[i + 1].start])
        # if no more entities coming, add all remaining text
        else:
            tokens.append(text[res.end :])

    # get the tokens that only contains the entities that can form a sentence
    part_annontated_tokens = []
    if not is_full:
        last_end_sentence = 0
        for i, token in enumerate(tokens):
            if any(item in token for item in [".", "!", "?"]) and any(
                type(item) is tuple for item in tokens[last_end_sentence:i]
            ):
                part_annontated_tokens.append(tokens[last_end_sentence:i])
                last_end_sentence = i
        return part_annontated_tokens
    return tokens


def _annotate(
    text: str, st_analyze_results: List[pa.RecognizerResult], is_full_html: bool = True
) -> List[str]:
    """
    Annotate identified input using Presidio Anonymizer.

    :param text:               The text for analysis.
    :param st_analyze_results: The list of Presidio RecognizerResult constructed from analysis.
    :param is_full_html:       Whether generate full html or not.

    :returns: The list of tokens with the identified entities.

    """
    return _get_tokens(text, st_analyze_results, is_full_html)


def _process(
    text: str,
    model: pa.AnalyzerEngine,
    score_threshold: float,
    entities: List[str] = None,
    entities_operator_map: dict = None,
    is_full_text: bool = True,
) -> Tuple[str, list]:
    """
    Process the text of str using the model.

    :param text:                  Text to process
    :param model:                 Model to use for processing
    :param entities:              Entities to recognize
    :param entities_operator_map: The entity_operator_map is a dictionary that maps entity to operator name and operator params.
    :param score_threshold:       The score threshold to use for recognition
    :param is_full_text:          Whether to return the full text or just the annotated text

    :returns: A tuple of:

              * the anonymized text
              * the list of Presidio RecognizerResult constructed from analysis
    """

    # get the analyzer engine
    analyzer = model

    # analyze the text that can be used for anonymization
    results = analyzer.analyze(
        text=text,
        language="en",
        entities=entities,
        score_threshold=score_threshold,
        return_decision_process=True,
    )

    # anonymize the text, replace the pii entities with the labels
    anonymized_text = _anonymize(text, results, entities_operator_map, is_full_text)

    return anonymized_text, results


def _get_single_html(
    text: str, results: List[pa.RecognizerResult], is_full_html: bool = True
):
    """
    Generate the html for a single txt file.

    :param text:         The text for analysis.
    :param results:      The list of Presidio RecognizerResult constructed from analysis.
    :param is_full_html: Whether generate full html or not.

    :returns: The html string for a single txt file.
    """
    # convert the results to tokens to generate the html
    tokens = _annotate(text, results, is_full_html)
    html = at_util.get_annotated_html(*tokens)

    # avoid the error during rendering of the \n in the html
    backslash_char = "\\"

    html_str = f"<p>{html.replace('{backslash_char}n', '<br>')}</p>"

    return html_str


def _get_single_json(results: List[pa.RecognizerResult], is_full_report: bool = True):
    """
    Generate the json for a single txt file.

    :param results:        The list of Presidio RecognizerResult constructed from analysis.
    :param is_full_report: Whether generate full json or not.

    :returns: The json string for a single txt file.
    """
    # generate the stats report if needed
    if not is_full_report:
        stats = []
        # add the simplify stats logic here
        for item in results:
            item.analysis_explanation = None
            stats.append(item)
    else:
        stats = results

    return stats


def _get_all_html(
    txt_content: dict,
    res_dict: dict,
    is_full_html: bool = True,
):
    """
    Generate the html for all txt files.

    :param txt_content:  The dictionary of txt file name and content.
    :param res_dict:     The dictionary of txt file name and the list of Presidio RecognizerResult constructed from analysis.
    :param is_full_html: Whether generate full html or not.

    :returns: The html string for all txt files.

    """
    # These are placeholder for the html string
    html_index = "<html><head><title>Highlighted Pii Entities</title></head><body><h1>Highlighted Pii Entities</h1><ul>"
    html_content = ""
    for txt_file, results in res_dict.items():
        txt = txt_content[txt_file]
        html_index += f"<li><a href='#{txt_file}'>{txt_file}</a></li>"
        html_content += f"<li><h2>{txt_file}</h2><p>{_get_single_html(txt, results, is_full_html)}</p></li>"
    html_index += "</ul>"
    html_res = f"{html_index}{html_content}</body></html>"

    return html_res


def _get_all_rpt(res_dict: dict, is_full_report: bool = True):
    """
    Generate the stats report for all txt files.

    :param res_dict:       The dictionary of txt file name and the list of Presidio RecognizerResult constructed from analysis.
    :param is_full_report: Whether generate full report or not.

    :returns: The stats report for all txt files.
    """
    # These are placeholder for the json report
    stats_dict = {}
    for txt_file, results in res_dict.items():
        new_stats = []
        for item in _get_single_json(results, is_full_report):
            if is_full_report:
                item.analysis_explanation = item.analysis_explanation.to_dict()
                new_stats.append(item.to_dict())
            else:
                tmp_dict = item.to_dict()
                tmp_dict.pop("analysis_explanation")
                tmp_dict.pop("recognition_metadata")
                new_stats.append(tmp_dict)
        stats_dict[txt_file] = new_stats
    return stats_dict


def recognize_pii(
    context: mlrun.MLClientCtx,
    input_path: Union[str, pathlib.Path],
    html_key: str,
    score_threshold: float,
    output_directory: str = None,
    entities: List[
        str
    ] = None,  # List of entities to recognize, default is recognizing all
    entity_operator_map: dict = None,
    model: str = None,
    generate_json: bool = True,
    generate_html: bool = True,
    is_full_text: bool = True,
    is_full_html: bool = True,
    is_full_report: bool = True,
) -> Union[Tuple[str, pd.DataFrame, dict, dict], Tuple[str, pd.DataFrame, dict]]:
    """
    Walk through the input path, recognize PII in text and store the anonymized text in the output path.
    Generate the html with different colors for each entity, json report of the explanation.

    :param context:              The MLRun context. this is needed for log the artifacts.
    :param input_path:           The input path of the text files needs to be analyzed.
    :param html_key:             The html key for the artifact.
    :param score_threshold:      The score threshold to mark the recognition as trusted.
    :param output_directory:     The output directory path to store the anonymized text.
    :param entities:             The list of entities to recognize.
    :param entity_operator_map:  The map of entity to operator (mask, redact, replace, keep, hash, and its params)
    :param model:                The model to use. Can be "spacy", "flair", "pattern" or "whole".
    :param generate_json:        Whether to generate the json report of the explanation.
    :param generate_html:        Whether to generate the html report of the explanation.
    :param is_full_text:         Whether to return the full text or only the masked text.
    :param is_full_html:         Whether to return the full html or just the annotated text
    :param is_full_report:       Whether to return the full report or just the score and start, end index

    :returns: A tuple of:

              * Path to the output directory
              * The json report of the explanation (if generate_json is True)
              * A dictionary of errors files that were not processed

    """

    # Set output directory
    if output_directory is None:
        output_directory = tempfile.mkdtemp()

    # Create the output directory:
    output_directory = pathlib.Path(output_directory)
    if not output_directory.exists():
        output_directory.mkdir(parents=True, exist_ok=True)

    txt_files_directory = pathlib.Path(input_path)
    successes = []
    errors = {}

    res_dict = {}
    txt_content = {}
    # Load the model:
    analyzer = _get_analyzer_engine(model, entities)
    logger.info("Model loaded")
    # Go over the text files in the input path, analyze and anonymize them:
    for txt_file in tqdm(
        list(txt_files_directory.glob("*.txt")),
        desc="Processing files",
        unit="file",
    ):
        try:
            # Load the str from the text file
            text = txt_file.read_text()
            txt_content[str(txt_file)] = text
            # Process the text to recoginze the pii entities in it
            anonymized_text, results = _process(
                text=text,
                model=analyzer,
                entities=entities,
                entities_operator_map=entity_operator_map,
                score_threshold=score_threshold,
                is_full_text=is_full_text,
            )
            res_dict[str(txt_file)] = results
            # Store the anonymized text in the output path
            output_file = output_directory / f"{txt_file.stem}.txt"
            output_file.parent.mkdir(parents=True, exist_ok=True)
            with open(output_file, "w") as f:
                f.write(anonymized_text)
            successes.append([txt_file.name, output_file.name])
        except Exception as e:
            errors[str(txt_file)] = str(e)
            logger.error(f"Error processing {txt_file}: {e}")

    successes = pd.DataFrame(
        successes,
        columns=["original_file", "anonymized_file"],
    )

    if generate_html:
        # Generate the html report
        html_res = _get_all_html(txt_content, res_dict, is_full_html)
        # Store the html report in the context
        arti_html = mlrun.artifacts.Artifact(body=html_res, format="html", key=html_key)
        context.log_artifact(arti_html)
    if generate_json:
        # Generate the json report
        json_res = _get_all_rpt(res_dict, is_full_report)
        return str(output_directory), successes, errors, json_res
    return str(output_directory), successes, errors
 - code_origin: '' - origin_filename: '' - description: This function is used to recognize PII in a directory of text files - image: '' + lineno: 845 command: '' - disable_auto_mount: false -kind: job -metadata: - name: pii-recognizer - tag: '' - categories: - - data-preparation - - NLP + description: This function is used to recognize PII in a directory of text files + default_handler: recognize_pii diff --git a/functions/src/pii_recognizer/pii_recognizer.py b/functions/src/pii_recognizer/pii_recognizer.py index 0acc55dcb..3a5366635 100644 --- a/functions/src/pii_recognizer/pii_recognizer.py +++ b/functions/src/pii_recognizer/pii_recognizer.py @@ -17,7 +17,7 @@ import pathlib import tempfile import warnings -from typing import List, Set, Tuple, Union +from typing import List import annotated_text.util as at_util import mlrun @@ -162,9 +162,9 @@ class CustomSpacyRecognizer(pa.LocalRecognizer): def __init__( self, supported_language: str = "en", - supported_entities: List[str] = None, - check_label_groups: Tuple[Set, Set] = None, - context: List[str] = None, + supported_entities: list[str] = None, + check_label_groups: tuple[set, set] = None, + context: list[str] = None, ner_strength: float = 1, ): """ @@ -258,7 +258,7 @@ def analyze(self, text: str, entities: List[str], nlp_artifacts=None): # noqa D @staticmethod def __check_label( - entity: str, label: str, check_label_groups: Tuple[Set, Set] + entity: str, label: str, check_label_groups: tuple[set, set] ) -> bool: """ Check if the label is in the label group. @@ -351,8 +351,8 @@ class FlairRecognizer(pa.EntityRecognizer): def __init__( self, supported_language: str = "en", - supported_entities: List[str] = None, - check_label_groups: Tuple[Set, Set] = None, + supported_entities: list[str] = None, + check_label_groups: tuple[set, set] = None, ): """ Initialize the FlairRecognizer. @@ -381,9 +381,9 @@ def __init__( def analyze( self, text: str, - entities: List[str], + entities: list[str], nlp_artifacts: pa.nlp_engine.NlpArtifacts = None, - ) -> List[pa.RecognizerResult]: + ) -> list[pa.RecognizerResult]: """ Analyze text and return the results. @@ -483,7 +483,7 @@ def _build_flair_explanation( # sanity check of the entity and label before recognition @staticmethod def __check_label( - entity: str, label: str, check_label_groups: Tuple[Set, Set] + entity: str, label: str, check_label_groups: tuple[set, set] ) -> bool: return any( entity in egrp and label in lgrp for egrp, lgrp in check_label_groups @@ -492,7 +492,7 @@ def __check_label( # get the analyzer engine based on the model def _get_analyzer_engine( - model: str = None, entities: List[str] = None + model: str = None, entities: list[str] = None ) -> pa.AnalyzerEngine: """ Return pa.AnalyzerEngine. @@ -542,7 +542,7 @@ def _get_analyzer_engine( registry.add_recognizer(recognizer) else: raise ValueError( - f"argument of model and entities can not be None at the same time" + "argument of model and entities can not be None at the same time" ) analyzer = pa.AnalyzerEngine( registry=registry, @@ -573,7 +573,7 @@ def _get_anonymizer_engine() -> pre_anoymizer.AnonymizerEngine: def _anonymize( text: str, - analyze_results: List[pa.RecognizerResult], + analyze_results: list[pa.RecognizerResult], entity_operator_map: dict = None, is_full_text: bool = True, ) -> str: @@ -640,8 +640,8 @@ def _anonymize( def _get_tokens( - text: str, analyze_results: List[pa.RecognizerResult], is_full: bool = True -) -> List[str]: + text: str, analyze_results: list[pa.RecognizerResult], is_full: bool = True +) -> list[str]: """ Get the full tokens or only contains the entities that can form a sentence. @@ -685,8 +685,8 @@ def _get_tokens( def _annotate( - text: str, st_analyze_results: List[pa.RecognizerResult], is_full_html: bool = True -) -> List[str]: + text: str, st_analyze_results: list[pa.RecognizerResult], is_full_html: bool = True +) -> list[str]: """ Annotate identified input using Presidio Anonymizer. @@ -704,10 +704,10 @@ def _process( text: str, model: pa.AnalyzerEngine, score_threshold: float, - entities: List[str] = None, + entities: list[str] = None, entities_operator_map: dict = None, is_full_text: bool = True, -) -> Tuple[str, list]: +) -> tuple[str, list]: """ Process the text of str using the model. @@ -743,7 +743,7 @@ def _process( def _get_single_html( - text: str, results: List[pa.RecognizerResult], is_full_html: bool = True + text: str, results: list[pa.RecognizerResult], is_full_html: bool = True ): """ Generate the html for a single txt file. @@ -766,7 +766,7 @@ def _get_single_html( return html_str -def _get_single_json(results: List[pa.RecognizerResult], is_full_report: bool = True): +def _get_single_json(results: list[pa.RecognizerResult], is_full_report: bool = True): """ Generate the json for a single txt file. @@ -844,11 +844,11 @@ def _get_all_rpt(res_dict: dict, is_full_report: bool = True): def recognize_pii( context: mlrun.MLClientCtx, - input_path: Union[str, pathlib.Path], + input_path: str | pathlib.Path, html_key: str, score_threshold: float, output_directory: str = None, - entities: List[ + entities: list[ str ] = None, # List of entities to recognize, default is recognizing all entity_operator_map: dict = None, @@ -858,7 +858,7 @@ def recognize_pii( is_full_text: bool = True, is_full_html: bool = True, is_full_report: bool = True, -) -> Union[Tuple[str, pd.DataFrame, dict, dict], Tuple[str, pd.DataFrame, dict]]: +) -> tuple[str, pd.DataFrame, dict, dict] | tuple[str, pd.DataFrame, dict]: """ Walk through the input path, recognize PII in text and store the anonymized text in the output path. Generate the html with different colors for each entity, json report of the explanation. diff --git a/functions/src/pii_recognizer/test_pii_recognizer.py b/functions/src/pii_recognizer/test_pii_recognizer.py index 81a16611f..080a5367a 100644 --- a/functions/src/pii_recognizer/test_pii_recognizer.py +++ b/functions/src/pii_recognizer/test_pii_recognizer.py @@ -13,16 +13,14 @@ # limitations under the License. # -import os -import pytest import random -from faker import Faker + import mlrun +import pytest +from faker import Faker from pii_recognizer import ( - _process, _get_analyzer_engine, - _anonymize, - _annotate, + _process, recognize_pii_parallel, ) diff --git a/functions/src/pyannote_audio/function.yaml b/functions/src/pyannote_audio/function.yaml index b4cd9ad93..78bfaf1a6 100644 --- a/functions/src/pyannote_audio/function.yaml +++ b/functions/src/pyannote_audio/function.yaml @@ -1,56 +1,58 @@ +metadata: + tag: '' + name: pyannote-audio + categories: + - deep-learning + - audio +verbose: false kind: job spec: - command: '' - disable_auto_mount: false image: '' + disable_auto_mount: false build: - code_origin: '' + origin_filename: '' + functionSourceCode: # Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import heapq
import logging
import operator
import os
import pathlib
from functools import reduce, wraps
from typing import Any

import pandas as pd
import pyannote.audio
import pyannote.core
import torch
import torchaudio
from tqdm import tqdm

# Get the global logger:
_LOGGER = logging.getLogger()


def _check_mlrun_and_open_mpi() -> tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intracomm"]:
    is_mpi = False
    try:
        import mlrun

        context = mlrun.get_or_create_ctx(name="mlrun")
        is_mpi = context.labels.get("kind", "job") == "mpijob"

        if is_mpi:
            try:
                from mpi4py import MPI

                return context, MPI.COMM_WORLD
            except ModuleNotFoundError as mpi4py_not_found:
                context.logger.error(
                    "To distribute the function using MLRun's 'mpijob' you need to have `mpi4py` package in your "
                    "interpreter. Please run `pip install mpi4py` and make sure you have open-mpi."
                )
                raise mpi4py_not_found
        else:
            return context, None
    except ModuleNotFoundError as module_not_found:
        if is_mpi:
            raise module_not_found
    return None, None


def open_mpi_handler(
    worker_inputs: list[str], root_worker_inputs: dict[str, Any] = None
):
    global _LOGGER

    # Check for MLRun and OpenMPI availability:
    context, comm = _check_mlrun_and_open_mpi()

    # Check if MLRun is available, set the global logger to MLRun's:
    if context:
        _LOGGER = context.logger

    def decorator(handler):
        if comm is None or comm.Get_size() == 1:
            return handler

        @wraps(handler)
        def wrapper(**kwargs):
            # Get the open mpi environment properties:
            size = comm.Get_size()
            rank = comm.Get_rank()

            # Give the correct chunk of the workers inputs:
            for worker_input in worker_inputs:
                input_argument = kwargs[worker_input]
                if input_argument is None:
                    continue
                if isinstance(input_argument, str):
                    input_argument = _get_audio_files(
                        data_path=pathlib.Path(input_argument).absolute()
                    )
                if len(input_argument) < size:
                    raise ValueError(
                        f"Cannot split the input '{worker_input}' of length {len(input_argument)} to {size} workers. "
                        f"Please reduce the amount of workers for this input."
                    )
                even_chunk_size = len(input_argument) // size
                chunk_start = rank * even_chunk_size
                chunk_end = (
                    (rank + 1) * even_chunk_size
                    if rank + 1 < size
                    else len(input_argument)
                )
                context.logger.info(
                    f"Rank #{rank}: Processing input chunk of '{worker_input}' "
                    f"from index {chunk_start} to {chunk_end}."
                )
                if isinstance(input_argument, list):
                    input_argument = input_argument[chunk_start:chunk_end]
                elif isinstance(input_argument, pd.DataFrame):
                    input_argument = input_argument.iloc[chunk_start:chunk_end:, :]
                kwargs[worker_input] = input_argument

            # Set the root worker only arguments:
            if rank == 0 and root_worker_inputs:
                kwargs.update(root_worker_inputs)

            # Run the worker:
            output = handler(**kwargs)

            # Send the output to the root rank (rank #0):
            output = comm.gather(output, root=0)
            if rank == 0:
                # Join the outputs:
                context.logger.info("Collecting data from workers to root worker.")
                diarization_dictionary = reduce(
                    operator.ior, [dia for dia, _ in output], {}
                )
                errors_dictionary = reduce(operator.ior, [err for _, err in output], {})
                return diarization_dictionary, errors_dictionary
            return None

        return wrapper

    return decorator


@open_mpi_handler(worker_inputs=["data_path"], root_worker_inputs={"verbose": True})
def diarize(
    data_path: str | list[str],
    model_name: str = "pyannote/speaker-diarization-3.0",
    access_token: str = None,
    device: str = None,
    speakers_labels: list[str] = None,
    speaker_prefix: str = "speaker_",
    separate_by_channels: bool = False,
    minimum_speakers: int = None,
    maximum_speakers: int = None,
    verbose: bool = False,
) -> tuple[dict[str, list[tuple[float, float, str]]], dict[str, str]]:
    """
    Perform speech diarization on given audio files using pyannote-audio (https://github.com/pyannote/pyannote-audio).
    The end result is a dictionary with the file names as keys and their diarization as value. A diarization is a list
    of tuples: (start, end, speaker_label).

    To use the `pyannote.audio` models you must pass a Huggingface token and get access to the required models. The
    token can be passed in one of the following options:

    * Use the parameter `access_token`.
    * Set an environment variable named "HUGGING_FACE_HUB_TOKEN".
    * If using MLRun, you can pass it as a secret named "HUGGING_FACE_HUB_TOKEN".

    To get access to the models on Huggingface, visit their page. For example, to use the default diarization model set
    in this function ("pyannote/speaker-diarization-3.0"), you need access for these two models:

    * https://huggingface.co/pyannote/segmentation-3.0
    * https://huggingface.co/pyannote/speaker-diarization-3.0

    Note: To control the recognized speakers in the diarization output you can choose one of the following methods:

    * For a known speakers amount, you may set speaker labels via the `speakers_labels` parameter that will be used in
      the order of speaking in the audio (first person speaking be the first label in the list). In addition, you can do
      diarization per channel (setting the parameter `separate_by_channels` to True). Each label will be assigned to a
      specific channel by order (first label to channel 0, second label to channel 1 and so on). Notice, this will
      increase runtime.
    * For unknown speakers amount, you can set the `speaker_prefix` parameter to add a prefix for each speaker number.
      You can also help the diarization by setting the speakers range via the `speakers_amount_range` parameter.

    :param data_path:            A directory of the audio files, a single file or a list of files to transcribe.
    :param model_name:           One of the official diarization model names (referred as diarization pipelines) of
                                 `pyannote.audio` Huggingface page. Default: "pyannote/speaker-diarization-3.0".
    :param access_token:         An access token to pass for using the `pyannote.audio` models. If not provided, it
                                 will be looking for the environment variable "HUGGING_FACE_HUB_TOKEN". If MLRun is
                                 available, it will look for a secret "HUGGING_FACE_HUB_TOKEN".
    :param device:               Device to load the model. Can be one of {"cuda", "cpu"}. Default will prefer "cuda" if
                                 available.
    :param speakers_labels:      Labels to use for the recognized speakers. Default: numeric labels (0, 1, ...).
    :param separate_by_channels: If each speaker is speaking in a separate channel, you can diarize each channel and
                                 combine the result into a single diarization. Each label set in the `speakers_labels`
                                 parameter will be assigned to a specific channel by order.
    :param speaker_prefix:       A prefix to add for the speakers labels. This parameter is ignored if
                                 `speakers_labels` is not None. Default: "speaker".
    :param minimum_speakers:     Set the minimum expected amount of speakers to be in the audio files. This parameter is
                                 ignored if `speakers_labels` is not None.
    :param maximum_speakers:     Set the maximum expected amount of speakers to be in the audio files. This parameter is
                                 ignored if `speakers_labels` is not None.
    :param verbose:              Whether to present logs of a progress bar and errors. Default: True.

    :returns: A tuple of:

              * Speech diarization dictionary.
              * A dictionary of errored files that were not transcribed.
    """
    global _LOGGER

    # Get the input audio files to diarize:
    if isinstance(data_path, str):
        data_path = pathlib.Path(data_path).absolute()
        audio_files = _get_audio_files(data_path=data_path)
    else:  # Should be a list of files.
        audio_files = data_path

    # Get the Huggingface access token:
    access_token = _get_access_token(parameter=access_token)
    if access_token is None:
        raise ValueError(
            "A Huggingface access token must be provided to use `pyannote.audio` models. Access token can be passed "
            "via one of the following options:\n"
            "* Use the parameter `access_token`.\n"
            "* Set an environment variable named 'HUGGING_FACE_HUB_TOKEN'.\n"
            "* If using MLRun, you can pass it as a secret named 'HUGGING_FACE_HUB_TOKEN'."
        )

    # Load the diarization pipeline:
    pipeline = pyannote.audio.Pipeline.from_pretrained(
        checkpoint_path=model_name, use_auth_token=access_token
    )

    # Set the device:
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    if device != "cpu":
        pipeline.to(torch.device(device))

    # Prepare the successes dataframe and errors dictionary to be returned:
    diarizations = {}
    errors = {}

    # Prepare the diarization keyword arguments:
    diarize_kwargs = {}
    if speakers_labels:
        diarize_kwargs["num_speakers"] = len(speakers_labels)
    else:
        if minimum_speakers:
            diarize_kwargs["min_speakers"] = minimum_speakers
        if maximum_speakers:
            diarize_kwargs["max_speakers"] = maximum_speakers

    # Go over the audio files and diarize:
    for audio_file in tqdm(
        audio_files, desc="Diarizing", unit="file", disable=not verbose
    ):
        try:
            # Load audio file:
            audio, sample_rate = torchaudio.load(uri=audio_file, channels_first=True)
            # Get the diarization (if provided):
            diarizations[audio_file.name] = _diarize(
                audio=audio,
                sample_rate=sample_rate,
                pipeline=pipeline,
                speakers_labels=speakers_labels,
                separate_by_channels=separate_by_channels,
                speaker_prefix=speaker_prefix,
                diarize_kwargs=diarize_kwargs,
            )
        except Exception as exception:
            # Note the exception as error in the dictionary:
            if verbose:
                _LOGGER.warning(f"Error in file: '{audio_file.name}'")
            errors[str(audio_file.name)] = str(exception)
            continue

    # Print the head of the produced dataframe and return:
    if verbose:
        _LOGGER.info(f"Done ({len(diarizations)}/{len(audio_files)})\n")
    return diarizations, errors


def _get_audio_files(
    data_path: pathlib.Path,
) -> list[pathlib.Path]:
    # Check if the path is of a directory or a file:
    if data_path.is_dir():
        # Get all files inside the directory:
        audio_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        audio_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be either a directory path or a file path. "
            f"Given: {str(data_path)} "
        )

    return audio_files


def _get_access_token(parameter: str) -> str:
    # If given as a parameter, return it:
    if parameter:
        return parameter

    # Otherwise, look at the environment variable:
    environment_variable = os.environ.get("HUGGING_FACE_HUB_TOKEN")
    if environment_variable:
        return environment_variable

    # Lastly, try look in the set secrets in MLRun:
    secret = None
    try:
        import mlrun

        context = mlrun.get_or_create_ctx(name="mlrun")
        secret = context.get_secret(key="HUGGING_FACE_HUB_TOKEN")
    except ModuleNotFoundError:
        pass

    return secret


def _diarize(
    audio: torch.Tensor,
    sample_rate: int,
    pipeline: pyannote.audio.Pipeline,
    speakers_labels: list[str],
    separate_by_channels: bool,
    speaker_prefix: str,
    diarize_kwargs: dict,
) -> list[tuple[float, float, str]]:
    # If there is no need for separation by channels, we diarize and return:
    if not separate_by_channels:
        # Diarize:
        diarization: pyannote.core.Annotation = pipeline(
            file={"waveform": audio, "sample_rate": sample_rate}, **diarize_kwargs
        )
        # Verify speakers labels (should not fail here as we set `num_speakers=len(speakers_labels)` when inferring
        # through the pipeline):
        if speakers_labels:
            given_speakers = len(speakers_labels)
            found_speakers = len(set(diarization.labels()))
            if given_speakers < found_speakers:
                raise ValueError(
                    f"Not enough `speakers_labels` were given. Got {given_speakers} labels but the diarization "
                    f"recognized {found_speakers} speakers."
                )
        # Return as a diarization list - a sorted list of tuples of start time, end time and a label (the default label
        # returned is "SPEAKER_i" so we take only the index out of it):
        return [
            (
                segment.start,
                segment.end,
                speakers_labels[int(label.split("_")[1])]
                if speakers_labels
                else f"{speaker_prefix}{int(label.split('_')[1])}",
            )
            for segment, track, label in diarization.itertracks(yield_label=True)
        ]

    # Separate to channels and diarize (we expect only one speaker per channel):
    channel_diarizations = [
        _diarize(
            audio=audio[channel].unsqueeze(
                0
            ),  # Take channel and add a channel dimension to it.
            sample_rate=sample_rate,
            pipeline=pipeline,
            speakers_labels=[
                speakers_labels[channel]
            ],  # Take the channel's label only.
            separate_by_channels=False,
            speaker_prefix=speaker_prefix,
            diarize_kwargs={"num_speakers": 1},  # Set to one speaker.
        )
        for channel in range(audio.shape[0])
    ]

    # Merge the channel diarizations into a single sorted list:
    return list(heapq.merge(*channel_diarizations))
 requirements: - pyannote.audio - pyannote.core - torchaudio - tqdm + code_origin: '' base_image: mlrun/mlrun-gpu - origin_filename: '' - functionSourceCode: # Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import heapq
import logging
import operator
import os
import pathlib
from functools import reduce, wraps
from typing import Any, Dict, List, Tuple, Union

import pandas as pd
import pyannote.audio
import pyannote.core
import torch
import torchaudio
from tqdm import tqdm

# Get the global logger:
_LOGGER = logging.getLogger()


def _check_mlrun_and_open_mpi() -> Tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intracomm"]:
    is_mpi = False
    try:
        import mlrun

        context = mlrun.get_or_create_ctx(name="mlrun")
        is_mpi = context.labels.get("kind", "job") == "mpijob"

        if is_mpi:
            try:
                from mpi4py import MPI

                return context, MPI.COMM_WORLD
            except ModuleNotFoundError as mpi4py_not_found:
                context.logger.error(
                    "To distribute the function using MLRun's 'mpijob' you need to have `mpi4py` package in your "
                    "interpreter. Please run `pip install mpi4py` and make sure you have open-mpi."
                )
                raise mpi4py_not_found
        else:
            return context, None
    except ModuleNotFoundError as module_not_found:
        if is_mpi:
            raise module_not_found
    return None, None


def open_mpi_handler(
    worker_inputs: List[str], root_worker_inputs: Dict[str, Any] = None
):
    global _LOGGER

    # Check for MLRun and OpenMPI availability:
    context, comm = _check_mlrun_and_open_mpi()

    # Check if MLRun is available, set the global logger to MLRun's:
    if context:
        _LOGGER = context.logger

    def decorator(handler):
        if comm is None or comm.Get_size() == 1:
            return handler

        @wraps(handler)
        def wrapper(**kwargs):
            # Get the open mpi environment properties:
            size = comm.Get_size()
            rank = comm.Get_rank()

            # Give the correct chunk of the workers inputs:
            for worker_input in worker_inputs:
                input_argument = kwargs[worker_input]
                if input_argument is None:
                    continue
                if isinstance(input_argument, str):
                    input_argument = _get_audio_files(
                        data_path=pathlib.Path(input_argument).absolute()
                    )
                if len(input_argument) < size:
                    raise ValueError(
                        f"Cannot split the input '{worker_input}' of length {len(input_argument)} to {size} workers. "
                        f"Please reduce the amount of workers for this input."
                    )
                even_chunk_size = len(input_argument) // size
                chunk_start = rank * even_chunk_size
                chunk_end = (
                    (rank + 1) * even_chunk_size
                    if rank + 1 < size
                    else len(input_argument)
                )
                context.logger.info(
                    f"Rank #{rank}: Processing input chunk of '{worker_input}' "
                    f"from index {chunk_start} to {chunk_end}."
                )
                if isinstance(input_argument, list):
                    input_argument = input_argument[chunk_start:chunk_end]
                elif isinstance(input_argument, pd.DataFrame):
                    input_argument = input_argument.iloc[chunk_start:chunk_end:, :]
                kwargs[worker_input] = input_argument

            # Set the root worker only arguments:
            if rank == 0 and root_worker_inputs:
                kwargs.update(root_worker_inputs)

            # Run the worker:
            output = handler(**kwargs)

            # Send the output to the root rank (rank #0):
            output = comm.gather(output, root=0)
            if rank == 0:
                # Join the outputs:
                context.logger.info("Collecting data from workers to root worker.")
                diarization_dictionary = reduce(
                    operator.ior, [dia for dia, _ in output], {}
                )
                errors_dictionary = reduce(operator.ior, [err for _, err in output], {})
                return diarization_dictionary, errors_dictionary
            return None

        return wrapper

    return decorator


@open_mpi_handler(worker_inputs=["data_path"], root_worker_inputs={"verbose": True})
def diarize(
    data_path: Union[str, List[str]],
    model_name: str = "pyannote/speaker-diarization-3.0",
    access_token: str = None,
    device: str = None,
    speakers_labels: List[str] = None,
    speaker_prefix: str = "speaker_",
    separate_by_channels: bool = False,
    minimum_speakers: int = None,
    maximum_speakers: int = None,
    verbose: bool = False,
) -> Tuple[Dict[str, List[Tuple[float, float, str]]], Dict[str, str]]:
    """
    Perform speech diarization on given audio files using pyannote-audio (https://github.com/pyannote/pyannote-audio).
    The end result is a dictionary with the file names as keys and their diarization as value. A diarization is a list
    of tuples: (start, end, speaker_label).

    To use the `pyannote.audio` models you must pass a Huggingface token and get access to the required models. The
    token can be passed in one of the following options:

    * Use the parameter `access_token`.
    * Set an environment variable named "HUGGING_FACE_HUB_TOKEN".
    * If using MLRun, you can pass it as a secret named "HUGGING_FACE_HUB_TOKEN".

    To get access to the models on Huggingface, visit their page. For example, to use the default diarization model set
    in this function ("pyannote/speaker-diarization-3.0"), you need access for these two models:

    * https://huggingface.co/pyannote/segmentation-3.0
    * https://huggingface.co/pyannote/speaker-diarization-3.0

    Note: To control the recognized speakers in the diarization output you can choose one of the following methods:

    * For a known speakers amount, you may set speaker labels via the `speakers_labels` parameter that will be used in
      the order of speaking in the audio (first person speaking be the first label in the list). In addition, you can do
      diarization per channel (setting the parameter `separate_by_channels` to True). Each label will be assigned to a
      specific channel by order (first label to channel 0, second label to channel 1 and so on). Notice, this will
      increase runtime.
    * For unknown speakers amount, you can set the `speaker_prefix` parameter to add a prefix for each speaker number.
      You can also help the diarization by setting the speakers range via the `speakers_amount_range` parameter.

    :param data_path:            A directory of the audio files, a single file or a list of files to transcribe.
    :param model_name:           One of the official diarization model names (referred as diarization pipelines) of
                                 `pyannote.audio` Huggingface page. Default: "pyannote/speaker-diarization-3.0".
    :param access_token:         An access token to pass for using the `pyannote.audio` models. If not provided, it
                                 will be looking for the environment variable "HUGGING_FACE_HUB_TOKEN". If MLRun is
                                 available, it will look for a secret "HUGGING_FACE_HUB_TOKEN".
    :param device:               Device to load the model. Can be one of {"cuda", "cpu"}. Default will prefer "cuda" if
                                 available.
    :param speakers_labels:      Labels to use for the recognized speakers. Default: numeric labels (0, 1, ...).
    :param separate_by_channels: If each speaker is speaking in a separate channel, you can diarize each channel and
                                 combine the result into a single diarization. Each label set in the `speakers_labels`
                                 parameter will be assigned to a specific channel by order.
    :param speaker_prefix:       A prefix to add for the speakers labels. This parameter is ignored if
                                 `speakers_labels` is not None. Default: "speaker".
    :param minimum_speakers:     Set the minimum expected amount of speakers to be in the audio files. This parameter is
                                 ignored if `speakers_labels` is not None.
    :param maximum_speakers:     Set the maximum expected amount of speakers to be in the audio files. This parameter is
                                 ignored if `speakers_labels` is not None.
    :param verbose:              Whether to present logs of a progress bar and errors. Default: True.

    :returns: A tuple of:

              * Speech diarization dictionary.
              * A dictionary of errored files that were not transcribed.
    """
    global _LOGGER

    # Get the input audio files to diarize:
    if isinstance(data_path, str):
        data_path = pathlib.Path(data_path).absolute()
        audio_files = _get_audio_files(data_path=data_path)
    else:  # Should be a list of files.
        audio_files = data_path

    # Get the Huggingface access token:
    access_token = _get_access_token(parameter=access_token)
    if access_token is None:
        raise ValueError(
            "A Huggingface access token must be provided to use `pyannote.audio` models. Access token can be passed "
            "via one of the following options:\n"
            "* Use the parameter `access_token`.\n"
            "* Set an environment variable named 'HUGGING_FACE_HUB_TOKEN'.\n"
            "* If using MLRun, you can pass it as a secret named 'HUGGING_FACE_HUB_TOKEN'."
        )

    # Load the diarization pipeline:
    pipeline = pyannote.audio.Pipeline.from_pretrained(
        checkpoint_path=model_name, use_auth_token=access_token
    )

    # Set the device:
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    if device != "cpu":
        pipeline.to(torch.device(device))

    # Prepare the successes dataframe and errors dictionary to be returned:
    diarizations = {}
    errors = {}

    # Prepare the diarization keyword arguments:
    diarize_kwargs = {}
    if speakers_labels:
        diarize_kwargs["num_speakers"] = len(speakers_labels)
    else:
        if minimum_speakers:
            diarize_kwargs["min_speakers"] = minimum_speakers
        if maximum_speakers:
            diarize_kwargs["max_speakers"] = maximum_speakers

    # Go over the audio files and diarize:
    for audio_file in tqdm(
        audio_files, desc="Diarizing", unit="file", disable=not verbose
    ):
        try:
            # Load audio file:
            audio, sample_rate = torchaudio.load(uri=audio_file, channels_first=True)
            # Get the diarization (if provided):
            diarizations[audio_file.name] = _diarize(
                audio=audio,
                sample_rate=sample_rate,
                pipeline=pipeline,
                speakers_labels=speakers_labels,
                separate_by_channels=separate_by_channels,
                speaker_prefix=speaker_prefix,
                diarize_kwargs=diarize_kwargs,
            )
        except Exception as exception:
            # Note the exception as error in the dictionary:
            if verbose:
                _LOGGER.warning(f"Error in file: '{audio_file.name}'")
            errors[str(audio_file.name)] = str(exception)
            continue

    # Print the head of the produced dataframe and return:
    if verbose:
        _LOGGER.info(f"Done ({len(diarizations)}/{len(audio_files)})\n")
    return diarizations, errors


def _get_audio_files(
    data_path: pathlib.Path,
) -> List[pathlib.Path]:
    # Check if the path is of a directory or a file:
    if data_path.is_dir():
        # Get all files inside the directory:
        audio_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        audio_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be either a directory path or a file path. "
            f"Given: {str(data_path)} "
        )

    return audio_files


def _get_access_token(parameter: str) -> str:
    # If given as a parameter, return it:
    if parameter:
        return parameter

    # Otherwise, look at the environment variable:
    environment_variable = os.environ.get("HUGGING_FACE_HUB_TOKEN")
    if environment_variable:
        return environment_variable

    # Lastly, try look in the set secrets in MLRun:
    secret = None
    try:
        import mlrun

        context = mlrun.get_or_create_ctx(name="mlrun")
        secret = context.get_secret(key="HUGGING_FACE_HUB_TOKEN")
    except ModuleNotFoundError:
        pass

    return secret


def _diarize(
    audio: torch.Tensor,
    sample_rate: int,
    pipeline: pyannote.audio.Pipeline,
    speakers_labels: List[str],
    separate_by_channels: bool,
    speaker_prefix: str,
    diarize_kwargs: dict,
) -> List[Tuple[float, float, str]]:
    # If there is no need for separation by channels, we diarize and return:
    if not separate_by_channels:
        # Diarize:
        diarization: pyannote.core.Annotation = pipeline(
            file={"waveform": audio, "sample_rate": sample_rate}, **diarize_kwargs
        )
        # Verify speakers labels (should not fail here as we set `num_speakers=len(speakers_labels)` when inferring
        # through the pipeline):
        if speakers_labels:
            given_speakers = len(speakers_labels)
            found_speakers = len(set(diarization.labels()))
            if given_speakers < found_speakers:
                raise ValueError(
                    f"Not enough `speakers_labels` were given. Got {given_speakers} labels but the diarization "
                    f"recognized {found_speakers} speakers."
                )
        # Return as a diarization list - a sorted list of tuples of start time, end time and a label (the default label
        # returned is "SPEAKER_i" so we take only the index out of it):
        return [
            (
                segment.start,
                segment.end,
                speakers_labels[int(label.split("_")[1])]
                if speakers_labels
                else f"{speaker_prefix}{int(label.split('_')[1])}",
            )
            for segment, track, label in diarization.itertracks(yield_label=True)
        ]

    # Separate to channels and diarize (we expect only one speaker per channel):
    channel_diarizations = [
        _diarize(
            audio=audio[channel].unsqueeze(
                0
            ),  # Take channel and add a channel dimension to it.
            sample_rate=sample_rate,
            pipeline=pipeline,
            speakers_labels=[
                speakers_labels[channel]
            ],  # Take the channel's label only.
            separate_by_channels=False,
            speaker_prefix=speaker_prefix,
            diarize_kwargs={"num_speakers": 1},  # Set to one speaker.
        )
        for channel in range(audio.shape[0])
    ]

    # Merge the channel diarizations into a single sorted list:
    return list(heapq.merge(*channel_diarizations))
 - default_handler: diarize + filename: pyannote_audio.py entry_points: open_mpi_handler: - name: open_mpi_handler - has_varargs: false - lineno: 61 parameters: - name: worker_inputs - type: List[str] + type: list[str] - name: root_worker_inputs - type: Dict[str, Any] + type: dict[str, Any] default: null - has_kwargs: false + name: open_mpi_handler doc: '' - decorator: - name: decorator + has_kwargs: false has_varargs: false - lineno: 73 + lineno: 61 + decorator: parameters: - name: handler - has_kwargs: false + name: decorator doc: '' + has_kwargs: false + has_varargs: false + lineno: 73 wrapper: name: wrapper + doc: '' + has_kwargs: true has_varargs: false lineno: 78 - has_kwargs: true - doc: '' diarize: - name: diarize - has_varargs: false - lineno: 139 outputs: - doc: 'A tuple of:' - type: Tuple[Dict[str, List[Tuple[float, float, str]]], Dict[str, str]] + type: tuple[dict[str, list[tuple[float, float, str]]], dict[str, str]] parameters: - name: data_path - type: Union[str, List[str]] doc: A directory of the audio files, a single file or a list of files to transcribe. - name: model_name type: str @@ -69,7 +71,7 @@ spec: prefer "cuda" if available. default: null - name: speakers_labels - type: List[str] + type: list[str] doc: 'Labels to use for the recognized speakers. Default: numeric labels (0, 1, ...).' default: null @@ -99,7 +101,7 @@ spec: type: bool doc: 'Whether to present logs of a progress bar and errors. Default: True.' default: false - has_kwargs: false + name: diarize doc: "Perform speech diarization on given audio files using pyannote-audio (https://github.com/pyannote/pyannote-audio).\n\ The end result is a dictionary with the file names as keys and their diarization\ \ as value. A diarization is a list\nof tuples: (start, end, speaker_label).\n\ @@ -123,11 +125,9 @@ spec: \ you can set the `speaker_prefix` parameter to add a prefix for each speaker\ \ number.\n You can also help the diarization by setting the speakers range\ \ via the `speakers_amount_range` parameter." + has_kwargs: false + has_varargs: false + lineno: 139 + command: '' description: pyannote's speech diarization of audio files -metadata: - name: pyannote-audio - tag: '' - categories: - - deep-learning - - audio -verbose: false + default_handler: diarize diff --git a/functions/src/pyannote_audio/pyannote_audio.py b/functions/src/pyannote_audio/pyannote_audio.py index 6271da6ae..bb097a750 100644 --- a/functions/src/pyannote_audio/pyannote_audio.py +++ b/functions/src/pyannote_audio/pyannote_audio.py @@ -18,7 +18,7 @@ import os import pathlib from functools import reduce, wraps -from typing import Any, Dict, List, Tuple, Union +from typing import Any import pandas as pd import pyannote.audio @@ -31,7 +31,7 @@ _LOGGER = logging.getLogger() -def _check_mlrun_and_open_mpi() -> Tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intracomm"]: +def _check_mlrun_and_open_mpi() -> tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intracomm"]: is_mpi = False try: import mlrun @@ -59,7 +59,7 @@ def _check_mlrun_and_open_mpi() -> Tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intrac def open_mpi_handler( - worker_inputs: List[str], root_worker_inputs: Dict[str, Any] = None + worker_inputs: list[str], root_worker_inputs: dict[str, Any] = None ): global _LOGGER @@ -137,17 +137,17 @@ def wrapper(**kwargs): @open_mpi_handler(worker_inputs=["data_path"], root_worker_inputs={"verbose": True}) def diarize( - data_path: Union[str, List[str]], + data_path: str | list[str], model_name: str = "pyannote/speaker-diarization-3.0", access_token: str = None, device: str = None, - speakers_labels: List[str] = None, + speakers_labels: list[str] = None, speaker_prefix: str = "speaker_", separate_by_channels: bool = False, minimum_speakers: int = None, maximum_speakers: int = None, verbose: bool = False, -) -> Tuple[Dict[str, List[Tuple[float, float, str]]], Dict[str, str]]: +) -> tuple[dict[str, list[tuple[float, float, str]]], dict[str, str]]: """ Perform speech diarization on given audio files using pyannote-audio (https://github.com/pyannote/pyannote-audio). The end result is a dictionary with the file names as keys and their diarization as value. A diarization is a list @@ -277,7 +277,7 @@ def diarize( def _get_audio_files( data_path: pathlib.Path, -) -> List[pathlib.Path]: +) -> list[pathlib.Path]: # Check if the path is of a directory or a file: if data_path.is_dir(): # Get all files inside the directory: @@ -320,11 +320,11 @@ def _diarize( audio: torch.Tensor, sample_rate: int, pipeline: pyannote.audio.Pipeline, - speakers_labels: List[str], + speakers_labels: list[str], separate_by_channels: bool, speaker_prefix: str, diarize_kwargs: dict, -) -> List[Tuple[float, float, str]]: +) -> list[tuple[float, float, str]]: # If there is no need for separation by channels, we diarize and return: if not separate_by_channels: # Diarize: diff --git a/functions/src/question_answering/function.yaml b/functions/src/question_answering/function.yaml index 21f741aa8..afcf893a2 100644 --- a/functions/src/question_answering/function.yaml +++ b/functions/src/question_answering/function.yaml @@ -1,83 +1,56 @@ metadata: - name: question-answering tag: '' + name: question-answering categories: - genai verbose: false kind: job spec: - command: '' - default_handler: answer_questions + image: '' + disable_auto_mount: false build: origin_filename: '' - base_image: mlrun/mlrun + functionSourceCode: # Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import logging
import operator
import pathlib
from collections import Counter
from functools import reduce, wraps
from typing import Any

import pandas as pd
import transformers
from tqdm import tqdm

# Get the global logger:
_LOGGER = logging.getLogger()


def _check_mlrun_and_open_mpi() -> tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intracomm"]:
    global _LOGGER

    is_mpi = False
    try:
        import mlrun

        context = mlrun.get_or_create_ctx(name="mlrun")
        _LOGGER = context.logger
        is_mpi = context.labels.get("kind", "job") == "mpijob"

        if is_mpi:
            try:
                from mpi4py import MPI

                return context, MPI.COMM_WORLD
            except ModuleNotFoundError as mpi4py_not_found:
                context.logger.error(
                    "To distribute the function using MLRun's 'mpijob' you need to have `mpi4py` package in your "
                    "interpreter. Please run `pip install mpi4py` and make sure you have open-mpi."
                )
                raise mpi4py_not_found
    except ModuleNotFoundError as module_not_found:
        if is_mpi:
            raise module_not_found
    return None, None


def open_mpi_handler(
    worker_inputs: list[str], root_worker_inputs: dict[str, Any] = None
):
    global _LOGGER

    # Check for MLRun and OpenMPI availability:
    context, comm = _check_mlrun_and_open_mpi()

    def decorator(handler):
        if comm is None or comm.Get_size() == 1:
            return handler

        @wraps(handler)
        def wrapper(**kwargs):
            # Get the open mpi environment properties:
            size = comm.Get_size()
            rank = comm.Get_rank()

            # Give the correct chunk of the workers inputs:
            for worker_input in worker_inputs:
                input_argument = kwargs[worker_input]
                if input_argument is None:
                    continue
                if isinstance(input_argument, str):
                    input_argument = _get_text_files(
                        data_path=pathlib.Path(input_argument).absolute()
                    )
                if len(input_argument) < size:
                    raise ValueError(
                        f"Cannot split the input '{worker_input}' of length {len(input_argument)} to {size} workers. "
                        f"Please reduce the amount of workers for this input."
                    )
                even_chunk_size = len(input_argument) // size
                chunk_start = rank * even_chunk_size
                chunk_end = (
                    (rank + 1) * even_chunk_size
                    if rank + 1 < size
                    else len(input_argument)
                )
                context.logger.info(
                    f"Rank #{rank}: Processing input chunk of '{worker_input}' "
                    f"from index {chunk_start} to {chunk_end}."
                )
                if isinstance(input_argument, list):
                    input_argument = input_argument[chunk_start:chunk_end]
                elif isinstance(input_argument, pd.DataFrame):
                    input_argument = input_argument.iloc[chunk_start:chunk_end:, :]
                kwargs[worker_input] = input_argument

            # Set the root worker only arguments:
            if rank == 0 and root_worker_inputs:
                kwargs.update(root_worker_inputs)

            # Run the worker:
            output = handler(**kwargs)

            # Send the output to the root rank (rank #0):
            output = comm.gather(output, root=0)
            if rank == 0:
                # Join the outputs:
                context.logger.info("Collecting data from workers to root worker.")
                dataframe = pd.concat(objs=[df for df, _ in output], axis=0)
                errors_dictionary = reduce(operator.ior, [err for _, err in output], {})
                return dataframe, errors_dictionary
            return None

        return wrapper

    return decorator


@open_mpi_handler(worker_inputs=["data_path"], root_worker_inputs={"verbose": True})
def answer_questions(
    data_path: str | list[str],
    model_name: str,
    questions: list[str] | list[list[str]],
    device_map: str | dict = None,
    model_kwargs: dict = None,
    auto_gptq_exllama_max_input_length: int = None,
    tokenizer_name: str = None,
    tokenizer_kwargs: dict = None,
    text_wrapper: str | list[str] = "",
    questions_wrapper: str | list[str] = "",
    generation_config: dict | list[dict] = None,
    questions_config: dict | list[dict] = None,
    batch_size: int = 1,
    questions_columns: list[str] = None,
    verbose: bool = False,
) -> tuple[pd.DataFrame, dict]:
    """
    Answer questions with a context to the given text files contents by a pretrained LLM model. Each text file will have
    the following prompt built:

    start of `text_wrapper`
    <text file content>
    end of `text_wrapper`

    start of `questions_wrapper`
    1. <questions[0]>
    2. <questions[1]>
    ...
    n. <questions[n-1]>
    end of `questions_wrapper`

    :param data_path:                          A path to a directory of text files or a path to a text file to ask
                                               questions about.
    :param model_name:                         The pre-trained model name from the huggingface hub to use for asking
                                               questions.
    :param questions:                          The questions to ask.
                                               A list of lists of questions to ask per text file, and devided
                                               by question groups, the groups can be dtermained by size (in order to
                                               avoid large inputs to the llm) or by questioning method
                                               (regular or poll like questioning).
    :param device_map:                         A map to use for loading the model on multiple devices.
    :param model_kwargs:                       Keyword arguments to pass for loading the model using HuggingFace's
                                               `transformers.AutoModelForCausalLM.from_pretrained` function.
    :param auto_gptq_exllama_max_input_length: For AutoGPTQ models to set and extend the model's input buffer size.
    :param tokenizer_name:                     The tokenizer name from the huggingface hub to use. If not given, the
                                               model name will be used.
    :param tokenizer_kwargs:                   Keyword arguments to pass for loading the tokenizer using HuggingFace's
                                               `transformers.AutoTokenizer.from_pretrained` function.
    :param text_wrapper:                       A wrapper for the file's text. Will be added at the start of the prompt.
                                               Must have a placeholder ('{}') for the text of the file.
    :param questions_wrapper:                  A wrapper for the questions received. Will be added after the text
                                               wrapper in the prompt template. Must have a placeholder ('{}') for the
                                               questions.
    :param generation_config:                  HuggingFace's `GenerationConfig` keyword arguments to pass to the
                                               `generate` method.
    :param questions_config:                   A dictionary or list of dictionaries containing specific ways to answer
                                               questions (using a poll for example), each dictionary in the list is for
                                               corresponding question group and determines the question asking method
                                               for said group.
    :param batch_size:                         Batch size for inference.
    :param questions_columns:                  Columns to use for the dataframe returned.
    :param verbose:                            Whether to present logs of a progress bar and errors. Default: True.


    :returns: A tuple of:

              * A dataframe dataset of the questions answers.
              * A dictionary of errored files that were not inferred or were not answered properly.
    """
    global _LOGGER

    # Set configs to empty dict if not given:
    if generation_config is None:
        generation_config = {}
    if questions_config is None:
        questions_config = {}

    # Get the input text files to question:
    if verbose:
        _LOGGER.info("Collecting text files.")
    if isinstance(data_path, str):
        data_path = pathlib.Path(data_path).absolute()
        text_files = _get_text_files(data_path=data_path)
    else:
        text_files = data_path
    if verbose:
        _LOGGER.info(f"Collected {len(text_files)} text files.")

    # Get the prompt template:
    if verbose:
        _LOGGER.info("Creating prompt template.")

    # Organize questions as a list of list, and count number of sub-lists for future use
    number_of_question_groups = 1 if isinstance(questions[0], str) else len(questions)
    questions = _to_group_list(
        argument_value=questions,
        argument_name="questions",
        length=number_of_question_groups,
    )

    # Organize prompt parts at proper length
    text_wrapper = _to_group_list(
        argument_value=text_wrapper,
        argument_name="text_wrapper",
        length=number_of_question_groups,
    )
    questions_wrapper = _to_group_list(
        argument_value=questions_wrapper,
        argument_name="questions_wrapper",
        length=number_of_question_groups,
    )

    # Create a list of prompt according to given parts and questions
    prompt_template = []
    questions = questions if isinstance(questions[0], list) else [questions]

    # Build all prompts
    for i in range(number_of_question_groups):
        prompt_template.append(
            _get_prompt_template(
                text_wrapper=text_wrapper[i],
                questions_wrapper=questions_wrapper[i],
                questions=questions[i],
            )
        )
    if verbose:
        _LOGGER.info(f"Prompt template created:\n\n{prompt_template}\n")

    # Get the total amount of questions:
    questions_amount = sum([len(sublist) for sublist in questions])

    # Get the questions columns:
    questions_columns = questions_columns or [
        f"q{i}" for i in range(1, questions_amount + 1)
    ]

    # Check if we have the correct amount of questions columns:
    if len(questions_columns) != questions_amount:
        raise ValueError(
            f"The provided questions columns length ({len(questions_columns)}) "
            f"does not match the questions amount ({questions_amount})"
        )

    # Load the generation config:
    if verbose:
        _LOGGER.info("Loading generation configuration.")
    generation_config = [
        transformers.GenerationConfig(**(cfg or {}))
        for cfg in _to_group_list(
            argument_value=generation_config,
            argument_name="generation_config",
            length=number_of_question_groups,
        )
    ]
    if verbose:
        _LOGGER.info(f"Generation configuration loaded: {generation_config}")

    # Load the model and tokenizer into a pipeline object:
    if verbose:
        _LOGGER.info(f"Loading model '{model_name}'.")
    generation_pipeline = _get_generation_pipeline(
        model_name=model_name,
        device_map=device_map,
        tokenizer_name=tokenizer_name or model_name,
        model_kwargs=model_kwargs or {},
        tokenizer_kwargs=tokenizer_kwargs or {},
        auto_gptq_exllama_max_input_length=auto_gptq_exllama_max_input_length,
        batch_size=batch_size,
    )
    if verbose:
        _LOGGER.info("Model loaded.")

    # Prepare the successes dataframe and errors dictionary to be returned:
    successes = []
    errors = {}

    # Split the files into batches:
    file_batches = [
        text_files[i : i + batch_size]
        if i + batch_size < len(text_files)
        else text_files[i:]
        for i in range(0, len(text_files), batch_size)
    ]
    questions_config = _to_group_list(
        argument_value=questions_config,
        argument_name="questions_config",
        length=number_of_question_groups,
    )

    # Create a list of question handlers according to given configs
    handlers = []
    for cfg in questions_config:
        question_type = cfg.pop("type", "default")
        handlers.append(QUESTION_MAPPING.get(question_type)(**cfg))

    # Go over the batches of text files and question them:
    for file_batch in tqdm(
        file_batches,
        desc="Generating answers",
        unit=f"file (batch of {batch_size})",
        disable=not verbose,
    ):
        try:
            total_answers = [[] for _ in range(batch_size)]

            # Go over all question group per batch of documents
            for question_group in range(number_of_question_groups):
                current_questions_amount = len(questions[question_group])

                # Read batch (read the text from the text files):
                batched_input = _read_file_batch(
                    file_batch=file_batch,
                    prompt_template=prompt_template[question_group],
                )

                # Answer the questions with each question handler:
                batched_answers = handlers[question_group].answer(
                    questions_amount=current_questions_amount,
                    batched_input=batched_input,
                    generation_pipeline=generation_pipeline,
                    generation_config=generation_config[question_group],
                )

                # Put the answers in the correct place in the total answers list according to the place in the batch:
                for i in range(batch_size):
                    total_answers[i].extend(batched_answers[i])

            # Collect the answers and attach the file name:
            successes.extend(
                [
                    [file.name, *answers]
                    for file, answers in zip(file_batch, total_answers)
                ]
            )
        except Exception as exception:
            # Note the exception as error in the dictionary:
            batch_file_names = ", ".join([file.name for file in file_batch])
            if verbose:
                _LOGGER.warning(
                    f"Error in batch '{batch_file_names}': {str(exception)}"
                )
            errors[batch_file_names] = str(exception)
            continue

    # Construct the answers dataframe:
    columns = [
        "text_file",
        *questions_columns,
    ]

    # Create a data frame of answers by files
    successes = pd.DataFrame(
        successes,
        columns=columns,
    )

    # Print the head of the produced dataframe and return:
    if verbose:
        _LOGGER.info(
            f"Done ({successes.shape[0]}/{len(text_files)})\n"
            f"Answers summary:\n"
            f"{successes.head()}"
        )
    return successes, errors


def _get_text_files(
    data_path: pathlib.Path,
) -> list[pathlib.Path]:
    # Check if the path is of a directory or a file:
    if data_path.is_dir():
        # Get all files inside the directory:
        text_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        text_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be either a directory path or a file path. "
            f"Given: {str(data_path)} "
        )

    return text_files


def _get_prompt_template(
    text_wrapper: str,
    questions_wrapper: str,
    questions: list[str],
) -> str:
    # Validate and build the text wrapper:
    text_wrapper = text_wrapper or ("Given the following text:\n-----\n{}\n-----")
    if text_wrapper.count("{}") != 1:
        raise ValueError(
            "The `text_wrapper` must include one placeholder '{}' for the text of the file to be asked about."
        )

    # Validate and build the question wrapper:
    questions_wrapper = questions_wrapper or "Answer the questions:\n{}"
    if questions_wrapper.count("{}") != 1:
        raise ValueError(
            "The `questions_wrapper` must include one placeholder '{}' for the list of questions."
        )

    # Validate and parse the questions:
    if len(questions) == 0:
        raise ValueError("Please include at least one question.")
    questions = "\n".join(
        [f"{i}. {question}" for i, question in enumerate(questions, 1)]
    )

    # Construct the template:
    return f"{text_wrapper}\n{questions_wrapper.format(questions)}\n"


def _get_generation_pipeline(
    model_name: str,
    device_map: str | dict,
    tokenizer_name: str,
    model_kwargs: dict,
    tokenizer_kwargs: dict,
    auto_gptq_exllama_max_input_length: int = None,
    batch_size: int = 1,
):
    # Load the model:
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_name, device_map=device_map, **model_kwargs
    )

    # Set exllama max input length if provided:
    # This changes the model's context size.
    if auto_gptq_exllama_max_input_length:
        from auto_gptq import exllama_set_max_input_length

        model = exllama_set_max_input_length(
            model=model, max_input_length=auto_gptq_exllama_max_input_length
        )

    # Load the tokenizer:
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        tokenizer_name, **tokenizer_kwargs
    )

    # Initialize a generation pipline and return:
    pipe = transformers.pipeline(
        task="text-generation",
        model=model,
        tokenizer=tokenizer,
        batch_size=batch_size,
    )
    pipe.tokenizer.pad_token_id = model.config.eos_token_id
    return pipe


def _read_file_batch(
    file_batch: list[pathlib.Path],
    prompt_template: str,
) -> list[str]:
    batch = []

    # Go over all files and read in usable format
    for file in file_batch:
        with open(file, encoding="utf-8") as fp:
            batch.append(prompt_template.format(fp.read()))
    return batch


def _to_group_list(argument_value: list, argument_name: str, length: int):
    # Check if is list, turn to list if not
    argument_value = (
        argument_value if isinstance(argument_value, list) else [argument_value]
    )
    list_len = len(argument_value)

    # If not a list, or is a list of len 1 we duplicate for correct length
    # If list in wrong length throw an error
    if list_len != length:
        if list_len == 1:
            return argument_value * length
        raise ValueError(
            f"The argument value of '{argument_name}' is not equal to the length of the given questions - {length}"
        )
    return argument_value


class QuestionHandler:
    """
    A class for handling questions answering for a given question type.
    This class is used as a base class for all question types, and for default question type (regular question
    answering without any special handling).
    """

    class ConfigKeys:
        pass

    def __init__(self):
        pass

    @staticmethod
    def _get_answers(generated_text: str, questions_amount: int) -> list[str]:
        # Clear answer start (part before numbers):
        # TODO find better way to verify, for list of questions this is redundant for example
        if "1." not in generated_text:
            raise ValueError(
                f"Answer 1. is missing from the generated text: '{generated_text}'"
            )
        text = generated_text.split("1.", 1)[1]

        # Start extracting the answers:
        answers = []
        for i in range(1, questions_amount + 1):
            # If it's the last answer to look for, take the rest of the text:
            if i == questions_amount:
                answer_i = text
            # Verify there is a question number in the text:
            elif f"{i + 1}." not in text:
                raise ValueError(
                    f"Answer {i + 1}. is missing from the generated text: '{generated_text}'"
                )
            # Take i's answer:
            else:
                answer_i, text = text.split(f"{i + 1}.", 1)
            # Collect the answer removing redundant spaces:
            answers.append(answer_i.strip())

        return answers

    def _infer_questions(
        self,
        questions_amount: int,
        batched_input: list[str],
        generation_pipeline: transformers.Pipeline,
        generation_config: transformers.GenerationConfig,
    ) -> list[list[str]]:
        # Infer through the llm:
        batched_output = generation_pipeline(
            batched_input,
            generation_config=generation_config,
            eos_token_id=generation_pipeline.tokenizer.eos_token_id,
            return_full_text=False,
            num_return_sequences=1,
        )

        # Process the outputs to get the answers:
        batched_answers = []
        for output in batched_output:
            # Get the generated answers:
            answers = self._get_answers(
                generated_text=output[0]["generated_text"],
                questions_amount=questions_amount,
            )
            # Collect the processed answers:
            batched_answers.append(answers)
        return batched_answers

    def answer(
        self,
        questions_amount: int,
        batched_input: list[str],
        generation_pipeline: transformers.Pipeline,
        generation_config: transformers.GenerationConfig,
    ) -> list[list[str]]:
        """
        Answer questions with a context to the given text files contents by a pretrained LLM model in given pipeline.
        """
        return self._infer_questions(
            questions_amount=questions_amount,
            batched_input=batched_input,
            generation_pipeline=generation_pipeline,
            generation_config=generation_config,
        )


class PollQuestionHandler(QuestionHandler):
    """
    Static class to hold all the possible poll question configurations options keys
    """

    class ConfigKeys:
        """
        A class for handling questions answering for poll type questions.
        These type of question are answered by asking the same question multiple times
        and choosing the most common answer or the average answer.
        """

        #: The number of times to ask the same question.
        POLL_COUNT = "poll_count"

        #: The strategy to use for choosing the answer from the poll.
        POLL_STRATEGY = "poll_strategy"

    class Strategy(enum.Enum):
        #: The most common answer strategy.
        MOST_COMMON = "most_common"

        #: The average answer strategy.
        AVERAGE = "average"

        @staticmethod
        def most_common(answers):
            """
            Calculate the most common answer for a given list of answers.
            """
            count = Counter(answers)
            most_common = count.most_common(1)
            return most_common[0][0]

        @staticmethod
        def average(answers):
            """
            Calculate the average answer for a given list of answers.
            """
            if isinstance(answers[0], str):
                raise ValueError(
                    "Cannot perform poll with average answer strategy of non numeric values,"
                    " please change the question to give numeric data, or choose 'most_common' as strategy."
                )
            else:
                numeric_values = answers
            avg = sum(numeric_values) / len(numeric_values)

            # Round to the closest integer and return corresponding value
            return round(avg)

        def do(self, answers):
            """
            Perform the strategy.
            """
            return getattr(self, self.value)(answers)

    def __init__(self, poll_count: int = 5, poll_strategy: str = "most_common"):
        super().__init__()
        self.poll_count = poll_count
        self.poll_strategy = self.Strategy(poll_strategy)

    def answer(
        self,
        questions_amount: int,
        batched_input: list[str],
        generation_pipeline: transformers.Pipeline,
        generation_config: transformers.GenerationConfig,
    ) -> list[list[str]]:
        """
        Answer questions with a context to the given text files contents by a pretrained LLM model in given pipeline.
        """
        return self._answer_poll_questions(
            questions_amount=questions_amount,
            batched_input=batched_input,
            generation_pipeline=generation_pipeline,
            generation_config=generation_config,
        )

    def _answer_poll_questions(
        self,
        questions_amount: int,
        batched_input: list[str],
        generation_pipeline: transformers.Pipeline,
        generation_config: transformers.GenerationConfig,
    ) -> list[list[str]]:
        votes = []

        # Run the poll for each question
        for _ in range(self.poll_count):
            batched_answers = self._infer_questions(
                questions_amount=questions_amount,
                batched_input=batched_input,
                generation_pipeline=generation_pipeline,
                generation_config=generation_config,
            )
            votes.append(batched_answers)
        answers = []

        # Collect the answers according to the poll strategy
        # Average strategy works for numeric values only
        for batch in range(len(votes[0])):
            batched_answers = []
            for question in range(questions_amount):
                # Create a list of all answers to relevant question
                answer = [
                    votes[voter][batch][question] for voter in range(self.poll_count)
                ]
                answer = self.poll_strategy.do(answer)
                batched_answers.append(answer)
            answers.append(batched_answers)
        return answers


# Holds names of QuestionHandles
class QuestionTypes:
    DEFAULT = "default"
    POLL = "poll"


# Maps question types to their handlers
QUESTION_MAPPING = {
    QuestionTypes.DEFAULT: QuestionHandler,
    QuestionTypes.POLL: PollQuestionHandler,
}
 requirements: - transformers - torch - tqdm code_origin: '' - functionSourceCode: # Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import logging
import operator
import pathlib
from collections import Counter
from functools import reduce, wraps
from typing import Any, Dict, List, Tuple, Union

import pandas as pd
import transformers
from tqdm import tqdm

# Get the global logger:
_LOGGER = logging.getLogger()


def _check_mlrun_and_open_mpi() -> Tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intracomm"]:
    global _LOGGER

    is_mpi = False
    try:
        import mlrun

        context = mlrun.get_or_create_ctx(name="mlrun")
        _LOGGER = context.logger
        is_mpi = context.labels.get("kind", "job") == "mpijob"

        if is_mpi:
            try:
                from mpi4py import MPI

                return context, MPI.COMM_WORLD
            except ModuleNotFoundError as mpi4py_not_found:
                context.logger.error(
                    "To distribute the function using MLRun's 'mpijob' you need to have `mpi4py` package in your "
                    "interpreter. Please run `pip install mpi4py` and make sure you have open-mpi."
                )
                raise mpi4py_not_found
    except ModuleNotFoundError as module_not_found:
        if is_mpi:
            raise module_not_found
    return None, None


def open_mpi_handler(
    worker_inputs: List[str], root_worker_inputs: Dict[str, Any] = None
):
    global _LOGGER

    # Check for MLRun and OpenMPI availability:
    context, comm = _check_mlrun_and_open_mpi()

    def decorator(handler):
        if comm is None or comm.Get_size() == 1:
            return handler

        @wraps(handler)
        def wrapper(**kwargs):
            # Get the open mpi environment properties:
            size = comm.Get_size()
            rank = comm.Get_rank()

            # Give the correct chunk of the workers inputs:
            for worker_input in worker_inputs:
                input_argument = kwargs[worker_input]
                if input_argument is None:
                    continue
                if isinstance(input_argument, str):
                    input_argument = _get_text_files(
                        data_path=pathlib.Path(input_argument).absolute()
                    )
                if len(input_argument) < size:
                    raise ValueError(
                        f"Cannot split the input '{worker_input}' of length {len(input_argument)} to {size} workers. "
                        f"Please reduce the amount of workers for this input."
                    )
                even_chunk_size = len(input_argument) // size
                chunk_start = rank * even_chunk_size
                chunk_end = (
                    (rank + 1) * even_chunk_size
                    if rank + 1 < size
                    else len(input_argument)
                )
                context.logger.info(
                    f"Rank #{rank}: Processing input chunk of '{worker_input}' "
                    f"from index {chunk_start} to {chunk_end}."
                )
                if isinstance(input_argument, list):
                    input_argument = input_argument[chunk_start:chunk_end]
                elif isinstance(input_argument, pd.DataFrame):
                    input_argument = input_argument.iloc[chunk_start:chunk_end:, :]
                kwargs[worker_input] = input_argument

            # Set the root worker only arguments:
            if rank == 0 and root_worker_inputs:
                kwargs.update(root_worker_inputs)

            # Run the worker:
            output = handler(**kwargs)

            # Send the output to the root rank (rank #0):
            output = comm.gather(output, root=0)
            if rank == 0:
                # Join the outputs:
                context.logger.info("Collecting data from workers to root worker.")
                dataframe = pd.concat(objs=[df for df, _ in output], axis=0)
                errors_dictionary = reduce(operator.ior, [err for _, err in output], {})
                return dataframe, errors_dictionary
            return None

        return wrapper

    return decorator


@open_mpi_handler(worker_inputs=["data_path"], root_worker_inputs={"verbose": True})
def answer_questions(
    data_path: Union[str, List[str]],
    model_name: str,
    questions: Union[List[str], List[List[str]]],
    device_map: Union[str, dict] = None,
    model_kwargs: dict = None,
    auto_gptq_exllama_max_input_length: int = None,
    tokenizer_name: str = None,
    tokenizer_kwargs: dict = None,
    text_wrapper: Union[str, List[str]] = "",
    questions_wrapper: Union[str, List[str]] = "",
    generation_config: Union[Dict, List[Dict]] = None,
    questions_config: Union[Dict, List[Dict]] = None,
    batch_size: int = 1,
    questions_columns: List[str] = None,
    verbose: bool = False,
) -> Tuple[pd.DataFrame, dict]:
    """
    Answer questions with a context to the given text files contents by a pretrained LLM model. Each text file will have
    the following prompt built:

    start of `text_wrapper`
    <text file content>
    end of `text_wrapper`

    start of `questions_wrapper`
    1. <questions[0]>
    2. <questions[1]>
    ...
    n. <questions[n-1]>
    end of `questions_wrapper`

    :param data_path:                          A path to a directory of text files or a path to a text file to ask
                                               questions about.
    :param model_name:                         The pre-trained model name from the huggingface hub to use for asking
                                               questions.
    :param questions:                          The questions to ask.
                                               A list of lists of questions to ask per text file, and devided
                                               by question groups, the groups can be dtermained by size (in order to
                                               avoid large inputs to the llm) or by questioning method
                                               (regular or poll like questioning).
    :param device_map:                         A map to use for loading the model on multiple devices.
    :param model_kwargs:                       Keyword arguments to pass for loading the model using HuggingFace's
                                               `transformers.AutoModelForCausalLM.from_pretrained` function.
    :param auto_gptq_exllama_max_input_length: For AutoGPTQ models to set and extend the model's input buffer size.
    :param tokenizer_name:                     The tokenizer name from the huggingface hub to use. If not given, the
                                               model name will be used.
    :param tokenizer_kwargs:                   Keyword arguments to pass for loading the tokenizer using HuggingFace's
                                               `transformers.AutoTokenizer.from_pretrained` function.
    :param text_wrapper:                       A wrapper for the file's text. Will be added at the start of the prompt.
                                               Must have a placeholder ('{}') for the text of the file.
    :param questions_wrapper:                  A wrapper for the questions received. Will be added after the text
                                               wrapper in the prompt template. Must have a placeholder ('{}') for the
                                               questions.
    :param generation_config:                  HuggingFace's `GenerationConfig` keyword arguments to pass to the
                                               `generate` method.
    :param questions_config:                   A dictionary or list of dictionaries containing specific ways to answer
                                               questions (using a poll for example), each dictionary in the list is for
                                               corresponding question group and determines the question asking method
                                               for said group.
    :param batch_size:                         Batch size for inference.
    :param questions_columns:                  Columns to use for the dataframe returned.
    :param verbose:                            Whether to present logs of a progress bar and errors. Default: True.


    :returns: A tuple of:

              * A dataframe dataset of the questions answers.
              * A dictionary of errored files that were not inferred or were not answered properly.
    """
    global _LOGGER

    # Set configs to empty dict if not given:
    if generation_config is None:
        generation_config = {}
    if questions_config is None:
        questions_config = {}

    # Get the input text files to question:
    if verbose:
        _LOGGER.info("Collecting text files.")
    if isinstance(data_path, str):
        data_path = pathlib.Path(data_path).absolute()
        text_files = _get_text_files(data_path=data_path)
    else:
        text_files = data_path
    if verbose:
        _LOGGER.info(f"Collected {len(text_files)} text files.")

    # Get the prompt template:
    if verbose:
        _LOGGER.info("Creating prompt template.")

    # Organize questions as a list of list, and count number of sub-lists for future use
    number_of_question_groups = 1 if isinstance(questions[0], str) else len(questions)
    questions = _to_group_list(
        argument_value=questions,
        argument_name="questions",
        length=number_of_question_groups,
    )

    # Organize prompt parts at proper length
    text_wrapper = _to_group_list(
        argument_value=text_wrapper,
        argument_name="text_wrapper",
        length=number_of_question_groups,
    )
    questions_wrapper = _to_group_list(
        argument_value=questions_wrapper,
        argument_name="questions_wrapper",
        length=number_of_question_groups,
    )

    # Create a list of prompt according to given parts and questions
    prompt_template = []
    questions = questions if isinstance(questions[0], list) else [questions]

    # Build all prompts
    for i in range(number_of_question_groups):
        prompt_template.append(
            _get_prompt_template(
                text_wrapper=text_wrapper[i],
                questions_wrapper=questions_wrapper[i],
                questions=questions[i],
            )
        )
    if verbose:
        _LOGGER.info(f"Prompt template created:\n\n{prompt_template}\n")

    # Get the total amount of questions:
    questions_amount = sum([len(sublist) for sublist in questions])

    # Get the questions columns:
    questions_columns = questions_columns or [
        f"q{i}" for i in range(1, questions_amount + 1)
    ]

    # Check if we have the correct amount of questions columns:
    if len(questions_columns) != questions_amount:
        raise ValueError(
            f"The provided questions columns length ({len(questions_columns)}) "
            f"does not match the questions amount ({questions_amount})"
        )

    # Load the generation config:
    if verbose:
        _LOGGER.info("Loading generation configuration.")
    generation_config = [
        transformers.GenerationConfig(**(cfg or {}))
        for cfg in _to_group_list(
            argument_value=generation_config,
            argument_name="generation_config",
            length=number_of_question_groups,
        )
    ]
    if verbose:
        _LOGGER.info(f"Generation configuration loaded: {generation_config}")

    # Load the model and tokenizer into a pipeline object:
    if verbose:
        _LOGGER.info(f"Loading model '{model_name}'.")
    generation_pipeline = _get_generation_pipeline(
        model_name=model_name,
        device_map=device_map,
        tokenizer_name=tokenizer_name or model_name,
        model_kwargs=model_kwargs or {},
        tokenizer_kwargs=tokenizer_kwargs or {},
        auto_gptq_exllama_max_input_length=auto_gptq_exllama_max_input_length,
        batch_size=batch_size,
    )
    if verbose:
        _LOGGER.info("Model loaded.")

    # Prepare the successes dataframe and errors dictionary to be returned:
    successes = []
    errors = {}

    # Split the files into batches:
    file_batches = [
        text_files[i : i + batch_size]
        if i + batch_size < len(text_files)
        else text_files[i:]
        for i in range(0, len(text_files), batch_size)
    ]
    questions_config = _to_group_list(
        argument_value=questions_config,
        argument_name="questions_config",
        length=number_of_question_groups,
    )

    # Create a list of question handlers according to given configs
    handlers = []
    for cfg in questions_config:
        question_type = cfg.pop("type", "default")
        handlers.append(QUESTION_MAPPING.get(question_type)(**cfg))

    # Go over the batches of text files and question them:
    for file_batch in tqdm(
        file_batches,
        desc="Generating answers",
        unit=f"file (batch of {batch_size})",
        disable=not verbose,
    ):
        try:
            total_answers = [[] for _ in range(batch_size)]

            # Go over all question group per batch of documents
            for question_group in range(number_of_question_groups):
                current_questions_amount = len(questions[question_group])

                # Read batch (read the text from the text files):
                batched_input = _read_file_batch(
                    file_batch=file_batch,
                    prompt_template=prompt_template[question_group],
                )

                # Answer the questions with each question handler:
                batched_answers = handlers[question_group].answer(
                    questions_amount=current_questions_amount,
                    batched_input=batched_input,
                    generation_pipeline=generation_pipeline,
                    generation_config=generation_config[question_group],
                )

                # Put the answers in the correct place in the total answers list according to the place in the batch:
                for i in range(batch_size):
                    total_answers[i].extend(batched_answers[i])

            # Collect the answers and attach the file name:
            successes.extend(
                [
                    [file.name, *answers]
                    for file, answers in zip(file_batch, total_answers)
                ]
            )
        except Exception as exception:
            # Note the exception as error in the dictionary:
            batch_file_names = ", ".join([file.name for file in file_batch])
            if verbose:
                _LOGGER.warning(
                    f"Error in batch '{batch_file_names}': {str(exception)}"
                )
            errors[batch_file_names] = str(exception)
            continue

    # Construct the answers dataframe:
    columns = [
        "text_file",
        *questions_columns,
    ]

    # Create a data frame of answers by files
    successes = pd.DataFrame(
        successes,
        columns=columns,
    )

    # Print the head of the produced dataframe and return:
    if verbose:
        _LOGGER.info(
            f"Done ({successes.shape[0]}/{len(text_files)})\n"
            f"Answers summary:\n"
            f"{successes.head()}"
        )
    return successes, errors


def _get_text_files(
    data_path: pathlib.Path,
) -> List[pathlib.Path]:

    # Check if the path is of a directory or a file:
    if data_path.is_dir():

        # Get all files inside the directory:
        text_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        text_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be either a directory path or a file path. "
            f"Given: {str(data_path)} "
        )

    return text_files


def _get_prompt_template(
    text_wrapper: str,
    questions_wrapper: str,
    questions: List[str],
) -> str:

    # Validate and build the text wrapper:
    text_wrapper = text_wrapper or (
        "Given the following text:\n" "-----\n" "{}\n" "-----"
    )
    if text_wrapper.count("{}") != 1:
        raise ValueError(
            "The `text_wrapper` must include one placeholder '{}' for the text of the file to be asked about."
        )

    # Validate and build the question wrapper:
    questions_wrapper = questions_wrapper or "Answer the questions:\n" "{}"
    if questions_wrapper.count("{}") != 1:
        raise ValueError(
            "The `questions_wrapper` must include one placeholder '{}' for the list of questions."
        )

    # Validate and parse the questions:
    if len(questions) == 0:
        raise ValueError("Please include at least one question.")
    questions = "\n".join(
        [f"{i}. {question}" for i, question in enumerate(questions, 1)]
    )

    # Construct the template:
    return f"{text_wrapper}\n{questions_wrapper.format(questions)}\n"


def _get_generation_pipeline(
    model_name: str,
    device_map: Union[str, dict],
    tokenizer_name: str,
    model_kwargs: dict,
    tokenizer_kwargs: dict,
    auto_gptq_exllama_max_input_length: int = None,
    batch_size: int = 1,
):
    # Load the model:
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_name, device_map=device_map, **model_kwargs
    )

    # Set exllama max input length if provided:
    # This changes the model's context size.
    if auto_gptq_exllama_max_input_length:
        from auto_gptq import exllama_set_max_input_length

        model = exllama_set_max_input_length(
            model=model, max_input_length=auto_gptq_exllama_max_input_length
        )

    # Load the tokenizer:
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        tokenizer_name, **tokenizer_kwargs
    )

    # Initialize a generation pipline and return:
    pipe = transformers.pipeline(
        task="text-generation",
        model=model,
        tokenizer=tokenizer,
        batch_size=batch_size,
    )
    pipe.tokenizer.pad_token_id = model.config.eos_token_id
    return pipe


def _read_file_batch(
    file_batch: List[pathlib.Path],
    prompt_template: str,
) -> List[str]:
    batch = []

    # Go over all files and read in usable format
    for file in file_batch:
        with open(file, "r", encoding="utf-8") as fp:
            batch.append(prompt_template.format(fp.read()))
    return batch


def _to_group_list(argument_value: list, argument_name: str, length: int):

    # Check if is list, turn to list if not
    argument_value = (
        argument_value if isinstance(argument_value, list) else [argument_value]
    )
    list_len = len(argument_value)

    # If not a list, or is a list of len 1 we duplicate for correct length
    # If list in wrong length throw an error
    if list_len != length:
        if list_len == 1:
            return argument_value * length
        raise ValueError(
            f"The argument value of '{argument_name}' is not equal to the length of the given questions - {length}"
        )
    return argument_value


class QuestionHandler:
    """
    A class for handling questions answering for a given question type.
    This class is used as a base class for all question types, and for default question type (regular question
    answering without any special handling).
    """

    class ConfigKeys:
        pass

    def __init__(self):
        pass

    @staticmethod
    def _get_answers(generated_text: str, questions_amount: int) -> List[str]:

        # Clear answer start (part before numbers):
        # TODO find better way to verify, for list of questions this is redundant for example
        if "1." not in generated_text:
            raise ValueError(
                f"Answer 1. is missing from the generated text: '{generated_text}'"
            )
        text = generated_text.split("1.", 1)[1]

        # Start extracting the answers:
        answers = []
        for i in range(1, questions_amount + 1):
            # If it's the last answer to look for, take the rest of the text:
            if i == questions_amount:
                answer_i = text
            # Verify there is a question number in the text:
            elif f"{i + 1}." not in text:
                raise ValueError(
                    f"Answer {i + 1}. is missing from the generated text: '{generated_text}'"
                )
            # Take i's answer:
            else:
                answer_i, text = text.split(f"{i + 1}.", 1)
            # Collect the answer removing redundant spaces:
            answers.append(answer_i.strip())

        return answers

    def _infer_questions(
        self,
        questions_amount: int,
        batched_input: List[str],
        generation_pipeline: transformers.Pipeline,
        generation_config: transformers.GenerationConfig,
    ) -> List[List[str]]:

        # Infer through the llm:
        batched_output = generation_pipeline(
            batched_input,
            generation_config=generation_config,
            eos_token_id=generation_pipeline.tokenizer.eos_token_id,
            return_full_text=False,
            num_return_sequences=1,
        )

        # Process the outputs to get the answers:
        batched_answers = []
        for output in batched_output:
            # Get the generated answers:
            answers = self._get_answers(
                generated_text=output[0]["generated_text"],
                questions_amount=questions_amount,
            )
            # Collect the processed answers:
            batched_answers.append(answers)
        return batched_answers

    def answer(
        self,
        questions_amount: int,
        batched_input: List[str],
        generation_pipeline: transformers.Pipeline,
        generation_config: transformers.GenerationConfig,
    ) -> List[List[str]]:
        """
        Answer questions with a context to the given text files contents by a pretrained LLM model in given pipeline.
        """
        return self._infer_questions(
            questions_amount=questions_amount,
            batched_input=batched_input,
            generation_pipeline=generation_pipeline,
            generation_config=generation_config,
        )


class PollQuestionHandler(QuestionHandler):
    """
    Static class to hold all the possible poll question configurations options keys
    """

    class ConfigKeys:
        """
        A class for handling questions answering for poll type questions.
        These type of question are answered by asking the same question multiple times
        and choosing the most common answer or the average answer.
        """

        #: The number of times to ask the same question.
        POLL_COUNT = "poll_count"

        #: The strategy to use for choosing the answer from the poll.
        POLL_STRATEGY = "poll_strategy"

    class Strategy(enum.Enum):
        #: The most common answer strategy.
        MOST_COMMON = "most_common"

        #: The average answer strategy.
        AVERAGE = "average"

        @staticmethod
        def most_common(answers):
            """
            Calculate the most common answer for a given list of answers.
            """
            count = Counter(answers)
            most_common = count.most_common(1)
            return most_common[0][0]

        @staticmethod
        def average(answers):
            """
            Calculate the average answer for a given list of answers.
            """
            if isinstance(answers[0], str):
                raise ValueError(
                    "Cannot perform poll with average answer strategy of non numeric values,"
                    " please change the question to give numeric data, or choose 'most_common' as strategy."
                )
            else:
                numeric_values = answers
            avg = sum(numeric_values) / len(numeric_values)

            # Round to the closest integer and return corresponding value
            return round(avg)

        def do(self, answers):
            """
            Perform the strategy.
            """
            return getattr(self, self.value)(answers)

    def __init__(
        self, poll_count: int = 5, poll_strategy: str = "most_common"):
        super().__init__()
        self.poll_count = poll_count
        self.poll_strategy = self.Strategy(poll_strategy)

    def answer(
        self,
        questions_amount: int,
        batched_input: List[str],
        generation_pipeline: transformers.Pipeline,
        generation_config: transformers.GenerationConfig,
    ) -> List[List[str]]:
        """
        Answer questions with a context to the given text files contents by a pretrained LLM model in given pipeline.
        """
        return self._answer_poll_questions(
            questions_amount=questions_amount,
            batched_input=batched_input,
            generation_pipeline=generation_pipeline,
            generation_config=generation_config,
        )

    def _answer_poll_questions(
        self,
        questions_amount: int,
        batched_input: List[str],
        generation_pipeline: transformers.Pipeline,
        generation_config: transformers.GenerationConfig,
    ) -> List[List[str]]:
        votes = []

        # Run the poll for each question
        for _ in range(self.poll_count):
            batched_answers = self._infer_questions(
                questions_amount=questions_amount,
                batched_input=batched_input,
                generation_pipeline=generation_pipeline,
                generation_config=generation_config,
            )
            votes.append(batched_answers)
        answers = []

        # Collect the answers according to the poll strategy
        # Average strategy works for numeric values only
        for batch in range(len(votes[0])):
            batched_answers = []
            for question in range(questions_amount):
                # Create a list of all answers to relevant question
                answer = [
                    votes[voter][batch][question] for voter in range(self.poll_count)
                ]
                answer = self.poll_strategy.do(answer)
                batched_answers.append(answer)
            answers.append(batched_answers)
        return answers


# Holds names of QuestionHandles
class QuestionTypes:
    DEFAULT = "default"
    POLL = "poll"


# Maps question types to their handlers
QUESTION_MAPPING = {
    QuestionTypes.DEFAULT: QuestionHandler,
    QuestionTypes.POLL: PollQuestionHandler,
}
 + base_image: mlrun/mlrun + filename: question_answering.py entry_points: open_mpi_handler: - name: open_mpi_handler - has_varargs: false - doc: '' - lineno: 58 parameters: - name: worker_inputs - type: List[str] + type: list[str] - name: root_worker_inputs - type: Dict[str, Any] + type: dict[str, Any] default: null + name: open_mpi_handler + doc: '' has_kwargs: false - decorator: - name: decorator has_varargs: false - doc: '' - lineno: 66 + lineno: 58 + decorator: parameters: - name: handler + name: decorator + doc: '' has_kwargs: false + has_varargs: false + lineno: 66 wrapper: name: wrapper - has_varargs: false doc: '' - lineno: 71 has_kwargs: true + has_varargs: false + lineno: 71 answer_questions: outputs: - doc: 'A tuple of:' - type: Tuple[pd.DataFrame, dict] - name: answer_questions - has_varargs: false - doc: 'Answer questions with a context to the given text files contents by a - pretrained LLM model. Each text file will have - - the following prompt built: - - - start of `text_wrapper` - - - - end of `text_wrapper` - - - start of `questions_wrapper` - - 1. - - 2. - - ... - - n. - - end of `questions_wrapper`' - lineno: 130 + type: tuple[pd.DataFrame, dict] parameters: - name: data_path - type: Union[str, List[str]] doc: A path to a directory of text files or a path to a text file to ask questions about. - name: model_name @@ -85,13 +58,11 @@ spec: doc: The pre-trained model name from the huggingface hub to use for asking questions. - name: questions - type: Union[List[str], List[List[str]]] doc: The questions to ask. A list of lists of questions to ask per text file, and devided by question groups, the groups can be dtermained by size (in order to avoid large inputs to the llm) or by questioning method (regular or poll like questioning). - name: device_map - type: Union[str, dict] doc: A map to use for loading the model on multiple devices. default: null - name: model_kwargs @@ -114,22 +85,18 @@ spec: `transformers.AutoTokenizer.from_pretrained` function. default: null - name: text_wrapper - type: Union[str, List[str]] doc: A wrapper for the file's text. Will be added at the start of the prompt. Must have a placeholder ('{}') for the text of the file. default: '' - name: questions_wrapper - type: Union[str, List[str]] doc: A wrapper for the questions received. Will be added after the text wrapper in the prompt template. Must have a placeholder ('{}') for the questions. default: '' - name: generation_config - type: Union[Dict, List[Dict]] doc: HuggingFace's `GenerationConfig` keyword arguments to pass to the `generate` method. default: null - name: questions_config - type: Union[Dict, List[Dict]] doc: A dictionary or list of dictionaries containing specific ways to answer questions (using a poll for example), each dictionary in the list is for corresponding question group and determines the question asking method for @@ -140,58 +107,85 @@ spec: doc: Batch size for inference. default: 1 - name: questions_columns - type: List[str] + type: list[str] doc: Columns to use for the dataframe returned. default: null - name: verbose type: bool doc: 'Whether to present logs of a progress bar and errors. Default: True.' default: false + name: answer_questions + doc: 'Answer questions with a context to the given text files contents by a + pretrained LLM model. Each text file will have + + the following prompt built: + + + start of `text_wrapper` + + + + end of `text_wrapper` + + + start of `questions_wrapper` + + 1. + + 2. + + ... + + n. + + end of `questions_wrapper`' has_kwargs: false + has_varargs: false + lineno: 130 answer: outputs: - - type: List[List[str]] - name: answer - has_varargs: false - doc: Answer questions with a context to the given text files contents by a pretrained - LLM model in given pipeline. - lineno: 674 + - type: list[list[str]] parameters: - name: self - name: questions_amount type: int - name: batched_input - type: List[str] + type: list[str] - name: generation_pipeline type: Pipeline - name: generation_config type: GenerationConfig + name: answer + doc: Answer questions with a context to the given text files contents by a pretrained + LLM model in given pipeline. has_kwargs: false - most_common: - name: most_common has_varargs: false - doc: Calculate the most common answer for a given list of answers. - lineno: 637 + lineno: 665 + most_common: parameters: - name: answers + name: most_common + doc: Calculate the most common answer for a given list of answers. has_kwargs: false - average: - name: average has_varargs: false - doc: Calculate the average answer for a given list of answers. - lineno: 646 + lineno: 629 + average: parameters: - name: answers + name: average + doc: Calculate the average answer for a given list of answers. has_kwargs: false - do: - name: do has_varargs: false - doc: Perform the strategy. - lineno: 662 + lineno: 638 + do: parameters: - name: self - name: answers + name: do + doc: Perform the strategy. has_kwargs: false - image: '' + has_varargs: false + lineno: 654 + command: '' description: GenAI approach of question answering on a given data - disable_auto_mount: false + default_handler: answer_questions diff --git a/functions/src/question_answering/question_answering.py b/functions/src/question_answering/question_answering.py index 2e4e96d03..0ad4bb015 100644 --- a/functions/src/question_answering/question_answering.py +++ b/functions/src/question_answering/question_answering.py @@ -17,7 +17,7 @@ import pathlib from collections import Counter from functools import reduce, wraps -from typing import Any, Dict, List, Tuple, Union +from typing import Any import pandas as pd import transformers @@ -27,7 +27,7 @@ _LOGGER = logging.getLogger() -def _check_mlrun_and_open_mpi() -> Tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intracomm"]: +def _check_mlrun_and_open_mpi() -> tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intracomm"]: global _LOGGER is_mpi = False @@ -56,7 +56,7 @@ def _check_mlrun_and_open_mpi() -> Tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intrac def open_mpi_handler( - worker_inputs: List[str], root_worker_inputs: Dict[str, Any] = None + worker_inputs: list[str], root_worker_inputs: dict[str, Any] = None ): global _LOGGER @@ -128,22 +128,22 @@ def wrapper(**kwargs): @open_mpi_handler(worker_inputs=["data_path"], root_worker_inputs={"verbose": True}) def answer_questions( - data_path: Union[str, List[str]], + data_path: str | list[str], model_name: str, - questions: Union[List[str], List[List[str]]], - device_map: Union[str, dict] = None, + questions: list[str] | list[list[str]], + device_map: str | dict = None, model_kwargs: dict = None, auto_gptq_exllama_max_input_length: int = None, tokenizer_name: str = None, tokenizer_kwargs: dict = None, - text_wrapper: Union[str, List[str]] = "", - questions_wrapper: Union[str, List[str]] = "", - generation_config: Union[Dict, List[Dict]] = None, - questions_config: Union[Dict, List[Dict]] = None, + text_wrapper: str | list[str] = "", + questions_wrapper: str | list[str] = "", + generation_config: dict | list[dict] = None, + questions_config: dict | list[dict] = None, batch_size: int = 1, - questions_columns: List[str] = None, + questions_columns: list[str] = None, verbose: bool = False, -) -> Tuple[pd.DataFrame, dict]: +) -> tuple[pd.DataFrame, dict]: """ Answer questions with a context to the given text files contents by a pretrained LLM model. Each text file will have the following prompt built: @@ -396,11 +396,9 @@ def answer_questions( def _get_text_files( data_path: pathlib.Path, -) -> List[pathlib.Path]: - +) -> list[pathlib.Path]: # Check if the path is of a directory or a file: if data_path.is_dir(): - # Get all files inside the directory: text_files = list(data_path.glob("*.*")) elif data_path.is_file(): @@ -417,20 +415,17 @@ def _get_text_files( def _get_prompt_template( text_wrapper: str, questions_wrapper: str, - questions: List[str], + questions: list[str], ) -> str: - # Validate and build the text wrapper: - text_wrapper = text_wrapper or ( - "Given the following text:\n" "-----\n" "{}\n" "-----" - ) + text_wrapper = text_wrapper or ("Given the following text:\n-----\n{}\n-----") if text_wrapper.count("{}") != 1: raise ValueError( "The `text_wrapper` must include one placeholder '{}' for the text of the file to be asked about." ) # Validate and build the question wrapper: - questions_wrapper = questions_wrapper or "Answer the questions:\n" "{}" + questions_wrapper = questions_wrapper or "Answer the questions:\n{}" if questions_wrapper.count("{}") != 1: raise ValueError( "The `questions_wrapper` must include one placeholder '{}' for the list of questions." @@ -449,7 +444,7 @@ def _get_prompt_template( def _get_generation_pipeline( model_name: str, - device_map: Union[str, dict], + device_map: str | dict, tokenizer_name: str, model_kwargs: dict, tokenizer_kwargs: dict, @@ -487,20 +482,19 @@ def _get_generation_pipeline( def _read_file_batch( - file_batch: List[pathlib.Path], + file_batch: list[pathlib.Path], prompt_template: str, -) -> List[str]: +) -> list[str]: batch = [] # Go over all files and read in usable format for file in file_batch: - with open(file, "r", encoding="utf-8") as fp: + with open(file, encoding="utf-8") as fp: batch.append(prompt_template.format(fp.read())) return batch def _to_group_list(argument_value: list, argument_name: str, length: int): - # Check if is list, turn to list if not argument_value = ( argument_value if isinstance(argument_value, list) else [argument_value] @@ -532,8 +526,7 @@ def __init__(self): pass @staticmethod - def _get_answers(generated_text: str, questions_amount: int) -> List[str]: - + def _get_answers(generated_text: str, questions_amount: int) -> list[str]: # Clear answer start (part before numbers): # TODO find better way to verify, for list of questions this is redundant for example if "1." not in generated_text: @@ -564,11 +557,10 @@ def _get_answers(generated_text: str, questions_amount: int) -> List[str]: def _infer_questions( self, questions_amount: int, - batched_input: List[str], + batched_input: list[str], generation_pipeline: transformers.Pipeline, generation_config: transformers.GenerationConfig, - ) -> List[List[str]]: - + ) -> list[list[str]]: # Infer through the llm: batched_output = generation_pipeline( batched_input, @@ -593,10 +585,10 @@ def _infer_questions( def answer( self, questions_amount: int, - batched_input: List[str], + batched_input: list[str], generation_pipeline: transformers.Pipeline, generation_config: transformers.GenerationConfig, - ) -> List[List[str]]: + ) -> list[list[str]]: """ Answer questions with a context to the given text files contents by a pretrained LLM model in given pipeline. """ @@ -665,8 +657,7 @@ def do(self, answers): """ return getattr(self, self.value)(answers) - def __init__( - self, poll_count: int = 5, poll_strategy: str = "most_common"): + def __init__(self, poll_count: int = 5, poll_strategy: str = "most_common"): super().__init__() self.poll_count = poll_count self.poll_strategy = self.Strategy(poll_strategy) @@ -674,10 +665,10 @@ def __init__( def answer( self, questions_amount: int, - batched_input: List[str], + batched_input: list[str], generation_pipeline: transformers.Pipeline, generation_config: transformers.GenerationConfig, - ) -> List[List[str]]: + ) -> list[list[str]]: """ Answer questions with a context to the given text files contents by a pretrained LLM model in given pipeline. """ @@ -691,10 +682,10 @@ def answer( def _answer_poll_questions( self, questions_amount: int, - batched_input: List[str], + batched_input: list[str], generation_pipeline: transformers.Pipeline, generation_config: transformers.GenerationConfig, - ) -> List[List[str]]: + ) -> list[list[str]]: votes = [] # Run the poll for each question diff --git a/functions/src/question_answering/test_question_answering.py b/functions/src/question_answering/test_question_answering.py index f35b4364e..41469ebe3 100644 --- a/functions/src/question_answering/test_question_answering.py +++ b/functions/src/question_answering/test_question_answering.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import tempfile + import mlrun import transformers -import tempfile APPLE_COLOR = "red" @@ -36,18 +37,15 @@ def test_question_answering(monkeypatch): input_path = "./data" artifact_path = tempfile.mkdtemp() project = mlrun.new_project("qa", context="./") - fn = project.set_function("question_answering.py", "answer_questions", kind="job", image="mlrun/mlrun") + fn = project.set_function( + "question_answering.py", "answer_questions", kind="job", image="mlrun/mlrun" + ) qa_run = fn.run( handler="answer_questions", params={ "model_name": "distilgpt2", "data_path": input_path, - "text_wrapper": ( - "Given the following sentence:\n" - "-----\n" - "{}\n" - "-----" - ), + "text_wrapper": ("Given the following sentence:\n-----\n{}\n-----"), "questions": [ "What is the color of the apple?", ], @@ -67,7 +65,7 @@ def test_question_answering(monkeypatch): "question_answering_errors: result", ], local=True, - artifact_path=artifact_path + artifact_path=artifact_path, ) qa_df = mlrun.get_dataitem( qa_run.status.artifacts[0]["spec"]["target_path"] diff --git a/functions/src/send_email/function.yaml b/functions/src/send_email/function.yaml index 1722fb586..00a0f2ad8 100644 --- a/functions/src/send_email/function.yaml +++ b/functions/src/send_email/function.yaml @@ -1,44 +1,35 @@ -kind: job metadata: - name: send-email tag: '' - hash: 5c4528084ea98992b77f65e29359bbcb4a0df8ab - project: '' - labels: - author: Iguazio + name: send-email categories: - utils +verbose: false +kind: job spec: - command: '' - args: [] image: mlrun/mlrun + disable_auto_mount: false build: - functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKIyBHZW5lcmF0ZWQgYnkgbnVjbGlvLmV4cG9ydC5OdWNsaW9FeHBvcnRlcgoKZnJvbSBtbHJ1bi5leGVjdXRpb24gaW1wb3J0IE1MQ2xpZW50Q3R4CmZyb20gdHlwaW5nIGltcG9ydCBMaXN0CgppbXBvcnQgc210cGxpYgpmcm9tIGVtYWlsLm1lc3NhZ2UgaW1wb3J0IEVtYWlsTWVzc2FnZQppbXBvcnQgb3MKCmltcG9ydCBtaW1ldHlwZXMKCgpkZWYgc2VuZF9lbWFpbCgKICAgIGNvbnRleHQ6IE1MQ2xpZW50Q3R4LAogICAgc2VuZGVyOiBzdHIsCiAgICB0bzogc3RyLAogICAgc3ViamVjdDogc3RyLAogICAgY29udGVudDogc3RyID0gIiIsCiAgICBzZXJ2ZXJfYWRkcjogc3RyID0gTm9uZSwKICAgIGF0dGFjaG1lbnRzOiBMaXN0W3N0cl0gPSBbXSwKKSAtPiBOb25lOgogICAgIiIiU2VuZCBhbiBlbWFpbC4KICAgIDpwYXJhbSBzZW5kZXI6IFNlbmRlciBlbWFpbCBhZGRyZXNzCiAgICA6cGFyYW0gY29udGV4dDogVGhlIGZ1bmN0aW9uIGNvbnRleHQKICAgIDpwYXJhbSB0bzogRW1haWwgYWRkcmVzcyBvZiBtYWlsIHJlY2lwaWVudAogICAgOnBhcmFtIHN1YmplY3Q6IEVtYWlsIHN1YmplY3QKICAgIDpwYXJhbSBjb250ZW50OiBPcHRpb25hbCBtYWlsIHRleHQKICAgIDpwYXJhbSBzZXJ2ZXJfYWRkcjogQWRkcmVzcyBvZiBTTVRQIHNlcnZlciB0byB1c2UuIFVzZSBmb3JtYXQgPGFkZHI+Ojxwb3J0PgogICAgOnBhcmFtIGF0dGFjaG1lbnRzOiBMaXN0IG9mIGF0dGFjaG1lbnRzIHRvIGFkZC4KICAgICIiIgoKICAgIGVtYWlsX3VzZXIgPSBjb250ZXh0LmdldF9zZWNyZXQoIlNNVFBfVVNFUiIpCiAgICBlbWFpbF9wYXNzID0gY29udGV4dC5nZXRfc2VjcmV0KCJTTVRQX1BBU1NXT1JEIikKICAgIGlmIGVtYWlsX3VzZXIgaXMgTm9uZSBvciBlbWFpbF9wYXNzIGlzIE5vbmU6CiAgICAgICAgY29udGV4dC5sb2dnZXIuZXJyb3IoIk1pc3Npbmcgc2VuZGVyIGVtYWlsIG9yIHBhc3N3b3JkIC0gY2Fubm90IHNlbmQgZW1haWwuIikKICAgICAgICByZXR1cm4KCiAgICBpZiBzZXJ2ZXJfYWRkciBpcyBOb25lOgogICAgICAgIGNvbnRleHQubG9nZ2VyLmVycm9yKCJTZXJ2ZXIgbm90IHNwZWNpZmllZCAtIGNhbm5vdCBzZW5kIGVtYWlsLiIpCiAgICAgICAgcmV0dXJuCgogICAgbXNnID0gRW1haWxNZXNzYWdlKCkKICAgIG1zZ1siRnJvbSJdID0gc2VuZGVyCiAgICBtc2dbIlN1YmplY3QiXSA9IHN1YmplY3QKICAgIG1zZ1siVG8iXSA9IHRvCiAgICBtc2cuc2V0X2NvbnRlbnQoY29udGVudCkKCiAgICBmb3IgZmlsZW5hbWUgaW4gYXR0YWNobWVudHM6CiAgICAgICAgY29udGV4dC5sb2dnZXIuaW5mbyhmIkxvb2tpbmcgYXQgYXR0YWNobWVudDoge2ZpbGVuYW1lfSIpCiAgICAgICAgaWYgbm90IG9zLnBhdGguaXNmaWxlKGZpbGVuYW1lKToKICAgICAgICAgICAgY29udGV4dC5sb2dnZXIud2FybmluZyhmIkZpbGVuYW1lIGRvZXMgbm90IGV4aXN0IHtmaWxlbmFtZX0iKQogICAgICAgICAgICBjb250aW51ZQogICAgICAgIGN0eXBlLCBlbmNvZGluZyA9IG1pbWV0eXBlcy5ndWVzc190eXBlKGZpbGVuYW1lKQogICAgICAgIGlmIGN0eXBlIGlzIE5vbmUgb3IgZW5jb2RpbmcgaXMgbm90IE5vbmU6CiAgICAgICAgICAgIGN0eXBlID0gImFwcGxpY2F0aW9uL29jdGV0LXN0cmVhbSIKICAgICAgICBtYWludHlwZSwgc3VidHlwZSA9IGN0eXBlLnNwbGl0KCIvIiwgMSkKICAgICAgICB3aXRoIG9wZW4oZmlsZW5hbWUsICJyYiIpIGFzIGZwOgogICAgICAgICAgICBtc2cuYWRkX2F0dGFjaG1lbnQoCiAgICAgICAgICAgICAgICBmcC5yZWFkKCksCiAgICAgICAgICAgICAgICBtYWludHlwZT1tYWludHlwZSwKICAgICAgICAgICAgICAgIHN1YnR5cGU9c3VidHlwZSwKICAgICAgICAgICAgICAgIGZpbGVuYW1lPW9zLnBhdGguYmFzZW5hbWUoZmlsZW5hbWUpLAogICAgICAgICAgICApCiAgICAgICAgICAgIGNvbnRleHQubG9nZ2VyLmluZm8oCiAgICAgICAgICAgICAgICBmIkFkZGVkIGF0dGFjaG1lbnQ6IEZpbGVuYW1lOiB7ZmlsZW5hbWV9LCBvZiBtaW1ldHlwZToge21haW50eXBlfSwge3N1YnR5cGV9IgogICAgICAgICAgICApCgogICAgdHJ5OgogICAgICAgIHMgPSBzbXRwbGliLlNNVFAoaG9zdD1zZXJ2ZXJfYWRkcikKICAgICAgICBzLnN0YXJ0dGxzKCkKICAgICAgICBzLmxvZ2luKGVtYWlsX3VzZXIsIGVtYWlsX3Bhc3MpCiAgICAgICAgcy5zZW5kX21lc3NhZ2UobXNnKQogICAgICAgIGNvbnRleHQubG9nZ2VyLmluZm8oIkVtYWlsIHNlbnQgc3VjY2Vzc2Z1bGx5LiIpCiAgICBleGNlcHQgc210cGxpYi5TTVRQRXhjZXB0aW9uIGFzIGV4cDoKICAgICAgICBjb250ZXh0LmxvZ2dlci5lcnJvcihmIlNNVFAgZXhjZXB0aW9uIGNhdWdodCBpbiBTTVRQIGNvZGU6IHtleHB9IikKICAgIGV4Y2VwdCBDb25uZWN0aW9uRXJyb3IgYXMgY2U6CiAgICAgICAgY29udGV4dC5sb2dnZXIuZXJyb3IoZiJDb25uZWN0aW9uIGVycm9yIGNhdWdodCBpbiBTTVRQIGNvZGU6IHtjZX0iKQo= - commands: [] - code_origin: "" - origin_filename: "" - requirements: [] + origin_filename: '' + functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKIyBHZW5lcmF0ZWQgYnkgbnVjbGlvLmV4cG9ydC5OdWNsaW9FeHBvcnRlcgoKaW1wb3J0IG1pbWV0eXBlcwppbXBvcnQgb3MKaW1wb3J0IHNtdHBsaWIKZnJvbSBlbWFpbC5tZXNzYWdlIGltcG9ydCBFbWFpbE1lc3NhZ2UKCmZyb20gbWxydW4uZXhlY3V0aW9uIGltcG9ydCBNTENsaWVudEN0eAoKCmRlZiBzZW5kX2VtYWlsKAogICAgY29udGV4dDogTUxDbGllbnRDdHgsCiAgICBzZW5kZXI6IHN0ciwKICAgIHRvOiBzdHIsCiAgICBzdWJqZWN0OiBzdHIsCiAgICBjb250ZW50OiBzdHIgPSAiIiwKICAgIHNlcnZlcl9hZGRyOiBzdHIgPSBOb25lLAogICAgYXR0YWNobWVudHM6IGxpc3Rbc3RyXSA9IFtdLAopIC0+IE5vbmU6CiAgICAiIiJTZW5kIGFuIGVtYWlsLgogICAgOnBhcmFtIHNlbmRlcjogU2VuZGVyIGVtYWlsIGFkZHJlc3MKICAgIDpwYXJhbSBjb250ZXh0OiBUaGUgZnVuY3Rpb24gY29udGV4dAogICAgOnBhcmFtIHRvOiBFbWFpbCBhZGRyZXNzIG9mIG1haWwgcmVjaXBpZW50CiAgICA6cGFyYW0gc3ViamVjdDogRW1haWwgc3ViamVjdAogICAgOnBhcmFtIGNvbnRlbnQ6IE9wdGlvbmFsIG1haWwgdGV4dAogICAgOnBhcmFtIHNlcnZlcl9hZGRyOiBBZGRyZXNzIG9mIFNNVFAgc2VydmVyIHRvIHVzZS4gVXNlIGZvcm1hdCA8YWRkcj46PHBvcnQ+CiAgICA6cGFyYW0gYXR0YWNobWVudHM6IExpc3Qgb2YgYXR0YWNobWVudHMgdG8gYWRkLgogICAgIiIiCgogICAgZW1haWxfdXNlciA9IGNvbnRleHQuZ2V0X3NlY3JldCgiU01UUF9VU0VSIikKICAgIGVtYWlsX3Bhc3MgPSBjb250ZXh0LmdldF9zZWNyZXQoIlNNVFBfUEFTU1dPUkQiKQogICAgaWYgZW1haWxfdXNlciBpcyBOb25lIG9yIGVtYWlsX3Bhc3MgaXMgTm9uZToKICAgICAgICBjb250ZXh0LmxvZ2dlci5lcnJvcigiTWlzc2luZyBzZW5kZXIgZW1haWwgb3IgcGFzc3dvcmQgLSBjYW5ub3Qgc2VuZCBlbWFpbC4iKQogICAgICAgIHJldHVybgoKICAgIGlmIHNlcnZlcl9hZGRyIGlzIE5vbmU6CiAgICAgICAgY29udGV4dC5sb2dnZXIuZXJyb3IoIlNlcnZlciBub3Qgc3BlY2lmaWVkIC0gY2Fubm90IHNlbmQgZW1haWwuIikKICAgICAgICByZXR1cm4KCiAgICBtc2cgPSBFbWFpbE1lc3NhZ2UoKQogICAgbXNnWyJGcm9tIl0gPSBzZW5kZXIKICAgIG1zZ1siU3ViamVjdCJdID0gc3ViamVjdAogICAgbXNnWyJUbyJdID0gdG8KICAgIG1zZy5zZXRfY29udGVudChjb250ZW50KQoKICAgIGZvciBmaWxlbmFtZSBpbiBhdHRhY2htZW50czoKICAgICAgICBjb250ZXh0LmxvZ2dlci5pbmZvKGYiTG9va2luZyBhdCBhdHRhY2htZW50OiB7ZmlsZW5hbWV9IikKICAgICAgICBpZiBub3Qgb3MucGF0aC5pc2ZpbGUoZmlsZW5hbWUpOgogICAgICAgICAgICBjb250ZXh0LmxvZ2dlci53YXJuaW5nKGYiRmlsZW5hbWUgZG9lcyBub3QgZXhpc3Qge2ZpbGVuYW1lfSIpCiAgICAgICAgICAgIGNvbnRpbnVlCiAgICAgICAgY3R5cGUsIGVuY29kaW5nID0gbWltZXR5cGVzLmd1ZXNzX3R5cGUoZmlsZW5hbWUpCiAgICAgICAgaWYgY3R5cGUgaXMgTm9uZSBvciBlbmNvZGluZyBpcyBub3QgTm9uZToKICAgICAgICAgICAgY3R5cGUgPSAiYXBwbGljYXRpb24vb2N0ZXQtc3RyZWFtIgogICAgICAgIG1haW50eXBlLCBzdWJ0eXBlID0gY3R5cGUuc3BsaXQoIi8iLCAxKQogICAgICAgIHdpdGggb3BlbihmaWxlbmFtZSwgInJiIikgYXMgZnA6CiAgICAgICAgICAgIG1zZy5hZGRfYXR0YWNobWVudCgKICAgICAgICAgICAgICAgIGZwLnJlYWQoKSwKICAgICAgICAgICAgICAgIG1haW50eXBlPW1haW50eXBlLAogICAgICAgICAgICAgICAgc3VidHlwZT1zdWJ0eXBlLAogICAgICAgICAgICAgICAgZmlsZW5hbWU9b3MucGF0aC5iYXNlbmFtZShmaWxlbmFtZSksCiAgICAgICAgICAgICkKICAgICAgICAgICAgY29udGV4dC5sb2dnZXIuaW5mbygKICAgICAgICAgICAgICAgIGYiQWRkZWQgYXR0YWNobWVudDogRmlsZW5hbWU6IHtmaWxlbmFtZX0sIG9mIG1pbWV0eXBlOiB7bWFpbnR5cGV9LCB7c3VidHlwZX0iCiAgICAgICAgICAgICkKCiAgICB0cnk6CiAgICAgICAgcyA9IHNtdHBsaWIuU01UUChob3N0PXNlcnZlcl9hZGRyKQogICAgICAgIHMuc3RhcnR0bHMoKQogICAgICAgIHMubG9naW4oZW1haWxfdXNlciwgZW1haWxfcGFzcykKICAgICAgICBzLnNlbmRfbWVzc2FnZShtc2cpCiAgICAgICAgY29udGV4dC5sb2dnZXIuaW5mbygiRW1haWwgc2VudCBzdWNjZXNzZnVsbHkuIikKICAgIGV4Y2VwdCBzbXRwbGliLlNNVFBFeGNlcHRpb24gYXMgZXhwOgogICAgICAgIGNvbnRleHQubG9nZ2VyLmVycm9yKGYiU01UUCBleGNlcHRpb24gY2F1Z2h0IGluIFNNVFAgY29kZToge2V4cH0iKQogICAgZXhjZXB0IENvbm5lY3Rpb25FcnJvciBhcyBjZToKICAgICAgICBjb250ZXh0LmxvZ2dlci5lcnJvcihmIkNvbm5lY3Rpb24gZXJyb3IgY2F1Z2h0IGluIFNNVFAgY29kZToge2NlfSIpCg== + code_origin: '' + filename: send_email.py entry_points: send_email: - name: send_email - doc: Send an email. + outputs: + - type: None parameters: - name: context type: MLClientCtx doc: The function context - default: '' - name: sender type: str doc: Sender email address - default: '' - name: to type: str doc: Email address of mail recipient - default: '' - name: subject type: str doc: Email subject - default: '' - name: content type: str doc: Optional mail text @@ -48,20 +39,14 @@ spec: doc: Address of SMTP server to use. Use format : default: null - name: attachments - type: List[str] + type: list[str] doc: List of attachments to add. default: [] - outputs: - - default: '' - lineno: 27 + name: send_email + doc: Send an email. + has_kwargs: false + has_varargs: false + lineno: 25 + command: '' description: Send Email messages through SMTP server default_handler: send_email - disable_auto_mount: false - clone_target_dir: '' - env: [] - priority_class_name: '' - preemption_mode: prevent - affinity: null - tolerations: null - security_context: {} -verbose: false diff --git a/functions/src/send_email/send_email.py b/functions/src/send_email/send_email.py index 0dd9f7d0f..f6ab688ae 100644 --- a/functions/src/send_email/send_email.py +++ b/functions/src/send_email/send_email.py @@ -14,14 +14,12 @@ # # Generated by nuclio.export.NuclioExporter -from mlrun.execution import MLClientCtx -from typing import List - +import mimetypes +import os import smtplib from email.message import EmailMessage -import os -import mimetypes +from mlrun.execution import MLClientCtx def send_email( @@ -31,7 +29,7 @@ def send_email( subject: str, content: str = "", server_addr: str = None, - attachments: List[str] = [], + attachments: list[str] = [], ) -> None: """Send an email. :param sender: Sender email address diff --git a/functions/src/silero_vad/function.yaml b/functions/src/silero_vad/function.yaml index fd637f1c0..1d7b53d34 100644 --- a/functions/src/silero_vad/function.yaml +++ b/functions/src/silero_vad/function.yaml @@ -1,76 +1,74 @@ metadata: tag: '' + name: silero-vad categories: - deep-learning - audio - name: silero-vad verbose: false +kind: job spec: - description: Silero VAD (Voice Activity Detection) functions. + image: '' + disable_auto_mount: false build: - code_origin: '' - base_image: mlrun/mlrun + origin_filename: '' + functionSourceCode: # Copyright 2024 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from multiprocessing import Process, Queue
from pathlib import Path
from types import FunctionType

import torch
import torchaudio
from tqdm import tqdm


class BaseTask:
    """
    A base class for a task to complete after VAD.
    """

    def __init__(self, audio_file: Path):
        """
        Initialize the base task.

        :param audio_file: The audio file assigned to the task.
        """
        # Store the audio file:
        self._audio_file = audio_file

        # Prepare the result:
        self._result = None

    @property
    def audio_file(self) -> Path:
        """
        Get the audio file of the task.

        :returns: The audio file of the task.
        """
        return self._audio_file

    def do_task(
        self, speech_timestamps: list[dict[str, int]] | list[list[dict[str, int]]]
    ):
        """
        Do the task on the given speech timestamps. The base task will simply save the speech timestamps as the result.

        :param speech_timestamps: The speech timestamps to do the task on as outputted from the VAD.
        """
        self._result = speech_timestamps

    def get_result(self) -> tuple[str, list]:
        """
        Get the result of the task. A tuple of the audio file name and the result.

        :returns: The result of the task.
        """
        return self._audio_file.name, self._result

    def to_tuple(self) -> tuple[str, dict]:
        """
        Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue).

        :returns: The converted task.
        """
        return self.__class__.__name__, {"audio_file": self._audio_file}


class SpeechDiarizationTask(BaseTask):
    """
    A speech diarization task. The task will diarize the VAD speech timestamps into speakers.
    """

    def __init__(self, audio_file: Path, speaker_labels: list[str]):
        """
        Initialize the speech diarization task.

        :param audio_file:     The audio file assigned to the task.
        :param speaker_labels: The speaker labels to use for the diarization. If not given, the speakers will be named
                               "speaker_0", "speaker_1", etc.
        """
        super().__init__(audio_file=audio_file)
        self._speaker_labels = speaker_labels

    def do_task(self, speech_timestamps: list[list[dict[str, int]]]):
        """
        Do the task on the given speech timestamps. The task will diarize the VAD speech timestamps into speakers.

        :param speech_timestamps: The speech timestamps per channel to do the task on as outputted from the VAD.
        """
        # Get the speaker labels (set default if not given):
        speaker_labels = self._speaker_labels or [
            f"speaker_{i}" for i in range(len(speech_timestamps))
        ]

        # Diarize - organize the speech timestamps into a single list of speakers and sort it by start time:
        speech_diarization = [
            (speech_timestamp["start"], speech_timestamp["end"], speaker_label)
            for speaker_label, channel_speech_timestamps in zip(
                speaker_labels, speech_timestamps
            )
            for speech_timestamp in channel_speech_timestamps
        ]
        speech_diarization.sort()
        self._result = speech_diarization

    def to_tuple(self) -> tuple[str, dict]:
        """
        Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue).

        :returns: The converted task.
        """
        task_class, task_kwargs = super().to_tuple()
        return task_class, {**task_kwargs, "speaker_labels": self._speaker_labels}


class TaskCreator:
    """
    A task creator to create different tasks to run after the VAD.
    """

    #: A map from task class name to task class to use in `from_tuple`:
    _MAP = {
        BaseTask.__name__: BaseTask,
        SpeechDiarizationTask.__name__: SpeechDiarizationTask,
    }

    def __init__(self, task_type: type[BaseTask], task_kwargs: dict = None):
        """
        Initialize the task creator.
        :param task_type: The task type - a `BaseTask` subclass.
        :param task_kwargs: Additional keyword arguments to pass to the to be created tasks.
        """
        self._task_type = task_type
        self._task_kwargs = task_kwargs or {}

    def create_task(self, audio_file: Path) -> BaseTask:
        """
        Create a task with the given audio file.

        :param audio_file: The audio file to assign to the task.

        :returns: The created task.
        """
        return self._task_type(audio_file=audio_file, **self._task_kwargs)

    @classmethod
    def from_tuple(cls, task_tuple: tuple[str, dict]) -> BaseTask:
        """
        Create a task from a tuple of the audio file name and the task kwargs.

        :param task_tuple: The task tuple to create the task from.

        :returns: The created task.
        """
        task_class, task_kwargs = task_tuple
        return cls._MAP[task_class](**task_kwargs)


class VoiceActivityDetector:
    """
    A voice activity detection wrapper for the silero VAD model - https://github.com/snakers4/silero-vad.
    """

    def __init__(
        self,
        # Model loading kwargs:
        use_onnx: bool = True,
        force_onnx_cpu: bool = True,
        # Detection kwargs:
        threshold: float = 0.5,
        sampling_rate: int = 16_000,
        min_speech_duration_ms: int = 250,
        max_speech_duration_s: float = float("inf"),
        min_silence_duration_ms: int = 100,
        window_size_samples: int = 512,
        speech_pad_ms: int = 30,
        return_seconds: bool = False,
        per_channel: bool = False,
    ):
        """
        Initialize the voice activity detector.

        :param use_onnx:                Whether to use ONNX for inference. Default is True.
        :param force_onnx_cpu:          Whether to force ONNX to use CPU for inference. Default is True.
        :param threshold:               Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
                                        probabilities ABOVE this value are considered as SPEECH. It is better to tune
                                        this parameter for each dataset separately, but "lazy" 0.5 is pretty good for
                                        most datasets.
        :param sampling_rate:           Currently, silero VAD models support 8000 and 16000 sample rates.
        :param min_speech_duration_ms:  Final speech chunks shorter min_speech_duration_ms are thrown out.
        :param max_speech_duration_s:   Maximum duration of speech chunks in seconds. Chunks longer than
                                        `max_speech_duration_s` will be split at the timestamp of the last silence that
                                        lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise,
                                        they will be split aggressively just before max_speech_duration_s.
        :param min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms before
                                        separating it.
        :param window_size_samples:     Audio chunks of window_size_samples size are fed to the silero VAD model.
                                        WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000
                                        sample rate and 256, 512, 768 samples for 8000 sample rate. Values other than
                                        these may affect model performance!
        :param speech_pad_ms:           Final speech chunks are padded by speech_pad_ms each side.
        :param return_seconds:          Whether return timestamps in seconds. False means to return timestamps in
                                        samples (default - False).
        :param per_channel:             Whether to return timestamps per channel (default - False). This will run VAD
                                        on each channel separately and return a list of timestamps per channel.
        """
        # Store configurations:
        self._use_onnx = use_onnx
        self._force_onnx_cpu = force_onnx_cpu
        self._threshold = threshold
        self._sampling_rate = sampling_rate
        self._min_speech_duration_ms = min_speech_duration_ms
        self._max_speech_duration_s = max_speech_duration_s
        self._min_silence_duration_ms = min_silence_duration_ms
        self._window_size_samples = window_size_samples
        self._speech_pad_ms = speech_pad_ms
        self._return_seconds = return_seconds
        self._per_channel = per_channel

        # Prepare the model variables
        self._model: torch.Module = None
        self._get_speech_timestamps: FunctionType = None

    def load(self, force_reload: bool = True):
        """
        Load the VAD model.

        :param force_reload: Whether to force reload the model even if it was already loaded. Default is True.
        """
        model, utils = torch.hub.load(
            repo_or_dir="snakers4/silero-vad",
            model="silero_vad",
            force_reload=force_reload,
            onnx=self._use_onnx,
            force_onnx_cpu=self._force_onnx_cpu,
        )
        self._model = model
        (
            self._get_speech_timestamps,
            _,  # save_audio,
            _,  # read_audio,
            _,  # VADIterator,
            _,  # collect_chunks
        ) = utils

    def detect_voice(
        self,
        audio_file: Path,
    ) -> list[dict[str, int]] | list[list[dict[str, int]]]:
        """
        Infer the audio through the VAD model and return the speech timestamps.

        :param audio_file: The audio file to infer.

        :returns: The speech timestamps in the audio. A list of timestamps where each timestamp is a dictionary with the
                 following keys:

                 * "start": The start sample index of the speech in the audio.
                 * "end":   The end sample index of the speech in the audio.

                 If `per_channel` is True, a list of timestamps per channel will be returned.
        """
        # Cast to a numpy array:
        audio = self._read_audio(audio_file)

        # Detect speech:
        if not self._per_channel:
            return self._get_speech_timestamps(
                audio,
                self._model,
                threshold=self._threshold,
                min_speech_duration_ms=self._min_speech_duration_ms,
                max_speech_duration_s=self._max_speech_duration_s,
                min_silence_duration_ms=self._min_silence_duration_ms,
                speech_pad_ms=self._speech_pad_ms,
                sampling_rate=self._sampling_rate,
                window_size_samples=self._window_size_samples,
                return_seconds=self._return_seconds,
            )

        # Per channel:
        speech_timestamps = []
        for channel in audio:
            speech_timestamps.append(
                self._get_speech_timestamps(
                    channel,
                    self._model,
                    threshold=self._threshold,
                    min_speech_duration_ms=self._min_speech_duration_ms,
                    max_speech_duration_s=self._max_speech_duration_s,
                    min_silence_duration_ms=self._min_silence_duration_ms,
                    speech_pad_ms=self._speech_pad_ms,
                    sampling_rate=self._sampling_rate,
                    window_size_samples=self._window_size_samples,
                    return_seconds=self._return_seconds,
                )
            )

        return speech_timestamps

    def _read_audio(
        self,
        path: Path,
    ) -> torch.Tensor:
        """
        Read the audio from the given path and return it as a tensor.

        :param path: The path to the audio file.

        :returns: The audio as a tensor.
        """
        # Read the audio:
        audio, sampling_rate = torchaudio.load(str(path))

        # Check if the audio is stereo and if so, convert it to mono (only if not per channel):
        if audio.size(0) > 1 and not self._per_channel:
            audio = audio.mean(dim=0, keepdim=True)

        # Resample the audio if needed:
        if sampling_rate != self._sampling_rate:
            transform = torchaudio.transforms.Resample(
                orig_freq=sampling_rate, new_freq=self._sampling_rate
            )
            audio = transform(audio)

        # Return the audio (squeeze if not per channel):
        return audio if self._per_channel else audio.squeeze(0)


#: The value to send into multiprocessing queues to stop the process:
_MULTIPROCESSING_STOP_MARK = "STOP"


def _multiprocessing_complete_tasks(
    vad_init_kwargs: dict, tasks_queue: Queue, results_queue: Queue
):
    """
    Complete the tasks in the given queue and put the results in the given results queue. The function will stop when
    the given tasks queue will receive the stop mark. It is aimed to be used with multiprocessing as a process.

    :param vad_init_kwargs: The VAD initialization kwargs.
    :param tasks_queue:     A queue to get the tasks from.
    :param results_queue:   A queue to put the results in.
    """
    # Initialize and load the VAD:
    vad = VoiceActivityDetector(**vad_init_kwargs)
    vad.load(force_reload=False)

    # Start listening to the tasks queue:
    while True:
        # Get the task:
        task: tuple[str, dict] = tasks_queue.get()
        if task == _MULTIPROCESSING_STOP_MARK:
            break
        try:
            # Create the task:
            task = TaskCreator.from_tuple(task_tuple=task)
            # Run the file through the VAD:
            speech_timestamps = vad.detect_voice(audio_file=task.audio_file)
            # Complete the task:
            task.do_task(speech_timestamps=speech_timestamps)
            # Build the result:
            result = (False, task.get_result())
        except Exception as exception:
            # Build the error:
            result = (True, (task.audio_file.name, str(exception)))
        # Collect the result / error:
        results_queue.put(result)

    # Mark the end of the tasks:
    results_queue.put(_MULTIPROCESSING_STOP_MARK)


# Get the global logger:
try:
    import mlrun

    _LOGGER = mlrun.get_or_create_ctx("silero_vad").logger
except ModuleNotFoundError:
    _LOGGER = logging.getLogger()


def detect_voice(
    # Input kwargs:
    data_path: str | Path | list[str | Path],
    # Model loading kwargs:
    use_onnx: bool = True,
    force_onnx_cpu: bool = True,
    # Detection kwargs:
    threshold: float = 0.5,
    sampling_rate: int = 16_000,
    min_speech_duration_ms: int = 250,
    max_speech_duration_s: float = float("inf"),
    min_silence_duration_ms: int = 100,
    window_size_samples: int = 512,
    speech_pad_ms: int = 30,
    return_seconds: bool = False,
    per_channel: bool = False,
    # Other kwargs:
    use_multiprocessing: int = 0,
    verbose: bool = False,
):
    """
    Perform voice activity detection on given audio files using the silero VAD model -
    https://github.com/snakers4/silero-vad. The end result is a dictionary with the file names as keys and their
    VAD timestamps dictionaries as value.

    For example::

        {
            "file_1.wav": [
                {"start": 0, "end": 16000},
                {"start": 16000, "end": 32000},
                {"start": 32000, "end": 48000},
                ...
            ],
            "file_2.wav": [
                {"start": 0, "end": 16000},
                {"start": 16000, "end": 32000},
                {"start": 32000, "end": 48000},
                ...
            ],
            ...
        }


    :param data_path:               The path to the audio files to diarize. Can be a path to a single file, a path to a
                                    directory or a list of paths to files.
    :param use_onnx:                Whether to use ONNX for inference. Default is True.
    :param force_onnx_cpu:          Whether to force ONNX to use CPU for inference. Default is True.
    :param threshold:               Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
                                    probabilities ABOVE this value are considered as SPEECH. It is better to tune
                                    this parameter for each dataset separately, but "lazy" 0.5 is pretty good for
                                    most datasets.
    :param sampling_rate:           Currently, silero VAD models support 8000 and 16000 sample rates.
    :param min_speech_duration_ms:  Final speech chunks shorter min_speech_duration_ms are thrown out.
    :param max_speech_duration_s:   Maximum duration of speech chunks in seconds. Chunks longer than
                                    `max_speech_duration_s` will be split at the timestamp of the last silence that
                                    lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will
                                    be split aggressively just before max_speech_duration_s.
    :param min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms before separating
                                    it.
    :param window_size_samples:     Audio chunks of window_size_samples size are fed to the silero VAD model.

                                    WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000
                                    sample rate and 256, 512, 768 samples for 8000 sample rate. Values other than
                                    these may affect model performance!
    :param speech_pad_ms:           Final speech chunks are padded by speech_pad_ms each side.
    :param return_seconds:          Whether return timestamps in seconds. False means to return timestamps in samples
                                    (default - False).
    :param per_channel:             Whether to return timestamps per channel (default - False). This will run VAD on
                                    each channel separately and return a list of timestamps per channel.
    :param use_multiprocessing:     The number of workers to use for multiprocessing. If 0, no multiprocessing will
                                    be used. Default is 0.
    :param verbose:                 Verbosity.
    """
    global _LOGGER

    # Get the input audio files to transcribe:
    if verbose:
        _LOGGER.info("Collecting audio files.")
    audio_files = _get_audio_files(data_path=data_path)
    if verbose:
        _LOGGER.info(f"Collected {len(audio_files)} audio files.")

    # Initialize the transcription pipeline:
    vad_init_kwargs = {
        "use_onnx": use_onnx,
        "force_onnx_cpu": force_onnx_cpu,
        "threshold": threshold,
        "sampling_rate": sampling_rate,
        "min_speech_duration_ms": min_speech_duration_ms,
        "max_speech_duration_s": max_speech_duration_s,
        "min_silence_duration_ms": min_silence_duration_ms,
        "window_size_samples": window_size_samples,
        "speech_pad_ms": speech_pad_ms,
        "return_seconds": return_seconds,
        "per_channel": per_channel,
    }

    # Create the task creator:
    task_creator = TaskCreator(task_type=BaseTask)

    # Run the transcription:
    if use_multiprocessing:
        results = _parallel_run(
            n_workers=use_multiprocessing,
            audio_files=audio_files,
            description="Detecting voice",
            vad_init_kwargs=vad_init_kwargs,
            task_creator=task_creator,
            verbose=verbose,
        )
    else:
        results = _run(
            audio_files=audio_files,
            description="Detecting voice",
            vad_init_kwargs=vad_init_kwargs,
            task_creator=task_creator,
            verbose=verbose,
        )

    # Process the results:
    return _process_results(results=results, verbose=verbose)


def diarize(
    # Input / Output kwargs:
    data_path: str | Path | list[str | Path],
    # Model loading kwargs:
    use_onnx: bool = True,
    force_onnx_cpu: bool = True,
    # Detection kwargs:
    threshold: float = 0.5,
    sampling_rate: int = 16_000,
    min_speech_duration_ms: int = 250,
    max_speech_duration_s: float = float("inf"),
    min_silence_duration_ms: int = 100,
    window_size_samples: int = 512,
    speech_pad_ms: int = 30,
    # Diarization kwargs:
    speaker_labels: list[str] = None,
    # Other kwargs:
    use_multiprocessing: int = 0,
    verbose: bool = False,
):
    """
    Perform speech diarization on given audio files using the silero VAD model - https://github.com/snakers4/silero-vad.
    The speech diarization is performed per channel so that each channel in the audio belong to a different speaker. The
    end result is a dictionary with the file names as keys and their diarization as value. A diarization is a list
    of tuples: (start, end, speaker_label).

    For example::

        {
            "file_1.wav": [
                (0.0, 1.0, "speaker_0"),
                (1.0, 2.0, "speaker_1"),
                (2.0, 3.0, "speaker_0"),
                ...
            ],
            "file_2.wav": [
                (0.0, 1.0, "speaker_0"),
                (1.0, 2.0, "speaker_1"),
                (2.0, 3.0, "speaker_0"),
                ...
            ],
            ...
        }


    :param data_path:               The path to the audio files to diarize. Can be a path to a single file, a path to a
                                    directory or a list of paths to files.
    :param use_onnx:                Whether to use ONNX for inference. Default is True.
    :param force_onnx_cpu:          Whether to force ONNX to use CPU for inference. Default is True.
    :param threshold:               Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
                                    probabilities ABOVE this value are considered as SPEECH. It is better to tune
                                    this parameter for each dataset separately, but "lazy" 0.5 is pretty good for
                                    most datasets.
    :param sampling_rate:           Currently, silero VAD models support 8000 and 16000 sample rates.
    :param min_speech_duration_ms:  Final speech chunks shorter min_speech_duration_ms are thrown out.
    :param max_speech_duration_s:   Maximum duration of speech chunks in seconds. Chunks longer than
                                    `max_speech_duration_s` will be split at the timestamp of the last silence that
                                    lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will
                                    be split aggressively just before max_speech_duration_s.
    :param min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms before separating
                                    it.
    :param window_size_samples:     Audio chunks of window_size_samples size are fed to the silero VAD model.

                                    WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000
                                    sample rate and 256, 512, 768 samples for 8000 sample rate. Values other than
                                    these may affect model performance!
    :param speech_pad_ms:           Final speech chunks are padded by speech_pad_ms each side.
    :param speaker_labels:          The speaker labels to use for the diarization. If not given, the speakers will be
                                    named "speaker_0", "speaker_1", etc.
    :param use_multiprocessing:     The number of workers to use for multiprocessing. If 0, no multiprocessing will
                                    be used. Default is 0.
    :param verbose:                 Verbosity.
    """
    global _LOGGER

    # Get the input audio files to transcribe:
    if verbose:
        _LOGGER.info("Collecting audio files.")
    audio_files = _get_audio_files(data_path=data_path)
    if verbose:
        _LOGGER.info(f"Collected {len(audio_files)} audio files.")

    # Initialize the transcription pipeline:
    vad_init_kwargs = {
        "use_onnx": use_onnx,
        "force_onnx_cpu": force_onnx_cpu,
        "threshold": threshold,
        "sampling_rate": sampling_rate,
        "min_speech_duration_ms": min_speech_duration_ms,
        "max_speech_duration_s": max_speech_duration_s,
        "min_silence_duration_ms": min_silence_duration_ms,
        "window_size_samples": window_size_samples,
        "speech_pad_ms": speech_pad_ms,
        "return_seconds": True,
        "per_channel": True,
    }

    # Create the task creator:
    task_creator = TaskCreator(
        task_type=SpeechDiarizationTask, task_kwargs={"speaker_labels": speaker_labels}
    )

    # Run the transcription:
    if use_multiprocessing:
        results = _parallel_run(
            n_workers=use_multiprocessing,
            audio_files=audio_files,
            description="Diarizing",
            vad_init_kwargs=vad_init_kwargs,
            task_creator=task_creator,
            verbose=verbose,
        )
    else:
        results = _run(
            audio_files=audio_files,
            description="Diarizing",
            vad_init_kwargs=vad_init_kwargs,
            task_creator=task_creator,
            verbose=verbose,
        )

    # Process the results:
    return _process_results(results=results, verbose=verbose)


def _get_audio_files(
    data_path: Path | str | list,
) -> list[Path]:
    """
    Get the audio files from the data path. If a path to a directory is given, all files in the directory will be
    collected.

    :param data_path: The data path to collect the audio files from.

    :returns: The audio files list.
    """
    # Check if given a list of paths:
    if isinstance(data_path, list):
        audio_files = []
        for path in data_path:
            audio_files.extend(_get_audio_files(data_path=path))
        return audio_files

    # Check if given a single string path to cast it to a `pathlib.Path`:
    if isinstance(data_path, str):
        data_path = Path(data_path).absolute()

    # Check if the path is of a directory or a file:
    if data_path.is_dir():
        # Get all files inside the directory:
        audio_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        audio_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be a valid path to either a directory path or a "
            f"file. Given: {str(data_path)} "
        )

    return audio_files


def _run(
    audio_files: list[Path],
    description: str,
    vad_init_kwargs: dict,
    task_creator: TaskCreator,
    verbose: bool,
) -> list[tuple[bool, tuple[str, list]]]:
    """
    Load a VAD and use it to complete the tasks that will be created on the provided files using the given task creator.

    :param audio_files:     The audio files to use.
    :param description:     The description to use for the progress bar.
    :param vad_init_kwargs: The VAD initialization keyword arguments.
    :param task_creator:    The task creator to use to create the tasks.
    :param verbose:         Verbosity.

    :returns: The collected results.
    """
    # Load the VAD:
    vad = VoiceActivityDetector(**vad_init_kwargs)
    if verbose:
        _LOGGER.info("Loading the VAD model.")
    vad.load()
    if verbose:
        _LOGGER.info("VAD model loaded.")

    # Run the VAD on the audio files and collect the results:
    results = []
    for audio_file in tqdm(
        audio_files,
        desc=description,
        unit="file",
        total=len(audio_files),
        disable=not verbose,
    ):
        try:
            # Create the task:
            task = task_creator.create_task(audio_file=audio_file)
            # Run the file through the VAD:
            speech_timestamps = vad.detect_voice(audio_file=audio_file)
            # Complete the task:
            task.do_task(speech_timestamps=speech_timestamps)
            # Collect the result:
            results.append((False, task.get_result()))
        except Exception as exception:
            # Collect the error:
            results.append((True, (audio_file.name, str(exception))))

    return results


def _parallel_run(
    n_workers: int,
    audio_files: list[Path],
    description: str,
    vad_init_kwargs: dict,
    task_creator: TaskCreator,
    verbose: bool,
) -> list[tuple[bool, tuple[str, list]]]:
    """
    Run multiple VAD workers with multiprocessing to complete the tasks that will be created on the provided files using
    the given task creator.

    :param n_workers:       The number of workers to use.
    :param audio_files:     The audio files to use.
    :param description:     The description to use for the progress bar.
    :param vad_init_kwargs: The VAD initialization keyword arguments.
    :param task_creator:    The task creator to use to create the tasks.
    :param verbose:         Verbosity.

    :returns: The collected results.
    """
    # Load the VAD (download once, and it will be loaded then per process later on):
    if verbose:
        _LOGGER.info("Loading the VAD model.")
    vad = VoiceActivityDetector(**vad_init_kwargs)
    vad.load()
    if verbose:
        _LOGGER.info("VAD model loaded.")

    # Check the number of workers:
    if n_workers > len(audio_files):
        _LOGGER.warning(
            f"The number of workers ({n_workers}) is larger than the number of audio files ({len(audio_files)}). "
            f"Setting the number of workers to {len(audio_files)}."
        )
        n_workers = len(audio_files)

    # Initialize the multiprocessing queues:
    tasks_queue = Queue()
    results_queue = Queue()

    # Initialize the multiprocessing processes:
    task_completion_processes = [
        Process(
            target=_multiprocessing_complete_tasks,
            kwargs={
                "vad_init_kwargs": vad_init_kwargs,
                "tasks_queue": tasks_queue,
                "results_queue": results_queue,
            },
        )
        for _ in range(n_workers)
    ]

    # Start the multiprocessing processes:
    for p in task_completion_processes:
        p.start()

    # Put the tasks in the queue:
    for audio_file in audio_files:
        tasks_queue.put(task_creator.create_task(audio_file=audio_file).to_tuple())

    # Put the stop marks in the queue:
    for _ in range(n_workers):
        tasks_queue.put(_MULTIPROCESSING_STOP_MARK)

    # Collect the results:
    results = []
    stop_marks_counter = 0
    with tqdm(
        desc=description,
        unit="file",
        total=len(audio_files),
        disable=not verbose,
    ) as progressbar:
        while True:
            # Get a result from the queue:
            result: tuple[bool, tuple[str, list]] = results_queue.get()
            if result == _MULTIPROCESSING_STOP_MARK:
                stop_marks_counter += 1
                if stop_marks_counter == n_workers:
                    break
            else:
                # Collect the result:
                results.append(result)
                progressbar.update(1)

    # Wait for the processes to finish:
    for p in task_completion_processes:
        p.join()

    return results


def _process_results(
    results: list[tuple[bool, tuple[str, list]]], verbose: bool
) -> tuple[dict, dict]:
    """
    Process the results of the tasks.

    :param results: The results to process.
    :param verbose: Verbosity.

    :returns: The processed results as a tuple of successes and errors.
    """
    if verbose:
        _LOGGER.info("Summarizing the results.")
    successes = {}
    errors = {}
    for is_error, result in results:
        if is_error:
            errors[result[0]] = result[1]
        else:
            successes[result[0]] = result[1]
    if verbose:
        _LOGGER.info(f"Done ({len(successes)}/{len(successes) + len(errors)})\n")

    return successes, errors
 requirements: - torch - torchaudio - tqdm - onnxruntime - functionSourceCode: # Copyright 2024 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from multiprocessing import Process, Queue
from pathlib import Path
from types import FunctionType
from typing import Dict, List, Tuple, Type, Union

import torch
import torchaudio
from tqdm import tqdm


class BaseTask:
    """
    A base class for a task to complete after VAD.
    """

    def __init__(self, audio_file: Path):
        """
        Initialize the base task.

        :param audio_file: The audio file assigned to the task.
        """
        # Store the audio file:
        self._audio_file = audio_file

        # Prepare the result:
        self._result = None

    @property
    def audio_file(self) -> Path:
        """
        Get the audio file of the task.

        :returns: The audio file of the task.
        """
        return self._audio_file

    def do_task(
        self, speech_timestamps: Union[List[Dict[str, int]], List[List[Dict[str, int]]]]
    ):
        """
        Do the task on the given speech timestamps. The base task will simply save the speech timestamps as the result.

        :param speech_timestamps: The speech timestamps to do the task on as outputted from the VAD.
        """
        self._result = speech_timestamps

    def get_result(self) -> Tuple[str, list]:
        """
        Get the result of the task. A tuple of the audio file name and the result.

        :returns: The result of the task.
        """
        return self._audio_file.name, self._result

    def to_tuple(self) -> Tuple[str, dict]:
        """
        Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue).

        :returns: The converted task.
        """
        return self.__class__.__name__, {"audio_file": self._audio_file}


class SpeechDiarizationTask(BaseTask):
    """
    A speech diarization task. The task will diarize the VAD speech timestamps into speakers.
    """

    def __init__(self, audio_file: Path, speaker_labels: List[str]):
        """
        Initialize the speech diarization task.

        :param audio_file:     The audio file assigned to the task.
        :param speaker_labels: The speaker labels to use for the diarization. If not given, the speakers will be named
                               "speaker_0", "speaker_1", etc.
        """
        super().__init__(audio_file=audio_file)
        self._speaker_labels = speaker_labels

    def do_task(self, speech_timestamps: List[List[Dict[str, int]]]):
        """
        Do the task on the given speech timestamps. The task will diarize the VAD speech timestamps into speakers.

        :param speech_timestamps: The speech timestamps per channel to do the task on as outputted from the VAD.
        """
        # Get the speaker labels (set default if not given):
        speaker_labels = self._speaker_labels or [
            f"speaker_{i}" for i in range(len(speech_timestamps))
        ]

        # Diarize - organize the speech timestamps into a single list of speakers and sort it by start time:
        speech_diarization = [
            (speech_timestamp["start"], speech_timestamp["end"], speaker_label)
            for speaker_label, channel_speech_timestamps in zip(
                speaker_labels, speech_timestamps
            )
            for speech_timestamp in channel_speech_timestamps
        ]
        speech_diarization.sort()
        self._result = speech_diarization

    def to_tuple(self) -> Tuple[str, dict]:
        """
        Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue).

        :returns: The converted task.
        """
        task_class, task_kwargs = super().to_tuple()
        return task_class, {**task_kwargs, "speaker_labels": self._speaker_labels}


class TaskCreator:
    """
    A task creator to create different tasks to run after the VAD.
    """

    #: A map from task class name to task class to use in `from_tuple`:
    _MAP = {
        BaseTask.__name__: BaseTask,
        SpeechDiarizationTask.__name__: SpeechDiarizationTask,
    }

    def __init__(self, task_type: Type[BaseTask], task_kwargs: dict = None):
        """
        Initialize the task creator.
        :param task_type: The task type - a `BaseTask` subclass.
        :param task_kwargs: Additional keyword arguments to pass to the to be created tasks.
        """
        self._task_type = task_type
        self._task_kwargs = task_kwargs or {}

    def create_task(self, audio_file: Path) -> BaseTask:
        """
        Create a task with the given audio file.

        :param audio_file: The audio file to assign to the task.

        :returns: The created task.
        """
        return self._task_type(audio_file=audio_file, **self._task_kwargs)

    @classmethod
    def from_tuple(cls, task_tuple: Tuple[str, dict]) -> BaseTask:
        """
        Create a task from a tuple of the audio file name and the task kwargs.

        :param task_tuple: The task tuple to create the task from.

        :returns: The created task.
        """
        task_class, task_kwargs = task_tuple
        return cls._MAP[task_class](**task_kwargs)


class VoiceActivityDetector:
    """
    A voice activity detection wrapper for the silero VAD model - https://github.com/snakers4/silero-vad.
    """

    def __init__(
        self,
        # Model loading kwargs:
        use_onnx: bool = True,
        force_onnx_cpu: bool = True,
        # Detection kwargs:
        threshold: float = 0.5,
        sampling_rate: int = 16_000,
        min_speech_duration_ms: int = 250,
        max_speech_duration_s: float = float("inf"),
        min_silence_duration_ms: int = 100,
        window_size_samples: int = 512,
        speech_pad_ms: int = 30,
        return_seconds: bool = False,
        per_channel: bool = False,
    ):
        """
        Initialize the voice activity detector.

        :param use_onnx:                Whether to use ONNX for inference. Default is True.
        :param force_onnx_cpu:          Whether to force ONNX to use CPU for inference. Default is True.
        :param threshold:               Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
                                        probabilities ABOVE this value are considered as SPEECH. It is better to tune
                                        this parameter for each dataset separately, but "lazy" 0.5 is pretty good for
                                        most datasets.
        :param sampling_rate:           Currently, silero VAD models support 8000 and 16000 sample rates.
        :param min_speech_duration_ms:  Final speech chunks shorter min_speech_duration_ms are thrown out.
        :param max_speech_duration_s:   Maximum duration of speech chunks in seconds. Chunks longer than
                                        `max_speech_duration_s` will be split at the timestamp of the last silence that
                                        lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise,
                                        they will be split aggressively just before max_speech_duration_s.
        :param min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms before
                                        separating it.
        :param window_size_samples:     Audio chunks of window_size_samples size are fed to the silero VAD model.
                                        WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000
                                        sample rate and 256, 512, 768 samples for 8000 sample rate. Values other than
                                        these may affect model performance!
        :param speech_pad_ms:           Final speech chunks are padded by speech_pad_ms each side.
        :param return_seconds:          Whether return timestamps in seconds. False means to return timestamps in
                                        samples (default - False).
        :param per_channel:             Whether to return timestamps per channel (default - False). This will run VAD
                                        on each channel separately and return a list of timestamps per channel.
        """
        # Store configurations:
        self._use_onnx = use_onnx
        self._force_onnx_cpu = force_onnx_cpu
        self._threshold = threshold
        self._sampling_rate = sampling_rate
        self._min_speech_duration_ms = min_speech_duration_ms
        self._max_speech_duration_s = max_speech_duration_s
        self._min_silence_duration_ms = min_silence_duration_ms
        self._window_size_samples = window_size_samples
        self._speech_pad_ms = speech_pad_ms
        self._return_seconds = return_seconds
        self._per_channel = per_channel

        # Prepare the model variables
        self._model: torch.Module = None
        self._get_speech_timestamps: FunctionType = None

    def load(self, force_reload: bool = True):
        """
        Load the VAD model.

        :param force_reload: Whether to force reload the model even if it was already loaded. Default is True.
        """
        model, utils = torch.hub.load(
            repo_or_dir="snakers4/silero-vad",
            model="silero_vad",
            force_reload=force_reload,
            onnx=self._use_onnx,
            force_onnx_cpu=self._force_onnx_cpu,
        )
        self._model = model
        (
            self._get_speech_timestamps,
            _,  # save_audio,
            _,  # read_audio,
            _,  # VADIterator,
            _,  # collect_chunks
        ) = utils

    def detect_voice(
        self,
        audio_file: Path,
    ) -> Union[List[Dict[str, int]], List[List[Dict[str, int]]]]:
        """
        Infer the audio through the VAD model and return the speech timestamps.

        :param audio_file: The audio file to infer.

        :returns: The speech timestamps in the audio. A list of timestamps where each timestamp is a dictionary with the
                 following keys:

                 * "start": The start sample index of the speech in the audio.
                 * "end":   The end sample index of the speech in the audio.

                 If `per_channel` is True, a list of timestamps per channel will be returned.
        """
        # Cast to a numpy array:
        audio = self._read_audio(audio_file)

        # Detect speech:
        if not self._per_channel:
            return self._get_speech_timestamps(
                audio,
                self._model,
                threshold=self._threshold,
                min_speech_duration_ms=self._min_speech_duration_ms,
                max_speech_duration_s=self._max_speech_duration_s,
                min_silence_duration_ms=self._min_silence_duration_ms,
                speech_pad_ms=self._speech_pad_ms,
                sampling_rate=self._sampling_rate,
                window_size_samples=self._window_size_samples,
                return_seconds=self._return_seconds,
            )

        # Per channel:
        speech_timestamps = []
        for channel in audio:
            speech_timestamps.append(
                self._get_speech_timestamps(
                    channel,
                    self._model,
                    threshold=self._threshold,
                    min_speech_duration_ms=self._min_speech_duration_ms,
                    max_speech_duration_s=self._max_speech_duration_s,
                    min_silence_duration_ms=self._min_silence_duration_ms,
                    speech_pad_ms=self._speech_pad_ms,
                    sampling_rate=self._sampling_rate,
                    window_size_samples=self._window_size_samples,
                    return_seconds=self._return_seconds,
                )
            )

        return speech_timestamps

    def _read_audio(
        self,
        path: Path,
    ) -> torch.Tensor:
        """
        Read the audio from the given path and return it as a tensor.

        :param path: The path to the audio file.

        :returns: The audio as a tensor.
        """
        # Read the audio:
        audio, sampling_rate = torchaudio.load(str(path))

        # Check if the audio is stereo and if so, convert it to mono (only if not per channel):
        if audio.size(0) > 1 and not self._per_channel:
            audio = audio.mean(dim=0, keepdim=True)

        # Resample the audio if needed:
        if sampling_rate != self._sampling_rate:
            transform = torchaudio.transforms.Resample(
                orig_freq=sampling_rate, new_freq=self._sampling_rate
            )
            audio = transform(audio)

        # Return the audio (squeeze if not per channel):
        return audio if self._per_channel else audio.squeeze(0)


#: The value to send into multiprocessing queues to stop the process:
_MULTIPROCESSING_STOP_MARK = "STOP"


def _multiprocessing_complete_tasks(
    vad_init_kwargs: dict, tasks_queue: Queue, results_queue: Queue
):
    """
    Complete the tasks in the given queue and put the results in the given results queue. The function will stop when
    the given tasks queue will receive the stop mark. It is aimed to be used with multiprocessing as a process.

    :param vad_init_kwargs: The VAD initialization kwargs.
    :param tasks_queue:     A queue to get the tasks from.
    :param results_queue:   A queue to put the results in.
    """
    # Initialize and load the VAD:
    vad = VoiceActivityDetector(**vad_init_kwargs)
    vad.load(force_reload=False)

    # Start listening to the tasks queue:
    while True:
        # Get the task:
        task: Tuple[str, dict] = tasks_queue.get()
        if task == _MULTIPROCESSING_STOP_MARK:
            break
        try:
            # Create the task:
            task = TaskCreator.from_tuple(task_tuple=task)
            # Run the file through the VAD:
            speech_timestamps = vad.detect_voice(audio_file=task.audio_file)
            # Complete the task:
            task.do_task(speech_timestamps=speech_timestamps)
            # Build the result:
            result = (False, task.get_result())
        except Exception as exception:
            # Build the error:
            result = (True, (task.audio_file.name, str(exception)))
        # Collect the result / error:
        results_queue.put(result)

    # Mark the end of the tasks:
    results_queue.put(_MULTIPROCESSING_STOP_MARK)


# Get the global logger:
try:
    import mlrun

    _LOGGER = mlrun.get_or_create_ctx("silero_vad").logger
except ModuleNotFoundError:
    _LOGGER = logging.getLogger()


def detect_voice(
    # Input kwargs:
    data_path: Union[str, Path, List[Union[str, Path]]],
    # Model loading kwargs:
    use_onnx: bool = True,
    force_onnx_cpu: bool = True,
    # Detection kwargs:
    threshold: float = 0.5,
    sampling_rate: int = 16_000,
    min_speech_duration_ms: int = 250,
    max_speech_duration_s: float = float("inf"),
    min_silence_duration_ms: int = 100,
    window_size_samples: int = 512,
    speech_pad_ms: int = 30,
    return_seconds: bool = False,
    per_channel: bool = False,
    # Other kwargs:
    use_multiprocessing: int = 0,
    verbose: bool = False,
):
    """
    Perform voice activity detection on given audio files using the silero VAD model -
    https://github.com/snakers4/silero-vad. The end result is a dictionary with the file names as keys and their
    VAD timestamps dictionaries as value.

    For example::

        {
            "file_1.wav": [
                {"start": 0, "end": 16000},
                {"start": 16000, "end": 32000},
                {"start": 32000, "end": 48000},
                ...
            ],
            "file_2.wav": [
                {"start": 0, "end": 16000},
                {"start": 16000, "end": 32000},
                {"start": 32000, "end": 48000},
                ...
            ],
            ...
        }


    :param data_path:               The path to the audio files to diarize. Can be a path to a single file, a path to a
                                    directory or a list of paths to files.
    :param use_onnx:                Whether to use ONNX for inference. Default is True.
    :param force_onnx_cpu:          Whether to force ONNX to use CPU for inference. Default is True.
    :param threshold:               Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
                                    probabilities ABOVE this value are considered as SPEECH. It is better to tune
                                    this parameter for each dataset separately, but "lazy" 0.5 is pretty good for
                                    most datasets.
    :param sampling_rate:           Currently, silero VAD models support 8000 and 16000 sample rates.
    :param min_speech_duration_ms:  Final speech chunks shorter min_speech_duration_ms are thrown out.
    :param max_speech_duration_s:   Maximum duration of speech chunks in seconds. Chunks longer than
                                    `max_speech_duration_s` will be split at the timestamp of the last silence that
                                    lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will
                                    be split aggressively just before max_speech_duration_s.
    :param min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms before separating
                                    it.
    :param window_size_samples:     Audio chunks of window_size_samples size are fed to the silero VAD model.

                                    WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000
                                    sample rate and 256, 512, 768 samples for 8000 sample rate. Values other than
                                    these may affect model performance!
    :param speech_pad_ms:           Final speech chunks are padded by speech_pad_ms each side.
    :param return_seconds:          Whether return timestamps in seconds. False means to return timestamps in samples
                                    (default - False).
    :param per_channel:             Whether to return timestamps per channel (default - False). This will run VAD on
                                    each channel separately and return a list of timestamps per channel.
    :param use_multiprocessing:     The number of workers to use for multiprocessing. If 0, no multiprocessing will
                                    be used. Default is 0.
    :param verbose:                 Verbosity.
    """
    global _LOGGER

    # Get the input audio files to transcribe:
    if verbose:
        _LOGGER.info("Collecting audio files.")
    audio_files = _get_audio_files(data_path=data_path)
    if verbose:
        _LOGGER.info(f"Collected {len(audio_files)} audio files.")

    # Initialize the transcription pipeline:
    vad_init_kwargs = {
        "use_onnx": use_onnx,
        "force_onnx_cpu": force_onnx_cpu,
        "threshold": threshold,
        "sampling_rate": sampling_rate,
        "min_speech_duration_ms": min_speech_duration_ms,
        "max_speech_duration_s": max_speech_duration_s,
        "min_silence_duration_ms": min_silence_duration_ms,
        "window_size_samples": window_size_samples,
        "speech_pad_ms": speech_pad_ms,
        "return_seconds": return_seconds,
        "per_channel": per_channel,
    }

    # Create the task creator:
    task_creator = TaskCreator(task_type=BaseTask)

    # Run the transcription:
    if use_multiprocessing:
        results = _parallel_run(
            n_workers=use_multiprocessing,
            audio_files=audio_files,
            description="Detecting voice",
            vad_init_kwargs=vad_init_kwargs,
            task_creator=task_creator,
            verbose=verbose,
        )
    else:
        results = _run(
            audio_files=audio_files,
            description="Detecting voice",
            vad_init_kwargs=vad_init_kwargs,
            task_creator=task_creator,
            verbose=verbose,
        )

    # Process the results:
    return _process_results(results=results, verbose=verbose)


def diarize(
    # Input / Output kwargs:
    data_path: Union[str, Path, List[Union[str, Path]]],
    # Model loading kwargs:
    use_onnx: bool = True,
    force_onnx_cpu: bool = True,
    # Detection kwargs:
    threshold: float = 0.5,
    sampling_rate: int = 16_000,
    min_speech_duration_ms: int = 250,
    max_speech_duration_s: float = float("inf"),
    min_silence_duration_ms: int = 100,
    window_size_samples: int = 512,
    speech_pad_ms: int = 30,
    # Diarization kwargs:
    speaker_labels: List[str] = None,
    # Other kwargs:
    use_multiprocessing: int = 0,
    verbose: bool = False,
):
    """
    Perform speech diarization on given audio files using the silero VAD model - https://github.com/snakers4/silero-vad.
    The speech diarization is performed per channel so that each channel in the audio belong to a different speaker. The
    end result is a dictionary with the file names as keys and their diarization as value. A diarization is a list
    of tuples: (start, end, speaker_label).

    For example::

        {
            "file_1.wav": [
                (0.0, 1.0, "speaker_0"),
                (1.0, 2.0, "speaker_1"),
                (2.0, 3.0, "speaker_0"),
                ...
            ],
            "file_2.wav": [
                (0.0, 1.0, "speaker_0"),
                (1.0, 2.0, "speaker_1"),
                (2.0, 3.0, "speaker_0"),
                ...
            ],
            ...
        }


    :param data_path:               The path to the audio files to diarize. Can be a path to a single file, a path to a
                                    directory or a list of paths to files.
    :param use_onnx:                Whether to use ONNX for inference. Default is True.
    :param force_onnx_cpu:          Whether to force ONNX to use CPU for inference. Default is True.
    :param threshold:               Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
                                    probabilities ABOVE this value are considered as SPEECH. It is better to tune
                                    this parameter for each dataset separately, but "lazy" 0.5 is pretty good for
                                    most datasets.
    :param sampling_rate:           Currently, silero VAD models support 8000 and 16000 sample rates.
    :param min_speech_duration_ms:  Final speech chunks shorter min_speech_duration_ms are thrown out.
    :param max_speech_duration_s:   Maximum duration of speech chunks in seconds. Chunks longer than
                                    `max_speech_duration_s` will be split at the timestamp of the last silence that
                                    lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will
                                    be split aggressively just before max_speech_duration_s.
    :param min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms before separating
                                    it.
    :param window_size_samples:     Audio chunks of window_size_samples size are fed to the silero VAD model.

                                    WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000
                                    sample rate and 256, 512, 768 samples for 8000 sample rate. Values other than
                                    these may affect model performance!
    :param speech_pad_ms:           Final speech chunks are padded by speech_pad_ms each side.
    :param speaker_labels:          The speaker labels to use for the diarization. If not given, the speakers will be
                                    named "speaker_0", "speaker_1", etc.
    :param use_multiprocessing:     The number of workers to use for multiprocessing. If 0, no multiprocessing will
                                    be used. Default is 0.
    :param verbose:                 Verbosity.
    """
    global _LOGGER

    # Get the input audio files to transcribe:
    if verbose:
        _LOGGER.info("Collecting audio files.")
    audio_files = _get_audio_files(data_path=data_path)
    if verbose:
        _LOGGER.info(f"Collected {len(audio_files)} audio files.")

    # Initialize the transcription pipeline:
    vad_init_kwargs = {
        "use_onnx": use_onnx,
        "force_onnx_cpu": force_onnx_cpu,
        "threshold": threshold,
        "sampling_rate": sampling_rate,
        "min_speech_duration_ms": min_speech_duration_ms,
        "max_speech_duration_s": max_speech_duration_s,
        "min_silence_duration_ms": min_silence_duration_ms,
        "window_size_samples": window_size_samples,
        "speech_pad_ms": speech_pad_ms,
        "return_seconds": True,
        "per_channel": True,
    }

    # Create the task creator:
    task_creator = TaskCreator(
        task_type=SpeechDiarizationTask, task_kwargs={"speaker_labels": speaker_labels}
    )

    # Run the transcription:
    if use_multiprocessing:
        results = _parallel_run(
            n_workers=use_multiprocessing,
            audio_files=audio_files,
            description="Diarizing",
            vad_init_kwargs=vad_init_kwargs,
            task_creator=task_creator,
            verbose=verbose,
        )
    else:
        results = _run(
            audio_files=audio_files,
            description="Diarizing",
            vad_init_kwargs=vad_init_kwargs,
            task_creator=task_creator,
            verbose=verbose,
        )

    # Process the results:
    return _process_results(results=results, verbose=verbose)


def _get_audio_files(
    data_path: Union[Path, str, list],
) -> List[Path]:
    """
    Get the audio files from the data path. If a path to a directory is given, all files in the directory will be
    collected.

    :param data_path: The data path to collect the audio files from.

    :returns: The audio files list.
    """
    # Check if given a list of paths:
    if isinstance(data_path, list):
        audio_files = []
        for path in data_path:
            audio_files.extend(_get_audio_files(data_path=path))
        return audio_files

    # Check if given a single string path to cast it to a `pathlib.Path`:
    if isinstance(data_path, str):
        data_path = Path(data_path).absolute()

    # Check if the path is of a directory or a file:
    if data_path.is_dir():
        # Get all files inside the directory:
        audio_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        audio_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be a valid path to either a directory path or a "
            f"file. Given: {str(data_path)} "
        )

    return audio_files


def _run(
    audio_files: List[Path],
    description: str,
    vad_init_kwargs: dict,
    task_creator: TaskCreator,
    verbose: bool,
) -> List[Tuple[bool, Tuple[str, list]]]:
    """
    Load a VAD and use it to complete the tasks that will be created on the provided files using the given task creator.

    :param audio_files:     The audio files to use.
    :param description:     The description to use for the progress bar.
    :param vad_init_kwargs: The VAD initialization keyword arguments.
    :param task_creator:    The task creator to use to create the tasks.
    :param verbose:         Verbosity.

    :returns: The collected results.
    """
    # Load the VAD:
    vad = VoiceActivityDetector(**vad_init_kwargs)
    if verbose:
        _LOGGER.info(f"Loading the VAD model.")
    vad.load()
    if verbose:
        _LOGGER.info("VAD model loaded.")

    # Run the VAD on the audio files and collect the results:
    results = []
    for audio_file in tqdm(
        audio_files,
        desc=description,
        unit="file",
        total=len(audio_files),
        disable=not verbose,
    ):
        try:
            # Create the task:
            task = task_creator.create_task(audio_file=audio_file)
            # Run the file through the VAD:
            speech_timestamps = vad.detect_voice(audio_file=audio_file)
            # Complete the task:
            task.do_task(speech_timestamps=speech_timestamps)
            # Collect the result:
            results.append((False, task.get_result()))
        except Exception as exception:
            # Collect the error:
            results.append((True, (audio_file.name, str(exception))))

    return results


def _parallel_run(
    n_workers: int,
    audio_files: List[Path],
    description: str,
    vad_init_kwargs: dict,
    task_creator: TaskCreator,
    verbose: bool,
) -> List[Tuple[bool, Tuple[str, list]]]:
    """
    Run multiple VAD workers with multiprocessing to complete the tasks that will be created on the provided files using
    the given task creator.

    :param n_workers:       The number of workers to use.
    :param audio_files:     The audio files to use.
    :param description:     The description to use for the progress bar.
    :param vad_init_kwargs: The VAD initialization keyword arguments.
    :param task_creator:    The task creator to use to create the tasks.
    :param verbose:         Verbosity.

    :returns: The collected results.
    """
    # Load the VAD (download once, and it will be loaded then per process later on):
    if verbose:
        _LOGGER.info(f"Loading the VAD model.")
    vad = VoiceActivityDetector(**vad_init_kwargs)
    vad.load()
    if verbose:
        _LOGGER.info("VAD model loaded.")

    # Check the number of workers:
    if n_workers > len(audio_files):
        _LOGGER.warning(
            f"The number of workers ({n_workers}) is larger than the number of audio files ({len(audio_files)}). "
            f"Setting the number of workers to {len(audio_files)}."
        )
        n_workers = len(audio_files)

    # Initialize the multiprocessing queues:
    tasks_queue = Queue()
    results_queue = Queue()

    # Initialize the multiprocessing processes:
    task_completion_processes = [
        Process(
            target=_multiprocessing_complete_tasks,
            kwargs={
                "vad_init_kwargs": vad_init_kwargs,
                "tasks_queue": tasks_queue,
                "results_queue": results_queue,
            },
        )
        for _ in range(n_workers)
    ]

    # Start the multiprocessing processes:
    for p in task_completion_processes:
        p.start()

    # Put the tasks in the queue:
    for audio_file in audio_files:
        tasks_queue.put(task_creator.create_task(audio_file=audio_file).to_tuple())

    # Put the stop marks in the queue:
    for _ in range(n_workers):
        tasks_queue.put(_MULTIPROCESSING_STOP_MARK)

    # Collect the results:
    results = []
    stop_marks_counter = 0
    with tqdm(
        desc=description,
        unit="file",
        total=len(audio_files),
        disable=not verbose,
    ) as progressbar:
        while True:
            # Get a result from the queue:
            result: Tuple[bool, Tuple[str, list]] = results_queue.get()
            if result == _MULTIPROCESSING_STOP_MARK:
                stop_marks_counter += 1
                if stop_marks_counter == n_workers:
                    break
            else:
                # Collect the result:
                results.append(result)
                progressbar.update(1)

    # Wait for the processes to finish:
    for p in task_completion_processes:
        p.join()

    return results


def _process_results(
    results: List[Tuple[bool, Tuple[str, list]]], verbose: bool
) -> Tuple[dict, dict]:
    """
    Process the results of the tasks.

    :param results: The results to process.
    :param verbose: Verbosity.

    :returns: The processed results as a tuple of successes and errors.
    """
    if verbose:
        _LOGGER.info("Summarizing the results.")
    successes = {}
    errors = {}
    for is_error, result in results:
        if is_error:
            errors[result[0]] = result[1]
        else:
            successes[result[0]] = result[1]
    if verbose:
        _LOGGER.info(f"Done ({len(successes)}/{len(successes) + len(errors)})\n")

    return successes, errors
 - origin_filename: '' - image: '' - command: '' + code_origin: '' + base_image: mlrun/mlrun + filename: silero_vad.py entry_points: audio_file: - doc: Get the audio file of the task. - lineno: 43 - has_varargs: false outputs: - doc: The audio file of the task. type: Path parameters: - name: self - has_kwargs: false name: audio_file - do_task: - doc: Do the task on the given speech timestamps. The task will diarize the VAD - speech timestamps into speakers. - lineno: 94 + doc: Get the audio file of the task. + has_kwargs: false has_varargs: false + lineno: 42 + do_task: parameters: - name: self - name: speech_timestamps - type: List[List[Dict[str, int]]] + type: list[list[dict[str, int]]] doc: The speech timestamps per channel to do the task on as outputted from the VAD. - has_kwargs: false name: do_task - get_result: - doc: Get the result of the task. A tuple of the audio file name and the result. - lineno: 61 + doc: Do the task on the given speech timestamps. The task will diarize the VAD + speech timestamps into speakers. + has_kwargs: false has_varargs: false + lineno: 93 + get_result: outputs: - doc: The result of the task. - type: Tuple[str, list] + type: tuple[str, list] parameters: - name: self - has_kwargs: false name: get_result - to_tuple: - doc: Convert the task to a tuple to reconstruct it later (used for multiprocessing - to pass in queue). - lineno: 116 + doc: Get the result of the task. A tuple of the audio file name and the result. + has_kwargs: false has_varargs: false + lineno: 60 + to_tuple: outputs: - doc: The converted task. - type: Tuple[str, dict] + type: tuple[str, dict] parameters: - name: self - has_kwargs: false name: to_tuple - create_task: - doc: Create a task with the given audio file. - lineno: 146 + doc: Convert the task to a tuple to reconstruct it later (used for multiprocessing + to pass in queue). + has_kwargs: false has_varargs: false + lineno: 115 + create_task: outputs: - doc: The created task. type: BaseTask @@ -79,26 +77,26 @@ spec: - name: audio_file type: Path doc: The audio file to assign to the task. - has_kwargs: false name: create_task - from_tuple: - doc: Create a task from a tuple of the audio file name and the task kwargs. - lineno: 157 + doc: Create a task with the given audio file. + has_kwargs: false has_varargs: false + lineno: 145 + from_tuple: outputs: - doc: The created task. type: BaseTask parameters: - name: cls - name: task_tuple - type: Tuple[str, dict] + type: tuple[str, dict] doc: The task tuple to create the task from. - has_kwargs: false name: from_tuple - load: - doc: Load the VAD model. - lineno: 234 + doc: Create a task from a tuple of the audio file name and the task kwargs. + has_kwargs: false has_varargs: false + lineno: 156 + load: parameters: - name: self - name: force_reload @@ -106,24 +104,14 @@ spec: doc: Whether to force reload the model even if it was already loaded. Default is True. default: true - has_kwargs: false name: load - detect_voice: - doc: "Perform voice activity detection on given audio files using the silero\ - \ VAD model -\nhttps://github.com/snakers4/silero-vad. The end result is a\ - \ dictionary with the file names as keys and their\nVAD timestamps dictionaries\ - \ as value.\n\nFor example::\n\n {\n \"file_1.wav\": [\n \ - \ {\"start\": 0, \"end\": 16000},\n {\"start\": 16000, \"end\"\ - : 32000},\n {\"start\": 32000, \"end\": 48000},\n ...\n\ - \ ],\n \"file_2.wav\": [\n {\"start\": 0, \"end\"\ - : 16000},\n {\"start\": 16000, \"end\": 32000},\n {\"\ - start\": 32000, \"end\": 48000},\n ...\n ],\n ...\n\ - \ }" - lineno: 393 + doc: Load the VAD model. + has_kwargs: false has_varargs: false + lineno: 233 + detect_voice: parameters: - name: data_path - type: Union[str, Path, List[Union[str, Path]]] doc: The path to the audio files to diarize. Can be a path to a single file, a path to a directory or a list of paths to files. - name: use_onnx @@ -188,25 +176,23 @@ spec: type: bool doc: Verbosity. default: false - has_kwargs: false name: detect_voice - diarize: - doc: "Perform speech diarization on given audio files using the silero VAD model\ - \ - https://github.com/snakers4/silero-vad.\nThe speech diarization is performed\ - \ per channel so that each channel in the audio belong to a different speaker.\ - \ The\nend result is a dictionary with the file names as keys and their diarization\ - \ as value. A diarization is a list\nof tuples: (start, end, speaker_label).\n\ - \nFor example::\n\n {\n \"file_1.wav\": [\n (0.0, 1.0,\ - \ \"speaker_0\"),\n (1.0, 2.0, \"speaker_1\"),\n (2.0,\ - \ 3.0, \"speaker_0\"),\n ...\n ],\n \"file_2.wav\"\ - : [\n (0.0, 1.0, \"speaker_0\"),\n (1.0, 2.0, \"speaker_1\"\ - ),\n (2.0, 3.0, \"speaker_0\"),\n ...\n ],\n\ - \ ...\n }" - lineno: 517 + doc: "Perform voice activity detection on given audio files using the silero\ + \ VAD model -\nhttps://github.com/snakers4/silero-vad. The end result is a\ + \ dictionary with the file names as keys and their\nVAD timestamps dictionaries\ + \ as value.\n\nFor example::\n\n {\n \"file_1.wav\": [\n \ + \ {\"start\": 0, \"end\": 16000},\n {\"start\": 16000, \"end\"\ + : 32000},\n {\"start\": 32000, \"end\": 48000},\n ...\n\ + \ ],\n \"file_2.wav\": [\n {\"start\": 0, \"end\"\ + : 16000},\n {\"start\": 16000, \"end\": 32000},\n {\"\ + start\": 32000, \"end\": 48000},\n ...\n ],\n ...\n\ + \ }" + has_kwargs: false has_varargs: false + lineno: 392 + diarize: parameters: - name: data_path - type: Union[str, Path, List[Union[str, Path]]] doc: The path to the audio files to diarize. Can be a path to a single file, a path to a directory or a list of paths to files. - name: use_onnx @@ -253,7 +239,7 @@ spec: doc: Final speech chunks are padded by speech_pad_ms each side. default: 30 - name: speaker_labels - type: List[str] + type: list[str] doc: The speaker labels to use for the diarization. If not given, the speakers will be named "speaker_0", "speaker_1", etc. default: null @@ -266,8 +252,21 @@ spec: type: bool doc: Verbosity. default: false - has_kwargs: false name: diarize - disable_auto_mount: false + doc: "Perform speech diarization on given audio files using the silero VAD model\ + \ - https://github.com/snakers4/silero-vad.\nThe speech diarization is performed\ + \ per channel so that each channel in the audio belong to a different speaker.\ + \ The\nend result is a dictionary with the file names as keys and their diarization\ + \ as value. A diarization is a list\nof tuples: (start, end, speaker_label).\n\ + \nFor example::\n\n {\n \"file_1.wav\": [\n (0.0, 1.0,\ + \ \"speaker_0\"),\n (1.0, 2.0, \"speaker_1\"),\n (2.0,\ + \ 3.0, \"speaker_0\"),\n ...\n ],\n \"file_2.wav\"\ + : [\n (0.0, 1.0, \"speaker_0\"),\n (1.0, 2.0, \"speaker_1\"\ + ),\n (2.0, 3.0, \"speaker_0\"),\n ...\n ],\n\ + \ ...\n }" + has_kwargs: false + has_varargs: false + lineno: 516 + command: '' + description: Silero VAD (Voice Activity Detection) functions. default_handler: detect_voice -kind: job diff --git a/functions/src/silero_vad/silero_vad.py b/functions/src/silero_vad/silero_vad.py index a477d4ecf..877f49972 100644 --- a/functions/src/silero_vad/silero_vad.py +++ b/functions/src/silero_vad/silero_vad.py @@ -15,7 +15,6 @@ from multiprocessing import Process, Queue from pathlib import Path from types import FunctionType -from typing import Dict, List, Tuple, Type, Union import torch import torchaudio @@ -49,7 +48,7 @@ def audio_file(self) -> Path: return self._audio_file def do_task( - self, speech_timestamps: Union[List[Dict[str, int]], List[List[Dict[str, int]]]] + self, speech_timestamps: list[dict[str, int]] | list[list[dict[str, int]]] ): """ Do the task on the given speech timestamps. The base task will simply save the speech timestamps as the result. @@ -58,7 +57,7 @@ def do_task( """ self._result = speech_timestamps - def get_result(self) -> Tuple[str, list]: + def get_result(self) -> tuple[str, list]: """ Get the result of the task. A tuple of the audio file name and the result. @@ -66,7 +65,7 @@ def get_result(self) -> Tuple[str, list]: """ return self._audio_file.name, self._result - def to_tuple(self) -> Tuple[str, dict]: + def to_tuple(self) -> tuple[str, dict]: """ Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue). @@ -80,7 +79,7 @@ class SpeechDiarizationTask(BaseTask): A speech diarization task. The task will diarize the VAD speech timestamps into speakers. """ - def __init__(self, audio_file: Path, speaker_labels: List[str]): + def __init__(self, audio_file: Path, speaker_labels: list[str]): """ Initialize the speech diarization task. @@ -91,7 +90,7 @@ def __init__(self, audio_file: Path, speaker_labels: List[str]): super().__init__(audio_file=audio_file) self._speaker_labels = speaker_labels - def do_task(self, speech_timestamps: List[List[Dict[str, int]]]): + def do_task(self, speech_timestamps: list[list[dict[str, int]]]): """ Do the task on the given speech timestamps. The task will diarize the VAD speech timestamps into speakers. @@ -113,7 +112,7 @@ def do_task(self, speech_timestamps: List[List[Dict[str, int]]]): speech_diarization.sort() self._result = speech_diarization - def to_tuple(self) -> Tuple[str, dict]: + def to_tuple(self) -> tuple[str, dict]: """ Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue). @@ -134,7 +133,7 @@ class TaskCreator: SpeechDiarizationTask.__name__: SpeechDiarizationTask, } - def __init__(self, task_type: Type[BaseTask], task_kwargs: dict = None): + def __init__(self, task_type: type[BaseTask], task_kwargs: dict = None): """ Initialize the task creator. :param task_type: The task type - a `BaseTask` subclass. @@ -154,7 +153,7 @@ def create_task(self, audio_file: Path) -> BaseTask: return self._task_type(audio_file=audio_file, **self._task_kwargs) @classmethod - def from_tuple(cls, task_tuple: Tuple[str, dict]) -> BaseTask: + def from_tuple(cls, task_tuple: tuple[str, dict]) -> BaseTask: """ Create a task from a tuple of the audio file name and the task kwargs. @@ -256,7 +255,7 @@ def load(self, force_reload: bool = True): def detect_voice( self, audio_file: Path, - ) -> Union[List[Dict[str, int]], List[List[Dict[str, int]]]]: + ) -> list[dict[str, int]] | list[list[dict[str, int]]]: """ Infer the audio through the VAD model and return the speech timestamps. @@ -359,7 +358,7 @@ def _multiprocessing_complete_tasks( # Start listening to the tasks queue: while True: # Get the task: - task: Tuple[str, dict] = tasks_queue.get() + task: tuple[str, dict] = tasks_queue.get() if task == _MULTIPROCESSING_STOP_MARK: break try: @@ -392,7 +391,7 @@ def _multiprocessing_complete_tasks( def detect_voice( # Input kwargs: - data_path: Union[str, Path, List[Union[str, Path]]], + data_path: str | Path | list[str | Path], # Model loading kwargs: use_onnx: bool = True, force_onnx_cpu: bool = True, @@ -516,7 +515,7 @@ def detect_voice( def diarize( # Input / Output kwargs: - data_path: Union[str, Path, List[Union[str, Path]]], + data_path: str | Path | list[str | Path], # Model loading kwargs: use_onnx: bool = True, force_onnx_cpu: bool = True, @@ -529,7 +528,7 @@ def diarize( window_size_samples: int = 512, speech_pad_ms: int = 30, # Diarization kwargs: - speaker_labels: List[str] = None, + speaker_labels: list[str] = None, # Other kwargs: use_multiprocessing: int = 0, verbose: bool = False, @@ -640,8 +639,8 @@ def diarize( def _get_audio_files( - data_path: Union[Path, str, list], -) -> List[Path]: + data_path: Path | str | list, +) -> list[Path]: """ Get the audio files from the data path. If a path to a directory is given, all files in the directory will be collected. @@ -677,12 +676,12 @@ def _get_audio_files( def _run( - audio_files: List[Path], + audio_files: list[Path], description: str, vad_init_kwargs: dict, task_creator: TaskCreator, verbose: bool, -) -> List[Tuple[bool, Tuple[str, list]]]: +) -> list[tuple[bool, tuple[str, list]]]: """ Load a VAD and use it to complete the tasks that will be created on the provided files using the given task creator. @@ -697,7 +696,7 @@ def _run( # Load the VAD: vad = VoiceActivityDetector(**vad_init_kwargs) if verbose: - _LOGGER.info(f"Loading the VAD model.") + _LOGGER.info("Loading the VAD model.") vad.load() if verbose: _LOGGER.info("VAD model loaded.") @@ -729,12 +728,12 @@ def _run( def _parallel_run( n_workers: int, - audio_files: List[Path], + audio_files: list[Path], description: str, vad_init_kwargs: dict, task_creator: TaskCreator, verbose: bool, -) -> List[Tuple[bool, Tuple[str, list]]]: +) -> list[tuple[bool, tuple[str, list]]]: """ Run multiple VAD workers with multiprocessing to complete the tasks that will be created on the provided files using the given task creator. @@ -750,7 +749,7 @@ def _parallel_run( """ # Load the VAD (download once, and it will be loaded then per process later on): if verbose: - _LOGGER.info(f"Loading the VAD model.") + _LOGGER.info("Loading the VAD model.") vad = VoiceActivityDetector(**vad_init_kwargs) vad.load() if verbose: @@ -804,7 +803,7 @@ def _parallel_run( ) as progressbar: while True: # Get a result from the queue: - result: Tuple[bool, Tuple[str, list]] = results_queue.get() + result: tuple[bool, tuple[str, list]] = results_queue.get() if result == _MULTIPROCESSING_STOP_MARK: stop_marks_counter += 1 if stop_marks_counter == n_workers: @@ -822,8 +821,8 @@ def _parallel_run( def _process_results( - results: List[Tuple[bool, Tuple[str, list]]], verbose: bool -) -> Tuple[dict, dict]: + results: list[tuple[bool, tuple[str, list]]], verbose: bool +) -> tuple[dict, dict]: """ Process the results of the tasks. diff --git a/functions/src/sklearn_classifier/function.yaml b/functions/src/sklearn_classifier/function.yaml index 205df697d..80b257214 100644 --- a/functions/src/sklearn_classifier/function.yaml +++ b/functions/src/sklearn_classifier/function.yaml @@ -1,10 +1,23 @@ +metadata: + tag: '' + name: sklearn-classifier + categories: + - machine-learning + - model-training +verbose: false +kind: job spec: image: mlrun/mlrun - description: train any classifier using scikit-learn's API - default_handler: train_model + disable_auto_mount: false + build: + origin_filename: '' + functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKIyBHZW5lcmF0ZWQgYnkgbnVjbGlvLmV4cG9ydC5OdWNsaW9FeHBvcnRlcgoKaW1wb3J0IHdhcm5pbmdzCgp3YXJuaW5ncy5zaW1wbGVmaWx0ZXIoYWN0aW9uPSJpZ25vcmUiLCBjYXRlZ29yeT1GdXR1cmVXYXJuaW5nKQoKCmltcG9ydCBwYW5kYXMgYXMgcGQKZnJvbSBjbG91ZHBpY2tsZSBpbXBvcnQgZHVtcHMKZnJvbSBtbHJ1bi5kYXRhc3RvcmUgaW1wb3J0IERhdGFJdGVtCmZyb20gbWxydW4uZXhlY3V0aW9uIGltcG9ydCBNTENsaWVudEN0eApmcm9tIG1scnVuLm1sdXRpbHMuZGF0YSBpbXBvcnQgZ2V0X3NhbXBsZSwgZ2V0X3NwbGl0cwpmcm9tIG1scnVuLm1sdXRpbHMubW9kZWxzIGltcG9ydCBldmFsX21vZGVsX3YyLCBnZW5fc2tsZWFybl9tb2RlbApmcm9tIG1scnVuLnV0aWxzLmhlbHBlcnMgaW1wb3J0IGNyZWF0ZV9jbGFzcwoKCmRlZiB0cmFpbl9tb2RlbCgKICAgIGNvbnRleHQ6IE1MQ2xpZW50Q3R4LAogICAgbW9kZWxfcGtnX2NsYXNzOiBzdHIsCiAgICBkYXRhc2V0OiBEYXRhSXRlbSwKICAgIGxhYmVsX2NvbHVtbjogc3RyID0gImxhYmVscyIsCiAgICBlbmNvZGVfY29sczogbGlzdFtzdHJdID0gW10sCiAgICBzYW1wbGU6IGludCA9IC0xLAogICAgdGVzdF9zaXplOiBmbG9hdCA9IDAuMzAsCiAgICB0cmFpbl92YWxfc3BsaXQ6IGZsb2F0ID0gMC43MCwKICAgIHRlc3Rfc2V0X2tleTogc3RyID0gInRlc3Rfc2V0IiwKICAgIG1vZGVsX2V2YWx1YXRvcj1Ob25lLAogICAgbW9kZWxzX2Rlc3Q6IHN0ciA9ICIiLAogICAgcGxvdHNfZGVzdDogc3RyID0gInBsb3RzIiwKICAgIGZpbGVfZXh0OiBzdHIgPSAicGFycXVldCIsCiAgICBtb2RlbF9wa2dfZmlsZTogc3RyID0gIiIsCiAgICByYW5kb21fc3RhdGU6IGludCA9IDEsCikgLT4gTm9uZToKICAgICIiInRyYWluIGEgY2xhc3NpZmllcgoKICAgIEFuIG9wdGlvbmFsIGN1dG9tIG1vZGVsIGV2YWx1YXRvciBjYW4gYmUgc3VwcGxpZWQgdGhhdCBzaG91bGQgaGF2ZSB0aGUgc2lnbmF0dXJlOgogICAgYG15X2N1c3RvbV9ldmFsdWF0b3IoY29udGV4dCwgeHZhbGlkLCB5dmFsaWQsIG1vZGVsKWAgYW5kIHJldHVybiBhIGRpY3Rpb25hcnkgb2YKICAgIHNjYWxhciAicmVzdWx0cyIsIGEgInBsb3RzIiBrZXlzIHdpdGggYSBsaXN0IG9mIFBsb3RBcnRpZmFjdHMsIGFuZAogICAgYW5kICJ0YWJsZXMiIGtleSBjb250YWluaW5nIGEgcmV0dXJuZWQgbGlzdCBvZiBUYWJsZUFydGlmYWN0cy4KCiAgICA6cGFyYW0gY29udGV4dDogICAgICAgICAgIHRoZSBmdW5jdGlvbiBjb250ZXh0CiAgICA6cGFyYW0gbW9kZWxfcGtnX2NsYXNzOiAgIHRoZSBtb2RlbCB0byB0cmFpbiwgZS5nLCAic2tsZWFybi5uZXVyYWxfbmV0d29ya3MuTUxQQ2xhc3NpZmllciIsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIG9yIGpzb24gbW9kZWwgY29uZmlnCiAgICA6cGFyYW0gZGF0YXNldDogICAgICAgICAgICgiZGF0YSIpIG5hbWUgb2YgcmF3IGRhdGEgZmlsZQogICAgOnBhcmFtIGxhYmVsX2NvbHVtbjogICAgICBncm91bmQtdHJ1dGggKHkpIGxhYmVscwogICAgOnBhcmFtIGVuY29kZV9jb2xzOiAgICAgICBkaWN0aW9uYXJ5IG9mIG5hbWVzIGFuZCBwcmVmaXhlcyBmb3IgY29sdW1ucyB0aGF0IGFyZQogICAgICAgICAgICAgICAgICAgICAgICAgICAgICB0byBob3QgYmUgZW5jb2RlZC4KICAgIDpwYXJhbSBzYW1wbGU6ICAgICAgICAgICAgU2VsZWN0cyB0aGUgZmlyc3QgbiByb3dzLCBvciBzZWxlY3QgYSBzYW1wbGUKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgc3RhcnRpbmcgZnJvbSB0aGUgZmlyc3QuIElmIG5lZ2F0aXZlIDwtMSwgc2VsZWN0CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGEgcmFuZG9tIHNhbXBsZQogICAgOnBhcmFtIHRlc3Rfc2l6ZTogICAgICAgICAoMC4wNSkgdGVzdCBzZXQgc2l6ZQogICAgOnBhcmFtIHRyYWluX3ZhbF9zcGxpdDogICAoMC43NSkgT25jZSB0aGUgdGVzdCBzZXQgaGFzIGJlZW4gcmVtb3ZlZCB0aGUKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgdHJhaW5pbmcgc2V0IGdldHMgdGhpcyBwcm9wb3J0aW9uLgogICAgOnBhcmFtIHRlc3Rfc2V0X2tleTogICAgICBrZXkgb2YgaGVsZCBvdXQgZGF0YSBpbiBhcnRpZmFjdCBzdG9yZQogICAgOnBhcmFtIG1vZGVsX2V2YWx1YXRvcjogICAoTm9uZSkgYSBjdXN0b20gbW9kZWwgZXZhbHVhdG9yIGNhbiBiZSBzcGVjaWZpZWQKICAgIDpwYXJhbSBtb2RlbHNfZGVzdDogICAgICAgKCIiKSBtb2RlbHMgc3ViZm9sZGVyIG9uIGFydGlmYWN0IHBhdGgKICAgIDpwYXJhbSBwbG90c19kZXN0OiAgICAgICAgcGxvdCBzdWJmb2xkZXIgb24gYXJ0aWZhY3QgcGF0aAogICAgOnBhcmFtIGZpbGVfZXh0OiAgICAgICAgICAoInBhcnF1ZXQiKSBmb3JtYXQgZm9yIHRlc3Rfc2V0X2tleSBob2xkIG91dCBkYXRhCiAgICA6cGFyYW0gcmFuZG9tX3N0YXRlOiAgICAgICgxKSBza2xlYXJuIHJuZyBzZWVkCgogICAgIiIiCiAgICBtb2RlbHNfZGVzdCA9IG1vZGVsc19kZXN0IG9yICJtb2RlbCIKCiAgICByYXcsIGxhYmVscywgaGVhZGVyID0gZ2V0X3NhbXBsZShkYXRhc2V0LCBzYW1wbGUsIGxhYmVsX2NvbHVtbikKCiAgICBpZiBlbmNvZGVfY29sczoKICAgICAgICByYXcgPSBwZC5nZXRfZHVtbWllcygKICAgICAgICAgICAgcmF3LAogICAgICAgICAgICBjb2x1bW5zPWxpc3QoZW5jb2RlX2NvbHMua2V5cygpKSwKICAgICAgICAgICAgcHJlZml4PWxpc3QoZW5jb2RlX2NvbHMudmFsdWVzKCkpLAogICAgICAgICAgICBkcm9wX2ZpcnN0PVRydWUsCiAgICAgICAgKQoKICAgICh4dHJhaW4sIHl0cmFpbiksICh4dmFsaWQsIHl2YWxpZCksICh4dGVzdCwgeXRlc3QpID0gZ2V0X3NwbGl0cygKICAgICAgICByYXcsIGxhYmVscywgMywgdGVzdF9zaXplLCAxIC0gdHJhaW5fdmFsX3NwbGl0LCByYW5kb21fc3RhdGUKICAgICkKCiAgICB0ZXN0X3NldCA9IHBkLmNvbmNhdChbeHRlc3QsIHl0ZXN0LnRvX2ZyYW1lKCldLCBheGlzPTEpCiAgICBjb250ZXh0LmxvZ19kYXRhc2V0KAogICAgICAgIHRlc3Rfc2V0X2tleSwKICAgICAgICBkZj10ZXN0X3NldCwKICAgICAgICBmb3JtYXQ9ZmlsZV9leHQsCiAgICAgICAgaW5kZXg9RmFsc2UsCiAgICAgICAgbGFiZWxzPXsiZGF0YS10eXBlIjogImhlbGQtb3V0In0sCiAgICAgICAgYXJ0aWZhY3RfcGF0aD1jb250ZXh0LmFydGlmYWN0X3N1YnBhdGgoImRhdGEiKSwKICAgICkKCiAgICBtb2RlbF9jb25maWcgPSBnZW5fc2tsZWFybl9tb2RlbChtb2RlbF9wa2dfY2xhc3MsIGNvbnRleHQucGFyYW1ldGVycy5pdGVtcygpKQoKICAgIG1vZGVsX2NvbmZpZ1siRklUIl0udXBkYXRlKHsiWCI6IHh0cmFpbiwgInkiOiB5dHJhaW4udmFsdWVzfSkKCiAgICBDbGFzc2lmaWVyQ2xhc3MgPSBjcmVhdGVfY2xhc3MobW9kZWxfY29uZmlnWyJNRVRBIl1bImNsYXNzIl0pCgogICAgbW9kZWwgPSBDbGFzc2lmaWVyQ2xhc3MoKiptb2RlbF9jb25maWdbIkNMQVNTIl0pCgogICAgbW9kZWwuZml0KCoqbW9kZWxfY29uZmlnWyJGSVQiXSkKCiAgICBhcnRpZmFjdF9wYXRoID0gY29udGV4dC5hcnRpZmFjdF9zdWJwYXRoKG1vZGVsc19kZXN0KQogICAgcGxvdHNfcGF0aCA9IGNvbnRleHQuYXJ0aWZhY3Rfc3VicGF0aChtb2RlbHNfZGVzdCwgcGxvdHNfZGVzdCkKICAgIGlmIG1vZGVsX2V2YWx1YXRvcjoKICAgICAgICBldmFsX21ldHJpY3MgPSBtb2RlbF9ldmFsdWF0b3IoCiAgICAgICAgICAgIGNvbnRleHQsIHh2YWxpZCwgeXZhbGlkLCBtb2RlbCwgcGxvdHNfYXJ0aWZhY3RfcGF0aD1wbG90c19wYXRoCiAgICAgICAgKQogICAgZWxzZToKICAgICAgICBldmFsX21ldHJpY3MgPSBldmFsX21vZGVsX3YyKAogICAgICAgICAgICBjb250ZXh0LCB4dmFsaWQsIHl2YWxpZCwgbW9kZWwsIHBsb3RzX2FydGlmYWN0X3BhdGg9cGxvdHNfcGF0aAogICAgICAgICkKCiAgICBrd2FyZ3MgPSB7InRyYWluaW5nX3NldCI6IHRlc3Rfc2V0LCAibGFiZWxfY29sdW1uIjogbGFiZWxfY29sdW1ufQogICAgc3BsaXQgPSBtb2RlbF9wa2dfY2xhc3MucnNwbGl0KCIuIiwgMSkKICAgIGlmIHNwbGl0IGFuZCBsZW4oc3BsaXQpID09IDI6CiAgICAgICAga3dhcmdzWyJhbGdvcml0aG0iXSA9IHNwbGl0WzFdCgogICAgaWYgZGF0YXNldC5tZXRhIGFuZCBkYXRhc2V0Lm1ldGEua2luZCA9PSAiRmVhdHVyZVZlY3RvciI6CiAgICAgICAga3dhcmdzWyJmZWF0dXJlX3ZlY3RvciJdID0gZGF0YXNldC5tZXRhLnVyaQoKICAgIGNvbnRleHQuc2V0X2xhYmVsKCJjbGFzcyIsIG1vZGVsX3BrZ19jbGFzcykKICAgIGNvbnRleHQubG9nX21vZGVsKAogICAgICAgICJtb2RlbCIsCiAgICAgICAgYm9keT1kdW1wcyhtb2RlbCksCiAgICAgICAgYXJ0aWZhY3RfcGF0aD1hcnRpZmFjdF9wYXRoLAogICAgICAgIGV4dHJhX2RhdGE9ZXZhbF9tZXRyaWNzLAogICAgICAgIG1vZGVsX2ZpbGU9Im1vZGVsLnBrbCIsCiAgICAgICAgbWV0cmljcz1jb250ZXh0LnJlc3VsdHMsCiAgICAgICAgbGFiZWxzPXsiY2xhc3MiOiBtb2RlbF9wa2dfY2xhc3N9LAogICAgICAgIGZyYW1ld29yaz0ic2tsZWFybiIsCiAgICAgICAgKiprd2FyZ3MsCiAgICApCg== + code_origin: '' + filename: sklearn_classifier.py entry_points: train_model: - has_varargs: false + outputs: + - type: None parameters: - name: context type: MLClientCtx @@ -21,14 +34,14 @@ spec: doc: ground-truth (y) labels default: labels - name: encode_cols - type: List[str] + type: list[str] doc: dictionary of names and prefixes for columns that are to hot be encoded. default: [] - name: sample type: int doc: Selects the first n rows, or select a sample starting from the first. If negative <-1, select a random sample - default: + default: - name: test_size type: float doc: (0.05) test set size @@ -76,21 +89,9 @@ spec: scalar "results", a "plots" keys with a list of PlotArtifacts, and and "tables" key containing a returned list of TableArtifacts.' - outputs: - - type: None - lineno: 32 has_kwargs: false - disable_auto_mount: false - build: - functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKIyBHZW5lcmF0ZWQgYnkgbnVjbGlvLmV4cG9ydC5OdWNsaW9FeHBvcnRlcgoKaW1wb3J0IHdhcm5pbmdzCgp3YXJuaW5ncy5zaW1wbGVmaWx0ZXIoYWN0aW9uPSJpZ25vcmUiLCBjYXRlZ29yeT1GdXR1cmVXYXJuaW5nKQoKCmZyb20gY2xvdWRwaWNrbGUgaW1wb3J0IGR1bXBzCmltcG9ydCBwYW5kYXMgYXMgcGQKZnJvbSB0eXBpbmcgaW1wb3J0IExpc3QKZnJvbSBtbHJ1bi5leGVjdXRpb24gaW1wb3J0IE1MQ2xpZW50Q3R4CmZyb20gbWxydW4uZGF0YXN0b3JlIGltcG9ydCBEYXRhSXRlbQpmcm9tIG1scnVuLm1sdXRpbHMuZGF0YSBpbXBvcnQgZ2V0X3NhbXBsZSwgZ2V0X3NwbGl0cwpmcm9tIG1scnVuLm1sdXRpbHMubW9kZWxzIGltcG9ydCBnZW5fc2tsZWFybl9tb2RlbCwgZXZhbF9tb2RlbF92Mgpmcm9tIG1scnVuLnV0aWxzLmhlbHBlcnMgaW1wb3J0IGNyZWF0ZV9jbGFzcwoKCmRlZiB0cmFpbl9tb2RlbCgKICAgIGNvbnRleHQ6IE1MQ2xpZW50Q3R4LAogICAgbW9kZWxfcGtnX2NsYXNzOiBzdHIsCiAgICBkYXRhc2V0OiBEYXRhSXRlbSwKICAgIGxhYmVsX2NvbHVtbjogc3RyID0gImxhYmVscyIsCiAgICBlbmNvZGVfY29sczogTGlzdFtzdHJdID0gW10sCiAgICBzYW1wbGU6IGludCA9IC0xLAogICAgdGVzdF9zaXplOiBmbG9hdCA9IDAuMzAsCiAgICB0cmFpbl92YWxfc3BsaXQ6IGZsb2F0ID0gMC43MCwKICAgIHRlc3Rfc2V0X2tleTogc3RyID0gInRlc3Rfc2V0IiwKICAgIG1vZGVsX2V2YWx1YXRvcj1Ob25lLAogICAgbW9kZWxzX2Rlc3Q6IHN0ciA9ICIiLAogICAgcGxvdHNfZGVzdDogc3RyID0gInBsb3RzIiwKICAgIGZpbGVfZXh0OiBzdHIgPSAicGFycXVldCIsCiAgICBtb2RlbF9wa2dfZmlsZTogc3RyID0gIiIsCiAgICByYW5kb21fc3RhdGU6IGludCA9IDEsCikgLT4gTm9uZToKICAgICIiInRyYWluIGEgY2xhc3NpZmllcgoKICAgIEFuIG9wdGlvbmFsIGN1dG9tIG1vZGVsIGV2YWx1YXRvciBjYW4gYmUgc3VwcGxpZWQgdGhhdCBzaG91bGQgaGF2ZSB0aGUgc2lnbmF0dXJlOgogICAgYG15X2N1c3RvbV9ldmFsdWF0b3IoY29udGV4dCwgeHZhbGlkLCB5dmFsaWQsIG1vZGVsKWAgYW5kIHJldHVybiBhIGRpY3Rpb25hcnkgb2YKICAgIHNjYWxhciAicmVzdWx0cyIsIGEgInBsb3RzIiBrZXlzIHdpdGggYSBsaXN0IG9mIFBsb3RBcnRpZmFjdHMsIGFuZAogICAgYW5kICJ0YWJsZXMiIGtleSBjb250YWluaW5nIGEgcmV0dXJuZWQgbGlzdCBvZiBUYWJsZUFydGlmYWN0cy4KCiAgICA6cGFyYW0gY29udGV4dDogICAgICAgICAgIHRoZSBmdW5jdGlvbiBjb250ZXh0CiAgICA6cGFyYW0gbW9kZWxfcGtnX2NsYXNzOiAgIHRoZSBtb2RlbCB0byB0cmFpbiwgZS5nLCAic2tsZWFybi5uZXVyYWxfbmV0d29ya3MuTUxQQ2xhc3NpZmllciIsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIG9yIGpzb24gbW9kZWwgY29uZmlnCiAgICA6cGFyYW0gZGF0YXNldDogICAgICAgICAgICgiZGF0YSIpIG5hbWUgb2YgcmF3IGRhdGEgZmlsZQogICAgOnBhcmFtIGxhYmVsX2NvbHVtbjogICAgICBncm91bmQtdHJ1dGggKHkpIGxhYmVscwogICAgOnBhcmFtIGVuY29kZV9jb2xzOiAgICAgICBkaWN0aW9uYXJ5IG9mIG5hbWVzIGFuZCBwcmVmaXhlcyBmb3IgY29sdW1ucyB0aGF0IGFyZQogICAgICAgICAgICAgICAgICAgICAgICAgICAgICB0byBob3QgYmUgZW5jb2RlZC4KICAgIDpwYXJhbSBzYW1wbGU6ICAgICAgICAgICAgU2VsZWN0cyB0aGUgZmlyc3QgbiByb3dzLCBvciBzZWxlY3QgYSBzYW1wbGUKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgc3RhcnRpbmcgZnJvbSB0aGUgZmlyc3QuIElmIG5lZ2F0aXZlIDwtMSwgc2VsZWN0CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGEgcmFuZG9tIHNhbXBsZQogICAgOnBhcmFtIHRlc3Rfc2l6ZTogICAgICAgICAoMC4wNSkgdGVzdCBzZXQgc2l6ZQogICAgOnBhcmFtIHRyYWluX3ZhbF9zcGxpdDogICAoMC43NSkgT25jZSB0aGUgdGVzdCBzZXQgaGFzIGJlZW4gcmVtb3ZlZCB0aGUKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgdHJhaW5pbmcgc2V0IGdldHMgdGhpcyBwcm9wb3J0aW9uLgogICAgOnBhcmFtIHRlc3Rfc2V0X2tleTogICAgICBrZXkgb2YgaGVsZCBvdXQgZGF0YSBpbiBhcnRpZmFjdCBzdG9yZQogICAgOnBhcmFtIG1vZGVsX2V2YWx1YXRvcjogICAoTm9uZSkgYSBjdXN0b20gbW9kZWwgZXZhbHVhdG9yIGNhbiBiZSBzcGVjaWZpZWQKICAgIDpwYXJhbSBtb2RlbHNfZGVzdDogICAgICAgKCIiKSBtb2RlbHMgc3ViZm9sZGVyIG9uIGFydGlmYWN0IHBhdGgKICAgIDpwYXJhbSBwbG90c19kZXN0OiAgICAgICAgcGxvdCBzdWJmb2xkZXIgb24gYXJ0aWZhY3QgcGF0aAogICAgOnBhcmFtIGZpbGVfZXh0OiAgICAgICAgICAoInBhcnF1ZXQiKSBmb3JtYXQgZm9yIHRlc3Rfc2V0X2tleSBob2xkIG91dCBkYXRhCiAgICA6cGFyYW0gcmFuZG9tX3N0YXRlOiAgICAgICgxKSBza2xlYXJuIHJuZyBzZWVkCgogICAgIiIiCiAgICBtb2RlbHNfZGVzdCA9IG1vZGVsc19kZXN0IG9yICJtb2RlbCIKCiAgICByYXcsIGxhYmVscywgaGVhZGVyID0gZ2V0X3NhbXBsZShkYXRhc2V0LCBzYW1wbGUsIGxhYmVsX2NvbHVtbikKCiAgICBpZiBlbmNvZGVfY29sczoKICAgICAgICByYXcgPSBwZC5nZXRfZHVtbWllcygKICAgICAgICAgICAgcmF3LAogICAgICAgICAgICBjb2x1bW5zPWxpc3QoZW5jb2RlX2NvbHMua2V5cygpKSwKICAgICAgICAgICAgcHJlZml4PWxpc3QoZW5jb2RlX2NvbHMudmFsdWVzKCkpLAogICAgICAgICAgICBkcm9wX2ZpcnN0PVRydWUsCiAgICAgICAgKQoKICAgICh4dHJhaW4sIHl0cmFpbiksICh4dmFsaWQsIHl2YWxpZCksICh4dGVzdCwgeXRlc3QpID0gZ2V0X3NwbGl0cygKICAgICAgICByYXcsIGxhYmVscywgMywgdGVzdF9zaXplLCAxIC0gdHJhaW5fdmFsX3NwbGl0LCByYW5kb21fc3RhdGUKICAgICkKCiAgICB0ZXN0X3NldCA9IHBkLmNvbmNhdChbeHRlc3QsIHl0ZXN0LnRvX2ZyYW1lKCldLCBheGlzPTEpCiAgICBjb250ZXh0LmxvZ19kYXRhc2V0KAogICAgICAgIHRlc3Rfc2V0X2tleSwKICAgICAgICBkZj10ZXN0X3NldCwKICAgICAgICBmb3JtYXQ9ZmlsZV9leHQsCiAgICAgICAgaW5kZXg9RmFsc2UsCiAgICAgICAgbGFiZWxzPXsiZGF0YS10eXBlIjogImhlbGQtb3V0In0sCiAgICAgICAgYXJ0aWZhY3RfcGF0aD1jb250ZXh0LmFydGlmYWN0X3N1YnBhdGgoImRhdGEiKSwKICAgICkKCiAgICBtb2RlbF9jb25maWcgPSBnZW5fc2tsZWFybl9tb2RlbChtb2RlbF9wa2dfY2xhc3MsIGNvbnRleHQucGFyYW1ldGVycy5pdGVtcygpKQoKICAgIG1vZGVsX2NvbmZpZ1siRklUIl0udXBkYXRlKHsiWCI6IHh0cmFpbiwgInkiOiB5dHJhaW4udmFsdWVzfSkKCiAgICBDbGFzc2lmaWVyQ2xhc3MgPSBjcmVhdGVfY2xhc3MobW9kZWxfY29uZmlnWyJNRVRBIl1bImNsYXNzIl0pCgogICAgbW9kZWwgPSBDbGFzc2lmaWVyQ2xhc3MoKiptb2RlbF9jb25maWdbIkNMQVNTIl0pCgogICAgbW9kZWwuZml0KCoqbW9kZWxfY29uZmlnWyJGSVQiXSkKCiAgICBhcnRpZmFjdF9wYXRoID0gY29udGV4dC5hcnRpZmFjdF9zdWJwYXRoKG1vZGVsc19kZXN0KQogICAgcGxvdHNfcGF0aCA9IGNvbnRleHQuYXJ0aWZhY3Rfc3VicGF0aChtb2RlbHNfZGVzdCwgcGxvdHNfZGVzdCkKICAgIGlmIG1vZGVsX2V2YWx1YXRvcjoKICAgICAgICBldmFsX21ldHJpY3MgPSBtb2RlbF9ldmFsdWF0b3IoCiAgICAgICAgICAgIGNvbnRleHQsIHh2YWxpZCwgeXZhbGlkLCBtb2RlbCwgcGxvdHNfYXJ0aWZhY3RfcGF0aD1wbG90c19wYXRoCiAgICAgICAgKQogICAgZWxzZToKICAgICAgICBldmFsX21ldHJpY3MgPSBldmFsX21vZGVsX3YyKAogICAgICAgICAgICBjb250ZXh0LCB4dmFsaWQsIHl2YWxpZCwgbW9kZWwsIHBsb3RzX2FydGlmYWN0X3BhdGg9cGxvdHNfcGF0aAogICAgICAgICkKCiAgICBrd2FyZ3MgPSB7InRyYWluaW5nX3NldCI6IHRlc3Rfc2V0LCAibGFiZWxfY29sdW1uIjogbGFiZWxfY29sdW1ufQogICAgc3BsaXQgPSBtb2RlbF9wa2dfY2xhc3MucnNwbGl0KCIuIiwgMSkKICAgIGlmIHNwbGl0IGFuZCBsZW4oc3BsaXQpID09IDI6CiAgICAgICAga3dhcmdzWyJhbGdvcml0aG0iXSA9IHNwbGl0WzFdCgogICAgaWYgZGF0YXNldC5tZXRhIGFuZCBkYXRhc2V0Lm1ldGEua2luZCA9PSAiRmVhdHVyZVZlY3RvciI6CiAgICAgICAga3dhcmdzWyJmZWF0dXJlX3ZlY3RvciJdID0gZGF0YXNldC5tZXRhLnVyaQoKICAgIGNvbnRleHQuc2V0X2xhYmVsKCJjbGFzcyIsIG1vZGVsX3BrZ19jbGFzcykKICAgIGNvbnRleHQubG9nX21vZGVsKAogICAgICAgICJtb2RlbCIsCiAgICAgICAgYm9keT1kdW1wcyhtb2RlbCksCiAgICAgICAgYXJ0aWZhY3RfcGF0aD1hcnRpZmFjdF9wYXRoLAogICAgICAgIGV4dHJhX2RhdGE9ZXZhbF9tZXRyaWNzLAogICAgICAgIG1vZGVsX2ZpbGU9Im1vZGVsLnBrbCIsCiAgICAgICAgbWV0cmljcz1jb250ZXh0LnJlc3VsdHMsCiAgICAgICAgbGFiZWxzPXsiY2xhc3MiOiBtb2RlbF9wa2dfY2xhc3N9LAogICAgICAgIGZyYW1ld29yaz0ic2tsZWFybiIsCiAgICAgICAgKiprd2FyZ3MKICAgICkK - origin_filename: '' - code_origin: '' + has_varargs: false + lineno: 31 command: '' -metadata: - tag: '' - name: sklearn-classifier - categories: - - machine-learning - - model-training -verbose: false -kind: job + description: train any classifier using scikit-learn's API + default_handler: train_model diff --git a/functions/src/sklearn_classifier/sklearn_classifier.py b/functions/src/sklearn_classifier/sklearn_classifier.py index 1a73d4045..daca4e4ad 100644 --- a/functions/src/sklearn_classifier/sklearn_classifier.py +++ b/functions/src/sklearn_classifier/sklearn_classifier.py @@ -19,13 +19,12 @@ warnings.simplefilter(action="ignore", category=FutureWarning) -from cloudpickle import dumps import pandas as pd -from typing import List -from mlrun.execution import MLClientCtx +from cloudpickle import dumps from mlrun.datastore import DataItem +from mlrun.execution import MLClientCtx from mlrun.mlutils.data import get_sample, get_splits -from mlrun.mlutils.models import gen_sklearn_model, eval_model_v2 +from mlrun.mlutils.models import eval_model_v2, gen_sklearn_model from mlrun.utils.helpers import create_class @@ -34,7 +33,7 @@ def train_model( model_pkg_class: str, dataset: DataItem, label_column: str = "labels", - encode_cols: List[str] = [], + encode_cols: list[str] = [], sample: int = -1, test_size: float = 0.30, train_val_split: float = 0.70, @@ -139,5 +138,5 @@ def train_model( metrics=context.results, labels={"class": model_pkg_class}, framework="sklearn", - **kwargs + **kwargs, ) diff --git a/functions/src/sklearn_classifier/test_sklearn_classifier.py b/functions/src/sklearn_classifier/test_sklearn_classifier.py index 5c29e85b3..2aa314b3d 100644 --- a/functions/src/sklearn_classifier/test_sklearn_classifier.py +++ b/functions/src/sklearn_classifier/test_sklearn_classifier.py @@ -12,22 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import mlrun import os import pickle + +import mlrun import pandas as pd def generate_data(): - fn = mlrun.import_function('../gen_class_data/function.yaml') - run = fn.run(params={'key': 'classifier-data', - 'n_samples': 10_000, - 'm_features': 5, - 'k_classes': 2, - 'header': None, - 'weight': [0.5, 0.5], - 'sk_params': {'n_informative': 2}, - 'file_ext': 'csv'}, local=True, artifact_path="./artifacts") + fn = mlrun.import_function("../gen_class_data/function.yaml") + run = fn.run( + params={ + "key": "classifier-data", + "n_samples": 10_000, + "m_features": 5, + "k_classes": 2, + "header": None, + "weight": [0.5, 0.5], + "sk_params": {"n_informative": 2}, + "file_ext": "csv", + }, + local=True, + artifact_path="./artifacts", + ) return run @@ -35,23 +42,31 @@ def test_import_sklearn_classifier(): acquire_run = generate_data() fn = mlrun.import_function("function.yaml") # define model - params = {"model_pkg_class": "sklearn.ensemble.RandomForestClassifier", - "label_column": "labels"} + params = { + "model_pkg_class": "sklearn.ensemble.RandomForestClassifier", + "label_column": "labels", + } - train_run = fn.run(params=params, - inputs={"dataset": acquire_run.status.artifacts[0]['spec']['target_path']}, - local=True, - artifact_path="./") + train_run = fn.run( + params=params, + inputs={"dataset": acquire_run.status.artifacts[0]["spec"]["target_path"]}, + local=True, + artifact_path="./", + ) for artifact in train_run.status.artifacts: - if artifact['kind'] == 'model': - assert os.path.exists(artifact['spec']['target_path']), 'Could not find model dir' + if artifact["kind"] == "model": + assert os.path.exists(artifact["spec"]["target_path"]), ( + "Could not find model dir" + ) break - assert os.path.exists(train_run.status.artifacts[0]['spec']['target_path']) - model = pickle.load(open(artifact['spec']['target_path'] + artifact['spec']['model_file'], 'rb')) - df = pd.read_csv(acquire_run.status.artifacts[0]['spec']['target_path']) - x = df.drop(['labels'], axis=1).iloc[0:1] - y_true = df['labels'][0] + assert os.path.exists(train_run.status.artifacts[0]["spec"]["target_path"]) + model = pickle.load( + open(artifact["spec"]["target_path"] + artifact["spec"]["model_file"], "rb") + ) + df = pd.read_csv(acquire_run.status.artifacts[0]["spec"]["target_path"]) + x = df.drop(["labels"], axis=1).iloc[0:1] + y_true = df["labels"][0] y_pred = model.predict_proba(x).argmax() assert y_pred == y_true, "Failed to predict correctly" diff --git a/functions/src/sklearn_classifier_dask/function.yaml b/functions/src/sklearn_classifier_dask/function.yaml index 46f733886..e202a6c2d 100644 --- a/functions/src/sklearn_classifier_dask/function.yaml +++ b/functions/src/sklearn_classifier_dask/function.yaml @@ -1,42 +1,34 @@ -kind: job metadata: - name: sklearn-classifier-dask tag: '' - hash: e542038fbb84f790b7144b529665f36d70d80906 - project: '' - labels: - author: Iguazio - framework: sklearn + name: sklearn-classifier-dask categories: - machine-learning - model-training +verbose: false +kind: job spec: - command: '' - args: [] image: mlrun/ml-models + disable_auto_mount: false build: - functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Generated by nuclio.export.NuclioExporter

import mlrun

import warnings

warnings.filterwarnings("ignore")

import joblib
import numpy as np
import pandas as pd
from cloudpickle import dumps

from dask import dataframe as dd
from dask.delayed import delayed
from dask_ml import model_selection
from dask_ml.preprocessing import StandardScaler, LabelEncoder

from mlrun.artifacts import PlotArtifact
from mlrun.mlutils.models import gen_sklearn_model
from mlrun.utils.helpers import create_class

import matplotlib.pyplot as plt
from yellowbrick.classifier import ROCAUC, ClassificationReport, ConfusionMatrix
from yellowbrick.model_selection import FeatureImportances


def train_model(
    context: mlrun.MLClientCtx,
    dataset: mlrun.DataItem,
    model_pkg_class: str,
    label_column: str = "label",
    train_validation_size: float = 0.75,
    sample: float = 1.0,
    models_dest: str = "models",
    test_set_key: str = "test_set",
    plots_dest: str = "plots",
    dask_function: str = None,
    dask_client=None,
    file_ext: str = "parquet",
    random_state: int = 42,
) -> None:

    """
    Train a sklearn classifier with Dask

    :param context:                 Function context.
    :param dataset:                 Raw data file.
    :param model_pkg_class:         Model to train, e.g, "sklearn.ensemble.RandomForestClassifier",
                                    or json model config.
    :param label_column:            (label) Ground-truth y labels.
    :param train_validation_size:   (0.75) Train validation set proportion out of the full dataset.
    :param sample:                  (1.0) Select sample from dataset (n-rows/% of total), randomzie rows as default.
    :param models_dest:             (models) Models subfolder on artifact path.
    :param test_set_key:            (test_set) Mlrun db key of held out data in artifact store.
    :param plots_dest:              (plots) Plot subfolder on artifact path.
    :param dask_function:           dask function url (db://..)
    :param dask_client:             dask client object
    :param file_ext:                (parquet) format for test_set_key hold out data
    :param random_state:            (42) sklearn seed
    """
    if dask_function:
        client = mlrun.import_function(dask_function).client
    elif dask_client:
        client = dask_client
    else:
        raise ValueError("dask client was not provided")

    context.logger.info("Read Data")
    df = dataset.as_df(df_module=dd)

    context.logger.info("Prep Data")
    numerics = ["int16", "int32", "int64", "float16", "float32", "float64"]
    df = df.select_dtypes(include=numerics)

    if df.isna().any().any().compute() == True:
        raise Exception("NAs valus found")

    df_header = df.columns

    df = df.sample(frac=sample).reset_index(drop=True)
    encoder = LabelEncoder()
    encoder = encoder.fit(df[label_column])
    X = df.drop(label_column, axis=1).to_dask_array(lengths=True)
    y = encoder.transform(df[label_column])

    classes = df[label_column].drop_duplicates()  # no unique values in dask
    classes = [str(i) for i in classes]

    context.logger.info("Split and Train")
    X_train, X_test, y_train, y_test = model_selection.train_test_split(
        X, y, train_size=train_validation_size, random_state=random_state
    )

    scaler = StandardScaler()
    scaler = scaler.fit(X_train)
    X_train_transformed = scaler.transform(X_train)
    X_test_transformed = scaler.transform(X_test)

    model_config = gen_sklearn_model(model_pkg_class, context.parameters.items())

    model_config["FIT"].update({"X": X_train_transformed, "y": y_train})

    ClassifierClass = create_class(model_config["META"]["class"])

    model = ClassifierClass(**model_config["CLASS"])

    with joblib.parallel_backend("dask"):
        model = model.fit(**model_config["FIT"])

    context.logger.info("Evaluate")
    extra_data_dict = {}
    for report in (ROCAUC, ClassificationReport, ConfusionMatrix):
        report_name = str(report.__name__)
        plt.cla()
        plt.clf()
        plt.close()

        viz = report(model, classes=classes, per_class=True, is_fitted=True)
        viz.fit(X_train_transformed, y_train)  # Fit the training data to the visualizer
        viz.score(
            X_test_transformed, y_test.compute()
        )  # Evaluate the model on the test data

        plot = context.log_artifact(
            PlotArtifact(report_name, body=viz.fig, title=report_name), db_key=False
        )
        extra_data_dict[str(report)] = plot

        if report_name == "ROCAUC":
            context.log_results(
                {"micro": viz.roc_auc.get("micro"), "macro": viz.roc_auc.get("macro")}
            )

        elif report_name == "ClassificationReport":
            for score_name in viz.scores_:
                for score_class in viz.scores_[score_name]:

                    context.log_results(
                        {
                            score_name
                            + "-"
                            + score_class: viz.scores_[score_name].get(score_class)
                        }
                    )

    viz = FeatureImportances(
        model,
        classes=classes,
        per_class=True,
        is_fitted=True,
        labels=df_header.delete(df_header.get_loc(label_column)),
    )
    viz.fit(X_train_transformed, y_train)
    viz.score(X_test_transformed, y_test)

    plot = context.log_artifact(
        PlotArtifact("FeatureImportances", body=viz.fig, title="FeatureImportances"),
        db_key=False,
    )
    extra_data_dict["FeatureImportances"] = plot

    plt.cla()
    plt.clf()
    plt.close()

    context.logger.info("Log artifacts")
    artifact_path = context.artifact_subpath(models_dest)

    context.set_label("class", model_pkg_class)

    context.log_model(
        "model",
        body=dumps(model),
        artifact_path=artifact_path,
        model_file="model.pkl",
        extra_data=extra_data_dict,
        metrics=context.results,
        labels={"class": model_pkg_class},
    )

    context.log_artifact(
        "standard_scaler",
        body=dumps(scaler),
        artifact_path=artifact_path,
    )

    context.log_artifact(
        "label_encoder",
        body=dumps(encoder),
        artifact_path=artifact_path,
    )

    df_to_save = delayed(np.column_stack)((X_test, y_test)).compute()
    context.log_dataset(
        test_set_key,
        df=pd.DataFrame(df_to_save, columns=df_header),  # improve log dataset ability
        format=file_ext,
        index=False,
        labels={"data-type": "held-out"},
        artifact_path=context.artifact_subpath("data"),
    )

    context.logger.info("Done!")
 - commands: [] - code_origin: https://github.com/guy1992l/functions.git#75359393bff0aaf27fb04c00d5d0037a1d1e32db:/Users/guyl/Projects/functions/sklearn_classifier_dask/sklearn_classifier_dask.py - origin_filename: /Users/guyl/Projects/functions/sklearn_classifier_dask/sklearn_classifier_dask.py + origin_filename: '' + functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Generated by nuclio.export.NuclioExporter

import warnings

import mlrun

warnings.filterwarnings("ignore")

import joblib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from cloudpickle import dumps
from dask import dataframe as dd
from dask.delayed import delayed
from dask_ml import model_selection
from dask_ml.preprocessing import LabelEncoder, StandardScaler
from mlrun.artifacts import PlotArtifact
from mlrun.mlutils.models import gen_sklearn_model
from mlrun.utils.helpers import create_class
from yellowbrick.classifier import ROCAUC, ClassificationReport, ConfusionMatrix
from yellowbrick.model_selection import FeatureImportances


def train_model(
    context: mlrun.MLClientCtx,
    dataset: mlrun.DataItem,
    model_pkg_class: str,
    label_column: str = "label",
    train_validation_size: float = 0.75,
    sample: float = 1.0,
    models_dest: str = "models",
    test_set_key: str = "test_set",
    plots_dest: str = "plots",
    dask_function: str = None,
    dask_client=None,
    file_ext: str = "parquet",
    random_state: int = 42,
) -> None:
    """
    Train a sklearn classifier with Dask

    :param context:                 Function context.
    :param dataset:                 Raw data file.
    :param model_pkg_class:         Model to train, e.g, "sklearn.ensemble.RandomForestClassifier",
                                    or json model config.
    :param label_column:            (label) Ground-truth y labels.
    :param train_validation_size:   (0.75) Train validation set proportion out of the full dataset.
    :param sample:                  (1.0) Select sample from dataset (n-rows/% of total), randomzie rows as default.
    :param models_dest:             (models) Models subfolder on artifact path.
    :param test_set_key:            (test_set) Mlrun db key of held out data in artifact store.
    :param plots_dest:              (plots) Plot subfolder on artifact path.
    :param dask_function:           dask function url (db://..)
    :param dask_client:             dask client object
    :param file_ext:                (parquet) format for test_set_key hold out data
    :param random_state:            (42) sklearn seed
    """
    if dask_function:
        client = mlrun.import_function(dask_function).client
    elif dask_client:
        client = dask_client
    else:
        raise ValueError("dask client was not provided")

    context.logger.info("Read Data")
    df = dataset.as_df(df_module=dd)

    context.logger.info("Prep Data")
    numerics = ["int16", "int32", "int64", "float16", "float32", "float64"]
    df = df.select_dtypes(include=numerics)

    if df.isna().any().any().compute() == True:
        raise Exception("NAs valus found")

    df_header = df.columns

    df = df.sample(frac=sample).reset_index(drop=True)
    encoder = LabelEncoder()
    encoder = encoder.fit(df[label_column])
    X = df.drop(label_column, axis=1).to_dask_array(lengths=True)
    y = encoder.transform(df[label_column])

    classes = df[label_column].drop_duplicates()  # no unique values in dask
    classes = [str(i) for i in classes]

    context.logger.info("Split and Train")
    X_train, X_test, y_train, y_test = model_selection.train_test_split(
        X, y, train_size=train_validation_size, random_state=random_state
    )

    scaler = StandardScaler()
    scaler = scaler.fit(X_train)
    X_train_transformed = scaler.transform(X_train)
    X_test_transformed = scaler.transform(X_test)

    model_config = gen_sklearn_model(model_pkg_class, context.parameters.items())

    model_config["FIT"].update({"X": X_train_transformed, "y": y_train})

    ClassifierClass = create_class(model_config["META"]["class"])

    model = ClassifierClass(**model_config["CLASS"])

    with joblib.parallel_backend("dask"):
        model = model.fit(**model_config["FIT"])

    context.logger.info("Evaluate")
    extra_data_dict = {}
    for report in (ROCAUC, ClassificationReport, ConfusionMatrix):
        report_name = str(report.__name__)
        plt.cla()
        plt.clf()
        plt.close()

        viz = report(model, classes=classes, per_class=True, is_fitted=True)
        viz.fit(X_train_transformed, y_train)  # Fit the training data to the visualizer
        viz.score(
            X_test_transformed, y_test.compute()
        )  # Evaluate the model on the test data

        plot = context.log_artifact(
            PlotArtifact(report_name, body=viz.fig, title=report_name), db_key=False
        )
        extra_data_dict[str(report)] = plot

        if report_name == "ROCAUC":
            context.log_results(
                {"micro": viz.roc_auc.get("micro"), "macro": viz.roc_auc.get("macro")}
            )

        elif report_name == "ClassificationReport":
            for score_name in viz.scores_:
                for score_class in viz.scores_[score_name]:
                    context.log_results(
                        {
                            score_name + "-" + score_class: viz.scores_[score_name].get(
                                score_class
                            )
                        }
                    )

    viz = FeatureImportances(
        model,
        classes=classes,
        per_class=True,
        is_fitted=True,
        labels=df_header.delete(df_header.get_loc(label_column)),
    )
    viz.fit(X_train_transformed, y_train)
    viz.score(X_test_transformed, y_test)

    plot = context.log_artifact(
        PlotArtifact("FeatureImportances", body=viz.fig, title="FeatureImportances"),
        db_key=False,
    )
    extra_data_dict["FeatureImportances"] = plot

    plt.cla()
    plt.clf()
    plt.close()

    context.logger.info("Log artifacts")
    artifact_path = context.artifact_subpath(models_dest)

    context.set_label("class", model_pkg_class)

    context.log_model(
        "model",
        body=dumps(model),
        artifact_path=artifact_path,
        model_file="model.pkl",
        extra_data=extra_data_dict,
        metrics=context.results,
        labels={"class": model_pkg_class},
    )

    context.log_artifact(
        "standard_scaler",
        body=dumps(scaler),
        artifact_path=artifact_path,
    )

    context.log_artifact(
        "label_encoder",
        body=dumps(encoder),
        artifact_path=artifact_path,
    )

    df_to_save = delayed(np.column_stack)((X_test, y_test)).compute()
    context.log_dataset(
        test_set_key,
        df=pd.DataFrame(df_to_save, columns=df_header),  # improve log dataset ability
        format=file_ext,
        index=False,
        labels={"data-type": "held-out"},
        artifact_path=context.artifact_subpath("data"),
    )

    context.logger.info("Done!")
 + code_origin: '' + filename: sklearn_classifier_dask.py entry_points: train_model: - name: train_model - doc: Train a sklearn classifier with Dask + outputs: + - type: None parameters: - name: context type: MLClientCtx doc: Function context. - default: '' - name: dataset type: DataItem doc: Raw data file. - default: '' - name: model_pkg_class type: str doc: Model to train, e.g, "sklearn.ensemble.RandomForestClassifier", or json model config. - default: '' - name: label_column type: str doc: (label) Ground-truth y labels. @@ -77,16 +69,11 @@ spec: type: int doc: (42) sklearn seed default: 42 - outputs: - - default: '' - lineno: 42 + name: train_model + doc: Train a sklearn classifier with Dask + has_kwargs: false + has_varargs: false + lineno: 39 + command: '' description: train any classifier using scikit-learn's API over Dask default_handler: train_model - disable_auto_mount: false - env: [] - priority_class_name: '' - preemption_mode: prevent - affinity: null - tolerations: null - security_context: {} -verbose: false diff --git a/functions/src/sklearn_classifier_dask/sklearn_classifier_dask.py b/functions/src/sklearn_classifier_dask/sklearn_classifier_dask.py index 39ec34716..73042ca45 100644 --- a/functions/src/sklearn_classifier_dask/sklearn_classifier_dask.py +++ b/functions/src/sklearn_classifier_dask/sklearn_classifier_dask.py @@ -14,27 +14,24 @@ # # Generated by nuclio.export.NuclioExporter -import mlrun - import warnings +import mlrun + warnings.filterwarnings("ignore") import joblib +import matplotlib.pyplot as plt import numpy as np import pandas as pd from cloudpickle import dumps - from dask import dataframe as dd from dask.delayed import delayed from dask_ml import model_selection -from dask_ml.preprocessing import StandardScaler, LabelEncoder - +from dask_ml.preprocessing import LabelEncoder, StandardScaler from mlrun.artifacts import PlotArtifact from mlrun.mlutils.models import gen_sklearn_model from mlrun.utils.helpers import create_class - -import matplotlib.pyplot as plt from yellowbrick.classifier import ROCAUC, ClassificationReport, ConfusionMatrix from yellowbrick.model_selection import FeatureImportances @@ -54,7 +51,6 @@ def train_model( file_ext: str = "parquet", random_state: int = 42, ) -> None: - """ Train a sklearn classifier with Dask @@ -149,12 +145,11 @@ def train_model( elif report_name == "ClassificationReport": for score_name in viz.scores_: for score_class in viz.scores_[score_name]: - context.log_results( { - score_name - + "-" - + score_class: viz.scores_[score_name].get(score_class) + score_name + "-" + score_class: viz.scores_[score_name].get( + score_class + ) } ) diff --git a/functions/src/structured_data_generator/function.yaml b/functions/src/structured_data_generator/function.yaml index 4e8a35626..e473c87f5 100644 --- a/functions/src/structured_data_generator/function.yaml +++ b/functions/src/structured_data_generator/function.yaml @@ -1,21 +1,27 @@ +metadata: + tag: '' + name: structured-data-generator + categories: + - data-generation + - genai +verbose: false +kind: job spec: + image: '' + disable_auto_mount: false build: origin_filename: '' + functionSourceCode: IyBDb3B5cmlnaHQgMjAyMyBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMAojCiMgVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQojIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuICJBUyBJUyIgQkFTSVMsCiMgV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuCiMgU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAojIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLgppbXBvcnQgYXN0CmltcG9ydCBvcwoKaW1wb3J0IHRxZG0KZnJvbSBsYW5nY2hhaW4uY2hhdF9tb2RlbHMgaW1wb3J0IENoYXRPcGVuQUkKCgpkZWYgX3NldF9vcGVuYWlfc2VjcmV0cygpIC0+IGJvb2w6CiAgICBrZXkgPSAiT1BFTkFJX0FQSV9LRVkiCiAgICBiYXNlID0gIk9QRU5BSV9BUElfQkFTRSIKICAgICMgQ2hlY2sgaWYgdGhlIGtleSBpcyBhbHJlYWR5IGluIHRoZSBlbnZpcm9ubWVudCB2YXJpYWJsZXM6CiAgICBpZiBrZXkgaW4gb3MuZW52aXJvbiBhbmQgYmFzZSBpbiBvcy5lbnZpcm9uOgogICAgICAgIHJldHVybiBUcnVlCiAgICAjIENoZWNrIGlmIG1scnVuIGlzIGluc3RhbGxlZDoKICAgIHRyeToKICAgICAgICBpbXBvcnQgbWxydW4KICAgIGV4Y2VwdCBNb2R1bGVOb3RGb3VuZEVycm9yOgogICAgICAgIHJhaXNlIE9TRXJyb3IoCiAgICAgICAgICAgIGYiT25lIG9yIG1vcmUgb2YgdGhlIE9wZW5BSSByZXF1aXJlZCBlbnZpcm9ubWVudCB2YXJpYWJsZXMgKCd7a2V5fScsICd7YmFzZX0nKSBhcmUgbWlzc2luZy4iCiAgICAgICAgICAgIGYiUGxlYXNlIHNldCB0aGVtIGFzIGVudmlyb25tZW50IHZhcmlhYmxlcyBvciBpbnN0YWxsIG1scnVuIChgcGlwIGluc3RhbGwgbWxydW5gKSIKICAgICAgICAgICAgZiJhbmQgc2V0IHRoZW0gYXMgcHJvamVjdCBzZWNyZXRzIHVzaW5nIGBwcm9qZWN5LnNldF9zZWNyZXRzYC4iCiAgICAgICAgKQoKICAgICMgQ2hlY2sgaWYgdGhlIGtleSBpcyBpbiB0aGUgc2VjcmV0czoKICAgIGNvbnRleHQgPSBtbHJ1bi5nZXRfb3JfY3JlYXRlX2N0eChuYW1lPSJjb250ZXh0IikKICAgIG9wZW5haV9rZXkgPSBjb250ZXh0LmdldF9zZWNyZXQoa2V5KQogICAgb3BlbmFpX2Jhc2UgPSBjb250ZXh0LmdldF9zZWNyZXQoYmFzZSkKCiAgICAjIElmIHRoZSBrZXkgaXMgbm90IGluIHRoZSBzZWNyZXRzLCByZXR1cm4gRmFsc2U6CiAgICBpZiBub3Qgb3BlbmFpX2tleToKICAgICAgICByYWlzZSBPU0Vycm9yKAogICAgICAgICAgICBmIkNvdWxkIG5vdCBmaW5kIE9wZW5BSSBBUEkga2V5IGluIHRoZSBlbnZpcm9ubWVudCB2YXJpYWJsZXMgb3Igc2VjcmV0cywiCiAgICAgICAgICAgIGYiIHBsZWFzZSBzZXQgaXQgYXM6IHtrZXl9LiIKICAgICAgICApCiAgICBpZiBub3Qgb3BlbmFpX2Jhc2U6CiAgICAgICAgcmFpc2UgT1NFcnJvcigKICAgICAgICAgICAgZiJDb3VsZCBub3QgZmluZCBPcGVuQUkgQVBJIGJhc2UgaW4gdGhlIGVudmlyb25tZW50IHZhcmlhYmxlcyBvciBzZWNyZXRzLCIKICAgICAgICAgICAgZiIgcGxlYXNlIHNldCBpdCBhczoge2Jhc2V9LiIKICAgICAgICApCiAgICAjIElmIHRoZSBrZXkgaXMgaW4gdGhlIHNlY3JldHMsIHNldCBpdCBpbiB0aGUgZW52aXJvbm1lbnQgdmFyaWFibGVzIGFuZCByZXR1cm4gVHJ1ZToKICAgIG9zLmVudmlyb25ba2V5XSA9IG9wZW5haV9rZXkKICAgIG9zLmVudmlyb25bYmFzZV0gPSBvcGVuYWlfYmFzZQogICAgcmV0dXJuIFRydWUKCgpkZWYgZ2VuZXJhdGVfZGF0YSgKICAgIGZpZWxkczogbGlzdCwKICAgIGFtb3VudDogaW50ID0gMTAsCiAgICBtb2RlbF9uYW1lOiBzdHIgPSAiZ3B0LTMuNS10dXJibyIsCiAgICBsYW5ndWFnZTogc3RyID0gImVuIiwKICAgIGNodW5rX3NpemU6IGludCA9IDUwLAopIC0+IGxpc3Q6CiAgICAiIiIKICAgIFN0cnVjdHVyZWQgZGF0YSBvZiBlbGVtZW50cyBhY2NvcmRpbmcgdG8gdGhlIGdpdmVuIHBhcmFtZXRlcnMuCiAgICBUaGUgZGF0YSBjYW4gYmUgbGF0ZXIgbG9nZ2VkIGFzIGEgc3RydWN0dXJlZCBmaWxlIHdpdGggTUxSdW4ncyBgcmV0dXJuc2AgcGFyYW1ldGVyLgoKICAgIDpwYXJhbSBmaWVsZHM6IEEgbGlzdCBvZiBmaWVsZHMgdG8gcmFuZG9tbHkgZ2VuZXJhdGUuCiAgICA6cGFyYW0gYW1vdW50OiBUaGUgbnVtYmVyIG9mIHZhcmlhbnRzIHRvIGdlbmVyYXRlLgogICAgOnBhcmFtIG1vZGVsX25hbWU6IFRoZSBuYW1lIG9mIHRoZSBtb2RlbCB0byB1c2UgZm9yIGNvbnZlcnNhdGlvbiBnZW5lcmF0aW9uLgogICAgICAgICAgICAgICAgICAgICAgIFlvdSBzaG91bGQgY2hvb3NlIG9uZSBvZiBHUFQtNCBvciBHUFQtMy41IGZyb20gdGhlIGxpc3QgaGVyZTogaHR0cHM6Ly9wbGF0Zm9ybS5vcGVuYWkuY29tL2RvY3MvbW9kZWxzLgogICAgICAgICAgICAgICAgICAgICAgIERlZmF1bHQ6ICdncHQtMy41LXR1cmJvJy4KICAgIDpwYXJhbSBsYW5ndWFnZTogVGhlIGxhbmd1YWdlIHRvIHVzZSBmb3IgdGhlIGdlbmVyYXRlZCBjb252ZXJzYXRpb24gdGV4dC4KICAgIDpwYXJhbSBjaHVua19zaXplOiBOdW1iZXIgb2Ygc2FtcGxlcyBnZW5lcmF0ZWQgYXQgZWFjaCBHUFQgcXVlcnkuCiAgICAiIiIKICAgIGluc3RydWN0aW9ucyA9ICIiCiAgICBmb3IgZmllbGQgaW4gZmllbGRzOgogICAgICAgICMgU3BsaXQgdGhlIGZpZWxkIHRvIGtleSBhbmQgaW5zdHJ1Y3Rpb246CiAgICAgICAgaWYgIjoiIGluIGZpZWxkOgogICAgICAgICAgICBrZXksIGluc3RydWN0aW9uID0gZmllbGQuc3BsaXQoIjoiLCAxKQogICAgICAgIGVsc2U6CiAgICAgICAgICAgIGtleSwgaW5zdHJ1Y3Rpb24gPSBmaWVsZCwgIm5vIHNwZWNpYWwgaW5zdHJ1Y3Rpb24iCiAgICAgICAgIyBSZXBsYWNlIHNwYWNlcyB3aXRoIHVuZGVyc2NvcmVzIGZvciB0aGUga2V5IHRvIGJlIHVzZWQgYXMgYSBqc29uIGtleToKICAgICAgICBrZXkgPSBrZXkuc3RyaXAoKS5yZXBsYWNlKCIgIiwgIl8iKQogICAgICAgIGluc3RydWN0aW9ucyArPSBmIioge2tleX06IHtpbnN0cnVjdGlvbn1cbiIKCiAgICAjIENyZWF0ZSB0aGUgcHJvbXB0IHN0cnVjdHVyZToKICAgIHByb21wdF9zdHJ1Y3R1cmUgPSAoCiAgICAgICAgZiJnZW5lcmF0ZSB0aGUgZm9sbG93aW5nIHZhbHVlcyB7YW1vdW50fSB0aW1lcyByYW5kb21seSwgaW4gYW4gb3JkZXIgdGhhdCBjcmVhdGVzIGEganNvbiB0YWJsZS5cbiIKICAgICAgICBmIlVzZSB0aGUgZm9sbG93aW5nIGtleXMgYW5kIGluc3RydWN0aW9ucyAoZXhhbXBsZTogJ2tleTogaW5zdHJ1Y3Rpb24gb3Igbm8gc3BlY2lhbCBpbnN0cnVjdGlvbicpOiAiCiAgICAgICAgZiJ7aW5zdHJ1Y3Rpb25zfS5cbiIKICAgICAgICBmIlBsZWFzZSBnZW5lcmF0ZSB0aGUgdmFsdWVzIGluIHtsYW5ndWFnZX0gbGFuZ3VhZ2UuIFxuIgogICAgICAgIGYiTWFrZSBzdXJlIHRoZSBuYW1lcyBvZiB0aGUga2V5cyBhcmUgdGhlIHNhbWUgYXMgdGhlIGdpdmVuIGZpZWxkIG5hbWUuXG4iCiAgICAgICAgZiJQbGVhc2UgcmV0dXJuIG9ubHkgdGhlIGpzb24gZm9ybWF0IHdpdGhvdXQgYW55IGludHJvZHVjdGlvbiBhbmQgZW5kaW5nIgogICAgKQoKICAgICMgU2V0IHRoZSBPcGVuQUkgc2VjcmV0czoKICAgIF9zZXRfb3BlbmFpX3NlY3JldHMoKQoKICAgICMgTG9hZCB0aGUgT3BlbkFJIG1vZGVsIHVzaW5nIGxhbmdjaGFpbjoKICAgIGxsbSA9IENoYXRPcGVuQUkobW9kZWw9bW9kZWxfbmFtZSkKCiAgICAjIFN0YXJ0IGdlbmVyYXRpbmcgZGF0YToKICAgIGRhdGEgPSBbXQogICAgZm9yIF8gaW4gdHFkbS50cWRtKHJhbmdlKChhbW91bnQgLy8gY2h1bmtfc2l6ZSkgKyAxKSwgZGVzYz0iR2VuZXJhdGluZyIpOgogICAgICAgICMgV2UgdHJ5IHRvIGdlbmVyYXRlIHRoZSBkYXRhIDMgdGltZXMsIGlmIHdlIGZhaWwgd2UgcmFpc2UgYW4gZXJyb3I6CiAgICAgICAgZm9yIHRyeW91dCBpbiByYW5nZSgzKToKICAgICAgICAgICAgIyBJZiB0aGUgYW1vdW50IHdhbnRlZCBpcyBiaWdnZXIgdGhhbiB0aGUgY2h1bmsgc2l6ZSwgd2UgZ2VuZXJhdGUgYSBjaHVuayBvZiBkYXRhIGluIHRoZSBzaXplIG9mIHRoZSBjaHVuawogICAgICAgICAgICAjIGFuZCBkZWNyZWFzZSB0aGUgYW1vdW50IGJ5IHRoZSBjaHVuayBzaXplLgogICAgICAgICAgICAjIG90aGVyd2lzZSB3ZSBnZW5lcmF0ZSBhIGNodW5rIG9mIGRhdGEgaW4gdGhlIHNpemUgb2YgdGhlIGFtb3VudDoKICAgICAgICAgICAgaWYgYW1vdW50ID4gY2h1bmtfc2l6ZToKICAgICAgICAgICAgICAgIGN1cnJlbnRfY2h1bmtfc2l6ZSA9IGNodW5rX3NpemUKICAgICAgICAgICAgICAgIGFtb3VudCAtPSBjaHVua19zaXplCiAgICAgICAgICAgIGVsc2U6CiAgICAgICAgICAgICAgICBjdXJyZW50X2NodW5rX3NpemUgPSBhbW91bnQKCiAgICAgICAgICAgICMgQ3JlYXRlIHRoZSBwcm9tcHQ6CiAgICAgICAgICAgIHByb21wdCA9IHByb21wdF9zdHJ1Y3R1cmUuZm9ybWF0KAogICAgICAgICAgICAgICAgYW1vdW50PWN1cnJlbnRfY2h1bmtfc2l6ZSwKICAgICAgICAgICAgKQoKICAgICAgICAgICAgIyBHZW5lcmF0ZSBhIGNodW5rIG9mIGRhdGE6CiAgICAgICAgICAgIGNodW5rX2RhdGEgPSBsbG0ucHJlZGljdCh0ZXh0PXByb21wdCkKCiAgICAgICAgICAgICMgVmFsaWRhdGUgdGhlIHJlc3BvbnNlIGZvciBjb3JyZWN0IHB5dGhvbiBgbGlzdGAgc3RydWN0dXJlCiAgICAgICAgICAgIGNodW5rX2RhdGEgPSBjaHVua19kYXRhW2NodW5rX2RhdGEuZmluZCgiWyIpIDogY2h1bmtfZGF0YS5yZmluZCgiXSIpICsgMV0KICAgICAgICAgICAgaWYgY2h1bmtfZGF0YS5jb3VudCgiWyIpICE9IGNodW5rX2RhdGEuY291bnQoIl0iKToKICAgICAgICAgICAgICAgIHByaW50KAogICAgICAgICAgICAgICAgICAgICJGYWlsZWQgdG8gZ2V0IHByb3BlciBqc29uIGZvcm1hdCBmcm9tIG1vZGVsLCBudW1iZXIgb2YgJ1snIGRvZXNuJ3QgbWF0Y2ggbnVtYmVyIG9mICddJy4iCiAgICAgICAgICAgICAgICApCiAgICAgICAgICAgICAgICBjb250aW51ZQogICAgICAgICAgICBjaHVua19kYXRhID0gYXN0LmxpdGVyYWxfZXZhbChjaHVua19kYXRhKQogICAgICAgICAgICBkYXRhICs9IGNodW5rX2RhdGEKICAgICAgICAgICAgYnJlYWsKICAgICAgICBpZiB0cnlvdXQgPT0gMzoKICAgICAgICAgICAgcmFpc2UgUnVudGltZUVycm9yKAogICAgICAgICAgICAgICAgZiJDb3VsZCBub3QgZ2VuZXJhdGUgYSBwcm9wZXIganNvbiBmb3JtYXQgZm9yIHRoZSBnaXZlbiBmaWVsZHMsIHVzaW5nIGdpdmVuIG1vZGVsOiB7bW9kZWxfbmFtZX0uIgogICAgICAgICAgICAgICAgZiIgSGludDogR3B0LTQgd29ya3MgYmVzdCBmb3IgbW9zdCBzY2VuYXJpb3MuIgogICAgICAgICAgICApCiAgICByZXR1cm4gZGF0YQo= requirements: - langchain - tqdm code_origin: '' - functionSourceCode: IyBDb3B5cmlnaHQgMjAyMyBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMAojCiMgVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQojIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuICJBUyBJUyIgQkFTSVMsCiMgV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuCiMgU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAojIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLgppbXBvcnQgYXN0CmltcG9ydCBvcwoKaW1wb3J0IHRxZG0KZnJvbSBsYW5nY2hhaW4uY2hhdF9tb2RlbHMgaW1wb3J0IENoYXRPcGVuQUkKCgpkZWYgX3NldF9vcGVuYWlfc2VjcmV0cygpIC0+IGJvb2w6CiAgICBrZXkgPSAiT1BFTkFJX0FQSV9LRVkiCiAgICBiYXNlID0gIk9QRU5BSV9BUElfQkFTRSIKICAgICMgQ2hlY2sgaWYgdGhlIGtleSBpcyBhbHJlYWR5IGluIHRoZSBlbnZpcm9ubWVudCB2YXJpYWJsZXM6CiAgICBpZiBrZXkgaW4gb3MuZW52aXJvbiBhbmQgYmFzZSBpbiBvcy5lbnZpcm9uOgogICAgICAgIHJldHVybiBUcnVlCiAgICAjIENoZWNrIGlmIG1scnVuIGlzIGluc3RhbGxlZDoKICAgIHRyeToKICAgICAgICBpbXBvcnQgbWxydW4KICAgIGV4Y2VwdCBNb2R1bGVOb3RGb3VuZEVycm9yOgogICAgICAgIHJhaXNlIEVudmlyb25tZW50RXJyb3IoCiAgICAgICAgICAgIGYiT25lIG9yIG1vcmUgb2YgdGhlIE9wZW5BSSByZXF1aXJlZCBlbnZpcm9ubWVudCB2YXJpYWJsZXMgKCd7a2V5fScsICd7YmFzZX0nKSBhcmUgbWlzc2luZy4iCiAgICAgICAgICAgIGYiUGxlYXNlIHNldCB0aGVtIGFzIGVudmlyb25tZW50IHZhcmlhYmxlcyBvciBpbnN0YWxsIG1scnVuIChgcGlwIGluc3RhbGwgbWxydW5gKSIKICAgICAgICAgICAgZiJhbmQgc2V0IHRoZW0gYXMgcHJvamVjdCBzZWNyZXRzIHVzaW5nIGBwcm9qZWN5LnNldF9zZWNyZXRzYC4iCiAgICAgICAgKQoKICAgICMgQ2hlY2sgaWYgdGhlIGtleSBpcyBpbiB0aGUgc2VjcmV0czoKICAgIGNvbnRleHQgPSBtbHJ1bi5nZXRfb3JfY3JlYXRlX2N0eChuYW1lPSJjb250ZXh0IikKICAgIG9wZW5haV9rZXkgPSBjb250ZXh0LmdldF9zZWNyZXQoa2V5KQogICAgb3BlbmFpX2Jhc2UgPSBjb250ZXh0LmdldF9zZWNyZXQoYmFzZSkKCiAgICAjIElmIHRoZSBrZXkgaXMgbm90IGluIHRoZSBzZWNyZXRzLCByZXR1cm4gRmFsc2U6CiAgICBpZiBub3Qgb3BlbmFpX2tleToKICAgICAgICByYWlzZSBFbnZpcm9ubWVudEVycm9yKAogICAgICAgICAgICBmIkNvdWxkIG5vdCBmaW5kIE9wZW5BSSBBUEkga2V5IGluIHRoZSBlbnZpcm9ubWVudCB2YXJpYWJsZXMgb3Igc2VjcmV0cywiCiAgICAgICAgICAgIGYiIHBsZWFzZSBzZXQgaXQgYXM6IHtrZXl9LiIKICAgICAgICApCiAgICBpZiBub3Qgb3BlbmFpX2Jhc2U6CiAgICAgICAgcmFpc2UgRW52aXJvbm1lbnRFcnJvcigKICAgICAgICAgICAgZiJDb3VsZCBub3QgZmluZCBPcGVuQUkgQVBJIGJhc2UgaW4gdGhlIGVudmlyb25tZW50IHZhcmlhYmxlcyBvciBzZWNyZXRzLCIKICAgICAgICAgICAgZiIgcGxlYXNlIHNldCBpdCBhczoge2Jhc2V9LiIKICAgICAgICApCiAgICAjIElmIHRoZSBrZXkgaXMgaW4gdGhlIHNlY3JldHMsIHNldCBpdCBpbiB0aGUgZW52aXJvbm1lbnQgdmFyaWFibGVzIGFuZCByZXR1cm4gVHJ1ZToKICAgIG9zLmVudmlyb25ba2V5XSA9IG9wZW5haV9rZXkKICAgIG9zLmVudmlyb25bYmFzZV0gPSBvcGVuYWlfYmFzZQogICAgcmV0dXJuIFRydWUKCgpkZWYgZ2VuZXJhdGVfZGF0YSgKICAgIGZpZWxkczogbGlzdCwKICAgIGFtb3VudDogaW50ID0gMTAsCiAgICBtb2RlbF9uYW1lOiBzdHIgPSAiZ3B0LTMuNS10dXJibyIsCiAgICBsYW5ndWFnZTogc3RyID0gImVuIiwKICAgIGNodW5rX3NpemU6IGludCA9IDUwLAopIC0+IGxpc3Q6CiAgICAiIiIKICAgIFN0cnVjdHVyZWQgZGF0YSBvZiBlbGVtZW50cyBhY2NvcmRpbmcgdG8gdGhlIGdpdmVuIHBhcmFtZXRlcnMuCiAgICBUaGUgZGF0YSBjYW4gYmUgbGF0ZXIgbG9nZ2VkIGFzIGEgc3RydWN0dXJlZCBmaWxlIHdpdGggTUxSdW4ncyBgcmV0dXJuc2AgcGFyYW1ldGVyLgoKICAgIDpwYXJhbSBmaWVsZHM6IEEgbGlzdCBvZiBmaWVsZHMgdG8gcmFuZG9tbHkgZ2VuZXJhdGUuCiAgICA6cGFyYW0gYW1vdW50OiBUaGUgbnVtYmVyIG9mIHZhcmlhbnRzIHRvIGdlbmVyYXRlLgogICAgOnBhcmFtIG1vZGVsX25hbWU6IFRoZSBuYW1lIG9mIHRoZSBtb2RlbCB0byB1c2UgZm9yIGNvbnZlcnNhdGlvbiBnZW5lcmF0aW9uLgogICAgICAgICAgICAgICAgICAgICAgIFlvdSBzaG91bGQgY2hvb3NlIG9uZSBvZiBHUFQtNCBvciBHUFQtMy41IGZyb20gdGhlIGxpc3QgaGVyZTogaHR0cHM6Ly9wbGF0Zm9ybS5vcGVuYWkuY29tL2RvY3MvbW9kZWxzLgogICAgICAgICAgICAgICAgICAgICAgIERlZmF1bHQ6ICdncHQtMy41LXR1cmJvJy4KICAgIDpwYXJhbSBsYW5ndWFnZTogVGhlIGxhbmd1YWdlIHRvIHVzZSBmb3IgdGhlIGdlbmVyYXRlZCBjb252ZXJzYXRpb24gdGV4dC4KICAgIDpwYXJhbSBjaHVua19zaXplOiBOdW1iZXIgb2Ygc2FtcGxlcyBnZW5lcmF0ZWQgYXQgZWFjaCBHUFQgcXVlcnkuCiAgICAiIiIKICAgIGluc3RydWN0aW9ucyA9ICIiCiAgICBmb3IgZmllbGQgaW4gZmllbGRzOgogICAgICAgICMgU3BsaXQgdGhlIGZpZWxkIHRvIGtleSBhbmQgaW5zdHJ1Y3Rpb246CiAgICAgICAgaWYgIjoiIGluIGZpZWxkOgogICAgICAgICAgICBrZXksIGluc3RydWN0aW9uID0gZmllbGQuc3BsaXQoIjoiLCAxKQogICAgICAgIGVsc2U6CiAgICAgICAgICAgIGtleSwgaW5zdHJ1Y3Rpb24gPSBmaWVsZCwgIm5vIHNwZWNpYWwgaW5zdHJ1Y3Rpb24iCiAgICAgICAgIyBSZXBsYWNlIHNwYWNlcyB3aXRoIHVuZGVyc2NvcmVzIGZvciB0aGUga2V5IHRvIGJlIHVzZWQgYXMgYSBqc29uIGtleToKICAgICAgICBrZXkgPSBrZXkuc3RyaXAoKS5yZXBsYWNlKCIgIiwgIl8iKQogICAgICAgIGluc3RydWN0aW9ucyArPSBmIioge2tleX06IHtpbnN0cnVjdGlvbn1cbiIKCiAgICAjIENyZWF0ZSB0aGUgcHJvbXB0IHN0cnVjdHVyZToKICAgIHByb21wdF9zdHJ1Y3R1cmUgPSAoCiAgICAgICAgZiJnZW5lcmF0ZSB0aGUgZm9sbG93aW5nIHZhbHVlcyB7YW1vdW50fSB0aW1lcyByYW5kb21seSwgaW4gYW4gb3JkZXIgdGhhdCBjcmVhdGVzIGEganNvbiB0YWJsZS5cbiIKICAgICAgICBmIlVzZSB0aGUgZm9sbG93aW5nIGtleXMgYW5kIGluc3RydWN0aW9ucyAoZXhhbXBsZTogJ2tleTogaW5zdHJ1Y3Rpb24gb3Igbm8gc3BlY2lhbCBpbnN0cnVjdGlvbicpOiAiCiAgICAgICAgZiJ7aW5zdHJ1Y3Rpb25zfS5cbiIKICAgICAgICBmIlBsZWFzZSBnZW5lcmF0ZSB0aGUgdmFsdWVzIGluIHtsYW5ndWFnZX0gbGFuZ3VhZ2UuIFxuIgogICAgICAgIGYiTWFrZSBzdXJlIHRoZSBuYW1lcyBvZiB0aGUga2V5cyBhcmUgdGhlIHNhbWUgYXMgdGhlIGdpdmVuIGZpZWxkIG5hbWUuXG4iCiAgICAgICAgZiJQbGVhc2UgcmV0dXJuIG9ubHkgdGhlIGpzb24gZm9ybWF0IHdpdGhvdXQgYW55IGludHJvZHVjdGlvbiBhbmQgZW5kaW5nIgogICAgKQoKICAgICMgU2V0IHRoZSBPcGVuQUkgc2VjcmV0czoKICAgIF9zZXRfb3BlbmFpX3NlY3JldHMoKQoKICAgICMgTG9hZCB0aGUgT3BlbkFJIG1vZGVsIHVzaW5nIGxhbmdjaGFpbjoKICAgIGxsbSA9IENoYXRPcGVuQUkobW9kZWw9bW9kZWxfbmFtZSkKCiAgICAjIFN0YXJ0IGdlbmVyYXRpbmcgZGF0YToKICAgIGRhdGEgPSBbXQogICAgZm9yIF8gaW4gdHFkbS50cWRtKHJhbmdlKChhbW91bnQgLy8gY2h1bmtfc2l6ZSkgKyAxKSwgZGVzYz0iR2VuZXJhdGluZyIpOgogICAgICAgICMgV2UgdHJ5IHRvIGdlbmVyYXRlIHRoZSBkYXRhIDMgdGltZXMsIGlmIHdlIGZhaWwgd2UgcmFpc2UgYW4gZXJyb3I6CiAgICAgICAgZm9yIHRyeW91dCBpbiByYW5nZSgzKToKICAgICAgICAgICAgIyBJZiB0aGUgYW1vdW50IHdhbnRlZCBpcyBiaWdnZXIgdGhhbiB0aGUgY2h1bmsgc2l6ZSwgd2UgZ2VuZXJhdGUgYSBjaHVuayBvZiBkYXRhIGluIHRoZSBzaXplIG9mIHRoZSBjaHVuawogICAgICAgICAgICAjIGFuZCBkZWNyZWFzZSB0aGUgYW1vdW50IGJ5IHRoZSBjaHVuayBzaXplLgogICAgICAgICAgICAjIG90aGVyd2lzZSB3ZSBnZW5lcmF0ZSBhIGNodW5rIG9mIGRhdGEgaW4gdGhlIHNpemUgb2YgdGhlIGFtb3VudDoKICAgICAgICAgICAgaWYgYW1vdW50ID4gY2h1bmtfc2l6ZToKICAgICAgICAgICAgICAgIGN1cnJlbnRfY2h1bmtfc2l6ZSA9IGNodW5rX3NpemUKICAgICAgICAgICAgICAgIGFtb3VudCAtPSBjaHVua19zaXplCiAgICAgICAgICAgIGVsc2U6CiAgICAgICAgICAgICAgICBjdXJyZW50X2NodW5rX3NpemUgPSBhbW91bnQKCiAgICAgICAgICAgICMgQ3JlYXRlIHRoZSBwcm9tcHQ6CiAgICAgICAgICAgIHByb21wdCA9IHByb21wdF9zdHJ1Y3R1cmUuZm9ybWF0KAogICAgICAgICAgICAgICAgYW1vdW50PWN1cnJlbnRfY2h1bmtfc2l6ZSwKICAgICAgICAgICAgKQoKICAgICAgICAgICAgIyBHZW5lcmF0ZSBhIGNodW5rIG9mIGRhdGE6CiAgICAgICAgICAgIGNodW5rX2RhdGEgPSBsbG0ucHJlZGljdCh0ZXh0PXByb21wdCkKCiAgICAgICAgICAgICMgVmFsaWRhdGUgdGhlIHJlc3BvbnNlIGZvciBjb3JyZWN0IHB5dGhvbiBgbGlzdGAgc3RydWN0dXJlCiAgICAgICAgICAgIGNodW5rX2RhdGEgPSBjaHVua19kYXRhW2NodW5rX2RhdGEuZmluZCgiWyIpIDogY2h1bmtfZGF0YS5yZmluZCgiXSIpICsgMV0KICAgICAgICAgICAgaWYgY2h1bmtfZGF0YS5jb3VudCgiWyIpICE9IGNodW5rX2RhdGEuY291bnQoIl0iKToKICAgICAgICAgICAgICAgIHByaW50KAogICAgICAgICAgICAgICAgICAgICJGYWlsZWQgdG8gZ2V0IHByb3BlciBqc29uIGZvcm1hdCBmcm9tIG1vZGVsLCBudW1iZXIgb2YgJ1snIGRvZXNuJ3QgbWF0Y2ggbnVtYmVyIG9mICddJy4iCiAgICAgICAgICAgICAgICApCiAgICAgICAgICAgICAgICBjb250aW51ZQogICAgICAgICAgICBjaHVua19kYXRhID0gYXN0LmxpdGVyYWxfZXZhbChjaHVua19kYXRhKQogICAgICAgICAgICBkYXRhICs9IGNodW5rX2RhdGEKICAgICAgICAgICAgYnJlYWsKICAgICAgICBpZiB0cnlvdXQgPT0gMzoKICAgICAgICAgICAgcmFpc2UgUnVudGltZUVycm9yKAogICAgICAgICAgICAgICAgZiJDb3VsZCBub3QgZ2VuZXJhdGUgYSBwcm9wZXIganNvbiBmb3JtYXQgZm9yIHRoZSBnaXZlbiBmaWVsZHMsIHVzaW5nIGdpdmVuIG1vZGVsOiB7bW9kZWxfbmFtZX0uIgogICAgICAgICAgICAgICAgZiIgSGludDogR3B0LTQgd29ya3MgYmVzdCBmb3IgbW9zdCBzY2VuYXJpb3MuIgogICAgICAgICAgICApCiAgICByZXR1cm4gZGF0YQo= base_image: mlrun/mlrun + filename: structured_data_generator.py entry_points: generate_data: - has_varargs: false - name: generate_data - has_kwargs: false - doc: 'Structured data of elements according to the given parameters. - - The data can be later logged as a structured file with MLRun''s `returns` - parameter.' + outputs: + - type: list parameters: - name: fields type: list @@ -38,19 +44,14 @@ spec: type: int doc: Number of samples generated at each GPT query. default: 50 - outputs: - - type: list + name: generate_data + doc: 'Structured data of elements according to the given parameters. + + The data can be later logged as a structured file with MLRun''s `returns` + parameter.' + has_kwargs: false + has_varargs: false lineno: 59 command: '' description: GenAI approach of generating structured data according to a given schema default_handler: generate_data - disable_auto_mount: false - image: '' -metadata: - name: structured-data-generator - tag: '' - categories: - - data-generation - - genai -verbose: false -kind: job diff --git a/functions/src/structured_data_generator/structured_data_generator.py b/functions/src/structured_data_generator/structured_data_generator.py index 34fa36d49..d817ef274 100644 --- a/functions/src/structured_data_generator/structured_data_generator.py +++ b/functions/src/structured_data_generator/structured_data_generator.py @@ -28,7 +28,7 @@ def _set_openai_secrets() -> bool: try: import mlrun except ModuleNotFoundError: - raise EnvironmentError( + raise OSError( f"One or more of the OpenAI required environment variables ('{key}', '{base}') are missing." f"Please set them as environment variables or install mlrun (`pip install mlrun`)" f"and set them as project secrets using `projecy.set_secrets`." @@ -41,12 +41,12 @@ def _set_openai_secrets() -> bool: # If the key is not in the secrets, return False: if not openai_key: - raise EnvironmentError( + raise OSError( f"Could not find OpenAI API key in the environment variables or secrets," f" please set it as: {key}." ) if not openai_base: - raise EnvironmentError( + raise OSError( f"Could not find OpenAI API base in the environment variables or secrets," f" please set it as: {base}." ) diff --git a/functions/src/structured_data_generator/test_structured_data_generator.py b/functions/src/structured_data_generator/test_structured_data_generator.py index 3a7a7aa57..b1ddaba8a 100644 --- a/functions/src/structured_data_generator/test_structured_data_generator.py +++ b/functions/src/structured_data_generator/test_structured_data_generator.py @@ -1,4 +1,5 @@ import os + import mlrun import pytest @@ -8,11 +9,13 @@ def test_structured_data_generator(): # Create mlrun project project = mlrun.get_or_create_project("structured-data-generator-test") - #Set secrets + # Set secrets # project.set_secrets({"OPENAI_API_KEY": "", "OPENAI_API_BASE": ""}) # Import the function from the yaml file, once it's in the hub we can import from there - data_generation = project.set_function(func="structured_data_generator.py", name="structured_data_generator") + data_generation = project.set_function( + func="structured_data_generator.py", name="structured_data_generator" + ) # Run the imported function with desired file/s and params data_generation_run = data_generation.run( @@ -26,7 +29,7 @@ def test_structured_data_generator(): "last_name", "phone_number: at least 9 digits long", "email", - "client_id: at least 8 digits long, only numbers" + "client_id: at least 8 digits long, only numbers", ], }, returns=[ @@ -34,4 +37,4 @@ def test_structured_data_generator(): ], local=True, ) - assert data_generation_run.outputs["clients"] \ No newline at end of file + assert data_generation_run.outputs["clients"] diff --git a/functions/src/test_classifier/function.yaml b/functions/src/test_classifier/function.yaml index f35446b51..33b625c80 100644 --- a/functions/src/test_classifier/function.yaml +++ b/functions/src/test_classifier/function.yaml @@ -1,49 +1,35 @@ -kind: job metadata: - name: test-classifier tag: '' - hash: b4d447a2328975e90a0dbc7a28f82009924cc157 - project: '' - labels: - author: Iguazio - framework: sklearn + name: test-classifier categories: - machine-learning - model-testing +verbose: false +kind: job spec: - command: '' - args: [] image: mlrun/mlrun - env: [] - default_handler: test_classifier + disable_auto_mount: false + build: + origin_filename: '' + functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKIyBHZW5lcmF0ZWQgYnkgbnVjbGlvLmV4cG9ydC5OdWNsaW9FeHBvcnRlcgoKaW1wb3J0IHdhcm5pbmdzCgp3YXJuaW5ncy5maWx0ZXJ3YXJuaW5ncygiaWdub3JlIikKCmltcG9ydCBwYW5kYXMgYXMgcGQKZnJvbSBjbG91ZHBpY2tsZSBpbXBvcnQgbG9hZApmcm9tIG1scnVuLmFydGlmYWN0cyBpbXBvcnQgZ2V0X21vZGVsLCB1cGRhdGVfbW9kZWwKZnJvbSBtbHJ1bi5kYXRhc3RvcmUgaW1wb3J0IERhdGFJdGVtCmZyb20gbWxydW4ubWx1dGlscy5tb2RlbHMgaW1wb3J0IGV2YWxfbW9kZWxfdjIKCgpkZWYgdGVzdF9jbGFzc2lmaWVyKAogICAgY29udGV4dCwKICAgIG1vZGVsc19wYXRoOiBEYXRhSXRlbSwKICAgIHRlc3Rfc2V0OiBEYXRhSXRlbSwKICAgIGxhYmVsX2NvbHVtbjogc3RyLAogICAgc2NvcmVfbWV0aG9kOiBzdHIgPSAibWljcm8iLAogICAgcGxvdHNfZGVzdDogc3RyID0gIiIsCiAgICBtb2RlbF9ldmFsdWF0b3I9Tm9uZSwKICAgIGRlZmF1bHRfbW9kZWw6IHN0ciA9ICJtb2RlbC5wa2wiLAogICAgcHJlZGljdGlvbnNfY29sdW1uOiBzdHIgPSAieXNjb3JlIiwKICAgIG1vZGVsX3VwZGF0ZT1UcnVlLAopIC0+IE5vbmU6CiAgICAiIiJUZXN0IG9uZSBvciBtb3JlIGNsYXNzaWZpZXIgbW9kZWxzIGFnYWluc3QgaGVsZC1vdXQgZGF0YXNldAoKICAgIFVzaW5nIGhlbGQtb3V0IHRlc3QgZmVhdHVyZXMsIGV2YWx1YXRlcyB0aGUgcGVmb3JtYW5jZSBvZiB0aGUgZXN0aW1hdGVkIG1vZGVsCgogICAgQ2FuIGJlIHBhcnQgb2YgYSBrdWJlZmxvdyBwaXBlbGluZSBhcyBhIHRlc3Qgc3RlcCB0aGF0IGlzIHJ1biBwb3N0IEVEQSBhbmQKICAgIHRyYWluaW5nL3ZhbGlkYXRpb24gY3ljbGVzCgogICAgOnBhcmFtIGNvbnRleHQ6ICAgICAgICAgICAgdGhlIGZ1bmN0aW9uIGNvbnRleHQKICAgIDpwYXJhbSBtb2RlbHNfcGF0aDogICAgICAgIGFydGlmYWN0IG1vZGVscyByZXByZXNlbnRpbmcgYSBmaWxlIG9yIGEgZm9sZGVyCiAgICA6cGFyYW0gdGVzdF9zZXQ6ICAgICAgICAgICB0ZXN0IGZlYXR1cmVzIGFuZCBsYWJlbHMKICAgIDpwYXJhbSBsYWJlbF9jb2x1bW46ICAgICAgIGNvbHVtbiBuYW1lIGZvciBncm91bmQgdHJ1dGggbGFiZWxzCiAgICA6cGFyYW0gc2NvcmVfbWV0aG9kOiAgICAgICBmb3IgbXVsdGljbGFzcyBjbGFzc2lmaWNhdGlvbgogICAgOnBhcmFtIHBsb3RzX2Rlc3Q6ICAgICAgICAgZGlyIGZvciB0ZXN0IHBsb3RzCiAgICA6cGFyYW0gbW9kZWxfZXZhbHVhdG9yOiAgICBOT1QgSU1QTEVNRU5URUQ6IHNwZWNpZmljIG1ldGhvZCB0byBnZW5lcmF0ZSBldmFsLCBwYXNzZWQgaW4gYXMgc3RyaW5nCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBvciBhdmFpbGFibGUgaW4gdGhpcyBmb2xkZXIKICAgIDpwYXJhbSBwcmVkaWN0aW9uc19jb2x1bW46IGNvbHVtbiBuYW1lIGZvciB0aGUgcHJlZGljdGlvbnMgY29sdW1uIG9uIHRoZSByZXN1bHRlZCBhcnRpZmFjdAogICAgOnBhcmFtIG1vZGVsX3VwZGF0ZTogICAgICAgKFRydWUpIHVwZGF0ZSBtb2RlbCwgd2hlbiBydW5uaW5nIGFzIHN0YW5kIGFsb25lIG5vIG5lZWQgaW4gdXBkYXRlCiAgICAiIiIKICAgIHh0ZXN0ID0gdGVzdF9zZXQuYXNfZGYoKQogICAgeXRlc3QgPSB4dGVzdC5wb3AobGFiZWxfY29sdW1uKQoKICAgIHRyeToKICAgICAgICBtb2RlbF9maWxlLCBtb2RlbF9vYmosIF8gPSBnZXRfbW9kZWwobW9kZWxzX3BhdGgsIHN1ZmZpeD0iLnBrbCIpCiAgICAgICAgbW9kZWxfb2JqID0gbG9hZChvcGVuKG1vZGVsX2ZpbGUsICJyYiIpKQogICAgZXhjZXB0IEV4Y2VwdGlvbjoKICAgICAgICByYWlzZSBFeGNlcHRpb24oIm1vZGVsIGxvY2F0aW9uIGxpa2VseSBzcGVjaWZpZWQiKQoKICAgIGV4dHJhX2RhdGEgPSBldmFsX21vZGVsX3YyKGNvbnRleHQsIHh0ZXN0LCB5dGVzdC52YWx1ZXMsIG1vZGVsX29iaikKICAgIGlmIG1vZGVsX29iaiBhbmQgbW9kZWxfdXBkYXRlID09IFRydWU6CiAgICAgICAgdXBkYXRlX21vZGVsKAogICAgICAgICAgICBtb2RlbHNfcGF0aCwKICAgICAgICAgICAgZXh0cmFfZGF0YT1leHRyYV9kYXRhLAogICAgICAgICAgICBtZXRyaWNzPWNvbnRleHQucmVzdWx0cywKICAgICAgICAgICAga2V5X3ByZWZpeD0idmFsaWRhdGlvbi0iLAogICAgICAgICkKCiAgICB5X2hhdCA9IG1vZGVsX29iai5wcmVkaWN0KHh0ZXN0KQogICAgaWYgeV9oYXQubmRpbSA9PSAxIG9yIHlfaGF0LnNoYXBlWzFdID09IDE6CiAgICAgICAgc2NvcmVfbmFtZXMgPSBbcHJlZGljdGlvbnNfY29sdW1uXQogICAgZWxzZToKICAgICAgICBzY29yZV9uYW1lcyA9IFtmIntwcmVkaWN0aW9uc19jb2x1bW59XyIgKyBzdHIoeCkgZm9yIHggaW4gcmFuZ2UoeV9oYXQuc2hhcGVbMV0pXQoKICAgIGRmID0gcGQuY29uY2F0KFt4dGVzdCwgeXRlc3QsIHBkLkRhdGFGcmFtZSh5X2hhdCwgY29sdW1ucz1zY29yZV9uYW1lcyldLCBheGlzPTEpCiAgICBjb250ZXh0LmxvZ19kYXRhc2V0KCJ0ZXN0X3NldF9wcmVkcyIsIGRmPWRmLCBmb3JtYXQ9InBhcnF1ZXQiLCBpbmRleD1GYWxzZSkK + code_origin: '' + filename: test_classifier.py entry_points: test_classifier: - name: test_classifier - doc: 'Test one or more classifier models against held-out dataset - - - Using held-out test features, evaluates the peformance of the estimated model - - - Can be part of a kubeflow pipeline as a test step that is run post EDA and - - training/validation cycles' + outputs: + - type: None parameters: - name: context doc: the function context - default: '' - name: models_path type: DataItem doc: artifact models representing a file or a folder - default: '' - name: test_set type: DataItem doc: test features and labels - default: '' - name: label_column type: str doc: column name for ground truth labels - default: '' - name: score_method type: str doc: for multiclass classification @@ -66,13 +52,19 @@ spec: - name: model_update doc: (True) update model, when running as stand alone no need in update default: true - outputs: - - default: '' - lineno: 17 + name: test_classifier + doc: 'Test one or more classifier models against held-out dataset + + + Using held-out test features, evaluates the peformance of the estimated model + + + Can be part of a kubeflow pipeline as a test step that is run post EDA and + + training/validation cycles' + has_kwargs: false + has_varargs: false + lineno: 28 + command: '' description: test a classifier using held-out or new data - build: - functionSourceCode: IyBHZW5lcmF0ZWQgYnkgbnVjbGlvLmV4cG9ydC5OdWNsaW9FeHBvcnRlcgoKaW1wb3J0IHdhcm5pbmdzCgp3YXJuaW5ncy5maWx0ZXJ3YXJuaW5ncygiaWdub3JlIikKCmltcG9ydCBvcwppbXBvcnQgcGFuZGFzIGFzIHBkCgpmcm9tIG1scnVuLmRhdGFzdG9yZSBpbXBvcnQgRGF0YUl0ZW0KZnJvbSBtbHJ1bi5hcnRpZmFjdHMgaW1wb3J0IGdldF9tb2RlbCwgdXBkYXRlX21vZGVsCmZyb20gbWxydW4ubWx1dGlscy5tb2RlbHMgaW1wb3J0IGV2YWxfbW9kZWxfdjIKZnJvbSBjbG91ZHBpY2tsZSBpbXBvcnQgbG9hZApmcm9tIHVybGxpYi5yZXF1ZXN0IGltcG9ydCB1cmxvcGVuCgoKZGVmIHRlc3RfY2xhc3NpZmllcigKICAgIGNvbnRleHQsCiAgICBtb2RlbHNfcGF0aDogRGF0YUl0ZW0sCiAgICB0ZXN0X3NldDogRGF0YUl0ZW0sCiAgICBsYWJlbF9jb2x1bW46IHN0ciwKICAgIHNjb3JlX21ldGhvZDogc3RyID0gIm1pY3JvIiwKICAgIHBsb3RzX2Rlc3Q6IHN0ciA9ICIiLAogICAgbW9kZWxfZXZhbHVhdG9yPU5vbmUsCiAgICBkZWZhdWx0X21vZGVsOiBzdHIgPSAibW9kZWwucGtsIiwKICAgIHByZWRpY3Rpb25zX2NvbHVtbjogc3RyID0gInlzY29yZSIsCiAgICBtb2RlbF91cGRhdGU9VHJ1ZSwKKSAtPiBOb25lOgogICAgIiIiVGVzdCBvbmUgb3IgbW9yZSBjbGFzc2lmaWVyIG1vZGVscyBhZ2FpbnN0IGhlbGQtb3V0IGRhdGFzZXQKCiAgICBVc2luZyBoZWxkLW91dCB0ZXN0IGZlYXR1cmVzLCBldmFsdWF0ZXMgdGhlIHBlZm9ybWFuY2Ugb2YgdGhlIGVzdGltYXRlZCBtb2RlbAoKICAgIENhbiBiZSBwYXJ0IG9mIGEga3ViZWZsb3cgcGlwZWxpbmUgYXMgYSB0ZXN0IHN0ZXAgdGhhdCBpcyBydW4gcG9zdCBFREEgYW5kCiAgICB0cmFpbmluZy92YWxpZGF0aW9uIGN5Y2xlcwoKICAgIDpwYXJhbSBjb250ZXh0OiAgICAgICAgICAgIHRoZSBmdW5jdGlvbiBjb250ZXh0CiAgICA6cGFyYW0gbW9kZWxzX3BhdGg6ICAgICAgICBhcnRpZmFjdCBtb2RlbHMgcmVwcmVzZW50aW5nIGEgZmlsZSBvciBhIGZvbGRlcgogICAgOnBhcmFtIHRlc3Rfc2V0OiAgICAgICAgICAgdGVzdCBmZWF0dXJlcyBhbmQgbGFiZWxzCiAgICA6cGFyYW0gbGFiZWxfY29sdW1uOiAgICAgICBjb2x1bW4gbmFtZSBmb3IgZ3JvdW5kIHRydXRoIGxhYmVscwogICAgOnBhcmFtIHNjb3JlX21ldGhvZDogICAgICAgZm9yIG11bHRpY2xhc3MgY2xhc3NpZmljYXRpb24KICAgIDpwYXJhbSBwbG90c19kZXN0OiAgICAgICAgIGRpciBmb3IgdGVzdCBwbG90cwogICAgOnBhcmFtIG1vZGVsX2V2YWx1YXRvcjogICAgTk9UIElNUExFTUVOVEVEOiBzcGVjaWZpYyBtZXRob2QgdG8gZ2VuZXJhdGUgZXZhbCwgcGFzc2VkIGluIGFzIHN0cmluZwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgb3IgYXZhaWxhYmxlIGluIHRoaXMgZm9sZGVyCiAgICA6cGFyYW0gcHJlZGljdGlvbnNfY29sdW1uOiBjb2x1bW4gbmFtZSBmb3IgdGhlIHByZWRpY3Rpb25zIGNvbHVtbiBvbiB0aGUgcmVzdWx0ZWQgYXJ0aWZhY3QKICAgIDpwYXJhbSBtb2RlbF91cGRhdGU6ICAgICAgIChUcnVlKSB1cGRhdGUgbW9kZWwsIHdoZW4gcnVubmluZyBhcyBzdGFuZCBhbG9uZSBubyBuZWVkIGluIHVwZGF0ZQogICAgIiIiCiAgICB4dGVzdCA9IHRlc3Rfc2V0LmFzX2RmKCkKICAgIHl0ZXN0ID0geHRlc3QucG9wKGxhYmVsX2NvbHVtbikKCiAgICB0cnk6CiAgICAgICAgbW9kZWxfZmlsZSwgbW9kZWxfb2JqLCBfID0gZ2V0X21vZGVsKG1vZGVsc19wYXRoLCBzdWZmaXg9Ii5wa2wiKQogICAgICAgIG1vZGVsX29iaiA9IGxvYWQob3Blbihtb2RlbF9maWxlLCAicmIiKSkKICAgIGV4Y2VwdCBFeGNlcHRpb24gYXMgYToKICAgICAgICByYWlzZSBFeGNlcHRpb24oIm1vZGVsIGxvY2F0aW9uIGxpa2VseSBzcGVjaWZpZWQiKQoKICAgIGV4dHJhX2RhdGEgPSBldmFsX21vZGVsX3YyKGNvbnRleHQsIHh0ZXN0LCB5dGVzdC52YWx1ZXMsIG1vZGVsX29iaikKICAgIGlmIG1vZGVsX29iaiBhbmQgbW9kZWxfdXBkYXRlID09IFRydWU6CiAgICAgICAgdXBkYXRlX21vZGVsKAogICAgICAgICAgICBtb2RlbHNfcGF0aCwKICAgICAgICAgICAgZXh0cmFfZGF0YT1leHRyYV9kYXRhLAogICAgICAgICAgICBtZXRyaWNzPWNvbnRleHQucmVzdWx0cywKICAgICAgICAgICAga2V5X3ByZWZpeD0idmFsaWRhdGlvbi0iLAogICAgICAgICkKCiAgICB5X2hhdCA9IG1vZGVsX29iai5wcmVkaWN0KHh0ZXN0KQogICAgaWYgeV9oYXQubmRpbSA9PSAxIG9yIHlfaGF0LnNoYXBlWzFdID09IDE6CiAgICAgICAgc2NvcmVfbmFtZXMgPSBbcHJlZGljdGlvbnNfY29sdW1uXQogICAgZWxzZToKICAgICAgICBzY29yZV9uYW1lcyA9IFtmIntwcmVkaWN0aW9uc19jb2x1bW59XyIgKyBzdHIoeCkgZm9yIHggaW4gcmFuZ2UoeV9oYXQuc2hhcGVbMV0pXQoKICAgIGRmID0gcGQuY29uY2F0KFt4dGVzdCwgeXRlc3QsIHBkLkRhdGFGcmFtZSh5X2hhdCwgY29sdW1ucz1zY29yZV9uYW1lcyldLCBheGlzPTEpCiAgICBjb250ZXh0LmxvZ19kYXRhc2V0KCJ0ZXN0X3NldF9wcmVkcyIsIGRmPWRmLCBmb3JtYXQ9InBhcnF1ZXQiLCBpbmRleD1GYWxzZSkK - commands: [] - code_origin: https://github.com/daniels290813/functions.git#55a79c32be5d233cc11efcf40cd3edbe309bfdef:/home/kali/functions/test_classifier/test_classifier.py - affinity: null -verbose: false + default_handler: test_classifier diff --git a/functions/src/test_classifier/test_classifier.py b/functions/src/test_classifier/test_classifier.py index 322ecefc5..c11a6d99e 100644 --- a/functions/src/test_classifier/test_classifier.py +++ b/functions/src/test_classifier/test_classifier.py @@ -18,14 +18,11 @@ warnings.filterwarnings("ignore") -import os import pandas as pd - -from mlrun.datastore import DataItem +from cloudpickle import load from mlrun.artifacts import get_model, update_model +from mlrun.datastore import DataItem from mlrun.mlutils.models import eval_model_v2 -from cloudpickle import load -from urllib.request import urlopen def test_classifier( @@ -64,7 +61,7 @@ def test_classifier( try: model_file, model_obj, _ = get_model(models_path, suffix=".pkl") model_obj = load(open(model_file, "rb")) - except Exception as a: + except Exception: raise Exception("model location likely specified") extra_data = eval_model_v2(context, xtest, ytest.values, model_obj) diff --git a/functions/src/text_to_audio_generator/function.yaml b/functions/src/text_to_audio_generator/function.yaml index 8edbde74f..1a4c2fc72 100644 --- a/functions/src/text_to_audio_generator/function.yaml +++ b/functions/src/text_to_audio_generator/function.yaml @@ -1,21 +1,40 @@ +metadata: + tag: '' + name: text-to-audio-generator + categories: + - data-generation + - audio +verbose: false +kind: job spec: - default_handler: generate_multi_speakers_audio + image: '' disable_auto_mount: false + build: + origin_filename: '' + functionSourceCode: # Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import io
import logging
import os
import pathlib
import random
import tempfile
from abc import ABC, abstractmethod

import numpy as np
import pandas as pd
import torch
import torchaudio
import tqdm

# Get the global logger:
_LOGGER = logging.getLogger()

OPENAI_API_KEY = "OPENAI_API_KEY"
OPENAI_BASE_URL = "OPENAI_API_BASE"
SAMPLE_RATE = 24000


def generate_multi_speakers_audio(
    data_path: str,
    speakers: list[str] | dict[str, int],
    available_voices: list[str],
    engine: str = "openai",
    output_directory: str = None,
    use_gpu: bool | None = None,
    use_small_models: bool | None = None,
    offload_cpu: bool | None = None,
    model: str | None = None,
    speed: float | None = None,
    sample_rate: int = 16000,
    file_format: str = "wav",
    verbose: bool = True,
    bits_per_sample: int | None = None,
) -> tuple[str, pd.DataFrame, dict]:
    """
    Generate audio files from text files.

    :param data_path:           Path to the text file or directory containing the text files to generate audio from.
    :param speakers:            List / Dict of speakers to generate audio for.
                                If a list is given, the speakers will be assigned to channels in the order given.
                                If dictionary, the keys will be the speakers and the values will be the channels.
    :param available_voices:    List of available voices to use for the generation.
                        See here for the available voices for bark engine:
                        https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c
                        See here for the available voices for openai engine:
                        https://beta.openai.com/docs/api-reference/speech
    :param engine:              The engine to use for the generation. Select either "bark" or "openai". Default is "openai".
    :param output_directory:    Path to the directory to save the generated audio files to.
    :param use_gpu:             Whether to use the GPU for the generation. Supported only in "bark" engine.
    :param use_small_models:    Whether to use the small models for the generation. Supported only in "bark" engine.
    :param offload_cpu:         To reduce the memory footprint, the models can be offloaded to the CPU after loading.
                                Supported only in "bark" engine.
    :param model:               Which model to use for the generation. Supported only in "openai" engine.
                                Default is "tts-1".
    :param speed:               The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is the default.
    :param sample_rate:         The sampling rate of the generated audio.
    :param file_format:         The format of the generated audio files.
    :param verbose:             Whether to print the progress of the generation.
    :param bits_per_sample:     Changes the bit depth for the supported formats.
                                Supported only in "wav" or "flac" formats.

    :returns:                   A tuple of:
                                - The output directory path.
                                - The generated audio files dataframe.
                                - The errors' dictionary.
    """

    global _LOGGER
    _LOGGER = _get_logger()
    # Get the input text files to turn to audio:
    data_path = pathlib.Path(data_path).absolute()
    text_files = _get_text_files(data_path=data_path)

    # Prepare the speech engine:
    engine = _get_engine(
        engine=engine,
        use_gpu=use_gpu,
        use_small_models=use_small_models,
        offload_cpu=offload_cpu,
        model=model,
        file_format=file_format,
        speed=speed,
    )

    # Check for per channel generation:
    if isinstance(speakers, dict):
        speaker_per_channel = True
        # Sort the given speakers by channels:
        speakers = {
            speaker: channel
            for speaker, channel in sorted(speakers.items(), key=lambda item: item[1])
        }
    else:
        speaker_per_channel = False

    # Prepare the resampling module:
    resampler = torchaudio.transforms.Resample(
        orig_freq=SAMPLE_RATE, new_freq=sample_rate, dtype=torch.float32
    )

    # Prepare the gap between each speaker:
    gap_between_speakers = np.zeros(int(0.5 * SAMPLE_RATE))

    # Prepare the successes dataframe and errors dictionary to be returned:
    successes = []
    errors = {}

    # Create the output directory:
    if output_directory is None:
        output_directory = tempfile.mkdtemp()
    output_directory = pathlib.Path(output_directory)
    if not output_directory.exists():
        output_directory.mkdir(exist_ok=True, parents=True)

    # Start generating audio:
    # Go over the audio files and transcribe:
    for text_file in tqdm.tqdm(
        text_files, desc="Generating", unit="file", disable=not verbose
    ):
        try:
            # Randomize voices for each speaker:
            chosen_voices = {}
            available_voices_copy = available_voices.copy()
            for speaker in speakers:
                voice = random.choice(available_voices_copy)
                chosen_voices[speaker] = voice
                available_voices_copy.remove(voice)
            # Read text:
            with open(text_file) as fp:
                text = fp.read()
            # Prepare a holder for all the generated pieces (if per channel each speaker will have its own):
            audio_pieces = (
                {speaker: [] for speaker in speakers}
                if speaker_per_channel
                else {"all": []}
            )

            # Generate audio per line:
            for line in text.splitlines():
                # Validate line is in correct speaker format:

                if ": " not in line:
                    if verbose:
                        _LOGGER.warning(f"Skipping line: {line}")
                    continue
                # Split line to speaker and his words:
                current_speaker, sentences = line.split(": ", 1)
                # Validate speaker is known:
                if current_speaker not in speakers:
                    raise ValueError(
                        f"Unknown speaker: {current_speaker}. Given speakers are: {speakers}"
                    )
                for sentence in _split_line(line=sentences):
                    # Generate words audio:
                    audio = engine._generate_audio(
                        text=sentence,
                        voice=chosen_voices[current_speaker],
                    )

                    if speaker_per_channel:
                        silence = np.zeros_like(audio)
                        for speaker in audio_pieces.keys():
                            if speaker == current_speaker:
                                audio_pieces[speaker] += [audio, gap_between_speakers]
                            else:
                                audio_pieces[speaker] += [silence, gap_between_speakers]
                    else:
                        audio_pieces["all"] += [audio, gap_between_speakers]
            # Construct a single audio array from all the pieces and channels:

            audio = np.vstack(
                [np.concatenate(audio_pieces[speaker]) for speaker in speakers]
            ).astype(dtype=np.float32)
            # Resample:
            audio = torch.from_numpy(audio)
            audio = resampler(audio)
            # Save to audio file:
            audio_file = output_directory / f"{text_file.stem}.{file_format}"

            torchaudio.save(
                uri=str(audio_file),
                src=audio,
                sample_rate=sample_rate,
                format=file_format,
                bits_per_sample=bits_per_sample,
            )

            # Collect to the successes:
            successes.append([text_file.name, audio_file.name])
        except Exception as exception:
            # Note the exception as error in the dictionary:
            if verbose:
                _LOGGER.warning(f"Error in file: '{text_file.name}'")
            print(exception)
            errors[text_file.name] = str(exception)

    # Construct the translations dataframe:
    successes = pd.DataFrame(
        successes,
        columns=["text_file", "audio_file"],
    )

    # Print the head of the produced dataframe and return:
    if verbose:
        _LOGGER.info(
            f"Done ({successes.shape[0]}/{len(text_files)})\n"
            f"Translations summary:\n"
            f"{successes.head()}"
        )
    return str(output_directory), successes, errors


class SpeechEngine(ABC):
    @abstractmethod
    def _generate_audio(self, text: str, voice: str) -> np.ndarray:
        pass


class BarkEngine(SpeechEngine):
    def __init__(
        self,
        use_gpu: bool = True,
        use_small_models: bool = False,
        offload_cpu: bool = False,
    ):
        try:
            self.bark = importlib.import_module("bark")
        except ImportError:
            raise ImportError(
                "The 'bark' library is required for the BarkEngine. Please install it using 'pip install bark-ai'."
            )

        self.bark.preload_models(
            text_use_gpu=use_gpu,
            text_use_small=use_small_models,
            coarse_use_gpu=use_gpu,
            coarse_use_small=use_small_models,
            fine_use_gpu=use_gpu,
            fine_use_small=use_small_models,
            codec_use_gpu=use_gpu,
            force_reload=offload_cpu,
        )

    def _generate_audio(self, text: str, voice: str) -> np.ndarray:
        # Generate words audio:
        audio = self.bark.generate_audio(
            text,
            history_prompt=voice,
            silent=True,
        )
        return audio


class OpenAIEngine(SpeechEngine):
    def __init__(
        self, model: str = "tts-1", file_format: str = "wav", speed: float = 1.0
    ):
        try:
            self.openai = importlib.import_module("openai")
            self.pydub = importlib.import_module("pydub")
        except ImportError:
            raise ImportError(
                "The 'openai' and 'pydub' libraries are required for the OpenAIEngine. Please install them using 'pip install openai pydub'."
            )

        api_key = os.getenv(OPENAI_API_KEY)
        base_url = os.getenv(OPENAI_BASE_URL)
        # Check if the key is already in the environment variables:
        if not api_key or not base_url:
            try:
                import mlrun

                context = mlrun.get_or_create_ctx(name="context")
                # Check if the key is in the secrets:
                api_key = context.get_secret(OPENAI_API_KEY)
                base_url = context.get_secret(OPENAI_BASE_URL)
            except ModuleNotFoundError:
                raise OSError(
                    f"One or more of the OpenAI required environment variables ('{OPENAI_API_KEY}', '{OPENAI_BASE_URL}') are missing."
                    f"Please set them as environment variables or install mlrun (`pip install mlrun`)"
                    f"and set them as project secrets using `project.set_secrets`."
                )

        self.client = self.openai.OpenAI(api_key=api_key, base_url=base_url)
        self.model = model
        self.file_format = file_format
        self.speed = speed

    def _generate_audio(self, text: str, voice: str) -> np.ndarray:
        # Generate words audio:
        audio = self.client.audio.speech.create(
            model=self.model,
            input=text,
            voice=voice,
            response_format=self.file_format,
            speed=self.speed,
        )
        audio = audio.content
        audio = self._bytes_to_np_array(audio=audio)
        return audio

    def _bytes_to_np_array(self, audio: bytes):
        if self.file_format == "mp3":
            audio_segment = self.pydub.AudioSegment.from_mp3(io.BytesIO(audio))

            # Convert to raw PCM audio data
            samples = audio_segment.get_array_of_samples()

            # Convert to numpy array
            audio_array = np.array(samples)

            # Normalize to float between -1 and 1
            return audio_array.astype(np.float32) / np.iinfo(samples.typecode).max
        else:
            return np.frombuffer(audio, dtype=np.int16) / 32768.0


def _get_engine(engine: str, file_format: str, **kwargs) -> SpeechEngine:
    # eliminate the None values:
    kwargs = {key: value for key, value in kwargs.items() if value is not None}

    if engine == "bark":
        return BarkEngine(**kwargs)
    elif engine == "openai":
        return OpenAIEngine(file_format=file_format, **kwargs)
    else:
        raise ValueError(
            f"Unrecognized engine. The parameter `engine` must be either 'bark' or 'openai'. Given: {engine}"
        )


def _get_text_files(
    data_path: pathlib.Path,
) -> list[pathlib.Path]:
    # Check if the path is of a directory or a file:
    if data_path.is_dir():
        # Get all files inside the directory:
        text_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        text_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be either a directory path or a file path. "
            f"Given: {str(data_path)} "
        )

    return text_files


def _split_line(line: str, max_length: int = 250) -> list[str]:
    if len(line) < max_length:
        return [line]

    sentences = [
        f"{sentence.strip()}." for sentence in line.split(".") if sentence.strip()
    ]

    splits = []
    current_length = len(sentences[0])
    split = sentences[0]
    for sentence in sentences[1:]:
        if current_length + len(sentence) > max_length:
            splits.append(split)
            split = sentence
            current_length = len(sentence)
        else:
            current_length += len(sentence)
            split += " " + sentence
    if split:
        splits.append(split)

    return splits


def _get_logger():
    global _LOGGER
    try:
        import mlrun

        # Check if MLRun is available:
        context = mlrun.get_or_create_ctx(name="mlrun")
        return context.logger
    except ModuleNotFoundError:
        return _LOGGER
 + requirements: + - torchaudio + - pydub + code_origin: '' + base_image: mlrun/mlrun + filename: text_to_audio_generator.py entry_points: generate_multi_speakers_audio: - lineno: 38 + outputs: + - doc: 'A tuple of: - The output directory path. - The generated audio files + dataframe. - The errors'' dictionary.' + type: tuple[str, pd.DataFrame, dict] parameters: - name: data_path type: str doc: Path to the text file or directory containing the text files to generate audio from. - name: speakers - type: Union[List[str], Dict[str, int]] doc: List / Dict of speakers to generate audio for. If a list is given, the speakers will be assigned to channels in the order given. If dictionary, the keys will be the speakers and the values will be the channels. - name: available_voices - type: List[str] + type: list[str] doc: 'List of available voices to use for the generation. See here for the available voices for bark engine: https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c See here for the available voices for openai engine: https://beta.openai.com/docs/api-reference/speech' @@ -29,26 +48,21 @@ spec: doc: Path to the directory to save the generated audio files to. default: null - name: use_gpu - type: Optional[bool] doc: Whether to use the GPU for the generation. Supported only in "bark" engine. default: null - name: use_small_models - type: Optional[bool] doc: Whether to use the small models for the generation. Supported only in "bark" engine. default: null - name: offload_cpu - type: Optional[bool] doc: To reduce the memory footprint, the models can be offloaded to the CPU after loading. Supported only in "bark" engine. default: null - name: model - type: Optional[str] doc: Which model to use for the generation. Supported only in "openai" engine. Default is "tts-1". default: null - name: speed - type: Optional[float] doc: The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is the default. default: null @@ -65,34 +79,14 @@ spec: doc: Whether to print the progress of the generation. default: true - name: bits_per_sample - type: Optional[int] doc: Changes the bit depth for the supported formats. Supported only in "wav" or "flac" formats. default: null name: generate_multi_speakers_audio + doc: Generate audio files from text files. has_kwargs: false has_varargs: false - outputs: - - doc: 'A tuple of: - The output directory path. - The generated audio files - dataframe. - The errors'' dictionary.' - type: Tuple[str, pd.DataFrame, dict] - doc: Generate audio files from text files. + lineno: 37 command: '' - image: '' description: Generate audio file from text using different speakers - build: - requirements: - - torchaudio - - pydub - base_image: mlrun/mlrun - code_origin: '' - origin_filename: '' - functionSourceCode: # Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import io
import logging
import os
import pathlib
import random
import tempfile
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import torch
import torchaudio
import tqdm

# Get the global logger:
_LOGGER = logging.getLogger()

OPENAI_API_KEY = "OPENAI_API_KEY"
OPENAI_BASE_URL = "OPENAI_API_BASE"
SAMPLE_RATE = 24000


def generate_multi_speakers_audio(
    data_path: str,
    speakers: Union[List[str], Dict[str, int]],
    available_voices: List[str],
    engine: str = "openai",
    output_directory: str = None,
    use_gpu: Optional[bool] = None,
    use_small_models: Optional[bool] = None,
    offload_cpu: Optional[bool] = None,
    model: Optional[str] = None,
    speed: Optional[float] = None,
    sample_rate: int = 16000,
    file_format: str = "wav",
    verbose: bool = True,
    bits_per_sample: Optional[int] = None,
) -> Tuple[str, pd.DataFrame, dict]:
    """
    Generate audio files from text files.

    :param data_path:           Path to the text file or directory containing the text files to generate audio from.
    :param speakers:            List / Dict of speakers to generate audio for.
                                If a list is given, the speakers will be assigned to channels in the order given.
                                If dictionary, the keys will be the speakers and the values will be the channels.
    :param available_voices:    List of available voices to use for the generation.
                        See here for the available voices for bark engine:
                        https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c
                        See here for the available voices for openai engine:
                        https://beta.openai.com/docs/api-reference/speech
    :param engine:              The engine to use for the generation. Select either "bark" or "openai". Default is "openai".
    :param output_directory:    Path to the directory to save the generated audio files to.
    :param use_gpu:             Whether to use the GPU for the generation. Supported only in "bark" engine.
    :param use_small_models:    Whether to use the small models for the generation. Supported only in "bark" engine.
    :param offload_cpu:         To reduce the memory footprint, the models can be offloaded to the CPU after loading.
                                Supported only in "bark" engine.
    :param model:               Which model to use for the generation. Supported only in "openai" engine.
                                Default is "tts-1".
    :param speed:               The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is the default.
    :param sample_rate:         The sampling rate of the generated audio.
    :param file_format:         The format of the generated audio files.
    :param verbose:             Whether to print the progress of the generation.
    :param bits_per_sample:     Changes the bit depth for the supported formats.
                                Supported only in "wav" or "flac" formats.

    :returns:                   A tuple of:
                                - The output directory path.
                                - The generated audio files dataframe.
                                - The errors' dictionary.
    """

    global _LOGGER
    _LOGGER = _get_logger()
    # Get the input text files to turn to audio:
    data_path = pathlib.Path(data_path).absolute()
    text_files = _get_text_files(data_path=data_path)


    # Prepare the speech engine:
    engine = _get_engine(
        engine=engine,
        use_gpu=use_gpu,
        use_small_models=use_small_models,
        offload_cpu=offload_cpu,
        model=model,
        file_format=file_format,
        speed=speed
    )

    # Check for per channel generation:
    if isinstance(speakers, dict):
        speaker_per_channel = True
        # Sort the given speakers by channels:
        speakers = {
            speaker: channel
            for speaker, channel in sorted(speakers.items(), key=lambda item: item[1])
        }
    else:
        speaker_per_channel = False

    # Prepare the resampling module:
    resampler = torchaudio.transforms.Resample(
        orig_freq=SAMPLE_RATE, new_freq=sample_rate, dtype=torch.float32
    )

    # Prepare the gap between each speaker:
    gap_between_speakers = np.zeros(int(0.5 * SAMPLE_RATE))

    # Prepare the successes dataframe and errors dictionary to be returned:
    successes = []
    errors = {}

    # Create the output directory:
    if output_directory is None:
        output_directory = tempfile.mkdtemp()
    output_directory = pathlib.Path(output_directory)
    if not output_directory.exists():
        output_directory.mkdir(exist_ok=True, parents=True)

    # Start generating audio:
    # Go over the audio files and transcribe:
    for text_file in tqdm.tqdm(
        text_files, desc="Generating", unit="file", disable=not verbose
    ):

        try:
            # Randomize voices for each speaker:
            chosen_voices = {}
            available_voices_copy = available_voices.copy()
            for speaker in speakers:
                voice = random.choice(available_voices_copy)
                chosen_voices[speaker] = voice
                available_voices_copy.remove(voice)
            # Read text:
            with open(text_file, "r") as fp:
                text = fp.read()
            # Prepare a holder for all the generated pieces (if per channel each speaker will have its own):
            audio_pieces = (
                {speaker: [] for speaker in speakers}
                if speaker_per_channel
                else {"all": []}
            )

            # Generate audio per line:
            for line in text.splitlines():
                # Validate line is in correct speaker format:

                if ": " not in line:
                    if verbose:
                        _LOGGER.warning(f"Skipping line: {line}")
                    continue
                # Split line to speaker and his words:
                current_speaker, sentences = line.split(": ", 1)
                # Validate speaker is known:
                if current_speaker not in speakers:
                    raise ValueError(
                        f"Unknown speaker: {current_speaker}. Given speakers are: {speakers}"
                    )
                for sentence in _split_line(line=sentences):
                    # Generate words audio:
                    audio = engine._generate_audio(
                        text=sentence,
                        voice=chosen_voices[current_speaker],
                    )

                    if speaker_per_channel:
                        silence = np.zeros_like(audio)
                        for speaker in audio_pieces.keys():
                            if speaker == current_speaker:
                                audio_pieces[speaker] += [audio, gap_between_speakers]
                            else:
                                audio_pieces[speaker] += [silence, gap_between_speakers]
                    else:
                        audio_pieces["all"] += [audio, gap_between_speakers]
            # Construct a single audio array from all the pieces and channels:

            audio = np.vstack(
                [np.concatenate(audio_pieces[speaker]) for speaker in speakers]
            ).astype(dtype=np.float32)
            # Resample:
            audio = torch.from_numpy(audio)
            audio = resampler(audio)
            # Save to audio file:
            audio_file = output_directory / f"{text_file.stem}.{file_format}"

            torchaudio.save(
                uri=str(audio_file),
                src=audio,
                sample_rate=sample_rate,
                format=file_format,
                bits_per_sample=bits_per_sample,
            )

            # Collect to the successes:
            successes.append([text_file.name, audio_file.name])
        except Exception as exception:
            # Note the exception as error in the dictionary:
            if verbose:
                _LOGGER.warning(f"Error in file: '{text_file.name}'")
            print(exception)
            errors[text_file.name] = str(exception)

    # Construct the translations dataframe:
    successes = pd.DataFrame(
        successes,
        columns=["text_file", "audio_file"],
    )

    # Print the head of the produced dataframe and return:
    if verbose:
        _LOGGER.info(
            f"Done ({successes.shape[0]}/{len(text_files)})\n"
            f"Translations summary:\n"
            f"{successes.head()}"
        )
    return str(output_directory), successes, errors


class SpeechEngine(ABC):
    @abstractmethod
    def _generate_audio(self, text: str, voice: str) -> np.ndarray:
        pass


class BarkEngine(SpeechEngine):
    def __init__(self, use_gpu: bool = True, use_small_models: bool = False, offload_cpu: bool = False):
        try:
            self.bark = importlib.import_module("bark")
        except ImportError:
            raise ImportError(
                "The 'bark' library is required for the BarkEngine. Please install it using 'pip install bark-ai'."
            )

        self.bark.preload_models(
            text_use_gpu=use_gpu,
            text_use_small=use_small_models,
            coarse_use_gpu=use_gpu,
            coarse_use_small=use_small_models,
            fine_use_gpu=use_gpu,
            fine_use_small=use_small_models,
            codec_use_gpu=use_gpu,
            force_reload=offload_cpu,
        )

    def _generate_audio(self, text: str, voice: str) -> np.ndarray:
        # Generate words audio:
        audio = self.bark.generate_audio(
            text,
            history_prompt=voice,
            silent=True,
        )
        return audio


class OpenAIEngine(SpeechEngine):
    def __init__(self, model: str = "tts-1", file_format: str = "wav", speed: float = 1.0):
        try:
            self.openai = importlib.import_module("openai")
            self.pydub = importlib.import_module("pydub")
        except ImportError:
            raise ImportError(
                "The 'openai' and 'pydub' libraries are required for the OpenAIEngine. Please install them using 'pip install openai pydub'."
            )

        api_key = os.getenv(OPENAI_API_KEY)
        base_url = os.getenv(OPENAI_BASE_URL)
        # Check if the key is already in the environment variables:
        if not api_key or not base_url:
            try:
                import mlrun

                context = mlrun.get_or_create_ctx(name="context")
                # Check if the key is in the secrets:
                api_key = context.get_secret(OPENAI_API_KEY)
                base_url = context.get_secret(OPENAI_BASE_URL)
            except ModuleNotFoundError:
                raise EnvironmentError(
                    f"One or more of the OpenAI required environment variables ('{OPENAI_API_KEY}', '{OPENAI_BASE_URL}') are missing."
                    f"Please set them as environment variables or install mlrun (`pip install mlrun`)"
                    f"and set them as project secrets using `project.set_secrets`."
                )

        self.client = self.openai.OpenAI(api_key=api_key, base_url=base_url)
        self.model = model
        self.file_format = file_format
        self.speed = speed

    def _generate_audio(self, text: str, voice: str) -> np.ndarray:
        # Generate words audio:
        audio = self.client.audio.speech.create(
            model=self.model,
            input=text,
            voice=voice,
            response_format=self.file_format,
            speed=self.speed,
        )
        audio = audio.content
        audio = self._bytes_to_np_array(audio=audio)
        return audio

    def _bytes_to_np_array(self, audio: bytes):
        if self.file_format == "mp3":
            audio_segment = self.pydub.AudioSegment.from_mp3(io.BytesIO(audio))

            # Convert to raw PCM audio data
            samples = audio_segment.get_array_of_samples()

            # Convert to numpy array
            audio_array = np.array(samples)

            # Normalize to float between -1 and 1
            return audio_array.astype(np.float32) / np.iinfo(samples.typecode).max
        else:
            return np.frombuffer(audio, dtype=np.int16) / 32768.0


def _get_engine(engine: str, file_format: str, **kwargs) -> SpeechEngine:
    # eliminate the None values:
    kwargs = {key: value for key, value in kwargs.items() if value is not None}

    if engine == "bark":
        return BarkEngine(**kwargs)
    elif engine == "openai":
        return OpenAIEngine(file_format=file_format, **kwargs)
    else:
        raise ValueError(
            f"Unrecognized engine. The parameter `engine` must be either 'bark' or 'openai'. Given: {engine}"
        )

def _get_text_files(
    data_path: pathlib.Path,
) -> List[pathlib.Path]:
    # Check if the path is of a directory or a file:
    if data_path.is_dir():
        # Get all files inside the directory:
        text_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        text_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be either a directory path or a file path. "
            f"Given: {str(data_path)} "
        )

    return text_files


def _split_line(line: str, max_length: int = 250) -> List[str]:
    if len(line) < max_length:
        return [line]

    sentences = [
        f"{sentence.strip()}." for sentence in line.split(".") if sentence.strip()
    ]

    splits = []
    current_length = len(sentences[0])
    split = sentences[0]
    for sentence in sentences[1:]:
        if current_length + len(sentence) > max_length:
            splits.append(split)
            split = sentence
            current_length = len(sentence)
        else:
            current_length += len(sentence)
            split += " " + sentence
    if split:
        splits.append(split)

    return splits


def _get_logger():
    global _LOGGER
    try:
        import mlrun

        # Check if MLRun is available:
        context = mlrun.get_or_create_ctx(name="mlrun")
        return context.logger
    except ModuleNotFoundError:
        return _LOGGER
 -metadata: - categories: - - data-generation - - audio - tag: '' - name: text-to-audio-generator -kind: job -verbose: false + default_handler: generate_multi_speakers_audio diff --git a/functions/src/text_to_audio_generator/test_text_to_audio_generator.py b/functions/src/text_to_audio_generator/test_text_to_audio_generator.py index fb8db3198..c8695cb03 100644 --- a/functions/src/text_to_audio_generator/test_text_to_audio_generator.py +++ b/functions/src/text_to_audio_generator/test_text_to_audio_generator.py @@ -86,4 +86,4 @@ def test_generate_multi_speakers_audio_openai(file_format, bits_per_sample): ) assert function_run.error == "" for key in ["audio_files", "audio_files_dataframe", "text_to_speech_errors"]: - assert key in function_run.outputs and function_run.outputs[key] is not None \ No newline at end of file + assert key in function_run.outputs and function_run.outputs[key] is not None diff --git a/functions/src/text_to_audio_generator/text_to_audio_generator.py b/functions/src/text_to_audio_generator/text_to_audio_generator.py index e03b827ff..4c2de03e3 100644 --- a/functions/src/text_to_audio_generator/text_to_audio_generator.py +++ b/functions/src/text_to_audio_generator/text_to_audio_generator.py @@ -19,7 +19,6 @@ import random import tempfile from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -37,20 +36,20 @@ def generate_multi_speakers_audio( data_path: str, - speakers: Union[List[str], Dict[str, int]], - available_voices: List[str], + speakers: list[str] | dict[str, int], + available_voices: list[str], engine: str = "openai", output_directory: str = None, - use_gpu: Optional[bool] = None, - use_small_models: Optional[bool] = None, - offload_cpu: Optional[bool] = None, - model: Optional[str] = None, - speed: Optional[float] = None, + use_gpu: bool | None = None, + use_small_models: bool | None = None, + offload_cpu: bool | None = None, + model: str | None = None, + speed: float | None = None, sample_rate: int = 16000, file_format: str = "wav", verbose: bool = True, - bits_per_sample: Optional[int] = None, -) -> Tuple[str, pd.DataFrame, dict]: + bits_per_sample: int | None = None, +) -> tuple[str, pd.DataFrame, dict]: """ Generate audio files from text files. @@ -90,7 +89,6 @@ def generate_multi_speakers_audio( data_path = pathlib.Path(data_path).absolute() text_files = _get_text_files(data_path=data_path) - # Prepare the speech engine: engine = _get_engine( engine=engine, @@ -99,7 +97,7 @@ def generate_multi_speakers_audio( offload_cpu=offload_cpu, model=model, file_format=file_format, - speed=speed + speed=speed, ) # Check for per channel generation: @@ -137,7 +135,6 @@ def generate_multi_speakers_audio( for text_file in tqdm.tqdm( text_files, desc="Generating", unit="file", disable=not verbose ): - try: # Randomize voices for each speaker: chosen_voices = {} @@ -147,7 +144,7 @@ def generate_multi_speakers_audio( chosen_voices[speaker] = voice available_voices_copy.remove(voice) # Read text: - with open(text_file, "r") as fp: + with open(text_file) as fp: text = fp.read() # Prepare a holder for all the generated pieces (if per channel each speaker will have its own): audio_pieces = ( @@ -238,7 +235,12 @@ def _generate_audio(self, text: str, voice: str) -> np.ndarray: class BarkEngine(SpeechEngine): - def __init__(self, use_gpu: bool = True, use_small_models: bool = False, offload_cpu: bool = False): + def __init__( + self, + use_gpu: bool = True, + use_small_models: bool = False, + offload_cpu: bool = False, + ): try: self.bark = importlib.import_module("bark") except ImportError: @@ -268,7 +270,9 @@ def _generate_audio(self, text: str, voice: str) -> np.ndarray: class OpenAIEngine(SpeechEngine): - def __init__(self, model: str = "tts-1", file_format: str = "wav", speed: float = 1.0): + def __init__( + self, model: str = "tts-1", file_format: str = "wav", speed: float = 1.0 + ): try: self.openai = importlib.import_module("openai") self.pydub = importlib.import_module("pydub") @@ -289,7 +293,7 @@ def __init__(self, model: str = "tts-1", file_format: str = "wav", speed: float api_key = context.get_secret(OPENAI_API_KEY) base_url = context.get_secret(OPENAI_BASE_URL) except ModuleNotFoundError: - raise EnvironmentError( + raise OSError( f"One or more of the OpenAI required environment variables ('{OPENAI_API_KEY}', '{OPENAI_BASE_URL}') are missing." f"Please set them as environment variables or install mlrun (`pip install mlrun`)" f"and set them as project secrets using `project.set_secrets`." @@ -342,9 +346,10 @@ def _get_engine(engine: str, file_format: str, **kwargs) -> SpeechEngine: f"Unrecognized engine. The parameter `engine` must be either 'bark' or 'openai'. Given: {engine}" ) + def _get_text_files( data_path: pathlib.Path, -) -> List[pathlib.Path]: +) -> list[pathlib.Path]: # Check if the path is of a directory or a file: if data_path.is_dir(): # Get all files inside the directory: @@ -360,7 +365,7 @@ def _get_text_files( return text_files -def _split_line(line: str, max_length: int = 250) -> List[str]: +def _split_line(line: str, max_length: int = 250) -> list[str]: if len(line) < max_length: return [line] diff --git a/functions/src/tf2_serving/function.yaml b/functions/src/tf2_serving/function.yaml index 17cf2fbb9..bb2fb852f 100644 --- a/functions/src/tf2_serving/function.yaml +++ b/functions/src/tf2_serving/function.yaml @@ -1,52 +1,32 @@ -kind: remote metadata: + tag: '' name: tf2-serving - hash: 134293b94996e74275d90546f8d4ef96198af679 - project: '' - labels: - author: Iguazio categories: - model-serving - machine-learning +verbose: false +kind: remote spec: + image: mlrun/mlrun + disable_auto_mount: false + build: + origin_filename: '' + functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKIyBHZW5lcmF0ZWQgYnkgbnVjbGlvLmV4cG9ydC5OdWNsaW9FeHBvcnRlcgoKaW1wb3J0IHdhcm5pbmdzCgp3YXJuaW5ncy5zaW1wbGVmaWx0ZXIoYWN0aW9uPSJpZ25vcmUiLCBjYXRlZ29yeT1GdXR1cmVXYXJuaW5nKQoKaW1wb3J0IGpzb24KZnJvbSBvcyBpbXBvcnQgZW52aXJvbgoKaW1wb3J0IG1scnVuCmltcG9ydCBudW1weSBhcyBucApmcm9tIFBJTCBpbXBvcnQgSW1hZ2UKZnJvbSB0ZW5zb3JmbG93LmtlcmFzLm1vZGVscyBpbXBvcnQgbG9hZF9tb2RlbApmcm9tIHRlbnNvcmZsb3cua2VyYXMucHJlcHJvY2Vzc2luZyBpbXBvcnQgaW1hZ2UKCgpjbGFzcyBURk1vZGVsKG1scnVuLnJ1bnRpbWVzLk1MTW9kZWxTZXJ2ZXIpOgogICAgZGVmIF9faW5pdF9fKHNlbGYsIG5hbWU6IHN0ciwgbW9kZWxfZGlyOiBzdHIpOgogICAgICAgIHN1cGVyKCkuX19pbml0X18obmFtZSwgbW9kZWxfZGlyKQoKICAgICAgICBzZWxmLklNQUdFX1dJRFRIID0gaW50KGVudmlyb24uZ2V0KCJJTUFHRV9XSURUSCIsICIxMjgiKSkKICAgICAgICBzZWxmLklNQUdFX0hFSUdIVCA9IGludChlbnZpcm9uLmdldCgiSU1BR0VfSEVJR0hUIiwgIjEyOCIpKQoKICAgICAgICB0cnk6CiAgICAgICAgICAgIHdpdGggb3BlbihlbnZpcm9uWyJjbGFzc2VzX21hcCJdKSBhcyBmOgogICAgICAgICAgICAgICAgc2VsZi5jbGFzc2VzID0ganNvbi5sb2FkKGYpCiAgICAgICAgZXhjZXB0IEV4Y2VwdGlvbiBhcyBlOgogICAgICAgICAgICBzZWxmLmNsYXNzZXMgPSBOb25lCiAgICAgICAgICAgIHByaW50KGYiY291bGQgbm90IGxvYWQgY2xhc3NlcyBtYXA6IHtlfSIpCgogICAgZGVmIGxvYWQoc2VsZik6CiAgICAgICAgbW9kZWxfZmlsZSwgZXh0cmFfZGF0YSA9IHNlbGYuZ2V0X21vZGVsKCIuaDUiKQogICAgICAgIHNlbGYubW9kZWwgPSBsb2FkX21vZGVsKG1vZGVsX2ZpbGUpCgogICAgZGVmIHByZXByb2Nlc3Moc2VsZiwgYm9keSk6CiAgICAgICAgdHJ5OgogICAgICAgICAgICBvdXRwdXQgPSB7Imluc3RhbmNlcyI6IFtdfQogICAgICAgICAgICBpbnN0YW5jZXMgPSBib2R5LmdldCgiaW5zdGFuY2VzIiwgW10pCiAgICAgICAgICAgIGZvciBieXRlX2ltYWdlIGluIGluc3RhbmNlczoKICAgICAgICAgICAgICAgIGltZyA9IEltYWdlLm9wZW4oYnl0ZV9pbWFnZSkKICAgICAgICAgICAgICAgIGltZyA9IGltZy5yZXNpemUoKHNlbGYuSU1BR0VfV0lEVEgsIHNlbGYuSU1BR0VfSEVJR0hUKSkKCiAgICAgICAgICAgICAgICB4ID0gaW1hZ2UuaW1nX3RvX2FycmF5KGltZykKICAgICAgICAgICAgICAgIHggPSBucC5leHBhbmRfZGltcyh4LCBheGlzPTApCiAgICAgICAgICAgICAgICBvdXRwdXRbImluc3RhbmNlcyJdLmFwcGVuZCh4KQoKICAgICAgICAgICAgb3V0cHV0WyJpbnN0YW5jZXMiXSA9IFtucC52c3RhY2sob3V0cHV0WyJpbnN0YW5jZXMiXSldCiAgICAgICAgICAgIHJldHVybiBvdXRwdXQKICAgICAgICBleGNlcHQ6CiAgICAgICAgICAgIHJhaXNlIEV4Y2VwdGlvbihmInJlY2VpdmVkOiB7Ym9keX0iKQoKICAgIGRlZiBwcmVkaWN0KHNlbGYsIGRhdGEpOgogICAgICAgIGltYWdlcyA9IGRhdGEuZ2V0KCJpbnN0YW5jZXMiLCBbXSkKCiAgICAgICAgcHJlZGljdGVkX3Byb2JhYmlsaXR5ID0gc2VsZi5tb2RlbC5wcmVkaWN0KGltYWdlcykKCiAgICAgICAgcmV0dXJuIHByZWRpY3RlZF9wcm9iYWJpbGl0eQoKICAgIGRlZiBwb3N0cHJvY2VzcyhzZWxmLCBwcmVkaWN0ZWRfcHJvYmFiaWxpdHkpOgogICAgICAgIGlmIHNlbGYuY2xhc3NlczoKICAgICAgICAgICAgcHJlZGljdGVkX2NsYXNzZXMgPSBucC5hcm91bmQocHJlZGljdGVkX3Byb2JhYmlsaXR5LCAxKS50b2xpc3QoKVswXQogICAgICAgICAgICBwcmVkaWN0ZWRfcHJvYmFiaWxpdGllcyA9IHByZWRpY3RlZF9wcm9iYWJpbGl0eS50b2xpc3QoKVswXQogICAgICAgICAgICByZXR1cm4gewogICAgICAgICAgICAgICAgInByZWRpY3Rpb24iOiBbCiAgICAgICAgICAgICAgICAgICAgc2VsZi5jbGFzc2VzW3N0cihpbnQoY2xzKSldIGZvciBjbHMgaW4gcHJlZGljdGVkX2NsYXNzZXMKICAgICAgICAgICAgICAgIF0sCiAgICAgICAgICAgICAgICBmIntzZWxmLmNsYXNzZXNbJzEnXX0tcHJvYmFiaWxpdHkiOiBwcmVkaWN0ZWRfcHJvYmFiaWxpdGllcywKICAgICAgICAgICAgfQogICAgICAgIGVsc2U6CiAgICAgICAgICAgIHJldHVybiBwcmVkaWN0ZWRfcHJvYmFiaWxpdHkudG9saXN0KClbMF0KCmZyb20gbWxydW4ucnVudGltZXMgaW1wb3J0IG51Y2xpb19pbml0X2hvb2sKZGVmIGluaXRfY29udGV4dChjb250ZXh0KToKICAgIG51Y2xpb19pbml0X2hvb2soY29udGV4dCwgZ2xvYmFscygpLCAnc2VydmluZycpCgpkZWYgaGFuZGxlcihjb250ZXh0LCBldmVudCk6CiAgICByZXR1cm4gY29udGV4dC5tbHJ1bl9oYW5kbGVyKGNvbnRleHQsIGV2ZW50KQo= + requirements: + - requests + - pillow + - tensorflow>=2.1 + code_origin: '' + filename: tf2_serving.py + min_replicas: 1 command: '' - args: [] - image: '' - description: tf2 image classification server + default_handler: '' + source: '' max_replicas: 4 + base_image_pull: false + description: tf2 image classification server + function_kind: serving + function_handler: tf2-serving-nuclio:handler env: - - name: MODEL_CLASS - value: TFModel - - name: ENABLE_EXPLAINER - value: 'False' - config: - spec.triggers.http: - kind: http - maxWorkers: 8 - attributes: - ingresses: {} - annotations: {} - base_spec: - apiVersion: nuclio.io/v1 - kind: nuclio:serving - metadata: - annotations: - nuclio.io/generated_by: function generated from 01-09-2020 - labels: {} - name: tf2-serving - spec: - build: - baseImage: mlrun/mlrun - commands: - - pip install tensorflow>=2.1 - - pip install requests pillow - functionSourceCode: IyBHZW5lcmF0ZWQgYnkgbnVjbGlvLmV4cG9ydC5OdWNsaW9FeHBvcnRlcgoKCgppbXBvcnQgd2FybmluZ3MKd2FybmluZ3Muc2ltcGxlZmlsdGVyKGFjdGlvbj0iaWdub3JlIiwgY2F0ZWdvcnk9RnV0dXJlV2FybmluZykKCmltcG9ydCBqc29uCmltcG9ydCBudW1weSBhcyBucAppbXBvcnQgcmVxdWVzdHMKZnJvbSB0ZW5zb3JmbG93IGltcG9ydCBrZXJhcwpmcm9tIHRlbnNvcmZsb3cua2VyYXMubW9kZWxzIGltcG9ydCBsb2FkX21vZGVsCmZyb20gdGVuc29yZmxvdy5rZXJhcy5wcmVwcm9jZXNzaW5nIGltcG9ydCBpbWFnZQpmcm9tIHRlbnNvcmZsb3cua2VyYXMucHJlcHJvY2Vzc2luZy5pbWFnZSBpbXBvcnQgbG9hZF9pbWcKZnJvbSBvcyBpbXBvcnQgZW52aXJvbiwgcGF0aApmcm9tIFBJTCBpbXBvcnQgSW1hZ2UKZnJvbSBpbyBpbXBvcnQgQnl0ZXNJTwpmcm9tIHVybGxpYi5yZXF1ZXN0IGltcG9ydCB1cmxvcGVuCmltcG9ydCBtbHJ1bgoKY2xhc3MgVEZNb2RlbChtbHJ1bi5ydW50aW1lcy5NTE1vZGVsU2VydmVyKToKICAgIGRlZiBfX2luaXRfXyhzZWxmLCBuYW1lOiBzdHIsIG1vZGVsX2Rpcjogc3RyKToKICAgICAgICBzdXBlcigpLl9faW5pdF9fKG5hbWUsIG1vZGVsX2RpcikKCiAgICAgICAgc2VsZi5JTUFHRV9XSURUSCA9IGludChlbnZpcm9uLmdldCgnSU1BR0VfV0lEVEgnLCAnMTI4JykpCiAgICAgICAgc2VsZi5JTUFHRV9IRUlHSFQgPSBpbnQoZW52aXJvbi5nZXQoJ0lNQUdFX0hFSUdIVCcsICcxMjgnKSkKICAgICAgICAKICAgICAgICB0cnk6CiAgICAgICAgICAgIHdpdGggb3BlbihlbnZpcm9uWydjbGFzc2VzX21hcCddLCAncicpIGFzIGY6CiAgICAgICAgICAgICAgICBzZWxmLmNsYXNzZXMgPSBqc29uLmxvYWQoZikKICAgICAgICBleGNlcHQ6CiAgICAgICAgICAgIHNlbGYuY2xhc3NlcyA9IE5vbmUKICAgICAgICAKICAgIGRlZiBsb2FkKHNlbGYpOgogICAgICAgIG1vZGVsX2ZpbGUsIGV4dHJhX2RhdGEgPSBzZWxmLmdldF9tb2RlbCgnLmg1JykKICAgICAgICBzZWxmLm1vZGVsID0gbG9hZF9tb2RlbChtb2RlbF9maWxlKQogICAgICAgIAogICAgZGVmIHByZXByb2Nlc3Moc2VsZiwgYm9keSk6CiAgICAgICAgdHJ5OgogICAgICAgICAgICBvdXRwdXQgPSB7J2luc3RhbmNlcyc6IFtdfQogICAgICAgICAgICBpbnN0YW5jZXMgPSBib2R5LmdldCgnaW5zdGFuY2VzJywgW10pCiAgICAgICAgICAgIGZvciBieXRlX2ltYWdlIGluIGluc3RhbmNlczoKICAgICAgICAgICAgICAgIGltZyA9IEltYWdlLm9wZW4oYnl0ZV9pbWFnZSkKICAgICAgICAgICAgICAgIGltZyA9IGltZy5yZXNpemUoKHNlbGYuSU1BR0VfV0lEVEgsIHNlbGYuSU1BR0VfSEVJR0hUKSkKCiAgICAgICAgICAgICAgICB4ID0gaW1hZ2UuaW1nX3RvX2FycmF5KGltZykKICAgICAgICAgICAgICAgIHggPSBucC5leHBhbmRfZGltcyh4LCBheGlzPTApCiAgICAgICAgICAgICAgICBvdXRwdXRbJ2luc3RhbmNlcyddLmFwcGVuZCh4KQogICAgICAgICAgICAKICAgICAgICAgICAgb3V0cHV0WydpbnN0YW5jZXMnXSA9IFtucC52c3RhY2sob3V0cHV0WydpbnN0YW5jZXMnXSldCiAgICAgICAgICAgIHJldHVybiBvdXRwdXQKICAgICAgICBleGNlcHQ6CiAgICAgICAgICAgIHJhaXNlIEV4Y2VwdGlvbihmJ3JlY2VpdmVkOiB7Ym9keX0nKQogICAgICAgICAgICAKCiAgICBkZWYgcHJlZGljdChzZWxmLCBkYXRhKToKICAgICAgICBpbWFnZXMgPSBkYXRhLmdldCgnaW5zdGFuY2VzJywgW10pCgogICAgICAgIHByZWRpY3RlZF9wcm9iYWJpbGl0eSA9IHNlbGYubW9kZWwucHJlZGljdChpbWFnZXMpCgogICAgICAgIHJldHVybiBwcmVkaWN0ZWRfcHJvYmFiaWxpdHkKICAgICAgICAKICAgIGRlZiBwb3N0cHJvY2VzcyhzZWxmLCBwcmVkaWN0ZWRfcHJvYmFiaWxpdHkpOgogICAgICAgIGlmIHNlbGYuY2xhc3NlczoKICAgICAgICAgICAgcHJlZGljdGVkX2NsYXNzZXMgPSBucC5hcm91bmQocHJlZGljdGVkX3Byb2JhYmlsaXR5LCAxKS50b2xpc3QoKVswXQogICAgICAgICAgICBwcmVkaWN0ZWRfcHJvYmFiaWxpdGllcyA9IHByZWRpY3RlZF9wcm9iYWJpbGl0eS50b2xpc3QoKVswXQogICAgICAgICAgICByZXR1cm4gewogICAgICAgICAgICAgICAgJ3ByZWRpY3Rpb24nOiBbc2VsZi5jbGFzc2VzW3N0cihpbnQoY2xzKSldIGZvciBjbHMgaW4gcHJlZGljdGVkX2NsYXNzZXNdLCAKICAgICAgICAgICAgICAgIGYne3NlbGYuY2xhc3Nlc1siMSJdfS1wcm9iYWJpbGl0eSc6IHByZWRpY3RlZF9wcm9iYWJpbGl0aWVzCiAgICAgICAgICAgIH0KICAgICAgICBlbHNlOgogICAgICAgICAgICByZXR1cm4gcHJlZGljdGVkX3Byb2JhYmlsaXR5LnRvbGlzdCgpWzBdCgoKZnJvbSBtbHJ1bi5ydW50aW1lcyBpbXBvcnQgbnVjbGlvX2luaXRfaG9vawpkZWYgaW5pdF9jb250ZXh0KGNvbnRleHQpOgogICAgbnVjbGlvX2luaXRfaG9vayhjb250ZXh0LCBnbG9iYWxzKCksICdzZXJ2aW5nJykKCmRlZiBoYW5kbGVyKGNvbnRleHQsIGV2ZW50KToKICAgIHJldHVybiBjb250ZXh0Lm1scnVuX2hhbmRsZXIoY29udGV4dCwgZXZlbnQpCg== - noBaseImagesPull: true - env: - - name: MODEL_CLASS - value: TF2Model - handler: tf2_serving:handler - runtime: python:3.9 - volumes: [] - source: '' - function_kind: serving \ No newline at end of file + - name: MLRUN_HTTPDB__NUCLIO__EXPLICIT_ACK + value: enabled diff --git a/functions/src/tf2_serving/tf2_serving.py b/functions/src/tf2_serving/tf2_serving.py index 57380fbfa..820c4fae9 100644 --- a/functions/src/tf2_serving/tf2_serving.py +++ b/functions/src/tf2_serving/tf2_serving.py @@ -19,17 +19,13 @@ warnings.simplefilter(action="ignore", category=FutureWarning) import json +from os import environ + +import mlrun import numpy as np -import requests -from tensorflow import keras +from PIL import Image from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing import image -from tensorflow.keras.preprocessing.image import load_img -from os import environ, path -from PIL import Image -from io import BytesIO -from urllib.request import urlopen -import mlrun class TFModel(mlrun.runtimes.MLModelServer): @@ -40,10 +36,11 @@ def __init__(self, name: str, model_dir: str): self.IMAGE_HEIGHT = int(environ.get("IMAGE_HEIGHT", "128")) try: - with open(environ["classes_map"], "r") as f: + with open(environ["classes_map"]) as f: self.classes = json.load(f) - except: + except Exception as e: self.classes = None + print(f"could not load classes map: {e}") def load(self): model_file, extra_data = self.get_model(".h5") @@ -81,7 +78,7 @@ def postprocess(self, predicted_probability): "prediction": [ self.classes[str(int(cls))] for cls in predicted_classes ], - f'{self.classes["1"]}-probability': predicted_probabilities, + f"{self.classes['1']}-probability": predicted_probabilities, } else: return predicted_probability.tolist()[0] diff --git a/functions/src/transcribe/test_transcribe.py b/functions/src/transcribe/test_transcribe.py index f70b3856d..4e80580df 100644 --- a/functions/src/transcribe/test_transcribe.py +++ b/functions/src/transcribe/test_transcribe.py @@ -20,7 +20,6 @@ import mlrun import pytest - expected_outputs = [ "This is a speech to text test.", "In the heart of the stadium, " @@ -30,7 +29,6 @@ "as the game writes its unpredictable story on the field of destiny.", ] models = [ - "openai/whisper-tiny", ] @@ -42,7 +40,9 @@ def test_transcribe(model_name: str, audio_path: str): # Setting variables and importing function: artifact_path = tempfile.mkdtemp() project = mlrun.get_or_create_project("test") - transcribe_function = project.set_function("transcribe.py", "transcribe", kind="job", image="mlrun/mlrun") + transcribe_function = project.set_function( + "transcribe.py", "transcribe", kind="job", image="mlrun/mlrun" + ) # transcribe_function = mlrun.import_function("function.yaml") temp_dir = tempfile.mkdtemp() @@ -80,7 +80,7 @@ def test_transcribe(model_name: str, audio_path: str): # Check that the transcribed text was approximately (90%) generated from audio: for text_file, expected in zip(text_files, expected_outputs): - with open(os.path.join(temp_dir, text_file), "r") as f: + with open(os.path.join(temp_dir, text_file)) as f: output = f.readlines()[0] ratio = SequenceMatcher(None, expected, output).ratio() assert ratio >= 0.9 diff --git a/functions/src/transcribe/transcribe.py b/functions/src/transcribe/transcribe.py index 9cabcb1e8..7f30563cb 100644 --- a/functions/src/transcribe/transcribe.py +++ b/functions/src/transcribe/transcribe.py @@ -15,10 +15,11 @@ import operator import os import tempfile +from collections.abc import Generator from functools import reduce, wraps from multiprocessing import Process, Queue from pathlib import Path -from typing import Any, Dict, Generator, List, Literal, NamedTuple, Tuple, Union +from typing import Any, Literal, NamedTuple import pandas as pd import torch @@ -38,7 +39,7 @@ class BaseTask: """ def __init__( - self, audio_file: Path, transcription_output: Union[dict, str], text_file: Path + self, audio_file: Path, transcription_output: dict | str, text_file: Path ): """ Initialize the task. @@ -75,7 +76,7 @@ def is_failed(self) -> bool: """ return self._error is not None - def get_result(self) -> Tuple[str, str]: + def get_result(self) -> tuple[str, str]: """ Get the result of the task. If the task failed, the error will be returned, otherwise, the result will be the text file name. @@ -86,7 +87,7 @@ def get_result(self) -> Tuple[str, str]: return self._audio_file.name, self._error return self._audio_file.name, self._text_file.name - def to_tuple(self) -> Tuple[str, dict]: + def to_tuple(self) -> tuple[str, dict]: """ Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue). @@ -147,7 +148,7 @@ def __init__( audio_file: Path, transcription_output: dict, text_file: Path, - speech_diarization: List[Tuple[float, float, str]], + speech_diarization: list[tuple[float, float, str]], ): """ Initialize the task. @@ -163,10 +164,10 @@ def __init__( text_file=text_file, ) self._speech_diarization = speech_diarization - self._segments: List[SpeechDiarizationTask._DiarizationSegment] = None + self._segments: list[SpeechDiarizationTask._DiarizationSegment] = None self._last_chosen_index = 0 - def to_tuple(self) -> Tuple[str, dict]: + def to_tuple(self) -> tuple[str, dict]: """ Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue). @@ -334,10 +335,10 @@ def __init__(self, audio_file: Path, text_file: Path): super().__init__( audio_file=audio_file, transcription_output={}, text_file=text_file ) - self._transcription_output_channels: List[Tuple[str, dict]] = [] + self._transcription_output_channels: list[tuple[str, dict]] = [] @property - def transcription_output_channels(self) -> List[Tuple[str, dict]]: + def transcription_output_channels(self) -> list[tuple[str, dict]]: """ Get the transcription output channels. @@ -355,7 +356,7 @@ def do_task(self): return super().do_task() - def to_tuple(self) -> Tuple[str, dict]: + def to_tuple(self) -> tuple[str, dict]: """ Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue). @@ -412,7 +413,7 @@ class BatchProcessor: associated methods. """ - def __init__(self, audio_files: List[Path], output_directory: Path): + def __init__(self, audio_files: list[Path], output_directory: Path): """ Initialize the batch processor. @@ -425,10 +426,10 @@ def __init__(self, audio_files: List[Path], output_directory: Path): # Prepare the batching variables: self._current_file_index = 0 - self._tasks: List[BaseTask] = [] - self._results: List[Tuple[bool, Tuple[str, str]]] = [] + self._tasks: list[BaseTask] = [] + self._results: list[tuple[bool, tuple[str, str]]] = [] - def process_batch(self, batch: List[Union[dict, str]]): + def process_batch(self, batch: list[dict | str]): """ Process a batch of transcriptions. Tasks related to the given batch will be created and stored in the batch processor. @@ -450,7 +451,7 @@ def process_batch(self, batch: List[Union[dict, str]]): ] ) - def get_tasks(self) -> List[BaseTask]: + def get_tasks(self) -> list[BaseTask]: """ Get the tasks to perform. @@ -468,7 +469,7 @@ def do_tasks(self): task.do_task() self._results.append((task.is_failed(), task.get_result())) - def get_results(self) -> List[Tuple[bool, Tuple[str, str]]]: + def get_results(self) -> list[tuple[bool, tuple[str, str]]]: """ Get the results of the tasks. The stored results are then cleared. @@ -478,7 +479,7 @@ def get_results(self) -> List[Tuple[bool, Tuple[str, str]]]: self._results = [] return results - def _get_current_files(self, batch_size: int) -> List[Path]: + def _get_current_files(self, batch_size: int) -> list[Path]: """ Get the current files to process. @@ -504,7 +505,7 @@ class SpeechDiarizationBatchProcessor(BatchProcessor): """ def __init__( - self, audio_files: List[Path], output_directory: Path, speech_diarization: dict + self, audio_files: list[Path], output_directory: Path, speech_diarization: dict ): """ Initialize the batch processor. @@ -517,7 +518,7 @@ def __init__( self._speech_diarization = speech_diarization self._audio_files = audio_files - def process_batch(self, batch: List[dict]): + def process_batch(self, batch: list[dict]): """ Process a batch of transcriptions. Tasks related to the given batch will be created and stored in the batch processor. @@ -550,10 +551,10 @@ class PerChannelSpeechDiarizationBatchProcessor(BatchProcessor): def __init__( self, - audio_files: List[Path], + audio_files: list[Path], output_directory: Path, n_channels: int, - speakers: List[str], + speakers: list[str], ): """ Initialize the batch processor. @@ -572,7 +573,7 @@ def __init__( # Prepare a channel buffer to store the channels until the current task created is fully covered: self._task_in_process: SpeechDiarizationPerChannelTask = None - def process_batch(self, batch: List[dict]): + def process_batch(self, batch: list[dict]): """ Process a batch of transcriptions. Tasks related to the given batch will be created and stored in the batch processor. @@ -627,50 +628,50 @@ def __init__( batch_size: int = 2, spoken_language: str = None, translate_to_english: bool = False, - return_timestamps: Union[bool, Literal["word"]] = False, + return_timestamps: bool | Literal["word"] = False, per_channel_transcription: int = 0, ): """ - Initialize the transcriber. - - :param model_name: The model name to use. Should be a model from the OpenAI's Whisper models for - best results (for example "tiny", "base", "large", etc.). - :param device: The device to use for inference. If not given, will use GPU if available. - :param use_flash_attention_2: Whether to use the Flash Attention 2 implementation. It can be used only with - one of the following GPUs: Nvidia H series and Nvidia A series. T4 support - will be available soon. - - Note: If both `use_flash_attention_2` and - `use_better_transformers` are `None`, the optimization will be chosen - automatically according to the available resources. - - :param use_better_transformers: Whether to use the Better Transformers library to further optimize the model. - Should be used for all use cases that do not support flash attention 2. - - Note: If both `use_flash_attention_2` and `use_better_transformers` are - `None`, the optimization will be chosen automatically according to the - available resources. - :param assistant_model: The assistant model name to use for inference. Notice that the optimizations - (flash attention 2 and better transformers) will be applied for the assistant - as well. Should be a model from Huggingface's distil-whisper (see here for - more information: https://github.com/huggingface/distil-whisper). - :param max_new_tokens: The maximum number of new tokens to generate. This is used to limit the - generation length. Default is 128 tokens. - :param chunk_length_s: The audio chunk to split the audio to (in seconds). Default is 30 seconds. - :param batch_size: The batch size to use for inference. Default is 2. - :param spoken_language: Aim whisper to know what language is spoken. If None, it will try to detect it - for each chunk. - :param translate_to_english: Whether to translate the transcriptions to English. Default is False. - :param return_timestamps: Whether to return the timestamps of the words. If "word", will return the - timestamps of each word. If True will return the timestamps of each chunk. - Default is False. Aimed to be used for speech diarization. - :param per_channel_transcription: Whether to do per channel transcription. If needed to run per channel - transcription, pass the number of channels expected for each audio file here. - 0 means regular transcription (merge channels). - - Note: If `per_channel_transcription` is not 0, `batch_size` must be treated to - be the number of channels and not audio files. Aimed to be used for per - channel speech diarization. + Initialize the transcriber. + + :param model_name: The model name to use. Should be a model from the OpenAI's Whisper models for + best results (for example "tiny", "base", "large", etc.). + :param device: The device to use for inference. If not given, will use GPU if available. + :param use_flash_attention_2: Whether to use the Flash Attention 2 implementation. It can be used only with + one of the following GPUs: Nvidia H series and Nvidia A series. T4 support + will be available soon. + + Note: If both `use_flash_attention_2` and + `use_better_transformers` are `None`, the optimization will be chosen + automatically according to the available resources. + + :param use_better_transformers: Whether to use the Better Transformers library to further optimize the model. + Should be used for all use cases that do not support flash attention 2. + + Note: If both `use_flash_attention_2` and `use_better_transformers` are + `None`, the optimization will be chosen automatically according to the + available resources. + :param assistant_model: The assistant model name to use for inference. Notice that the optimizations + (flash attention 2 and better transformers) will be applied for the assistant + as well. Should be a model from Huggingface's distil-whisper (see here for + more information: https://github.com/huggingface/distil-whisper). + :param max_new_tokens: The maximum number of new tokens to generate. This is used to limit the + generation length. Default is 128 tokens. + :param chunk_length_s: The audio chunk to split the audio to (in seconds). Default is 30 seconds. + :param batch_size: The batch size to use for inference. Default is 2. + :param spoken_language: Aim whisper to know what language is spoken. If None, it will try to detect it + for each chunk. + :param translate_to_english: Whether to translate the transcriptions to English. Default is False. + :param return_timestamps: Whether to return the timestamps of the words. If "word", will return the + timestamps of each word. If True will return the timestamps of each chunk. + Default is False. Aimed to be used for speech diarization. + :param per_channel_transcription: Whether to do per channel transcription. If needed to run per channel + transcription, pass the number of channels expected for each audio file here. + 0 means regular transcription (merge channels). + + Note: If `per_channel_transcription` is not 0, `batch_size` must be treated to + be the number of channels and not audio files. Aimed to be used for per + channel speech diarization. """ # Store loading parameters: self._model_name = model_name @@ -781,11 +782,11 @@ def load(self): def transcribe( self, - audio_files: List[Path], + audio_files: list[Path], batch_processor: BatchProcessor = None, batches_queue: Queue = None, verbose: bool = False, - ) -> Union[List[List[dict]], None]: + ) -> list[list[dict]] | None: """ Transcribe the given audio files. The transcriptions will be sent to a queue or a batch processor for further processing like writing to text files. If no queue or batch processor is given, the transcriptions outputs from @@ -799,9 +800,10 @@ def transcribe( :returns: The transcriptions outputs from the pipeline if no queue or batch processor is given, otherwise, `None`. """ + # Wrap the audio files with a function to iterate over them via a generator (save memory and runtime with # Huggingface's pipelines as they preload each input while inference is running): - def audio_iterator() -> Generator[Union[dict, str], None, None]: + def audio_iterator() -> Generator[dict | str, None, None]: if self._per_channel_transcription: for audio_file in audio_files: audio, sampling_rate = torchaudio.load(str(audio_file)) @@ -813,7 +815,7 @@ def audio_iterator() -> Generator[Union[dict, str], None, None]: yield str(audio_file) # Create a batch iterator: - def batch_iterator() -> Generator[List[Union[dict, str]], None, None]: + def batch_iterator() -> Generator[list[dict | str], None, None]: batch = [] for audio in audio_iterator(): batch.append(audio) @@ -899,7 +901,7 @@ def _multiprocessing_process_batches( """ while True: # Get the batch: - batch: List[dict] = batches_queue.get() + batch: list[dict] = batches_queue.get() if batch == _MULTIPROCESSING_STOP_MARK: break @@ -955,7 +957,7 @@ def _multiprocessing_complete_tasks(tasks_queue: Queue, results_queue: Queue): def open_mpi_handler( - worker_inputs: List[str], root_worker_inputs: Dict[str, Any] = None + worker_inputs: list[str], root_worker_inputs: dict[str, Any] = None ): global _LOGGER @@ -1056,7 +1058,7 @@ def wrapper(**kwargs): if comm.recv(source=0): files = [] for file in os.listdir(output_directory): - with open(output_directory / file, "r") as f: + with open(output_directory / file) as f: files.append((file, f.read())) comm.send(files, dest=0) return None @@ -1066,7 +1068,7 @@ def wrapper(**kwargs): return decorator -def _check_mlrun_and_open_mpi() -> Tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intracomm"]: +def _check_mlrun_and_open_mpi() -> tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intracomm"]: is_mpi = False try: import mlrun @@ -1096,7 +1098,7 @@ def _check_mlrun_and_open_mpi() -> Tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intrac @open_mpi_handler(worker_inputs=["data_path"], root_worker_inputs={"verbose": True}) def transcribe( # Input / Output kwargs: - data_path: Union[str, Path, List[Union[str, Path]]], + data_path: str | Path | list[str | Path], output_directory: str = None, # Model loading kwargs: model_name: str = "openai/whisper-tiny", @@ -1111,11 +1113,11 @@ def transcribe( spoken_language: str = None, translate_to_english: bool = False, # Diarization kwargs: - speech_diarization: Dict[str, List[Tuple[float, float, str]]] = None, + speech_diarization: dict[str, list[tuple[float, float, str]]] = None, speech_diarize_per_channel: int = None, - speaker_labels: List[str] = None, + speaker_labels: list[str] = None, # Other kwargs: - use_multiprocessing: Union[bool, int] = False, + use_multiprocessing: bool | int = False, verbose: bool = False, ): """ @@ -1314,8 +1316,8 @@ def transcribe( def _get_audio_files( - data_path: Union[Path, str, list], -) -> List[Path]: + data_path: Path | str | list, +) -> list[Path]: """ Get the audio files to transcribe. If a path to a directory is given, all files in the directory will be collected. @@ -1350,11 +1352,11 @@ def _get_audio_files( def _run( - audio_files: List[Path], + audio_files: list[Path], batch_processor: BatchProcessor, transcriber: Transcriber, verbose: bool, -) -> List[Tuple[bool, Tuple[str, str]]]: +) -> list[tuple[bool, tuple[str, str]]]: """ Run the transcription without multiprocessing. @@ -1367,7 +1369,7 @@ def _run( """ # Load the transcription pipeline: if verbose: - _LOGGER.info(f"Loading the transcription pipeline.") + _LOGGER.info("Loading the transcription pipeline.") transcriber.load() if verbose: _LOGGER.info("Transcription pipeline loaded.") @@ -1385,7 +1387,7 @@ def _run( def _parallel_run( n_workers: int, - audio_files: List[Path], + audio_files: list[Path], batch_processor: BatchProcessor, transcriber: Transcriber, verbose: bool, @@ -1431,7 +1433,7 @@ def _parallel_run( # Load the transcription pipeline: if verbose: - _LOGGER.info(f"Loading the transcription pipeline.") + _LOGGER.info("Loading the transcription pipeline.") transcriber.load() if verbose: _LOGGER.info("Transcription pipeline loaded.") @@ -1446,7 +1448,7 @@ def _parallel_run( stop_marks_counter = 0 while True: # Get a result from the queue: - result: Tuple[bool, Tuple[str, str]] = results_queue.get() + result: tuple[bool, tuple[str, str]] = results_queue.get() if result == _MULTIPROCESSING_STOP_MARK: stop_marks_counter += 1 if stop_marks_counter == n_workers: @@ -1461,4 +1463,4 @@ def _parallel_run( for p in task_completion_processes: p.join() - return results \ No newline at end of file + return results diff --git a/functions/src/translate/function.yaml b/functions/src/translate/function.yaml index eb1ffd345..bda404af3 100644 --- a/functions/src/translate/function.yaml +++ b/functions/src/translate/function.yaml @@ -1,43 +1,58 @@ +metadata: + tag: '' + name: translate + categories: + - genai + - NLP verbose: false +kind: job spec: - description: Translate text files from one language to another - filename: /Users/Daniel_Perez/PycharmProjects/functions/functions/src/translate/translate.py - command: '' + image: '' + disable_auto_mount: false + build: + origin_filename: '' + functionSourceCode: # Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import operator
import pathlib
from functools import reduce, wraps
from typing import Any

import pandas as pd
import transformers
from tqdm import tqdm

# Get the global logger:
_LOGGER = logging.getLogger()


def _check_mlrun_and_open_mpi() -> tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intracomm"]:
    is_mpi = False
    try:
        import mlrun

        context = mlrun.get_or_create_ctx(name="mlrun")
        is_mpi = context.labels.get("kind", "job") == "mpijob"

        if is_mpi:
            try:
                from mpi4py import MPI

                return context, MPI.COMM_WORLD
            except ModuleNotFoundError as mpi4py_not_found:
                context.logger.error(
                    "To distribute the function using MLRun's 'mpijob' you need to have `mpi4py` package in your "
                    "interpreter. Please run `pip install mpi4py` and make sure you have open-mpi."
                )
                raise mpi4py_not_found
        else:
            return context, None
    except ModuleNotFoundError as module_not_found:
        if is_mpi:
            raise module_not_found
    return None, None


def open_mpi_handler(
    worker_inputs: list[str], root_worker_inputs: dict[str, Any] = None
):
    global _LOGGER

    # Check for MLRun and OpenMPI availability:
    context, comm = _check_mlrun_and_open_mpi()

    # Check if MLRun is available, set the global logger to MLRun's:
    if context:
        _LOGGER = context.logger

    def decorator(handler):
        if comm is None or comm.Get_size() == 1:
            return handler

        @wraps(handler)
        def wrapper(**kwargs):
            # Get the open mpi environment properties:
            size = comm.Get_size()
            rank = comm.Get_rank()

            # Give the correct chunk of the workers inputs:
            for worker_input in worker_inputs:
                input_argument = kwargs[worker_input]
                if input_argument is None:
                    continue
                if isinstance(input_argument, (str, pathlib.Path)):
                    input_argument = _get_text_files(
                        data_path=pathlib.Path(input_argument).absolute()
                    )
                if len(input_argument) < size:
                    raise ValueError(
                        f"Cannot split the input '{worker_input}' of length {len(input_argument)} to {size} workers. "
                        f"Please reduce the amount of workers for this input."
                    )
                even_chunk_size = len(input_argument) // size
                chunk_start = rank * even_chunk_size
                chunk_end = (
                    (rank + 1) * even_chunk_size
                    if rank + 1 < size
                    else len(input_argument)
                )
                context.logger.info(
                    f"Rank #{rank}: Processing input chunk of '{worker_input}' "
                    f"from index {chunk_start} to {chunk_end}."
                )
                if isinstance(input_argument, list):
                    input_argument = input_argument[chunk_start:chunk_end]
                elif isinstance(input_argument, pd.DataFrame):
                    input_argument = input_argument.iloc[chunk_start:chunk_end:, :]
                kwargs[worker_input] = input_argument

            # Set the root worker only arguments:
            if rank == 0 and root_worker_inputs:
                kwargs.update(root_worker_inputs)

            # Run the worker:
            output = handler(**kwargs)

            # Send the output to the root rank (rank #0):
            output = comm.gather(output, root=0)
            if rank == 0:
                # Join the outputs:
                context.logger.info("Collecting data from workers to root worker.")
                output_directory = output[0][0]
                dataframe = pd.concat(objs=[df for _, df, _ in output], axis=0)
                errors_dictionary = reduce(
                    operator.ior, [err for _, _, err in output], {}
                )
                return output_directory, dataframe, errors_dictionary
            return None

        return wrapper

    return decorator


@open_mpi_handler(worker_inputs=["data_path"], root_worker_inputs={"verbose": True})
def translate(
    data_path: str | list[str] | pathlib.Path,
    output_directory: str,
    model_name: str = None,
    source_language: str = None,
    target_language: str = None,
    device: str = None,
    model_kwargs: dict = None,
    batch_size: int = 1,
    translation_kwargs: dict = None,
    verbose: bool = False,
) -> tuple[str, pd.DataFrame, dict]:
    """
    Translate text files using a transformer model from Huggingface's hub according to the source and target languages
    given (or using the directly provided model name). The end result is a directory of translated text files and a
    dataframe containing the following columns:

    * text_file - The text file path.
    * translation_file - The translation text file name in the output directory.

    :param data_path:          A directory of text files or a single file or a list of files to translate.
    :param output_directory:   Directory where the translated files will be saved.
    :param model_name:         The name of a model to load. If None, the model name is constructed using the source and
                               target languages parameters.
    :param source_language:    The source language code (e.g., 'en' for English).
    :param target_language:    The target language code (e.g., 'en' for English).
    :param model_kwargs:       Keyword arguments to pass regarding the loading of the model in HuggingFace's `pipeline`
                               function.
    :param device:             The device index for transformers. Default will prefer cuda if available.
    :param batch_size:         The number of batches to use in translation. The files are translated one by one, but the
                               sentences can be batched.
    :param translation_kwargs: Additional keyword arguments to pass to a `transformers.TranslationPipeline` when doing
                               the translation inference. Notice the batch size here is being added automatically.
    :param verbose:            Whether to present logs of a progress bar and errors. Default: True.

    :returns: A tuple of:

              * Path to the output directory.
              * A dataframe dataset of the translated file names.
              * A dictionary of errored files that were not translated.
    """
    global _LOGGER

    # Get the input text files to translate:
    if verbose:
        _LOGGER.info("Collecting text files.")
    if isinstance(data_path, str):
        data_path = pathlib.Path(data_path).absolute()
        text_files = _get_text_files(data_path=data_path)
    else:
        text_files = data_path
    if verbose:
        _LOGGER.info(f"Collected {len(text_files)} text files.")

    # Get the translation pipeline:
    if verbose:
        _LOGGER.info(f"Loading model - using device '{device}'.")
    translation_pipeline, model_name = _get_translation_pipeline(
        model_name=model_name,
        source_language=source_language,
        target_language=target_language,
        device=device,
        model_kwargs=model_kwargs,
        batch_size=batch_size if batch_size != 1 else None,
    )
    if verbose:
        _LOGGER.info(f"Model '{model_name}' was loaded successfully.")

    # Prepare the successes dataframe and errors dictionary to be returned:
    successes = []
    errors = {}

    # Create the output directory:
    output_directory = pathlib.Path(output_directory)
    output_directory.mkdir(parents=True, exist_ok=True)

    # Prepare the translation keyword arguments:
    translation_kwargs = translation_kwargs or {}

    # Go over the audio files and transcribe:
    for text_file in tqdm(
        text_files, desc="Translating", unit="file", disable=not verbose
    ):
        try:
            # Translate:
            translation = _translate(
                text_file=text_file,
                translation_pipeline=translation_pipeline,
                translation_kwargs=translation_kwargs,
            )
            # Write the transcription to file:
            translation_file = _save_to_file(
                translation=translation,
                file_name=text_file.stem,
                output_directory=output_directory,
            )
            # Note as a success in the list:
            successes.append(
                [
                    text_file.name,
                    translation_file.name,
                ]
            )
        except Exception as exception:
            # Note the exception as error in the dictionary:
            if verbose:
                _LOGGER.warning(f"Error in file: '{text_file.name}'")
            errors[str(text_file.name)] = str(exception)
            continue

    # Construct the translations dataframe:
    columns = [
        "text_file",
        "translation_file",
    ]
    successes = pd.DataFrame(
        successes,
        columns=columns,
    )

    # Print the head of the produced dataframe and return:
    if verbose:
        _LOGGER.info(
            f"Done ({successes.shape[0]}/{len(text_files)})\n"
            f"Translations summary:\n"
            f"{successes.head()}"
        )
    return str(output_directory), successes, errors


def _get_text_files(
    data_path: pathlib.Path,
) -> list[pathlib.Path]:
    # Check if the path is of a directory or a file:
    if data_path.is_dir():
        # Get all files inside the directory:
        text_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        text_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be either a directory path or a file path. "
            f"Given: {str(data_path)} "
        )

    return text_files


def _get_translation_pipeline(
    model_name: str = None,
    source_language: str = None,
    target_language: str = None,
    device: str = None,
    model_kwargs: dict = None,
    batch_size: int = None,
) -> tuple[transformers.Pipeline, str]:
    # Construct the model name - if model name is provided (not None) then we take it, otherwise we check both source
    # and target were provided to construct the model name:
    if model_name is None and (source_language is None or target_language is None):
        raise ValueError(
            "No model name were given and missing source and / or target languages. In order to translate you must "
            "pass a `model_name` or both `source_language` and `target_language`."
        )
    elif model_name is None:
        model_name = f"Helsinki-NLP/opus-mt-{source_language}-{target_language}"

    # Initialize the translation pipeline:
    try:
        translation_pipeline = transformers.pipeline(
            task="translation",
            model=model_name,
            tokenizer=model_name,
            device=device,
            model_kwargs=model_kwargs,
            batch_size=batch_size,
        )
    except OSError as load_exception:
        if (
            "is not a valid model identifier listed on 'https://huggingface.co/models'"
            in str(load_exception)
            and source_language
        ):
            raise ValueError(
                f"The model '{model_name}' is not a valid model identifier. "
                f"The parameters `source_language` and `target_language` are used to construct a Helsinki model for "
                f"text to text generation, but the model created from the given languages does not exist. "
                f"You may check language identifiers at "
                f"https://developers.google.com/admin-sdk/directory/v1/languages, and if the error was not fixed, one "
                f"or more language code might be with 3 letters and needs to be found online. "
                f"Remember, you can always choose a model directly from the Huggingface hub by using the `model_name` "
                f"parameter."
            ) from load_exception
        raise load_exception

    return translation_pipeline, model_name


def _translate(
    text_file: pathlib.Path,
    translation_pipeline: transformers.Pipeline,
    translation_kwargs: dict,
) -> str:
    # Read the text from file:
    with open(text_file) as fp:
        text = fp.read()

    # Split to paragraphs and each paragraph to sentences:
    paragraphs = [paragraph.split(".") for paragraph in text.split("\n")]

    # Discover the newline indexes to restore the file to its structure post translation:
    newlines_indexes = []
    for paragraph in paragraphs[:-1]:
        if len(newlines_indexes) == 0:
            newlines_indexes.append(len(paragraph) - 1)
        else:
            newlines_indexes.append(newlines_indexes[-1] + len(paragraph))

    # Prepare the batches (each sentence from the paragraphs). Notice we add a dot not only to restore the sentence
    # structure but to ignore empty strings as it will ruin the translation:
    sentences = [f"{line}." for paragraph in paragraphs for line in paragraph]

    # Translate the sentences:
    translations = translation_pipeline(sentences, **translation_kwargs)

    # Restructure the full text from the sentences:
    translated_text = []
    newline_index = newlines_indexes.pop(0) if newlines_indexes else None
    for i, translation in enumerate(translations):
        # Get the translation:
        text = translation["translation_text"]
        # Validate if it was an empty sentence before:
        if text == ".":
            text = ""
        # Check if needed to insert a newline:
        if newline_index and newline_index == i:
            text += "\n"
            newline_index = newlines_indexes.pop(0) if newlines_indexes else None
        # Collect it:
        translated_text.append(text)
    translated_text = "".join(translated_text)

    return translated_text


def _save_to_file(
    translation: str, file_name: str, output_directory: pathlib.Path
) -> pathlib.Path:
    # Prepare the file full path (checking for no duplications):
    translation_file = output_directory / f"{file_name}.txt"
    i = 1
    while translation_file.exists():
        i += 1
        translation_file = output_directory / f"{file_name}_{i}.txt"

    # Make sure all directories are created:
    translation_file.parent.mkdir(exist_ok=True, parents=True)

    # Write to file:
    with open(translation_file, "w") as fp:
        fp.write(translation)

    return translation_file
 + requirements: + - transformers + - sentencepiece + - torch>=2.6 + - tqdm + code_origin: '' + base_image: mlrun/mlrun + filename: translate.py entry_points: open_mpi_handler: - lineno: 56 parameters: - name: worker_inputs - type: List[str] + type: list[str] - name: root_worker_inputs - type: Dict[str, Any] + type: dict[str, Any] default: null + name: open_mpi_handler doc: '' has_kwargs: false has_varargs: false - name: open_mpi_handler + lineno: 56 decorator: - lineno: 68 parameters: - name: handler + name: decorator doc: '' has_kwargs: false has_varargs: false - name: decorator + lineno: 68 wrapper: - lineno: 73 + name: wrapper doc: '' has_kwargs: true has_varargs: false - name: wrapper + lineno: 73 translate: outputs: - doc: 'A tuple of:' - type: Tuple[str, pd.DataFrame, dict] - lineno: 135 + type: tuple[str, pd.DataFrame, dict] parameters: - name: data_path - type: Union[str, List[str], Path] doc: A directory of text files or a single file or a list of files to translate. - name: output_directory type: str @@ -79,6 +94,7 @@ spec: type: bool doc: 'Whether to present logs of a progress bar and errors. Default: True.' default: false + name: translate doc: 'Translate text files using a transformer model from Huggingface''s hub according to the source and target languages @@ -93,24 +109,7 @@ spec: * translation_file - The translation text file name in the output directory.' has_kwargs: false has_varargs: false - name: translate - disable_auto_mount: false - image: '' + lineno: 135 + command: '' + description: Translate text files from one language to another default_handler: translate - build: - functionSourceCode: # Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import operator
import pathlib
from functools import reduce, wraps
from typing import Any, Dict, List, Tuple, Union

import pandas as pd
import transformers
from tqdm import tqdm

# Get the global logger:
_LOGGER = logging.getLogger()


def _check_mlrun_and_open_mpi() -> Tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intracomm"]:
    is_mpi = False
    try:
        import mlrun

        context = mlrun.get_or_create_ctx(name="mlrun")
        is_mpi = context.labels.get("kind", "job") == "mpijob"

        if is_mpi:
            try:
                from mpi4py import MPI

                return context, MPI.COMM_WORLD
            except ModuleNotFoundError as mpi4py_not_found:
                context.logger.error(
                    "To distribute the function using MLRun's 'mpijob' you need to have `mpi4py` package in your "
                    "interpreter. Please run `pip install mpi4py` and make sure you have open-mpi."
                )
                raise mpi4py_not_found
        else:
            return context, None
    except ModuleNotFoundError as module_not_found:
        if is_mpi:
            raise module_not_found
    return None, None


def open_mpi_handler(
    worker_inputs: List[str], root_worker_inputs: Dict[str, Any] = None
):
    global _LOGGER

    # Check for MLRun and OpenMPI availability:
    context, comm = _check_mlrun_and_open_mpi()

    # Check if MLRun is available, set the global logger to MLRun's:
    if context:
        _LOGGER = context.logger

    def decorator(handler):
        if comm is None or comm.Get_size() == 1:
            return handler

        @wraps(handler)
        def wrapper(**kwargs):
            # Get the open mpi environment properties:
            size = comm.Get_size()
            rank = comm.Get_rank()

            # Give the correct chunk of the workers inputs:
            for worker_input in worker_inputs:
                input_argument = kwargs[worker_input]
                if input_argument is None:
                    continue
                if isinstance(input_argument, (str, pathlib.Path)):
                    input_argument = _get_text_files(
                        data_path=pathlib.Path(input_argument).absolute()
                    )
                if len(input_argument) < size:
                    raise ValueError(
                        f"Cannot split the input '{worker_input}' of length {len(input_argument)} to {size} workers. "
                        f"Please reduce the amount of workers for this input."
                    )
                even_chunk_size = len(input_argument) // size
                chunk_start = rank * even_chunk_size
                chunk_end = (
                    (rank + 1) * even_chunk_size
                    if rank + 1 < size
                    else len(input_argument)
                )
                context.logger.info(
                    f"Rank #{rank}: Processing input chunk of '{worker_input}' "
                    f"from index {chunk_start} to {chunk_end}."
                )
                if isinstance(input_argument, list):
                    input_argument = input_argument[chunk_start:chunk_end]
                elif isinstance(input_argument, pd.DataFrame):
                    input_argument = input_argument.iloc[chunk_start:chunk_end:, :]
                kwargs[worker_input] = input_argument

            # Set the root worker only arguments:
            if rank == 0 and root_worker_inputs:
                kwargs.update(root_worker_inputs)

            # Run the worker:
            output = handler(**kwargs)

            # Send the output to the root rank (rank #0):
            output = comm.gather(output, root=0)
            if rank == 0:
                # Join the outputs:
                context.logger.info("Collecting data from workers to root worker.")
                output_directory = output[0][0]
                dataframe = pd.concat(objs=[df for _, df, _ in output], axis=0)
                errors_dictionary = reduce(
                    operator.ior, [err for _, _, err in output], {}
                )
                return output_directory, dataframe, errors_dictionary
            return None

        return wrapper

    return decorator


@open_mpi_handler(worker_inputs=["data_path"], root_worker_inputs={"verbose": True})
def translate(
    data_path: Union[str, List[str], pathlib.Path],
    output_directory: str,
    model_name: str = None,
    source_language: str = None,
    target_language: str = None,
    device: str = None,
    model_kwargs: dict = None,
    batch_size: int = 1,
    translation_kwargs: dict = None,
    verbose: bool = False,
) -> Tuple[str, pd.DataFrame, dict]:
    """
    Translate text files using a transformer model from Huggingface's hub according to the source and target languages
    given (or using the directly provided model name). The end result is a directory of translated text files and a
    dataframe containing the following columns:

    * text_file - The text file path.
    * translation_file - The translation text file name in the output directory.

    :param data_path:          A directory of text files or a single file or a list of files to translate.
    :param output_directory:   Directory where the translated files will be saved.
    :param model_name:         The name of a model to load. If None, the model name is constructed using the source and
                               target languages parameters.
    :param source_language:    The source language code (e.g., 'en' for English).
    :param target_language:    The target language code (e.g., 'en' for English).
    :param model_kwargs:       Keyword arguments to pass regarding the loading of the model in HuggingFace's `pipeline`
                               function.
    :param device:             The device index for transformers. Default will prefer cuda if available.
    :param batch_size:         The number of batches to use in translation. The files are translated one by one, but the
                               sentences can be batched.
    :param translation_kwargs: Additional keyword arguments to pass to a `transformers.TranslationPipeline` when doing
                               the translation inference. Notice the batch size here is being added automatically.
    :param verbose:            Whether to present logs of a progress bar and errors. Default: True.

    :returns: A tuple of:

              * Path to the output directory.
              * A dataframe dataset of the translated file names.
              * A dictionary of errored files that were not translated.
    """
    global _LOGGER

    # Get the input text files to translate:
    if verbose:
        _LOGGER.info("Collecting text files.")
    if isinstance(data_path, str):
        data_path = pathlib.Path(data_path).absolute()
        text_files = _get_text_files(data_path=data_path)
    else:
        text_files = data_path
    if verbose:
        _LOGGER.info(f"Collected {len(text_files)} text files.")

    # Get the translation pipeline:
    if verbose:
        _LOGGER.info(f"Loading model - using device '{device}'.")
    translation_pipeline, model_name = _get_translation_pipeline(
        model_name=model_name,
        source_language=source_language,
        target_language=target_language,
        device=device,
        model_kwargs=model_kwargs,
        batch_size=batch_size if batch_size != 1 else None,
    )
    if verbose:
        _LOGGER.info(f"Model '{model_name}' was loaded successfully.")

    # Prepare the successes dataframe and errors dictionary to be returned:
    successes = []
    errors = {}

    # Create the output directory:
    output_directory = pathlib.Path(output_directory)
    output_directory.mkdir(parents=True, exist_ok=True)

    # Prepare the translation keyword arguments:
    translation_kwargs = translation_kwargs or {}

    # Go over the audio files and transcribe:
    for text_file in tqdm(
        text_files, desc="Translating", unit="file", disable=not verbose
    ):
        try:
            # Translate:
            translation = _translate(
                text_file=text_file,
                translation_pipeline=translation_pipeline,
                translation_kwargs=translation_kwargs,
            )
            # Write the transcription to file:
            translation_file = _save_to_file(
                translation=translation,
                file_name=text_file.stem,
                output_directory=output_directory,
            )
            # Note as a success in the list:
            successes.append(
                [
                    text_file.name,
                    translation_file.name,
                ]
            )
        except Exception as exception:
            # Note the exception as error in the dictionary:
            if verbose:
                _LOGGER.warning(f"Error in file: '{text_file.name}'")
            errors[str(text_file.name)] = str(exception)
            continue

    # Construct the translations dataframe:
    columns = [
        "text_file",
        "translation_file",
    ]
    successes = pd.DataFrame(
        successes,
        columns=columns,
    )

    # Print the head of the produced dataframe and return:
    if verbose:
        _LOGGER.info(
            f"Done ({successes.shape[0]}/{len(text_files)})\n"
            f"Translations summary:\n"
            f"{successes.head()}"
        )
    return str(output_directory), successes, errors


def _get_text_files(
    data_path: pathlib.Path,
) -> List[pathlib.Path]:
    # Check if the path is of a directory or a file:
    if data_path.is_dir():
        # Get all files inside the directory:
        text_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        text_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be either a directory path or a file path. "
            f"Given: {str(data_path)} "
        )

    return text_files


def _get_translation_pipeline(
    model_name: str = None,
    source_language: str = None,
    target_language: str = None,
    device: str = None,
    model_kwargs: dict = None,
    batch_size: int = None,
) -> Tuple[transformers.Pipeline, str]:
    # Construct the model name - if model name is provided (not None) then we take it, otherwise we check both source
    # and target were provided to construct the model name:
    if model_name is None and (source_language is None or target_language is None):
        raise ValueError(
            "No model name were given and missing source and / or target languages. In order to translate you must "
            "pass a `model_name` or both `source_language` and `target_language`."
        )
    elif model_name is None:
        model_name = f"Helsinki-NLP/opus-mt-{source_language}-{target_language}"

    # Initialize the translation pipeline:
    try:
        translation_pipeline = transformers.pipeline(
            task="translation",
            model=model_name,
            tokenizer=model_name,
            device=device,
            model_kwargs=model_kwargs,
            batch_size=batch_size,
        )
    except OSError as load_exception:
        if (
            "is not a valid model identifier listed on 'https://huggingface.co/models'"
            in str(load_exception)
            and source_language
        ):
            raise ValueError(
                f"The model '{model_name}' is not a valid model identifier. "
                f"The parameters `source_language` and `target_language` are used to construct a Helsinki model for "
                f"text to text generation, but the model created from the given languages does not exist. "
                f"You may check language identifiers at "
                f"https://developers.google.com/admin-sdk/directory/v1/languages, and if the error was not fixed, one "
                f"or more language code might be with 3 letters and needs to be found online. "
                f"Remember, you can always choose a model directly from the Huggingface hub by using the `model_name` "
                f"parameter."
            ) from load_exception
        raise load_exception

    return translation_pipeline, model_name


def _translate(
    text_file: pathlib.Path,
    translation_pipeline: transformers.Pipeline,
    translation_kwargs: dict,
) -> str:
    # Read the text from file:
    with open(text_file, "r") as fp:
        text = fp.read()

    # Split to paragraphs and each paragraph to sentences:
    paragraphs = [paragraph.split(".") for paragraph in text.split("\n")]

    # Discover the newline indexes to restore the file to its structure post translation:
    newlines_indexes = []
    for paragraph in paragraphs[:-1]:
        if len(newlines_indexes) == 0:
            newlines_indexes.append(len(paragraph) - 1)
        else:
            newlines_indexes.append(newlines_indexes[-1] + len(paragraph))

    # Prepare the batches (each sentence from the paragraphs). Notice we add a dot not only to restore the sentence
    # structure but to ignore empty strings as it will ruin the translation:
    sentences = [f"{line}." for paragraph in paragraphs for line in paragraph]

    # Translate the sentences:
    translations = translation_pipeline(sentences, **translation_kwargs)

    # Restructure the full text from the sentences:
    translated_text = []
    newline_index = newlines_indexes.pop(0) if newlines_indexes else None
    for i, translation in enumerate(translations):
        # Get the translation:
        text = translation["translation_text"]
        # Validate if it was an empty sentence before:
        if text == ".":
            text = ""
        # Check if needed to insert a newline:
        if newline_index and newline_index == i:
            text += "\n"
            newline_index = newlines_indexes.pop(0) if newlines_indexes else None
        # Collect it:
        translated_text.append(text)
    translated_text = "".join(translated_text)

    return translated_text


def _save_to_file(
    translation: str, file_name: str, output_directory: pathlib.Path
) -> pathlib.Path:
    # Prepare the file full path (checking for no duplications):
    translation_file = output_directory / f"{file_name}.txt"
    i = 1
    while translation_file.exists():
        i += 1
        translation_file = output_directory / f"{file_name}_{i}.txt"

    # Make sure all directories are created:
    translation_file.parent.mkdir(exist_ok=True, parents=True)

    # Write to file:
    with open(translation_file, "w") as fp:
        fp.write(translation)

    return translation_file
 - origin_filename: '' - base_image: mlrun/mlrun - requirements: - - transformers - - sentencepiece - - torch>=2.6 - - tqdm - code_origin: '' -kind: job -metadata: - tag: '' - categories: - - genai - - NLP - name: translate diff --git a/functions/src/translate/item.yaml b/functions/src/translate/item.yaml index 68f176ac2..24424748b 100644 --- a/functions/src/translate/item.yaml +++ b/functions/src/translate/item.yaml @@ -12,7 +12,7 @@ labels: author: Iguazio maintainers: [] marketplaceType: '' -mlrunVersion: 1.10.0-rc41 +mlrunVersion: 1.10.0 name: translate platformVersion: 3.5.3 spec: diff --git a/functions/src/translate/test_translate.py b/functions/src/translate/test_translate.py index a22dc899a..e56572546 100644 --- a/functions/src/translate/test_translate.py +++ b/functions/src/translate/test_translate.py @@ -19,7 +19,9 @@ def test_translate(): project = mlrun.new_project("test-translate") - translate_fn = project.set_function("translate.py", "translate", image="mlrun/mlrun") + translate_fn = project.set_function( + "translate.py", "translate", image="mlrun/mlrun" + ) input_text = "Ali her gece bir kitap okur." expected_translation = "Ali reads a book every night." @@ -48,4 +50,3 @@ def test_translate(): assert translate_run.status.state == "completed" with open(os.path.join(test_dir, "test_tr.txt")) as f: assert f.read() == expected_translation - diff --git a/functions/src/translate/translate.py b/functions/src/translate/translate.py index 360fa6203..a5e05f2d2 100644 --- a/functions/src/translate/translate.py +++ b/functions/src/translate/translate.py @@ -16,7 +16,7 @@ import operator import pathlib from functools import reduce, wraps -from typing import Any, Dict, List, Tuple, Union +from typing import Any import pandas as pd import transformers @@ -26,7 +26,7 @@ _LOGGER = logging.getLogger() -def _check_mlrun_and_open_mpi() -> Tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intracomm"]: +def _check_mlrun_and_open_mpi() -> tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intracomm"]: is_mpi = False try: import mlrun @@ -54,7 +54,7 @@ def _check_mlrun_and_open_mpi() -> Tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intrac def open_mpi_handler( - worker_inputs: List[str], root_worker_inputs: Dict[str, Any] = None + worker_inputs: list[str], root_worker_inputs: dict[str, Any] = None ): global _LOGGER @@ -133,7 +133,7 @@ def wrapper(**kwargs): @open_mpi_handler(worker_inputs=["data_path"], root_worker_inputs={"verbose": True}) def translate( - data_path: Union[str, List[str], pathlib.Path], + data_path: str | list[str] | pathlib.Path, output_directory: str, model_name: str = None, source_language: str = None, @@ -143,7 +143,7 @@ def translate( batch_size: int = 1, translation_kwargs: dict = None, verbose: bool = False, -) -> Tuple[str, pd.DataFrame, dict]: +) -> tuple[str, pd.DataFrame, dict]: """ Translate text files using a transformer model from Huggingface's hub according to the source and target languages given (or using the directly provided model name). The end result is a directory of translated text files and a @@ -264,7 +264,7 @@ def translate( def _get_text_files( data_path: pathlib.Path, -) -> List[pathlib.Path]: +) -> list[pathlib.Path]: # Check if the path is of a directory or a file: if data_path.is_dir(): # Get all files inside the directory: @@ -287,7 +287,7 @@ def _get_translation_pipeline( device: str = None, model_kwargs: dict = None, batch_size: int = None, -) -> Tuple[transformers.Pipeline, str]: +) -> tuple[transformers.Pipeline, str]: # Construct the model name - if model name is provided (not None) then we take it, otherwise we check both source # and target were provided to construct the model name: if model_name is None and (source_language is None or target_language is None): @@ -335,7 +335,7 @@ def _translate( translation_kwargs: dict, ) -> str: # Read the text from file: - with open(text_file, "r") as fp: + with open(text_file) as fp: text = fp.read() # Split to paragraphs and each paragraph to sentences: diff --git a/functions/src/v2_model_server/function.yaml b/functions/src/v2_model_server/function.yaml index 5ecfec9ba..4a2b6dd81 100644 --- a/functions/src/v2_model_server/function.yaml +++ b/functions/src/v2_model_server/function.yaml @@ -1,87 +1,29 @@ -kind: serving metadata: - name: v2-model-server tag: '' - hash: ad85919d3b9cf2acae43a3434ba56e01b005755e - project: '' - labels: - author: Iguazio - framework: sklearn + name: v2-model-server categories: - model-serving - machine-learning +verbose: false +kind: serving spec: - command: '' - args: [] image: mlrun/mlrun - entry_points: - load: - name: load - doc: load and initialize the model and/or other elements - parameters: - - name: self - default: '' - outputs: - - default: '' - lineno: 16 - predict: - name: predict - doc: Generate model predictions from sample. - parameters: - - name: self - default: '' - - name: body - type: dict - default: '' - outputs: - - default: '' - type: List - lineno: 21 - init_context: - name: init_context - doc: '' - parameters: - - name: context - default: '' - outputs: - - default: '' - lineno: 39 - handler: - name: handler - doc: '' - parameters: - - name: context - default: '' - - name: event - default: '' - outputs: - - default: '' - lineno: 42 - description: generic sklearn model server + disable_auto_mount: false + build: + origin_filename: '' + functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKIyBHZW5lcmF0ZWQgYnkgbnVjbGlvLmV4cG9ydC5OdWNsaW9FeHBvcnRlcgoKaW1wb3J0IHdhcm5pbmdzCgppbXBvcnQgbWxydW4KaW1wb3J0IG51bXB5IGFzIG5wCmZyb20gY2xvdWRwaWNrbGUgaW1wb3J0IGxvYWQKCndhcm5pbmdzLmZpbHRlcndhcm5pbmdzKCJpZ25vcmUiKQoKCmNsYXNzIENsYXNzaWZpZXJNb2RlbChtbHJ1bi5zZXJ2aW5nLlYyTW9kZWxTZXJ2ZXIpOgogICAgZGVmIGxvYWQoc2VsZik6CiAgICAgICAgIiIibG9hZCBhbmQgaW5pdGlhbGl6ZSB0aGUgbW9kZWwgYW5kL29yIG90aGVyIGVsZW1lbnRzIiIiCiAgICAgICAgbW9kZWxfZmlsZSwgZXh0cmFfZGF0YSA9IHNlbGYuZ2V0X21vZGVsKCIucGtsIikKICAgICAgICBzZWxmLm1vZGVsID0gbG9hZChvcGVuKG1vZGVsX2ZpbGUsICJyYiIpKQoKICAgIGRlZiBwcmVkaWN0KHNlbGYsIGJvZHk6IGRpY3QpIC0+IGxpc3Q6CiAgICAgICAgIiIiR2VuZXJhdGUgbW9kZWwgcHJlZGljdGlvbnMgZnJvbSBzYW1wbGUuIiIiCiAgICAgICAgZmVhdHMgPSBucC5hc2FycmF5KGJvZHlbImlucHV0cyJdKQogICAgICAgIHJlc3VsdDogbnAubmRhcnJheSA9IHNlbGYubW9kZWwucHJlZGljdChmZWF0cykKICAgICAgICByZXR1cm4gcmVzdWx0LnRvbGlzdCgpCgpmcm9tIG1scnVuLnJ1bnRpbWVzIGltcG9ydCBudWNsaW9faW5pdF9ob29rCmRlZiBpbml0X2NvbnRleHQoY29udGV4dCk6CiAgICBudWNsaW9faW5pdF9ob29rKGNvbnRleHQsIGdsb2JhbHMoKSwgJ3NlcnZpbmdfdjInKQoKZGVmIGhhbmRsZXIoY29udGV4dCwgZXZlbnQpOgogICAgcmV0dXJuIGNvbnRleHQubWxydW5faGFuZGxlcihjb250ZXh0LCBldmVudCkK + code_origin: '' + filename: v2_model_server.py + default_class: ClassifierModel min_replicas: 1 - max_replicas: 4 - env: [] - base_spec: - apiVersion: nuclio.io/v1 - kind: Function - metadata: - name: v2-model-server - labels: {} - annotations: - nuclio.io/generated_by: function generated from /home/michaell/projects/functions/v2_model_server/v2_model_server.py - spec: - runtime: python:3.9 - handler: v2_model_server:handler - env: [] - volumes: [] - build: - commands: [] - noBaseImagesPull: true - functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKIyBHZW5lcmF0ZWQgYnkgbnVjbGlvLmV4cG9ydC5OdWNsaW9FeHBvcnRlcgoKaW1wb3J0IG1scnVuCgpmcm9tIGNsb3VkcGlja2xlIGltcG9ydCBsb2FkCmZyb20gdHlwaW5nIGltcG9ydCBMaXN0CmZyb20gc2tsZWFybi5kYXRhc2V0cyBpbXBvcnQgbG9hZF9pcmlzCmltcG9ydCBudW1weSBhcyBucAoKaW1wb3J0IHdhcm5pbmdzCgp3YXJuaW5ncy5maWx0ZXJ3YXJuaW5ncygiaWdub3JlIikKCgpjbGFzcyBDbGFzc2lmaWVyTW9kZWwobWxydW4uc2VydmluZy5WMk1vZGVsU2VydmVyKToKICAgIGRlZiBsb2FkKHNlbGYpOgogICAgICAgICIiImxvYWQgYW5kIGluaXRpYWxpemUgdGhlIG1vZGVsIGFuZC9vciBvdGhlciBlbGVtZW50cyIiIgogICAgICAgIG1vZGVsX2ZpbGUsIGV4dHJhX2RhdGEgPSBzZWxmLmdldF9tb2RlbCgiLnBrbCIpCiAgICAgICAgc2VsZi5tb2RlbCA9IGxvYWQob3Blbihtb2RlbF9maWxlLCAicmIiKSkKCiAgICBkZWYgcHJlZGljdChzZWxmLCBib2R5OiBkaWN0KSAtPiBMaXN0OgogICAgICAgICIiIkdlbmVyYXRlIG1vZGVsIHByZWRpY3Rpb25zIGZyb20gc2FtcGxlLiIiIgogICAgICAgIGZlYXRzID0gbnAuYXNhcnJheShib2R5WyJpbnB1dHMiXSkKICAgICAgICByZXN1bHQ6IG5wLm5kYXJyYXkgPSBzZWxmLm1vZGVsLnByZWRpY3QoZmVhdHMpCiAgICAgICAgcmV0dXJuIHJlc3VsdC50b2xpc3QoKQpmcm9tIG1scnVuLnJ1bnRpbWVzIGltcG9ydCBudWNsaW9faW5pdF9ob29rCmRlZiBpbml0X2NvbnRleHQoY29udGV4dCk6CiAgICBudWNsaW9faW5pdF9ob29rKGNvbnRleHQsIGdsb2JhbHMoKSwgJ3NlcnZpbmdfdjInKQoKZGVmIGhhbmRsZXIoY29udGV4dCwgZXZlbnQpOgogICAgcmV0dXJuIGNvbnRleHQubWxydW5faGFuZGxlcihjb250ZXh0LCBldmVudCkK + command: '' + default_handler: '' source: '' + max_replicas: 4 + base_image_pull: false + description: generic sklearn model server function_kind: serving_v2 - default_class: ClassifierModel - build: - commands: [] - code_origin: https://github.com/Michaelliv/functions.git#0e79859b0adccb92a9b65b02d438ed3dfa3e785f:/home/michaell/projects/functions/v2_model_server/v2_model_server.py -verbose: false + function_handler: v2-model-server-nuclio:handler + env: + - name: MLRUN_HTTPDB__NUCLIO__EXPLICIT_ACK + value: enabled diff --git a/functions/src/v2_model_server/v2_model_server.py b/functions/src/v2_model_server/v2_model_server.py index 572f1680d..d2d54793d 100644 --- a/functions/src/v2_model_server/v2_model_server.py +++ b/functions/src/v2_model_server/v2_model_server.py @@ -14,14 +14,11 @@ # # Generated by nuclio.export.NuclioExporter -import mlrun +import warnings -from cloudpickle import load -from typing import List -from sklearn.datasets import load_iris +import mlrun import numpy as np - -import warnings +from cloudpickle import load warnings.filterwarnings("ignore") @@ -32,7 +29,7 @@ def load(self): model_file, extra_data = self.get_model(".pkl") self.model = load(open(model_file, "rb")) - def predict(self, body: dict) -> List: + def predict(self, body: dict) -> list: """Generate model predictions from sample.""" feats = np.asarray(body["inputs"]) result: np.ndarray = self.model.predict(feats) diff --git a/functions/src/v2_model_tester/function.yaml b/functions/src/v2_model_tester/function.yaml index c9562b097..c70ec5e49 100644 --- a/functions/src/v2_model_tester/function.yaml +++ b/functions/src/v2_model_tester/function.yaml @@ -1,35 +1,29 @@ -kind: job metadata: - name: v2-model-tester tag: '' - hash: 72d3f664ff2aa870109e44f52f975bda2ac13682 - project: '' - labels: - author: Iguazio + name: v2-model-tester categories: - model-testing - machine-learning +verbose: false +kind: job spec: - command: '' - args: [] image: mlrun/mlrun - env: [] - default_handler: model_server_tester + disable_auto_mount: false + build: + origin_filename: '' + functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKIyBHZW5lcmF0ZWQgYnkgbnVjbGlvLmV4cG9ydC5OdWNsaW9FeHBvcnRlcgoKaW1wb3J0IGpzb24KZnJvbSBkYXRldGltZSBpbXBvcnQgZGF0ZXRpbWUKCmltcG9ydCBudW1weSBhcyBucAppbXBvcnQgcmVxdWVzdHMKZnJvbSBtbHJ1bi5hcnRpZmFjdHMgaW1wb3J0IENoYXJ0QXJ0aWZhY3QKZnJvbSBtbHJ1bi5kYXRhc3RvcmUgaW1wb3J0IERhdGFJdGVtCgoKZGVmIG1vZGVsX3NlcnZlcl90ZXN0ZXIoCiAgICBjb250ZXh0LAogICAgdGFibGU6IERhdGFJdGVtLAogICAgYWRkcjogc3RyLAogICAgbGFiZWxfY29sdW1uOiBzdHIgPSAibGFiZWwiLAogICAgbW9kZWw6IHN0ciA9ICIiLAogICAgbWF0Y2hfZXJyOiBib29sID0gRmFsc2UsCiAgICByb3dzOiBpbnQgPSAyMCwKKToKICAgICIiIlRlc3QgYSBtb2RlbCBzZXJ2ZXIKCiAgICA6cGFyYW0gdGFibGU6ICAgICAgICAgY3N2L3BhcnF1ZXQgdGFibGUgd2l0aCB0ZXN0IGRhdGEKICAgIDpwYXJhbSBhZGRyOiAgICAgICAgICBmdW5jdGlvbiBhZGRyZXNzL3VybAogICAgOnBhcmFtIGxhYmVsX2NvbHVtbjogIG5hbWUgb2YgdGhlIGxhYmVsIGNvbHVtbiBpbiB0YWJsZQogICAgOnBhcmFtIG1vZGVsOiAgICAgICAgIHRlc3RlZCBtb2RlbCBuYW1lCiAgICA6cGFyYW0gbWF0Y2hfZXJyOiAgICAgcmFpc2UgZXJyb3Igb24gdmFsaWRhdGlvbiAocmVxdWlyZSBwcm9wZXIgdGVzdCBzZXQpCiAgICA6cGFyYW0gcm93czogICAgICAgICAgbnVtYmVyIG9mIHJvd3MgdG8gdXNlIGZyb20gdGVzdCBzZXQKICAgICIiIgoKICAgIHRhYmxlID0gdGFibGUuYXNfZGYoKQoKICAgIHlfbGlzdCA9IHRhYmxlLnBvcChsYWJlbF9jb2x1bW4pLnZhbHVlcy50b2xpc3QoKQogICAgY29udGV4dC5sb2dnZXIuaW5mbyhmInRlc3Rpbmcgd2l0aCBkYXRhc2V0IGFnYWluc3Qge2FkZHJ9LCBtb2RlbDoge21vZGVsfSIpCiAgICBpZiByb3dzIGFuZCByb3dzIDwgdGFibGUuc2hhcGVbMF06CiAgICAgICAgdGFibGUgPSB0YWJsZS5zYW1wbGUocm93cykKCiAgICBjb3VudCA9IGVycl9jb3VudCA9IG1hdGNoID0gMAogICAgdGltZXMgPSBbXQogICAgZm9yIHgsIHkgaW4gemlwKHRhYmxlLnZhbHVlcywgeV9saXN0KToKICAgICAgICBjb3VudCArPSAxCiAgICAgICAgZXZlbnRfZGF0YSA9IGpzb24uZHVtcHMoeyJpbnB1dHMiOiBbeC50b2xpc3QoKV19KQogICAgICAgIGhhZF9lcnIgPSBGYWxzZQogICAgICAgIHRyeToKICAgICAgICAgICAgc3RhcnQgPSBkYXRldGltZS5ub3coKQogICAgICAgICAgICByZXNwID0gcmVxdWVzdHMucHV0KGYie2FkZHJ9L3YyL21vZGVscy97bW9kZWx9L2luZmVyIiwganNvbj1ldmVudF9kYXRhKQogICAgICAgICAgICBpZiBub3QgcmVzcC5vazoKICAgICAgICAgICAgICAgIGNvbnRleHQubG9nZ2VyLmVycm9yKGYiYmFkIGZ1bmN0aW9uIHJlc3AhIVxue3Jlc3AudGV4dH0iKQogICAgICAgICAgICAgICAgZXJyX2NvdW50ICs9IDEKICAgICAgICAgICAgICAgIGNvbnRpbnVlCiAgICAgICAgICAgIHRpbWVzLmFwcGVuZCgoZGF0ZXRpbWUubm93KCkgLSBzdGFydCkubWljcm9zZWNvbmRzKQoKICAgICAgICBleGNlcHQgT1NFcnJvciBhcyBlcnI6CiAgICAgICAgICAgIGNvbnRleHQubG9nZ2VyLmVycm9yKGYiZXJyb3IgaW4gcmVxdWVzdCwgZGF0YTp7ZXZlbnRfZGF0YX0sIGVycm9yOiB7ZXJyfSIpCiAgICAgICAgICAgIGVycl9jb3VudCArPSAxCiAgICAgICAgICAgIGNvbnRpbnVlCgogICAgICAgIHJlc3BfZGF0YSA9IHJlc3AuanNvbigpCiAgICAgICAgcHJpbnQocmVzcF9kYXRhKQogICAgICAgIHlfcmVzcCA9IHJlc3BfZGF0YVsib3V0cHV0cyJdWzBdCiAgICAgICAgaWYgeSA9PSB5X3Jlc3A6CiAgICAgICAgICAgIG1hdGNoICs9IDEKCiAgICBjb250ZXh0LmxvZ19yZXN1bHQoInRvdGFsX3Rlc3RzIiwgY291bnQpCiAgICBjb250ZXh0LmxvZ19yZXN1bHQoImVycm9ycyIsIGVycl9jb3VudCkKICAgIGNvbnRleHQubG9nX3Jlc3VsdCgibWF0Y2giLCBtYXRjaCkKICAgIGlmIGNvdW50IC0gZXJyX2NvdW50ID4gMDoKICAgICAgICB0aW1lc19hcnIgPSBucC5hcnJheSh0aW1lcykKICAgICAgICBjb250ZXh0LmxvZ19yZXN1bHQoImF2Z19sYXRlbmN5IiwgaW50KG5wLm1lYW4odGltZXNfYXJyKSkpCiAgICAgICAgY29udGV4dC5sb2dfcmVzdWx0KCJtaW5fbGF0ZW5jeSIsIGludChucC5hbWluKHRpbWVzX2FycikpKQogICAgICAgIGNvbnRleHQubG9nX3Jlc3VsdCgibWF4X2xhdGVuY3kiLCBpbnQobnAuYW1heCh0aW1lc19hcnIpKSkKCiAgICAgICAgY2hhcnQgPSBDaGFydEFydGlmYWN0KCJsYXRlbmN5IiwgaGVhZGVyPVsiVGVzdCIsICJMYXRlbmN5IChtaWNyb3NlYykiXSkKICAgICAgICBmb3IgaSBpbiByYW5nZShsZW4odGltZXMpKToKICAgICAgICAgICAgY2hhcnQuYWRkX3JvdyhbaSArIDEsIGludCh0aW1lc1tpXSldKQogICAgICAgIGNvbnRleHQubG9nX2FydGlmYWN0KGNoYXJ0KQoKICAgIGNvbnRleHQubG9nZ2VyLmluZm8oCiAgICAgICAgZiJydW4ge2NvdW50fSB0ZXN0cywge2Vycl9jb3VudH0gZXJyb3JzIGFuZCB7bWF0Y2h9IG1hdGNoIGV4cGVjdGVkIHZhbHVlIgogICAgKQoKICAgIGlmIGVycl9jb3VudDoKICAgICAgICByYWlzZSBWYWx1ZUVycm9yKGYiZmFpbGVkIG9uIHtlcnJfY291bnR9IHRlc3RzIG9mIHtjb3VudH0iKQoKICAgIGlmIG1hdGNoX2VyciBhbmQgbWF0Y2ggIT0gY291bnQ6CiAgICAgICAgcmFpc2UgVmFsdWVFcnJvcihmIm9ubHkge21hdGNofSByZXN1bHRzIG1hdGNoIG91dCBvZiB7Y291bnR9IikK + code_origin: '' + filename: v2_model_tester.py entry_points: model_server_tester: - name: model_server_tester - doc: Test a model server parameters: - name: context - default: '' - name: table type: DataItem doc: csv/parquet table with test data - default: '' - name: addr type: str doc: function address/url - default: '' - name: label_column type: str doc: name of the label column in table @@ -46,13 +40,11 @@ spec: type: int doc: number of rows to use from test set default: 20 - outputs: - - default: '' - lineno: 13 + name: model_server_tester + doc: Test a model server + has_kwargs: false + has_varargs: false + lineno: 26 + command: '' description: test v2 model servers - build: - functionSourceCode: IyBHZW5lcmF0ZWQgYnkgbnVjbGlvLmV4cG9ydC5OdWNsaW9FeHBvcnRlcgoKaW1wb3J0IG9zCmltcG9ydCBwYW5kYXMgYXMgcGQKaW1wb3J0IHJlcXVlc3RzCmltcG9ydCBqc29uCmltcG9ydCBudW1weSBhcyBucApmcm9tIGRhdGV0aW1lIGltcG9ydCBkYXRldGltZQpmcm9tIG1scnVuLmRhdGFzdG9yZSBpbXBvcnQgRGF0YUl0ZW0KZnJvbSBtbHJ1bi5hcnRpZmFjdHMgaW1wb3J0IENoYXJ0QXJ0aWZhY3QKCgpkZWYgbW9kZWxfc2VydmVyX3Rlc3RlcigKICAgIGNvbnRleHQsCiAgICB0YWJsZTogRGF0YUl0ZW0sCiAgICBhZGRyOiBzdHIsCiAgICBsYWJlbF9jb2x1bW46IHN0ciA9ICJsYWJlbCIsCiAgICBtb2RlbDogc3RyID0gIiIsCiAgICBtYXRjaF9lcnI6IGJvb2wgPSBGYWxzZSwKICAgIHJvd3M6IGludCA9IDIwLAopOgogICAgIiIiVGVzdCBhIG1vZGVsIHNlcnZlcgoKICAgIDpwYXJhbSB0YWJsZTogICAgICAgICBjc3YvcGFycXVldCB0YWJsZSB3aXRoIHRlc3QgZGF0YQogICAgOnBhcmFtIGFkZHI6ICAgICAgICAgIGZ1bmN0aW9uIGFkZHJlc3MvdXJsCiAgICA6cGFyYW0gbGFiZWxfY29sdW1uOiAgbmFtZSBvZiB0aGUgbGFiZWwgY29sdW1uIGluIHRhYmxlCiAgICA6cGFyYW0gbW9kZWw6ICAgICAgICAgdGVzdGVkIG1vZGVsIG5hbWUKICAgIDpwYXJhbSBtYXRjaF9lcnI6ICAgICByYWlzZSBlcnJvciBvbiB2YWxpZGF0aW9uIChyZXF1aXJlIHByb3BlciB0ZXN0IHNldCkKICAgIDpwYXJhbSByb3dzOiAgICAgICAgICBudW1iZXIgb2Ygcm93cyB0byB1c2UgZnJvbSB0ZXN0IHNldAogICAgIiIiCgogICAgdGFibGUgPSB0YWJsZS5hc19kZigpCgogICAgeV9saXN0ID0gdGFibGUucG9wKGxhYmVsX2NvbHVtbikudmFsdWVzLnRvbGlzdCgpCiAgICBjb250ZXh0LmxvZ2dlci5pbmZvKGYidGVzdGluZyB3aXRoIGRhdGFzZXQgYWdhaW5zdCB7YWRkcn0sIG1vZGVsOiB7bW9kZWx9IikKICAgIGlmIHJvd3MgYW5kIHJvd3MgPCB0YWJsZS5zaGFwZVswXToKICAgICAgICB0YWJsZSA9IHRhYmxlLnNhbXBsZShyb3dzKQoKICAgIGNvdW50ID0gZXJyX2NvdW50ID0gbWF0Y2ggPSAwCiAgICB0aW1lcyA9IFtdCiAgICBmb3IgeCwgeSBpbiB6aXAodGFibGUudmFsdWVzLCB5X2xpc3QpOgogICAgICAgIGNvdW50ICs9IDEKICAgICAgICBldmVudF9kYXRhID0ganNvbi5kdW1wcyh7ImlucHV0cyI6IFt4LnRvbGlzdCgpXX0pCiAgICAgICAgaGFkX2VyciA9IEZhbHNlCiAgICAgICAgdHJ5OgogICAgICAgICAgICBzdGFydCA9IGRhdGV0aW1lLm5vdygpCiAgICAgICAgICAgIHJlc3AgPSByZXF1ZXN0cy5wdXQoZiJ7YWRkcn0vdjIvbW9kZWxzL3ttb2RlbH0vaW5mZXIiLCBqc29uPWV2ZW50X2RhdGEpCiAgICAgICAgICAgIGlmIG5vdCByZXNwLm9rOgogICAgICAgICAgICAgICAgY29udGV4dC5sb2dnZXIuZXJyb3IoZiJiYWQgZnVuY3Rpb24gcmVzcCEhXG57cmVzcC50ZXh0fSIpCiAgICAgICAgICAgICAgICBlcnJfY291bnQgKz0gMQogICAgICAgICAgICAgICAgY29udGludWUKICAgICAgICAgICAgdGltZXMuYXBwZW5kKChkYXRldGltZS5ub3coKSAtIHN0YXJ0KS5taWNyb3NlY29uZHMpCgogICAgICAgIGV4Y2VwdCBPU0Vycm9yIGFzIGVycjoKICAgICAgICAgICAgY29udGV4dC5sb2dnZXIuZXJyb3IoZiJlcnJvciBpbiByZXF1ZXN0LCBkYXRhOntldmVudF9kYXRhfSwgZXJyb3I6IHtlcnJ9IikKICAgICAgICAgICAgZXJyX2NvdW50ICs9IDEKICAgICAgICAgICAgY29udGludWUKCiAgICAgICAgcmVzcF9kYXRhID0gcmVzcC5qc29uKCkKICAgICAgICBwcmludChyZXNwX2RhdGEpCiAgICAgICAgeV9yZXNwID0gcmVzcF9kYXRhWyJvdXRwdXRzIl1bMF0KICAgICAgICBpZiB5ID09IHlfcmVzcDoKICAgICAgICAgICAgbWF0Y2ggKz0gMQoKICAgIGNvbnRleHQubG9nX3Jlc3VsdCgidG90YWxfdGVzdHMiLCBjb3VudCkKICAgIGNvbnRleHQubG9nX3Jlc3VsdCgiZXJyb3JzIiwgZXJyX2NvdW50KQogICAgY29udGV4dC5sb2dfcmVzdWx0KCJtYXRjaCIsIG1hdGNoKQogICAgaWYgY291bnQgLSBlcnJfY291bnQgPiAwOgogICAgICAgIHRpbWVzX2FyciA9IG5wLmFycmF5KHRpbWVzKQogICAgICAgIGNvbnRleHQubG9nX3Jlc3VsdCgiYXZnX2xhdGVuY3kiLCBpbnQobnAubWVhbih0aW1lc19hcnIpKSkKICAgICAgICBjb250ZXh0LmxvZ19yZXN1bHQoIm1pbl9sYXRlbmN5IiwgaW50KG5wLmFtaW4odGltZXNfYXJyKSkpCiAgICAgICAgY29udGV4dC5sb2dfcmVzdWx0KCJtYXhfbGF0ZW5jeSIsIGludChucC5hbWF4KHRpbWVzX2FycikpKQoKICAgICAgICBjaGFydCA9IENoYXJ0QXJ0aWZhY3QoImxhdGVuY3kiLCBoZWFkZXI9WyJUZXN0IiwgIkxhdGVuY3kgKG1pY3Jvc2VjKSJdKQogICAgICAgIGZvciBpIGluIHJhbmdlKGxlbih0aW1lcykpOgogICAgICAgICAgICBjaGFydC5hZGRfcm93KFtpICsgMSwgaW50KHRpbWVzW2ldKV0pCiAgICAgICAgY29udGV4dC5sb2dfYXJ0aWZhY3QoY2hhcnQpCgogICAgY29udGV4dC5sb2dnZXIuaW5mbygKICAgICAgICBmInJ1biB7Y291bnR9IHRlc3RzLCB7ZXJyX2NvdW50fSBlcnJvcnMgYW5kIHttYXRjaH0gbWF0Y2ggZXhwZWN0ZWQgdmFsdWUiCiAgICApCgogICAgaWYgZXJyX2NvdW50OgogICAgICAgIHJhaXNlIFZhbHVlRXJyb3IoZiJmYWlsZWQgb24ge2Vycl9jb3VudH0gdGVzdHMgb2Yge2NvdW50fSIpCgogICAgaWYgbWF0Y2hfZXJyIGFuZCBtYXRjaCAhPSBjb3VudDoKICAgICAgICByYWlzZSBWYWx1ZUVycm9yKGYib25seSB7bWF0Y2h9IHJlc3VsdHMgbWF0Y2ggb3V0IG9mIHtjb3VudH0iKQo= - commands: [] - code_origin: https://github.com/daniels290813/functions.git#55a79c32be5d233cc11efcf40cd3edbe309bfdef:/home/kali/functions/v2_model_tester/v2_model_tester.py - affinity: null -verbose: false + default_handler: model_server_tester diff --git a/functions/src/v2_model_tester/v2_model_tester.py b/functions/src/v2_model_tester/v2_model_tester.py index 74590acdc..3d41ad37b 100644 --- a/functions/src/v2_model_tester/v2_model_tester.py +++ b/functions/src/v2_model_tester/v2_model_tester.py @@ -14,14 +14,13 @@ # # Generated by nuclio.export.NuclioExporter -import os -import pandas as pd -import requests import json -import numpy as np from datetime import datetime -from mlrun.datastore import DataItem + +import numpy as np +import requests from mlrun.artifacts import ChartArtifact +from mlrun.datastore import DataItem def model_server_tester( diff --git a/modules/src/agent_deployer/agent_deployer.py b/modules/src/agent_deployer/agent_deployer.py index 9af0dd632..9a4ab415a 100644 --- a/modules/src/agent_deployer/agent_deployer.py +++ b/modules/src/agent_deployer/agent_deployer.py @@ -12,18 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional import os import mlrun.errors -from mlrun import get_current_project, code_to_function, mlconf -from mlrun.runtimes import ServingRuntime -from mlrun.serving import ModelRunnerStep +from mlrun import code_to_function, get_current_project, mlconf from mlrun.datastore.datastore_profile import ( - DatastoreProfileV3io, DatastoreProfileKafkaStream, DatastoreProfileTDEngine, + DatastoreProfileV3io, ) +from mlrun.runtimes import ServingRuntime +from mlrun.serving import ModelRunnerStep from mlrun.utils import logger @@ -33,10 +32,10 @@ def __init__( agent_name: str, model_class_name: str, function: str, - result_path: Optional[str] = None, - inputs_path: Optional[str] = None, - outputs: Optional[list[str]] = None, - requirements: Optional[list[str]] = None, + result_path: str | None = None, + inputs_path: str | None = None, + outputs: list[str] | None = None, + requirements: list[str] | None = None, image: str = "mlrun/mlrun", set_model_monitoring: bool = False, **model_params, diff --git a/modules/src/agent_deployer/test_agent_deployer.py b/modules/src/agent_deployer/test_agent_deployer.py index 0bb3adc8b..7700bfeea 100644 --- a/modules/src/agent_deployer/test_agent_deployer.py +++ b/modules/src/agent_deployer/test_agent_deployer.py @@ -14,14 +14,13 @@ # import unittest -from unittest.mock import patch, MagicMock -from agent_deployer import AgentDeployer -import mlrun.errors +from unittest.mock import MagicMock, patch +import mlrun.errors +from agent_deployer import AgentDeployer class TestAgentDeployer(unittest.TestCase): - def setUp(self): # Common parameters for a minimal AgentDeployer instance self.deployer_params = { @@ -33,7 +32,9 @@ def setUp(self): # --- Test Cases for Properties --- - @patch('agent_deployer.get_current_project') # Patch the import in the *module* you are testing + @patch( + "agent_deployer.get_current_project" + ) # Patch the import in the *module* you are testing def test_project_property_returns_project(self, mock_get_current_project): """Test that the project property returns the project if it exists.""" mock_proj = MagicMock() @@ -42,13 +43,13 @@ def test_project_property_returns_project(self, mock_get_current_project): self.assertEqual(self.deployer.project, mock_proj) mock_get_current_project.assert_called_once_with(silent=True) - @patch('agent_deployer.get_current_project', return_value=None) + @patch("agent_deployer.get_current_project", return_value=None) def test_project_name_raises_error_if_no_project(self, mock_get_current_project): """Test that project_name raises an error when no project is found.""" with self.assertRaises(mlrun.errors.MLRunInvalidArgumentError): _ = self.deployer.project_name - @patch('agent_deployer.get_current_project') + @patch("agent_deployer.get_current_project") def test_project_name_returns_name(self, mock_get_current_project): """Test that project_name correctly retrieves the name from the project metadata.""" mock_proj = MagicMock() @@ -57,15 +58,18 @@ def test_project_name_returns_name(self, mock_get_current_project): self.assertEqual(self.deployer.project_name, "test-project-name") - - @patch('agent_deployer.AgentDeployer.project', new_callable=unittest.mock.PropertyMock) + @patch( + "agent_deployer.AgentDeployer.project", new_callable=unittest.mock.PropertyMock + ) def test_configure_model_monitoring_handles_conflict_error(self, mock_project_prop): """Test that the method handles expected exceptions during enable_model_monitoring.""" mock_project = MagicMock() # Simulate an expected error that should be caught and passed over - mock_project.enable_model_monitoring.side_effect = mlrun.errors.MLRunConflictError("Already deployed") + mock_project.enable_model_monitoring.side_effect = ( + mlrun.errors.MLRunConflictError("Already deployed") + ) mock_project_prop.return_value = mock_project # This should run without raising an uncaught exception self.deployer.configure_model_monitoring() - mock_project.enable_model_monitoring.assert_called_once() \ No newline at end of file + mock_project.enable_model_monitoring.assert_called_once() diff --git a/modules/src/count_events/count_events.py b/modules/src/count_events/count_events.py index 1c6d97621..4f04366ac 100644 --- a/modules/src/count_events/count_events.py +++ b/modules/src/count_events/count_events.py @@ -13,21 +13,22 @@ # limitations under the License. # +import mlrun.model_monitoring.applications.context as mm_context from mlrun.model_monitoring.applications import ( - ModelMonitoringApplicationBase, ModelMonitoringApplicationMetric, + ModelMonitoringApplicationBase, + ModelMonitoringApplicationMetric, ) -import mlrun.model_monitoring.applications.context as mm_context class CountApp(ModelMonitoringApplicationBase): """ Model Monitoring Application that counts the number of events in the given time window. """ + def do_tracking( - self, - monitoring_context: mm_context.MonitoringApplicationContext + self, monitoring_context: mm_context.MonitoringApplicationContext ) -> ModelMonitoringApplicationMetric: - """" + """ " he do_tracking method implementation for the CountApp class. It counts the number of events in the sample data-frame and logs the count. @@ -47,4 +48,4 @@ def do_tracking( return ModelMonitoringApplicationMetric( name="count", value=count, - ) \ No newline at end of file + ) diff --git a/modules/src/count_events/item.yaml b/modules/src/count_events/item.yaml index 049651ddb..723ebc4a9 100644 --- a/modules/src/count_events/item.yaml +++ b/modules/src/count_events/item.yaml @@ -7,7 +7,7 @@ generationDate: 2025-09-16:12-25 hidden: false labels: author: Iguazio -mlrunVersion: 1.10.0-rc41 +mlrunVersion: 1.10.0 name: count_events spec: filename: count_events.py diff --git a/modules/src/count_events/test_count_events.py b/modules/src/count_events/test_count_events.py index 66a94c932..fc3e76a4e 100644 --- a/modules/src/count_events/test_count_events.py +++ b/modules/src/count_events/test_count_events.py @@ -14,15 +14,15 @@ # -from mlrun.model_monitoring.applications import ModelMonitoringApplicationMetric -import mlrun.model_monitoring.applications.context as mm_context - -from count_events import CountApp - -from unittest.mock import Mock from datetime import datetime +from unittest.mock import Mock + +import mlrun.model_monitoring.applications.context as mm_context import pandas as pd import pytest +from count_events import CountApp +from mlrun.model_monitoring.applications import ModelMonitoringApplicationMetric + class TestCountApp: """Test suite for CountApp class.""" @@ -30,6 +30,7 @@ class TestCountApp: def setup_method(self): """Set up test fixtures before each test method.""" self.count_app = CountApp() + @staticmethod def _create_mock_monitoring_context(sample_df, model_endpoint_name="test-model"): """Helper method to create a mock monitoring context.""" @@ -53,7 +54,6 @@ def _create_mock_monitoring_context(sample_df, model_endpoint_name="test-model") return mock_context - @pytest.mark.parametrize("df_size", [0, 1, 10, 100, 1000]) def test_do_tracking_with_various_dataframe_sizes(self, df_size): """Test do_tracking with various dataframe sizes using parametrized test.""" @@ -72,4 +72,3 @@ def test_do_tracking_with_various_dataframe_sizes(self, df_size): assert isinstance(result, ModelMonitoringApplicationMetric) assert result.value == df_size assert result.name == "count" - diff --git a/modules/src/evidently_iris/evidently_iris.py b/modules/src/evidently_iris/evidently_iris.py index e7a9f3ef9..375c1d3f8 100644 --- a/modules/src/evidently_iris/evidently_iris.py +++ b/modules/src/evidently_iris/evidently_iris.py @@ -14,18 +14,8 @@ from typing import Optional -import pandas as pd -from sklearn.datasets import load_iris - import mlrun.model_monitoring.applications.context as mm_context -from mlrun.common.schemas.model_monitoring.constants import ( - ResultKindApp, - ResultStatusApp, -) -from mlrun.feature_store.api import norm_column_name -from mlrun.model_monitoring.applications import ModelMonitoringApplicationResult -from mlrun.model_monitoring.applications.evidently import EvidentlyModelMonitoringApplicationBase - +import pandas as pd from evidently.core.report import Report, Snapshot from evidently.metrics import DatasetMissingValueCount, ValueDrift from evidently.presets import DataDriftPreset, DataSummaryPreset @@ -33,6 +23,16 @@ STR_UUID, OrgID, ) +from mlrun.common.schemas.model_monitoring.constants import ( + ResultKindApp, + ResultStatusApp, +) +from mlrun.feature_store.api import norm_column_name +from mlrun.model_monitoring.applications import ModelMonitoringApplicationResult +from mlrun.model_monitoring.applications.evidently import ( + EvidentlyModelMonitoringApplicationBase, +) +from sklearn.datasets import load_iris _PROJECT_NAME = "Iris Monitoring" _PROJECT_DESCRIPTION = "Test project using iris dataset" @@ -43,12 +43,13 @@ class EvidentlyIrisMonitoringApp(EvidentlyModelMonitoringApplicationBase): This model monitoring application is a simple example of integrating MLRun with Evidently for data monitoring, which you can adapt to fit your own project needs or use as a reference implementation. """ + NAME = "Evidently-App-Test" def __init__( self, evidently_project_id: Optional["STR_UUID"] = None, - evidently_workspace_path: Optional[str] = None, + evidently_workspace_path: str | None = None, cloud_workspace: bool = False, evidently_organization_id: Optional["OrgID"] = None, ) -> None: diff --git a/modules/src/evidently_iris/item.yaml b/modules/src/evidently_iris/item.yaml index 42c5c10cb..f8aa203fa 100644 --- a/modules/src/evidently_iris/item.yaml +++ b/modules/src/evidently_iris/item.yaml @@ -8,7 +8,7 @@ generationDate: 2025-11-09:12-25 hidden: false labels: author: Iguazio -mlrunVersion: 1.10.0-rc41 +mlrunVersion: 1.10.0 name: evidently_iris spec: filename: evidently_iris.py diff --git a/modules/src/evidently_iris/test_evidently_iris.py b/modules/src/evidently_iris/test_evidently_iris.py index 6488768fd..a9d12d75a 100644 --- a/modules/src/evidently_iris/test_evidently_iris.py +++ b/modules/src/evidently_iris/test_evidently_iris.py @@ -20,14 +20,12 @@ import pytest import semver - +from evidently_iris import EvidentlyIrisMonitoringApp from mlrun.errors import MLRunIncompatibleVersionError from mlrun.model_monitoring.applications.evidently.base import ( _check_evidently_version, ) -from evidently_iris import EvidentlyIrisMonitoringApp - @pytest.mark.parametrize( ("cur", "ref", "expectation"), diff --git a/modules/src/histogram_data_drift/histogram_data_drift.py b/modules/src/histogram_data_drift/histogram_data_drift.py index b8cdcf299..59df3df06 100644 --- a/modules/src/histogram_data_drift/histogram_data_drift.py +++ b/modules/src/histogram_data_drift/histogram_data_drift.py @@ -13,16 +13,14 @@ # limitations under the License. from dataclasses import dataclass -from typing import Final, Optional, Protocol, Union, cast - -import numpy as np -from pandas import DataFrame, Series +from typing import Final, Protocol, cast import mlrun.artifacts import mlrun.common.model_monitoring.helpers import mlrun.model_monitoring.applications.context as mm_context import mlrun.model_monitoring.applications.results as mm_results import mlrun.model_monitoring.features_drift_table as mm_drift_table +import numpy as np from mlrun.common.schemas.model_monitoring.constants import ( ResultKindApp, ResultStatusApp, @@ -37,6 +35,7 @@ KullbackLeiblerDivergence, TotalVarianceDistance, ) +from pandas import DataFrame, Series class InvalidMetricValueError(ValueError): @@ -134,7 +133,7 @@ class HistogramDataDriftApplication(ModelMonitoringApplicationBase): def __init__( self, - value_classifier: Optional[ValueClassifier] = None, + value_classifier: ValueClassifier | None = None, produce_json_artifact: bool = False, produce_plotly_artifact: bool = False, ) -> None: @@ -145,9 +144,9 @@ def __init__( :param produce_plotly_artifact: Whether to produce the Plotly artifact or not, ``False`` by default. """ self._value_classifier = value_classifier or DataDriftClassifier() - assert self._REQUIRED_METRICS <= set( - self.metrics - ), "TVD and Hellinger distance are required for the general data drift result" + assert self._REQUIRED_METRICS <= set(self.metrics), ( + "TVD and Hellinger distance are required for the general data drift result" + ) self._produce_json_artifact = produce_json_artifact self._produce_plotly_artifact = produce_plotly_artifact @@ -349,11 +348,9 @@ def _log_drift_artifacts( def do_tracking( self, monitoring_context: mm_context.MonitoringApplicationContext ) -> list[ - Union[ - mm_results.ModelMonitoringApplicationResult, - mm_results.ModelMonitoringApplicationMetric, - mm_results._ModelMonitoringApplicationStats, - ] + mm_results.ModelMonitoringApplicationResult + | mm_results.ModelMonitoringApplicationMetric + | mm_results._ModelMonitoringApplicationStats ]: """ Calculate and return the data drift metrics, averaged over the features. diff --git a/modules/src/histogram_data_drift/item.yaml b/modules/src/histogram_data_drift/item.yaml index f516ae071..83d0f0c99 100644 --- a/modules/src/histogram_data_drift/item.yaml +++ b/modules/src/histogram_data_drift/item.yaml @@ -8,7 +8,7 @@ generationDate: 2025-11-06:12-25 hidden: false labels: author: Iguazio -mlrunVersion: 1.10.0-rc41 +mlrunVersion: 1.10.0 name: histogram_data_drift spec: filename: histogram_data_drift.py diff --git a/modules/src/histogram_data_drift/test_histogram_data_drift.py b/modules/src/histogram_data_drift/test_histogram_data_drift.py index 018edaa86..c731e2c9b 100644 --- a/modules/src/histogram_data_drift/test_histogram_data_drift.py +++ b/modules/src/histogram_data_drift/test_histogram_data_drift.py @@ -16,25 +16,24 @@ from pathlib import Path from unittest.mock import Mock -import pandas as pd -import pytest -from hypothesis import given -from hypothesis import strategies as st - import mlrun.common.model_monitoring.helpers import mlrun.model_monitoring.applications import mlrun.model_monitoring.applications.context as mm_context import mlrun.utils -from mlrun.common.schemas.model_monitoring.constants import ( - ResultKindApp, - ResultStatusApp, -) +import pandas as pd +import pytest from histogram_data_drift import ( DataDriftClassifier, HistogramDataDriftApplication, InvalidMetricValueError, InvalidThresholdValueError, ) +from hypothesis import given +from hypothesis import strategies as st +from mlrun.common.schemas.model_monitoring.constants import ( + ResultKindApp, + ResultStatusApp, +) assets_folder = Path(__file__).parent / "assets" @@ -99,9 +98,9 @@ def classifier() -> DataDriftClassifier: def test_status( classifier: DataDriftClassifier, value: float, expected_status: ResultStatusApp ) -> None: - assert ( - classifier.value_to_status(value) == expected_status - ), "The status is different than expected" + assert classifier.value_to_status(value) == expected_status, ( + "The status is different than expected" + ) class TestApplication: @@ -205,15 +204,15 @@ def test( res, mlrun.model_monitoring.applications.ModelMonitoringApplicationResult, ): - assert ( - res.kind == ResultKindApp.data_drift - ), "The kind should be data drift" - assert ( - res.name == "general_drift" - ), "The result name should be general_drift" - assert ( - res.status == ResultStatusApp.potential_detection - ), "Expected potential detection in the general drift" + assert res.kind == ResultKindApp.data_drift, ( + "The kind should be data drift" + ) + assert res.name == "general_drift", ( + "The result name should be general_drift" + ) + assert res.status == ResultStatusApp.potential_detection, ( + "Expected potential detection in the general drift" + ) elif isinstance( res, mlrun.model_monitoring.applications.ModelMonitoringApplicationMetric, @@ -274,6 +273,6 @@ def test_compute_metrics_per_feature( assert set(metrics_per_feature.columns) == { metric.NAME for metric in application.metrics }, "Different metrics than expected" - assert set(metrics_per_feature.index) == set( - feature_stats.columns - ), "The features are different than expected" + assert set(metrics_per_feature.index) == set(feature_stats.columns), ( + "The features are different than expected" + ) diff --git a/modules/src/openai_proxy_app/openai_proxy_app.py b/modules/src/openai_proxy_app/openai_proxy_app.py index 65bfbf7c9..9132d9bff 100644 --- a/modules/src/openai_proxy_app/openai_proxy_app.py +++ b/modules/src/openai_proxy_app/openai_proxy_app.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # -#This module acts as a lightweight gateway to OpenAI-compatible APIs. -#You can send chat prompts, create embeddings, or get model responses without worrying about authentication or endpoint differences. -#It simplifies access so you can test, analyze, or integrate AI features directly into your projects or notebooks with minimal setup. +# This module acts as a lightweight gateway to OpenAI-compatible APIs. +# You can send chat prompts, create embeddings, or get model responses without worrying about authentication or endpoint differences. +# It simplifies access so you can test, analyze, or integrate AI features directly into your projects or notebooks with minimal setup. BASE64 = "IyBvcGVuYWlfcHJveHkvb3BlbmFpLnB5CgppbXBvcnQgb3MKaW1wb3J0IGpzb24KZnJvbSB1cmxsaWIucGFyc2UgaW1wb3J0IHVybGpvaW4KZnJvbSB0eXBpbmcgaW1wb3J0IEFueSwgRGljdCwgTGlzdCwgT3B0aW9uYWwKCmltcG9ydCByZXF1ZXN0cwpmcm9tIGZhc3RhcGkgaW1wb3J0IEZhc3RBUEksIFJlcXVlc3QsIFJlc3BvbnNlLCBCb2R5CgphcHAgPSBGYXN0QVBJKAogICAgdGl0bGU9Ik9wZW5BSSBQcm94eSBBcHAiLAogICAgZGVzY3JpcHRpb249IkxvY2FsIEZhc3RBUEkgcHJveHkgZm9yIE9wZW5BSSBzdHlsZSBlbmRwb2ludHMiLAogICAgdmVyc2lvbj0iMS4wLjAiLAopCgpPUEVOQUlfQkFTRV9VUkwgPSBvcy5nZXRlbnYoIk9QRU5BSV9CQVNFX1VSTCIsICJodHRwczovL2FwaS5vcGVuYWkuY29tIikucnN0cmlwKCIvIikKT1BFTkFJX0FQSV9LRVkgPSBvcy5nZXRlbnYoIk9QRU5BSV9BUElfS0VZIiwgIiIpCk9QRU5BSV9ERUZBVUxUX01PREVMID0gb3MuZ2V0ZW52KCJPUEVOQUlfREVGQVVMVF9NT0RFTCIsICJncHQtNG8tbWluaSIpCgoKZGVmIGJ1aWxkX2hlYWRlcnMoaW5jb21pbmc6IGRpY3QpIC0+IGRpY3Q6CiAgICBoZWFkZXJzID0ge30KICAgIGF1dGggPSBpbmNvbWluZy5nZXQoImF1dGhvcml6YXRpb24iKSBvciBpbmNvbWluZy5nZXQoIkF1dGhvcml6YXRpb24iKQogICAgaWYgYXV0aDoKICAgICAgICBoZWFkZXJzWyJBdXRob3JpemF0aW9uIl0gPSBhdXRoCiAgICBlbGlmIE9QRU5BSV9BUElfS0VZOgogICAgICAgIGhlYWRlcnNbIkF1dGhvcml6YXRpb24iXSA9IGYiQmVhcmVyIHtPUEVOQUlfQVBJX0tFWX0iCiAgICBjdHlwZSA9IGluY29taW5nLmdldCgiY29udGVudC10eXBlIikgb3IgaW5jb21pbmcuZ2V0KCJDb250ZW50LVR5cGUiKSBvciAiYXBwbGljYXRpb24vanNvbiIKICAgIGhlYWRlcnNbIkNvbnRlbnQtVHlwZSJdID0gY3R5cGUKICAgIHJldHVybiBoZWFkZXJzCgoKZGVmIGJ1aWxkX3RhcmdldChwYXRoOiBzdHIpIC0+IHN0cjoKICAgIGJhc2UgPSBPUEVOQUlfQkFTRV9VUkwKICAgIGlmIGJhc2UuZW5kc3dpdGgoIi92MSIpIG9yIGJhc2UuZW5kc3dpdGgoIi92MS8iKToKICAgICAgICBiYXNlID0gYmFzZVs6LTNdIGlmIGJhc2UuZW5kc3dpdGgoIi92MSIpIGVsc2UgYmFzZVs6LTRdCiAgICByZXR1cm4gdXJsam9pbihiYXNlICsgIi8iLCBwYXRoLmxzdHJpcCgiLyIpKQoKCmRlZiBmb3J3YXJkX2pzb24ocGF0aDogc3RyLCBib2R5OiBkaWN0LCBoZWFkZXJzOiBkaWN0LCBxdWVyeTogZGljdCk6CiAgICB0YXJnZXQgPSBidWlsZF90YXJnZXQocGF0aCkKICAgIHJlc3AgPSByZXF1ZXN0cy5wb3N0KAogICAgICAgIHRhcmdldCwKICAgICAgICBoZWFkZXJzPWhlYWRlcnMsCiAgICAgICAgcGFyYW1zPXF1ZXJ5LAogICAgICAgIGpzb249Ym9keSwKICAgICAgICB0aW1lb3V0PTYwLAogICAgKQogICAgcmV0dXJuIHJlc3AKCkBhcHAuZ2V0KCIvIikKZGVmIGhlYWx0aCgpOgogICAgcmV0dXJuIHsic3RhdHVzIjogIm9rIn0KCgojIHJlbGF4ZWQgY2hhdCBlbmRwb2ludCwgYWNjZXB0cyBhbnkgSlNPTiB0aGF0IGluY2x1ZGVzIG1lc3NhZ2VzCkBhcHAucG9zdCgiL3YxL2NoYXQvY29tcGxldGlvbnMiKQphc3luYyBkZWYgY2hhdF9jb21wbGV0aW9ucygKICAgIHJlcXVlc3Q6IFJlcXVlc3QsCiAgICBwYXlsb2FkOiBEaWN0W3N0ciwgQW55XSA9IEJvZHkoLi4uKSwKKToKICAgIGlmICJtZXNzYWdlcyIgbm90IGluIHBheWxvYWQgb3Igbm90IGlzaW5zdGFuY2UocGF5bG9hZFsibWVzc2FnZXMiXSwgbGlzdCk6CiAgICAgICAgcmV0dXJuIFJlc3BvbnNlKAogICAgICAgICAgICBjb250ZW50PWpzb24uZHVtcHMoeyJlcnJvciI6ICJtZXNzYWdlcyBtdXN0IGJlIGEgbGlzdCBvZiBjaGF0IG1lc3NhZ2VzIn0pLAogICAgICAgICAgICBzdGF0dXNfY29kZT00MDAsCiAgICAgICAgICAgIG1lZGlhX3R5cGU9ImFwcGxpY2F0aW9uL2pzb24iLAogICAgICAgICkKCiAgICBpZiAibW9kZWwiIG5vdCBpbiBwYXlsb2FkIG9yIHBheWxvYWRbIm1vZGVsIl0gaXMgTm9uZToKICAgICAgICBwYXlsb2FkWyJtb2RlbCJdID0gT1BFTkFJX0RFRkFVTFRfTU9ERUwKCiAgICBoZWFkZXJzID0gYnVpbGRfaGVhZGVycyhkaWN0KHJlcXVlc3QuaGVhZGVycykpCiAgICByZXNwID0gZm9yd2FyZF9qc29uKCIvdjEvY2hhdC9jb21wbGV0aW9ucyIsIHBheWxvYWQsIGhlYWRlcnMsIGRpY3QocmVxdWVzdC5xdWVyeV9wYXJhbXMpKQogICAgcmV0dXJuIFJlc3BvbnNlKAogICAgICAgIGNvbnRlbnQ9cmVzcC5jb250ZW50LAogICAgICAgIHN0YXR1c19jb2RlPXJlc3Auc3RhdHVzX2NvZGUsCiAgICAgICAgbWVkaWFfdHlwZT1yZXNwLmhlYWRlcnMuZ2V0KCJDb250ZW50LVR5cGUiLCAiYXBwbGljYXRpb24vanNvbiIpLAogICAgKQoKCkBhcHAucG9zdCgiL3YxL2VtYmVkZGluZ3MiKQphc3luYyBkZWYgZW1iZWRkaW5ncygKICAgIHJlcXVlc3Q6IFJlcXVlc3QsCiAgICBwYXlsb2FkOiBEaWN0W3N0ciwgQW55XSA9IEJvZHkoLi4uKSwKKToKICAgIGlmICJtb2RlbCIgbm90IGluIHBheWxvYWQgb3Igbm90IHBheWxvYWRbIm1vZGVsIl06CiAgICAgICAgcGF5bG9hZFsibW9kZWwiXSA9ICJ0ZXh0LWVtYmVkZGluZy0zLXNtYWxsIgogICAgaGVhZGVycyA9IGJ1aWxkX2hlYWRlcnMoZGljdChyZXF1ZXN0LmhlYWRlcnMpKQogICAgcmVzcCA9IGZvcndhcmRfanNvbigiL3YxL2VtYmVkZGluZ3MiLCBwYXlsb2FkLCBoZWFkZXJzLCBkaWN0KHJlcXVlc3QucXVlcnlfcGFyYW1zKSkKICAgIHJldHVybiBSZXNwb25zZSgKICAgICAgICBjb250ZW50PXJlc3AuY29udGVudCwKICAgICAgICBzdGF0dXNfY29kZT1yZXNwLnN0YXR1c19jb2RlLAogICAgICAgIG1lZGlhX3R5cGU9cmVzcC5oZWFkZXJzLmdldCgiQ29udGVudC1UeXBlIiwgImFwcGxpY2F0aW9uL2pzb24iKSwKICAgICkKCgpAYXBwLnBvc3QoIi92MS9yZXNwb25zZXMiKQphc3luYyBkZWYgcmVzcG9uc2VzX2FwaSgKICAgIHJlcXVlc3Q6IFJlcXVlc3QsCiAgICBwYXlsb2FkOiBEaWN0W3N0ciwgQW55XSA9IEJvZHkoLi4uKSwKKToKICAgIGlmICJtb2RlbCIgbm90IGluIHBheWxvYWQgb3IgcGF5bG9hZFsibW9kZWwiXSBpcyBOb25lOgogICAgICAgIHBheWxvYWRbIm1vZGVsIl0gPSBPUEVOQUlfREVGQVVMVF9NT0RFTAogICAgaGVhZGVycyA9IGJ1aWxkX2hlYWRlcnMoZGljdChyZXF1ZXN0LmhlYWRlcnMpKQogICAgcmVzcCA9IGZvcndhcmRfanNvbigiL3YxL3Jlc3BvbnNlcyIsIHBheWxvYWQsIGhlYWRlcnMsIGRpY3QocmVxdWVzdC5xdWVyeV9wYXJhbXMpKQogICAgcmV0dXJuIFJlc3BvbnNlKAogICAgICAgIGNvbnRlbnQ9cmVzcC5jb250ZW50LAogICAgICAgIHN0YXR1c19jb2RlPXJlc3Auc3RhdHVzX2NvZGUsCiAgICAgICAgbWVkaWFfdHlwZT1yZXNwLmhlYWRlcnMuZ2V0KCJDb250ZW50LVR5cGUiLCAiYXBwbGljYXRpb24vanNvbiIpLAogICAgKQoKCiMgLS0tLS0tLS0tLS0tLS0tLSBjbGllbnQgLS0tLS0tLS0tLS0tLS0tLQpjbGFzcyBPcGVuQUlQcm94eUNsaWVudDoKICAgICIiIgogICAgU2ltcGxlIGNsaWVudCBmb3IgdGhlIGxvY2FsIHByb3h5LgogICAgRGVmYXVsdCBiYXNlIHVybCBpcyBodHRwOi8vbG9jYWxob3N0OjgwMDAKICAgIElmIGFwaV9rZXkgaXMgbm90IHByb3ZpZGVkLCBpdCB1c2VzIE9QRU5BSV9BUElfS0VZIGZyb20gZW52aXJvbm1lbnQuCiAgICAiIiIKCiAgICBkZWYgX19pbml0X18oc2VsZiwgYmFzZV91cmw6IHN0ciA9ICJodHRwOi8vbG9jYWxob3N0OjgwMDAiLCBhcGlfa2V5OiBPcHRpb25hbFtzdHJdID0gTm9uZSk6CiAgICAgICAgc2VsZi5iYXNlX3VybCA9IGJhc2VfdXJsLnJzdHJpcCgiLyIpCiAgICAgICAgc2VsZi5hcGlfa2V5ID0gYXBpX2tleQoKICAgIGRlZiBfaGVhZGVycyhzZWxmKSAtPiBEaWN0W3N0ciwgc3RyXToKICAgICAgICBoZWFkZXJzID0geyJDb250ZW50LVR5cGUiOiAiYXBwbGljYXRpb24vanNvbiJ9CiAgICAgICAga2V5ID0gc2VsZi5hcGlfa2V5IG9yIG9zLmdldGVudigiT1BFTkFJX0FQSV9LRVkiLCAiIikKICAgICAgICBpZiBrZXk6CiAgICAgICAgICAgIGhlYWRlcnNbIkF1dGhvcml6YXRpb24iXSA9IGYiQmVhcmVyIHtrZXl9IgogICAgICAgIHJldHVybiBoZWFkZXJzCgogICAgZGVmIGNoYXQoc2VsZiwgbWVzc2FnZXM6IExpc3RbRGljdFtzdHIsIHN0cl1dLCBtb2RlbDogT3B0aW9uYWxbc3RyXSA9IE5vbmUpIC0+IERpY3Rbc3RyLCBBbnldOgogICAgICAgIGJvZHk6IERpY3Rbc3RyLCBBbnldID0geyJtZXNzYWdlcyI6IG1lc3NhZ2VzfQogICAgICAgIGlmIG1vZGVsOgogICAgICAgICAgICBib2R5WyJtb2RlbCJdID0gbW9kZWwKICAgICAgICByZXNwID0gcmVxdWVzdHMucG9zdCgKICAgICAgICAgICAgZiJ7c2VsZi5iYXNlX3VybH0vdjEvY2hhdC9jb21wbGV0aW9ucyIsCiAgICAgICAgICAgIGhlYWRlcnM9c2VsZi5faGVhZGVycygpLAogICAgICAgICAgICBqc29uPWJvZHksCiAgICAgICAgICAgIHRpbWVvdXQ9NjAsCiAgICAgICAgKQogICAgICAgIHJlc3AucmFpc2VfZm9yX3N0YXR1cygpCiAgICAgICAgcmV0dXJuIHJlc3AuanNvbigpCgogICAgZGVmIGVtYmVkZGluZ3Moc2VsZiwgdGV4dDogQW55LCBtb2RlbDogT3B0aW9uYWxbc3RyXSA9IE5vbmUpIC0+IERpY3Rbc3RyLCBBbnldOgogICAgICAgIGJvZHk6IERpY3Rbc3RyLCBBbnldID0geyJpbnB1dCI6IHRleHR9CiAgICAgICAgaWYgbW9kZWw6CiAgICAgICAgICAgIGJvZHlbIm1vZGVsIl0gPSBtb2RlbAogICAgICAgIHJlc3AgPSByZXF1ZXN0cy5wb3N0KAogICAgICAgICAgICBmIntzZWxmLmJhc2VfdXJsfS92MS9lbWJlZGRpbmdzIiwKICAgICAgICAgICAgaGVhZGVycz1zZWxmLl9oZWFkZXJzKCksCiAgICAgICAgICAgIGpzb249Ym9keSwKICAgICAgICAgICAgdGltZW91dD02MCwKICAgICAgICApCiAgICAgICAgcmVzcC5yYWlzZV9mb3Jfc3RhdHVzKCkKICAgICAgICByZXR1cm4gcmVzcC5qc29uKCkKCiAgICBkZWYgcmVzcG9uc2VzKHNlbGYsIGlucHV0X3RleHQ6IEFueSwgbW9kZWw6IE9wdGlvbmFsW3N0cl0gPSBOb25lKSAtPiBEaWN0W3N0ciwgQW55XToKICAgICAgICBib2R5OiBEaWN0W3N0ciwgQW55XSA9IHsiaW5wdXQiOiBpbnB1dF90ZXh0fQogICAgICAgIGlmIG1vZGVsOgogICAgICAgICAgICBib2R5WyJtb2RlbCJdID0gbW9kZWwKICAgICAgICByZXNwID0gcmVxdWVzdHMucG9zdCgKICAgICAgICAgICAgZiJ7c2VsZi5iYXNlX3VybH0vdjEvcmVzcG9uc2VzIiwKICAgICAgICAgICAgaGVhZGVycz1zZWxmLl9oZWFkZXJzKCksCiAgICAgICAgICAgIGpzb249Ym9keSwKICAgICAgICAgICAgdGltZW91dD02MCwKICAgICAgICApCiAgICAgICAgcmVzcC5yYWlzZV9mb3Jfc3RhdHVzKCkKICAgICAgICByZXR1cm4gcmVzcC5qc29uKCkKCgojIG9wdGlvbmFsIHF1aWNrIHNlbGYgdGVzdCB3aGVuIHJ1bm5pbmcgdGhpcyBmaWxlIGRpcmVjdGx5CmlmIF9fbmFtZV9fID09ICJfX21haW5fXyI6CiAgICAjIHN0YXJ0IHRoZSBzZXJ2ZXIgaW4gYW5vdGhlciB0ZXJtaW5hbCBmaXJzdDoKICAgICMgdXZpY29ybiBvcGVuYWlfcHJveHkub3BlbmFpOmFwcCAtLWhvc3QgMC4wLjAuMCAtLXBvcnQgODAwMCAtLXJlbG9hZAogICAgYyA9IE9wZW5BSVByb3h5Q2xpZW50KCkKICAgIHRyeToKICAgICAgICBwcmludCgiSGVhbHRoOiIsIHJlcXVlc3RzLmdldChmIntjLmJhc2VfdXJsfS8iKS5qc29uKCkpCiAgICBleGNlcHQgRXhjZXB0aW9uIGFzIGU6CiAgICAgICAgcHJpbnQoIlNlcnZlciBub3QgcnVubmluZzoiLCBlKQo=" -CMD = r''' +CMD = r""" set -e python - <<'PY' import os, base64, pathlib @@ -34,23 +34,24 @@ --bind 0.0.0.0:8000 \ --worker-class uvicorn.workers.UvicornWorker \ --log-level info -'''.strip() +""".strip() + + class OpenAIModule: - def __init__(self,project): + def __init__(self, project): self.project = project - self.openai_proxy_app = self.project.set_function(name="openai",kind="application",image="python:3.11") - self.openai_proxy_app.with_requirements([ + self.openai_proxy_app = self.project.set_function( + name="openai", kind="application", image="python:3.11" + ) + self.openai_proxy_app.with_requirements( + [ "fastapi==0.124.0", "uvicorn[standard]==0.38.0", "gunicorn==23.0.0", "requests=2.32.5", - ]) - self.openai_proxy_app.set_env("BASE64",BASE64) + ] + ) + self.openai_proxy_app.set_env("BASE64", BASE64) self.openai_proxy_app.set_internal_application_port(8000) self.openai_proxy_app.spec.command = "/bin/sh" self.openai_proxy_app.spec.args = ["-c", CMD] - - - - - diff --git a/modules/src/openai_proxy_app/test_openai_proxy_app.py b/modules/src/openai_proxy_app/test_openai_proxy_app.py index 79fbc726a..957222325 100644 --- a/modules/src/openai_proxy_app/test_openai_proxy_app.py +++ b/modules/src/openai_proxy_app/test_openai_proxy_app.py @@ -13,8 +13,9 @@ # limitations under the License. # -from openai_proxy_app import OpenAIModule import mlrun +from openai_proxy_app import OpenAIModule + class TestOpenAIProxyApp: """Test suite for TestOpenAIProxyApp class.""" @@ -26,6 +27,7 @@ def setup_method(self): def test_openai_proxy_app(self): """Test do_tracking with various dataframe sizes using parametrized test.""" - assert type(self.TestOpenAIProxyApp.openai_proxy_app) == mlrun.runtimes.nuclio.application.application.ApplicationRuntime - - + assert ( + type(self.TestOpenAIProxyApp.openai_proxy_app) + == mlrun.runtimes.nuclio.application.application.ApplicationRuntime + ) diff --git a/modules/src/vllm_module/test_vllm_module.py b/modules/src/vllm_module/test_vllm_module.py index 3a5f422ae..4de1be16a 100644 --- a/modules/src/vllm_module/test_vllm_module.py +++ b/modules/src/vllm_module/test_vllm_module.py @@ -13,8 +13,8 @@ # limitations under the License. # -from vllm_module import VLLMModule import mlrun +from vllm_module import VLLMModule class TestVllmModule: @@ -30,6 +30,7 @@ def setup_method(self): ) def test_vllm_module(self): - assert ( - type(self.TestVllmModule.vllm_app) == mlrun.runtimes.nuclio.application.application.ApplicationRuntime + assert isinstance( + self.TestVllmModule.vllm_app, + mlrun.runtimes.nuclio.application.application.ApplicationRuntime, ) diff --git a/modules/src/vllm_module/vllm_module.py b/modules/src/vllm_module/vllm_module.py index 50bc9f038..39ecff28f 100644 --- a/modules/src/vllm_module/vllm_module.py +++ b/modules/src/vllm_module/vllm_module.py @@ -12,20 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # -#This module acts as a lightweight gateway to OpenAI-compatible APIs. -#You can send chat prompts, create embeddings, or get model responses without worrying about authentication or endpoint differences. -#It simplifies access so you can test, analyze, or integrate AI features directly into your projects or notebooks with minimal setup. +# This module acts as a lightweight gateway to OpenAI-compatible APIs. +# You can send chat prompts, create embeddings, or get model responses without worrying about authentication or endpoint differences. +# It simplifies access so you can test, analyze, or integrate AI features directly into your projects or notebooks with minimal setup. -from typing import Dict, Optional, List - class VLLMModule: """ VLLMModule - + This module provides a lightweight wrapper for deploying a vLLM (OpenAI-compatible) large language model server as an MLRun application runtime. - + The VLLMModule is responsible for: - Creating an MLRun application runtime based on a vLLM container image - Configuring GPU resources, memory limits, and Kubernetes node selection @@ -34,35 +32,33 @@ class VLLMModule: - Automatically configuring shared memory (/dev/shm) when using multiple GPUs - Exposing an OpenAI-compatible API (e.g. /v1/chat/completions) for inference - Providing a simple Python interface for deployment and invocation from Jupyter notebooks - + The module is designed to be used in Jupyter notebooks and MLRun pipelines, allowing users to deploy and test large language models on Kubernetes with minimal configuration. """ def __init__( - self, - project: str, - *, - node_selector: Optional[Dict[str, str]] = None, - name: str = "vllm", - image: str = "vllm/vllm-openai:latest", - model: str = "Qwen/Qwen2.5-Omni-3B", - gpus: int = 1, - mem: str = "10G", - port: int = 8000, - dtype: str = "auto", - uvicorn_log_level: str = "info", - max_tokens: int = 500, + self, + project: str, + *, + node_selector: dict[str, str] | None = None, + name: str = "vllm", + image: str = "vllm/vllm-openai:latest", + model: str = "Qwen/Qwen2.5-Omni-3B", + gpus: int = 1, + mem: str = "10G", + port: int = 8000, + dtype: str = "auto", + uvicorn_log_level: str = "info", + max_tokens: int = 500, ): if gpus < 1: raise ValueError("gpus must be >= 1") - - if node_selector is None: node_selector = {"alpha.eksctl.io/nodegroup-name": "added-gpu"} - + if not isinstance(max_tokens, int): raise TypeError("max_tokens must be an integer") @@ -94,7 +90,7 @@ def __init__( self.vllm_app.set_internal_application_port(self.port) - args: List[str] = [ + args: list[str] = [ "serve", self.model, "--dtype", @@ -110,10 +106,12 @@ def __init__( args += ["--tensor-parallel-size", str(gpus)] # For more than one GPU you should create a share volume for the multiple GPUs - self.vllm_app.spec.volumes = [{"name": "dshm", "emptyDir": {"medium": "Memory"}}] - self.vllm_app.spec.volume_mounts = [{"name": "dshm", "mountPath": "/dev/shm"}] - - + self.vllm_app.spec.volumes = [ + {"name": "dshm", "emptyDir": {"medium": "Memory"}} + ] + self.vllm_app.spec.volume_mounts = [ + {"name": "dshm", "mountPath": "/dev/shm"} + ] self.vllm_app.spec.command = "vllm" self.vllm_app.spec.args = args @@ -124,8 +122,9 @@ def __init__( def get_runtime(self): return self.vllm_app - def add_args(self, extra_args: List[str]): - if not isinstance(extra_args, list) or not all(isinstance(x, str) for x in extra_args): + def add_args(self, extra_args: list[str]): + if not isinstance(extra_args, list) or not all( + isinstance(x, str) for x in extra_args + ): raise ValueError("extra_args must be a list of strings") self.vllm_app.spec.args += extra_args - diff --git a/pyproject.toml b/pyproject.toml index d7813821d..869e3356b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "mlrun-hub" version = "0.1.0" description = "MLRun Hub - centralized location for open source contributions of mlrun hub components" readme = "README.md" -requires-python = ">=3.11" +requires-python = ">=3.10.18" license = { file = "LICENSE" } authors = [ { name = "MLRun Team" } @@ -33,8 +33,12 @@ dependencies = [ mlrun-functions = "cli.cli:cli" [tool.ruff] -target-version = "py311" +target-version = "py310" required-version = ">=0.8.0" +exclude = [ + "**/*.ipynb", +] + [tool.ruff.lint] extend-select = [ diff --git a/steps/src/verify_schema/test_verify_schema.py b/steps/src/verify_schema/test_verify_schema.py index 5a7e08b53..bebb0a5b4 100644 --- a/steps/src/verify_schema/test_verify_schema.py +++ b/steps/src/verify_schema/test_verify_schema.py @@ -15,25 +15,19 @@ from verify_schema import VerifySchema + class TestVerifySchema: def test_verify_schema(self): schema = ["id", "name", "active"] verifier = VerifySchema(schema=schema, allow_unexpected_keys=False) # Test with valid event - event = { - "id": 1, - "name": "Test Event", - "active": True - } + event = {"id": 1, "name": "Test Event", "active": True} result = verifier.do(event) assert result == event # Test with missing key - event_missing_key = { - "id": 1, - "name": "Test Event" - } + event_missing_key = {"id": 1, "name": "Test Event"} try: verifier.do(event_missing_key) except KeyError as e: @@ -44,7 +38,7 @@ def test_verify_schema(self): "id": 1, "name": "Test Event", "active": True, - "extra": "unexpected" + "extra": "unexpected", } try: verifier.do(event_unexpected_key) @@ -56,11 +50,6 @@ def test_verify_schema_allow_unexpected(self): verifier = VerifySchema(schema=schema, allow_unexpected_keys=True) # Test with valid event and unexpected key - event = { - "id": 1, - "name": "Test Event", - "active": True, - "extra": "unexpected" - } + event = {"id": 1, "name": "Test Event", "active": True, "extra": "unexpected"} result = verifier.do(event) - assert result == event \ No newline at end of file + assert result == event diff --git a/steps/src/verify_schema/verify_schema.py b/steps/src/verify_schema/verify_schema.py index 80a379560..81cc46353 100644 --- a/steps/src/verify_schema/verify_schema.py +++ b/steps/src/verify_schema/verify_schema.py @@ -13,6 +13,7 @@ # limitations under the License. # + class VerifySchema: """ This step validates that an event dictionary contains exactly the keys defined in the schema, @@ -27,7 +28,9 @@ def do(self, event: dict): # Check if all keys in the expected schema are present in the event missing = set(self.schema) - set(event) if missing: - raise KeyError(f"Schema verification failed: missing keys {missing} in event: {event}") + raise KeyError( + f"Schema verification failed: missing keys {missing} in event: {event}" + ) if self.allow_unexpected_keys: return event @@ -35,6 +38,8 @@ def do(self, event: dict): # Check if there are any unexpected keys in the event unexpected = set(event) - set(self.schema) if unexpected: - raise KeyError(f"Schema verification failed: unexpected keys {unexpected} in event: {event}") + raise KeyError( + f"Schema verification failed: unexpected keys {unexpected} in event: {event}" + ) return event