diff --git a/functions/src/auto_trainer/auto_trainer.py b/functions/src/auto_trainer/auto_trainer.py index 7b4764700..4e53e5b7e 100755 --- a/functions/src/auto_trainer/auto_trainer.py +++ b/functions/src/auto_trainer/auto_trainer.py @@ -67,30 +67,14 @@ def _get_dataframe( Classification tasks. :param drop_columns: str/int or a list of strings/ints that represent the column names/indices to drop. """ - store_uri_prefix, _ = mlrun.datastore.parse_store_uri(dataset.artifact_url) - - # Getting the dataset: - if mlrun.utils.StorePrefix.FeatureVector == store_uri_prefix: - label_columns = label_columns or dataset.meta.status.label_column - context.logger.info(f"label columns: {label_columns}") - # FeatureVector case: - try: - fv = mlrun.datastore.get_store_resource(dataset.artifact_url) - dataset = fv.get_offline_features(drop_columns=drop_columns).to_dataframe() - except AttributeError: - # Leave here for backwards compatibility - dataset = fs.get_offline_features( - dataset.meta.uri, drop_columns=drop_columns - ).to_dataframe() - - elif not label_columns: - context.logger.info( - "label_columns not provided, mandatory when dataset is not a FeatureVector" - ) - raise ValueError - - elif isinstance(dataset, (list, dict)): + # Check if dataset is list/dict first (before trying to access artifact_url) + if isinstance(dataset, (list, dict)): # list/dict case: + if not label_columns: + context.logger.info( + "label_columns not provided, mandatory when dataset is not a FeatureVector" + ) + raise ValueError dataset = pd.DataFrame(dataset) # Checking if drop_columns provided by integer type: if drop_columns: @@ -103,17 +87,38 @@ def _get_dataframe( ) raise ValueError dataset.drop(drop_columns, axis=1, inplace=True) - else: - # simple URL case: - dataset = dataset.as_df() - if drop_columns: - if all(col in dataset for col in drop_columns): - dataset = dataset.drop(drop_columns, axis=1) - else: + # Dataset is a DataItem with artifact_url (URI or FeatureVector) + store_uri_prefix, _ = mlrun.datastore.parse_store_uri(dataset.artifact_url) + + # Getting the dataset: + if mlrun.utils.StorePrefix.FeatureVector == store_uri_prefix: + label_columns = label_columns or dataset.meta.status.label_column + context.logger.info(f"label columns: {label_columns}") + # FeatureVector case: + try: + fv = mlrun.datastore.get_store_resource(dataset.artifact_url) + dataset = fv.get_offline_features(drop_columns=drop_columns).to_dataframe() + except AttributeError: + # Leave here for backwards compatibility + dataset = fs.get_offline_features( + dataset.meta.uri, drop_columns=drop_columns + ).to_dataframe() + else: + # simple URL case: + if not label_columns: context.logger.info( - "not all of the columns to drop in the dataset, drop columns process skipped" + "label_columns not provided, mandatory when dataset is not a FeatureVector" ) + raise ValueError + dataset = dataset.as_df() + if drop_columns: + if all(col in dataset for col in drop_columns): + dataset = dataset.drop(drop_columns, axis=1) + else: + context.logger.info( + "not all of the columns to drop in the dataset, drop columns process skipped" + ) return dataset, label_columns diff --git a/functions/src/auto_trainer/function.yaml b/functions/src/auto_trainer/function.yaml index 0920b1033..50a36e750 100644 --- a/functions/src/auto_trainer/function.yaml +++ b/functions/src/auto_trainer/function.yaml @@ -1,22 +1,19 @@ -metadata: - categories: - - machine-learning - - model-training - tag: '' - name: auto-trainer +verbose: false +kind: job spec: - image: mlrun/mlrun - build: - origin_filename: '' - functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import mlrun
import mlrun.datastore
import mlrun.utils
import pandas as pd
from mlrun import feature_store as fs
from mlrun.datastore import DataItem
from mlrun.execution import MLClientCtx
from mlrun.frameworks.auto_mlrun import AutoMLRun
from mlrun.utils.helpers import create_class, create_function
from sklearn.model_selection import train_test_split

PathType = Union[str, Path]


class KWArgsPrefixes:
    MODEL_CLASS = "CLASS_"
    FIT = "FIT_"
    TRAIN = "TRAIN_"


def _get_sub_dict_by_prefix(src: Dict, prefix_key: str) -> Dict[str, Any]:
    """
    Collect all the keys from the given dict that starts with the given prefix and creates a new dictionary with these
    keys.

    :param src:         The source dict to extract the values from.
    :param prefix_key:  Only keys with this prefix will be returned. The keys in the result dict will be without this
                        prefix.
    """
    return {
        key.replace(prefix_key, ""): val
        for key, val in src.items()
        if key.startswith(prefix_key)
    }


def _get_dataframe(
    context: MLClientCtx,
    dataset: DataItem,
    label_columns: Optional[Union[str, List[str]]] = None,
    drop_columns: Union[str, List[str], int, List[int]] = None,
) -> Tuple[pd.DataFrame, Optional[Union[str, List[str]]]]:
    """
    Getting the DataFrame of the dataset and drop the columns accordingly.

    :param context:         MLRun context.
    :param dataset:         The dataset to train the model on.
                            Can be either a list of lists, dict, URI or a FeatureVector.
    :param label_columns:   The target label(s) of the column(s) in the dataset. for Regression or
                            Classification tasks.
    :param drop_columns:    str/int or a list of strings/ints that represent the column names/indices to drop.
    """
    store_uri_prefix, _ = mlrun.datastore.parse_store_uri(dataset.artifact_url)

    # Getting the dataset:
    if mlrun.utils.StorePrefix.FeatureVector == store_uri_prefix:
        label_columns = label_columns or dataset.meta.status.label_column
        context.logger.info(f"label columns: {label_columns}")
        # FeatureVector case:
        try:
            fv = mlrun.datastore.get_store_resource(dataset.artifact_url)
            dataset = fv.get_offline_features(drop_columns=drop_columns).to_dataframe()
        except AttributeError:
            # Leave here for backwards compatibility
            dataset = fs.get_offline_features(
                dataset.meta.uri, drop_columns=drop_columns
            ).to_dataframe()

    elif not label_columns:
        context.logger.info(
            "label_columns not provided, mandatory when dataset is not a FeatureVector"
        )
        raise ValueError

    elif isinstance(dataset, (list, dict)):
        # list/dict case:
        dataset = pd.DataFrame(dataset)
        # Checking if drop_columns provided by integer type:
        if drop_columns:
            if isinstance(drop_columns, str) or (
                isinstance(drop_columns, list)
                and any(isinstance(col, str) for col in drop_columns)
            ):
                context.logger.error(
                    "drop_columns must be an integer/list of integers if not provided with a URI/FeatureVector dataset"
                )
                raise ValueError
            dataset.drop(drop_columns, axis=1, inplace=True)

    else:
        # simple URL case:
        dataset = dataset.as_df()
        if drop_columns:
            if all(col in dataset for col in drop_columns):
                dataset = dataset.drop(drop_columns, axis=1)
            else:
                context.logger.info(
                    "not all of the columns to drop in the dataset, drop columns process skipped"
                )

    return dataset, label_columns


def train(
    context: MLClientCtx,
    dataset: DataItem,
    model_class: str,
    label_columns: Optional[Union[str, List[str]]] = None,
    drop_columns: List[str] = None,
    model_name: str = "model",
    tag: str = "",
    sample_set: DataItem = None,
    test_set: DataItem = None,
    train_test_split_size: float = None,
    random_state: int = None,
    labels: dict = None,
    **kwargs,
):
    """
    Training a model with the given dataset.

    example::

        import mlrun
        project = mlrun.get_or_create_project("my-project")
        project.set_function("hub://auto_trainer", "train")
        trainer_run = project.run(
            name="train",
            handler="train",
            inputs={"dataset": "./path/to/dataset.csv"},
            params={
                "model_class": "sklearn.linear_model.LogisticRegression",
                "label_columns": "label",
                "drop_columns": "id",
                "model_name": "my-model",
                "tag": "v1.0.0",
                "sample_set": "./path/to/sample_set.csv",
                "test_set": "./path/to/test_set.csv",
                "CLASS_solver": "liblinear",
            },
        )

    :param context:                 MLRun context
    :param dataset:                 The dataset to train the model on. Can be either a URI or a FeatureVector
    :param model_class:             The class of the model, e.g. `sklearn.linear_model.LogisticRegression`
    :param label_columns:           The target label(s) of the column(s) in the dataset. for Regression or
                                    Classification tasks. Mandatory when dataset is not a FeatureVector.
    :param drop_columns:            str or a list of strings that represent the columns to drop
    :param model_name:              The model's name to use for storing the model artifact, default to 'model'
    :param tag:                     The model's tag to log with
    :param sample_set:              A sample set of inputs for the model for logging its stats along the model in favour
                                    of model monitoring. Can be either a URI or a FeatureVector
    :param test_set:                The test set to train the model with.
    :param train_test_split_size:   if test_set was provided then this argument is ignored.
                                    Should be between 0.0 and 1.0 and represent the proportion of the dataset to include
                                    in the test split. The size of the Training set is set to the complement of this
                                    value. Default = 0.2
    :param random_state:            Relevant only when using train_test_split_size.
                                    A random state seed to shuffle the data. For more information, see:
                                    https://scikit-learn.org/stable/glossary.html#term-random_state
                                    Notice that here we only pass integer values.
    :param labels:                  Labels to log with the model
    :param kwargs:                  Here you can pass keyword arguments with prefixes,
                                    that will be parsed and passed to the relevant function, by the following prefixes:
                                    - `CLASS_` - for the model class arguments
                                    - `FIT_` - for the `fit` function arguments
                                    - `TRAIN_` - for the `train` function (in xgb or lgbm train function - future)

    """
    # Validate inputs:
    # Check if exactly one of them is supplied:
    if test_set is None:
        if train_test_split_size is None:
            context.logger.info(
                "test_set or train_test_split_size are not provided, setting train_test_split_size to 0.2"
            )
            train_test_split_size = 0.2

    elif train_test_split_size:
        context.logger.info(
            "test_set provided, ignoring given train_test_split_size value"
        )
        train_test_split_size = None

    # Get DataFrame by URL or by FeatureVector:
    dataset, label_columns = _get_dataframe(
        context=context,
        dataset=dataset,
        label_columns=label_columns,
        drop_columns=drop_columns,
    )

    # Getting the sample set:
    if sample_set is None:
        context.logger.info(
            f"Sample set not given, using the whole training set as the sample set"
        )
        sample_set = dataset
    else:
        sample_set, _ = _get_dataframe(
            context=context,
            dataset=sample_set,
            label_columns=label_columns,
            drop_columns=drop_columns,
        )

    # Parsing kwargs:
    # TODO: Use in xgb or lgbm train function.
    train_kwargs = _get_sub_dict_by_prefix(src=kwargs, prefix_key=KWArgsPrefixes.TRAIN)
    fit_kwargs = _get_sub_dict_by_prefix(src=kwargs, prefix_key=KWArgsPrefixes.FIT)
    model_class_kwargs = _get_sub_dict_by_prefix(
        src=kwargs, prefix_key=KWArgsPrefixes.MODEL_CLASS
    )

    # Check if model or function:
    if hasattr(model_class, "train"):
        # TODO: Need to call: model(), afterwards to start the train function.
        # model = create_function(f"{model_class}.train")
        raise NotImplementedError
    else:
        # Creating model instance:
        model = create_class(model_class)(**model_class_kwargs)

    x = dataset.drop(label_columns, axis=1)
    y = dataset[label_columns]
    if train_test_split_size:
        x_train, x_test, y_train, y_test = train_test_split(
            x, y, test_size=train_test_split_size, random_state=random_state
        )
    else:
        x_train, y_train = x, y

        test_set = test_set.as_df()
        if drop_columns:
            test_set = dataset.drop(drop_columns, axis=1)

        x_test, y_test = test_set.drop(label_columns, axis=1), test_set[label_columns]

    AutoMLRun.apply_mlrun(
        model=model,
        model_name=model_name,
        context=context,
        tag=tag,
        sample_set=sample_set,
        y_columns=label_columns,
        test_set=test_set,
        x_test=x_test,
        y_test=y_test,
        artifacts=context.artifacts,
        labels=labels,
    )
    context.logger.info(f"training '{model_name}'")
    model.fit(x_train, y_train, **fit_kwargs)


def evaluate(
    context: MLClientCtx,
    model: str,
    dataset: mlrun.DataItem,
    drop_columns: List[str] = None,
    label_columns: Optional[Union[str, List[str]]] = None,
    **kwargs,
):
    """
    Evaluating a model. Artifacts generated by the MLHandler.

    :param context:                 MLRun context.
    :param model:                   The model Store path.
    :param dataset:                 The dataset to evaluate the model on. Can be either a URI or a FeatureVector.
    :param drop_columns:            str or a list of strings that represent the columns to drop.
    :param label_columns:           The target label(s) of the column(s) in the dataset. for Regression or
                                    Classification tasks. Mandatory when dataset is not a FeatureVector.
    :param kwargs:                  Here you can pass keyword arguments to the predict function
                                    (PREDICT_ prefix is not required).
    """
    # Get dataset by URL or by FeatureVector:
    dataset, label_columns = _get_dataframe(
        context=context,
        dataset=dataset,
        label_columns=label_columns,
        drop_columns=drop_columns,
    )

    # Parsing label_columns:
    parsed_label_columns = []
    if label_columns:
        label_columns = (
            label_columns if isinstance(label_columns, list) else [label_columns]
        )
        for lc in label_columns:
            if fs.common.feature_separator in lc:
                feature_set_name, label_name, alias = fs.common.parse_feature_string(lc)
                parsed_label_columns.append(alias or label_name)
        if parsed_label_columns:
            label_columns = parsed_label_columns

    x = dataset.drop(label_columns, axis=1)
    y = dataset[label_columns]

    # Loading the model and predicting:
    model_handler = AutoMLRun.load_model(
        model_path=model, context=context, model_name="model_LinearRegression"
    )
    AutoMLRun.apply_mlrun(model_handler.model, y_test=y, model_path=model)

    context.logger.info(f"evaluating '{model_handler.model_name}'")
    model_handler.model.predict(x, **kwargs)


def predict(
    context: MLClientCtx,
    model: str,
    dataset: mlrun.DataItem,
    drop_columns: Union[str, List[str], int, List[int]] = None,
    label_columns: Optional[Union[str, List[str]]] = None,
    result_set: Optional[str] = None,
    **kwargs,
):
    """
    Predicting dataset by a model.

    :param context:                 MLRun context.
    :param model:                   The model Store path.
    :param dataset:                 The dataset to predict the model on. Can be either a URI, a FeatureVector or a
                                    sample in a shape of a list/dict.
                                    When passing a sample, pass the dataset as a field in `params` instead of `inputs`.
    :param drop_columns:            str/int or a list of strings/ints that represent the column names/indices to drop.
                                    When the dataset is a list/dict this parameter should be represented by integers.
    :param label_columns:           The target label(s) of the column(s) in the dataset. for Regression or
                                    Classification tasks. Mandatory when dataset is not a FeatureVector.
    :param result_set:              The db key to set name of the prediction result and the filename.
                                    Default to 'prediction'.
    :param kwargs:                  Here you can pass keyword arguments to the predict function
                                    (PREDICT_ prefix is not required).
    """
    # Get dataset by URL or by FeatureVector:
    dataset, label_columns = _get_dataframe(
        context=context,
        dataset=dataset,
        label_columns=label_columns,
        drop_columns=drop_columns,
    )

    # loading the model, and getting the model handler:
    model_handler = AutoMLRun.load_model(model_path=model, context=context)

    # Dropping label columns if necessary:
    if not label_columns:
        label_columns = []
    elif isinstance(label_columns, str):
        label_columns = [label_columns]

    # Predicting:
    context.logger.info(f"making prediction by '{model_handler.model_name}'")
    y_pred = model_handler.model.predict(dataset, **kwargs)

    # Preparing and validating label columns for the dataframe of the prediction result:
    num_predicted = 1 if len(y_pred.shape) == 1 else y_pred.shape[1]

    if num_predicted > len(label_columns):
        if num_predicted == 1:
            label_columns = ["predicted labels"]
        else:
            label_columns.extend(
                [
                    f"predicted_label_{i + 1 + len(label_columns)}"
                    for i in range(num_predicted - len(label_columns))
                ]
            )
    elif num_predicted < len(label_columns):
        context.logger.error(
            f"number of predicted labels: {num_predicted} is smaller than number of label columns: {len(label_columns)}"
        )
        raise ValueError

    artifact_name = result_set or "prediction"
    labels_inside_df = set(label_columns) & set(dataset.columns.tolist())
    if labels_inside_df:
        context.logger.error(
            f"The labels: {labels_inside_df} are already existed in the dataframe"
        )
        raise ValueError
    pred_df = pd.concat([dataset, pd.DataFrame(y_pred, columns=label_columns)], axis=1)
    context.log_dataset(artifact_name, pred_df, db_key=result_set)
 - code_origin: '' - description: Automatic train, evaluate and predict functions for the ML frameworks - - Scikit-Learn, XGBoost and LightGBM. - disable_auto_mount: false - default_handler: train entry_points: train: - lineno: 121 + doc: "Training a model with the given dataset.\n\nexample::\n\n import mlrun\n\ + \ project = mlrun.get_or_create_project(\"my-project\")\n project.set_function(\"\ + hub://auto_trainer\", \"train\")\n trainer_run = project.run(\n \ + \ name=\"train\",\n handler=\"train\",\n inputs={\"dataset\"\ + : \"./path/to/dataset.csv\"},\n params={\n \"model_class\"\ + : \"sklearn.linear_model.LogisticRegression\",\n \"label_columns\"\ + : \"label\",\n \"drop_columns\": \"id\",\n \"model_name\"\ + : \"my-model\",\n \"tag\": \"v1.0.0\",\n \"sample_set\"\ + : \"./path/to/sample_set.csv\",\n \"test_set\": \"./path/to/test_set.csv\"\ + ,\n \"CLASS_solver\": \"liblinear\",\n },\n )" + has_kwargs: true parameters: - name: context type: MLClientCtx @@ -70,21 +67,12 @@ spec: type: dict doc: Labels to log with the model default: null - has_varargs: false + lineno: 126 name: train - has_kwargs: true - doc: "Training a model with the given dataset.\n\nexample::\n\n import mlrun\n\ - \ project = mlrun.get_or_create_project(\"my-project\")\n project.set_function(\"\ - hub://auto_trainer\", \"train\")\n trainer_run = project.run(\n \ - \ name=\"train\",\n handler=\"train\",\n inputs={\"dataset\"\ - : \"./path/to/dataset.csv\"},\n params={\n \"model_class\"\ - : \"sklearn.linear_model.LogisticRegression\",\n \"label_columns\"\ - : \"label\",\n \"drop_columns\": \"id\",\n \"model_name\"\ - : \"my-model\",\n \"tag\": \"v1.0.0\",\n \"sample_set\"\ - : \"./path/to/sample_set.csv\",\n \"test_set\": \"./path/to/test_set.csv\"\ - ,\n \"CLASS_solver\": \"liblinear\",\n },\n )" + has_varargs: false evaluate: - lineno: 273 + doc: Evaluating a model. Artifacts generated by the MLHandler. + has_kwargs: true parameters: - name: context type: MLClientCtx @@ -104,12 +92,12 @@ spec: doc: The target label(s) of the column(s) in the dataset. for Regression or Classification tasks. Mandatory when dataset is not a FeatureVector. default: null - has_varargs: false + lineno: 278 name: evaluate - has_kwargs: true - doc: Evaluating a model. Artifacts generated by the MLHandler. + has_varargs: false predict: - lineno: 327 + doc: Predicting dataset by a model. + has_kwargs: true parameters: - name: context type: MLClientCtx @@ -138,10 +126,22 @@ spec: doc: The db key to set name of the prediction result and the filename. Default to 'prediction'. default: null - has_varargs: false + lineno: 332 name: predict - has_kwargs: true - doc: Predicting dataset by a model. + has_varargs: false + build: + code_origin: '' + origin_filename: '' + functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import mlrun
import mlrun.datastore
import mlrun.utils
import pandas as pd
from mlrun import feature_store as fs
from mlrun.datastore import DataItem
from mlrun.execution import MLClientCtx
from mlrun.frameworks.auto_mlrun import AutoMLRun
from mlrun.utils.helpers import create_class, create_function
from sklearn.model_selection import train_test_split

PathType = Union[str, Path]


class KWArgsPrefixes:
    MODEL_CLASS = "CLASS_"
    FIT = "FIT_"
    TRAIN = "TRAIN_"


def _get_sub_dict_by_prefix(src: Dict, prefix_key: str) -> Dict[str, Any]:
    """
    Collect all the keys from the given dict that starts with the given prefix and creates a new dictionary with these
    keys.

    :param src:         The source dict to extract the values from.
    :param prefix_key:  Only keys with this prefix will be returned. The keys in the result dict will be without this
                        prefix.
    """
    return {
        key.replace(prefix_key, ""): val
        for key, val in src.items()
        if key.startswith(prefix_key)
    }


def _get_dataframe(
    context: MLClientCtx,
    dataset: DataItem,
    label_columns: Optional[Union[str, List[str]]] = None,
    drop_columns: Union[str, List[str], int, List[int]] = None,
) -> Tuple[pd.DataFrame, Optional[Union[str, List[str]]]]:
    """
    Getting the DataFrame of the dataset and drop the columns accordingly.

    :param context:         MLRun context.
    :param dataset:         The dataset to train the model on.
                            Can be either a list of lists, dict, URI or a FeatureVector.
    :param label_columns:   The target label(s) of the column(s) in the dataset. for Regression or
                            Classification tasks.
    :param drop_columns:    str/int or a list of strings/ints that represent the column names/indices to drop.
    """
    # Check if dataset is list/dict first (before trying to access artifact_url)
    if isinstance(dataset, (list, dict)):
        # list/dict case:
        if not label_columns:
            context.logger.info(
                "label_columns not provided, mandatory when dataset is not a FeatureVector"
            )
            raise ValueError
        dataset = pd.DataFrame(dataset)
        # Checking if drop_columns provided by integer type:
        if drop_columns:
            if isinstance(drop_columns, str) or (
                isinstance(drop_columns, list)
                and any(isinstance(col, str) for col in drop_columns)
            ):
                context.logger.error(
                    "drop_columns must be an integer/list of integers if not provided with a URI/FeatureVector dataset"
                )
                raise ValueError
            dataset.drop(drop_columns, axis=1, inplace=True)
    else:
        # Dataset is a DataItem with artifact_url (URI or FeatureVector)
        store_uri_prefix, _ = mlrun.datastore.parse_store_uri(dataset.artifact_url)

        # Getting the dataset:
        if mlrun.utils.StorePrefix.FeatureVector == store_uri_prefix:
            label_columns = label_columns or dataset.meta.status.label_column
            context.logger.info(f"label columns: {label_columns}")
            # FeatureVector case:
            try:
                fv = mlrun.datastore.get_store_resource(dataset.artifact_url)
                dataset = fv.get_offline_features(drop_columns=drop_columns).to_dataframe()
            except AttributeError:
                # Leave here for backwards compatibility
                dataset = fs.get_offline_features(
                    dataset.meta.uri, drop_columns=drop_columns
                ).to_dataframe()
        else:
            # simple URL case:
            if not label_columns:
                context.logger.info(
                    "label_columns not provided, mandatory when dataset is not a FeatureVector"
                )
                raise ValueError
            dataset = dataset.as_df()
            if drop_columns:
                if all(col in dataset for col in drop_columns):
                    dataset = dataset.drop(drop_columns, axis=1)
                else:
                    context.logger.info(
                        "not all of the columns to drop in the dataset, drop columns process skipped"
                    )

    return dataset, label_columns


def train(
    context: MLClientCtx,
    dataset: DataItem,
    model_class: str,
    label_columns: Optional[Union[str, List[str]]] = None,
    drop_columns: List[str] = None,
    model_name: str = "model",
    tag: str = "",
    sample_set: DataItem = None,
    test_set: DataItem = None,
    train_test_split_size: float = None,
    random_state: int = None,
    labels: dict = None,
    **kwargs,
):
    """
    Training a model with the given dataset.

    example::

        import mlrun
        project = mlrun.get_or_create_project("my-project")
        project.set_function("hub://auto_trainer", "train")
        trainer_run = project.run(
            name="train",
            handler="train",
            inputs={"dataset": "./path/to/dataset.csv"},
            params={
                "model_class": "sklearn.linear_model.LogisticRegression",
                "label_columns": "label",
                "drop_columns": "id",
                "model_name": "my-model",
                "tag": "v1.0.0",
                "sample_set": "./path/to/sample_set.csv",
                "test_set": "./path/to/test_set.csv",
                "CLASS_solver": "liblinear",
            },
        )

    :param context:                 MLRun context
    :param dataset:                 The dataset to train the model on. Can be either a URI or a FeatureVector
    :param model_class:             The class of the model, e.g. `sklearn.linear_model.LogisticRegression`
    :param label_columns:           The target label(s) of the column(s) in the dataset. for Regression or
                                    Classification tasks. Mandatory when dataset is not a FeatureVector.
    :param drop_columns:            str or a list of strings that represent the columns to drop
    :param model_name:              The model's name to use for storing the model artifact, default to 'model'
    :param tag:                     The model's tag to log with
    :param sample_set:              A sample set of inputs for the model for logging its stats along the model in favour
                                    of model monitoring. Can be either a URI or a FeatureVector
    :param test_set:                The test set to train the model with.
    :param train_test_split_size:   if test_set was provided then this argument is ignored.
                                    Should be between 0.0 and 1.0 and represent the proportion of the dataset to include
                                    in the test split. The size of the Training set is set to the complement of this
                                    value. Default = 0.2
    :param random_state:            Relevant only when using train_test_split_size.
                                    A random state seed to shuffle the data. For more information, see:
                                    https://scikit-learn.org/stable/glossary.html#term-random_state
                                    Notice that here we only pass integer values.
    :param labels:                  Labels to log with the model
    :param kwargs:                  Here you can pass keyword arguments with prefixes,
                                    that will be parsed and passed to the relevant function, by the following prefixes:
                                    - `CLASS_` - for the model class arguments
                                    - `FIT_` - for the `fit` function arguments
                                    - `TRAIN_` - for the `train` function (in xgb or lgbm train function - future)

    """
    # Validate inputs:
    # Check if exactly one of them is supplied:
    if test_set is None:
        if train_test_split_size is None:
            context.logger.info(
                "test_set or train_test_split_size are not provided, setting train_test_split_size to 0.2"
            )
            train_test_split_size = 0.2

    elif train_test_split_size:
        context.logger.info(
            "test_set provided, ignoring given train_test_split_size value"
        )
        train_test_split_size = None

    # Get DataFrame by URL or by FeatureVector:
    dataset, label_columns = _get_dataframe(
        context=context,
        dataset=dataset,
        label_columns=label_columns,
        drop_columns=drop_columns,
    )

    # Getting the sample set:
    if sample_set is None:
        context.logger.info(
            f"Sample set not given, using the whole training set as the sample set"
        )
        sample_set = dataset
    else:
        sample_set, _ = _get_dataframe(
            context=context,
            dataset=sample_set,
            label_columns=label_columns,
            drop_columns=drop_columns,
        )

    # Parsing kwargs:
    # TODO: Use in xgb or lgbm train function.
    train_kwargs = _get_sub_dict_by_prefix(src=kwargs, prefix_key=KWArgsPrefixes.TRAIN)
    fit_kwargs = _get_sub_dict_by_prefix(src=kwargs, prefix_key=KWArgsPrefixes.FIT)
    model_class_kwargs = _get_sub_dict_by_prefix(
        src=kwargs, prefix_key=KWArgsPrefixes.MODEL_CLASS
    )

    # Check if model or function:
    if hasattr(model_class, "train"):
        # TODO: Need to call: model(), afterwards to start the train function.
        # model = create_function(f"{model_class}.train")
        raise NotImplementedError
    else:
        # Creating model instance:
        model = create_class(model_class)(**model_class_kwargs)

    x = dataset.drop(label_columns, axis=1)
    y = dataset[label_columns]
    if train_test_split_size:
        x_train, x_test, y_train, y_test = train_test_split(
            x, y, test_size=train_test_split_size, random_state=random_state
        )
    else:
        x_train, y_train = x, y

        test_set = test_set.as_df()
        if drop_columns:
            test_set = dataset.drop(drop_columns, axis=1)

        x_test, y_test = test_set.drop(label_columns, axis=1), test_set[label_columns]

    AutoMLRun.apply_mlrun(
        model=model,
        model_name=model_name,
        context=context,
        tag=tag,
        sample_set=sample_set,
        y_columns=label_columns,
        test_set=test_set,
        x_test=x_test,
        y_test=y_test,
        artifacts=context.artifacts,
        labels=labels,
    )
    context.logger.info(f"training '{model_name}'")
    model.fit(x_train, y_train, **fit_kwargs)


def evaluate(
    context: MLClientCtx,
    model: str,
    dataset: mlrun.DataItem,
    drop_columns: List[str] = None,
    label_columns: Optional[Union[str, List[str]]] = None,
    **kwargs,
):
    """
    Evaluating a model. Artifacts generated by the MLHandler.

    :param context:                 MLRun context.
    :param model:                   The model Store path.
    :param dataset:                 The dataset to evaluate the model on. Can be either a URI or a FeatureVector.
    :param drop_columns:            str or a list of strings that represent the columns to drop.
    :param label_columns:           The target label(s) of the column(s) in the dataset. for Regression or
                                    Classification tasks. Mandatory when dataset is not a FeatureVector.
    :param kwargs:                  Here you can pass keyword arguments to the predict function
                                    (PREDICT_ prefix is not required).
    """
    # Get dataset by URL or by FeatureVector:
    dataset, label_columns = _get_dataframe(
        context=context,
        dataset=dataset,
        label_columns=label_columns,
        drop_columns=drop_columns,
    )

    # Parsing label_columns:
    parsed_label_columns = []
    if label_columns:
        label_columns = (
            label_columns if isinstance(label_columns, list) else [label_columns]
        )
        for lc in label_columns:
            if fs.common.feature_separator in lc:
                feature_set_name, label_name, alias = fs.common.parse_feature_string(lc)
                parsed_label_columns.append(alias or label_name)
        if parsed_label_columns:
            label_columns = parsed_label_columns

    x = dataset.drop(label_columns, axis=1)
    y = dataset[label_columns]

    # Loading the model and predicting:
    model_handler = AutoMLRun.load_model(
        model_path=model, context=context, model_name="model_LinearRegression"
    )
    AutoMLRun.apply_mlrun(model_handler.model, y_test=y, model_path=model)

    context.logger.info(f"evaluating '{model_handler.model_name}'")
    model_handler.model.predict(x, **kwargs)


def predict(
    context: MLClientCtx,
    model: str,
    dataset: mlrun.DataItem,
    drop_columns: Union[str, List[str], int, List[int]] = None,
    label_columns: Optional[Union[str, List[str]]] = None,
    result_set: Optional[str] = None,
    **kwargs,
):
    """
    Predicting dataset by a model.

    :param context:                 MLRun context.
    :param model:                   The model Store path.
    :param dataset:                 The dataset to predict the model on. Can be either a URI, a FeatureVector or a
                                    sample in a shape of a list/dict.
                                    When passing a sample, pass the dataset as a field in `params` instead of `inputs`.
    :param drop_columns:            str/int or a list of strings/ints that represent the column names/indices to drop.
                                    When the dataset is a list/dict this parameter should be represented by integers.
    :param label_columns:           The target label(s) of the column(s) in the dataset. for Regression or
                                    Classification tasks. Mandatory when dataset is not a FeatureVector.
    :param result_set:              The db key to set name of the prediction result and the filename.
                                    Default to 'prediction'.
    :param kwargs:                  Here you can pass keyword arguments to the predict function
                                    (PREDICT_ prefix is not required).
    """
    # Get dataset by URL or by FeatureVector:
    dataset, label_columns = _get_dataframe(
        context=context,
        dataset=dataset,
        label_columns=label_columns,
        drop_columns=drop_columns,
    )

    # loading the model, and getting the model handler:
    model_handler = AutoMLRun.load_model(model_path=model, context=context)

    # Fix feature names for models that require them (e.g., XGBoost)
    # When dataset comes from a list, pandas assigns default integer column names
    # but some models expect specific feature names they were trained with
    if hasattr(model_handler.model, 'feature_names_in_'):
        expected_features = model_handler.model.feature_names_in_
        if len(dataset.columns) == len(expected_features):
            # Only rename if the number of columns matches
            # This handles the case where a list was converted to DataFrame with default column names
            if not all(col == feat for col, feat in zip(dataset.columns, expected_features)):
                context.logger.info(
                    f"Renaming dataset columns to match model's expected feature names"
                )
                dataset.columns = expected_features

    # Dropping label columns if necessary:
    if not label_columns:
        label_columns = []
    elif isinstance(label_columns, str):
        label_columns = [label_columns]

    # Predicting:
    context.logger.info(f"making prediction by '{model_handler.model_name}'")
    y_pred = model_handler.model.predict(dataset, **kwargs)

    # Preparing and validating label columns for the dataframe of the prediction result:
    num_predicted = 1 if len(y_pred.shape) == 1 else y_pred.shape[1]

    if num_predicted > len(label_columns):
        if num_predicted == 1:
            label_columns = ["predicted labels"]
        else:
            label_columns.extend(
                [
                    f"predicted_label_{i + 1 + len(label_columns)}"
                    for i in range(num_predicted - len(label_columns))
                ]
            )
    elif num_predicted < len(label_columns):
        context.logger.error(
            f"number of predicted labels: {num_predicted} is smaller than number of label columns: {len(label_columns)}"
        )
        raise ValueError

    artifact_name = result_set or "prediction"
    labels_inside_df = set(label_columns) & set(dataset.columns.tolist())
    if labels_inside_df:
        context.logger.error(
            f"The labels: {labels_inside_df} are already existed in the dataframe"
        )
        raise ValueError
    pred_df = pd.concat([dataset, pd.DataFrame(y_pred, columns=label_columns)], axis=1)
    context.log_dataset(artifact_name, pred_df, db_key=result_set)
 command: '' -kind: job -verbose: false + default_handler: train + image: mlrun/mlrun + disable_auto_mount: false + description: Automatic train, evaluate and predict functions for the ML frameworks + - Scikit-Learn, XGBoost and LightGBM. +metadata: + categories: + - machine-learning + - model-training + tag: '' + name: auto-trainer diff --git a/functions/src/auto_trainer/item.yaml b/functions/src/auto_trainer/item.yaml index ba33f6a08..d397a79d6 100755 --- a/functions/src/auto_trainer/item.yaml +++ b/functions/src/auto_trainer/item.yaml @@ -13,7 +13,7 @@ labels: author: Iguazio maintainers: [] marketplaceType: '' -mlrunVersion: 1.7.0 +mlrunVersion: 1.10.0 name: auto_trainer platformVersion: 3.5.0 spec: @@ -23,4 +23,4 @@ spec: kind: job requirements: [] url: '' -version: 1.8.0 +version: 1.9.0 diff --git a/functions/src/auto_trainer/requirements.txt b/functions/src/auto_trainer/requirements.txt index b14a0293c..b23f9b9dd 100644 --- a/functions/src/auto_trainer/requirements.txt +++ b/functions/src/auto_trainer/requirements.txt @@ -1,4 +1,5 @@ pandas -scikit-learn<1.4.0 +scikit-learn~=1.5.2 +lightgbm xgboost<2.0.0 plotly diff --git a/functions/src/auto_trainer/test_auto_trainer.py b/functions/src/auto_trainer/test_auto_trainer.py index 9a1ff554c..ac95109f8 100644 --- a/functions/src/auto_trainer/test_auto_trainer.py +++ b/functions/src/auto_trainer/test_auto_trainer.py @@ -29,6 +29,9 @@ ("sklearn.linear_model.LinearRegression", "regression"), ("sklearn.ensemble.RandomForestClassifier", "classification"), ("xgboost.XGBRegressor", "regression"), + ("xgboost.XGBClassifier", "classification"), + ("lightgbm.LGBMRegressor", "regression"), + ("lightgbm.LGBMClassifier", "classification") ] REQUIRED_ENV_VARS = [ @@ -78,11 +81,15 @@ def _assert_train_handler(train_run): @pytest.mark.parametrize("model", MODELS) +@pytest.mark.skipif( + condition=not _validate_environment_variables(), + reason="Project's environment variables are not set", +) def test_train(model: Tuple[str, str]): dataset, label_columns = _get_dataset(model[1]) is_test_passed = True - project = mlrun.new_project("auto-trainer-test", context="./") + project = mlrun.get_or_create_project("auto-trainer-test", context="./") fn = project.set_function("function.yaml", "train", kind="job", image="mlrun/mlrun") train_run = None @@ -119,7 +126,7 @@ def test_train_evaluate(model: Tuple[str, str]): dataset, label_columns = _get_dataset(model[1]) is_test_passed = True # Importing function: - project = mlrun.new_project("auto-trainer-test", context="./") + project = mlrun.get_or_create_project("auto-trainer-test", context="./") fn = project.set_function("function.yaml", "train", kind="job", image="mlrun/mlrun") temp_dir = tempfile.mkdtemp() @@ -172,7 +179,7 @@ def test_train_predict(model: Tuple[str, str]): df = pd.read_csv(dataset) sample = df.head().drop("labels", axis=1).values.tolist() # Importing function: - project = mlrun.new_project("auto-trainer-test", context="./") + project = mlrun.get_or_create_project("auto-trainer-test", context="./") fn = project.set_function("function.yaml", "train", kind="job", image="mlrun/mlrun") temp_dir = tempfile.mkdtemp() diff --git a/functions/src/describe/function.yaml b/functions/src/describe/function.yaml index a11461774..1c254c3c4 100644 --- a/functions/src/describe/function.yaml +++ b/functions/src/describe/function.yaml @@ -1,9 +1,44 @@ +metadata: + tag: '' + categories: + - data-analysis + name: describe +verbose: false +kind: job spec: + command: '' + image: mlrun/mlrun + description: describe and visualizes dataset stats + disable_auto_mount: false + default_handler: analyze entry_points: analyze: + doc: 'The function will output the following artifacts per + + column within the data frame (based on data types) + + If the data has more than 500,000 sample we + + sample randomly 500,000 samples: + + + describe csv + + histograms + + scatter-2d + + violin chart + + correlation-matrix chart + + correlation-matrix csv + + imbalance pie chart + + imbalance-weights-vec csv' + has_kwargs: false has_varargs: false - outputs: - - type: None parameters: - name: context type: MLClientCtx @@ -45,46 +80,11 @@ spec: - name: dask_client doc: Dask client object default: null - doc: 'The function will output the following artifacts per - - column within the data frame (based on data types) - - If the data has more than 500,000 sample we - - sample randomly 500,000 samples: - - - describe csv - - histograms - - scatter-2d - - violin chart - - correlation-matrix chart - - correlation-matrix csv - - imbalance pie chart - - imbalance-weights-vec csv' - has_kwargs: false + outputs: + - type: None name: analyze lineno: 46 - image: mlrun/mlrun - command: '' build: - functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Generated by nuclio.export.NuclioExporter

import warnings
from typing import Union

import mlrun
import numpy as np

warnings.simplefilter(action="ignore", category=FutureWarning)

import mlrun.feature_store as fstore
import pandas as pd
import plotly.express as px
import plotly.figure_factory as ff
import plotly.graph_objects as go
from mlrun.artifacts import (
    Artifact,
    DatasetArtifact,
    PlotlyArtifact,
    TableArtifact,
    update_dataset_meta,
)
from mlrun.datastore import DataItem
from mlrun.execution import MLClientCtx
from mlrun.feature_store import FeatureSet
from plotly.subplots import make_subplots

pd.set_option("display.float_format", lambda x: "%.2f" % x)
MAX_SIZE_OF_DF = 500000


def analyze(
    context: MLClientCtx,
    name: str = "dataset",
    table: Union[FeatureSet, DataItem] = None,
    label_column: str = None,
    plots_dest: str = "plots",
    random_state: int = 1,
    problem_type: str = "classification",
    dask_key: str = "dask_key",
    dask_function: str = None,
    dask_client=None,
) -> None:
    """
    The function will output the following artifacts per
    column within the data frame (based on data types)
    If the data has more than 500,000 sample we
    sample randomly 500,000 samples:

    describe csv
    histograms
    scatter-2d
    violin chart
    correlation-matrix chart
    correlation-matrix csv
    imbalance pie chart
    imbalance-weights-vec csv

    :param context:                 The function context
    :param name:                    Key of dataset to database ("dataset" for default)
    :param table:                   MLRun input pointing to pandas dataframe (csv/parquet file path) or FeatureSet
                                    as param
    :param label_column:            Ground truth column label
    :param plots_dest:              Destination folder of summary plots (relative to artifact_path)
                                    ("plots" for default)
    :param random_state:            When the table has more than 500,000 samples, we sample randomly 500,000 samples
    :param problem_type             The type of the ML problem the data facing - regression, classification or None
                                    (classification for default)
    :param dask_key:                Key of dataframe in dask client "datasets" attribute
    :param dask_function:           Dask function url (db://..)
    :param dask_client:             Dask client object
    """
    data_item, featureset, creat, update = False, False, False, False
    get_from_table = True
    if dask_function or dask_client:
        data_item, creat = True, True
        if dask_function:
            client = mlrun.import_function(dask_function).client
        elif dask_client:
            client = dask_client
        else:
            raise ValueError("dask client was not provided")

        if dask_key in client.datasets:
            df = client.get_dataset(dask_key)
            data_item, creat, get_from_table = True, True, False
        elif table:
            get_from_table = True
        else:
            context.logger.info(
                f"only these datasets are available {client.datasets} in client {client}"
            )
            raise Exception("dataset not found on dask cluster")

    if get_from_table:
        if type(table) == DataItem:
            if table.meta is None:
                data_item, creat, update = True, True, False
            elif table.meta.kind == "dataset":
                data_item, creat, update = True, False, True
            elif table.meta.kind == "FeatureVector":
                data_item, creat, update = True, False, False
            elif table.meta.kind == "FeatureSet":
                featureset, creat, update = True, False, False

        if data_item:
            df = table.as_df()
        elif featureset:
            project_name, set_name = (
                table._path.split("/")[2],
                table._path.split("/")[4],
            )
            feature_set = fstore.get_feature_set(
                f"store://feature-sets/{project_name}/{set_name}"
            )
            df = feature_set.to_dataframe()
        else:
            context.logger.error(f"Wrong table type.")
            return

    if df.size > MAX_SIZE_OF_DF:
        df = df.sample(n=int(MAX_SIZE_OF_DF / df.shape[1]), random_state=random_state)
    extra_data = {}

    if label_column not in df.columns:
        label_column = None

    extra_data["describe csv"] = context.log_artifact(
        TableArtifact("describe-csv", df=df.describe()),
        local_path=f"{plots_dest}/describe.csv",
    )

    try:
        _create_histogram_mat_artifact(
            context, df, extra_data, label_column, plots_dest
        )
    except Exception as e:
        context.logger.warn(f"Failed to create histogram matrix artifact due to: {e}")
    try:
        _create_features_histogram_artifacts(
            context, df, extra_data, label_column, plots_dest, problem_type
        )
    except Exception as e:
        context.logger.warn(f"Failed to create pairplot histograms due to: {e}")
    try:
        _create_features_2d_scatter_artifacts(
            context, df, extra_data, label_column, plots_dest, problem_type
        )
    except Exception as e:
        context.logger.warn(f"Failed to create pairplot 2d_scatter due to: {e}")
    try:
        _create_violin_artifact(context, df, extra_data, plots_dest)
    except Exception as e:
        context.logger.warn(f"Failed to create violin distribution plots due to: {e}")
    try:
        _create_imbalance_artifact(
            context, df, extra_data, label_column, plots_dest, problem_type
        )
    except Exception as e:
        context.logger.warn(f"Failed to create class imbalance plot due to: {e}")
    try:
        _create_corr_artifact(context, df, extra_data, label_column, plots_dest)
    except Exception as e:
        context.logger.warn(f"Failed to create features correlation plot due to: {e}")

    if not data_item:
        return

    artifact = table.artifact_url
    if creat:  # dataset not stored
        artifact = DatasetArtifact(
            key="dataset", stats=True, df=df, extra_data=extra_data
        )
        artifact = context.log_artifact(artifact, db_key=name)
        context.logger.info(f"The data set is logged to the project under {name} name")

    if update:
        update_dataset_meta(artifact, extra_data=extra_data)
        context.logger.info(f"The data set named {name} is updated")

    # TODO : 3-D plot on on selected features.
    # TODO : Reintegration plot on on selected features.
    # TODO : PCA plot (with options)


def _create_histogram_mat_artifact(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
):
    """
    Create and log a histogram matrix artifact
    """
    context.log_artifact(
        item=Artifact(
            key="hist",
            body=b"<b> Deprecated, see the artifacts scatter-2d "
            b"and histograms instead<b>",
        ),
        local_path=f"{plots_dest}/hist.html",
    )


def _create_features_histogram_artifacts(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
    problem_type: str,
):
    """
    Create and log a histogram artifact for each feature
    """

    figs = dict()
    first_feature_name = ""
    if label_column is not None and problem_type == "classification":
        all_labels = df[label_column].unique()
    visible = True
    for column_name in df.columns:
        if column_name == label_column:
            continue

        if label_column is not None and problem_type == "classification":
            for label in all_labels:
                sub_fig = go.Histogram(
                    histfunc="count",
                    x=df.loc[df[label_column] == label][column_name],
                    name=str(label),
                    visible=visible,
                )
                figs[f"{column_name}@?@{label}"] = sub_fig
        else:
            sub_fig = go.Histogram(histfunc="count", x=df[column_name], visible=visible)
            figs[f"{column_name}@?@{1}"] = sub_fig
        if visible:
            first_feature_name = column_name
        visible = False

    fig = go.Figure()
    for k in figs.keys():
        fig.add_trace(figs[k])

    fig.update_layout(
        updatemenus=[
            {
                "buttons": [
                    {
                        "label": column_name,
                        "method": "update",
                        "args": [
                            {
                                "visible": [
                                    key.split("@?@")[0] == column_name
                                    for key in figs.keys()
                                ],
                                "xaxis": {
                                    "range": [
                                        min(df[column_name]),
                                        max(df[column_name]),
                                    ]
                                },
                            },
                            {"title": f"<i><b>Histogram of {column_name}</b></i>"},
                        ],
                    }
                    for column_name in df.columns
                    if column_name != label_column
                ],
                "direction": "down",
                "pad": {"r": 10, "t": 10},
                "showactive": True,
                "x": 0.25,
                "xanchor": "left",
                "y": 1.1,
                "yanchor": "top",
            }
        ],
        annotations=[
            dict(
                text="Select Feature Name ",
                showarrow=False,
                x=0,
                y=1.05,
                yref="paper",
                xref="paper",
                align="left",
                xanchor="left",
                yanchor="top",
                font={
                    "color": "blue",
                },
            )
        ],
    )

    fig.update_layout(
        width=600,
        height=400,
        autosize=False,
        margin=dict(t=100, b=0, l=0, r=0),
        template="plotly_white",
    )

    fig.update_layout(title_text=f"<i><b>Histograms of {first_feature_name}</b></i>")
    extra_data[f"histograms"] = context.log_artifact(
        PlotlyArtifact(key=f"histograms", figure=fig),
        local_path=f"{plots_dest}/histograms.html",
    )


def _create_features_2d_scatter_artifacts(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
    problem_type: str,
):
    """
    Create and log a scatter-2d artifact for each couple of features
    """
    features = [
        column_name for column_name in df.columns if column_name != label_column
    ]
    max_feature_len = float(max(len(elem) for elem in features))
    if label_column is not None:
        labels = sorted(df[label_column].unique())
    else:
        labels = [None]
    fig = go.Figure()
    if label_column is not None and problem_type == "classification":
        for l in labels:
            fig.add_trace(
                go.Scatter(
                    x=df.loc[df[label_column] == l][features[0]],
                    y=df.loc[df[label_column] == l][features[0]],
                    mode="markers",
                    visible=True,
                    showlegend=True,
                    name=str(l),
                )
            )
    elif label_column is None:
        fig.add_trace(
            go.Scatter(
                x=df[features[0]],
                y=df[features[0]],
                mode="markers",
                visible=True,
            )
        )
    elif problem_type == "regression":
        fig.add_trace(
            go.Scatter(
                x=df[features[0]],
                y=df[features[0]],
                mode="markers",
                marker=dict(
                    color=df[label_column], colorscale="Viridis", showscale=True
                ),
                visible=True,
            )
        )

    x_buttons = []
    y_buttons = []

    for ncol in features:
        if problem_type == "classification" and label_column is not None:
            x_buttons.append(
                dict(
                    method="update",
                    label=ncol,
                    args=[
                        {"x": [df.loc[df[label_column] == l][ncol] for l in labels]},
                        np.arange(len(labels)).tolist(),
                    ],
                )
            )

            y_buttons.append(
                dict(
                    method="update",
                    label=ncol,
                    args=[
                        {"y": [df.loc[df[label_column] == l][ncol] for l in labels]},
                        np.arange(len(labels)).tolist(),
                    ],
                )
            )
        else:
            x_buttons.append(
                dict(method="update", label=ncol, args=[{"x": [df[ncol]]}])
            )

            y_buttons.append(
                dict(method="update", label=ncol, args=[{"y": [df[ncol]]}])
            )

    # Pass buttons to the updatemenus argument
    fig.update_layout(
        updatemenus=[
            dict(buttons=x_buttons, direction="up", x=0.5, y=-0.1),
            dict(buttons=y_buttons, direction="down", x=-max_feature_len / 100, y=0.5),
        ]
    )

    fig.update_layout(
        width=600,
        height=400,
        autosize=False,
        margin=dict(t=100, b=0, l=0, r=0),
        template="plotly_white",
    )

    fig.update_layout(title_text=f"<i><b>Scatter-2d</b></i>")
    extra_data[f"scatter-2d"] = context.log_artifact(
        PlotlyArtifact(key=f"scatter-2d", figure=fig),
        local_path=f"{plots_dest}/scatter-2d.html",
    )


def _create_violin_artifact(
    context: MLClientCtx, df: pd.DataFrame, extra_data: dict, plots_dest: str
):
    """
    Create and log a violin artifact
    """
    cols = 5
    rows = (df.shape[1] // cols) + 1
    fig = make_subplots(rows=rows, cols=cols)

    plot_num = 0

    for column_name in df.columns:
        column_data = df[column_name]
        violin = go.Violin(
            x=[column_name] * column_data.shape[0],
            y=column_data,
            name=column_name,
        )

        fig.add_trace(
            violin,
            row=(plot_num // cols) + 1,
            col=(plot_num % cols) + 1,
        )

        plot_num += 1

    fig["layout"].update(
        height=(rows + 1) * 200,
        width=(cols + 1) * 200,
        title="<i><b>Violin Plots</b></i>",
    )

    fig.update_layout(showlegend=False)
    extra_data["violin"] = context.log_artifact(
        PlotlyArtifact(key="violin", figure=fig),
        local_path=f"{plots_dest}/violin.html",
    )


def _create_imbalance_artifact(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
    problem_type: str,
):
    """
    Create and log an imbalance class artifact (csv + plot)
    """
    if label_column:
        if problem_type == "classification":
            values_column = "count"
            labels_count = df[label_column].value_counts().sort_index()
            df_labels_count = pd.DataFrame(labels_count)
            df_labels_count[label_column] = labels_count.index
            df_labels_count.rename(columns={"": values_column}, inplace=True)
            df_labels_count[values_column] = df_labels_count[values_column] / sum(
                df_labels_count[values_column]
            )
            fig = px.pie(df_labels_count, names=label_column, values=values_column)
        else:
            fig = px.histogram(
                histfunc="count",
                x=df[label_column],
            )
            hist = np.histogram(df[label_column])
            df_labels_count = pd.DataFrame(
                {"min_val": hist[1], "count": hist[0].tolist() + [0]}
            )
        fig.update_layout(title_text="<i><b>Labels Imbalance</b></i>")
        extra_data["imbalance"] = context.log_artifact(
            PlotlyArtifact(key="imbalance", figure=fig),
            local_path=f"{plots_dest}/imbalance.html",
        )
        extra_data["imbalance-csv"] = context.log_artifact(
            TableArtifact("imbalance-weights-vec", df=df_labels_count),
            local_path=f"{plots_dest}/imbalance-weights-vec.csv",
        )


def _create_corr_artifact(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
):
    """
    Create and log an correlation-matrix artifact (csv + plot)
    """
    if label_column is not None:
        df = df.drop([label_column], axis=1)
    tblcorr = df.corr(numeric_only=True)
    extra_data["correlation-matrix-csv"] = context.log_artifact(
        TableArtifact("correlation-matrix-csv", df=tblcorr, visible=True),
        local_path=f"{plots_dest}/correlation-matrix.csv",
    )

    z = tblcorr.values.tolist()
    z_text = [["{:.2f}".format(y) for y in x] for x in z]
    fig = ff.create_annotated_heatmap(
        z,
        x=list(tblcorr.columns),
        y=list(tblcorr.columns),
        annotation_text=z_text,
        colorscale="agsunset",
    )
    fig["layout"]["yaxis"]["autorange"] = "reversed"  # l -> r
    fig.update_layout(title_text="<i><b>Correlation matrix</b></i>")
    fig["data"][0]["showscale"] = True

    extra_data["correlation"] = context.log_artifact(
        PlotlyArtifact(key="correlation", figure=fig),
        local_path=f"{plots_dest}/correlation.html",
    )
 - code_origin: '' origin_filename: '' - description: describe and visualizes dataset stats - disable_auto_mount: false - default_handler: analyze -verbose: false -metadata: - tag: '' - name: describe - categories: - - data-analysis -kind: job + code_origin: '' + functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Generated by nuclio.export.NuclioExporter

import warnings
from typing import Union

import mlrun
import numpy as np

warnings.simplefilter(action="ignore", category=FutureWarning)

import mlrun.feature_store as fstore
import pandas as pd
import plotly.express as px
import plotly.figure_factory as ff
import plotly.graph_objects as go
from mlrun.artifacts import (
    Artifact,
    DatasetArtifact,
    PlotlyArtifact,
    TableArtifact,
    update_dataset_meta,
)
from mlrun.datastore import DataItem
from mlrun.execution import MLClientCtx
from mlrun.feature_store import FeatureSet
from plotly.subplots import make_subplots

pd.set_option("display.float_format", lambda x: "%.2f" % x)
MAX_SIZE_OF_DF = 500000


def analyze(
    context: MLClientCtx,
    name: str = "dataset",
    table: Union[FeatureSet, DataItem] = None,
    label_column: str = None,
    plots_dest: str = "plots",
    random_state: int = 1,
    problem_type: str = "classification",
    dask_key: str = "dask_key",
    dask_function: str = None,
    dask_client=None,
) -> None:
    """
    The function will output the following artifacts per
    column within the data frame (based on data types)
    If the data has more than 500,000 sample we
    sample randomly 500,000 samples:

    describe csv
    histograms
    scatter-2d
    violin chart
    correlation-matrix chart
    correlation-matrix csv
    imbalance pie chart
    imbalance-weights-vec csv

    :param context:                 The function context
    :param name:                    Key of dataset to database ("dataset" for default)
    :param table:                   MLRun input pointing to pandas dataframe (csv/parquet file path) or FeatureSet
                                    as param
    :param label_column:            Ground truth column label
    :param plots_dest:              Destination folder of summary plots (relative to artifact_path)
                                    ("plots" for default)
    :param random_state:            When the table has more than 500,000 samples, we sample randomly 500,000 samples
    :param problem_type             The type of the ML problem the data facing - regression, classification or None
                                    (classification for default)
    :param dask_key:                Key of dataframe in dask client "datasets" attribute
    :param dask_function:           Dask function url (db://..)
    :param dask_client:             Dask client object
    """
    data_item, featureset, creat, update = False, False, False, False
    get_from_table = True
    if dask_function or dask_client:
        data_item, creat = True, True
        if dask_function:
            client = mlrun.import_function(dask_function).client
        elif dask_client:
            client = dask_client
        else:
            raise ValueError("dask client was not provided")

        if dask_key in client.datasets:
            df = client.get_dataset(dask_key)
            data_item, creat, get_from_table = True, True, False
        elif table:
            get_from_table = True
        else:
            context.logger.info(
                f"only these datasets are available {client.datasets} in client {client}"
            )
            raise Exception("dataset not found on dask cluster")

    if get_from_table:
        if type(table) == DataItem:
            if table.meta is None:
                data_item, creat, update = True, True, False
            elif table.meta.kind == "dataset":
                data_item, creat, update = True, False, True
            elif table.meta.kind == "FeatureVector":
                data_item, creat, update = True, False, False
            elif table.meta.kind == "FeatureSet":
                featureset, creat, update = True, False, False

        if data_item:
            df = table.as_df()
        elif featureset:
            project_name, set_name = (
                table._path.split("/")[2],
                table._path.split("/")[4],
            )
            feature_set = fstore.get_feature_set(
                f"store://feature-sets/{project_name}/{set_name}"
            )
            df = feature_set.to_dataframe()
        else:
            context.logger.error(f"Wrong table type.")
            return

    if df.size > MAX_SIZE_OF_DF:
        df = df.sample(n=int(MAX_SIZE_OF_DF / df.shape[1]), random_state=random_state)
    extra_data = {}

    if label_column not in df.columns:
        label_column = None

    extra_data["describe csv"] = context.log_artifact(
        TableArtifact("describe-csv", df=df.describe()),
        local_path=f"{plots_dest}/describe.csv",
    )

    try:
        _create_histogram_mat_artifact(
            context, df, extra_data, label_column, plots_dest
        )
    except Exception as e:
        context.logger.warn(f"Failed to create histogram matrix artifact due to: {e}")
    try:
        _create_features_histogram_artifacts(
            context, df, extra_data, label_column, plots_dest, problem_type
        )
    except Exception as e:
        context.logger.warn(f"Failed to create pairplot histograms due to: {e}")
    try:
        _create_features_2d_scatter_artifacts(
            context, df, extra_data, label_column, plots_dest, problem_type
        )
    except Exception as e:
        context.logger.warn(f"Failed to create pairplot 2d_scatter due to: {e}")
    try:
        _create_violin_artifact(context, df, extra_data, plots_dest)
    except Exception as e:
        context.logger.warn(f"Failed to create violin distribution plots due to: {e}")
    try:
        _create_imbalance_artifact(
            context, df, extra_data, label_column, plots_dest, problem_type
        )
    except Exception as e:
        context.logger.warn(f"Failed to create class imbalance plot due to: {e}")
    try:
        _create_corr_artifact(context, df, extra_data, label_column, plots_dest)
    except Exception as e:
        context.logger.warn(f"Failed to create features correlation plot due to: {e}")

    if not data_item:
        return

    artifact = table.artifact_url
    if creat:  # dataset not stored
        artifact = DatasetArtifact(
            key="dataset", stats=True, df=df, extra_data=extra_data
        )
        artifact = context.log_artifact(artifact, db_key=name)
        context.logger.info(f"The data set is logged to the project under {name} name")

    if update:
        update_dataset_meta(artifact, extra_data=extra_data)
        context.logger.info(f"The data set named {name} is updated")

    # TODO : 3-D plot on on selected features.
    # TODO : Reintegration plot on on selected features.
    # TODO : PCA plot (with options)


def _create_histogram_mat_artifact(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
):
    """
    Create and log a histogram matrix artifact
    """
    context.log_artifact(
        item=Artifact(
            key="hist",
            body=b"<b> Deprecated, see the artifacts scatter-2d "
            b"and histograms instead<b>",
        ),
        local_path=f"{plots_dest}/hist.html",
    )


def _create_features_histogram_artifacts(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
    problem_type: str,
):
    """
    Create and log a histogram artifact for each feature
    """

    figs = dict()
    first_feature_name = ""
    if label_column is not None and problem_type == "classification":
        all_labels = df[label_column].unique()
    visible = True
    for column_name in df.columns:
        if column_name == label_column:
            continue

        if label_column is not None and problem_type == "classification":
            for label in all_labels:
                sub_fig = go.Histogram(
                    histfunc="count",
                    x=df.loc[df[label_column] == label][column_name],
                    name=str(label),
                    visible=visible,
                )
                figs[f"{column_name}@?@{label}"] = sub_fig
        else:
            sub_fig = go.Histogram(histfunc="count", x=df[column_name], visible=visible)
            figs[f"{column_name}@?@{1}"] = sub_fig
        if visible:
            first_feature_name = column_name
        visible = False

    fig = go.Figure()
    for k in figs.keys():
        fig.add_trace(figs[k])

    fig.update_layout(
        updatemenus=[
            {
                "buttons": [
                    {
                        "label": column_name,
                        "method": "update",
                        "args": [
                            {
                                "visible": [
                                    key.split("@?@")[0] == column_name
                                    for key in figs.keys()
                                ],
                                "xaxis": {
                                    "range": [
                                        min(df[column_name]),
                                        max(df[column_name]),
                                    ]
                                },
                            },
                            {"title": f"<i><b>Histogram of {column_name}</b></i>"},
                        ],
                    }
                    for column_name in df.columns
                    if column_name != label_column
                ],
                "direction": "down",
                "pad": {"r": 10, "t": 10},
                "showactive": True,
                "x": 0.25,
                "xanchor": "left",
                "y": 1.1,
                "yanchor": "top",
            }
        ],
        annotations=[
            dict(
                text="Select Feature Name ",
                showarrow=False,
                x=0,
                y=1.05,
                yref="paper",
                xref="paper",
                align="left",
                xanchor="left",
                yanchor="top",
                font={
                    "color": "blue",
                },
            )
        ],
    )

    fig.update_layout(
        width=600,
        height=400,
        autosize=False,
        margin=dict(t=100, b=0, l=0, r=0),
        template="plotly_white",
    )

    fig.update_layout(title_text=f"<i><b>Histograms of {first_feature_name}</b></i>")
    extra_data[f"histograms"] = context.log_artifact(
        PlotlyArtifact(key=f"histograms", figure=fig),
        local_path=f"{plots_dest}/histograms.html",
    )


def _create_features_2d_scatter_artifacts(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
    problem_type: str,
):
    """
    Create and log a scatter-2d artifact for each couple of features
    """
    features = [
        column_name for column_name in df.columns if column_name != label_column
    ]
    max_feature_len = float(max(len(elem) for elem in features))
    if label_column is not None:
        labels = sorted(df[label_column].unique())
    else:
        labels = [None]
    fig = go.Figure()
    if label_column is not None and problem_type == "classification":
        for l in labels:
            fig.add_trace(
                go.Scatter(
                    x=df.loc[df[label_column] == l][features[0]],
                    y=df.loc[df[label_column] == l][features[0]],
                    mode="markers",
                    visible=True,
                    showlegend=True,
                    name=str(l),
                )
            )
    elif label_column is None:
        fig.add_trace(
            go.Scatter(
                x=df[features[0]],
                y=df[features[0]],
                mode="markers",
                visible=True,
            )
        )
    elif problem_type == "regression":
        fig.add_trace(
            go.Scatter(
                x=df[features[0]],
                y=df[features[0]],
                mode="markers",
                marker=dict(
                    color=df[label_column], colorscale="Viridis", showscale=True
                ),
                visible=True,
            )
        )

    x_buttons = []
    y_buttons = []

    for ncol in features:
        if problem_type == "classification" and label_column is not None:
            x_buttons.append(
                dict(
                    method="update",
                    label=ncol,
                    args=[
                        {"x": [df.loc[df[label_column] == l][ncol] for l in labels]},
                        np.arange(len(labels)).tolist(),
                    ],
                )
            )

            y_buttons.append(
                dict(
                    method="update",
                    label=ncol,
                    args=[
                        {"y": [df.loc[df[label_column] == l][ncol] for l in labels]},
                        np.arange(len(labels)).tolist(),
                    ],
                )
            )
        else:
            x_buttons.append(
                dict(method="update", label=ncol, args=[{"x": [df[ncol]]}])
            )

            y_buttons.append(
                dict(method="update", label=ncol, args=[{"y": [df[ncol]]}])
            )

    # Pass buttons to the updatemenus argument
    fig.update_layout(
        updatemenus=[
            dict(buttons=x_buttons, direction="up", x=0.5, y=-0.1),
            dict(buttons=y_buttons, direction="down", x=-max_feature_len / 100, y=0.5),
        ]
    )

    fig.update_layout(
        width=600,
        height=400,
        autosize=False,
        margin=dict(t=100, b=0, l=0, r=0),
        template="plotly_white",
    )

    fig.update_layout(title_text=f"<i><b>Scatter-2d</b></i>")
    extra_data[f"scatter-2d"] = context.log_artifact(
        PlotlyArtifact(key=f"scatter-2d", figure=fig),
        local_path=f"{plots_dest}/scatter-2d.html",
    )


def _create_violin_artifact(
    context: MLClientCtx, df: pd.DataFrame, extra_data: dict, plots_dest: str
):
    """
    Create and log a violin artifact
    """
    cols = 5
    rows = (df.shape[1] // cols) + 1
    fig = make_subplots(rows=rows, cols=cols)

    plot_num = 0

    for column_name in df.columns:
        column_data = df[column_name]
        violin = go.Violin(
            x=[column_name] * column_data.shape[0],
            y=column_data,
            name=column_name,
        )

        fig.add_trace(
            violin,
            row=(plot_num // cols) + 1,
            col=(plot_num % cols) + 1,
        )

        plot_num += 1

    fig["layout"].update(
        height=(rows + 1) * 200,
        width=(cols + 1) * 200,
        title="<i><b>Violin Plots</b></i>",
    )

    fig.update_layout(showlegend=False)
    extra_data["violin"] = context.log_artifact(
        PlotlyArtifact(key="violin", figure=fig),
        local_path=f"{plots_dest}/violin.html",
    )


def _create_imbalance_artifact(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
    problem_type: str,
):
    """
    Create and log an imbalance class artifact (csv + plot)
    """
    if label_column:
        if problem_type == "classification":
            values_column = "count"
            labels_count = df[label_column].value_counts().sort_index()
            df_labels_count = pd.DataFrame(labels_count)
            df_labels_count[label_column] = labels_count.index
            df_labels_count.rename(columns={"": values_column}, inplace=True)
            df_labels_count[values_column] = df_labels_count[values_column] / sum(
                df_labels_count[values_column]
            )
            fig = px.pie(df_labels_count, names=label_column, values=values_column)
        else:
            fig = px.histogram(
                histfunc="count",
                x=df[label_column],
            )
            hist = np.histogram(df[label_column])
            df_labels_count = pd.DataFrame(
                {"min_val": hist[1], "count": hist[0].tolist() + [0]}
            )
        fig.update_layout(title_text="<i><b>Labels Imbalance</b></i>")
        extra_data["imbalance"] = context.log_artifact(
            PlotlyArtifact(key="imbalance", figure=fig),
            local_path=f"{plots_dest}/imbalance.html",
        )
        extra_data["imbalance-csv"] = context.log_artifact(
            TableArtifact("imbalance-weights-vec", df=df_labels_count),
            local_path=f"{plots_dest}/imbalance-weights-vec.csv",
        )


def _create_corr_artifact(
    context: MLClientCtx,
    df: pd.DataFrame,
    extra_data: dict,
    label_column: str,
    plots_dest: str,
):
    """
    Create and log an correlation-matrix artifact (csv + plot)
    """
    if label_column is not None:
        df = df.drop([label_column], axis=1)
    tblcorr = df.corr(numeric_only=True)
    extra_data["correlation-matrix-csv"] = context.log_artifact(
        TableArtifact("correlation-matrix-csv", df=tblcorr, visible=True),
        local_path=f"{plots_dest}/correlation-matrix.csv",
    )

    z = tblcorr.values.tolist()
    z_text = [["{:.2f}".format(y) for y in x] for x in z]
    fig = ff.create_annotated_heatmap(
        z,
        x=list(tblcorr.columns),
        y=list(tblcorr.columns),
        annotation_text=z_text,
        colorscale="agsunset",
    )
    fig["layout"]["yaxis"]["autorange"] = "reversed"  # l -> r
    fig.update_layout(title_text="<i><b>Correlation matrix</b></i>")
    fig["data"][0]["showscale"] = True

    extra_data["correlation"] = context.log_artifact(
        PlotlyArtifact(key="correlation", figure=fig),
        local_path=f"{plots_dest}/correlation.html",
    )
 diff --git a/functions/src/describe/item.yaml b/functions/src/describe/item.yaml index da26f1501..a1aa47372 100644 --- a/functions/src/describe/item.yaml +++ b/functions/src/describe/item.yaml @@ -11,7 +11,7 @@ labels: author: Iguazio maintainers: [] marketplaceType: '' -mlrunVersion: 1.7.0 +mlrunVersion: 1.10.0 name: describe platformVersion: 3.5.3 spec: @@ -21,4 +21,4 @@ spec: kind: job requirements: [] url: '' -version: 1.4.0 +version: 1.5.0 diff --git a/functions/src/describe/requirements.txt b/functions/src/describe/requirements.txt index 15492b176..ac445e6d6 100644 --- a/functions/src/describe/requirements.txt +++ b/functions/src/describe/requirements.txt @@ -1,4 +1,4 @@ -scikit-learn~=1.0.2 +scikit-learn~=1.5.2 plotly~=5.23 pytest~=7.0.1 matplotlib~=3.5.1 diff --git a/functions/src/gen_class_data/function.yaml b/functions/src/gen_class_data/function.yaml index 1769bec07..fa802964e 100644 --- a/functions/src/gen_class_data/function.yaml +++ b/functions/src/gen_class_data/function.yaml @@ -1,13 +1,15 @@ metadata: - categories: - - data-generation tag: '' name: gen-class-data + categories: + - data-generation +verbose: false spec: description: Create a binary classification sample dataset and save. - default_handler: gen_class_data entry_points: gen_class_data: + lineno: 22 + has_varargs: false has_kwargs: false parameters: - name: context @@ -48,7 +50,6 @@ spec: - name: sk_params doc: additional parameters for `sklearn.datasets.make_classification` default: {} - lineno: 22 doc: 'Create a binary classification sample dataset and save. If no filename is given it will default to: @@ -59,14 +60,13 @@ spec: Additional scikit-learn parameters can be set using **sk_params, please see https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html for more details.' - has_varargs: false name: gen_class_data - command: '' - disable_auto_mount: false - image: mlrun/mlrun build: origin_filename: '' functionSourceCode: IyBDb3B5cmlnaHQgMjAxOSBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgICAgaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wCiMKIyBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlCiMgZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gIkFTIElTIiBCQVNJUywKIyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KIyBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kCiMgbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuCiMKaW1wb3J0IHBhbmRhcyBhcyBwZApmcm9tIHR5cGluZyBpbXBvcnQgT3B0aW9uYWwsIExpc3QKZnJvbSBza2xlYXJuLmRhdGFzZXRzIGltcG9ydCBtYWtlX2NsYXNzaWZpY2F0aW9uCgpmcm9tIG1scnVuLmV4ZWN1dGlvbiBpbXBvcnQgTUxDbGllbnRDdHgKCgpkZWYgZ2VuX2NsYXNzX2RhdGEoCiAgICAgICAgY29udGV4dDogTUxDbGllbnRDdHgsCiAgICAgICAgbl9zYW1wbGVzOiBpbnQsCiAgICAgICAgbV9mZWF0dXJlczogaW50LAogICAgICAgIGtfY2xhc3NlczogaW50LAogICAgICAgIGhlYWRlcjogT3B0aW9uYWxbTGlzdFtzdHJdXSwKICAgICAgICBsYWJlbF9jb2x1bW46IE9wdGlvbmFsW3N0cl0gPSAibGFiZWxzIiwKICAgICAgICB3ZWlnaHQ6IGZsb2F0ID0gMC41LAogICAgICAgIHJhbmRvbV9zdGF0ZTogaW50ID0gMSwKICAgICAgICBrZXk6IHN0ciA9ICJjbGFzc2lmaWVyLWRhdGEiLAogICAgICAgIGZpbGVfZXh0OiBzdHIgPSAicGFycXVldCIsCiAgICAgICAgc2tfcGFyYW1zPXt9Cik6CiAgICAiIiJDcmVhdGUgYSBiaW5hcnkgY2xhc3NpZmljYXRpb24gc2FtcGxlIGRhdGFzZXQgYW5kIHNhdmUuCiAgICBJZiBubyBmaWxlbmFtZSBpcyBnaXZlbiBpdCB3aWxsIGRlZmF1bHQgdG86CiAgICAic2ltZGF0YS17bl9zYW1wbGVzfVh7bV9mZWF0dXJlc30ucGFycXVldCIuCgogICAgQWRkaXRpb25hbCBzY2lraXQtbGVhcm4gcGFyYW1ldGVycyBjYW4gYmUgc2V0IHVzaW5nICoqc2tfcGFyYW1zLCBwbGVhc2Ugc2VlIGh0dHBzOi8vc2Npa2l0LWxlYXJuLm9yZy9zdGFibGUvbW9kdWxlcy9nZW5lcmF0ZWQvc2tsZWFybi5kYXRhc2V0cy5tYWtlX2NsYXNzaWZpY2F0aW9uLmh0bWwgZm9yIG1vcmUgZGV0YWlscy4KCiAgICA6cGFyYW0gY29udGV4dDogICAgICAgZnVuY3Rpb24gY29udGV4dAogICAgOnBhcmFtIG5fc2FtcGxlczogICAgIG51bWJlciBvZiByb3dzL3NhbXBsZXMKICAgIDpwYXJhbSBtX2ZlYXR1cmVzOiAgICBudW1iZXIgb2YgY29scy9mZWF0dXJlcwogICAgOnBhcmFtIGtfY2xhc3NlczogICAgIG51bWJlciBvZiBjbGFzc2VzCiAgICA6cGFyYW0gaGVhZGVyOiAgICAgICAgaGVhZGVyIGZvciBmZWF0dXJlcyBhcnJheQogICAgOnBhcmFtIGxhYmVsX2NvbHVtbjogIGNvbHVtbiBuYW1lIG9mIGdyb3VuZC10cnV0aCBzZXJpZXMKICAgIDpwYXJhbSB3ZWlnaHQ6ICAgICAgICBmcmFjdGlvbiBvZiBzYW1wbGUgbmVnYXRpdmUgdmFsdWUgKGdyb3VuZC10cnV0aD0wKQogICAgOnBhcmFtIHJhbmRvbV9zdGF0ZTogIHJuZyBzZWVkIChzZWUgaHR0cHM6Ly9zY2lraXQtbGVhcm4ub3JnL3N0YWJsZS9nbG9zc2FyeS5odG1sI3Rlcm0tcmFuZG9tLXN0YXRlKQogICAgOnBhcmFtIGtleTogICAgICAgICAgIGtleSBvZiBkYXRhIGluIGFydGlmYWN0IHN0b3JlCiAgICA6cGFyYW0gZmlsZV9leHQ6ICAgICAgKHBxdCkgZXh0ZW5zaW9uIGZvciBwYXJxdWV0IGZpbGUKICAgIDpwYXJhbSBza19wYXJhbXM6ICAgICBhZGRpdGlvbmFsIHBhcmFtZXRlcnMgZm9yIGBza2xlYXJuLmRhdGFzZXRzLm1ha2VfY2xhc3NpZmljYXRpb25gCiAgICAiIiIKICAgIGZlYXR1cmVzLCBsYWJlbHMgPSBtYWtlX2NsYXNzaWZpY2F0aW9uKAogICAgICAgIG5fc2FtcGxlcz1uX3NhbXBsZXMsCiAgICAgICAgbl9mZWF0dXJlcz1tX2ZlYXR1cmVzLAogICAgICAgIHdlaWdodHM9d2VpZ2h0LAogICAgICAgIG5fY2xhc3Nlcz1rX2NsYXNzZXMsCiAgICAgICAgcmFuZG9tX3N0YXRlPXJhbmRvbV9zdGF0ZSwKICAgICAgICAqKnNrX3BhcmFtcykKCiAgICAjIG1ha2UgZGF0YWZyYW1lcywgYWRkIGNvbHVtbiBuYW1lcywgY29uY2F0ZW5hdGUgKFgsIHkpCiAgICBYID0gcGQuRGF0YUZyYW1lKGZlYXR1cmVzKQogICAgaWYgbm90IGhlYWRlcjoKICAgICAgICBYLmNvbHVtbnMgPSBbImZlYXRfIiArIHN0cih4KSBmb3IgeCBpbiByYW5nZShtX2ZlYXR1cmVzKV0KICAgIGVsc2U6CiAgICAgICAgWC5jb2x1bW5zID0gaGVhZGVyCgogICAgeSA9IHBkLkRhdGFGcmFtZShsYWJlbHMsIGNvbHVtbnM9W2xhYmVsX2NvbHVtbl0pCiAgICBkYXRhID0gcGQuY29uY2F0KFtYLCB5XSwgYXhpcz0xKQoKICAgIGNvbnRleHQubG9nX2RhdGFzZXQoa2V5LCBkZj1kYXRhLCBmb3JtYXQ9ZmlsZV9leHQsIGluZGV4PUZhbHNlKQo= code_origin: '' + command: '' + image: mlrun/mlrun + default_handler: gen_class_data + disable_auto_mount: false kind: job -verbose: false diff --git a/functions/src/gen_class_data/item.yaml b/functions/src/gen_class_data/item.yaml index 30f5cd21c..082b00305 100644 --- a/functions/src/gen_class_data/item.yaml +++ b/functions/src/gen_class_data/item.yaml @@ -11,7 +11,7 @@ labels: author: Iguazio maintainers: [] marketplaceType: '' -mlrunVersion: 1.7.0 +mlrunVersion: 1.10.0 name: gen_class_data platformVersion: 3.5.3 spec: @@ -21,4 +21,4 @@ spec: kind: job requirements: [] url: '' -version: 1.3.0 +version: 1.4.0 diff --git a/functions/src/gen_class_data/requirements.txt b/functions/src/gen_class_data/requirements.txt index d7dbe376b..e265290f6 100644 --- a/functions/src/gen_class_data/requirements.txt +++ b/functions/src/gen_class_data/requirements.txt @@ -1,2 +1,2 @@ pandas -scikit-learn==1.0.2 \ No newline at end of file +scikit-learn~=1.5.2 \ No newline at end of file diff --git a/functions/src/gen_class_data/test_gen_class_data.py b/functions/src/gen_class_data/test_gen_class_data.py index e06eeb16b..990075dec 100644 --- a/functions/src/gen_class_data/test_gen_class_data.py +++ b/functions/src/gen_class_data/test_gen_class_data.py @@ -36,4 +36,7 @@ def test_gen_class_data(): local=True, artifact_path="./artifacts", ) - assert os.path.isfile(run.status.artifacts[0]['spec']['target_path']), 'dataset is not available' + # In local mode, artifacts are in function-name/iteration subdirectory + # Default key is "classifier-data" (can be overridden in params) + dataset_path = "./artifacts/test-gen-class-data-gen-class-data/0/classifier-data.csv" + assert os.path.isfile(dataset_path), f'dataset is not available at {dataset_path}' diff --git a/functions/src/onnx_utils/function.yaml b/functions/src/onnx_utils/function.yaml index 05a0f0bc2..091002cdc 100644 --- a/functions/src/onnx_utils/function.yaml +++ b/functions/src/onnx_utils/function.yaml @@ -1,39 +1,13 @@ -kind: job metadata: + name: onnx-utils + tag: '' categories: - utilities - deep-learning - name: onnx-utils - tag: '' -verbose: false +kind: job spec: - build: - code_origin: '' - base_image: mlrun/mlrun - origin_filename: '' - functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Callable, Dict, List, Tuple

import mlrun


class _ToONNXConversions:
    """
    An ONNX conversion functions library class.
    """

    @staticmethod
    def tf_keras_to_onnx(
        model_handler,
        onnx_model_name: str = None,
        optimize_model: bool = True,
        input_signature: List[Tuple[Tuple[int], str]] = None,
    ):
        """
        Convert a TF.Keras model to an ONNX model and log it back to MLRun as a new model object.

        :param model_handler:   An initialized TFKerasModelHandler with a loaded model to convert to ONNX.
        :param onnx_model_name: The name to use to log the converted ONNX model. If not given, the given `model_name`
                                will be used with an additional suffix `_onnx`. Defaulted to None.
        :param optimize_model:  Whether or not to optimize the ONNX model using 'onnxoptimizer' before saving the model.
                                Defaulted to True.
        :param input_signature: A list of the input layers shape and data type properties. Expected to receive a list
                                where each element is an input layer tuple. An input layer tuple is a tuple of:
                                [0] = Layer's shape, a tuple of integers.
                                [1] = Layer's data type, a mlrun.data_types.ValueType string.
                                If None, the input signature will be tried to be read from the model artifact. Defaulted
                                to None.
        """
        # Import the framework and handler:
        import tensorflow as tf
        from mlrun.frameworks.tf_keras import TFKerasUtils

        # Check the given 'input_signature' parameter:
        if input_signature is None:
            # Read the inputs from the model:
            try:
                model_handler.read_inputs_from_model()
            except Exception as error:
                raise mlrun.errors.MLRunRuntimeError(
                    f"Please provide the 'input_signature' parameter. The function tried reading the input layers "
                    f"information automatically but failed with the following error: {error}"
                )
        else:
            # Parse the 'input_signature' parameter:
            input_signature = [
                tf.TensorSpec(
                    shape=shape,
                    dtype=TFKerasUtils.convert_value_type_to_tf_dtype(
                        value_type=value_type
                    ),
                )
                for (shape, value_type) in input_signature
            ]

        # Convert to ONNX:
        model_handler.to_onnx(
            model_name=onnx_model_name,
            input_signature=input_signature,
            optimize=optimize_model,
        )

    @staticmethod
    def pytorch_to_onnx(
        model_handler,
        onnx_model_name: str = None,
        optimize_model: bool = True,
        input_signature: List[Tuple[Tuple[int, ...], str]] = None,
        input_layers_names: List[str] = None,
        output_layers_names: List[str] = None,
        dynamic_axes: Dict[str, Dict[int, str]] = None,
        is_batched: bool = True,
    ):
        """
        Convert a PyTorch model to an ONNX model and log it back to MLRun as a new model object.

        :param model_handler:       An initialized PyTorchModelHandler with a loaded model to convert to ONNX.
        :param onnx_model_name:     The name to use to log the converted ONNX model. If not given, the given
                                    `model_name` will be used with an additional suffix `_onnx`. Defaulted to None.
        :param optimize_model:      Whether or not to optimize the ONNX model using 'onnxoptimizer' before saving the
                                    model. Defaulted to True.
        :param input_signature:     A list of the input layers shape and data type properties. Expected to receive a
                                    list where each element is an input layer tuple. An input layer tuple is a tuple of:
                                    [0] = Layer's shape, a tuple of integers.
                                    [1] = Layer's data type, a mlrun.data_types.ValueType string.
                                    If None, the input signature will be tried to be read from the model artifact.
                                    Defaulted to None.
        :param input_layers_names:  List of names to assign to the input nodes of the graph in order. All of the other
                                    parameters (inner layers) can be set as well by passing additional names in the
                                    list. The order is by the order of the parameters in the model. If None, the inputs
                                    will be read from the handler's inputs. If its also None, it is defaulted to:
                                    "input_0", "input_1", ...
        :param output_layers_names: List of names to assign to the output nodes of the graph in order. If None, the
                                    outputs will be read from the handler's outputs. If its also None, it is defaulted
                                    to: "output_0" (for multiple outputs, this parameter must be provided).
        :param dynamic_axes:        If part of the input / output shape is dynamic, like (batch_size, 3, 32, 32) you can
                                    specify it by giving a dynamic axis to the input / output layer by its name as
                                    follows: {
                                        "input layer name": {0: "batch_size"},
                                        "output layer name": {0: "batch_size"},
                                    }
                                    If provided, the 'is_batched' flag will be ignored. Defaulted to None.
        :param is_batched:          Whether to include a batch size as the first axis in every input and output layer.
                                    Defaulted to True. Will be ignored if 'dynamic_axes' is provided.
        """
        # Import the framework and handler:
        import torch
        from mlrun.frameworks.pytorch import PyTorchUtils

        # Parse the 'input_signature' parameter:
        if input_signature is not None:
            input_signature = tuple(
                [
                    torch.zeros(
                        size=shape,
                        dtype=PyTorchUtils.convert_value_type_to_torch_dtype(
                            value_type=value_type
                        ),
                    )
                    for (shape, value_type) in input_signature
                ]
            )

        # Convert to ONNX:
        model_handler.to_onnx(
            model_name=onnx_model_name,
            input_sample=input_signature,
            optimize=optimize_model,
            input_layers_names=input_layers_names,
            output_layers_names=output_layers_names,
            dynamic_axes=dynamic_axes,
            is_batched=is_batched,
        )


# Map for getting the conversion function according to the provided framework:
_CONVERSION_MAP = {
    "tensorflow.keras": _ToONNXConversions.tf_keras_to_onnx,
    "torch": _ToONNXConversions.pytorch_to_onnx,
}  # type: Dict[str, Callable]


def to_onnx(
    context: mlrun.MLClientCtx,
    model_path: str,
    load_model_kwargs: dict = None,
    onnx_model_name: str = None,
    optimize_model: bool = True,
    framework_kwargs: Dict[str, Any] = None,
):
    """
    Convert the given model to an ONNX model.

    :param context:           The MLRun function execution context
    :param model_path:        The model path store object.
    :param load_model_kwargs: Keyword arguments to pass to the `AutoMLRun.load_model` method.
    :param onnx_model_name:   The name to use to log the converted ONNX model. If not given, the given `model_name` will
                              be used with an additional suffix `_onnx`. Defaulted to None.
    :param optimize_model:    Whether to optimize the ONNX model using 'onnxoptimizer' before saving the model.
                              Defaulted to True.
    :param framework_kwargs:  Additional arguments each framework may require to convert to ONNX. To get the doc string
                              of the desired framework onnx conversion function, pass "help".
    """
    from mlrun.frameworks.auto_mlrun.auto_mlrun import AutoMLRun

    # Get a model handler of the required framework:
    load_model_kwargs = load_model_kwargs or {}
    model_handler = AutoMLRun.load_model(
        model_path=model_path, context=context, **load_model_kwargs
    )

    # Get the model's framework:
    framework = model_handler.FRAMEWORK_NAME

    # Use the conversion map to get the specific framework to onnx conversion:
    if framework not in _CONVERSION_MAP:
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"The following framework: '{framework}', has no ONNX conversion."
        )
    conversion_function = _CONVERSION_MAP[framework]

    # Check if needed to print the function's doc string ("help" is passed):
    if framework_kwargs == "help":
        print(conversion_function.__doc__)
        return

    # Set the default empty framework kwargs if needed:
    if framework_kwargs is None:
        framework_kwargs = {}

    # Run the conversion:
    try:
        conversion_function(
            model_handler=model_handler,
            onnx_model_name=onnx_model_name,
            optimize_model=optimize_model,
            **framework_kwargs,
        )
    except TypeError as exception:
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"ERROR: A TypeError exception was raised during the conversion:\n{exception}. "
            f"Please read the {framework} framework conversion function doc string by passing 'help' in the "
            f"'framework_kwargs' dictionary parameter."
        )


def optimize(
    context: mlrun.MLClientCtx,
    model_path: str,
    handler_init_kwargs: dict = None,
    optimizations: List[str] = None,
    fixed_point: bool = False,
    optimized_model_name: str = None,
):
    """
    Optimize the given ONNX model.

    :param context:              The MLRun function execution context.
    :param model_path:           Path to the ONNX model object.
    :param handler_init_kwargs:  Keyword arguments to pass to the `ONNXModelHandler` init method preloading.
    :param optimizations:        List of possible optimizations. To see what optimizations are available, pass "help".
                                 If None, all the optimizations will be used. Defaulted to None.
    :param fixed_point:          Optimize the weights using fixed point. Defaulted to False.
    :param optimized_model_name: The name of the optimized model. If None, the original model will be overridden.
                                 Defaulted to None.
    """
    # Import the model handler:
    import onnxoptimizer
    from mlrun.frameworks.onnx import ONNXModelHandler

    # Check if needed to print the available optimizations ("help" is passed):
    if optimizations == "help":
        available_passes = "\n* ".join(onnxoptimizer.get_available_passes())
        print(f"The available optimizations are:\n* {available_passes}")
        return

    # Create the model handler:
    handler_init_kwargs = handler_init_kwargs or {}
    model_handler = ONNXModelHandler(
        model_path=model_path, context=context, **handler_init_kwargs
    )

    # Load the ONNX model:
    model_handler.load()

    # Optimize the model using the given configurations:
    model_handler.optimize(optimizations=optimizations, fixed_point=fixed_point)

    # Rename if needed:
    if optimized_model_name is not None:
        model_handler.set_model_name(model_name=optimized_model_name)

    # Log the optimized model:
    model_handler.log()
 - requirements: - - tqdm~=4.67.1 - - tensorflow~=2.19.0 - - tf_keras~=2.19.0 - - torch~=2.6.0 - - torchvision~=0.21.0 - - onnx~=1.17.0 - - onnxruntime~=1.19.2 - - onnxoptimizer~=0.3.13 - - onnxmltools~=1.13.0 - - tf2onnx~=1.16.1 - - plotly~=5.23 - with_mlrun: false - auto_build: true - disable_auto_mount: false - description: ONNX intigration in MLRun, some utils functions for the ONNX framework, - optimizing and converting models from different framework to ONNX using MLRun. - image: '' entry_points: tf_keras_to_onnx: - doc: Convert a TF.Keras model to an ONNX model and log it back to MLRun as a - new model object. name: tf_keras_to_onnx parameters: - name: model_handler @@ -58,12 +32,12 @@ spec: data type, a mlrun.data_types.ValueType string. If None, the input signature will be tried to be read from the model artifact. Defaulted to None.' default: null + doc: Convert a TF.Keras model to an ONNX model and log it back to MLRun as a + new model object. + lineno: 26 has_varargs: false has_kwargs: false - lineno: 26 pytorch_to_onnx: - doc: Convert a PyTorch model to an ONNX model and log it back to MLRun as a - new model object. name: pytorch_to_onnx parameters: - name: model_handler @@ -116,11 +90,12 @@ spec: doc: Whether to include a batch size as the first axis in every input and output layer. Defaulted to True. Will be ignored if 'dynamic_axes' is provided. default: true + doc: Convert a PyTorch model to an ONNX model and log it back to MLRun as a + new model object. + lineno: 81 has_varargs: false has_kwargs: false - lineno: 81 to_onnx: - doc: Convert the given model to an ONNX model. name: to_onnx parameters: - name: context @@ -150,11 +125,11 @@ spec: get the doc string of the desired framework onnx conversion function, pass "help". default: null + doc: Convert the given model to an ONNX model. + lineno: 160 has_varargs: false has_kwargs: false - lineno: 160 optimize: - doc: Optimize the given ONNX model. name: optimize parameters: - name: context @@ -181,9 +156,34 @@ spec: doc: The name of the optimized model. If None, the original model will be overridden. Defaulted to None. default: null + doc: Optimize the given ONNX model. + lineno: 224 has_varargs: false has_kwargs: false - lineno: 224 + image: '' default_handler: to_onnx allow_empty_resources: true command: '' + disable_auto_mount: false + description: ONNX intigration in MLRun, some utils functions for the ONNX framework, + optimizing and converting models from different framework to ONNX using MLRun. + build: + functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Callable, Dict, List, Tuple

import mlrun


class _ToONNXConversions:
    """
    An ONNX conversion functions library class.
    """

    @staticmethod
    def tf_keras_to_onnx(
        model_handler,
        onnx_model_name: str = None,
        optimize_model: bool = True,
        input_signature: List[Tuple[Tuple[int], str]] = None,
    ):
        """
        Convert a TF.Keras model to an ONNX model and log it back to MLRun as a new model object.

        :param model_handler:   An initialized TFKerasModelHandler with a loaded model to convert to ONNX.
        :param onnx_model_name: The name to use to log the converted ONNX model. If not given, the given `model_name`
                                will be used with an additional suffix `_onnx`. Defaulted to None.
        :param optimize_model:  Whether or not to optimize the ONNX model using 'onnxoptimizer' before saving the model.
                                Defaulted to True.
        :param input_signature: A list of the input layers shape and data type properties. Expected to receive a list
                                where each element is an input layer tuple. An input layer tuple is a tuple of:
                                [0] = Layer's shape, a tuple of integers.
                                [1] = Layer's data type, a mlrun.data_types.ValueType string.
                                If None, the input signature will be tried to be read from the model artifact. Defaulted
                                to None.
        """
        # Import the framework and handler:
        import tensorflow as tf
        from mlrun.frameworks.tf_keras import TFKerasUtils

        # Check the given 'input_signature' parameter:
        if input_signature is None:
            # Read the inputs from the model:
            try:
                model_handler.read_inputs_from_model()
            except Exception as error:
                raise mlrun.errors.MLRunRuntimeError(
                    f"Please provide the 'input_signature' parameter. The function tried reading the input layers "
                    f"information automatically but failed with the following error: {error}"
                )
        else:
            # Parse the 'input_signature' parameter:
            input_signature = [
                tf.TensorSpec(
                    shape=shape,
                    dtype=TFKerasUtils.convert_value_type_to_tf_dtype(
                        value_type=value_type
                    ),
                )
                for (shape, value_type) in input_signature
            ]

        # Convert to ONNX:
        model_handler.to_onnx(
            model_name=onnx_model_name,
            input_signature=input_signature,
            optimize=optimize_model,
        )

    @staticmethod
    def pytorch_to_onnx(
        model_handler,
        onnx_model_name: str = None,
        optimize_model: bool = True,
        input_signature: List[Tuple[Tuple[int, ...], str]] = None,
        input_layers_names: List[str] = None,
        output_layers_names: List[str] = None,
        dynamic_axes: Dict[str, Dict[int, str]] = None,
        is_batched: bool = True,
    ):
        """
        Convert a PyTorch model to an ONNX model and log it back to MLRun as a new model object.

        :param model_handler:       An initialized PyTorchModelHandler with a loaded model to convert to ONNX.
        :param onnx_model_name:     The name to use to log the converted ONNX model. If not given, the given
                                    `model_name` will be used with an additional suffix `_onnx`. Defaulted to None.
        :param optimize_model:      Whether or not to optimize the ONNX model using 'onnxoptimizer' before saving the
                                    model. Defaulted to True.
        :param input_signature:     A list of the input layers shape and data type properties. Expected to receive a
                                    list where each element is an input layer tuple. An input layer tuple is a tuple of:
                                    [0] = Layer's shape, a tuple of integers.
                                    [1] = Layer's data type, a mlrun.data_types.ValueType string.
                                    If None, the input signature will be tried to be read from the model artifact.
                                    Defaulted to None.
        :param input_layers_names:  List of names to assign to the input nodes of the graph in order. All of the other
                                    parameters (inner layers) can be set as well by passing additional names in the
                                    list. The order is by the order of the parameters in the model. If None, the inputs
                                    will be read from the handler's inputs. If its also None, it is defaulted to:
                                    "input_0", "input_1", ...
        :param output_layers_names: List of names to assign to the output nodes of the graph in order. If None, the
                                    outputs will be read from the handler's outputs. If its also None, it is defaulted
                                    to: "output_0" (for multiple outputs, this parameter must be provided).
        :param dynamic_axes:        If part of the input / output shape is dynamic, like (batch_size, 3, 32, 32) you can
                                    specify it by giving a dynamic axis to the input / output layer by its name as
                                    follows: {
                                        "input layer name": {0: "batch_size"},
                                        "output layer name": {0: "batch_size"},
                                    }
                                    If provided, the 'is_batched' flag will be ignored. Defaulted to None.
        :param is_batched:          Whether to include a batch size as the first axis in every input and output layer.
                                    Defaulted to True. Will be ignored if 'dynamic_axes' is provided.
        """
        # Import the framework and handler:
        import torch
        from mlrun.frameworks.pytorch import PyTorchUtils

        # Parse the 'input_signature' parameter:
        if input_signature is not None:
            input_signature = tuple(
                [
                    torch.zeros(
                        size=shape,
                        dtype=PyTorchUtils.convert_value_type_to_torch_dtype(
                            value_type=value_type
                        ),
                    )
                    for (shape, value_type) in input_signature
                ]
            )

        # Convert to ONNX:
        model_handler.to_onnx(
            model_name=onnx_model_name,
            input_sample=input_signature,
            optimize=optimize_model,
            input_layers_names=input_layers_names,
            output_layers_names=output_layers_names,
            dynamic_axes=dynamic_axes,
            is_batched=is_batched,
        )


# Map for getting the conversion function according to the provided framework:
_CONVERSION_MAP = {
    "tensorflow.keras": _ToONNXConversions.tf_keras_to_onnx,
    "torch": _ToONNXConversions.pytorch_to_onnx,
}  # type: Dict[str, Callable]


def to_onnx(
    context: mlrun.MLClientCtx,
    model_path: str,
    load_model_kwargs: dict = None,
    onnx_model_name: str = None,
    optimize_model: bool = True,
    framework_kwargs: Dict[str, Any] = None,
):
    """
    Convert the given model to an ONNX model.

    :param context:           The MLRun function execution context
    :param model_path:        The model path store object.
    :param load_model_kwargs: Keyword arguments to pass to the `AutoMLRun.load_model` method.
    :param onnx_model_name:   The name to use to log the converted ONNX model. If not given, the given `model_name` will
                              be used with an additional suffix `_onnx`. Defaulted to None.
    :param optimize_model:    Whether to optimize the ONNX model using 'onnxoptimizer' before saving the model.
                              Defaulted to True.
    :param framework_kwargs:  Additional arguments each framework may require to convert to ONNX. To get the doc string
                              of the desired framework onnx conversion function, pass "help".
    """
    from mlrun.frameworks.auto_mlrun.auto_mlrun import AutoMLRun

    # Get a model handler of the required framework:
    load_model_kwargs = load_model_kwargs or {}
    model_handler = AutoMLRun.load_model(
        model_path=model_path, context=context, **load_model_kwargs
    )

    # Get the model's framework:
    framework = model_handler.FRAMEWORK_NAME

    # Use the conversion map to get the specific framework to onnx conversion:
    if framework not in _CONVERSION_MAP:
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"The following framework: '{framework}', has no ONNX conversion."
        )
    conversion_function = _CONVERSION_MAP[framework]

    # Check if needed to print the function's doc string ("help" is passed):
    if framework_kwargs == "help":
        print(conversion_function.__doc__)
        return

    # Set the default empty framework kwargs if needed:
    if framework_kwargs is None:
        framework_kwargs = {}

    # Run the conversion:
    try:
        conversion_function(
            model_handler=model_handler,
            onnx_model_name=onnx_model_name,
            optimize_model=optimize_model,
            **framework_kwargs,
        )
    except TypeError as exception:
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"ERROR: A TypeError exception was raised during the conversion:\n{exception}. "
            f"Please read the {framework} framework conversion function doc string by passing 'help' in the "
            f"'framework_kwargs' dictionary parameter."
        )


def optimize(
    context: mlrun.MLClientCtx,
    model_path: str,
    handler_init_kwargs: dict = None,
    optimizations: List[str] = None,
    fixed_point: bool = False,
    optimized_model_name: str = None,
):
    """
    Optimize the given ONNX model.

    :param context:              The MLRun function execution context.
    :param model_path:           Path to the ONNX model object.
    :param handler_init_kwargs:  Keyword arguments to pass to the `ONNXModelHandler` init method preloading.
    :param optimizations:        List of possible optimizations. To see what optimizations are available, pass "help".
                                 If None, all the optimizations will be used. Defaulted to None.
    :param fixed_point:          Optimize the weights using fixed point. Defaulted to False.
    :param optimized_model_name: The name of the optimized model. If None, the original model will be overridden.
                                 Defaulted to None.
    """
    # Import the model handler:
    import onnxoptimizer
    from mlrun.frameworks.onnx import ONNXModelHandler

    # Check if needed to print the available optimizations ("help" is passed):
    if optimizations == "help":
        available_passes = "\n* ".join(onnxoptimizer.get_available_passes())
        print(f"The available optimizations are:\n* {available_passes}")
        return

    # Create the model handler:
    handler_init_kwargs = handler_init_kwargs or {}
    model_handler = ONNXModelHandler(
        model_path=model_path, context=context, **handler_init_kwargs
    )

    # Load the ONNX model:
    model_handler.load()

    # Optimize the model using the given configurations:
    model_handler.optimize(optimizations=optimizations, fixed_point=fixed_point)

    # Rename if needed:
    if optimized_model_name is not None:
        model_handler.set_model_name(model_name=optimized_model_name)

    # Log the optimized model:
    model_handler.log()
 + base_image: mlrun/mlrun + with_mlrun: false + auto_build: true + requirements: + - tqdm~=4.67.1 + - tensorflow~=2.19.0 + - tf_keras~=2.19.0 + - torch~=2.8.0 + - torchvision~=0.23.0 + - onnx~=1.17.0 + - onnxruntime~=1.19.2 + - onnxoptimizer~=0.3.13 + - onnxmltools~=1.13.0 + - tf2onnx~=1.16.1 + - plotly~=5.23 + origin_filename: '' + code_origin: '' +verbose: false diff --git a/functions/src/onnx_utils/item.yaml b/functions/src/onnx_utils/item.yaml index 803bd2599..5f129389f 100644 --- a/functions/src/onnx_utils/item.yaml +++ b/functions/src/onnx_utils/item.yaml @@ -13,7 +13,7 @@ labels: author: Iguazio maintainers: [] marketplaceType: '' -mlrunVersion: 1.7.2 +mlrunVersion: 1.10.0 name: onnx_utils platformVersion: 3.5.0 spec: @@ -30,8 +30,8 @@ spec: - tqdm~=4.67.1 - tensorflow~=2.19.0 - tf_keras~=2.19.0 - - torch~=2.6.0 - - torchvision~=0.21.0 + - torch~=2.8.0 + - torchvision~=0.23.0 - onnx~=1.17.0 - onnxruntime~=1.19.2 - onnxoptimizer~=0.3.13 @@ -39,4 +39,4 @@ spec: - tf2onnx~=1.16.1 - plotly~=5.23 url: '' -version: 1.3.0 +version: 1.4.0 diff --git a/functions/src/onnx_utils/onnx_utils.ipynb b/functions/src/onnx_utils/onnx_utils.ipynb index 78203a45d..14c810fab 100644 --- a/functions/src/onnx_utils/onnx_utils.ipynb +++ b/functions/src/onnx_utils/onnx_utils.ipynb @@ -77,9 +77,9 @@ "source": [ "### 1.2. Demo\n", "\n", - "We will use the `TF.Keras` framework, a `MobileNetV2` as our model and we will convert it to ONNX using the `to_onnx` handler.\n", + "We will use the `PyTorch` framework, a `MobileNetV2` as our model and we will convert it to ONNX using the `to_onnx` handler.\n", "\n", - "1.2.1. First we will set a temporary artifact path for our model to be saved in and choose the models names:" + "1.2.1. First we will set the artifact path for our model to be saved in and choose the models names:" ] }, { @@ -87,16 +87,21 @@ "metadata": { "pycharm": { "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2026-02-10T14:13:28.256582Z", + "start_time": "2026-02-10T14:13:28.250886Z" } }, "source": [ "import os\n", - "os.environ[\"TF_USE_LEGACY_KERAS\"] = \"true\"\n", - "from tempfile import TemporaryDirectory\n", + "import tempfile\n", + "# Use a temporary directory for model artifacts (safe cleanup):\n", + "ARTIFACT_PATH = tempfile.mkdtemp()\n", + "os.environ[\"MLRUN_ARTIFACT_PATH\"] = ARTIFACT_PATH\n", "\n", - "# Create a temporary directory for the model artifact:\n", - "ARTIFACT_PATH = TemporaryDirectory().name\n", - "os.makedirs(ARTIFACT_PATH)\n", + "# Project name:\n", + "PROJECT_NAME = \"onnx-utils\"\n", "\n", "# Choose our model's name:\n", "MODEL_NAME = \"mobilenetv2\"\n", @@ -108,7 +113,7 @@ "OPTIMIZED_ONNX_MODEL_NAME = \"optimized_onnx_mobilenetv2\"" ], "outputs": [], - "execution_count": null + "execution_count": 1 }, { "cell_type": "markdown", @@ -118,87 +123,88 @@ } }, "source": [ - "1.2.2. Download the model from `keras.applications` and log it with MLRun's `TFKerasModelHandler`:" + "1.2.2. Download the model from `torchvision.models` and log it with MLRun's `PyTorchModelHandler`:" ] }, { - "cell_type": "code", "metadata": { - "pycharm": { - "name": "#%%\n" + "ExecuteTime": { + "end_time": "2026-02-10T14:00:15.032590Z", + "start_time": "2026-02-10T14:00:15.031196Z" } }, - "source": [ - "# mlrun: start-code" - ], + "cell_type": "code", + "source": "# mlrun: start-code", "outputs": [], - "execution_count": null + "execution_count": 8 }, { + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-10T14:14:00.992001Z", + "start_time": "2026-02-10T14:13:33.115438Z" + } + }, "cell_type": "code", - "metadata": {}, "source": [ - "from tensorflow import keras\n", + "import torchvision\n", "\n", "import mlrun\n", - "import mlrun.frameworks.tf_keras as mlrun_tf_keras\n", + "from mlrun.frameworks.pytorch import PyTorchModelHandler\n", "\n", "\n", "def get_model(context: mlrun.MLClientCtx, model_name: str):\n", " # Download the MobileNetV2 model:\n", - " model = keras.applications.mobilenet_v2.MobileNetV2()\n", + " model = torchvision.models.mobilenet_v2()\n", "\n", " # Initialize a model handler for logging the model:\n", - " model_handler = mlrun_tf_keras.TFKerasModelHandler(\n", + " model_handler = PyTorchModelHandler(\n", " model_name=model_name,\n", " model=model,\n", - " context=context\n", + " model_class=\"mobilenet_v2\",\n", + " modules_map={\"torchvision.models\": \"mobilenet_v2\"},\n", + " context=context,\n", " )\n", "\n", " # Log the model:\n", " model_handler.log()" ], "outputs": [], - "execution_count": null + "execution_count": 2 }, { - "cell_type": "code", "metadata": { - "pycharm": { - "name": "#%%\n" + "ExecuteTime": { + "end_time": "2026-02-10T14:00:15.040221Z", + "start_time": "2026-02-10T14:00:15.038886Z" } }, - "source": [ - "# mlrun: end-code" - ], + "cell_type": "code", + "source": "# mlrun: end-code", "outputs": [], - "execution_count": null - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "1.2.3. Create the function using MLRun's `code_to_function` and run it:" - ] + "execution_count": 10 }, { "cell_type": "code", "metadata": { "pycharm": { "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2026-02-10T14:14:34.429194Z", + "start_time": "2026-02-10T14:14:07.906087Z" } }, "source": [ "import mlrun\n", "\n", + "# Create or get the MLRun project:\n", + "project = mlrun.get_or_create_project(PROJECT_NAME, context=\"./\")\n", "\n", "# Create the function parsing this notebook's code using 'code_to_function':\n", "get_model_function = mlrun.code_to_function(\n", " name=\"get_mobilenetv2\",\n", + " project=PROJECT_NAME,\n", " kind=\"job\",\n", " image=\"mlrun/ml-models\"\n", ")\n", @@ -206,15 +212,267 @@ "# Run the function to log the model:\n", "get_model_run = get_model_function.run(\n", " handler=\"get_model\",\n", - " artifact_path=ARTIFACT_PATH,\n", + " output_path=ARTIFACT_PATH,\n", " params={\n", " \"model_name\": MODEL_NAME\n", " },\n", " local=True\n", ")" ], - "outputs": [], - "execution_count": null + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> 2026-02-10 16:14:24,932 [info] Created and saved project: {\"context\":\"./\",\"from_template\":null,\"name\":\"onnx-utils\",\"overwrite\":false,\"save\":true}\n", + "> 2026-02-10 16:14:24,933 [info] Project created successfully: {\"project_name\":\"onnx-utils\",\"stored_in_db\":true}\n", + "> 2026-02-10 16:14:31,659 [info] Storing function: {\"db\":null,\"name\":\"get-mobilenetv2-get-model\",\"uid\":\"7b9d1b54375b44e191d73685a382c910\"}\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
projectuiditerstartendstatekindnamelabelsinputsparametersresultsartifact_uris
onnx-utils0Feb 10 14:14:32NaTcompletedrunget-mobilenetv2-get-model
v3io_user=omerm
kind=local
owner=omerm
host=M-KCX16N69X3
model_name=mobilenetv2
mobilenetv2_modules_map.json=store://artifacts/onnx-utils/#0@7b9d1b54375b44e191d73685a382c910
model=store://models/onnx-utils/mobilenetv2#0@7b9d1b54375b44e191d73685a382c910^e0393bc5b070fd55cc57cecb94160ce412498e0f
\n", + "
\n", + "
\n", + "
\n", + " Title\n", + " ×\n", + "
\n", + " \n", + "
\n", + "
\n" + ] + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ], + "text/html": [ + " > to track results use the .show() or .logs() methods or click here to open in UI" + ] + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> 2026-02-10 16:14:34,427 [info] Run execution finished: {\"name\":\"get-mobilenetv2-get-model\",\"status\":\"completed\"}\n" + ] + } + ], + "execution_count": 3 }, { "cell_type": "markdown", @@ -228,33 +486,271 @@ "metadata": { "pycharm": { "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2026-02-10T14:14:53.863947Z", + "start_time": "2026-02-10T14:14:48.088349Z" } }, - "source": [ - "# Import the ONNX function from the marketplace:\n", - "onnx_utils_function = mlrun.import_function(\"hub://onnx_utils\")\n", - "\n", - "# Run the function to convert our model to ONNX:\n", - "to_onnx_run = onnx_utils_function.run(\n", - " handler=\"to_onnx\",\n", - " artifact_path=ARTIFACT_PATH,\n", - " params={\n", - " \"model_name\": MODEL_NAME,\n", - " \"model_path\": get_model_run.outputs[MODEL_NAME], # <- Take the logged model from the previous function.\n", - " \"onnx_model_name\": ONNX_MODEL_NAME,\n", - " \"optimize_model\": False # <- For optimizing it later in the demo, we mark the flag as False\n", - " },\n", - " local=True\n", - ")" + "source": "# Import the ONNX function from the marketplace:\nonnx_utils_function = mlrun.import_function(\"hub://onnx_utils\", project=PROJECT_NAME)\n\n# Construct the model path from the run directory structure:\nmodel_path = os.path.join(ARTIFACT_PATH, \"get-mobilenetv2-get-model\", \"0\", \"model\")\nmodules_map_path = os.path.join(ARTIFACT_PATH, \"get-mobilenetv2-get-model\", \"0\", \"mobilenetv2_modules_map.json.json\")\n\n# Run the function to convert our model to ONNX:\nto_onnx_run = onnx_utils_function.run(\n handler=\"to_onnx\",\n output_path=ARTIFACT_PATH,\n params={\n \"model_name\": MODEL_NAME,\n \"model_path\": model_path,\n \"load_model_kwargs\": {\n \"model_name\": MODEL_NAME,\n \"model_class\": \"mobilenet_v2\",\n \"modules_map\": modules_map_path,\n },\n \"onnx_model_name\": ONNX_MODEL_NAME,\n \"optimize_model\": False, # <- For optimizing it later in the demo, we mark the flag as False\n \"framework_kwargs\": {\"input_signature\": [((32, 3, 224, 224), \"float32\")]},\n },\n local=True\n)", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> 2026-02-10 16:14:48,519 [info] Storing function: {\"db\":null,\"name\":\"onnx-utils-to-onnx\",\"uid\":\"95deb2c7dbf0460291efb25c48eeebd7\"}\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
projectuiditerstartendstatekindnamelabelsinputsparametersresultsartifact_uris
onnx-utils0Feb 10 14:14:49NaTcompletedrunonnx-utils-to-onnx
v3io_user=omerm
kind=local
owner=omerm
host=M-KCX16N69X3
model_name=mobilenetv2
model_path=/var/folders/rn/q8gs952n26982d36y50w_2rw0000gp/T/tmpvs5qvbxr/get-mobilenetv2-get-model/0/model
load_model_kwargs={'model_name': 'mobilenetv2', 'model_class': 'mobilenet_v2', 'modules_map': '/var/folders/rn/q8gs952n26982d36y50w_2rw0000gp/T/tmpvs5qvbxr/get-mobilenetv2-get-model/0/mobilenetv2_modules_map.json.json'}
onnx_model_name=onnx_mobilenetv2
optimize_model=False
framework_kwargs={'input_signature': [((32, 3, 224, 224), 'float32')]}
model=store://models/onnx-utils/onnx_mobilenetv2#0@95deb2c7dbf0460291efb25c48eeebd7^03e4286da44d015cf5465d43e809a504d15f7f63
\n", + "
\n", + "
\n", + "
\n", + " Title\n", + " ×\n", + "
\n", + " \n", + "
\n", + "
\n" + ] + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ], + "text/html": [ + " > to track results use the .show() or .logs() methods or click here to open in UI" + ] + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> 2026-02-10 16:14:53,862 [info] Run execution finished: {\"name\":\"onnx-utils-to-onnx\",\"status\":\"completed\"}\n" + ] + } ], - "outputs": [], - "execution_count": null + "execution_count": 4 }, { "cell_type": "markdown", "metadata": {}, "source": [ - "1.2.5. Now, listing the artifact directory we will see both our `tf.keras` model and the `onnx` model:" + "1.2.5. Now we verify the ONNX model was created:" ] }, { @@ -262,16 +758,29 @@ "metadata": { "pycharm": { "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2026-02-10T14:14:56.820411Z", + "start_time": "2026-02-10T14:14:56.817892Z" } }, "source": [ "import os\n", "\n", - "\n", - "print(os.listdir(ARTIFACT_PATH))" + "onnx_model_file = os.path.join(ARTIFACT_PATH, \"onnx-utils-to-onnx\", \"0\", \"model\", \"onnx_mobilenetv2.onnx\")\n", + "assert os.path.isfile(onnx_model_file), f\"ONNX model not found at {onnx_model_file}\"\n", + "print(f\"ONNX model created at: {onnx_model_file}\")" ], - "outputs": [], - "execution_count": null + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ONNX model created at: /var/folders/rn/q8gs952n26982d36y50w_2rw0000gp/T/tmpvs5qvbxr/onnx-utils-to-onnx/0/model/onnx_mobilenetv2.onnx\n" + ] + } + ], + "execution_count": 5 }, { "cell_type": "markdown", @@ -308,28 +817,281 @@ "metadata": { "pycharm": { "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2026-02-10T14:15:03.415997Z", + "start_time": "2026-02-10T14:15:00.637332Z" } }, - "source": [ - "onnx_utils_function.run(\n", - " handler=\"optimize\",\n", - " artifact_path=ARTIFACT_PATH,\n", - " params={\n", - " \"model_name\": ONNX_MODEL_NAME,\n", - " \"model_path\": to_onnx_run.output(ONNX_MODEL_NAME), # <- Take the logged model from the previous function.\n", - " \"optimized_model_name\": OPTIMIZED_ONNX_MODEL_NAME,\n", - " },\n", - " local=True\n", - ")" + "source": "# Construct the ONNX model path from the run directory structure:\nonnx_model_path = os.path.join(ARTIFACT_PATH, \"onnx-utils-to-onnx\", \"0\", \"model\")\n\nonnx_utils_function.run(\n handler=\"optimize\",\n output_path=ARTIFACT_PATH,\n params={\n \"model_path\": onnx_model_path,\n \"handler_init_kwargs\": {\"model_name\": ONNX_MODEL_NAME},\n \"optimized_model_name\": OPTIMIZED_ONNX_MODEL_NAME,\n },\n local=True\n)", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> 2026-02-10 16:15:00,639 [info] Storing function: {\"db\":null,\"name\":\"onnx-utils-optimize\",\"uid\":\"0c30d7af94814dcabde8152a1951fb5d\"}\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
projectuiditerstartendstatekindnamelabelsinputsparametersresultsartifact_uris
onnx-utils0Feb 10 14:15:01NaTcompletedrunonnx-utils-optimize
v3io_user=omerm
kind=local
owner=omerm
host=M-KCX16N69X3
model_path=/var/folders/rn/q8gs952n26982d36y50w_2rw0000gp/T/tmpvs5qvbxr/onnx-utils-to-onnx/0/model
handler_init_kwargs={'model_name': 'onnx_mobilenetv2'}
optimized_model_name=optimized_onnx_mobilenetv2
model=store://models/onnx-utils/optimized_onnx_mobilenetv2#0@0c30d7af94814dcabde8152a1951fb5d^599547984e83a664dc1c2708607d06731edb5ac2
\n", + "
\n", + "
\n", + "
\n", + " Title\n", + " ×\n", + "
\n", + " \n", + "
\n", + "
\n" + ] + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ], + "text/html": [ + " > to track results use the .show() or .logs() methods or click here to open in UI" + ] + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> 2026-02-10 16:15:03,414 [info] Run execution finished: {\"name\":\"onnx-utils-optimize\",\"status\":\"completed\"}\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } ], - "outputs": [], - "execution_count": null + "execution_count": 6 }, { "cell_type": "markdown", "metadata": {}, "source": [ - "2.2.2. And now our model was optimized and can be seen under the `ARTIFACT_PATH`:" + "2.2.2. And now our model was optimized. Let us verify:" ] }, { @@ -337,13 +1099,27 @@ "metadata": { "pycharm": { "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2026-02-10T14:15:05.748413Z", + "start_time": "2026-02-10T14:15:05.745309Z" } }, "source": [ - "print(os.listdir(ARTIFACT_PATH))" + "optimized_model_file = os.path.join(ARTIFACT_PATH, \"onnx-utils-optimize\", \"0\", \"model\", \"optimized_onnx_mobilenetv2.onnx\")\n", + "assert os.path.isfile(optimized_model_file), f\"Optimized ONNX model not found at {optimized_model_file}\"\n", + "print(f\"Optimized ONNX model created at: {optimized_model_file}\")" ], - "outputs": [], - "execution_count": null + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Optimized ONNX model created at: /var/folders/rn/q8gs952n26982d36y50w_2rw0000gp/T/tmpvs5qvbxr/onnx-utils-optimize/0/model/optimized_onnx_mobilenetv2.onnx\n" + ] + } + ], + "execution_count": 7 }, { "cell_type": "markdown", @@ -353,7 +1129,7 @@ } }, "source": [ - "Lastly, run this code to clean up the models:" + "Lastly, run this code to clean up all generated files and directories:" ] }, { @@ -361,23 +1137,22 @@ "metadata": { "pycharm": { "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2026-02-10T14:00:28.409998Z", + "start_time": "2026-02-10T13:57:21.679146Z" } }, - "source": [ - "import shutil\n", - "\n", - "\n", - "shutil.rmtree(ARTIFACT_PATH)" - ], + "source": "import shutil\n\n# Clean up the temporary artifact directory:\nif os.path.exists(ARTIFACT_PATH):\n shutil.rmtree(ARTIFACT_PATH)", "outputs": [], "execution_count": null } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "mlrun_functions", "language": "python", - "name": "python3" + "name": "mlrun_functions" }, "language_info": { "codemirror_mode": { diff --git a/functions/src/onnx_utils/requirements.txt b/functions/src/onnx_utils/requirements.txt index d3d7dfd68..912b3d7e5 100644 --- a/functions/src/onnx_utils/requirements.txt +++ b/functions/src/onnx_utils/requirements.txt @@ -1,11 +1,10 @@ tqdm~=4.67.1 tensorflow~=2.19.0 tf_keras~=2.19.0 -torch~=2.6.0 -torchvision~=0.21.0 +torch~=2.8 +torchvision~=0.23.0 onnx~=1.17.0 onnxruntime~=1.19.2 onnxoptimizer~=0.3.13 onnxmltools~=1.13.0 -tf2onnx~=1.16.1 -plotly~=5.23 +plotly~=5.23 \ No newline at end of file diff --git a/functions/src/onnx_utils/test_onnx_utils.py b/functions/src/onnx_utils/test_onnx_utils.py index 2e01782f5..59c6c2b38 100644 --- a/functions/src/onnx_utils/test_onnx_utils.py +++ b/functions/src/onnx_utils/test_onnx_utils.py @@ -17,6 +17,9 @@ import tempfile import mlrun +import pytest + +PROJECT_NAME = "onnx-utils" # Choose our model's name: MODEL_NAME = "model" @@ -27,41 +30,67 @@ # Choose our optimized ONNX version model's name: OPTIMIZED_ONNX_MODEL_NAME = f"optimized_{ONNX_MODEL_NAME}" +REQUIRED_ENV_VARS = [ + "MLRUN_DBPATH", + "MLRUN_ARTIFACT_PATH", + "V3IO_USERNAME", + "V3IO_ACCESS_KEY", +] -def _setup_environment() -> str: - """ - Setup the test environment, creating the artifacts path of the test. - :returns: The temporary directory created for the test artifacts path. +def _validate_environment_variables() -> bool: """ - artifact_path = tempfile.TemporaryDirectory().name - os.makedirs(artifact_path) - return artifact_path + Checks that all required Environment variables are set. + """ + environment_keys = os.environ.keys() + return all(key in environment_keys for key in REQUIRED_ENV_VARS) -def _cleanup_environment(artifact_path: str): +def _is_tf2onnx_available() -> bool: """ - Cleanup the test environment, deleting files and artifacts created during the test. - - :param artifact_path: The artifact path to delete. + Check if tf2onnx is installed (required for TensorFlow/Keras ONNX conversion). """ - # Clean the local directory: + try: + import tf2onnx + return True + except ImportError: + return False + + +@pytest.fixture(scope="session") +def onnx_project(): + """Create/get the MLRun project once per test session.""" + return mlrun.get_or_create_project(PROJECT_NAME, context="./") + + +@pytest.fixture(autouse=True) +def test_environment(onnx_project): + """Setup and cleanup test artifacts for each test.""" + artifact_path = tempfile.mkdtemp() + yield artifact_path + # Cleanup - only remove files/dirs from the directory containing this test file, + # never from an arbitrary CWD (which could be the project root). + test_dir = os.path.dirname(os.path.abspath(__file__)) for test_output in [ - *os.listdir(artifact_path), "schedules", "runs", "artifacts", "functions", + "model.pt", + "model.zip", + "model_modules_map.json", + "model_modules_map.json.json", + "onnx_model.onnx", + "optimized_onnx_model.onnx", ]: - test_output_path = os.path.abspath(f"./{test_output}") + test_output_path = os.path.join(test_dir, test_output) if os.path.exists(test_output_path): if os.path.isdir(test_output_path): shutil.rmtree(test_output_path) else: os.remove(test_output_path) - - # Clean the artifacts directory: - shutil.rmtree(artifact_path) + if os.path.exists(artifact_path): + shutil.rmtree(artifact_path) def _log_tf_keras_model(context: mlrun.MLClientCtx, model_name: str): @@ -114,42 +143,55 @@ def _log_pytorch_model(context: mlrun.MLClientCtx, model_name: str): model_handler.log() -def test_to_onnx_help(): +@pytest.mark.skipif( + condition=not _validate_environment_variables(), + reason="Project's environment variables are not set", +) +def test_to_onnx_help(test_environment): """ Test the 'to_onnx' handler, passing "help" in the 'framework_kwargs'. """ - # Setup the tests environment: - artifact_path = _setup_environment() + artifact_path = test_environment # Create the function: log_model_function = mlrun.code_to_function( filename="test_onnx_utils.py", name="log_model", + project=PROJECT_NAME, kind="job", image="mlrun/ml-models", ) # Run the function to log the model: - log_model_run = log_model_function.run( - handler="_log_tf_keras_model", - artifact_path=artifact_path, + log_model_function.run( + handler="_log_pytorch_model", + output_path=artifact_path, params={"model_name": MODEL_NAME}, local=True, ) + # Get artifact paths - construct from artifact_path and run structure + run_artifact_dir = os.path.join(artifact_path, "log-model--log-pytorch-model", "0") + model_path = os.path.join(run_artifact_dir, "model") + modules_map_path = os.path.join(run_artifact_dir, "model_modules_map.json.json") + # Import the ONNX Utils function: - onnx_function = mlrun.import_function("function.yaml") + onnx_function = mlrun.import_function("function.yaml", project=PROJECT_NAME) # Run the function, passing "help" in 'framework_kwargs' and see that no exception was raised: is_test_passed = True try: onnx_function.run( handler="to_onnx", - artifact_path=artifact_path, + output_path=artifact_path, params={ # Take the logged model from the previous function. - "model_path": log_model_run.status.artifacts[0]["spec"]["target_path"], - "load_model_kwargs": {"model_name": MODEL_NAME}, + "model_path": model_path, + "load_model_kwargs": { + "model_name": MODEL_NAME, + "model_class": "mobilenet_v2", + "modules_map": modules_map_path, + }, "framework_kwargs": "help", }, local=True, @@ -160,23 +202,28 @@ def test_to_onnx_help(): ) is_test_passed = False - # Cleanup the tests environment: - _cleanup_environment(artifact_path=artifact_path) - assert is_test_passed -def test_tf_keras_to_onnx(): +@pytest.mark.skipif( + condition=not _validate_environment_variables(), + reason="Project's environment variables are not set", +) +@pytest.mark.skipif( + condition=not _is_tf2onnx_available(), + reason="tf2onnx is not installed", +) +def test_tf_keras_to_onnx(test_environment): """ Test the 'to_onnx' handler, giving it a tf.keras model. """ - # Setup the tests environment: - artifact_path = _setup_environment() + artifact_path = test_environment # Create the function: log_model_function = mlrun.code_to_function( filename="test_onnx_utils.py", name="log_model", + project=PROJECT_NAME, kind="job", image="mlrun/ml-models", ) @@ -184,18 +231,18 @@ def test_tf_keras_to_onnx(): # Run the function to log the model: log_model_run = log_model_function.run( handler="_log_tf_keras_model", - artifact_path=artifact_path, + output_path=artifact_path, params={"model_name": MODEL_NAME}, local=True, ) # Import the ONNX Utils function: - onnx_function = mlrun.import_function("function.yaml") + onnx_function = mlrun.import_function("function.yaml", project=PROJECT_NAME) # Run the function to convert our model to ONNX: onnx_function_run = onnx_function.run( handler="to_onnx", - artifact_path=artifact_path, + output_path=artifact_path, params={ # Take the logged model from the previous function. "model_path": log_model_run.status.artifacts[0]["spec"]["target_path"], @@ -205,9 +252,6 @@ def test_tf_keras_to_onnx(): local=True, ) - # Cleanup the tests environment: - _cleanup_environment(artifact_path=artifact_path) - # Print the outputs list: print(f"Produced outputs: {onnx_function_run.outputs}") @@ -215,17 +259,21 @@ def test_tf_keras_to_onnx(): assert "model" in onnx_function_run.outputs -def test_pytorch_to_onnx(): +@pytest.mark.skipif( + condition=not _validate_environment_variables(), + reason="Project's environment variables are not set", +) +def test_pytorch_to_onnx(test_environment): """ Test the 'to_onnx' handler, giving it a pytorch model. """ - # Setup the tests environment: - artifact_path = _setup_environment() + artifact_path = test_environment # Create the function: log_model_function = mlrun.code_to_function( filename="test_onnx_utils.py", name="log_model", + project=PROJECT_NAME, kind="job", image="mlrun/ml-models", ) @@ -233,25 +281,30 @@ def test_pytorch_to_onnx(): # Run the function to log the model: log_model_run = log_model_function.run( handler="_log_pytorch_model", - artifact_path=artifact_path, + output_path=artifact_path, params={"model_name": MODEL_NAME}, local=True, ) # Import the ONNX Utils function: - onnx_function = mlrun.import_function("function.yaml") + onnx_function = mlrun.import_function("function.yaml", project=PROJECT_NAME) + + # Get artifact paths - construct from artifact_path and run structure + run_artifact_dir = os.path.join(artifact_path, "log-model--log-pytorch-model", "0") + model_path = os.path.join(run_artifact_dir, "model") + modules_map_path = os.path.join(run_artifact_dir, "model_modules_map.json.json") # Run the function to convert our model to ONNX: onnx_function_run = onnx_function.run( handler="to_onnx", - artifact_path=artifact_path, + output_path=artifact_path, params={ # Take the logged model from the previous function. - "model_path": log_model_run.status.artifacts[1]["spec"]["target_path"], + "model_path": model_path, "load_model_kwargs": { "model_name": MODEL_NAME, "model_class": "mobilenet_v2", - "modules_map": log_model_run.status.artifacts[0]["spec"]["target_path"], + "modules_map": modules_map_path, }, "onnx_model_name": ONNX_MODEL_NAME, "framework_kwargs": {"input_signature": [((32, 3, 224, 224), "float32")]}, @@ -259,9 +312,6 @@ def test_pytorch_to_onnx(): local=True, ) - # Cleanup the tests environment: - _cleanup_environment(artifact_path=artifact_path) - # Print the outputs list: print(f"Produced outputs: {onnx_function_run.outputs}") @@ -269,22 +319,25 @@ def test_pytorch_to_onnx(): assert "model" in onnx_function_run.outputs -def test_optimize_help(): +@pytest.mark.skipif( + condition=not _validate_environment_variables(), + reason="Project's environment variables are not set", +) +def test_optimize_help(test_environment): """ Test the 'optimize' handler, passing "help" in the 'optimizations'. """ - # Setup the tests environment: - artifact_path = _setup_environment() + artifact_path = test_environment # Import the ONNX Utils function: - onnx_function = mlrun.import_function("function.yaml") + onnx_function = mlrun.import_function("function.yaml", project=PROJECT_NAME) # Run the function, passing "help" in 'optimizations' and see that no exception was raised: is_test_passed = True try: onnx_function.run( handler="optimize", - artifact_path=artifact_path, + output_path=artifact_path, params={ "model_path": "", "optimizations": "help", @@ -297,69 +350,81 @@ def test_optimize_help(): ) is_test_passed = False - # Cleanup the tests environment: - _cleanup_environment(artifact_path=artifact_path) - assert is_test_passed -def test_optimize(): +@pytest.mark.skipif( + condition=not _validate_environment_variables(), + reason="Project's environment variables are not set", +) +def test_optimize(test_environment): """ - Test the 'optimize' handler, giving it a model from the ONNX zoo git repository. + Test the 'optimize' handler, giving it a pytorch model converted to ONNX. """ - # Setup the tests environment: - artifact_path = _setup_environment() + artifact_path = test_environment # Create the function: log_model_function = mlrun.code_to_function( filename="test_onnx_utils.py", name="log_model", + project=PROJECT_NAME, kind="job", image="mlrun/ml-models", ) # Run the function to log the model: - log_model_run = log_model_function.run( - handler="_log_tf_keras_model", - artifact_path=artifact_path, + log_model_function.run( + handler="_log_pytorch_model", + output_path=artifact_path, params={"model_name": MODEL_NAME}, local=True, ) + # Get artifact paths - construct from artifact_path and run structure + run_artifact_dir = os.path.join(artifact_path, "log-model--log-pytorch-model", "0") + model_path = os.path.join(run_artifact_dir, "model") + modules_map_path = os.path.join(run_artifact_dir, "model_modules_map.json.json") + # Import the ONNX Utils function: - onnx_function = mlrun.import_function("function.yaml") + onnx_function = mlrun.import_function("function.yaml", project=PROJECT_NAME) # Run the function to convert our model to ONNX: - to_onnx_function_run = onnx_function.run( + onnx_function.run( handler="to_onnx", - artifact_path=artifact_path, + output_path=artifact_path, params={ # Take the logged model from the previous function. - "model_path": log_model_run.status.artifacts[0]["spec"]["target_path"], - "load_model_kwargs": {"model_name": MODEL_NAME}, + "model_path": model_path, + "load_model_kwargs": { + "model_name": MODEL_NAME, + "model_class": "mobilenet_v2", + "modules_map": modules_map_path, + }, "onnx_model_name": ONNX_MODEL_NAME, + "framework_kwargs": {"input_signature": [((32, 3, 224, 224), "float32")]}, }, local=True, ) + # Get the ONNX model path from the to_onnx run output + onnx_run_artifact_dir = os.path.join( + artifact_path, "onnx-utils-to-onnx", "0" + ) + onnx_model_path = os.path.join(onnx_run_artifact_dir, "model") + # Run the function to optimize our model: optimize_function_run = onnx_function.run( handler="optimize", - artifact_path=artifact_path, + output_path=artifact_path, params={ # Take the logged model from the previous function. - "model_path": to_onnx_function_run.status.artifacts[0]["spec"][ - "target_path" - ], + "model_path": onnx_model_path, "handler_init_kwargs": {"model_name": ONNX_MODEL_NAME}, "optimized_model_name": OPTIMIZED_ONNX_MODEL_NAME, }, local=True, ) - # Cleanup the tests environment: - _cleanup_environment(artifact_path=artifact_path) - # Print the outputs list: print(f"Produced outputs: {optimize_function_run.outputs}") diff --git a/modules/src/langchain_mlrun/item.yaml b/modules/src/langchain_mlrun/item.yaml new file mode 100644 index 000000000..532cb4bd3 --- /dev/null +++ b/modules/src/langchain_mlrun/item.yaml @@ -0,0 +1,24 @@ +apiVersion: v1 +categories: +- langchain +- langgraph +- tracing +- monitoring +- llm +description: LangChain x MLRun integration - Orchestrate your LangChain code with MLRun. +example: langchain_mlrun.ipynb +generationDate: 2026-01-08:12-25 +hidden: false +labels: + author: Iguazio +mlrunVersion: 1.10.0 +name: langchain_mlrun +spec: + filename: langchain_mlrun.py + image: mlrun/mlrun + kind: generic + requirements: + - langchain~=1.2 + - pydantic-settings~=2.12 + - kafka-python~=2.3 +version: 0.0.1 \ No newline at end of file diff --git a/modules/src/langchain_mlrun/langchain_mlrun.ipynb b/modules/src/langchain_mlrun/langchain_mlrun.ipynb new file mode 100644 index 000000000..0e5a341e7 --- /dev/null +++ b/modules/src/langchain_mlrun/langchain_mlrun.ipynb @@ -0,0 +1,1046 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "7955da79-02cc-42fe-aee0-5456d3e386fd", + "metadata": {}, + "source": [ + "# LangChain ✕ MLRun Integration\n", + "\n", + "`langchain_mlrun` is a hub module that implements LangChain integration with MLRun. Using the module allows MLRun to orchestrate LangChain and LangGraph code, enabling tracing and monitoring batch workflows and realtime deployments.\n", + "___" + ] + }, + { + "cell_type": "markdown", + "id": "8392a3e1-d0a1-409a-ae68-fcc36858d30a", + "metadata": {}, + "source": [ + "## Main Components\n", + "\n", + "This is a short brief of the components available to import from the `langchain_mlrun` module. For full docs, see the documentation page.\n", + "\n", + "### Settings\n", + "\n", + "The module uses Pydantic settings classes that can be configured programmatically or via environment variables. The main class is `MLRunTracerSettings`. It contains two sub-settings:\n", + "* `MLRunTracerClientSettings` - Connection settings (stream path, container, endpoint info). Env prefix: `\"LC_MLRUN_TRACER_CLIENT_\"`\n", + "* `MLRunTracerMonitorSettings` - Controls what/how runs are captured (filters, labels, debug mode). Env prefix: `\"LC_MLRUN_TRACER_MONITOR_\"`\n", + "\n", + "For more information about each setting, see the class docstrings.\n", + "\n", + "#### Example - via code configuration\n", + "\n", + "```python\n", + "from langchain_mlrun import MLRunTracerSettings, MLRunTracerClientSettings, MLRunTracerMonitorSettings\n", + "\n", + "settings = MLRunTracerSettings(\n", + " client=MLRunTracerClientSettings(\n", + " stream_path=\"my-project/model-endpoints/stream-v1\",\n", + " container=\"projects\",\n", + " model_endpoint_name=\"my_endpoint\",\n", + " model_endpoint_uid=\"abc123\",\n", + " serving_function=\"my_function\",\n", + " ),\n", + " monitor=MLRunTracerMonitorSettings(\n", + " label=\"production\",\n", + " root_run_only=True, # Only monitor root runs, not child runs\n", + " tags_filter=[\"important\"], # Only monitor runs with this tag\n", + " ),\n", + ")\n", + "```\n", + "\n", + "#### Example - environment variable configuration\n", + "\n", + "```bash\n", + "export LC_MLRUN_TRACER_CLIENT_STREAM_PATH=\"my-project/model-endpoints/stream-v1\"\n", + "export LC_MLRUN_TRACER_CLIENT_CONTAINER=\"projects\"\n", + "export LC_MLRUN_TRACER_MONITOR_LABEL=\"production\"\n", + "export LC_MLRUN_TRACER_MONITOR_ROOT_RUN_ONLY=\"true\"\n", + "```\n", + "\n", + "### MLRun Tracer\n", + "\n", + "`MLRunTracer` is a LangChain-compatible tracer that converts LangChain `Run` objects into MLRun monitoring events and publishes them to a V3IO stream. \n", + "\n", + "Key points:\n", + "* **No inheritance required** - use it directly without subclassing.\n", + "* **Fully customizable via settings** - control filtering, summarization, and output format.\n", + "* **Custom summarizer support** - pass your own `run_summarizer_function` via settings to customize how runs are converted to events.\n", + "\n", + "### Monitoring Setup Utility Function\n", + "\n", + "`setup_langchain_monitoring()` is a utility function that creates the necessary MLRun infrastructure for LangChain monitoring. This is a **temporary workaround** until custom endpoint creation support is added to MLRun.\n", + "\n", + "The function returns a dictionary of environment variables to configure auto-tracing. See how to use it in the tutorial section below.\n", + "\n", + "### LangChain Monitoring Application\n", + "\n", + "`LangChainMonitoringApp` is a base class (inheriting from MLRun's `ModelMonitoringApplicationBase`) for building monitoring applications that process events from the MLRun Tracer.\n", + "\n", + "It offers several built-in helper methods and metrics for analyzing LangChain runs:\n", + "\n", + "* Helper methods:\n", + " * `get_structured_runs()` - Parse raw monitoring samples into structured run dictionaries with filtering options\n", + " * `iterate_structured_runs()` - Iterate over all runs including nested child runs\n", + "* Metric methods:\n", + " * `calculate_average_latency()` - Average latency across root runs\n", + " * `calculate_success_rate()` - Percentage of runs without errors\n", + " * `count_token_usage()` - Total input/output tokens from LLM runs\n", + " * `count_run_names()` - Count occurrences of each run name\n", + "\n", + "The base app can be used as-is, but it is recommended to extend it with your own custom monitoring logic.\n", + "___" + ] + }, + { + "cell_type": "markdown", + "id": "7e24e1a5-d80a-4b7e-9b94-57b24e8b39d7", + "metadata": {}, + "source": [ + "## How to Apply MLRun?\n", + "\n", + "### Auto Tracing\n", + "\n", + "Auto tracing automatically instruments all LangChain code by setting the `LC_MLRUN_MONITORING_ENABLED` environment variable and importing the module:\n", + "\n", + "```python\n", + "import os\n", + "os.environ[\"LC_MLRUN_MONITORING_ENABLED\"] = \"1\"\n", + "# Set other LC_MLRUN_TRACER_* environment variables as needed...\n", + "\n", + "# Import the module BEFORE any LangChain code\n", + "langchain_mlrun = mlrun.import_module(\"hub://langchain_mlrun\")\n", + "\n", + "# All LangChain/LangGraph code below will be automatically traced\n", + "chain.invoke(...)\n", + "```\n", + "\n", + "### Manual Tracing\n", + "\n", + "For more control, use the `mlrun_monitoring()` context manager to trace specific code blocks:\n", + "\n", + "```python\n", + "langchain_mlrun = mlrun.import_module(\"hub://langchain_mlrun\")\n", + "mlrun_monitoring = langchain_mlrun.mlrun_monitoring\n", + "MLRunTracerSettings = langchain_mlrun.MLRunTracerSettings\n", + "\n", + "# Optional: customize settings\n", + "settings = MLRunTracerSettings(...)\n", + "\n", + "with mlrun_monitoring(settings=settings) as tracer:\n", + " # Only LangChain code within this block will be traced\n", + " result = chain.invoke({\"topic\": \"MLRun\"})\n", + "```\n", + "___" + ] + }, + { + "cell_type": "markdown", + "id": "68b52d3d-a431-44fb-acd6-ea33fec37a49", + "metadata": {}, + "source": [ + "## Tutorial\n", + "\n", + "In this tutorial we'll show how to orchestrate LangChain based code with MLRun using the `langchain_mlrun` hub module.\n", + "\n", + "### Prerequisites\n", + "\n", + "Install MLRun and the `langchain_mlrun` requirements." + ] + }, + { + "cell_type": "code", + "id": "caf72aa6-06e8-4a04-bfc4-409b39d255fe", + "metadata": {}, + "source": "!pip install mlrun langchain~=1.2 pydantic-settings~=2.12 kafka-python~=2.3", + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "vprq8wj4iqh", + "source": [ + "### Local Development Setup (Optional)\n", + "\n", + "> Skip this section if you're running inside a Jupyter instance deployed in the MLRun cluster.\n", + "\n", + "If you're running this notebook from your local machine, follow these steps:\n", + "\n", + "#### Step 1: Set Environment Variables\n", + "\n", + "Run the cell below to set up all required environment variables for local development." + ], + "metadata": {} + }, + { + "cell_type": "code", + "id": "9lc788zu3zi", + "source": [ + "import os\n", + "\n", + "# MLRun API endpoint:\n", + "# os.environ[\"MLRUN_DBPATH\"] = \"http://localhost:30070\"\n", + "\n", + "# Kafka Configuration:\n", + "# os.environ[\"KAFKA_BROKER\"] = \"\"\n", + "\n", + "# TDEngine Configuration:\n", + "# os.environ[\"TDENGINE_HOST\"] = \"\"\n", + "# os.environ[\"TDENGINE_PORT\"] = \"\"\n", + "# os.environ[\"TDENGINE_USER\"] = \"\"\n", + "# os.environ[\"TDENGINE_PASSWORD\"] = \"\"\n", + "\n", + "# MinIO/S3 Configuration:\n", + "# os.environ[\"AWS_ACCESS_KEY_ID\"] = \"\"\n", + "# os.environ[\"AWS_SECRET_ACCESS_KEY\"] = \"\"\n", + "# os.environ[\"AWS_ENDPOINT_URL_S3\"] = \"\"" + ], + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "#### Step 2: Set Up Port Forwarding\n", + "\n", + "Set up port-forwarding to access cluster services. Run these commands in separate terminal windows:\n", + "\n", + "```bash\n", + "# MLRun API\n", + "kubectl port-forward -n mlrun svc/mlrun-api 30070:8080\n", + "```\n", + "\n", + "```bash\n", + "# MinIO (S3-compatible storage)\n", + "kubectl port-forward -n mlrun svc/minio 9000:9000\n", + "```\n", + "\n", + "```bash\n", + "# Kafka (for CE mode) - requires /etc/hosts entry: 127.0.0.1 kafka-stream\n", + "kubectl port-forward -n mlrun svc/kafka-stream 9092:9092\n", + "```\n", + "\n", + "```bash\n", + "# TDEngine (for CE mode) - requires /etc/hosts entry: 127.0.0.1 tdengine-tsdb\n", + "kubectl port-forward -n mlrun svc/tdengine-tsdb 6041:6041\n", + "```" + ], + "id": "6d1d2d3c016ec62c" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "### Create Project\n", + "\n", + "We'll first create an MLRun project" + ], + "id": "4442f7ad1b0a8ee" + }, + { + "cell_type": "code", + "id": "2664df3e-d9c6-40dd-a215-29d60e4b4208", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-03T19:43:18.142870Z", + "start_time": "2026-02-03T19:43:10.068758Z" + } + }, + "source": [ + "import time\n", + "import datetime\n", + "import mlrun\n", + "\n", + "print(f\"MLRun version: {mlrun.__version__}\")\n", + "print(f\"CE Mode: {mlrun.mlconf.is_ce_mode()}\")\n", + "\n", + "project = mlrun.get_or_create_project(\"langchain-mlrun-tutorial\")\n", + "print(f\"Project: {project.name}\")" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MLRun version: 1.10.0\n", + "CE Mode: True\n", + "> 2026-02-03 21:43:18,053 [info] Loading project from path: {\"path\":\"./\",\"project_name\":\"langchain-mlrun-tutorial\",\"user_project\":false}\n", + "> 2026-02-03 21:43:18,141 [info] Project loaded successfully: {\"path\":\"./\",\"project_name\":\"langchain-mlrun-tutorial\",\"stored_in_db\":true}\n", + "Project: langchain-mlrun-tutorial\n" + ] + } + ], + "execution_count": 3 + }, + { + "cell_type": "markdown", + "id": "33f28986-c158-47fd-97a6-74f69892b4eb", + "metadata": {}, + "source": "### Enable Monitoring\n\nTo use MLRun's monitoring feature in our project we first need to set up the monitoring infrastructure.\n\n- **MLRun CE**: Uses Kafka for streaming (automatically detected)\n- **MLRun Enterprise**: Uses V3IO for streaming (automatically detected)\n\nThe cell below automatically detects your MLRun mode and sets up the appropriate streaming infrastructure." + }, + { + "cell_type": "code", + "id": "d9d2fa66-0498-445d-ab4a-8370f46aec1e", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-03T19:49:22.700332Z", + "start_time": "2026-02-03T19:43:18.148037Z" + } + }, + "source": [ + "# Create datastore profiles (based on CE or Enterprise):\n", + "if mlrun.mlconf.is_ce_mode():\n", + " print(\"Setting up Kafka streaming for MLRun CE...\")\n", + " from mlrun.datastore.datastore_profile import DatastoreProfileKafkaStream, DatastoreProfileTDEngine\n", + " \n", + " stream_profile = DatastoreProfileKafkaStream(\n", + " name=\"kafka-stream-profile\",\n", + " brokers=os.environ[\"KAFKA_BROKER\"],\n", + " topics=[],\n", + " )\n", + " tsdb_profile = DatastoreProfileTDEngine(\n", + " name=\"tsdb-profile\",\n", + " user=os.environ[\"TDENGINE_USER\"],\n", + " password=os.environ[\"TDENGINE_PASSWORD\"],\n", + " host=os.environ[\"TDENGINE_HOST\"],\n", + " port=int(os.environ[\"TDENGINE_PORT\"]),\n", + " )\n", + " project.register_datastore_profile(stream_profile)\n", + " project.register_datastore_profile(tsdb_profile)\n", + "else: # Enterprise\n", + " print(\"Setting up V3IO streaming for MLRun Enterprise...\")\n", + " from mlrun.datastore import DatastoreProfileV3io\n", + " \n", + " stream_profile = DatastoreProfileV3io(name=\"v3io-ds\", v3io_access_key=os.environ[\"V3IO_ACCESS_KEY\"])\n", + " tsdb_profile = stream_profile\n", + " project.register_datastore_profile(stream_profile)\n", + "\n", + "# Enable monitoring in our project:\n", + "project.set_model_monitoring_credentials(\n", + " stream_profile_name=stream_profile.name,\n", + " tsdb_profile_name=tsdb_profile.name,\n", + ")\n", + "project.enable_model_monitoring(\n", + " base_period=1,\n", + " wait_for_deployment=True,\n", + ")\n", + "\n", + "print(\"Monitoring enabled successfully!\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "f23117fa-7b67-470c-80ca-976d14c2120e", + "metadata": {}, + "source": [ + "### Import `langchain_mlrun`\n", + "\n", + "Now we'll import `langchain_mlrun` from the hub." + ] + }, + { + "cell_type": "code", + "id": "2360cd49-b260-4140-bd16-138349e000b3", + "metadata": {}, + "source": [ + "# Import the module from the hub:\n", + "langchain_mlrun = mlrun.import_module(\"hub://langchain_mlrun\")\n", + "\n", + "# Import the utility function and monitoring application from the module:\n", + "setup_langchain_monitoring = langchain_mlrun.setup_langchain_monitoring\n", + "LangChainMonitoringApp = langchain_mlrun.LangChainMonitoringApp" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "de030131-ebaf-48f8-96ed-3c1013b5e260", + "metadata": {}, + "source": [ + "### Create Monitorable Endpoint\n", + "\n", + "Endpoints are the entities being monitored by MLRun. We'll use the `setup_langchain_monitoring()` utility function to create the model monitoring endpoint.\n", + "\n", + "For MLRun CE mode, you must pass the `kafka_stream_profile_name` parameter with the name of the registered Kafka stream profile.\n", + "\n", + "By default, the endpoint name will be `\"langchain_mlrun_endpoint\"` but you can change it by using the `model_endpoint_name` parameter." + ] + }, + { + "cell_type": "code", + "id": "0e9baf78-3d38-46bd-89dd-6f83760eaeb0", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-03T19:49:23.861085Z", + "start_time": "2026-02-03T19:49:23.412235Z" + } + }, + "source": [ + "# Pass kafka_stream_profile_name for CE mode (required)\n", + "env_vars = setup_langchain_monitoring(\n", + " kafka_stream_profile_name=stream_profile.name if mlrun.mlconf.is_ce_mode() else None\n", + ")" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating LangChain model endpoint\n", + "\n", + " [✓] Loading Project......................... Done (0.00s)\u001B[K\n", + " [✓] Creating Model.......................... Done (0.31s) \u001B[K\n", + " [✓] Creating Function....................... Done (0.04s) \u001B[K\n", + " [✓] Creating Model Endpoint................. Done (0.09s) \u001B[K\n", + "\n", + "✨ Done! LangChain monitoring model endpoint created successfully.\n", + "You can now set the following environment variables to enable MLRun tracing in your LangChain code:\n", + "\n", + "{\n", + " \"MLRUN_MONITORING_ENABLED\": \"1\",\n", + " \"MLRUN_TRACER_CLIENT_PROJECT\": \"langchain-mlrun-tutorial\",\n", + " \"MLRUN_TRACER_CLIENT_MODEL_ENDPOINT_NAME\": \"langchain_mlrun_endpoint\",\n", + " \"MLRUN_TRACER_CLIENT_MODEL_ENDPOINT_UID\": \"d1d2b2686772441cacf687b45cd48ffa\",\n", + " \"MLRUN_TRACER_CLIENT_SERVING_FUNCTION\": \"langchain_mlrun_function\",\n", + " \"MLRUN_TRACER_CLIENT_KAFKA_STREAM_PROFILE_NAME\": \"kafka-stream-profile\"\n", + "}\n", + "\n", + "To customize the monitoring behavior, you can also set additional environment variables prefixed with 'MLRUN_TRACER_MONITOR_'. Refer to the MLRun tracer documentation for more details.\n", + "\n" + ] + } + ], + "execution_count": 6 + }, + { + "cell_type": "markdown", + "id": "dd45c94b-ee05-449c-9336-0aa659e66bda", + "metadata": {}, + "source": [ + "### Setup Environment Variables for Auto Tracing\n", + "\n", + "We'll use the environment variables returned from `setup_langchain_monitoring` to setup the environment for auto-tracing. Read the printed outputs for more information." + ] + }, + { + "cell_type": "code", + "id": "1c1988f8-c80a-4bf2-bfb1-d43523fc161f", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-03T19:49:23.866556Z", + "start_time": "2026-02-03T19:49:23.864805Z" + } + }, + "source": [ + "os.environ.update(env_vars)" + ], + "outputs": [], + "execution_count": 7 + }, + { + "cell_type": "markdown", + "id": "d3f3b8e5-3538-4153-95da-e6d8776be3ac", + "metadata": {}, + "source": "### Run `langchain` or `langgraph` Code\n\nHere we have 3 functions, each using different method utilizing LLMs with `langchain` and `langgraph`:\n* `run_simple_chain` - Using `langchain`'s chains.\n* `run_simple_agent` - Using `langchain`'s `create_agent` function and `tool`s.\n* `run_langgraph_graph` - Using pure `langgraph`.\n\n> **Notice**: You don't need to set OpenAI API credentials, there is a mock `ChatModel` that will replace it if the credentials are not set in the environment. If you wish to use OpenAI models, make sure you `pip install langchain_openai` and set the `OPENAI_API_KEY` environment variable before continue to the next cell.\n\nBecause the auto-tracing environment is set, any run will be automatically traced and monitored!\n\nFeel free to adjust the code as you like.\n\n> **Remember**: To enable auto-tracing you do need to set the environment variables and import the `langchain_mlrun` module before any LangChain code. For batch jobs and realtime functions, make sure you set env vars in the MLRun function and add the import line `langchain_mlrun = mlrun.import_module(\"hub://langchain_mlrun\")` at the top of your code." + }, + { + "cell_type": "code", + "id": "94b4d4b0-8d10-4ad3-8f16-7b1b7daeac11", + "metadata": { + "tags": [], + "ExecuteTime": { + "end_time": "2026-02-03T19:49:24.899991Z", + "start_time": "2026-02-03T19:49:23.869475Z" + } + }, + "source": [ + "import os\n", + "from typing import Literal, TypedDict, Annotated, Sequence, Any, Callable\n", + "from operator import add\n", + "\n", + "from langchain_core.language_models import LanguageModelInput\n", + "from langchain_core.runnables import Runnable, RunnableLambda\n", + "from langchain_core.prompts import ChatPromptTemplate\n", + "from langchain_core.output_parsers import StrOutputParser\n", + "from langchain_core.language_models.fake_chat_models import FakeListChatModel, GenericFakeChatModel\n", + "from langchain.agents import create_agent\n", + "from langchain_core.messages import AIMessage, HumanMessage\n", + "from langchain_core.tools import tool, BaseTool\n", + "\n", + "from langgraph.graph import StateGraph, START, END\n", + "from langchain_core.messages import BaseMessage\n", + "\n", + "\n", + "def _check_openai_credentials() -> bool:\n", + " \"\"\"\n", + " Check if OpenAI API key is set in environment variables.\n", + "\n", + " :return: True if OPENAI_API_KEY is set, False otherwise.\n", + " \"\"\"\n", + " return \"OPENAI_API_KEY\" in os.environ\n", + "\n", + "\n", + "# Import ChatOpenAI only if OpenAI credentials are available (meaning `langchain-openai` must be installed).\n", + "if _check_openai_credentials():\n", + " from langchain_openai import ChatOpenAI\n", + "\n", + " \n", + "class _ToolEnabledFakeModel(GenericFakeChatModel):\n", + " \"\"\"\n", + " A fake chat model that supports tool binding for running agent tracing tests.\n", + " \"\"\"\n", + "\n", + " def bind_tools(\n", + " self,\n", + " tools: Sequence[\n", + " dict[str, Any] | type | Callable | BaseTool # noqa: UP006\n", + " ],\n", + " *,\n", + " tool_choice: str | None = None,\n", + " **kwargs: Any,\n", + " ) -> Runnable[LanguageModelInput, AIMessage]:\n", + " return self\n", + "\n", + "\n", + "#: Tag value for testing tag filtering.\n", + "_dummy_tag = \"dummy_tag\"\n", + "\n", + "\n", + "def run_simple_chain() -> str:\n", + " \"\"\"\n", + " Run a simple LangChain chain that gets a fact about a topic.\n", + " \"\"\"\n", + " # Build a simple chain: prompt -> llm -> str output parser\n", + " llm = ChatOpenAI(\n", + " model=\"gpt-4o-mini\",\n", + " tags=[_dummy_tag]\n", + " ) if _check_openai_credentials() else (\n", + " FakeListChatModel(\n", + " responses=[\n", + " \"MLRun is an open-source orchestrator for machine learning pipelines.\"\n", + " ],\n", + " tags=[_dummy_tag]\n", + " )\n", + " )\n", + " prompt = ChatPromptTemplate.from_template(\"Tell me a short fact about {topic}\")\n", + " chain = prompt | llm | StrOutputParser()\n", + "\n", + " # Run the chain:\n", + " response = chain.invoke({\"topic\": \"MLRun\"})\n", + " return response\n", + "\n", + "\n", + "def run_simple_agent():\n", + " \"\"\"\n", + " Run a simple LangChain agent that uses two tools to get weather and stock price.\n", + " \"\"\"\n", + " # Define the tools:\n", + " @tool\n", + " def get_weather(city: str) -> str:\n", + " \"\"\"Get the current weather for a specific city.\"\"\"\n", + " return f\"The weather in {city} is 22°C and sunny.\"\n", + "\n", + " @tool\n", + " def get_stock_price(symbol: str) -> str:\n", + " \"\"\"Get the current stock price for a symbol.\"\"\"\n", + " return f\"The stock price for {symbol} is $150.25.\"\n", + "\n", + " # Define the model:\n", + " model = ChatOpenAI(\n", + " model=\"gpt-4o-mini\",\n", + " tags=[_dummy_tag]\n", + " ) if _check_openai_credentials() else (\n", + " _ToolEnabledFakeModel(\n", + " messages=iter(\n", + " [\n", + " AIMessage(\n", + " content=\"\",\n", + " tool_calls=[\n", + " {\"name\": \"get_weather\", \"args\": {\"city\": \"London\"}, \"id\": \"call_abc123\"},\n", + " {\"name\": \"get_stock_price\", \"args\": {\"symbol\": \"AAPL\"}, \"id\": \"call_def456\"}\n", + " ]\n", + " ),\n", + " AIMessage(content=\"The weather in London is 22°C and AAPL is trading at $150.25.\")\n", + " ]\n", + " ),\n", + " tags=[_dummy_tag]\n", + " )\n", + " )\n", + "\n", + " # Create the agent:\n", + " agent = create_agent(\n", + " model=model,\n", + " tools=[get_weather, get_stock_price],\n", + " system_prompt=\"You are a helpful assistant with access to tools.\"\n", + " )\n", + "\n", + " # Run the agent:\n", + " return agent.invoke({\"messages\": [\"What is the weather in London and the stock price of AAPL?\"]})\n", + "\n", + "\n", + "def run_langgraph_graph():\n", + " \"\"\"\n", + " Run a LangGraph agent that uses reflection to correct its answer.\n", + " \"\"\"\n", + " # Define the graph state:\n", + " class AgentState(TypedDict):\n", + " messages: Annotated[list[BaseMessage], add]\n", + " attempts: int\n", + "\n", + " # Define the model:\n", + " model = ChatOpenAI(model=\"gpt-4o-mini\") if _check_openai_credentials() else (\n", + " _ToolEnabledFakeModel(\n", + " messages=iter(\n", + " [\n", + " AIMessage(content=\"There are 2 'r's in Strawberry.\"), # Mocking the failure\n", + " AIMessage(content=\"I stand corrected. S-t-r-a-w-b-e-r-r-y. There are 3 'r's.\"), # Mocking the fix\n", + " ]\n", + " )\n", + " )\n", + " )\n", + "\n", + " # Define the graph nodes and router:\n", + " def call_model(state: AgentState):\n", + " response = model.invoke(state[\"messages\"])\n", + " return {\"messages\": [response], \"attempts\": state[\"attempts\"] + 1}\n", + "\n", + " def reflect_node(state: AgentState):\n", + " prompt = \"Wait, count the 'r's again slowly, letter by letter. Are you sure?\"\n", + " return {\"messages\": [HumanMessage(content=prompt)]}\n", + "\n", + " def router(state: AgentState) -> Literal[\"reflect\", END]:\n", + " # Make sure there are 2 attempts at least for an answer:\n", + " if state[\"attempts\"] == 1:\n", + " return \"reflect\"\n", + " return END\n", + "\n", + " # Build the graph:\n", + " builder = StateGraph(AgentState)\n", + " builder.add_node(\"model\", call_model)\n", + " tagged_reflect_node = RunnableLambda(reflect_node).with_config(tags=[_dummy_tag])\n", + " builder.add_node(\"reflect\", tagged_reflect_node)\n", + " builder.add_edge(START, \"model\")\n", + " builder.add_conditional_edges(\"model\", router)\n", + " builder.add_edge(\"reflect\", \"model\")\n", + " graph = builder.compile()\n", + "\n", + " # Run the graph:\n", + " return graph.invoke({\"messages\": [HumanMessage(content=\"How many 'r's in Strawberry?\")], \"attempts\": 0})" + ], + "outputs": [], + "execution_count": 8 + }, + { + "cell_type": "markdown", + "id": "49964f96-89ba-4f61-8788-38290a877aa2", + "metadata": {}, + "source": "Let's create some traffic, we'll run whatever function you want in a loop to get some events. We take timestamps in order to use them later to run the monitoring application on the data we'll send." + }, + { + "cell_type": "code", + "id": "b7e6418d-76f4-4b18-9ef9-c5bb40b20545", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-03T22:05:54.601563Z", + "start_time": "2026-02-03T22:05:52.518385Z" + } + }, + "source": [ + "# Run LangChain code and now it should be tracked and monitored in MLRun:\n", + "start_timestamp = datetime.datetime.now() - datetime.timedelta(minutes=1)\n", + "for i in range(20):\n", + " run_simple_agent()\n", + "end_timestamp = datetime.datetime.now() + datetime.timedelta(minutes=5)" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> 2026-02-04 00:05:52,553 [info] Project loaded successfully: {\"project_name\":\"langchain-mlrun-tutorial\"}\n" + ] + } + ], + "execution_count": 13 + }, + { + "cell_type": "markdown", + "id": "d9085765-91fd-4d31-84b4-927ecf9cc455", + "metadata": {}, + "source": "> **Note**: Please wait a minute or two until the events are processed." + }, + { + "cell_type": "code", + "id": "85fae3e4-5f1b-4f0c-ba71-81060f10804f", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-03T19:50:26.655189Z", + "start_time": "2026-02-03T19:49:26.648461Z" + } + }, + "source": [ + "time.sleep(60)" + ], + "outputs": [], + "execution_count": 10 + }, + { + "cell_type": "markdown", + "id": "2475ebec-fc32-4884-9723-3ca9cfde577f", + "metadata": {}, + "source": [ + "### Test the LangChain Monitoring Application\n", + "\n", + "To test a monitoring application, we use the `evaluate` class method. We'll run an evaluation on the data we just sent. It is a small local job and should run fast.\n", + "\n", + "Keep an eye for the returned metrics from the monitoring application." + ] + }, + { + "cell_type": "code", + "id": "3d046755-9153-497a-a024-5d63316e1f91", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-03T19:50:28.003195Z", + "start_time": "2026-02-03T19:50:26.670024Z" + } + }, + "source": [ + "LangChainMonitoringApp.evaluate(\n", + " func_name=\"langchain-monitoring-app-test\",\n", + " func_path=\"langchain_mlrun.py\",\n", + " run_local=True,\n", + " endpoints=[env_vars[\"LC_MLRUN_TRACER_CLIENT_MODEL_ENDPOINT_NAME\"]],\n", + " start=start_timestamp.isoformat(),\n", + " end=end_timestamp.isoformat(),\n", + ")" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> 2026-02-03 21:50:26,671 [info] Changing function name - adding `\"-batch\"` suffix: {\"func_name\":\"langchain-monitoring-app-test-batch\"}\n", + "> 2026-02-03 21:50:26,815 [warning] It is recommended to use k8s secret (specify secret_name), specifying aws_access_key/aws_secret_key directly is unsafe.\n", + "> 2026-02-03 21:50:26,829 [info] Storing function: {\"db\":\"http://localhost:30070\",\"name\":\"langchain-monitoring-app-test-batch--handler\",\"uid\":\"f2c3c94681094915beb2c5c1ccc0dac8\"}\n", + "> 2026-02-03 21:50:27,953 [warning] No data was found for any of the specified endpoints. No results were produced: {\"application_name\":\"langchain-monitoring-app-test-batch\",\"end\":\"2026-02-03T21:54:26.640556\",\"endpoints\":[\"langchain_mlrun_endpoint\"],\"start\":\"2026-02-03T21:48:24.904667\"}\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
projectuiditerstartendstatekindnamelabelsinputsparametersresults
langchain-mlrun-tutorial
...c0dac8
0Feb 03 19:50:26NaTcompletedrunlangchain-monitoring-app-test-batch--handler
kind=local
owner=Tomer_Weitzman
host=M-QXN63PHMF9
endpoints=['langchain_mlrun_endpoint']
start=2026-02-03T21:48:24.904667
end=2026-02-03T21:54:26.640556
base_period=None
write_output=False
existing_data_handling=fail_on_overlap
stream_profile=None
\n", + "
\n", + "
\n", + "
\n", + " Title\n", + " ×\n", + "
\n", + " \n", + "
\n", + "
\n" + ] + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ], + "text/html": [ + " > to track results use the .show() or .logs() methods " + ] + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> 2026-02-03 21:50:28,001 [info] Run execution finished: {\"name\":\"langchain-monitoring-app-test-batch--handler\",\"status\":\"completed\"}\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 11 + }, + { + "cell_type": "markdown", + "id": "eda724c3-27f3-4d28-a7ba-1e59b9be2a37", + "metadata": {}, + "source": "### Deploy the Monitoring Application\n\nAll that's left to do now is to deploy our monitoring application!" + }, + { + "cell_type": "code", + "id": "652b00d4-070d-4849-9784-4d461cb83eae", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-03T19:52:29.502406Z", + "start_time": "2026-02-03T19:50:28.009318Z" + } + }, + "source": "# Deploy the monitoring app:\nLangChainMonitoringApp.deploy(\n func_name=\"langchain-monitoring-app\",\n func_path=\"langchain_mlrun.py\",\n image=\"mlrun/mlrun\",\n requirements=[\n \"langchain\",\n \"pydantic-settings\",\n ],\n)", + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "c23bef7a-cbdb-4b22-a2d9-2edbfde5eb04", + "metadata": {}, + "source": [ + "Once it is deployed, you can run events again and see the monitoring application in MLRun UI in action:\n", + "\n", + "![mlrun ui example](./notebook_images/mlrun_ui.png)" + ] + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "", + "id": "fc994d2114a89a25" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/modules/src/langchain_mlrun/langchain_mlrun.py b/modules/src/langchain_mlrun/langchain_mlrun.py new file mode 100644 index 000000000..920354bfb --- /dev/null +++ b/modules/src/langchain_mlrun/langchain_mlrun.py @@ -0,0 +1,1840 @@ +# Copyright 2026 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +MLRun to LangChain integration - a tracer that converts LangChain Run objects into serializable event and send them to +MLRun monitoring. +""" + +from abc import ABC, abstractmethod +import copy +import importlib +import orjson +import os +import socket +from uuid import UUID +import threading +from contextlib import contextmanager +from contextvars import ContextVar +import datetime +from typing import Any, Callable, Generator, Optional + +from langchain_core.tracers import BaseTracer, Run +from langchain_core.tracers.context import register_configure_hook + +from pydantic import Field, field_validator, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict +from uuid_utils import uuid7 + +import mlrun +from mlrun.runtimes import RemoteRuntime +from mlrun.model_monitoring.applications import ( + ModelMonitoringApplicationBase, ModelMonitoringApplicationMetric, + ModelMonitoringApplicationResult, MonitoringApplicationContext, +) +import mlrun.common.schemas.model_monitoring.constants as mm_constants + +#: Environment variable name to use MLRun monitoring tracer via LangChain global tracing system: +mlrun_monitoring_env_var = "LC_MLRUN_MONITORING_ENABLED" + + +class _MLRunEndPointClient(ABC): + """ + An MLRun model endpoint monitoring client base class to connect and send events on a monitoring stream. + """ + + def __init__( + self, + model_endpoint_name: str, + model_endpoint_uid: str, + serving_function: str | RemoteRuntime, + serving_function_tag: str | None = None, + project: str | mlrun.projects.MlrunProject = None, + ): + """ + Initialize an MLRun model endpoint monitoring client. + + :param model_endpoint_name: The monitoring endpoint related model name. + :param model_endpoint_uid: Model endpoint unique identifier. + :param serving_function: Serving function name or ``RemoteRuntime`` object. + :param serving_function_tag: Optional function tag (defaults to 'latest'). + :param project: Project name or ``MlrunProject``. If ``None``, uses the current project. + raise: MLRunInvalidArgumentError: If there is no current active project and no `project` argument was provided. + """ + # Store the provided info: + self._model_endpoint_name = model_endpoint_name + self._model_endpoint_uid = model_endpoint_uid + + # Load project: + if project is None: + try: + self._project_name = mlrun.get_current_project(silent=False).name + except mlrun.errors.MLRunInvalidArgumentError: + raise mlrun.errors.MLRunInvalidArgumentError( + "There is no current active project. Either use `mlrun.get_or_create_project` prior to " + "initializing the monitoring tracer or pass a project name to load. You can also set the " + "environment variable: 'LC_MLRUN_TRACER_CLIENT_PROJECT'." + ) + elif isinstance(project, str): + self._project_name = project + else: + self._project_name = project.name + + # Load function: + if isinstance(serving_function, str): + self._serving_function_name = serving_function + self._serving_function_tag = serving_function_tag or "latest" + else: + self._serving_function_name = serving_function.metadata.name + self._serving_function_tag = ( + serving_function_tag or serving_function.metadata.tag + ) + + # Prepare the sample: + self._event_sample = { + "class": "CustomStream", + "worker": "0", + "model": self._model_endpoint_name, + "host": socket.gethostname(), + "function_uri": f"{self._project_name}/{self._serving_function_name}:{self._serving_function_tag}", + "endpoint_id": self._model_endpoint_uid, + "sampling_percentage": 100, + "request": {"inputs": [], "background_task_state": "succeeded"}, + "op": "infer", + "resp": { + "id": None, + "model_name": self._model_endpoint_name, + "outputs": [], + "timestamp": None, + "model_endpoint_uid": self._model_endpoint_uid, + }, + "when": None, + "microsec": 496, + "effective_sample_count": 1, + } + + @abstractmethod + def monitor( + self, + event_id: str, + label: str, + input_data: dict, + output_data: dict, + request_timestamp: str, + response_timestamp: str, + ): + """ + Monitor the provided event, sending it to the model endpoint monitoring stream. + + :param event_id: Unique event identifier used as the monitored record id. + :param label: Label for the run/event. + :param input_data: Serialized input data for the run. + :param output_data: Serialized output data for the run. + :param request_timestamp: Request/start timestamp in the format of '%Y-%m-%d %H:%M:%S%z'. + :param response_timestamp: Response/end timestamp in the format of '%Y-%m-%d %H:%M:%S%z'. + """ + pass + + def flush(self): + """ + Flush any buffered messages to ensure they are sent to the stream. + + For streaming backends that buffer messages (like Kafka), this ensures delivery. For backends that send + immediately (like V3IO), this may be a no-op. + """ + pass + + def _create_event( + self, + event_id: str, + label: str, + input_data: dict, + output_data: dict, + request_timestamp: str, + response_timestamp: str, + ) -> dict: + """ + Create a new event out of the stored event sample. + + :param event_id: Unique event identifier used as the monitored record id. + :param label: Label for the run/event. + :param input_data: Serialized input data for the run. + :param output_data: Serialized output data for the run. + :param request_timestamp: Request/start timestamp in the format of '%Y-%m-%d %H:%M:%S%z'. + :param response_timestamp: Response/end timestamp in the format of '%Y-%m-%d %H:%M:%S%z'. + + :returns: The event to send to the monitoring stream. + """ + # Copy the sample: + event = copy.deepcopy(self._event_sample) + + # Edit event with given parameters: + event["when"] = request_timestamp + event["request"]["inputs"].append(orjson.dumps({"label": label, "input": input_data}).decode('utf-8')) + event["resp"]["timestamp"] = response_timestamp + event["resp"]["outputs"].append(orjson.dumps(output_data).decode('utf-8')) + event["resp"]["id"] = event_id + + return event + + +class _V3IOMLRunEndPointClient(_MLRunEndPointClient): + """ + An MLRun model endpoint monitoring client to connect and send events on a V3IO stream. + """ + + def __init__( + self, + monitoring_stream_path: str, + monitoring_container: str, + model_endpoint_name: str, + model_endpoint_uid: str, + serving_function: str | RemoteRuntime, + serving_function_tag: str | None = None, + project: str | mlrun.projects.MlrunProject = None, + ): + """ + Initialize an MLRun model endpoint monitoring client. + + :param monitoring_stream_path: V3IO stream path. + :param monitoring_container: V3IO container name. + :param model_endpoint_name: The monitoring endpoint related model name. + :param model_endpoint_uid: Model endpoint unique identifier. + :param serving_function: Serving function name or ``RemoteRuntime`` object. + :param serving_function_tag: Optional function tag (defaults to 'latest'). + :param project: Project name or ``MlrunProject``. If ``None``, uses the current project. + raise: MLRunInvalidArgumentError: If there is no current active project and no `project` argument was provided. + """ + super().__init__( + model_endpoint_name=model_endpoint_name, + model_endpoint_uid=model_endpoint_uid, + serving_function=serving_function, + serving_function_tag=serving_function_tag, + project=project, + ) + + import v3io + + # Store the provided info: + self._monitoring_stream_path = monitoring_stream_path + self._monitoring_container = monitoring_container + + # Initialize a V3IO client: + self._v3io_client = v3io.Client() + + def monitor( + self, + event_id: str, + label: str, + input_data: dict, + output_data: dict, + request_timestamp: str, + response_timestamp: str, + ): + """ + Monitor the provided event, sending it to the model endpoint monitoring stream. + + :param event_id: Unique event identifier used as the monitored record id. + :param label: Label for the run/event. + :param input_data: Serialized input data for the run. + :param output_data: Serialized output data for the run. + :param request_timestamp: Request/start timestamp in the format of '%Y-%m-%d %H:%M:%S%z'. + :param response_timestamp: Response/end timestamp in the format of '%Y-%m-%d %H:%M:%S%z'. + """ + # Copy the sample: + event = self._create_event( + event_id=event_id, + label=label, + input_data=input_data, + output_data=output_data, + request_timestamp=request_timestamp, + response_timestamp=response_timestamp, + ) + + # Push to stream: + self._v3io_client.stream.put_records( + container=self._monitoring_container, + stream_path=self._monitoring_stream_path, + records=[{"data": orjson.dumps(event).decode('utf-8')}], + ) + + +class _KafkaMLRunEndPointClient(_MLRunEndPointClient): + """ + An MLRun model endpoint monitoring client to connect and send events on a Kafka stream. + """ + + def __init__( + self, + kafka_stream_profile_name: str, + model_endpoint_name: str, + model_endpoint_uid: str, + serving_function: str | RemoteRuntime, + serving_function_tag: str | None = None, + project: str | mlrun.projects.MlrunProject = None, + kafka_linger_ms: int = 0, + ): + """ + Initialize an MLRun model endpoint monitoring client for Kafka. + + :param kafka_stream_profile_name: The name of the registered DatastoreProfileKafkaStream to use for Kafka + configuration. This profile should be registered via ``project.register_datastore_profile()`` and + contains all Kafka settings including broker, topic, SASL credentials, SSL config, etc. + :param model_endpoint_name: The monitoring endpoint related model name. + :param model_endpoint_uid: Model endpoint unique identifier. + :param serving_function: Serving function name or ``RemoteRuntime`` object. + :param serving_function_tag: Optional function tag (defaults to 'latest'). + :param project: Project name or ``MlrunProject``. If ``None``, uses the current project. + :param kafka_linger_ms: Kafka producer linger.ms setting controlling message batching. Messages are + accumulated for up to this duration before being sent as a batch. Default: 500ms. + raise: MLRunInvalidArgumentError: If there is no current active project and no `project` argument was provided. + """ + super().__init__( + model_endpoint_name=model_endpoint_name, + model_endpoint_uid=model_endpoint_uid, + serving_function=serving_function, + serving_function_tag=serving_function_tag, + project=project, + ) + + from kafka import KafkaProducer + from mlrun.datastore.utils import KafkaParameters + from mlrun.common.model_monitoring.helpers import get_kafka_topic + + # Get project object using resolved project name from parent: + project_obj = mlrun.get_or_create_project(self._project_name) + + # Fetch the Kafka stream profile: + stream_profile = project_obj.get_datastore_profile(profile=kafka_stream_profile_name) + + # Get profile attributes and convert to producer config: + profile_attrs = stream_profile.attributes() + kafka_params = KafkaParameters(kwargs=profile_attrs) + producer_config = kafka_params.producer() + + # Extract broker and determine topic (use profile's topic if available, otherwise use MLRun's standard naming): + self._monitoring_broker = profile_attrs.get("brokers") + topics = profile_attrs.get("topics", []) + self._monitoring_topic = topics[0] if topics else get_kafka_topic(project=project_obj.name) + + # Remove bootstrap_servers from producer_config to avoid duplicate argument error: + producer_config.pop("bootstrap_servers", None) + + # Initialize a Kafka producer with full config from profile: + self._kafka_producer = KafkaProducer( + bootstrap_servers=self._monitoring_broker, + key_serializer=lambda k: k.encode("utf-8") if isinstance(k, str) else k, + value_serializer=( + lambda v: v if isinstance(v, bytes) + else orjson.dumps(v) if isinstance(v, dict) + else str(v).encode("utf-8") + ), + linger_ms=kafka_linger_ms, + **producer_config, + ) + + def monitor( + self, + event_id: str, + label: str, + input_data: dict, + output_data: dict, + request_timestamp: str, + response_timestamp: str, + ): + """ + Monitor the provided event, sending it to the model endpoint monitoring stream. + + :param event_id: Unique event identifier used as the monitored record id. + :param label: Label for the run/event. + :param input_data: Serialized input data for the run. + :param output_data: Serialized output data for the run. + :param request_timestamp: Request/start timestamp in the format of '%Y-%m-%d %H:%M:%S%z'. + :param response_timestamp: Response/end timestamp in the format of '%Y-%m-%d %H:%M:%S%z'. + """ + # Copy the sample: + event = self._create_event( + event_id=event_id, + label=label, + input_data=input_data, + output_data=output_data, + request_timestamp=request_timestamp, + response_timestamp=response_timestamp, + ) + + # Push to stream (async - message is buffered): + self._kafka_producer.send( + topic=self._monitoring_topic, + value=event, # Will be serialized by the value_serializer + key=self._model_endpoint_uid, + ) + + def flush(self): + """ + Flush all buffered messages to ensure they are sent to Kafka. + + Blocks until all buffered messages are delivered and acknowledged by the broker. + """ + self._kafka_producer.flush() + + +class MLRunTracerClientSettings(BaseSettings): + """ + MLRun tracer monitoring client configurations. These are mandatory arguments for allowing MLRun to send monitoring + events to a specific model endpoint stream. + """ + + v3io_stream_path: str | None = None + """ + The V3IO stream path to send the events to. + """ + + v3io_container: str | None = None + """ + The V3IO stream container. + """ + + kafka_stream_profile_name: str | None = None + """ + The name of the registered DatastoreProfileKafkaStream to use for Kafka configuration. This profile should be + registered via ``project.register_datastore_profile()`` and contains all Kafka settings including broker, topic, + SASL credentials, SSL config, etc. + """ + + kafka_linger_ms: int = 500 + """ + The Kafka producer linger.ms setting controlling message batching (in milliseconds). Messages are accumulated for + up to this duration before being sent as a batch, reducing network overhead. + + The tracer always flushes at the end of each root run, guaranteeing delivery regardless of this setting. + Default: 500ms. Set to 0 to disable batching (each message sent immediately). + """ + + model_endpoint_name: str = ... + """ + The model endpoint name. + """ + + model_endpoint_uid: str = ... + """ + The model endpoint UID. + """ + + serving_function: str = ... + """ + The serving function name. + """ + + serving_function_tag: str | None = None + """ + The serving function tag. If not set, it will be 'latest' by default. + """ + + project: str | None = None + """ + The MLRun project name related to the serving function and model endpoint. + """ + + #: Pydantic model configuration to set the environment variable prefix. + model_config = SettingsConfigDict(env_prefix="LC_MLRUN_TRACER_CLIENT_") + + @model_validator(mode='after') + def validate_stream_settings(self) -> 'MLRunTracerClientSettings': + """ + Validate that either V3IO settings or stream profile name is provided, but not both or none. + + :returns: The validated settings instance. + """ + v3io_settings = all([self.v3io_container, self.v3io_stream_path]) + kafka_settings = self.kafka_stream_profile_name is not None + + if v3io_settings and kafka_settings: + raise ValueError("Provide either V3IO settings OR Kafka settings, not both.") + if not v3io_settings and not kafka_settings: + raise ValueError("You must provide either a complete V3IO settings or complete Kafka settings. See docs for more information") + return self + +class MLRunTracerMonitorSettings(BaseSettings): + """ + MLRun tracer monitoring configurations. These are optional arguments to customize the LangChain runs summarization + into monitorable MLRun endpoint events. If needed, a custom summarization can be passed. + """ + + label: str = "default" + """ + Label to use for all monitored runs. Can be used to differentiate between different monitored sources on the same + endpoint. + """ + + tags_filter: list[str] | None = None + """ + Filter runs by tags. Only runs with at least one tag in this list will be monitored. + If None, no tag-based filtering is applied and runs with any tags are considered. + Default: None. + """ + + run_types_filter: list[str] | None = None + """ + Filter runs by run types (e.g. "chain", "llm", "chat", "tool"). + Only runs whose `run_type` appears in this list will be monitored. + If None, no run-type filtering is applied. + Default: None. + """ + + names_filter: list[str] | None = None + """ + Filter runs by class/name. Only runs whose `name` appears in this list will be monitored. + If None, no name-based filtering is applied. + Default: None. + """ + + include_full_run: bool = False + """ + If True, include the complete serialized run dict (the output of `run._get_dicts_safe()`) + in the event outputs under the key `full_run`. Useful for debugging or when consumers need + the raw run payload. Default: False. + """ + + include_errors: bool = True + """ + If True, include run error information in the outputs under the `error` key. + If False, runs that contain an error may be skipped by the summarizer filters. + Default: True. + """ + + include_metadata: bool = True + """ + If True, include run metadata (environment, tool metadata, etc.) in the inputs under + the `metadata` key. Default: True. + """ + + include_latency: bool = True + """ + If True, include latency information in the outputs under the `latency` key. + Default: True. + """ + + root_run_only: bool = False + """ + If True, only the root/top-level run will be monitored and any child runs will be + ignored/removed from monitoring. Use when only the top-level run should produce events. + Default: False. + """ + + split_runs: bool = False + """ + If True, child runs are emitted as separate monitoring events (each run summarized and + sent individually). If False, child runs are nested inside the parent/root run event under + `child_runs`. Default: False. + """ + + run_summarizer_function: ( + str + | Callable[ + [Run, Optional[BaseSettings]], + Generator[tuple[dict, dict] | None, None, None], + ] + | None + ) = None + """ + A function to summarize a `Run` object into a tuple of inputs and outputs. Can be passed directly or via a full + module path ("a.b.c.my_summarizer" will be imported as `from a.b.c import my_summarizer`). + + A summarizer is a function that will be used to process a run into monitoring events. The function is expected to be + of type: + `Callable[[Run, Optional[BaseSettings]], Generator[tuple[dict, dict] | None, None, None]]`, meaning + get a run object and optionally a settings object and return a generator yielding tuples of serialized dictionaries, + the (inputs, outputs) to send to MLRun monitoring as events or `None` to skip monitoring this run. + """ + + run_summarizer_settings: str | BaseSettings | None = None + """ + Settings to pass to the run summarizer function. Can be passed directly or via a full module path to be imported + and initialized. If the summarizer function does not require settings, this can be left as None. + """ + + debug: bool = False + """ + If True, disable sending events to MLRun and instead route events to `debug_target_list` + or print them as JSON to stdout. Useful for unit tests and local debugging. Default: False. + """ + + debug_target_list: list[dict] | bool = False + """ + Optional list to which debug events will be appended when `debug` is True. + If set, each generated event dict will be appended to this list. If not set and `debug` is True, + events will be printed to stdout as JSON. Default: False. + """ + + #: Pydantic model configuration to set the environment variable prefix. + model_config = SettingsConfigDict(env_prefix="LC_MLRUN_TRACER_MONITOR_") + + @field_validator('debug_target_list', mode='before') + @classmethod + def convert_bool_to_list(cls, v): + """ + Convert a boolean `True` value to an empty list for `debug_target_list`. + + :param v: The value to validate. + + :returns: An empty list if `v` is True, otherwise the original value. + """ + if v is True: + return [] + return v + + +class MLRunTracerSettings(BaseSettings): + """ + MLRun tracer settings to configure the tracer. The settings are split into two groups: + + * `client`: settings required to connect and send events to the MLRun monitoring stream. + * `monitor`: settings controlling which LangChain runs are summarized and sent and how. + """ + + client: MLRunTracerClientSettings = Field(default_factory=MLRunTracerClientSettings) + """ + Client configuration group (``MLRunTracerClientSettings``). + + Contains the mandatory connection and endpoint information required to publish monitoring + events. Values may be supplied programmatically or via environment variables prefixed with + `LC_MLRUN_TRACER_CLIENT_`. See more at ``MLRunTracerClientSettings``. + """ + + monitor: MLRunTracerMonitorSettings = Field(default_factory=MLRunTracerMonitorSettings) + """ + Monitoring configuration group (``MLRunTracerMonitorSettings``). + + Controls what runs are captured, how they are summarized (including custom summarizer import + options), whether child runs are split or nested, and debug behavior. Values may be supplied + programmatically or via environment variables prefixed with `LC_MLRUN_TRACER_MONITOR_`. + See more at ``MLRunTracerMonitorSettings``. + """ + + #: Pydantic model configuration to set the environment variable prefix. + model_config = SettingsConfigDict(env_prefix="LC_MLRUN_TRACER_") + + +class MLRunTracer(BaseTracer): + """ + MLRun tracer for LangChain runs allowing monitoring LangChain and LangGraph in production using MLRun's monitoring. + + There are two usage modes for the MLRun tracer following LangChain tracing best practices: + + 1. **Manual Mode** - Using the ``mlrun_monitoring`` context manager:: + + from mlrun_tracer import mlrun_monitoring + + with mlrun_monitoring(...) as tracer: + # LangChain code here. + pass + + 2. **Auto Mode** - Setting the `LC_MLRUN_MONITORING_ENABLED="1"` environment variable:: + + import mlrun_integration.tracer + + # All LangChain code will be automatically traced and monitored. + pass + + To control how runs are being summarized into the events being monitored, the ``MLRunTracerSettings`` can be set. + As it is a Pydantic ``BaseSettings`` class, it can be done in two ways: + + 1. Initializing the settings classes and passing them to the context manager:: + + from mlrun_tracer import ( + mlrun_monitoring, + MLRunTracerSettings, + MLRunTracerClientSettings, + MLRunTracerMonitorSettings, + ) + + my_settings = MLRunTracerSettings( + client=MLRunTracerClientSettings(), + monitor=MLRunTracerMonitorSettings(root_run_only=True), + ) + + with mlrun_monitoring(settings=my_settings) as tracer: + # LangChain code here. + pass + + 2. Or via environment variables following the prefix 'LC_MLRUN_TRACER_CLIENT_' for client settings and + 'LC_MLRUN_TRACER_MONITOR_' for monitoring settings. + """ + + #: A singleton tracer for when using the tracer via environment variable to activate global tracing. + _singleton_tracer: "MLRunTracer | None" = None + #: A thread lock for initializing the tracer singleton safely. + _lock = threading.Lock() + #: A boolean flag to know whether the singleton was initialized. + _initialized = False + + def __new__(cls, *args, **kwargs) -> "MLRunTracer": + """ + Create or return an ``MLRunTracer`` instance. + + When ``LC_MLRUN_MONITORING_ENABLED`` is not set to ``"1"``, a normal instance is returned. + When the env var is ``"1"``, a process-wide singleton is returned. Creation is thread-safe. + + :returns: MLRunTracer instance (singleton if 'auto' mode is active). + """ + # Check if needed to use a singleton as the user is using the MLRun tracer by setting the environment variable + # and not manually (via context manager): + if not cls._check_for_env_var_usage(): + return super(MLRunTracer, cls).__new__(cls) + + # Check if the singleton is set: + if cls._singleton_tracer is None: + # Acquire lock to initialize the singleton: + with cls._lock: + # Double-check after acquiring lock: + if cls._singleton_tracer is None: + cls._singleton_tracer = super(MLRunTracer, cls).__new__(cls) + + return cls._singleton_tracer + + def __init__(self, settings: MLRunTracerSettings = None, **kwargs): + """ + Initialize the tracer. + + :param settings: Settings to use for the tracer. If not passed, defaults are used and environment variables are + applied per Pydantic settings behavior. + :param kwargs: Passed to the base initializer. + """ + # Proceed with initialization only if singleton mode is not required or the singleton was not initialized: + if self._check_for_env_var_usage() and self._initialized: + return + + # Call the base tracer init: + super().__init__(**kwargs) + + # Set a UID for this instance: + self._uid = uuid7() + + # Set the settings: + self._settings = settings or MLRunTracerSettings() + self._client_settings = self._settings.client + self._monitor_settings = self._settings.monitor + + # Initialize the MLRun endpoint client: + self._mlrun_client = ( + self._get_mlrun_client() + if not self._monitor_settings.debug + else None + ) + + # In case the user passed a custom summarizer, import it: + self._custom_run_summarizer_function: ( + Callable[ + [Run, Optional[BaseSettings]], + Generator[tuple[dict, dict] | None, None, None], + ] + | None + ) = None + self._custom_run_summarizer_settings: BaseSettings | None = None + self._import_custom_run_summarizer() + + # Mark the initialization flag (for the singleton case): + self._initialized = True + + @property + def settings(self) -> MLRunTracerSettings: + """ + Access the effective settings. + + :returns: The settings used by this tracer. + """ + return self._settings + + def _get_mlrun_client(self) -> _MLRunEndPointClient: + """ + Create and return an MLRun model endpoint monitoring client based on the MLRun (CE or not) and current + configuration. + + :returns: An MLRun model endpoint monitoring client. + """ + if mlrun.mlconf.is_ce_mode(): + return _KafkaMLRunEndPointClient( + kafka_stream_profile_name=self._client_settings.kafka_stream_profile_name, + model_endpoint_name=self._client_settings.model_endpoint_name, + model_endpoint_uid=self._client_settings.model_endpoint_uid, + serving_function=self._client_settings.serving_function, + serving_function_tag=self._client_settings.serving_function_tag, + project=self._client_settings.project, + kafka_linger_ms=self._client_settings.kafka_linger_ms, + ) + return _V3IOMLRunEndPointClient( + monitoring_stream_path=self._client_settings.v3io_stream_path, + monitoring_container=self._client_settings.v3io_container, + model_endpoint_name=self._client_settings.model_endpoint_name, + model_endpoint_uid=self._client_settings.model_endpoint_uid, + serving_function=self._client_settings.serving_function, + serving_function_tag=self._client_settings.serving_function_tag, + project=self._client_settings.project, + ) + + def _import_custom_run_summarizer(self): + """ + Import or assign a custom run summarizer (and its custom settings) if configured. + """ + # If the user did not pass a run summarizer function, return: + if not self._monitor_settings.run_summarizer_function: + return + + # Check if the function needs to be imported: + if isinstance(self._monitor_settings.run_summarizer_function, str): + self._custom_run_summarizer_function = self._import_from_module_path( + module_path=self._monitor_settings.run_summarizer_function + ) + else: + self._custom_run_summarizer_function = ( + self._monitor_settings.run_summarizer_function + ) + + # Check if the user passed settings as well: + if self._monitor_settings.run_summarizer_settings: + # Check if the settings need to be imported: + if isinstance(self._monitor_settings.run_summarizer_settings, str): + self._custom_run_summarizer_settings = self._import_from_module_path( + module_path=self._monitor_settings.run_summarizer_settings + )() + else: + self._custom_run_summarizer_settings = ( + self._monitor_settings.run_summarizer_settings + ) + + def _persist_run(self, run: Run, level: int = 0) -> None: + """ + Summarize the run (and its children) into MLRun monitoring events. + + Note: This will use the MLRun tracer's default summarization that can be configured via + ``MLRunTracerMonitorSettings``, unless a custom summarizer was provided (via the same settings). + + :param run: LangChain run object to process holding all the nested tree of runs. + :param level: The nesting level of the run (0 for root runs, incremented for child runs). + """ + try: + # Serialize the run: + serialized_run = self._serialize_run( + run=run, + include_child_runs=not (self._settings.monitor.root_run_only or self._settings.monitor.split_runs) + ) + + # Check for a user custom run summarizer function: + if self._custom_run_summarizer_function: + for summarized_run in self._custom_run_summarizer_function( + run, self._custom_run_summarizer_settings + ): + if summarized_run: + inputs, outputs = summarized_run + self._send_run_event( + event_id=serialized_run["id"], + inputs=inputs, + outputs=outputs, + start_time=run.start_time, + end_time=run.end_time, + ) + return + + # Check how to deal with the child runs, monitor them in separate events or as a single event: + if self._monitor_settings.split_runs and not self._settings.monitor.root_run_only: + # Monitor as separate events: + for child_run in run.child_runs: + self._persist_run(run=child_run, level=level + 1) + summarized_run = self._summarize_run(serialized_run=serialized_run, include_children=False) + if summarized_run: + inputs, outputs = summarized_run + inputs["child_level"] = level + self._send_run_event( + event_id=serialized_run["id"], + inputs=inputs, + outputs=outputs, + start_time=run.start_time, + end_time=run.end_time, + ) + return + + # Monitor the root event (include child runs if `root_run_only` is False): + summarized_run = self._summarize_run( + serialized_run=serialized_run, + include_children=not self._monitor_settings.root_run_only + ) + if not summarized_run: + return + inputs, outputs = summarized_run + inputs["child_level"] = level + self._send_run_event( + event_id=serialized_run["id"], + inputs=inputs, + outputs=outputs, + start_time=run.start_time, + end_time=run.end_time, + ) + finally: + # Flush buffered messages after root run completion to ensure delivery: + if level == 0 and self._mlrun_client: + self._mlrun_client.flush() + + def _serialize_run(self, run: Run, include_child_runs: bool) -> dict: + """ + Serialize a LangChain run into a dictionary. + + :param run: The run to serialize. + :param include_child_runs: Whether to include child runs in the serialization. + + :returns: The serialized run dictionary. + """ + # In LangChain 1.2.3+, the Run model uses Pydantic v2 with child_runs marked as Field(exclude=True), so we + # must manually serialize child runs. Still excluding manually for future compatibility. In previous + # LangChain versions, Run was Pydantic v1, so we use dict. + serialized_run = ( + run.model_dump(exclude={"child_runs"}) + if hasattr(run, "model_dump") + else run.dict(exclude={"child_runs"}) + ) + + # Manually serialize child runs if needed: + if include_child_runs and run.child_runs: + serialized_run["child_runs"] = [ + self._serialize_run(child_run, include_child_runs=True) + for child_run in run.child_runs + ] + + return orjson.loads(orjson.dumps(serialized_run, default=self._serialize_default)) + + def _serialize_default(self, obj: Any): + """ + Default serializer for objects present in LangChain run that are not serializable by default JSON encoder. It + includes handling Pydantic v1 and v2 models, UUIDs, and datetimes. + + :param obj: The object to serialize. + + :returns: The serialized object. + """ + if isinstance(obj, UUID): + return str(obj) + if isinstance(obj, datetime.datetime): + return obj.isoformat() + if hasattr(obj, "model_dump"): + return orjson.loads(orjson.dumps(obj.model_dump(), default=self._serialize_default)) + if hasattr(obj, "dict"): + return orjson.loads(orjson.dumps(obj.dict(), default=self._serialize_default)) + return str(obj) + + def _filter_by_tags(self, serialized_run: dict) -> bool: + """ + Apply tag-based filtering. + + :param serialized_run: Serialized run dictionary. + + :returns: True if the run passes tag filters or if no tag filter is configured. + """ + # Check if the user enabled filtering by tags: + if not self._monitor_settings.tags_filter: + return True + + # Filter the run: + return not set(self._monitor_settings.tags_filter).isdisjoint( + serialized_run["tags"] + ) + + def _filter_by_run_types(self, serialized_run: dict) -> bool: + """ + Apply run-type filtering. + + :param serialized_run: Serialized run dictionary. + + :returns: True if the run's ``run_type`` is allowed or if no run-type filter is configured. + """ + # Check if the user enabled filtering by run types: + if not self._monitor_settings.run_types_filter: + return True + + # Filter the run: + return serialized_run["run_type"] in self._monitor_settings.run_types_filter + + def _filter_by_names(self, serialized_run: dict) -> bool: + """ + Apply class/name filtering. + + :param serialized_run: Serialized run dictionary. + + :returns: True if the run's ``name`` is allowed or if no name filter is configured. + """ + # Check if the user enabled filtering by class names: + if not self._monitor_settings.names_filter: + return True + + # Filter the run: + return serialized_run["name"] in self._monitor_settings.names_filter + + def _get_run_inputs(self, serialized_run: dict) -> dict[str, Any]: + """ + Build the inputs dictionary for a monitoring event. + + :param serialized_run: Serialized run dictionary. + + :returns: A dictionary containing inputs, run metadata and (optionally) additional metadata. + """ + inputs = { + "inputs": serialized_run["inputs"], + "run_type": serialized_run["run_type"], + "run_name": serialized_run["name"], + "tags": serialized_run["tags"], + "run_id": serialized_run["id"], + "start_timestamp": serialized_run["start_time"], + } + if "parent_run_id" in serialized_run: + # Parent run ID is excluded when child runs are joined in the same event. When child runs are split, it is + # included and can be used to reconstruct the run tree if needed. + inputs = {**inputs, "parent_run_id": serialized_run["parent_run_id"]} + if self._monitor_settings.include_metadata and "metadata" in serialized_run: + inputs = {**inputs, "metadata": serialized_run["metadata"]} + + return inputs + + def _get_run_outputs(self, serialized_run: dict) -> dict[str, Any]: + """ + Build the outputs dictionary for a monitoring event. + + :param serialized_run: Serialized run dictionary. + + :returns: A dictionary with outputs and optional other collected info depending on monitor settings. + """ + outputs = {"outputs": serialized_run["outputs"], "end_timestamp": serialized_run["end_time"]} + if self._monitor_settings.include_latency and "latency" in serialized_run: + outputs = {**outputs, "latency": serialized_run["latency"]} + if self._monitor_settings.include_errors: + outputs = {**outputs, "error": serialized_run["error"]} + if self._monitor_settings.include_full_run: + outputs = {**outputs, "full_run": serialized_run} + + return outputs + + def _summarize_run(self, serialized_run: dict, include_children: bool) -> tuple[dict, dict] | None: + """ + Summarize a single run into (inputs, outputs) if it passes filters. + + :param serialized_run: Serialized run dictionary. + :param include_children: Whether to include child runs. + + :returns: The summarized run (inputs, outputs) tuple if the run should be monitored, otherwise ``None``. + """ + # Pass filters: + if not ( + self._filter_by_tags(serialized_run=serialized_run) + and self._filter_by_run_types(serialized_run=serialized_run) + and self._filter_by_names(serialized_run=serialized_run) + ): + return None + + # Check if needed to include errors: + if serialized_run["error"] and not self._monitor_settings.include_errors: + return None + + # Prepare the inputs and outputs: + inputs = self._get_run_inputs(serialized_run=serialized_run) + outputs = self._get_run_outputs(serialized_run=serialized_run) + + # Check if needed to include child runs: + if include_children: + outputs["child_runs"] = [] + for child_run in serialized_run.get("child_runs", []): + # Recursively summarize the child run: + summarized_child_run = self._summarize_run(serialized_run=child_run, include_children=True) + if summarized_child_run: + inputs_child, outputs_child = summarized_child_run + outputs["child_runs"].append( + { + "input_data": inputs_child, + "output_data": outputs_child, + } + ) + + return inputs, outputs + + def _send_run_event( + self, event_id: str, inputs: dict, outputs: dict, start_time: datetime.datetime, end_time: datetime.datetime + ): + """ + Send a monitoring event for a single run. + + Note: If monitor debug mode is enabled, appends to ``debug_target_list`` or prints JSON. + + :param event_id: Unique event identifier. + :param inputs: Inputs dictionary for the event. + :param outputs: Outputs dictionary for the event. + :param start_time: Request/start timestamp. + :param end_time: Response/end timestamp. + """ + event = { + "event_id": event_id, + "label": self._monitor_settings.label, + "input_data": {"input_data": inputs}, # So it will be a single "input feature" in MLRun monitoring. + "output_data": {"output_data": outputs}, # So it will be a single "output feature" in MLRun monitoring. + "request_timestamp": start_time.strftime("%Y-%m-%d %H:%M:%S%z"), + "response_timestamp": end_time.strftime("%Y-%m-%d %H:%M:%S%z"), + } + if self._monitor_settings.debug: + if isinstance(self._monitor_settings.debug_target_list, list): + self._monitor_settings.debug_target_list.append(event) + else: + print(orjson.dumps(event, option=orjson.OPT_INDENT_2 | orjson.OPT_APPEND_NEWLINE)) + return + + self._mlrun_client.monitor(**event) + + @staticmethod + def _check_for_env_var_usage() -> bool: + """ + Check whether global env-var activated tracing is requested. + + :returns: True when ``LC_MLRUN_MONITORING_ENABLED`` environment variable equals ``"1"``. + """ + return os.environ.get(mlrun_monitoring_env_var, "0") == "1" + + @staticmethod + def _import_from_module_path(module_path: str) -> Any: + """ + Import an object from a full module path string. + + :param module_path: Full dotted path, e.g. ``a.b.module.object``. + + :returns: The imported object. + """ + try: + module_name, object_name = module_path.rsplit(".", 1) + module = importlib.import_module(module_name) + obj = getattr(module, object_name) + except ValueError as value_error: + raise ValueError( + f"The provided '{module_path}' is not valid: it must have at least one '.'. " + f"If the class is locally defined, please add '__main__.MyObject' to the path." + ) from value_error + except ImportError as import_error: + raise ImportError( + f"Could not import '{module_path}'. Tried to import '{module_name}' and failed with the following " + f"error: {import_error}." + ) from import_error + except AttributeError as attribute_error: + raise AttributeError( + f"Could not import '{object_name}'. Tried to run 'from {module_name} import {object_name}' and could " + f"not find it: {attribute_error}" + ) from attribute_error + + return obj + + +#: MLRun monitoring context variable to set when the user wraps his code with `mlrun_monitoring`. From this context +# variable LangChain will get the tracer in a thread-safe way. +mlrun_monitoring_var: ContextVar[MLRunTracer | None] = ContextVar( + "mlrun_monitoring", default=None +) + + +@contextmanager +def mlrun_monitoring(settings: MLRunTracerSettings | None = None): + """ + Context manager to enable MLRun tracing for LangChain code to monitor LangChain runs. + + Example usage:: + + from mlrun_tracer import mlrun_monitoring, MLRunTracerSettings + + settings = MLRunTracerSettings(...) + with mlrun_monitoring(settings=settings) as tracer: + # LangChain execution within this block will be traced by `tracer`. + ... + + :param settings: The settings to use to configure the tracer. + """ + mlrun_tracer = MLRunTracer(settings=settings) + token = mlrun_monitoring_var.set(mlrun_tracer) + try: + yield mlrun_tracer + finally: + mlrun_monitoring_var.reset(token) + + +# Register a hook for LangChain to apply the MLRun tracer: +register_configure_hook( + context_var=mlrun_monitoring_var, + inheritable=True, # To allow inner runs (agent that uses a tool that uses a llm...) to be traced. + env_var=mlrun_monitoring_env_var, + handle_class=MLRunTracer, +) + +# Temporary convenient function to set up the monitoring infrastructure required for the tracer. +def setup_langchain_monitoring( + project: str | mlrun.MlrunProject = None, + function_name: str = "langchain_mlrun_function", + model_name: str = "langchain_mlrun_model", + model_endpoint_name: str = "langchain_mlrun_endpoint", + v3io_container: str = "projects", + v3io_stream_path: str = None, + kafka_stream_profile_name: str = None, + kafka_linger_ms: int = 500, +) -> dict: + """ + Create a model endpoint in the given project to be used for LangChain monitoring with MLRun and returns the + necessary environment variables to configure the MLRun tracer client. The project should already exist and have + monitoring enabled:: + + project.set_model_monitoring_credentials( + stream_profile_name=..., + tsdb_profile_name=... + ) + + This function creates and logs dummy model and function in the specified project in order to create the model + endpoint for monitoring. It is a temporary workaround and will be added as a feature in a future MLRun version. + + :param project: The MLRun project name or object where to create the model endpoint. If None, the current active + project will be used. + :param function_name: The name of the serving function to create. + :param model_name: The name of the model to create. + :param model_endpoint_name: The name of the model endpoint to create. + :param v3io_container: The V3IO container where the monitoring stream is located (for MLRun Enterprise). + :param v3io_stream_path: The V3IO stream path for monitoring (for MLRun Enterprise). If None, + ``/model-endpoints/stream-v1`` will be used. + :param kafka_stream_profile_name: The name of the registered ``DatastoreProfileKafkaStream`` to use for Kafka + configuration (required for MLRun CE). This profile should be registered via + ``project.register_datastore_profile()`` and contains all Kafka settings including broker, topic, + SASL credentials, SSL config, etc. + :param kafka_linger_ms: Kafka producer linger.ms setting controlling message batching (default: 500ms). + Messages are accumulated for up to this duration before being sent as a batch, reducing network overhead. + The tracer always flushes at the end of each root run, guaranteeing delivery. Set to 0 to disable batching. + + :returns: A dictionary with the necessary environment variables to configure the MLRun tracer client. + raise: MLRunInvalidArgumentError: If no project is provided and there is no current active project. + """ + import io + import time + import sys + from contextlib import redirect_stdout, redirect_stderr + import tempfile + import pickle + import json + + from mlrun.features import Feature + + class ProgressStep: + """ + A context manager to display progress of a code block with timing and optional output suppression. + """ + + def __init__(self, label: str, indent: int = 2, width: int = 40, clean: bool = True): + """ + Initialize the ProgressStep context manager. + + :param label: The label to display for the progress step. + :param indent: The number of spaces to indent the label. + :param width: The width to pad the label for alignment. + :param clean: Whether to suppress stdout and stderr during the block execution. + """ + # Store parameters: + self._label = label + self._indent = indent + self._width = width + self._clean = clean + + # Internal state: + self._start_time = None + self._sink = io.StringIO() + self._stdout_redirect = None + self._stderr_redirect = None + self._last_line_length = 0 # To track the line printed when terminals don't support '\033[K'. + + # Capture the stream currently in use (before and if clean is true and we redirect it): + self._terminal = sys.stdout + + def __enter__(self): + """ + Enter the context manager, starting the timer and printing the initial status. + """ + # Start timer: + self._start_time = time.perf_counter() + + # Print without newline (using \r to allow overwriting): + self._write(icon=" ", status="Running", new_line=False) + + # Silence all internal noise: + if self._clean: + self._stdout_redirect = redirect_stdout(self._sink) + self._stderr_redirect = redirect_stderr(self._sink) + self._stdout_redirect.__enter__() + self._stderr_redirect.__enter__() + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Exit the context manager, stopping the timer and printing the final status. + + :param exc_type: The exception type, if any. + :param exc_val: The exception value, if any. + :param exc_tb: The exception traceback, if any. + """ + # Restore stdout/stderr: + if self._clean: + self._stdout_redirect.__exit__(exc_type, exc_val, exc_tb) + self._stderr_redirect.__exit__(exc_type, exc_val, exc_tb) + + # Calculate elapsed time: + elapsed = time.perf_counter() - self._start_time + + # Move cursor back to start of line ('\r') and overwrite ('\033[K' clears the line to the right): + if exc_type is None: + self._write(icon="✓", status=f"Done ({elapsed:.2f}s)", new_line=True) + else: + self._write(icon="✕", status="Failed", new_line=True) + + def update(self, status: str): + """ + Update the status message displayed for the progress step. + + :param status: The new status message to display. + """ + self._write(icon=" ", status=status, new_line=False) + + def _write(self, icon: str, status: str, new_line: bool): + """ + Write the progress line to the terminal, handling line clearing for terminals that do not support it. + + :param icon: The icon to display (e.g., checkmark, cross, space). + :param status: The status message to display. + :param new_line: Whether to end the line with a newline character. + """ + # Construct the basic line + line = f"\r{' ' * self._indent}[{icon}] {self._label.ljust(self._width, '.')} {status}" + + # Calculate if we need to pad with spaces to clear the old, longer line: + padding = max(0, self._last_line_length - len(line)) + + # Add spaces to clear old text (add the ANSI clear for terminals that support it): + line = f"{line}{' ' * padding}\033[K" + + # Add newline if needed: + if new_line: + line += "\n" + + # Write to terminal: + self._terminal.write(line) + self._terminal.flush() + + # Update the max length seen so far: + self._last_line_length = len(line) + + print("Creating LangChain model endpoint\n") + + # Get the project: + with ProgressStep("Loading Project"): + if project is None: + try: + project = mlrun.get_current_project(silent=False) + except mlrun.errors.MLRunInvalidArgumentError: + raise mlrun.errors.MLRunInvalidArgumentError( + "There is no current active project. Either use `mlrun.get_or_create_project` prior to " + "creating the monitoring endpoint or pass a project name to load." + ) + if isinstance(project, str): + project = mlrun.load_project(name=project) + + # Create and log the dummy model: + with ProgressStep(f"Creating Model") as progress_step: + # Check if the model already exists: + progress_step.update("Checking if model exists") + try: + dummy_model = project.get_artifact(key=model_name) + except mlrun.MLRunNotFoundError: + dummy_model = None + # If not, create and log it: + if not dummy_model: + progress_step.update(f"Logging model '{model_name}'") + with tempfile.TemporaryDirectory() as tmpdir: + # Create a dummy model file: + dummy_model_path = os.path.join(tmpdir, "for_langchain_mlrun_tracer.pkl") + with open(dummy_model_path, "wb") as f: + pickle.dump({"dummy": "model"}, f) + # Log the model: + dummy_model = project.log_model( + key=model_name, + model_file=dummy_model_path, + inputs=[Feature(value_type="str", name="input")], + outputs=[Feature(value_type='str', name="output")] + ) + + # Create and set the dummy function: + with ProgressStep("Creating Function") as progress_step: + # Check if the function already exists: + progress_step.update("Checking if function exists") + try: + dummy_function = project.get_function(key=function_name) + except mlrun.MLRunNotFoundError: + dummy_function = None + # If not, create and save it: + if not dummy_function: + progress_step.update(f"Setting function '{function_name}'") + with tempfile.TemporaryDirectory() as tmpdir: + # Create a dummy function file: + dummy_function_code = """ +def handler(context, event): + return "ok" +""" + dummy_function_path = os.path.join(tmpdir, "dummy_function.py") + with open(dummy_function_path, "w") as f: + f.write(dummy_function_code) + # Set the function in the project: + dummy_function = project.set_function( + func=dummy_function_path, name=function_name, image="mlrun/mlrun", kind="nuclio" + ) + dummy_function.save() + + # Create the model endpoint: + with ProgressStep("Creating Model Endpoint") as progress_step: + # Get the MLRun DB: + progress_step.update("Getting MLRun DB") + db = mlrun.get_run_db() + # Check if the model endpoint already exists: + progress_step.update("Checking if endpoint exists") + model_endpoint = project.list_model_endpoints(names=[model_endpoint_name]).endpoints + if model_endpoint: + model_endpoint = model_endpoint[0] + else: + progress_step.update("Creating model endpoint") + model_endpoint = mlrun.common.schemas.ModelEndpoint( + metadata=mlrun.common.schemas.ModelEndpointMetadata( + project=project.name, + name=model_endpoint_name, + endpoint_type=mlrun.common.schemas.model_monitoring.EndpointType.NODE_EP, + ), + spec=mlrun.common.schemas.ModelEndpointSpec( + function_name=dummy_function.metadata.name, + function_tag="latest", + model_path=dummy_model.uri, + model_class="CustomStream", + ), + status=mlrun.common.schemas.ModelEndpointStatus( + monitoring_mode=mm_constants.ModelMonitoringMode.enabled, + ), + ) + db.create_model_endpoint(model_endpoint=model_endpoint) + # Wait for the model endpoint UID to be set: + progress_step.update("Waiting for model endpoint") + uid_exist_flag = False + while not uid_exist_flag: + model_endpoint = project.list_model_endpoints(names=[model_endpoint_name]) + model_endpoint = model_endpoint.endpoints[0] + if model_endpoint.metadata.uid: + uid_exist_flag = True + + # Set parameters defaults: + v3io_stream_path = v3io_stream_path or f"{project.name}/model-endpoints/stream-v1" + + if mlrun.mlconf.is_ce_mode(): + if kafka_stream_profile_name is None: + raise ValueError( + "kafka_stream_profile_name is required for MLRun CE mode. Register a DatastoreProfileKafkaStream and " + "pass its name." + ) + client_env_vars = { + "LC_MLRUN_TRACER_CLIENT_KAFKA_STREAM_PROFILE_NAME": kafka_stream_profile_name, + "LC_MLRUN_TRACER_CLIENT_KAFKA_LINGER_MS": str(kafka_linger_ms), + } + else: + client_env_vars = { + "LC_MLRUN_TRACER_CLIENT_V3IO_STREAM_PATH": v3io_stream_path, + "LC_MLRUN_TRACER_CLIENT_V3IO_CONTAINER": v3io_container, + } + + # Prepare the environment variables: + env_vars = { + "LC_MLRUN_MONITORING_ENABLED": "1", + "LC_MLRUN_TRACER_CLIENT_PROJECT": project.name, + "LC_MLRUN_TRACER_CLIENT_MODEL_ENDPOINT_NAME": model_endpoint.metadata.name, + "LC_MLRUN_TRACER_CLIENT_MODEL_ENDPOINT_UID": model_endpoint.metadata.uid, + "LC_MLRUN_TRACER_CLIENT_SERVING_FUNCTION": function_name, + **client_env_vars + } + print("\n✨ Done! LangChain monitoring model endpoint created successfully.") + print("You can now set the following environment variables to enable MLRun tracing in your LangChain code:\n") + print(json.dumps(env_vars, indent=4)) + print( + "\nTo customize the monitoring behavior, you can also set additional environment variables prefixed with " + "'LC_MLRUN_TRACER_MONITOR_'. Refer to the MLRun tracer documentation for more details.\n" + ) + + return env_vars + + +class LangChainMonitoringApp(ModelMonitoringApplicationBase): + """ + A base monitoring application for LangChain that calculates common metrics on LangChain runs traced with the MLRun + tracer. + + The class is inheritable and can be extended to add custom metrics or override existing ones. It provides methods to + extract structured runs from the monitoring context and calculate metrics such as average latency, success rate, + token usage, and run name counts. + + If inheriting, the main method to override is `do_tracking`, which performs the tracking on the monitoring context. + """ + + def do_tracking(self, monitoring_context: MonitoringApplicationContext) -> ( + ModelMonitoringApplicationResult | + list[ModelMonitoringApplicationResult | ModelMonitoringApplicationMetric] | + dict[str, Any] + ): + """ + The main function that performs tracking on the monitoring context. The LangChain monitoring app by default + will calculate all the provided metrics on the structured runs extracted from the monitoring context sample + dataframe. + + :param monitoring_context: The monitoring context containing the sample dataframe. + + :returns: The monitoring artifacts, metrics and results. + """ + # Get the structured runs from the monitoring context: + structured_runs, _ = self.get_structured_runs(monitoring_context=monitoring_context) + + # Calculate the metrics: + average_latency = self.calculate_average_latency(structured_runs=structured_runs) + success_rate = self.calculate_success_rate(structured_runs=structured_runs) + token_usage = self.count_token_usage(structured_runs=structured_runs) + run_name_counts = self.count_run_names(structured_runs=structured_runs) + + return [ + ModelMonitoringApplicationMetric( + name="average_latency", + value=average_latency, + ), + ModelMonitoringApplicationMetric( + name="success_rate", + value=success_rate, + ), + ModelMonitoringApplicationMetric( + name="total_input_tokens", + value=token_usage["total_input_tokens"], + ), + ModelMonitoringApplicationMetric( + name="total_output_tokens", + value=token_usage["total_output_tokens"], + ), + ModelMonitoringApplicationMetric( + name="combined_total_tokens", + value=token_usage["combined_total"], + ), + *[ModelMonitoringApplicationMetric( + name=f"run_name_counts_{run_name}", + value=count, + ) for run_name, count in run_name_counts.items()], + ] + + @staticmethod + def get_structured_runs( + monitoring_context: MonitoringApplicationContext, + labels_filter: list[str] = None, + tags_filter: list[str] = None, + run_name_filter: list[str] = None, + run_type_filter: list[str] = None, + flatten_child_runs: bool = False, + ignore_child_runs: bool = False, + ignore_errored_runs: bool = False, + ) -> tuple[list[dict], list[dict]]: + """ + Get the structured runs from the monitoring context sample dataframe. The sample dataframe contains the raw + input and output data as JSON strings - the way the MLRun tracer sends them as events to MLRun monitoring. This + function parses the JSON strings into structured dictionaries that can be used for further metrics calculations + and analysis. + + :param monitoring_context: The monitoring context containing the sample dataframe. + :param labels_filter: List of labels to filter the runs. Only runs with a label appearing in this list will + remain. If None, no filtering is applied. + :param tags_filter: List of tags to filter the runs. Only runs containing at least one tag from this list will + remain. If None, no filtering is applied. + :param run_name_filter: List of run names to filter the runs. Only runs with a name appearing in this list will + remain. If None, no filtering is applied. + :param run_type_filter: List of run types to filter the runs. Only runs with a type appearing in this list will + remain. If None, no filtering is applied. + :param flatten_child_runs: Whether to flatten child runs into the main runs list. If True, all child runs will + be extracted and added to the main runs list. If False, child runs will be kept nested within their parent + runs. + :param ignore_child_runs: Whether to ignore child runs completely. If True, child runs will be removed from the + output. If False, child runs will be processed according to the other parameters. + :param ignore_errored_runs: Whether to ignore runs that resulted in errors. If True, runs with errors will be + excluded from the output. If False, errored runs will be included. + + :returns: A list of structured run dictionaries that passed the filters and a list of samples that could not be + parsed due to errors. + """ + # Retrieve the input and output samples from the monitoring context: + samples = monitoring_context.sample_df[['input', 'output']].to_dict('records') + + # Prepare to collect structured samples: + structured_samples = [] + errored_samples = [] + + # Go over all samples: + for sample in samples: + try: + # Parse the input data into structured format: + parsed_input = orjson.loads(sample['input']) + label = parsed_input['label'] + parsed_input = parsed_input["input"]["input_data"] + # Parse the output data into structured format: + parsed_output = orjson.loads(sample['output'])["output_data"] + structured_samples.extend( + LangChainMonitoringApp._collect_run( + structured_input=parsed_input, + structured_output=parsed_output, + label=label, + labels_filter=labels_filter, + tags_filter=tags_filter, + run_name_filter=run_name_filter, + run_type_filter=run_type_filter, + flatten_child_runs=flatten_child_runs, + ignore_child_runs=ignore_child_runs, + ignore_errored_runs=ignore_errored_runs, + ) + ) + except Exception: + errored_samples.append(sample) + + return structured_samples, errored_samples + + @staticmethod + def _collect_run( + structured_input: dict, + structured_output: dict, + label: str, + child_level: int = 0, + labels_filter: list[str] = None, + tags_filter: list[str] = None, + run_name_filter: list[str] = None, + run_type_filter: list[str] = None, + flatten_child_runs: bool = False, + ignore_child_runs: bool = False, + ignore_errored_runs: bool = False, + ) -> list[dict]: + """ + Recursively collect runs from the structured input and output data, applying filters as specified. + + :param structured_input: The structured input data of the run. + :param structured_output: The structured output data of the run. + :param label: The label of the run. + :param child_level: The current child level of the run (0 for root runs). + :param labels_filter: Label filter as described in `get_structured_runs`. + :param tags_filter: Tag filter as described in `get_structured_runs`. + :param run_name_filter: Run name filter as described in `get_structured_runs`. + :param run_type_filter: Run type filter as described in `get_structured_runs`. + :param flatten_child_runs: Flag to flatten child runs as described in `get_structured_runs`. + :param ignore_child_runs: Flag to ignore child runs as described in `get_structured_runs`. + :param ignore_errored_runs: Flag to ignore errored runs as described in `get_structured_runs`. + + :returns: A list of structured run dictionaries that passed the filters. + """ + # Prepare to collect runs: + runs = [] + + # Filter by label: + if labels_filter and label not in labels_filter: + return runs + + # Handle child runs: + if "child_runs" in structured_output: + # Check if we need to ignore or flatten child runs: + if ignore_child_runs: + structured_output.pop("child_runs") + elif flatten_child_runs: + # Recursively collect child runs: + child_runs = structured_output.pop("child_runs") + flattened_runs = [] + for child_run in child_runs: + flattened_runs.extend( + LangChainMonitoringApp._collect_run( + structured_input=child_run["input_data"], + structured_output=child_run["output_data"], + label=label, + child_level=child_level + 1, + tags_filter=tags_filter, + run_name_filter=run_name_filter, + run_type_filter=run_type_filter, + flatten_child_runs=flatten_child_runs, + ignore_child_runs=ignore_child_runs, + ignore_errored_runs=ignore_errored_runs, + ) + ) + runs.extend(flattened_runs) + + # Filter by tags, run name, run type, and errors: + if tags_filter and not set(structured_input["tags"]).isdisjoint(tags_filter): + return runs + if run_name_filter and structured_input["run_name"] not in run_name_filter: + return runs + if run_type_filter and structured_input["run_type"] not in run_type_filter: + return runs + if ignore_errored_runs and structured_output.get("error", None): + return runs + + # Collect the current run: + runs.append({"label": label, "input_data": structured_input, "output_data": structured_output, + "child_level": child_level}) + return runs + + @staticmethod + def iterate_structured_runs(structured_runs: list[dict]) -> Generator[dict, None, None]: + """ + Iterates over all runs in the structured samples, including child runs. + + :param structured_runs: List of structured run samples. + + :returns: A generator yielding each run structure. + """ + # TODO: Add an option to stop at a certain child level. + for structured_run in structured_runs: + if "child_runs" in structured_run['output_data']: + for child_run in structured_run['output_data']['child_runs']: + yield from LangChainMonitoringApp.iterate_structured_runs([{ + "label": structured_run['label'], + "input_data": child_run['input_data'], + "output_data": child_run['output_data'], + "child_level": structured_run['child_level'] + 1 + }]) + yield structured_run + + @staticmethod + def count_run_names(structured_runs: list[dict]) -> dict[str, int]: + """ + Counts occurrences of each run name in the structured samples. + + :param structured_runs: List of structured run samples. + + :returns: A dictionary with run names as keys and their counts as values. + """ + # TODO: Add a nice plot artifact that will draw the bar chart for what is being used the most. + # Prepare to count run names: + run_name_counts = {} + + # Go over all the runs: + for structured_run in LangChainMonitoringApp.iterate_structured_runs(structured_runs): + run_name = structured_run['input_data']['run_name'] + if run_name in run_name_counts: + run_name_counts[run_name] += 1 + else: + run_name_counts[run_name] = 1 + + return run_name_counts + + @staticmethod + def count_token_usage(structured_runs: list[dict]) -> dict: + """ + Calculates total tokens by only counting unique 'llm' type runs. + + :param structured_runs: List of structured run samples. + + :returns: A dictionary with total input tokens, total output tokens, and combined total tokens. + """ + # TODO: Add a token count per model breakdown (a dictionary of : to token counts) + # including an artifact that will plot it nicely. Pay attention that different providers use different + # keys in the response metadata. We should implement a mapping for that so each provider will have its own + # handler that will know how to extract the relevant info out of a run. + # Prepare to count tokens: + total_input_tokens = 0 + total_output_tokens = 0 + + # Go over all the LLM typed runs: + for structured_run in LangChainMonitoringApp.iterate_structured_runs(structured_runs): + # Count only LLM type runs as chain runs may include duplicative information as they accumulate the tokens + # from the child runs: + if structured_run['input_data']['run_type'] != 'llm': + continue + # Look for the token count information: + outputs = structured_run['output_data']["outputs"] + # Newer implementations should have the metadata in the `AIMessage` kwargs under generations: + if "generations" in outputs: + for generation in outputs["generations"]: # Iterate over generations. + for sample in generation: # Iterate over the generation batch. + token_usage = sample.get("message", {}).get("kwargs", {}).get("usage_metadata", {}) + if token_usage: + total_input_tokens += ( + token_usage.get('input_tokens', 0) + or token_usage.get('prompt_tokens', 0) + ) + total_output_tokens += ( + token_usage.get('output_tokens', 0) or + token_usage.get('completion_tokens', 0) + ) + continue + # Older implementations may have the metadata under `llm_output`: + if "llm_output" in outputs: + token_usage = outputs["llm_output"].get("token_usage", {}) + if token_usage: + total_input_tokens += token_usage.get('input_tokens', 0) or token_usage.get('prompt_tokens', 0) + total_output_tokens += ( + token_usage.get('output_tokens', 0) or + token_usage.get('completion_tokens', 0) + ) + + return { + "total_input_tokens": total_input_tokens, + "total_output_tokens": total_output_tokens, + "combined_total": total_input_tokens + total_output_tokens + } + + @staticmethod + def calculate_success_rate(structured_runs: list[dict]) -> float: + """ + Calculates the success rate across all runs. + + :param structured_runs: List of structured run samples. + + :returns: Success rate as a float percentage between 0 and 1. + """ + # TODO: Add an option to see errors breakdown by kind of error and maybe an option to show which run name yielded + # most of the errors with artifacts showcasing it. + successful_count = 0 + for structured_run in structured_runs: + if 'error' not in structured_run['output_data'] or structured_run['output_data']['error'] is None: + successful_count += 1 + return successful_count / len(structured_runs) if structured_runs else 0.0 + + @staticmethod + def calculate_average_latency(structured_runs: list[dict]) -> float: + """ + Calculates the average latency across all runs. + + :param structured_runs: List of structured run samples. + + :returns: Average latency in milliseconds. + """ + # TODO: Add an option to calculate latency per run name (to know which runs are slower/faster) and then return an + # artifact showcasing it. + # Prepare to calculate average latency: + total_latency = 0.0 + count = 0 + + # Go over all the root runs: + for structured_run in structured_runs: + # Skip child runs: + if structured_run["child_level"] > 0: + continue + # Check if latency is already provided: + if "latency" in structured_run['output_data']: + total_latency += structured_run['output_data']['latency'] + count += 1 + continue + # Calculate latency from timestamps: + start_time = datetime.datetime.fromisoformat(structured_run['input_data']['start_timestamp']) + end_time = datetime.datetime.fromisoformat(structured_run['output_data']['end_timestamp']) + total_latency += (end_time - start_time).total_seconds() * 1000 # Convert to milliseconds + count += 1 + + return total_latency / count if count > 0 else 0.0 diff --git a/modules/src/langchain_mlrun/notebook_images/mlrun_ui.png b/modules/src/langchain_mlrun/notebook_images/mlrun_ui.png new file mode 100644 index 000000000..9785eeae3 Binary files /dev/null and b/modules/src/langchain_mlrun/notebook_images/mlrun_ui.png differ diff --git a/modules/src/langchain_mlrun/requirements.txt b/modules/src/langchain_mlrun/requirements.txt new file mode 100644 index 000000000..2597fbe12 --- /dev/null +++ b/modules/src/langchain_mlrun/requirements.txt @@ -0,0 +1,4 @@ +pytest +langchain~=1.2 +pydantic-settings~=2.12 +kafka-python~=2.3 \ No newline at end of file diff --git a/modules/src/langchain_mlrun/test_langchain_mlrun.py b/modules/src/langchain_mlrun/test_langchain_mlrun.py new file mode 100644 index 000000000..cb6ca184c --- /dev/null +++ b/modules/src/langchain_mlrun/test_langchain_mlrun.py @@ -0,0 +1,1025 @@ +# Copyright 2026 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Literal, TypedDict, Annotated, Sequence, Any, Callable +from concurrent.futures import ThreadPoolExecutor +from operator import add + +import pytest +from langchain_core.language_models import LanguageModelInput +from langchain_core.runnables import Runnable, RunnableLambda +from pydantic import ValidationError + +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.output_parsers import StrOutputParser +from langchain_core.tracers import Run +from langchain_core.language_models.fake_chat_models import FakeListChatModel, GenericFakeChatModel +from langchain.agents import create_agent +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.tools import tool, BaseTool + +from langgraph.graph import StateGraph, START, END +from langchain_core.messages import BaseMessage +from pydantic_settings import BaseSettings, SettingsConfigDict + +from langchain_mlrun import ( + mlrun_monitoring, + MLRunTracer, + MLRunTracerSettings, + MLRunTracerClientSettings, + MLRunTracerMonitorSettings, + mlrun_monitoring_env_var, + LangChainMonitoringApp, +) + + +def _check_openai_credentials() -> bool: + """ + Check if OpenAI API key is set in environment variables. + + :return: True if OPENAI_API_KEY is set, False otherwise. + """ + return "OPENAI_API_KEY" in os.environ + + +# Import ChatOpenAI only if OpenAI credentials are available (meaning `langchain-openai` must be installed). +if _check_openai_credentials(): + from langchain_openai import ChatOpenAI + + +class _ToolEnabledFakeModel(GenericFakeChatModel): + """ + A fake chat model that supports tool binding for running agent tracing tests. + """ + + def bind_tools( + self, + tools: Sequence[ + dict[str, Any] | type | Callable | BaseTool # noqa: UP006 + ], + *, + tool_choice: str | None = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, AIMessage]: + return self + + +#: Tag value for testing tag filtering. +_dummy_tag = "dummy_tag" + + +def _run_simple_chain() -> str: + """ + Run a simple LangChain chain that gets a fact about a topic. + """ + # Build a simple chain: prompt -> llm -> str output parser + llm = ChatOpenAI( + model="gpt-4o-mini", + tags=[_dummy_tag] + ) if _check_openai_credentials() else ( + FakeListChatModel( + responses=[ + "MLRun is an open-source orchestrator for machine learning pipelines." + ], + tags=[_dummy_tag] + ) + ) + prompt = ChatPromptTemplate.from_template("Tell me a short fact about {topic}") + chain = prompt | llm | StrOutputParser() + + # Run the chain: + response = chain.invoke({"topic": "MLRun"}) + return response + + +def _run_simple_agent(): + """ + Run a simple LangChain agent that uses two tools to get weather and stock price. + """ + # Define the tools: + @tool + def get_weather(city: str) -> str: + """Get the current weather for a specific city.""" + return f"The weather in {city} is 22°C and sunny." + + @tool + def get_stock_price(symbol: str) -> str: + """Get the current stock price for a symbol.""" + return f"The stock price for {symbol} is $150.25." + + # Define the model: + model = ChatOpenAI( + model="gpt-4o-mini", + tags=[_dummy_tag] + ) if _check_openai_credentials() else ( + _ToolEnabledFakeModel( + messages=iter( + [ + AIMessage( + content="", + tool_calls=[ + {"name": "get_weather", "args": {"city": "London"}, "id": "call_abc123"}, + {"name": "get_stock_price", "args": {"symbol": "AAPL"}, "id": "call_def456"} + ] + ), + AIMessage(content="The weather in London is 22°C and AAPL is trading at $150.25.") + ] + ), + tags=[_dummy_tag] + ) + ) + + # Create the agent: + agent = create_agent( + model=model, + tools=[get_weather, get_stock_price], + system_prompt="You are a helpful assistant with access to tools." + ) + + # Run the agent: + return agent.invoke({"messages": ["What is the weather in London and the stock price of AAPL?"]}) + + +def _run_langgraph_graph(): + """ + Run a LangGraph agent that uses reflection to correct its answer. + """ + + # Define the graph state: + class AgentState(TypedDict): + messages: Annotated[list[BaseMessage], add] + attempts: int + + # Define the model: + model = ChatOpenAI(model="gpt-4o-mini") if _check_openai_credentials() else ( + _ToolEnabledFakeModel( + messages=iter( + [ + AIMessage(content="There are 2 'r's in Strawberry."), # Mocking the failure + AIMessage(content="I stand corrected. S-t-r-a-w-b-e-r-r-y. There are 3 'r's."), # Mocking the fix + ] + ) + ) + ) + + # Define the graph nodes and router: + def call_model(state: AgentState): + response = model.invoke(state["messages"]) + return {"messages": [response], "attempts": state["attempts"] + 1} + + def reflect_node(state: AgentState): + prompt = "Wait, count the 'r's again slowly, letter by letter. Are you sure?" + return {"messages": [HumanMessage(content=prompt)]} + + def router(state: AgentState) -> Literal["reflect", END]: + # Make sure there are 2 attempts at least for an answer: + if state["attempts"] == 1: + return "reflect" + return END + + # Build the graph: + builder = StateGraph(AgentState) + builder.add_node("model", call_model) + tagged_reflect_node = RunnableLambda(reflect_node).with_config(tags=[_dummy_tag]) + builder.add_node("reflect", tagged_reflect_node) + builder.add_edge(START, "model") + builder.add_conditional_edges("model", router) + builder.add_edge("reflect", "model") + graph = builder.compile() + + # Run the graph: + return graph.invoke({"messages": [HumanMessage(content="How many 'r's in Strawberry?")], "attempts": 0}) + + +#: List of example functions to run in tests along the full (split-run enabled) expected monitor events. +_run_suites: list[tuple[Callable, int]] = [ + (_run_simple_chain, 4), + (_run_simple_agent, 9), + (_run_langgraph_graph, 9), +] + + +#: Dummy environment variables for testing. +_dummy_environment_variables = { + "LC_MLRUN_TRACER_CLIENT_V3IO_STREAM_PATH": "dummy_stream_path", + "LC_MLRUN_TRACER_CLIENT_V3IO_CONTAINER": "dummy_container", + "LC_MLRUN_TRACER_CLIENT_MODEL_ENDPOINT_NAME": "dummy_model_name", + "LC_MLRUN_TRACER_CLIENT_MODEL_ENDPOINT_UID": "dummy_model_endpoint_uid", + "LC_MLRUN_TRACER_CLIENT_SERVING_FUNCTION": "dummy_serving_function", + "LC_MLRUN_TRACER_MONITOR_DEBUG": "true", + "LC_MLRUN_TRACER_MONITOR_DEBUG_TARGET_LIST": "true", + "LC_MLRUN_TRACER_MONITOR_SPLIT_RUNS": "true", +} + + +@pytest.fixture() +def auto_mode_settings(monkeypatch): + """ + Sets the environment variables to enable mlrun monitoring in 'auto' mode. + """ + # Set environment variables for the duration of the test: + monkeypatch.setenv(mlrun_monitoring_env_var, "1") + for key, value in _dummy_environment_variables.items(): + monkeypatch.setenv(key, value) + + # Reset the singleton tracer to ensure fresh initialization: + MLRunTracer._singleton_tracer = None + MLRunTracer._initialized = False + + yield + + # Reset the singleton tracer after the test: + MLRunTracer._singleton_tracer = None + MLRunTracer._initialized = False + + +@pytest.fixture +def manual_mode_settings(): + """ + Sets the mandatory client settings and debug flag for the tests. + """ + settings = MLRunTracerSettings( + client=MLRunTracerClientSettings( + v3io_stream_path="dummy_stream_path", + v3io_container="dummy_container", + model_endpoint_name="dummy_model_name", + model_endpoint_uid="dummy_model_endpoint_uid", + serving_function="dummy_serving_function", + ), + monitor=MLRunTracerMonitorSettings( + debug=True, + debug_target_list=[], + split_runs=True, # Easier to test with split runs (filters can filter per run instead of inner events) + ), + ) + + yield settings + + +def test_settings_init_via_env_vars(monkeypatch): + """ + Test that settings are correctly initialized from environment variables. + """ + #: First, ensure that without env vars, validation fails due to missing required fields: + with pytest.raises(ValidationError): + MLRunTracerSettings() + + # Now, set the environment variables for the client settings and debug flag: + for key, value in _dummy_environment_variables.items(): + monkeypatch.setenv(key, value) + + # Ensure that settings are now correctly initialized from env vars: + settings = MLRunTracerSettings() + assert settings.client.v3io_stream_path == "dummy_stream_path" + assert settings.client.v3io_container == "dummy_container" + assert settings.client.model_endpoint_name == "dummy_model_name" + assert settings.client.model_endpoint_uid == "dummy_model_endpoint_uid" + assert settings.client.serving_function == "dummy_serving_function" + assert settings.monitor.debug is True + + +@pytest.mark.parametrize( + "test_suite", [ + # Valid case: only v3io settings provided + ( + { + "v3io_stream_path": "dummy_stream_path", + "v3io_container": "dummy_container", + "model_endpoint_name": "dummy_model_name", + "model_endpoint_uid": "dummy_model_endpoint_uid", + "serving_function": "dummy_serving_function", + }, + True, + ), + # Invalid case: partial v3io settings provided + ( + { + "v3io_stream_path": "dummy_stream_path", + "model_endpoint_name": "dummy_model_name", + "model_endpoint_uid": "dummy_model_endpoint_uid", + "serving_function": "dummy_serving_function", + }, + False, + ), + # Valid case: only kafka settings provided + ( + { + "kafka_stream_profile_name": "dummy_stream_profile_name", + "model_endpoint_name": "dummy_model_name", + "model_endpoint_uid": "dummy_model_endpoint_uid", + "serving_function": "dummy_serving_function", + }, + True, + ), + # Invalid case: partial kafka settings provided + ( + { + "kafka_linger_ms": "1000", + "model_endpoint_name": "dummy_model_name", + "model_endpoint_uid": "dummy_model_endpoint_uid", + "serving_function": "dummy_serving_function", + }, + False, + ), + # Invalid case: both v3io and kafka settings provided + ( + { + "v3io_stream_path": "dummy_stream_path", + "v3io_container": "dummy_container", + "kafka_stream_profile_name": "dummy_stream_profile_name", + "model_endpoint_name": "dummy_model_name", + "model_endpoint_uid": "dummy_model_endpoint_uid", + "serving_function": "dummy_serving_function", + }, + False, + ), + # Invalid case: both v3io and kafka settings provided (partial) + ( + { + "v3io_container": "dummy_container", + "kafka_linger_ms": "1000", + "model_endpoint_name": "dummy_model_name", + "model_endpoint_uid": "dummy_model_endpoint_uid", + "serving_function": "dummy_serving_function", + }, + False, + ), + ] +) +def test_settings_v3io_kafka_combination(test_suite: tuple[dict[str, str], bool]): + """ + Test that settings validation enforces mutual exclusivity between v3io and kafka configurations. + + :param test_suite: A tuple containing environment variable overrides and a flag indicating + whether validation should pass. + """ + settings, should_pass = test_suite + + if should_pass: + MLRunTracerClientSettings(**settings) + else: + with pytest.raises(ValidationError): + MLRunTracerClientSettings(**settings) + + +def test_auto_mode_singleton_thread_safety(auto_mode_settings): + """ + Test that MLRunTracer singleton initialization is thread-safe in 'auto' mode. + + :param auto_mode_settings: Fixture to set up 'auto' mode environment and settings. + """ + # Initialize a list to hold tracer instances created in different threads: + tracer_instances = [] + + # Function to initialize the tracer in a thread: + def _init_tracer(): + tracer = MLRunTracer() + return tracer + + # Use ThreadPoolExecutor to simulate concurrent tracer initialization: + num_threads = 50 + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(_init_tracer) for _ in range(num_threads)] + tracer_instances = [f.result() for f in futures] + + # Check if every single reference in the list is the exact same object: + unique_instances = set(tracer._uid for tracer in tracer_instances) + + assert len(tracer_instances) == num_threads, "Not all threads returned a tracer instance. Test cannot proceed." + assert len(unique_instances) == 1, ( + f"Thread-safety failure! {len(unique_instances)} different instances were created under high concurrency." + ) + assert tracer_instances[0] is MLRunTracer(), "The global access point should return the same singleton." + + +def test_manual_mode_multi_instances(manual_mode_settings: MLRunTracerSettings): + """ + Test that MLRunTracer allows multiple instances in 'manual' mode. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + """ + # Initialize a list to hold tracer instances created in different iterations: + tracer_instances = [] + + # Create multiple tracer instances: + num_instances = 50 + for _ in range(num_instances): + tracer = MLRunTracer(settings=manual_mode_settings) + tracer_instances.append(tracer) + + # Check if every single reference in the list is a different object: + unique_instances = set(tracer._uid for tracer in tracer_instances) + + assert len(tracer_instances) == num_instances, "Not all instances were created. Test cannot proceed." + assert len(unique_instances) == num_instances, ( + f"Manual mode failure! {len(unique_instances)} unique instances were created instead of {num_instances}." + ) + + +@pytest.mark.parametrize("run_suites", _run_suites) +def test_auto_mode(auto_mode_settings, run_suites: tuple[Callable, int]): + """ + Test that MLRunTracer in 'auto' mode captures debug target list after running a LangChain / LangGraph example code. + + :param auto_mode_settings: Fixture to set up 'auto' mode environment and settings. + + :param run_suites: The function to run with the expected monitored events. + """ + run_func, expected_events = run_suites + + tracer = MLRunTracer() + assert len(tracer.settings.monitor.debug_target_list) == 0 + + print(run_func()) + assert len(tracer.settings.monitor.debug_target_list) == expected_events + + +@pytest.mark.parametrize("run_suites", _run_suites) +def test_manual_mode(manual_mode_settings: MLRunTracerSettings, run_suites: tuple[Callable, int]): + """ + Test that MLRunTracer in 'auto' mode captures debug target list after running a LangChain / LangGraph example code. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + :param run_suites: The function to run with the expected monitored events. + """ + run_func, expected_events = run_suites + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + print(run_func()) + assert len(tracer.settings.monitor.debug_target_list) == expected_events + + +def test_labeling(manual_mode_settings: MLRunTracerSettings): + """ + Test that MLRunTracer in 'auto' mode captures debug target list after running a LangChain / LangGraph example code. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + """ + for i, (run_func, expected_events) in enumerate(_run_suites): + label = f"label_{i}" + manual_mode_settings.monitor.label = label + manual_mode_settings.monitor.debug_target_list.clear() + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + print(run_func()) + assert len(tracer.settings.monitor.debug_target_list) == expected_events + for event in tracer.settings.monitor.debug_target_list: + assert event["label"] == label + + +@pytest.mark.parametrize( + "run_suites", [ + run_suite + (filtered_events,) + for run_suite, filtered_events in zip(_run_suites, [1, 2, 1]) + ] +) +def test_monitor_settings_tags_filter( + manual_mode_settings: MLRunTracerSettings, + run_suites: tuple[Callable, int, int], +): + """ + Test the `tags_filter` setting of MLRunTracer. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + :param run_suites: The function to run with the expected monitored events and filtered events. + """ + run_func, expected_events, filtered_events = run_suites + + manual_mode_settings.monitor.tags_filter = [_dummy_tag] + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + print(run_func()) + assert len(tracer.settings.monitor.debug_target_list) == filtered_events + for event in tracer.settings.monitor.debug_target_list: + assert not set(manual_mode_settings.monitor.tags_filter).isdisjoint(event["input_data"]["input_data"]["tags"]) + + +@pytest.mark.parametrize( + "run_suites", [ + run_suite + (filtered_events,) + for run_suite, filtered_events in zip(_run_suites, [1, 3, 4]) + ] +) +def test_monitor_settings_name_filter( + manual_mode_settings: MLRunTracerSettings, + run_suites: tuple[Callable, int, int], +): + """ + Test the `names_filter` setting of MLRunTracer. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + :param run_suites: The function to run with the expected monitored events and filtered events. + """ + run_func, expected_events, filtered_events = run_suites + + manual_mode_settings.monitor.names_filter = ["StrOutputParser", "get_weather", "model", "router"] + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + print(run_func()) + assert len(tracer.settings.monitor.debug_target_list) == filtered_events + for event in tracer.settings.monitor.debug_target_list: + assert event["input_data"]["input_data"]["run_name"] in manual_mode_settings.monitor.names_filter + + +@pytest.mark.parametrize( + "run_suites", [ + run_suite + (filtered_events,) + for run_suite, filtered_events in zip(_run_suites, [2, 7, 9]) + ] +) +@pytest.mark.parametrize("split_runs", [True, False]) +def test_monitor_settings_run_type_filter( + manual_mode_settings: MLRunTracerSettings, + run_suites: tuple[Callable, int, int], + split_runs: bool +): + """ + Test the `run_types_filter` setting of MLRunTracer. Will also test with split runs enabled and disabled - meaning + that when disabled, if a parent run is filtered, all its child runs are also filtered by default. In the test we + made sure that the root run is always passing the filter (hence the equal one). + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + :param run_suites: The function to run with the expected monitored events and filtered events. + :param split_runs: Whether to enable split runs in the monitor settings. + """ + run_func, expected_events, filtered_events = run_suites + filtered_events = filtered_events if split_runs else 1 + + manual_mode_settings.monitor.run_types_filter = ["llm", "chain"] + manual_mode_settings.monitor.split_runs = split_runs + + def recursive_check_run_types(run: dict): + assert run["input_data"]["run_type"] in manual_mode_settings.monitor.run_types_filter + if "child_runs" in run["output_data"]: + for child_run in run["output_data"]["child_runs"]: + recursive_check_run_types(child_run) + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + print(run_func()) + assert len(tracer.settings.monitor.debug_target_list) == filtered_events + + for event in tracer.settings.monitor.debug_target_list: + event_run = { + "input_data": event["input_data"]["input_data"], + "output_data": event["output_data"]["output_data"], + } + recursive_check_run_types(run=event_run) + +@pytest.mark.parametrize("run_suites", _run_suites) +@pytest.mark.parametrize("split_runs", [True, False]) +def test_monitor_settings_full_filter( + manual_mode_settings: MLRunTracerSettings, + run_suites: tuple[Callable, int], + split_runs: bool +): + """ + Test that a complete filter (not allowing any events to pass) won't fail the tracer. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + :param run_suites: The function to run with the expected monitored events. + :param split_runs: Whether to enable split runs in the monitor settings. + """ + run_func, _ = run_suites + + manual_mode_settings.monitor.run_types_filter = ["dummy_run_type"] + manual_mode_settings.monitor.split_runs = split_runs + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + print(run_func()) + assert len(tracer.settings.monitor.debug_target_list) == 0 + + +@pytest.mark.parametrize("run_suites", _run_suites) +@pytest.mark.parametrize("split_runs", [True, False]) +@pytest.mark.parametrize("root_run_only", [True, False]) +def test_monitor_settings_split_runs_and_root_run_only( + manual_mode_settings: MLRunTracerSettings, + run_suites: tuple[Callable, int], + split_runs: bool, + root_run_only: bool, +): + """ + Test the `split_runs` setting of MLRunTracer. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + :param run_suites: The function to run with the expected monitored events. + :param split_runs: Whether to enable split runs in the monitor settings. + :param root_run_only: Whether to enable `root_run_only` in the monitor settings. + """ + run_func, expected_events = run_suites + + manual_mode_settings.monitor.split_runs = split_runs + manual_mode_settings.monitor.root_run_only = root_run_only + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + for run_iteration in range(1, 3): + print(run_func()) + if root_run_only: + assert len(tracer.settings.monitor.debug_target_list) == 1 * run_iteration + assert "child_runs" not in tracer.settings.monitor.debug_target_list[-1]["output_data"]["output_data"] + elif split_runs: + assert len(tracer.settings.monitor.debug_target_list) == expected_events * run_iteration + assert "child_runs" not in tracer.settings.monitor.debug_target_list[-1]["output_data"]["output_data"] + else: # split_runs disabled + assert len(tracer.settings.monitor.debug_target_list) == 1 * run_iteration + assert len(tracer.settings.monitor.debug_target_list[-1]["output_data"]["output_data"]["child_runs"]) != 0 + + +class _CustomRunSummarizerSettings(BaseSettings): + """ + Settings for the custom summarizer function. + """ + dummy_value: int = 21 + + model_config = SettingsConfigDict(env_prefix="TEST_CUSTOM_SUMMARIZER_SETTINGS_") + + +def _custom_run_summarizer(run: Run, settings: _CustomRunSummarizerSettings = None): + """ + A custom summarizer function for testing. + + :param run: The LangChain / LangGraph run to summarize. + :param settings: Optional settings for the summarizer. + """ + inputs = { + "run_id": run.id, + "input": run.inputs, + "from_settings": settings.dummy_value if settings else 0, + } + + def count_llm_calls(r: Run) -> int: + if not r.child_runs: + return 1 if r.run_type == "llm" else 0 + return sum(count_llm_calls(child) for child in r.child_runs) + + def count_tool_calls(r: Run) -> int: + if not r.child_runs: + return 1 if r.run_type == "tool" else 0 + return sum(count_tool_calls(child) for child in r.child_runs) + + outputs = { + "llm_calls": count_llm_calls(run), + "tool_calls": count_tool_calls(run), + "output": run.outputs + } + + yield inputs, outputs + + +@pytest.mark.parametrize("run_suites", _run_suites) +@pytest.mark.parametrize("run_summarizer_function", [ + _custom_run_summarizer, + "test_langchain_mlrun._custom_run_summarizer", +]) +@pytest.mark.parametrize("run_summarizer_settings", [ + _CustomRunSummarizerSettings(dummy_value=12), + "test_langchain_mlrun._CustomRunSummarizerSettings", + None, +]) +def test_monitor_settings_custom_run_summarizer( + manual_mode_settings: MLRunTracerSettings, + run_suites: tuple[Callable, int], + run_summarizer_function: Callable | str, + run_summarizer_settings: BaseSettings | str | None, +): + """ + Test the custom run summarizer that can be passed to MLRunTracer. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + :param run_suites: The function to run with the expected monitored events. + :param run_summarizer_function: The custom summarizer function or its import path. + :param run_summarizer_settings: The settings for the custom summarizer or its import path. + """ + run_func, _ = run_suites + manual_mode_settings.monitor.run_summarizer_function = run_summarizer_function + manual_mode_settings.monitor.run_summarizer_settings = run_summarizer_settings + dummy_value_for_settings_from_env = 26 + os.environ["TEST_CUSTOM_SUMMARIZER_SETTINGS_DUMMY_VALUE"] = str(dummy_value_for_settings_from_env) + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + print(run_func()) + assert len(tracer.settings.monitor.debug_target_list) == 1 + + event = tracer.settings.monitor.debug_target_list[0] + if run_summarizer_settings: + if isinstance(run_summarizer_settings, str): + assert event["input_data"]["input_data"]["from_settings"] == dummy_value_for_settings_from_env + else: + assert event["input_data"]["input_data"]["from_settings"] == run_summarizer_settings.dummy_value + else: + assert event["input_data"]["input_data"]["from_settings"] == 0 + + +def test_monitor_settings_include_errors_field_presence(manual_mode_settings: MLRunTracerSettings): + """ + Test that when `include_errors` is True, the error field is present in outputs. + When `include_errors` is False, the error field is not added to outputs. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + """ + # Run with include_errors=True (default) and verify error field is present: + manual_mode_settings.monitor.include_errors = True + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + _run_simple_chain() + assert len(tracer.settings.monitor.debug_target_list) > 0 + + for event in tracer.settings.monitor.debug_target_list: + output_data = event["output_data"]["output_data"] + assert "error" in output_data, "error field should be present when include_errors is True" + + # Now run with include_errors=False and verify error field is excluded: + manual_mode_settings.monitor.include_errors = False + manual_mode_settings.monitor.debug_target_list.clear() + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + _run_simple_chain() + assert len(tracer.settings.monitor.debug_target_list) > 0 + + for event in tracer.settings.monitor.debug_target_list: + output_data = event["output_data"]["output_data"] + assert "error" not in output_data, "error field should be excluded when include_errors is False" + + +def test_monitor_settings_include_full_run(manual_mode_settings: MLRunTracerSettings): + """ + Test that when `include_full_run` is True, the complete serialized run is included in outputs. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + """ + manual_mode_settings.monitor.include_full_run = True + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + _run_simple_chain() + + assert len(tracer.settings.monitor.debug_target_list) > 0 + + for event in tracer.settings.monitor.debug_target_list: + output_data = event["output_data"]["output_data"] + assert "full_run" in output_data, "full_run should be included in outputs when include_full_run is True" + # Verify the full_run contains expected run structure: + assert "inputs" in output_data["full_run"] + assert "outputs" in output_data["full_run"] + + +def test_monitor_settings_include_metadata(manual_mode_settings: MLRunTracerSettings): + """ + Test that when `include_metadata` is False, metadata is excluded from inputs. + + Note: The fake models used in tests don't produce runs with metadata, so we can only + verify the "exclude" behavior. The code only adds metadata if the run actually contains it. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + """ + # Run with include_metadata=False and verify metadata is excluded: + manual_mode_settings.monitor.include_metadata = False + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + _run_simple_chain() + assert len(tracer.settings.monitor.debug_target_list) > 0 + + # Check that metadata is not present in inputs: + for event in tracer.settings.monitor.debug_target_list: + input_data = event["input_data"]["input_data"] + assert "metadata" not in input_data, "metadata should be excluded when include_metadata is False" + + +def test_monitor_settings_include_latency(manual_mode_settings: MLRunTracerSettings): + """ + Test that when `include_latency` is False, latency is excluded from outputs. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + """ + manual_mode_settings.monitor.include_latency = False + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + _run_simple_chain() + assert len(tracer.settings.monitor.debug_target_list) > 0 + + for event in tracer.settings.monitor.debug_target_list: + assert "latency" not in event["output_data"]["output_data"], \ + "latency should be excluded when include_latency is False" + + +def test_import_from_module_path_errors(): + """ + Test that `_import_from_module_path` raises appropriate errors for invalid paths. + """ + # Test ValueError for path without a dot: + with pytest.raises(ValueError) as exc_info: + MLRunTracer._import_from_module_path("no_dot_path") + assert "must have at least one '.'" in str(exc_info.value) + + # Test ImportError for non-existent module: + with pytest.raises(ImportError) as exc_info: + MLRunTracer._import_from_module_path("nonexistent_module_xyz.SomeClass") + assert "Could not import" in str(exc_info.value) + + # Test AttributeError for non-existent attribute in existing module: + with pytest.raises(AttributeError) as exc_info: + MLRunTracer._import_from_module_path("os.nonexistent_attribute_xyz") + assert "Could not import" in str(exc_info.value) + + +#: Sample structured runs for testing LangChainMonitoringApp methods. +_sample_structured_runs = [ + { + "label": "test_label", + "child_level": 0, + "input_data": { + "run_name": "RunnableSequence", + "run_type": "chain", + "tags": ["tag1"], + "inputs": {"topic": "MLRun"}, + "start_timestamp": "2024-01-01T10:00:00+00:00", + }, + "output_data": { + "outputs": {"result": "test output"}, + "end_timestamp": "2024-01-01T10:00:01+00:00", + "error": None, + "child_runs": [ + { + "input_data": { + "run_name": "FakeListChatModel", + "run_type": "llm", + "tags": ["tag2"], + "inputs": {"prompt": "test"}, + "start_timestamp": "2024-01-01T10:00:00.100+00:00", + }, + "output_data": { + "outputs": { + "generations": [[{ + "message": { + "kwargs": { + "usage_metadata": { + "input_tokens": 10, + "output_tokens": 20, + } + } + } + }]] + }, + "end_timestamp": "2024-01-01T10:00:00.500+00:00", + "error": None, + }, + }, + ], + }, + }, + { + "label": "test_label", + "child_level": 0, + "input_data": { + "run_name": "SimpleAgent", + "run_type": "chain", + "tags": ["tag1"], + "inputs": {"query": "test query"}, + "start_timestamp": "2024-01-01T10:00:02+00:00", + }, + "output_data": { + "outputs": {"result": "agent output"}, + "end_timestamp": "2024-01-01T10:00:04+00:00", + "error": "SomeError: something went wrong", + }, + }, +] + + +def test_langchain_monitoring_app_iterate_structured_runs(): + """ + Test that `iterate_structured_runs` yields all runs including nested child runs. + """ + # Iterate over all runs: + all_runs = list(LangChainMonitoringApp.iterate_structured_runs(_sample_structured_runs)) + + # Should yield parent runs and child runs: + # - First sample: 1 parent + 1 child = 2 runs + # - Second sample: 1 parent = 1 run + # Total: 3 runs + assert len(all_runs) == 3 + + # Verify run names are as expected: + run_names = [r["input_data"]["run_name"] for r in all_runs] + assert "RunnableSequence" in run_names + assert "FakeListChatModel" in run_names + assert "SimpleAgent" in run_names + + +def test_langchain_monitoring_app_count_run_names(): + """ + Test that `count_run_names` correctly counts occurrences of each run name. + """ + counts = LangChainMonitoringApp.count_run_names(_sample_structured_runs) + + assert counts["RunnableSequence"] == 1 + assert counts["FakeListChatModel"] == 1 + assert counts["SimpleAgent"] == 1 + + +def test_langchain_monitoring_app_count_token_usage(): + """ + Test that `count_token_usage` correctly calculates total tokens from LLM runs. + """ + token_usage = LangChainMonitoringApp.count_token_usage(_sample_structured_runs) + + assert token_usage["total_input_tokens"] == 10 + assert token_usage["total_output_tokens"] == 20 + assert token_usage["combined_total"] == 30 + + +def test_langchain_monitoring_app_calculate_success_rate(): + """ + Test that `calculate_success_rate` returns the correct percentage of successful runs. + """ + success_rate = LangChainMonitoringApp.calculate_success_rate(_sample_structured_runs) + + # First run has no error, second run has error: + # Success rate should be 1/2 = 0.5 + assert success_rate == 0.5 + + # Test with empty list: + empty_rate = LangChainMonitoringApp.calculate_success_rate([]) + assert empty_rate == 0.0 + + # Test with all successful runs: + successful_runs = [_sample_structured_runs[0]] # Only the first run which has no error + all_success_rate = LangChainMonitoringApp.calculate_success_rate(successful_runs) + assert all_success_rate == 1.0 + + +def test_langchain_monitoring_app_calculate_average_latency(): + """ + Test that `calculate_average_latency` returns the correct average latency across root runs. + """ + # Calculate average latency: + avg_latency = LangChainMonitoringApp.calculate_average_latency(_sample_structured_runs) + + # First run: 10:00:00 to 10:00:01 = 1000ms + # Second run: 10:00:02 to 10:00:04 = 2000ms + # Average: (1000 + 2000) / 2 = 1500ms + assert avg_latency == 1500.0 + + # Test with empty list: + empty_latency = LangChainMonitoringApp.calculate_average_latency([]) + assert empty_latency == 0.0 + + +def test_langchain_monitoring_app_calculate_average_latency_skips_child_runs(): + """ + Test that `calculate_average_latency` skips child runs (only calculates for root runs). + """ + # Create a sample with a child run that has child_level > 0: + runs_with_child = [ + { + "label": "test", + "child_level": 0, + "input_data": {"start_timestamp": "2024-01-01T10:00:00+00:00"}, + "output_data": {"end_timestamp": "2024-01-01T10:00:01+00:00"}, + }, + { + "label": "test", + "child_level": 1, # This is a child run, should be skipped + "input_data": {"start_timestamp": "2024-01-01T10:00:00+00:00"}, + "output_data": {"end_timestamp": "2024-01-01T10:00:10+00:00"}, # 10 seconds - would skew average + }, + ] + + # Calculate average latency: + avg_latency = LangChainMonitoringApp.calculate_average_latency(runs_with_child) + + # Should only consider the root run (1000ms), not the child run: + assert avg_latency == 1000.0 + + +def test_debug_mode_stdout(manual_mode_settings: MLRunTracerSettings, capsys): + """ + Test that debug mode prints to stdout when `debug_target_list` is not set (is False). + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + :param capsys: Pytest fixture to capture stdout/stderr. + """ + # Set debug mode with debug_target_list=False (should print to stdout): + manual_mode_settings.monitor.debug = True + manual_mode_settings.monitor.debug_target_list = False + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + _run_simple_chain() + + # Capture stdout: + captured = capsys.readouterr() + + # Verify that JSON output was printed to stdout: + assert "event_id" in captured.out, "Event should be printed to stdout when debug_target_list is False" + assert "input_data" in captured.out + assert "output_data" in captured.out