diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index d91edbfce6d..9e382032c1d 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -1246,7 +1246,7 @@ Result> MakeListElementReference( return MakeDirectReference(std::move(expr), std::move(ref_segment)); } -Result> EncodeSubstraitCall( + Result> EncodeSubstraitCall( const SubstraitCall& call, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { ARROW_ASSIGN_OR_RAISE(uint32_t anchor, ext_set->EncodeFunction(call.id())); @@ -1272,19 +1272,18 @@ Result> EncodeSubstraitCa " arguments but no argument could be found at index ", i); } } - for (const auto& option : call.options()) { substrait::FunctionOption* fn_option = scalar_fn->add_options(); fn_option->set_name(option.first); for (const auto& opt_val : option.second) { - std::string* pref = fn_option->add_preference(); - *pref = opt_val; + *fn_option->add_preference() = opt_val; } } return scalar_fn; } + Result>> DatumToLiterals( const Datum& datum, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { @@ -1365,83 +1364,87 @@ Result> ToProto( auto call = CallNotNull(expr); if (call->function_name == "case_when") { - auto conditions = call->arguments[0].call(); - if (conditions && conditions->function_name == "make_struct") { - // catch the special case of calls convertible to IfThen - auto if_then_ = std::make_unique(); - - // don't try to convert argument 0 of the case_when; we have to convert the elements - // of make_struct individually - std::vector> arguments( - call->arguments.size() - 1); - for (size_t i = 1; i < call->arguments.size(); ++i) { - ARROW_ASSIGN_OR_RAISE(arguments[i - 1], - ToProto(call->arguments[i], ext_set, conversion_options)); - } + auto conditions = call->arguments[0].call(); + if (conditions && conditions->function_name == "make_struct") { + auto if_then_ = std::make_unique(); - for (size_t i = 0; i < conditions->arguments.size(); ++i) { - ARROW_ASSIGN_OR_RAISE(auto cond_substrait, ToProto(conditions->arguments[i], - ext_set, conversion_options)); - auto clause = std::make_unique(); - clause->set_allocated_if_(cond_substrait.release()); - clause->set_allocated_then(arguments[i].release()); - if_then_->mutable_ifs()->AddAllocated(clause.release()); - } + std::vector> converted_args; + converted_args.reserve(call->arguments.size()); - if_then_->set_allocated_else_(arguments.back().release()); + for (size_t i = 1; i < call->arguments.size(); ++i) { + ARROW_ASSIGN_OR_RAISE( + auto arg, ToProto(call->arguments[i], ext_set, conversion_options)); + converted_args.push_back(std::move(arg)); + } - out->set_allocated_if_then(if_then_.release()); - return out; + for (size_t i = 0; i < conditions->arguments.size(); ++i) { + ARROW_ASSIGN_OR_RAISE( + auto cond, ToProto(conditions->arguments[i], ext_set, conversion_options)); + auto clause = std::make_unique(); + clause->set_allocated_if_(cond.release()); + clause->set_allocated_then(converted_args[i].release()); + if_then_->mutable_ifs()->AddAllocated(clause.release()); } + + if_then_->set_allocated_else_(converted_args.back().release()); + out->set_allocated_if_then(if_then_.release()); + return out; } +} // the remaining function pattern matchers only convert the function itself, so we // should be able to convert all its arguments first here - std::vector> arguments(call->arguments.size()); - for (size_t i = 0; i < arguments.size(); ++i) { - ARROW_ASSIGN_OR_RAISE(arguments[i], - ToProto(call->arguments[i], ext_set, conversion_options)); - } - if (call->function_name == "struct_field") { - // catch the special case of calls convertible to a StructField - const auto& field_options = - checked_cast(*call->options); - const DataType& struct_type = *call->arguments[0].type(); - DCHECK_EQ(struct_type.id(), Type::STRUCT); - - ARROW_ASSIGN_OR_RAISE(auto field_path, field_options.field_ref.FindOne(struct_type)); - out = std::move(arguments[0]); - for (int index : field_path.indices()) { - ARROW_ASSIGN_OR_RAISE(out, MakeStructFieldReference(std::move(out), index)); - } - return out; + ARROW_ASSIGN_OR_RAISE( + auto base, ToProto(call->arguments[0], ext_set, conversion_options)); + + const auto& field_options = + checked_cast(*call->options); + const DataType& struct_type = *call->arguments[0].type(); + + ARROW_ASSIGN_OR_RAISE(auto field_path, field_options.field_ref.FindOne(struct_type)); + out = std::move(base); + + for (int index : field_path.indices()) { + ARROW_ASSIGN_OR_RAISE(out, MakeStructFieldReference(std::move(out), index)); } + return out; +} + if (call->function_name == "list_element") { - // catch the special case of calls convertible to a ListElement - if (arguments[0]->has_selection() && - arguments[0]->selection().has_direct_reference()) { - if (arguments[1]->has_literal() && arguments[1]->literal().literal_type_case() == - substrait::Expression::Literal::kI32) { - return MakeListElementReference(std::move(arguments[0]), - arguments[1]->literal().i32()); - } - } + ARROW_ASSIGN_OR_RAISE( + auto base, ToProto(call->arguments[0], ext_set, conversion_options)); + ARROW_ASSIGN_OR_RAISE( + auto offset, ToProto(call->arguments[1], ext_set, conversion_options)); + + if (base->has_selection() && + offset->has_literal() && + offset->literal().literal_type_case() == + substrait::Expression::Literal::kI32) { + return MakeListElementReference(std::move(base), offset->literal().i32()); } +} + if (call->function_name == "if_else") { - // catch the special case of calls convertible to IfThen - auto if_clause = std::make_unique(); - if_clause->set_allocated_if_(arguments[0].release()); - if_clause->set_allocated_then(arguments[1].release()); + ARROW_ASSIGN_OR_RAISE( + auto if_, ToProto(call->arguments[0], ext_set, conversion_options)); + ARROW_ASSIGN_OR_RAISE( + auto then_, ToProto(call->arguments[1], ext_set, conversion_options)); + ARROW_ASSIGN_OR_RAISE( + auto else_, ToProto(call->arguments[2], ext_set, conversion_options)); - auto if_then = std::make_unique(); - if_then->mutable_ifs()->AddAllocated(if_clause.release()); - if_then->set_allocated_else_(arguments[2].release()); + auto clause = std::make_unique(); + clause->set_allocated_if_(if_.release()); + clause->set_allocated_then(then_.release()); - out->set_allocated_if_then(if_then.release()); - return out; + auto if_then = std::make_unique(); + if_then->mutable_ifs()->AddAllocated(clause.release()); + if_then->set_allocated_else_(else_.release()); + + out->set_allocated_if_then(if_then.release()); + return out; } else if (call->function_name == "cast") { auto cast = std::make_unique(); @@ -1456,44 +1459,45 @@ Result> ToProto( return Status::Invalid("Substrait is only capable of representing unsafe casts"); } - if (arguments.size() != 1) { - return Status::Invalid( - "A call to the cast function must have exactly one argument"); - } - - cast->set_allocated_input(arguments[0].release()); - - ARROW_ASSIGN_OR_RAISE(std::unique_ptr to_type, - ToProto(*cast_options->to_type.type, /*nullable=*/true, ext_set, - conversion_options)); + if (call->arguments.size() != 1) { + return Status::Invalid( + "A call to the cast function must have exactly one argument"); +} - cast->set_allocated_type(to_type.release()); +ARROW_ASSIGN_OR_RAISE( + auto input, ToProto(call->arguments[0], ext_set, conversion_options)); +cast->set_allocated_input(input.release()); out->set_allocated_cast(cast.release()); return out; } else if (call->function_name == "is_in") { - auto or_list = std::make_unique(); + auto or_list = std::make_unique(); - if (arguments.size() != 1) { - return Status::Invalid( - "A call to the is_in function must have exactly one argument"); - } + if (call->arguments.size() != 1) { + return Status::Invalid( + "A call to the is_in function must have exactly one argument"); + } - or_list->set_allocated_value(arguments[0].release()); - std::shared_ptr is_in_options = - internal::checked_pointer_cast(call->options); + ARROW_ASSIGN_OR_RAISE( + auto value, ToProto(call->arguments[0], ext_set, conversion_options)); + or_list->set_allocated_value(value.release()); - // TODO(GH-36420) Acero does not currently handle nulls correctly - ARROW_ASSIGN_OR_RAISE( - std::vector> options, - DatumToLiterals(is_in_options->value_set, ext_set, conversion_options)); - for (auto& option : options) { - or_list->mutable_options()->AddAllocated(option.release()); - } - out->set_allocated_singular_or_list(or_list.release()); - return out; + std::shared_ptr is_in_options = + internal::checked_pointer_cast(call->options); + + ARROW_ASSIGN_OR_RAISE( + std::vector> options, + DatumToLiterals(is_in_options->value_set, ext_set, conversion_options)); + + for (auto& option : options) { + or_list->mutable_options()->AddAllocated(option.release()); } + out->set_allocated_singular_or_list(or_list.release()); + return out; +} + + // other expression types dive into extensions immediately Result maybe_converter = ext_set->registry()->GetArrowToSubstraitCall(call->function_name);