@@ -5466,6 +5466,42 @@ def prog(x):
54665466 if _VALIDATE_MODEL :
54675467 assert_model_is_valid (prog , {"x" : (2 , 4 )})
54685468
5469+ def test_divide_to_multiply_skip_size (self ):
5470+ @mb .program (input_specs = [mb .TensorSpec (shape = (42 ,))])
5471+ def prog (x ):
5472+ div_const = mb .range_1d (start = 1. , end = 43. , step = 1. )
5473+
5474+ div_val_1 = np .random .rand (42 ).astype (np .float32 )
5475+ div_const_1 = mb .const (val = div_val_1 )
5476+
5477+ real_div = mb .real_div (x = x , y = div_const_1 )
5478+
5479+ return mb .real_div (x = real_div , y = div_const )
5480+
5481+ assert_op_count_match (prog , expect = 2 , op = "real_div" )
5482+ assert_op_count_match (prog , expect = 0 , op = "mul" )
5483+
5484+ def check_counts (divs , muls , const_skip = False ):
5485+ new_prog = copy .deepcopy (prog )
5486+ if const_skip is None :
5487+ PASS_REGISTRY ["common::const_elimination" ](new_prog )
5488+ elif const_skip :
5489+ const_elim = copy .deepcopy (PASS_REGISTRY ["common::const_elimination" ])
5490+ const_elim .skip_const_by_size = const_skip
5491+ const_elim (new_prog )
5492+ PASS_REGISTRY ["common::divide_to_multiply" ](new_prog )
5493+ assert_same_output_names (prog , new_prog )
5494+ assert_op_count_match (new_prog , expect = divs , op = "real_div" )
5495+ assert_op_count_match (new_prog , expect = muls , op = "mul" )
5496+
5497+ check_counts (divs = 1 , muls = 1 )
5498+ check_counts (divs = 0 , muls = 2 , const_skip = None )
5499+ check_counts (divs = 1 , muls = 1 , const_skip = 32 )
5500+ check_counts (divs = 0 , muls = 2 , const_skip = 64 )
5501+
5502+ if _VALIDATE_MODEL :
5503+ assert_model_is_valid (prog , {"x" : (42 ,)})
5504+
54695505
54705506class TestSelectOptimization :
54715507 @pytest .mark .parametrize (
0 commit comments