diff --git a/coremltools/converters/mil/mil/passes/defs/optimize_elementwise_binary.py b/coremltools/converters/mil/mil/passes/defs/optimize_elementwise_binary.py index ddd147248..a26a0adae 100644 --- a/coremltools/converters/mil/mil/passes/defs/optimize_elementwise_binary.py +++ b/coremltools/converters/mil/mil/passes/defs/optimize_elementwise_binary.py @@ -39,7 +39,12 @@ def _divide_to_multiply_block(self, block): # to a floating point number. If x or y was originally an integer, and y becomes # a floating point number, then the original type # signature (with integer output) would not be preserved. - if op.op_type == "real_div" and op.y.val is not None and _types.is_float(op.x.dtype): + if ( + op.op_type == "real_div" + and op.y.val is not None + and op.y.op.op_type == "const" + and _types.is_float(op.x.dtype) + ): new_y_val = np.array(1.0, dtype=op.y.val.dtype) / op.y.val if not np.isfinite(new_y_val).all(): continue diff --git a/coremltools/converters/mil/mil/passes/tests/test_passes.py b/coremltools/converters/mil/mil/passes/tests/test_passes.py index 668523020..ca61d2700 100644 --- a/coremltools/converters/mil/mil/passes/tests/test_passes.py +++ b/coremltools/converters/mil/mil/passes/tests/test_passes.py @@ -5466,6 +5466,42 @@ def prog(x): if _VALIDATE_MODEL: assert_model_is_valid(prog, {"x": (2, 4)}) + def test_divide_to_multiply_skip_size(self): + @mb.program(input_specs=[mb.TensorSpec(shape=(42,))]) + def prog(x): + div_const = mb.range_1d(start=1., end=43., step=1.) + + div_val_1 = np.random.rand(42).astype(np.float32) + div_const_1 = mb.const(val=div_val_1) + + real_div = mb.real_div(x=x, y=div_const_1) + + return mb.real_div(x=real_div, y=div_const) + + assert_op_count_match(prog, expect=2, op="real_div") + assert_op_count_match(prog, expect=0, op="mul") + + def check_counts(divs, muls, const_skip=False): + new_prog = copy.deepcopy(prog) + if const_skip is None: + PASS_REGISTRY["common::const_elimination"](new_prog) + elif const_skip: + const_elim = copy.deepcopy(PASS_REGISTRY["common::const_elimination"]) + const_elim.skip_const_by_size = const_skip + const_elim(new_prog) + PASS_REGISTRY["common::divide_to_multiply"](new_prog) + assert_same_output_names(prog, new_prog) + assert_op_count_match(new_prog, expect=divs, op="real_div") + assert_op_count_match(new_prog, expect=muls, op="mul") + + check_counts(divs=1, muls=1) + check_counts(divs=0, muls=2, const_skip=None) + check_counts(divs=1, muls=1, const_skip=32) + check_counts(divs=0, muls=2, const_skip=64) + + if _VALIDATE_MODEL: + assert_model_is_valid(prog, {"x": (42,)}) + class TestSelectOptimization: @pytest.mark.parametrize(