Skip to content
32 changes: 32 additions & 0 deletions api/dynamic_tests_v2/interp_bilinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,38 @@ def build_graph(self, config):
if config.backward:
self.append_gradients(out, [x])

def compute_flop_and_byte(self, config):
# at least one of out_shape and scale must be set
x_shape = config.x_shape
out_size = config.size

assert (config.scale_factor is not None or out_size is not None
), "at least one of out_shape and scale must be set"
# config.size has higher priority than config.scale_factor
if isinstance(out_size, (list, tuple)):
out_shape = x_shape[0:-len(out_size)] + out_size
# scale_factor shouldn`t to be a float in bilinear mode
elif isinstance(config.scale_factor, (list, tuple)):
scale_length = len(config.scale_factor)
change_out = x_shape[-scale_length:]
scale_out = [
i * j for i, j in zip(change_out, config.scale_factor)
]
out_shape = x_shape[0:-scale_length] + scale_out

# forward flops, sub*10 + mul*9 + div*1 + add*3
forward_flop = numel(out_shape) * 23

# forward byte, read 4 address to compute 1 address
read_byte = 4 * numel(out_shape) * sizeof(config.x_dtype)
write_byte = numel(out_shape) * sizeof(config.x_dtype)
forward_byte = read_byte + write_byte
if not config.backward:
return forward_flop, forward_byte
else:
# to be implemented.
return None, None


class TorchInterpBilinear(PytorchAPIBenchmarkBase):
def build_graph(self, config):
Expand Down