Skip to content

Commit 0434e85

Browse files
authored
avoid overwriting skip_const_by_size via divide_to_multiply (#2629)
* avoid overwriting `skip_const_by_size` via `divide_to_multiply` * unit test for inverse divide const size skip
1 parent ffe1cfe commit 0434e85

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

coremltools/converters/mil/mil/passes/defs/optimize_elementwise_binary.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@ def _divide_to_multiply_block(self, block):
3939
# to a floating point number. If x or y was originally an integer, and y becomes
4040
# a floating point number, then the original type
4141
# signature (with integer output) would not be preserved.
42-
if op.op_type == "real_div" and op.y.val is not None and _types.is_float(op.x.dtype):
42+
if (
43+
op.op_type == "real_div"
44+
and op.y.val is not None
45+
and op.y.op.op_type == "const"
46+
and _types.is_float(op.x.dtype)
47+
):
4348
new_y_val = np.array(1.0, dtype=op.y.val.dtype) / op.y.val
4449
if not np.isfinite(new_y_val).all():
4550
continue

coremltools/converters/mil/mil/passes/tests/test_passes.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5466,6 +5466,42 @@ def prog(x):
54665466
if _VALIDATE_MODEL:
54675467
assert_model_is_valid(prog, {"x": (2, 4)})
54685468

5469+
def test_divide_to_multiply_skip_size(self):
5470+
@mb.program(input_specs=[mb.TensorSpec(shape=(42,))])
5471+
def prog(x):
5472+
div_const = mb.range_1d(start=1., end=43., step=1.)
5473+
5474+
div_val_1 = np.random.rand(42).astype(np.float32)
5475+
div_const_1 = mb.const(val=div_val_1)
5476+
5477+
real_div = mb.real_div(x=x, y=div_const_1)
5478+
5479+
return mb.real_div(x=real_div, y=div_const)
5480+
5481+
assert_op_count_match(prog, expect=2, op="real_div")
5482+
assert_op_count_match(prog, expect=0, op="mul")
5483+
5484+
def check_counts(divs, muls, const_skip=False):
5485+
new_prog = copy.deepcopy(prog)
5486+
if const_skip is None:
5487+
PASS_REGISTRY["common::const_elimination"](new_prog)
5488+
elif const_skip:
5489+
const_elim = copy.deepcopy(PASS_REGISTRY["common::const_elimination"])
5490+
const_elim.skip_const_by_size = const_skip
5491+
const_elim(new_prog)
5492+
PASS_REGISTRY["common::divide_to_multiply"](new_prog)
5493+
assert_same_output_names(prog, new_prog)
5494+
assert_op_count_match(new_prog, expect=divs, op="real_div")
5495+
assert_op_count_match(new_prog, expect=muls, op="mul")
5496+
5497+
check_counts(divs=1, muls=1)
5498+
check_counts(divs=0, muls=2, const_skip=None)
5499+
check_counts(divs=1, muls=1, const_skip=32)
5500+
check_counts(divs=0, muls=2, const_skip=64)
5501+
5502+
if _VALIDATE_MODEL:
5503+
assert_model_is_valid(prog, {"x": (42,)})
5504+
54695505

54705506
class TestSelectOptimization:
54715507
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)