Skip to content

[BUG][IR][AIE] Inappropriate type casting for aie backend #484

@Fangtangtang

Description

@Fangtangtang

Describe the bug
Allo has implicit type casting, but different backends have different data type support. The current typing rules were designed for hls backends but are incompatible with the aie backend, which leads to incorrect results.

To Reproduce
run tests/dataflow/aie/test_tp.py

Buggy output

AssertionError: 
Not equal to tolerance rtol=1e-07, atol=1e-05

Mismatched elements: 31 / 64 (48.4%)
Max absolute difference: 2671354
Max relative difference: 2.31548599
 x: array([[1624830, 1924335, 2056509, 1984767, 1817787, 1652966, 2190470,
        1419622],
       [ 926552, 1225078,  976136, 1081069,  993961,  934372, 1197088,...
 y: array([[1624830, 1924335, 2056509, 1984767, 1817787, 1652966, 2190470,
        1419622],
       [ 926552, 1225078,  976136, 1081069,  993961,  934372, 1197088,...

Note: the corresponding allo mlir is:

#map = affine_map<(d0, d1) -> (d0, d1)>
module {
  func.func @gemm0_0(%arg0: memref<8x8xi32>, %arg1: memref<8x4xi32>, %arg2: !allo.stream<memref<8x4xi32>, 2>) attributes {df.kernel, itypes = "ss_", otypes = "", stypes = "__o", tag = "gemm0_()"} {
    %c0_i32 = arith.constant 0 : i32
    %alloc = memref.alloc() : memref<8x4xi32>
    linalg.fill {op_name = "matmul_init_zero_0"} ins(%c0_i32 : i32) outs(%alloc : memref<8x4xi32>)
    linalg.matmul {cast = #linalg.type_fn<cast_signed>, op_name = "matmul_1"} ins(%arg0, %arg1 : memref<8x8xi32>, memref<8x4xi32>) outs(%alloc : memref<8x4xi32>)
    allo.stream_put(%arg2, [], %alloc) : !allo.stream<memref<8x4xi32>, 2> contains memref<8x4xi32>
    return
  }
  func.func @gemm1_0(%arg0: memref<4x8xi32>, %arg1: !allo.stream<memref<8x4xi32>, 2>, %arg2: !allo.stream<memref<8x8xi32>, 2>) attributes {df.kernel, itypes = "s__", otypes = "", stypes = "_io", tag = "gemm1_()"} {
    %c0_i32 = arith.constant 0 : i32
    %0 = allo.stream_get(%arg1, []) : !allo.stream<memref<8x4xi32>, 2> -> memref<8x4xi32>
    %alloc = memref.alloc() : memref<8x8xi32>
    linalg.fill {op_name = "matmul_init_zero_0"} ins(%c0_i32 : i32) outs(%alloc : memref<8x8xi32>)
    linalg.matmul {cast = #linalg.type_fn<cast_signed>, op_name = "matmul_1"} ins(%0, %arg0 : memref<8x4xi32>, memref<4x8xi32>) outs(%alloc : memref<8x8xi32>)
    allo.stream_put(%arg2, [], %alloc) : !allo.stream<memref<8x8xi32>, 2> contains memref<8x8xi32>
    return
  }
  func.func @acc_0(%arg0: memref<8x8xi32>, %arg1: !allo.stream<memref<8x8xi32>, 2>) attributes {df.kernel, itypes = "s_", otypes = "", stypes = "_i", tag = "acc_()"} {
    %c0_i32 = arith.constant 0 : i32
    %alloc = memref.alloc() {name = "Z_out"} : memref<8x8xi32>
    linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<8x8xi32>)
    affine.for %arg2 = 0 to 2 {
      %0 = allo.stream_get(%arg1, []) : !allo.stream<memref<8x8xi32>, 2> -> memref<8x8xi32>
      %alloc_0 = memref.alloc() : memref<8x8xi33>
      linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc : memref<8x8xi32>) outs(%alloc_0 : memref<8x8xi33>) attrs =  {cast_from = "i32", cast_to = "i33"} {
      ^bb0(%in: i32, %out: i33):
        %1 = arith.extsi %in : i32 to i33
        linalg.yield %1 : i33
      }
      %alloc_1 = memref.alloc() : memref<8x8xi33>
      linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : memref<8x8xi32>) outs(%alloc_1 : memref<8x8xi33>) attrs =  {cast_from = "i32", cast_to = "i33"} {
      ^bb0(%in: i32, %out: i33):
        %1 = arith.extsi %in : i32 to i33
        linalg.yield %1 : i33
      }
      %alloc_2 = memref.alloc() : memref<8x8xi33>
      linalg.add {op_name = "add_0"} ins(%alloc_0, %alloc_1 : memref<8x8xi33>, memref<8x8xi33>) outs(%alloc_2 : memref<8x8xi33>)
      %alloc_3 = memref.alloc() : memref<8x8xi32>
      linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_2 : memref<8x8xi33>) outs(%alloc_3 : memref<8x8xi32>) attrs =  {cast_from = "i33", cast_to = "i32"} {
      ^bb0(%in: i33, %out: i32):
        %1 = arith.trunci %in : i33 to i32
        linalg.yield %1 : i32
      }
      memref.copy %alloc_3, %alloc {to = "Z_out"} : memref<8x8xi32> to memref<8x8xi32>
    } {loop_name = "i", op_name = "S_i_0"}
    memref.copy %alloc, %arg0 {to = "Z"} : memref<8x8xi32> to memref<8x8xi32>
    return
  }
  func.func @top(%arg0: memref<8x8xi32>, %arg1: memref<8x8xi32>, %arg2: memref<8x8xi32>, %arg3: memref<8x8xi32>) attributes {dataflow, itypes = "ssss"} {
    %0 = allo.stream_construct() {name = "Y_0"} : !allo.stream<memref<8x4xi32>, 2>
    %1 = allo.stream_construct() {name = "Y_1"} : !allo.stream<memref<8x4xi32>, 2>
    %2 = allo.stream_construct() {name = "part_Z_0"} : !allo.stream<memref<8x8xi32>, 2>
    %3 = allo.stream_construct() {name = "part_Z_1"} : !allo.stream<memref<8x8xi32>, 2>
    return
  }
}

and top.mlir (in mlir-aie) includes

    %core_0_4 = aie.core(%tile_0_4) {
      %c0 = arith.constant 0 : index
      %c1 = arith.constant 1 : index
      %c9223372036854775807 = arith.constant 9223372036854775807 : index
      scf.for %arg0 = %c0 to %c9223372036854775807 step %c1 {
        %0 = aie.objectfifo.acquire @fifo_8(Produce, 1) : !aie.objectfifosubview<memref<8x8xi32>>
        %1 = aie.objectfifo.subview.access %0[0] : !aie.objectfifosubview<memref<8x8xi32>> -> memref<8x8xi32>
        func.call @fill_zeros_i32_8_8_vector(%1) {lib = "fill_zeros_i32_8_8_vector"} : (memref<8x8xi32>) -> ()
        affine.for %arg1 = 0 to 2 {
          %2 = scf.index_switch %arg1 -> memref<8x8xi32> 
          case 1 {
            %3 = aie.objectfifo.acquire @part_Z_1(Consume, 1) : !aie.objectfifosubview<memref<8x8xi32>>
            %4 = aie.objectfifo.subview.access %3[0] : !aie.objectfifosubview<memref<8x8xi32>> -> memref<8x8xi32>
            scf.yield %4 : memref<8x8xi32>
          }
          default {
            %3 = aie.objectfifo.acquire @part_Z_0(Consume, 1) : !aie.objectfifosubview<memref<8x8xi32>>
            %4 = aie.objectfifo.subview.access %3[0] : !aie.objectfifosubview<memref<8x8xi32>> -> memref<8x8xi32>
            scf.yield %4 : memref<8x8xi32>
          }
          memref.copy %2, %buffer_0_4 : memref<8x8xi32> to memref<8x8xi32>
          scf.index_switch %arg1 
          case 1 {
            aie.objectfifo.release @part_Z_1(Consume, 1)
            scf.yield
          }
          default {
            aie.objectfifo.release @part_Z_0(Consume, 1)
          }
          affine.for %arg2 = 0 to 8 {
            affine.for %arg3 = 0 to 8 {
              %3 = affine.load %1[%arg2, %arg3] : memref<8x8xi32>
              %4 = arith.extsi %3 : i32 to i33
              affine.store %4, %buffer_0_4_0[%arg2, %arg3] : memref<8x8xi33>
            }
          }
          affine.for %arg2 = 0 to 8 {
            affine.for %arg3 = 0 to 8 {
              %3 = affine.load %buffer_0_4[%arg2, %arg3] : memref<8x8xi32>
              %4 = arith.extsi %3 : i32 to i33
              affine.store %4, %buffer_0_4_1[%arg2, %arg3] : memref<8x8xi33>
            }
          }
          affine.for %arg2 = 0 to 8 {
            affine.for %arg3 = 0 to 8 {
              %3 = affine.load %buffer_0_4_0[%arg2, %arg3] : memref<8x8xi33>
              %4 = affine.load %buffer_0_4_1[%arg2, %arg3] : memref<8x8xi33>
              %5 = arith.addi %3, %4 : i33
              affine.store %5, %buffer_0_4_2[%arg2, %arg3] : memref<8x8xi33>
            }
          }
          affine.for %arg2 = 0 to 8 {
            affine.for %arg3 = 0 to 8 {
              %3 = affine.load %buffer_0_4_2[%arg2, %arg3] : memref<8x8xi33>
              %4 = arith.trunci %3 : i33 to i32
              affine.store %4, %buffer_0_4_3[%arg2, %arg3] : memref<8x8xi32>
            }
          }
          memref.copy %buffer_0_4_3, %1 {to = "Z_out"} : memref<8x8xi32> to memref<8x8xi32>
        } {loop_name = "i", op_name = "S_i_0"}
        aie.objectfifo.release @fifo_8(Produce, 1)
      }
      aie.end
    } {link_with = "external1.o"}

i33 operations can be lowered, but lead to incorrect result.

Expected behavior
in this example, i32 should not cast to i33.

Additional context
refer to #462

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions