From 3fd7b96241ce54155fcd515a3bbb58b3f83b5ed3 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Fri, 27 Feb 2026 18:50:10 -0800 Subject: [PATCH] Fix null assignment to fields PiperOrigin-RevId: 876500682 --- .../dev/cel/common/internal/ProtoAdapter.java | 33 +++++++++++------ .../common/values/ProtoCelValueConverter.java | 2 +- .../cel/common/internal/ProtoAdapterTest.java | 17 ++++++--- .../test/java/dev/cel/conformance/BUILD.bazel | 17 --------- .../dev/cel/conformance/ConformanceTest.java | 2 +- .../cel/runtime/CelLiteInterpreterTest.java | 5 +++ .../test/resources/nullAssignability.baseline | 35 +++++++++++++++++++ .../dev/cel/testing/BaseInterpreterTest.java | 25 +++++++++++++ 8 files changed, 103 insertions(+), 33 deletions(-) create mode 100644 runtime/src/test/resources/nullAssignability.baseline diff --git a/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java b/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java index 962a9d2e9..b6648a5b8 100644 --- a/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java +++ b/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java @@ -204,8 +204,29 @@ public Optional adaptFieldToValue(FieldDescriptor fieldDescriptor, Objec @SuppressWarnings({"unchecked", "rawtypes"}) public Optional adaptValueToFieldType( FieldDescriptor fieldDescriptor, Object fieldValue) { - if (isWrapperType(fieldDescriptor) && fieldValue.equals(NullValue.NULL_VALUE)) { - return Optional.empty(); + if (fieldValue instanceof NullValue) { + // `null` assignment to fields indicate that the field would not be set + // in a protobuf message (e.g: Message{msg_field: null} -> Message{}) + // + // We explicitly check below for invalid null assignments, such as repeated + // or map fields. (e.g: Message{repeated_field: null} -> Error) + if (fieldDescriptor.isMapField() + || fieldDescriptor.isRepeated() + || fieldDescriptor.getJavaType() != FieldDescriptor.JavaType.MESSAGE + || WellKnownProto.JSON_STRUCT_VALUE + .typeName() + .equals(fieldDescriptor.getMessageType().getFullName()) + || WellKnownProto.JSON_LIST_VALUE + .typeName() + .equals(fieldDescriptor.getMessageType().getFullName())) { + throw new IllegalArgumentException("Unsupported field type"); + } + + String typeFullName = fieldDescriptor.getMessageType().getFullName(); + if (!WellKnownProto.ANY_VALUE.typeName().equals(typeFullName) + && !WellKnownProto.JSON_VALUE.typeName().equals(typeFullName)) { + return Optional.empty(); + } } if (fieldDescriptor.isMapField()) { Descriptor entryDescriptor = fieldDescriptor.getMessageType(); @@ -370,14 +391,6 @@ private static String typeName(Descriptor protoType) { return protoType.getFullName(); } - private static boolean isWrapperType(FieldDescriptor fieldDescriptor) { - if (fieldDescriptor.getJavaType() != FieldDescriptor.JavaType.MESSAGE) { - return false; - } - String fieldTypeName = fieldDescriptor.getMessageType().getFullName(); - return WellKnownProto.isWrapperType(fieldTypeName); - } - private static int intCheckedCast(long value) { try { return Ints.checkedCast(value); diff --git a/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java b/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java index c7b829e13..565c65438 100644 --- a/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java +++ b/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java @@ -67,7 +67,7 @@ protected Object fromWellKnownProto(MessageLiteOrBuilder msg, WellKnownProto wel try { unpackedMessage = dynamicProto.unpack((Any) message); } catch (InvalidProtocolBufferException e) { - throw new IllegalStateException( + throw new IllegalArgumentException( "Unpacking failed for message: " + message.getDescriptorForType().getFullName(), e); } return toRuntimeValue(unpackedMessage); diff --git a/common/src/test/java/dev/cel/common/internal/ProtoAdapterTest.java b/common/src/test/java/dev/cel/common/internal/ProtoAdapterTest.java index 61c71e4a6..91e0e22db 100644 --- a/common/src/test/java/dev/cel/common/internal/ProtoAdapterTest.java +++ b/common/src/test/java/dev/cel/common/internal/ProtoAdapterTest.java @@ -150,10 +150,7 @@ public static List data() { @Test public void adaptValueToProto_bidirectionalConversion() { DynamicProto dynamicProto = DynamicProto.create(DefaultMessageFactory.INSTANCE); - ProtoAdapter protoAdapter = - new ProtoAdapter( - dynamicProto, - CelOptions.current().build()); + ProtoAdapter protoAdapter = new ProtoAdapter(dynamicProto, CelOptions.current().build()); assertThat(protoAdapter.adaptValueToProto(value, proto.getDescriptorForType().getFullName())) .isEqualTo(proto); assertThat(protoAdapter.adaptProtoToValue(proto)).isEqualTo(value); @@ -181,6 +178,18 @@ public void adaptAnyValue_hermeticTypes_bidirectionalConversion() { @RunWith(JUnit4.class) public static class AsymmetricConversionTest { + + @Test + public void unpackAny_celNullValue() throws Exception { + ProtoAdapter protoAdapter = new ProtoAdapter(DYNAMIC_PROTO, CelOptions.DEFAULT); + Any any = + (Any) + protoAdapter.adaptValueToProto( + dev.cel.common.values.NullValue.NULL_VALUE, "google.protobuf.Any"); + Object unpacked = protoAdapter.adaptProtoToValue(any); + assertThat(unpacked).isEqualTo(dev.cel.common.values.NullValue.NULL_VALUE); + } + @Test public void adaptValueToProto_asymmetricFloatConversion() { ProtoAdapter protoAdapter = new ProtoAdapter(DYNAMIC_PROTO, CelOptions.DEFAULT); diff --git a/conformance/src/test/java/dev/cel/conformance/BUILD.bazel b/conformance/src/test/java/dev/cel/conformance/BUILD.bazel index d6b2296e5..fb2b1a159 100644 --- a/conformance/src/test/java/dev/cel/conformance/BUILD.bazel +++ b/conformance/src/test/java/dev/cel/conformance/BUILD.bazel @@ -124,14 +124,6 @@ _TESTS_TO_SKIP_LEGACY = [ "string_ext/format", "string_ext/format_errors", - # TODO: Fix null assignment to a field - "proto2/set_null/single_message", - "proto2/set_null/single_duration", - "proto2/set_null/single_timestamp", - "proto3/set_null/single_message", - "proto3/set_null/single_duration", - "proto3/set_null/single_timestamp", - # Future features for CEL 1.0 # TODO: Strong typing support for enums, specified but not implemented. "enums/strong_proto2", @@ -162,7 +154,6 @@ _TESTS_TO_SKIP_PLANNER = [ "string_ext/format_errors", # TODO: Check behavior for go/cpp - "basic/functions/unbound", "basic/functions/unbound_is_runtime_error", # TODO: Ensure overflow occurs on conversions of double values which might not work properly on all platforms. @@ -177,14 +168,6 @@ _TESTS_TO_SKIP_PLANNER = [ # Skip until fixed. "parse/receiver_function_names", - # TODO: Fix null assignment to a field - "proto2/set_null/single_message", - "proto2/set_null/single_duration", - "proto2/set_null/single_timestamp", - "proto3/set_null/single_message", - "proto3/set_null/single_duration", - "proto3/set_null/single_timestamp", - # Type inference edgecases around null(able) assignability. # These type check, but resolve to a different type. # list(int), want list(wrapper(int)) diff --git a/conformance/src/test/java/dev/cel/conformance/ConformanceTest.java b/conformance/src/test/java/dev/cel/conformance/ConformanceTest.java index 437e50fea..5a25fb9d9 100644 --- a/conformance/src/test/java/dev/cel/conformance/ConformanceTest.java +++ b/conformance/src/test/java/dev/cel/conformance/ConformanceTest.java @@ -210,10 +210,10 @@ public void evaluate() throws Throwable { } CelRuntime runtime = getRuntime(test, usePlanner); - Program program = runtime.createProgram(response.getAst()); ExprValue result = null; CelEvaluationException error = null; try { + Program program = runtime.createProgram(response.getAst()); result = toExprValue(program.eval(getBindings(test)), response.getAst().getResultType()); } catch (CelEvaluationException e) { error = e; diff --git a/runtime/src/test/java/dev/cel/runtime/CelLiteInterpreterTest.java b/runtime/src/test/java/dev/cel/runtime/CelLiteInterpreterTest.java index b3a1f2efa..1d1a316c0 100644 --- a/runtime/src/test/java/dev/cel/runtime/CelLiteInterpreterTest.java +++ b/runtime/src/test/java/dev/cel/runtime/CelLiteInterpreterTest.java @@ -54,6 +54,11 @@ public void dynamicMessage_dynamicDescriptor() throws Exception { // All the tests below rely on message creation with fields populated. They are excluded for time // being until this support is added. + @Override + public void nullAssignability() throws Exception { + skipBaselineVerification(); + } + @Override public void wrappers() throws Exception { skipBaselineVerification(); diff --git a/runtime/src/test/resources/nullAssignability.baseline b/runtime/src/test/resources/nullAssignability.baseline new file mode 100644 index 000000000..47b9c7a0d --- /dev/null +++ b/runtime/src/test/resources/nullAssignability.baseline @@ -0,0 +1,35 @@ +Source: TestAllTypes{single_int64_wrapper: null}.single_int64_wrapper == null +=====> +bindings: {} +result: true + +Source: TestAllTypes{}.single_int64_wrapper == null +=====> +bindings: {} +result: true + +Source: has(TestAllTypes{single_int64_wrapper: null}.single_int64_wrapper) +=====> +bindings: {} +result: false + +Source: TestAllTypes{single_value: null}.single_value == null +=====> +bindings: {} +result: true + +Source: has(TestAllTypes{single_value: null}.single_value) +=====> +bindings: {} +result: true + +Source: TestAllTypes{single_timestamp: null}.single_timestamp == timestamp(0) +=====> +bindings: {} +result: true + +Source: has(TestAllTypes{single_timestamp: null}.single_timestamp) +=====> +bindings: {} +result: false + diff --git a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java index bc67e8218..144ada5a8 100644 --- a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java +++ b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java @@ -2122,6 +2122,31 @@ public void wrappers() throws Exception { runTest(); } + @Test + public void nullAssignability() throws Exception { + setContainer(CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage())); + source = "TestAllTypes{single_int64_wrapper: null}.single_int64_wrapper == null"; + runTest(); + + source = "TestAllTypes{}.single_int64_wrapper == null"; + runTest(); + + source = "has(TestAllTypes{single_int64_wrapper: null}.single_int64_wrapper)"; + runTest(); + + source = "TestAllTypes{single_value: null}.single_value == null"; + runTest(); + + source = "has(TestAllTypes{single_value: null}.single_value)"; + runTest(); + + source = "TestAllTypes{single_timestamp: null}.single_timestamp == timestamp(0)"; + runTest(); + + source = "has(TestAllTypes{single_timestamp: null}.single_timestamp)"; + runTest(); + } + @Test public void longComprehension() { ImmutableList l = LongStream.range(0L, 1000L).boxed().collect(toImmutableList());