Skip to content

Comments

[Feature][IR] Make implicit conversions for binary arithmetic operations optional#462

Draft
Fangtangtang wants to merge 6 commits intocornell-zhang:mainfrom
Fangtangtang:optional-casting
Draft

[Feature][IR] Make implicit conversions for binary arithmetic operations optional#462
Fangtangtang wants to merge 6 commits intocornell-zhang:mainfrom
Fangtangtang:optional-casting

Conversation

@Fangtangtang
Copy link
Collaborator

Description

This PR enable user to optionally disable casting for some binary arithmetic operations.

Problems

Previously, the constructed IR contains many casting operations, some of them are redundant.

e.g.

VLEN = 256
ELEN = 32

def test_vadd():
    @df.region()
    def top():
        @df.kernel(mapping=[1])
        def VEC(
            A: uint256[1],
            B: uint256[1],
            C: uint256[1],
        ):
            for i in allo.grid(VLEN // ELEN, name="vec_nest"):
                C[0][i * ELEN : (i + 1) * ELEN] = (
                    A[0][i * ELEN : (i + 1) * ELEN] + B[0][i * ELEN : (i + 1) * ELEN]
                )

The constructed IR is

module {
  func.func @VEC_0(%arg0: memref<1xi256>, %arg1: memref<1xi256>, %arg2: memref<1xi256>) attributes {df.kernel, itypes = "uuu", otypes = "", stypes = "___"} {
    %c1_i66 = arith.constant 1 : i66
    %c32_i66 = arith.constant 32 : i66
    %c1_i34 = arith.constant 1 : i34
    %c32_i64 = arith.constant 32 : i64
    affine.for %arg3 = 0 to 8 {
      %0 = affine.load %arg0[0] {from = "A", unsigned} : memref<1xi256>
      %1 = arith.index_cast %arg3 : index to i64
      %2 = arith.muli %1, %c32_i64 : i64
      %3 = arith.index_cast %arg3 : index to i34
      %4 = arith.addi %3, %c1_i34 : i34
      %5 = arith.extsi %4 : i34 to i66
      %6 = arith.muli %5, %c32_i66 : i66
      %7 = arith.subi %6, %c1_i66 : i66
      %8 = arith.index_cast %2 : i64 to index
      %9 = arith.index_cast %7 : i66 to index
      %10 = allo.get_slice(%0 : i256, %9, %8) -> i32
      %11 = affine.load %arg1[0] {from = "B", unsigned} : memref<1xi256>
      %12 = allo.get_slice(%11 : i256, %9, %8) -> i32
      %13 = arith.extui %10 {unsigned} : i32 to i33
      %14 = arith.extui %12 {unsigned} : i32 to i33
      %15 = arith.addi %13, %14 {unsigned} : i33
      %16 = arith.trunci %15 {unsigned} : i33 to i32
      %17 = affine.load %arg2[0] {from = "C", unsigned} : memref<1xi256>
      %18 = allo.set_slice(%17 : i256, %9, %8, %16 : i32) -> i256
      affine.store %18, %arg2[0] {to = "C"} : memref<1xi256>
    } {loop_name = "i", op_name = "vec_nest", unroll = 0 : i32}
    return
  }
  func.func @top(%arg0: memref<1xi256>, %arg1: memref<1xi256>, %arg2: memref<1xi256>) attributes {dataflow, itypes = "uuu"} {
    call @VEC_0(%arg0, %arg1, %arg2) {last} : (memref<1xi256>, memref<1xi256>, memref<1xi256>) -> ()
    return
  }
}

and for vitis_hls backend, the generated kernel.cpp constains

void VEC_0(
  ap_uint<256> v0[1],
  ap_uint<256> v1[1],
  ap_uint<256> v2[1]
) {	// L2
  l_vec_nest_i: for (int i = 0; i < 8; i++) {	// L7
  #pragma HLS unroll
    ap_uint<256> v4 = v0[0];	// L8
    int64_t v5 = i;	// L9
    int64_t v6 = v5 * 32;	// L10
    ap_int<34> v7 = i;	// L11
    ap_int<34> v8 = v7 + 1;	// L12
    ap_int<66> v9 = v8;	// L13
    ap_int<66> v10 = v9 * 32;	// L14
    ap_int<66> v11 = v10 - 1;	// L15
    int v12 = v6;	// L16
    int v13 = v11;	// L17
    int32_t v14;
    ap_int<256> v14_tmp = v4;
    v14 = v14_tmp(v13, v12);	// L18
    ap_uint<256> v15 = v1[0];	// L19
    int32_t v16;
    ap_int<256> v16_tmp = v15;
    v16 = v16_tmp(v13, v12);	// L20
    ap_uint<33> v17 = v14;	// L21
    ap_uint<33> v18 = v16;	// L22
    ap_uint<33> v19 = v17 + v18;	// L23
    uint32_t v20 = v19;	// L24
    ap_uint<256> v21 = v2[0];	// L25
    ap_int<256> v22;
    ap_int<256> v22_tmp = v21;
    v22_tmp(v13, v12) = v20;
    v22 = v22_tmp;	// L26
    v2[0] = v22;	// L27
  }
}

Proposed Solutions

add simplified typing rules for binary arithmetic operations, allowing users to control with flag.

Examples

The simplified typing rules generally follows cpp implicit conversion rules and only apply to

- int8, int16, int32, int64
- uint8, uint16, uint32, uint64
- index (int32)
- bfloat16, float16, float32, float64

Users can export USE_LESS_CASTING=1 to use these rules for binary arithmetic operations. Currently we don't provide fallback for unsupported data types

Applying the simplified typing rules, IR for the same program above looks like this

module {
  func.func @VEC_0(%arg0: memref<1xi256>, %arg1: memref<1xi256>, %arg2: memref<1xi256>) attributes {df.kernel, itypes = "uuu", otypes = "", stypes = "___"} {
    %c1_i32 = arith.constant 1 : i32
    %c32_i32 = arith.constant 32 : i32
    affine.for %arg3 = 0 to 8 {
      %0 = affine.load %arg0[0] {from = "A", unsigned} : memref<1xi256>
      %1 = arith.index_cast %arg3 : index to i32
      %2 = arith.muli %1, %c32_i32 : i32
      %3 = arith.addi %1, %c1_i32 : i32
      %4 = arith.muli %3, %c32_i32 : i32
      %5 = arith.subi %4, %c1_i32 : i32
      %6 = arith.index_cast %2 : i32 to index
      %7 = arith.index_cast %5 : i32 to index
      %8 = allo.get_slice(%0 : i256, %7, %6) -> i32
      %9 = affine.load %arg1[0] {from = "B", unsigned} : memref<1xi256>
      %10 = allo.get_slice(%9 : i256, %7, %6) -> i32
      %11 = arith.addi %8, %10 {unsigned} : i32
      %12 = affine.load %arg2[0] {from = "C", unsigned} : memref<1xi256>
      %13 = allo.set_slice(%12 : i256, %7, %6, %11 : i32) -> i256
      affine.store %13, %arg2[0] {to = "C"} : memref<1xi256>
    } {loop_name = "i", op_name = "vec_nest", unroll = 0 : i32}
    return
  }
  func.func @top(%arg0: memref<1xi256>, %arg1: memref<1xi256>, %arg2: memref<1xi256>) attributes {dataflow, itypes = "uuu"} {
    call @VEC_0(%arg0, %arg1, %arg2) {last} : (memref<1xi256>, memref<1xi256>, memref<1xi256>) -> ()
    return
  }
}

in kernel.cpp

void VEC_0(
  ap_uint<256> v0[1],
  ap_uint<256> v1[1],
  ap_uint<256> v2[1]
) {	// L2
  l_vec_nest_i: for (int i = 0; i < 8; i++) {	// L5
  #pragma HLS unroll
    ap_uint<256> v4 = v0[0];	// L6
    int v5 = i * 32;	// L7
    int v6 = i + 1;	// L8
    int v7 = v6 * 32;	// L9
    int v8 = v7 - 1;	// L10
    int32_t v9;
    ap_int<256> v9_tmp = v4;
    v9 = v9_tmp(v8, v5);	// L11
    ap_uint<256> v10 = v1[0];	// L12
    int32_t v11;
    ap_int<256> v11_tmp = v10;
    v11 = v11_tmp(v8, v5);	// L13
    int32_t v12 = v9 + v11;	// L14
    ap_uint<256> v13 = v2[0];	// L15
    ap_int<256> v14;
    ap_int<256> v14_tmp = v13;
    v14_tmp(v8, v5) = v12;
    v14 = v14_tmp;	// L16
    v2[0] = v14;	// L17
  }
}

Checklist

Please make sure to review and check all of these items:

  • PR's title starts with a category (e.g. [Bugfix], [IR], [Builder], etc)
  • All changes have test coverage (It would be good to provide ~2 different test cases to test the robustness of your code)
  • Pass the formatting check locally
  • Code is well-documented

@zhangzhiru
Copy link

Can we discuss this in one of the Allo meetings later?

@Fangtangtang
Copy link
Collaborator Author

sure!

@Fangtangtang
Copy link
Collaborator Author

[update]

Previous casting rule also seems inappropriate for aie backend.

For example,

    Ty = int32
    M = 1024
    @df.region()
    def top():
        @df.kernel(mapping=[1])
        def core(A: Ty[M], B: Ty[M], C: Ty[M]):
            C[:] = A + B

    A = np.random.randint(0, 100, M).astype(np.int32)
    B = np.random.randint(0, 100, M).astype(np.int32)
    mod = df.build(top, target="aie")
    C = np.zeros(M).astype(np.int32)
    mod(A, B, C)
    np.testing.assert_allclose(C, A + B)

may produce incorrect results.

The correctness even depends on the value range of A and B. I think that's because the backend does not properly support types such as i33, leading to undefined behavior during lowering.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants