diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLOperator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLOperator.java index e914090f0c7..404b92b85ae 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLOperator.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLOperator.java @@ -5,7 +5,8 @@ package org.opensearch.sql.opensearch.planner.physical; -import java.util.ArrayList; +import static org.opensearch.sql.utils.MLCommonsConstants.CATEGORY_FIELD; + import java.util.Collections; import java.util.HashMap; import java.util.Iterator; @@ -14,6 +15,7 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.Row; import org.opensearch.ml.common.output.MLOutput; @@ -42,28 +44,40 @@ public class MLOperator extends MLCommonsOperatorActions { @Override public void open() { super.open(); - DataFrame inputDataFrame = generateInputDataset(input); Map args = processArgs(arguments); - MLOutput mlOutput = getMLOutput(inputDataFrame, args, nodeClient); - final Iterator inputRowIter = inputDataFrame.iterator(); + // Check if category_field is provided + String categoryField = + arguments.containsKey(CATEGORY_FIELD) + ? (String) arguments.get(CATEGORY_FIELD).getValue() + : null; + // Only need to check train here, as action should be already checked in ml client. final boolean isPrediction = ((String) args.get("action")).equals("train") ? false : true; - // For train, only one row to return. - final Iterator trainIter = - new ArrayList() { - { - add("train"); - } - }.iterator(); - final Iterator resultRowIter = - isPrediction ? ((MLPredictionOutput) mlOutput).getPredictionResult().iterator() : null; + final Iterator trainIter = Collections.singletonList("train").iterator(); + + // For prediction mode, handle both categorized and non-categorized cases + List> inputDataFrames = + generateCategorizedInputDataset(input, categoryField); + List mlOutputs = + inputDataFrames.stream() + .map(pair -> getMLOutput(pair.getRight(), args, nodeClient)) + .toList(); + Iterator> inputDataFramesIter = inputDataFrames.iterator(); + Iterator mlOutputIter = mlOutputs.iterator(); + iterator = - new Iterator() { + new Iterator<>() { + private DataFrame inputDataFrame = null; + private Iterator inputRowIter = null; + private MLOutput mlOutput = null; + private Iterator resultRowIter = null; + @Override public boolean hasNext() { if (isPrediction) { - return inputRowIter.hasNext(); + return (inputRowIter != null && inputRowIter.hasNext()) + || inputDataFramesIter.hasNext(); } else { boolean res = trainIter.hasNext(); if (res) { @@ -75,8 +89,19 @@ public boolean hasNext() { @Override public ExprValue next() { - return buildPPLResult( - isPrediction, inputRowIter, inputDataFrame, mlOutput, resultRowIter); + if (isPrediction) { + if (inputRowIter == null || !inputRowIter.hasNext()) { + Pair pair = inputDataFramesIter.next(); + inputDataFrame = pair.getLeft(); + inputRowIter = inputDataFrame.iterator(); + mlOutput = mlOutputIter.next(); + resultRowIter = ((MLPredictionOutput) mlOutput).getPredictionResult().iterator(); + } + return buildPPLResult(true, inputRowIter, inputDataFrame, mlOutput, resultRowIter); + } else { + // train case + return buildPPLResult(false, null, null, mlOutputs.getFirst(), null); + } } }; } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLOperatorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLOperatorTest.java index 5004965d584..ebcd2b639e8 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLOperatorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLOperatorTest.java @@ -15,6 +15,7 @@ import static org.mockito.Mockito.when; import static org.opensearch.sql.utils.MLCommonsConstants.ACTION; import static org.opensearch.sql.utils.MLCommonsConstants.ALGO; +import static org.opensearch.sql.utils.MLCommonsConstants.CATEGORY_FIELD; import static org.opensearch.sql.utils.MLCommonsConstants.KMEANS; import static org.opensearch.sql.utils.MLCommonsConstants.PREDICT; import static org.opensearch.sql.utils.MLCommonsConstants.TRAIN; @@ -144,6 +145,21 @@ public void testOpenPredict() { } } + @Test + public void testOpenPredictWithCategoryField() { + setUpPredict(); + // Add category_field parameter + arguments.put(CATEGORY_FIELD, AstDSL.stringLiteral("region")); + + try (MockedStatic mlClientMockedStatic = Mockito.mockStatic(MLClient.class)) { + when(MLClient.getMLClient(any(NodeClient.class))).thenReturn(machineLearningNodeClient); + mlOperator.open(); + assertTrue(mlOperator.hasNext()); + assertNotNull(mlOperator.next()); + assertFalse(mlOperator.hasNext()); + } + } + @Test public void testOpenTrain() { setUpTrain();