Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

package org.opensearch.sql.opensearch.planner.physical;

import java.util.ArrayList;
import static org.opensearch.sql.utils.MLCommonsConstants.CATEGORY_FIELD;

import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
Expand All @@ -14,6 +15,7 @@
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.Row;
import org.opensearch.ml.common.output.MLOutput;
Expand Down Expand Up @@ -42,28 +44,40 @@ public class MLOperator extends MLCommonsOperatorActions {
@Override
public void open() {
super.open();
DataFrame inputDataFrame = generateInputDataset(input);
Map<String, Object> args = processArgs(arguments);

MLOutput mlOutput = getMLOutput(inputDataFrame, args, nodeClient);
final Iterator<Row> inputRowIter = inputDataFrame.iterator();
// Check if category_field is provided
String categoryField =
arguments.containsKey(CATEGORY_FIELD)
? (String) arguments.get(CATEGORY_FIELD).getValue()
: null;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

categoryField is null will throw NPE in generateCategorizedInputDataset

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How so? generateCategorizedInputDataset has null checking:

ExprValue categoryValue = categoryField == null ? null : tupleValue.get(categoryField);

If we want, we can add a @Nullable annotation to that field to document that contract in the signature

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it depends on the what kind of Map it used, seems HashMap can handle null key for computeIfAbsent(key), but ConcurrentHashMap and other kinds of Map throws NPE.


// Only need to check train here, as action should be already checked in ml client.
final boolean isPrediction = ((String) args.get("action")).equals("train") ? false : true;
// For train, only one row to return.
final Iterator<String> trainIter =
new ArrayList<String>() {
{
add("train");
}
}.iterator();
final Iterator<Row> resultRowIter =
isPrediction ? ((MLPredictionOutput) mlOutput).getPredictionResult().iterator() : null;
final Iterator<String> trainIter = Collections.singletonList("train").iterator();

// For prediction mode, handle both categorized and non-categorized cases
List<Pair<DataFrame, DataFrame>> inputDataFrames =
generateCategorizedInputDataset(input, categoryField);
List<MLOutput> mlOutputs =
inputDataFrames.stream()
.map(pair -> getMLOutput(pair.getRight(), args, nodeClient))
.toList();
Iterator<Pair<DataFrame, DataFrame>> inputDataFramesIter = inputDataFrames.iterator();
Iterator<MLOutput> mlOutputIter = mlOutputs.iterator();

iterator =
new Iterator<ExprValue>() {
new Iterator<>() {
private DataFrame inputDataFrame = null;
private Iterator<Row> inputRowIter = null;
private MLOutput mlOutput = null;
private Iterator<Row> resultRowIter = null;

@Override
public boolean hasNext() {
if (isPrediction) {
return inputRowIter.hasNext();
return (inputRowIter != null && inputRowIter.hasNext())
|| inputDataFramesIter.hasNext();
} else {
boolean res = trainIter.hasNext();
if (res) {
Expand All @@ -75,8 +89,19 @@ public boolean hasNext() {

@Override
public ExprValue next() {
return buildPPLResult(
isPrediction, inputRowIter, inputDataFrame, mlOutput, resultRowIter);
if (isPrediction) {
if (inputRowIter == null || !inputRowIter.hasNext()) {
Pair<DataFrame, DataFrame> pair = inputDataFramesIter.next();
inputDataFrame = pair.getLeft();
inputRowIter = inputDataFrame.iterator();
mlOutput = mlOutputIter.next();
resultRowIter = ((MLPredictionOutput) mlOutput).getPredictionResult().iterator();
}
return buildPPLResult(true, inputRowIter, inputDataFrame, mlOutput, resultRowIter);
} else {
// train case
return buildPPLResult(false, null, null, mlOutputs.getFirst(), null);
}
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import static org.mockito.Mockito.when;
import static org.opensearch.sql.utils.MLCommonsConstants.ACTION;
import static org.opensearch.sql.utils.MLCommonsConstants.ALGO;
import static org.opensearch.sql.utils.MLCommonsConstants.CATEGORY_FIELD;
import static org.opensearch.sql.utils.MLCommonsConstants.KMEANS;
import static org.opensearch.sql.utils.MLCommonsConstants.PREDICT;
import static org.opensearch.sql.utils.MLCommonsConstants.TRAIN;
Expand Down Expand Up @@ -144,6 +145,21 @@ public void testOpenPredict() {
}
}

@Test
public void testOpenPredictWithCategoryField() {
setUpPredict();
// Add category_field parameter
arguments.put(CATEGORY_FIELD, AstDSL.stringLiteral("region"));

try (MockedStatic<MLClient> mlClientMockedStatic = Mockito.mockStatic(MLClient.class)) {
when(MLClient.getMLClient(any(NodeClient.class))).thenReturn(machineLearningNodeClient);
mlOperator.open();
assertTrue(mlOperator.hasNext());
assertNotNull(mlOperator.next());
assertFalse(mlOperator.hasNext());
}
}

@Test
public void testOpenTrain() {
setUpTrain();
Expand Down
Loading