Hello all!
I’m trying to work with CodeGen strategy for matrix multiplication and faced with the following issue.
Consider the multiplication of 2 matrices 1024X1024 * 1024X1024 = 1024X1024.
Case 1:
For the 1-st case tile sizes used are (512, 256), that is perfectly suitable to matrices dimensions, and microkernel dims are 6X2. That means that L2D buffer size for LHS tile is 512X256Xsizeof(float)=524288 bytes and L1D buffer size for RHS tile is 512X2X8Xsizeof(float)=32768 bytes that is suitable values for target platform caches. I apply the following sequence of transformations (useFullTileBuffers is set to true)
tile(2, 512).promote(1).tile(0, 256).promote(0).tile(1, 16).promote(1).tile(0, 6).tile(2, 8).promote({0, 1, 2}).vectorize()
and get the following MLIR-IR:
#map0 = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>
#map1 = affine_map<(d0, d1) -> (d0 * 1024 + d1)>
#map2 = affine_map<(d0, d1) -> (d0 * 512 + d1)>
#map3 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
.... ....
func @sgemm(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<1024x1024xf32>) {
.... ....
%cst_1 = constant dense<0.000000e+00> : vector<16xf32>
%0 = memref.alloca() : memref<6xvector<16xf32>>
%1 = memref.alloca() : memref<6xvector<16xf32>>
%cst_2 = constant dense<0.000000e+00> : vector<8xf32>
%2 = memref.alloca() : memref<6xvector<8xf32>>
%3 = memref.alloca() : memref<6xvector<8xf32>>
%4 = memref.alloca() : memref<8xvector<16xf32>>
%5 = memref.alloca() : memref<6xvector<16xf32>>
%6 = memref.alloca() : memref<6xvector<16xf32>>
%7 = memref.alloca() {alignment = 32 : i64} : memref<6x16xf32>
%8 = memref.alloca() {alignment = 32 : i64} : memref<6x8xf32>
%9 = memref.alloca() {alignment = 32 : i64} : memref<6x16xf32>
%10 = memref.alloc() : memref<1024x1024xf32>
linalg.fill(%10, %cst) : memref<1024x1024xf32>, f32
scf.for %arg3 = %c0 to %c1024 step %c512 {
%11 = memref.subview %arg0[0, %arg3] [1024, 512] [1, 1] : memref<1024x1024xf32> to memref<1024x512xf32, #map0>
%12 = memref.subview %arg1[%arg3, 0] [512, 1024] [1, 1] : memref<1024x1024xf32> to memref<512x1024xf32, #map0>
%13 = memref.alloc(%c2097152) : memref<?xi8>
%14 = memref.view %13[%c0][] : memref<?xi8> to memref<512x1024xf32>
%15 = memref.subview %14[0, 0] [512, 1024] [1, 1] : memref<512x1024xf32> to memref<512x1024xf32, #map1>
linalg.fill(%14, %cst) : memref<512x1024xf32>, f32
linalg.copy(%12, %15) : memref<512x1024xf32, #map0>, memref<512x1024xf32, #map1>
scf.for %arg4 = %c0 to %c1024 step %c256 {
%16 = memref.subview %11[%arg4, 0] [256, 512] [1, 1] : memref<1024x512xf32, #map0> to memref<256x512xf32, #map0>
%17 = memref.subview %10[%arg4, 0] [256, 1024] [1, 1] : memref<1024x1024xf32> to memref<256x1024xf32, #map0>
%18 = memref.alloc(%c524288) : memref<?xi8>
%19 = memref.view %18[%c0][] : memref<?xi8> to memref<256x512xf32>
%20 = memref.subview %19[0, 0] [256, 512] [1, 1] : memref<256x512xf32> to memref<256x512xf32, #map2>
linalg.fill(%19, %cst) : memref<256x512xf32>, f32
linalg.copy(%16, %20) : memref<256x512xf32, #map0>, memref<256x512xf32, #map2>
scf.for %arg5 = %c0 to %c1024 step %c16 {
%21 = memref.subview %14[0, %arg5] [512, 16] [1, 1] : memref<512x1024xf32> to memref<512x16xf32, #map0>
%22 = memref.subview %17[0, %arg5] [256, 16] [1, 1] : memref<256x1024xf32, #map0> to memref<256x16xf32, #map0>
%23 = memref.alloc(%c32768) : memref<?xi8>
%24 = memref.view %23[%c0][] : memref<?xi8> to memref<512x16xf32>
%25 = memref.subview %24[0, 0] [512, 16] [1, 1] : memref<512x16xf32> to memref<512x16xf32, #map3>
linalg.fill(%24, %cst) : memref<512x16xf32>, f32
linalg.copy(%21, %25) : memref<512x16xf32, #map0>, memref<512x16xf32, #map3>
scf.for %arg6 = %c0 to %c256 step %c6 {
%26 = affine.min #map4(%arg6)
%27 = memref.subview %19[%arg6, 0] [%26, 512] [1, 1] : memref<256x512xf32> to memref<?x512xf32, #map5>
%28 = affine.min #map4(%arg6)
%29 = memref.subview %22[%arg6, 0] [%28, 16] [1, 1] : memref<256x16xf32, #map0> to memref<?x16xf32, #map0>
%30 = cmpi sle, %c6, %28 : index
%31:3 = scf.if %30 -> (memref<?x16xf32, #map6>, index, index) {
%39 = memref.cast %29 : memref<?x16xf32, #map0> to memref<?x16xf32, #map6>
scf.yield %39, %c0, %c0 : memref<?x16xf32, #map6>, index, index
} else {
affine.for %arg7 = 0 to 6 {
%43 = cmpi slt, %arg7, %28 : index
scf.if %43 {
%44 = vector.transfer_read %29[%arg7, %c0], %cst {in_bounds = [true]} : memref<?x16xf32, #map0>, vector<16xf32>
memref.store %44, %0[%arg7] : memref<6xvector<16xf32>>
} else {
memref.store %cst_1, %0[%arg7] : memref<6xvector<16xf32>>
}
}
%39 = vector.type_cast %0 : memref<6xvector<16xf32>> to memref<vector<6x16xf32>>
%40 = memref.load %39[] : memref<vector<6x16xf32>>
%41 = vector.type_cast %7 : memref<6x16xf32> to memref<vector<6x16xf32>>
memref.store %40, %41[] : memref<vector<6x16xf32>>
%42 = memref.cast %7 : memref<6x16xf32> to memref<?x16xf32, #map6>
scf.yield %42, %c0, %c0 : memref<?x16xf32, #map6>, index, index
}
affine.for %arg7 = 0 to 6 {
%39 = affine.apply #map7(%arg7, %31#1)
%40 = vector.transfer_read %31#0[%39, %31#2], %cst {in_bounds = [true]} : memref<?x16xf32, #map6>, vector<16xf32>
memref.store %40, %1[%arg7] : memref<6xvector<16xf32>>
}
%32 = vector.type_cast %1 : memref<6xvector<16xf32>> to memref<vector<6x16xf32>>
%33 = memref.load %32[] : memref<vector<6x16xf32>>
%34 = scf.for %arg7 = %c0 to %c512 step %c8 iter_args(%arg8 = %33) -> (vector<6x16xf32>) {
%39 = memref.subview %27[0, %arg7] [%26, 8] [1, 1] : memref<?x512xf32, #map5> to memref<?x8xf32, #map5>
%40 = memref.subview %24[%arg7, 0] [8, 16] [1, 1] : memref<512x16xf32> to memref<8x16xf32, #map8>
%41 = memref.alloc(%c192) : memref<?xi8>
%42 = memref.alloc(%c512) : memref<?xi8>
%43 = memref.alloc(%c384) : memref<?xi8>
%44 = cmpi sle, %c6, %26 : index
%45:3 = scf.if %44 -> (memref<?x8xf32, #map6>, index, index) {
%78 = memref.cast %39 : memref<?x8xf32, #map5> to memref<?x8xf32, #map6>
scf.yield %78, %c0, %c0 : memref<?x8xf32, #map6>, index, index
} else {
affine.for %arg9 = 0 to 6 {
%82 = cmpi slt, %arg9, %26 : index
scf.if %82 {
%83 = vector.transfer_read %39[%arg9, %c0], %cst {in_bounds = [true]} : memref<?x8xf32, #map5>, vector<8xf32>
memref.store %83, %2[%arg9] : memref<6xvector<8xf32>>
} else {
memref.store %cst_2, %2[%arg9] : memref<6xvector<8xf32>>
}
}
%78 = vector.type_cast %2 : memref<6xvector<8xf32>> to memref<vector<6x8xf32>>
%79 = memref.load %78[] : memref<vector<6x8xf32>>
%80 = vector.type_cast %8 : memref<6x8xf32> to memref<vector<6x8xf32>>
memref.store %79, %80[] : memref<vector<6x8xf32>>
%81 = memref.cast %8 : memref<6x8xf32> to memref<?x8xf32, #map6>
scf.yield %81, %c0, %c0 : memref<?x8xf32, #map6>, index, index
}
affine.for %arg9 = 0 to 6 {
%78 = affine.apply #map7(%arg9, %45#1)
%79 = vector.transfer_read %45#0[%78, %45#2], %cst {in_bounds = [true]} : memref<?x8xf32, #map6>, vector<8xf32>
memref.store %79, %3[%arg9] : memref<6xvector<8xf32>>
}
%46 = vector.type_cast %3 : memref<6xvector<8xf32>> to memref<vector<6x8xf32>>
%47 = memref.load %46[] : memref<vector<6x8xf32>>
affine.for %arg9 = 0 to 8 {
%78 = vector.transfer_read %40[%arg9, %c0], %cst {in_bounds = [true]} : memref<8x16xf32, #map8>, vector<16xf32>
memref.store %78, %4[%arg9] : memref<8xvector<16xf32>>
}
%48 = vector.type_cast %4 : memref<8xvector<16xf32>> to memref<vector<8x16xf32>>
%49 = memref.load %48[] : memref<vector<8x16xf32>>
%50 = vector.transpose %49, [1, 0] : vector<8x16xf32> to vector<16x8xf32>
%51 = vector.transpose %47, [1, 0] : vector<6x8xf32> to vector<8x6xf32>
%52 = vector.transpose %50, [1, 0] : vector<16x8xf32> to vector<8x16xf32>
%53 = vector.extract %51[0] : vector<8x6xf32>
%54 = vector.extract %52[0] : vector<8x16xf32>
%55 = vector.outerproduct %53, %54, %cst_0 {kind = #vector.kind<add>} : vector<6xf32>, vector<16xf32>
.... ....
%77 = addf %arg8, %76 : vector<6x16xf32>
.... ....
scf.yield %77 : vector<6x16xf32>
}
%35 = cmpi sle, %c6, %28 : index
%36:3 = scf.if %35 -> (memref<?x16xf32, #map6>, index, index) {
%39 = memref.cast %29 : memref<?x16xf32, #map0> to memref<?x16xf32, #map6>
scf.yield %39, %c0, %c0 : memref<?x16xf32, #map6>, index, index
} else {
%39 = memref.cast %9 : memref<6x16xf32> to memref<?x16xf32, #map6>
scf.yield %39, %c0, %c0 : memref<?x16xf32, #map6>, index, index
}
%37 = vector.type_cast %5 : memref<6xvector<16xf32>> to memref<vector<6x16xf32>>
memref.store %34, %37[] : memref<vector<6x16xf32>>
affine.for %arg7 = 0 to 6 {
%39 = affine.apply #map7(%arg7, %36#2)
%40 = memref.load %5[%arg7] : memref<6xvector<16xf32>>
vector.transfer_write %40, %36#0[%39, %36#2] {in_bounds = [true]} : vector<16xf32>, memref<?x16xf32, #map6>
}
%38 = xor %35, %true : i1
scf.if %38 {
%39 = vector.type_cast %9 : memref<6x16xf32> to memref<vector<6x16xf32>>
%40 = memref.load %39[] : memref<vector<6x16xf32>>
%41 = vector.type_cast %6 : memref<6xvector<16xf32>> to memref<vector<6x16xf32>>
memref.store %40, %41[] : memref<vector<6x16xf32>>
affine.for %arg7 = 0 to 6 {
%42 = cmpi slt, %arg7, %28 : index
scf.if %42 {
%43 = memref.load %6[%arg7] : memref<6xvector<16xf32>>
vector.transfer_write %43, %29[%arg7, %c0] {in_bounds = [true]} : vector<16xf32>, memref<?x16xf32, #map0>
}
}
}
}
.... ....
return
}
In fact, this code after being compiled and executed, shows not bad performance GFLOPS values (~100 GFLOPS, while peak performance is ~140GFLOPS) only if input matrices sizes are divisible by tile values. If not, the incomplete zero-filled edge tiles are processed like normal tiles fully filled by data, so that performance values are reduced by ~15-30%. I.e., the cause of GFLOPS reducing is processing of edge tiles partially filled by data and partially filled by zeros in (3).
Case 2:
Now the I/O matrices are the same, but tile sizes are (480, 330).
I apply the following sequence of transformations (useFullTileBuffers is still set to true)
tile(2, 480).promote(1).tile(0, 330).promote(0).tile(1, 16).promote(1).tile(0, 6).tile(2, 8).promote({0, 1, 2}).vectorize()
and get the following MLIR-IR:
#map0 = affine_map<(d0) -> (480, -d0 + 1024)>
#map1 = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>
#map2 = affine_map<(d0, d1) -> (d0 * 1024 + d1)>
#map3 = affine_map<(d0) -> (330, -d0 + 1024)>
.... ....
func @sgemm(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<1024x1024xf32>) {
.... ....
%0 = memref.alloca() : memref<6xvector<16xf32>>
%1 = memref.alloca() : memref<6xvector<16xf32>>
%2 = memref.alloca() : memref<6xvector<8xf32>>
%3 = memref.alloca() : memref<8xvector<16xf32>>
%4 = memref.alloca() : memref<6xvector<16xf32>>
%5 = memref.alloca() : memref<6xvector<16xf32>>
%6 = memref.alloca() {alignment = 32 : i64} : memref<6x16xf32>
%7 = memref.alloca() {alignment = 32 : i64} : memref<6x16xf32>
%8 = memref.alloc() : memref<1024x1024xf32>
linalg.fill(%8, %cst) : memref<1024x1024xf32>, f32
scf.for %arg3 = %c0 to %c1024 step %c480 {
%9 = affine.min #map0(%arg3)
%10 = memref.subview %arg0[0, %arg3] [1024, %9] [1, 1] : memref<1024x1024xf32> to memref<1024x?xf32, #map1>
%11 = affine.min #map0(%arg3)
%12 = memref.subview %arg1[%arg3, 0] [%11, 1024] [1, 1] : memref<1024x1024xf32> to memref<?x1024xf32, #map1>
%13 = memref.alloc(%c1966080) : memref<?xi8>
%14 = memref.view %13[%c0][] : memref<?xi8> to memref<480x1024xf32>
%15 = memref.subview %14[0, 0] [%11, 1024] [1, 1] : memref<480x1024xf32> to memref<?x1024xf32, #map2>
(1) linalg.fill(%14, %cst) : memref<480x1024xf32>, f32
linalg.copy(%12, %15) : memref<?x1024xf32, #map1>, memref<?x1024xf32, #map2>
scf.for %arg4 = %c0 to %c1024 step %c330 {
%16 = affine.min #map3(%arg4)
%17 = memref.subview %10[%arg4, 0] [%16, %9] [1, 1] : memref<1024x?xf32, #map1> to memref<?x?xf32, #map1>
%18 = affine.min #map3(%arg4)
%19 = memref.subview %8[%arg4, 0] [%18, 1024] [1, 1] : memref<1024x1024xf32> to memref<?x1024xf32, #map1>
%20 = memref.alloc(%c633600) : memref<?xi8>
%21 = memref.view %20[%c0][] : memref<?xi8> to memref<330x480xf32>
%22 = memref.subview %21[0, 0] [%16, %9] [1, 1] : memref<330x480xf32> to memref<?x?xf32, #map4>
(2) linalg.fill(%21, %cst) : memref<330x480xf32>, f32
linalg.copy(%17, %22) : memref<?x?xf32, #map1>, memref<?x?xf32, #map4>
scf.for %arg5 = %c0 to %c1024 step %c16 {
%23 = memref.subview %14[0, %arg5] [480, 16] [1, 1] : memref<480x1024xf32> to memref<480x16xf32, #map1>
%24 = memref.subview %19[0, %arg5] [%18, 16] [1, 1] : memref<?x1024xf32, #map1> to memref<?x16xf32, #map1>
%25 = memref.alloc(%c30720) : memref<?xi8>
%26 = memref.view %25[%c0][] : memref<?xi8> to memref<480x16xf32>
%27 = memref.subview %26[0, 0] [480, 16] [1, 1] : memref<480x16xf32> to memref<480x16xf32, #map5>
(3) linalg.fill(%26, %cst) : memref<480x16xf32>, f32
linalg.copy(%23, %27) : memref<480x16xf32, #map1>, memref<480x16xf32, #map5>
(4) scf.for %arg6 = %c0 to %c330 step %c6 {
%28 = memref.subview %21[%arg6, 0] [6, 480] [1, 1] : memref<330x480xf32> to memref<6x480xf32, #map6>
%29 = affine.min #map7(%18, %arg6)
%30 = memref.subview %24[%arg6, 0] [%29, 16] [1, 1] : memref<?x16xf32, #map1> to memref<?x16xf32, #map1>
%31 = cmpi sle, %c6, %29 : index
%32:3 = scf.if %31 -> (memref<?x16xf32, #map8>, index, index) {
%40 = memref.cast %30 : memref<?x16xf32, #map1> to memref<?x16xf32, #map8>
scf.yield %40, %c0, %c0 : memref<?x16xf32, #map8>, index, index
} else {
affine.for %arg7 = 0 to 6 {
%44 = cmpi slt, %arg7, %29 : index
scf.if %44 {
%45 = vector.transfer_read %30[%arg7, %c0], %cst {in_bounds = [true]} : memref<?x16xf32, #map1>, vector<16xf32>
memref.store %45, %0[%arg7] : memref<6xvector<16xf32>>
} else {
memref.store %cst_1, %0[%arg7] : memref<6xvector<16xf32>>
}
}
%40 = vector.type_cast %0 : memref<6xvector<16xf32>> to memref<vector<6x16xf32>>
%41 = memref.load %40[] : memref<vector<6x16xf32>>
%42 = vector.type_cast %6 : memref<6x16xf32> to memref<vector<6x16xf32>>
memref.store %41, %42[] : memref<vector<6x16xf32>>
%43 = memref.cast %6 : memref<6x16xf32> to memref<?x16xf32, #map8>
scf.yield %43, %c0, %c0 : memref<?x16xf32, #map8>, index, index
}
affine.for %arg7 = 0 to 6 {
%40 = affine.apply #map9(%arg7, %32#1)
%41 = vector.transfer_read %32#0[%40, %32#2], %cst {in_bounds = [true]} : memref<?x16xf32, #map8>, vector<16xf32>
memref.store %41, %1[%arg7] : memref<6xvector<16xf32>>
}
%33 = vector.type_cast %1 : memref<6xvector<16xf32>> to memref<vector<6x16xf32>>
%34 = memref.load %33[] : memref<vector<6x16xf32>>
(5) %35 = scf.for %arg7 = %c0 to %c480 step %c8 iter_args(%arg8 = %34) -> (vector<6x16xf32>) {
%40 = memref.subview %28[0, %arg7] [6, 8] [1, 1] : memref<6x480xf32, #map6> to memref<6x8xf32, #map6>
%41 = memref.subview %26[%arg7, 0] [8, 16] [1, 1] : memref<480x16xf32> to memref<8x16xf32, #map10>
%42 = memref.alloc(%c192) : memref<?xi8>
%43 = memref.alloc(%c512) : memref<?xi8>
%44 = memref.alloc(%c384) : memref<?xi8>
affine.for %arg9 = 0 to 6 {
%77 = vector.transfer_read %40[%arg9, %c0], %cst {in_bounds = [true]} : memref<6x8xf32, #map6>, vector<8xf32>
memref.store %77, %2[%arg9] : memref<6xvector<8xf32>>
}
%45 = vector.type_cast %2 : memref<6xvector<8xf32>> to memref<vector<6x8xf32>>
%46 = memref.load %45[] : memref<vector<6x8xf32>>
affine.for %arg9 = 0 to 8 {
%77 = vector.transfer_read %41[%arg9, %c0], %cst {in_bounds = [true]} : memref<8x16xf32, #map10>, vector<16xf32>
memref.store %77, %3[%arg9] : memref<8xvector<16xf32>>
}
%47 = vector.type_cast %3 : memref<8xvector<16xf32>> to memref<vector<8x16xf32>>
%48 = memref.load %47[] : memref<vector<8x16xf32>>
%49 = vector.transpose %48, [1, 0] : vector<8x16xf32> to vector<16x8xf32>
%50 = vector.transpose %46, [1, 0] : vector<6x8xf32> to vector<8x6xf32>
%51 = vector.transpose %49, [1, 0] : vector<16x8xf32> to vector<8x16xf32>
%52 = vector.extract %50[0] : vector<8x6xf32>
%53 = vector.extract %51[0] : vector<8x16xf32>
%54 = vector.outerproduct %52, %53, %cst_0 {kind = #vector.kind<add>} : vector<6xf32>, vector<16xf32>
.... ....
%76 = addf %arg8, %75 : vector<6x16xf32>
.... ....
scf.yield %76 : vector<6x16xf32>
}
%36 = cmpi sle, %c6, %29 : index
%37:3 = scf.if %36 -> (memref<?x16xf32, #map8>, index, index) {
%40 = memref.cast %30 : memref<?x16xf32, #map1> to memref<?x16xf32, #map8>
scf.yield %40, %c0, %c0 : memref<?x16xf32, #map8>, index, index
} else {
%40 = memref.cast %7 : memref<6x16xf32> to memref<?x16xf32, #map8>
scf.yield %40, %c0, %c0 : memref<?x16xf32, #map8>, index, index
}
%38 = vector.type_cast %4 : memref<6xvector<16xf32>> to memref<vector<6x16xf32>>
memref.store %35, %38[] : memref<vector<6x16xf32>>
affine.for %arg7 = 0 to 6 {
%40 = affine.apply #map9(%arg7, %37#2)
%41 = memref.load %4[%arg7] : memref<6xvector<16xf32>>
vector.transfer_write %41, %37#0[%40, %37#2] {in_bounds = [true]} : vector<16xf32>, memref<?x16xf32, #map8>
}
%39 = xor %36, %true : i1
scf.if %39 {
%40 = vector.type_cast %7 : memref<6x16xf32> to memref<vector<6x16xf32>>
%41 = memref.load %40[] : memref<vector<6x16xf32>>
%42 = vector.type_cast %5 : memref<6xvector<16xf32>> to memref<vector<6x16xf32>>
memref.store %41, %42[] : memref<vector<6x16xf32>>
affine.for %arg7 = 0 to 6 {
%43 = cmpi slt, %arg7, %29 : index
scf.if %43 {
%44 = memref.load %5[%arg7] : memref<6xvector<16xf32>>
vector.transfer_write %44, %30[%arg7, %c0] {in_bounds = [true]} : vector<16xf32>, memref<?x16xf32, #map1>
}
}
}
}
.... ....
return
}
This code works two times slower than code from ‘Case 1’, i.e. performance value is ~50 GFLOPS.
If I comment tile buffer zero-fillings {(1), (2), (3)} and properly change upper bounds for loops {(4), (5)},
scf.for %arg6 = %c0 to %c330 step %c6 ==>> scf.for %arg6 = %c0 to %16 step %c6
%35 = scf.for %arg7 = %c0 to %c480 step %c8 iter_args(%arg8 = %34) → (vector<6x16xf32>) ==>>
%35 = scf.for %arg7 = %c0 to %11 step %c8 iter_args(%arg8 = %34) → (vector<6x16xf32>)
it gives the former GFLOPS performance value, ~100GFLOPS. I.e., it begins to work well even with imperfect tile values.
Case 3:
Now I/O matrices and tile sizes (480, 330) are the same, but I try to get rid of zero-fillings and set proper upper bounds by turning off useFullTileBuffers flag. So, I apply the following sequence of transformations
tile(2, 480).promote(1, false, false).tile(0, 330).promote(0, false, false).tile(1, 16).promote(1, false, false).tile(0, 6).tile(2, 8).promote({0, 1, 2}).vectorize()));
and get the following MLIR-IR:
#map0 = affine_map<(d0) -> (480, -d0 + 1024)>
#map1 = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>
#map2 = affine_map<(d0, d1) -> (d0 * 1024 + d1)>
#map3 = affine_map<(d0) -> (330, -d0 + 1024)>
.... ....
func @sgemm(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<1024x1024xf32>) {
.... ....
%0 = memref.alloca() : memref<6xvector<16xf32>>
%1 = memref.alloca() : memref<6xvector<16xf32>>
%cst_1 = constant dense<0.000000e+00> : vector<8xf32>
%2 = memref.alloca() : memref<6xvector<8xf32>>
%3 = memref.alloca() : memref<6xvector<8xf32>>
%cst_2 = constant dense<0.000000e+00> : vector<16xf32>
%4 = memref.alloca() : memref<8xvector<16xf32>>
%5 = memref.alloca() : memref<6xvector<16xf32>>
%6 = memref.alloca() : memref<6xvector<16xf32>>
%7 = memref.alloca() {alignment = 32 : i64} : memref<6x16xf32>
%8 = memref.alloca() {alignment = 32 : i64} : memref<6x8xf32>
%9 = memref.alloca() {alignment = 32 : i64} : memref<6x16xf32>
%10 = memref.alloc() : memref<1024x1024xf32>
linalg.fill(%10, %cst) : memref<1024x1024xf32>, f32
scf.for %arg3 = %c0 to %c1024 step %c480 {
%11 = affine.min #map0(%arg3)
%12 = memref.subview %arg0[0, %arg3] [1024, %11] [1, 1] : memref<1024x1024xf32> to memref<1024x?xf32, #map1>
%13 = affine.min #map0(%arg3)
%14 = memref.subview %arg1[%arg3, 0] [%13, 1024] [1, 1] : memref<1024x1024xf32> to memref<?x1024xf32, #map1>
%15 = memref.alloc(%c1966080) : memref<?xi8>
%16 = memref.view %15[%c0][] : memref<?xi8> to memref<480x1024xf32>
%17 = memref.subview %16[0, 0] [%13, 1024] [1, 1] : memref<480x1024xf32> to memref<?x1024xf32, #map2>
linalg.copy(%14, %17) : memref<?x1024xf32, #map1>, memref<?x1024xf32, #map2>
scf.for %arg4 = %c0 to %c1024 step %c330 {
%18 = affine.min #map3(%arg4)
%19 = memref.subview %12[%arg4, 0] [%18, %11] [1, 1] : memref<1024x?xf32, #map1> to memref<?x?xf32, #map1>
%20 = affine.min #map3(%arg4)
%21 = memref.subview %10[%arg4, 0] [%20, 1024] [1, 1] : memref<1024x1024xf32> to memref<?x1024xf32, #map1>
%22 = memref.alloc(%c633600) : memref<?xi8>
%23 = memref.view %22[%c0][] : memref<?xi8> to memref<330x480xf32>
%24 = memref.subview %23[0, 0] [%18, %11] [1, 1] : memref<330x480xf32> to memref<?x?xf32, #map4>
linalg.copy(%19, %24) : memref<?x?xf32, #map1>, memref<?x?xf32, #map4>
scf.for %arg5 = %c0 to %c1024 step %c16 {
%25 = memref.subview %17[0, %arg5] [%13, 16] [1, 1] : memref<?x1024xf32, #map2> to memref<?x16xf32, #map1>
%26 = memref.subview %21[0, %arg5] [%20, 16] [1, 1] : memref<?x1024xf32, #map1> to memref<?x16xf32, #map1>
%27 = memref.alloc(%c30720) : memref<?xi8>
%28 = memref.view %27[%c0][] : memref<?xi8> to memref<480x16xf32>
%29 = memref.subview %28[0, 0] [%13, 16] [1, 1] : memref<480x16xf32> to memref<?x16xf32, #map5>
linalg.copy(%25, %29) : memref<?x16xf32, #map1>, memref<?x16xf32, #map5>
scf.for %arg6 = %c0 to %18 step %c6 {
%30 = affine.min #map6(%18, %arg6)
%31 = memref.subview %24[%arg6, 0] [%30, %11] [1, 1] : memref<?x?xf32, #map4> to memref<?x?xf32, #map7>
%32 = affine.min #map6(%20, %arg6)
%33 = memref.subview %26[%arg6, 0] [%32, 16] [1, 1] : memref<?x16xf32, #map1> to memref<?x16xf32, #map1>
%34 = cmpi sle, %c6, %32 : index
%35:3 = scf.if %34 -> (memref<?x16xf32, #map8>, index, index) {
%43 = memref.cast %33 : memref<?x16xf32, #map1> to memref<?x16xf32, #map8>
scf.yield %43, %c0, %c0 : memref<?x16xf32, #map8>, index, index
} else {
affine.for %arg7 = 0 to 6 {
%47 = cmpi slt, %arg7, %32 : index
scf.if %47 {
%48 = vector.transfer_read %33[%arg7, %c0], %cst {in_bounds = [true]} : memref<?x16xf32, #map1>, vector<16xf32>
memref.store %48, %0[%arg7] : memref<6xvector<16xf32>>
} else {
memref.store %cst_2, %0[%arg7] : memref<6xvector<16xf32>>
}
}
%43 = vector.type_cast %0 : memref<6xvector<16xf32>> to memref<vector<6x16xf32>>
%44 = memref.load %43[] : memref<vector<6x16xf32>>
%45 = vector.type_cast %7 : memref<6x16xf32> to memref<vector<6x16xf32>>
memref.store %44, %45[] : memref<vector<6x16xf32>>
%46 = memref.cast %7 : memref<6x16xf32> to memref<?x16xf32, #map8>
scf.yield %46, %c0, %c0 : memref<?x16xf32, #map8>, index, index
}
affine.for %arg7 = 0 to 6 {
%43 = affine.apply #map9(%arg7, %35#1)
%44 = vector.transfer_read %35#0[%43, %35#2], %cst {in_bounds = [true]} : memref<?x16xf32, #map8>, vector<16xf32>
memref.store %44, %1[%arg7] : memref<6xvector<16xf32>>
}
%36 = vector.type_cast %1 : memref<6xvector<16xf32>> to memref<vector<6x16xf32>>
%37 = memref.load %36[] : memref<vector<6x16xf32>>
%38 = scf.for %arg7 = %c0 to %11 step %c8 iter_args(%arg8 = %37) -> (vector<6x16xf32>) {
%43 = affine.min #map10(%11, %arg7)
%44 = memref.subview %31[0, %arg7] [%30, %43] [1, 1] : memref<?x?xf32, #map7> to memref<?x?xf32, #map7>
%45 = affine.min #map10(%13, %arg7)
%46 = memref.subview %29[%arg7, 0] [%45, 16] [1, 1] : memref<?x16xf32, #map5> to memref<?x16xf32, #map11>
%47 = memref.alloc(%c192) : memref<?xi8>
%48 = memref.alloc(%c512) : memref<?xi8>
%49 = memref.alloc(%c384) : memref<?xi8>
%50 = cmpi sle, %c6, %30 : index
%51 = cmpi sle, %c8, %43 : index
%52 = and %50, %51 : i1
%53:3 = scf.if %52 -> (memref<?x?xf32, #map8>, index, index) {
%86 = memref.cast %44 : memref<?x?xf32, #map7> to memref<?x?xf32, #map8>
scf.yield %86, %c0, %c0 : memref<?x?xf32, #map8>, index, index
} else {
affine.for %arg9 = 0 to 6 {
%90 = cmpi slt, %arg9, %30 : index
scf.if %90 {
%91 = vector.transfer_read %44[%arg9, %c0], %cst : memref<?x?xf32, #map7>, vector<8xf32>
memref.store %91, %2[%arg9] : memref<6xvector<8xf32>>
} else {
memref.store %cst_1, %2[%arg9] : memref<6xvector<8xf32>>
}
}
%86 = vector.type_cast %2 : memref<6xvector<8xf32>> to memref<vector<6x8xf32>>
%87 = memref.load %86[] : memref<vector<6x8xf32>>
%88 = vector.type_cast %8 : memref<6x8xf32> to memref<vector<6x8xf32>>
memref.store %87, %88[] : memref<vector<6x8xf32>>
%89 = memref.cast %8 : memref<6x8xf32> to memref<?x?xf32, #map8>
scf.yield %89, %c0, %c0 : memref<?x?xf32, #map8>, index, index
}
affine.for %arg9 = 0 to 6 {
%86 = affine.apply #map9(%arg9, %53#1)
%87 = vector.transfer_read %53#0[%86, %53#2], %cst {in_bounds = [true]} : memref<?x?xf32, #map8>, vector<8xf32>
memref.store %87, %3[%arg9] : memref<6xvector<8xf32>>
}
%54 = vector.type_cast %3 : memref<6xvector<8xf32>> to memref<vector<6x8xf32>>
%55 = memref.load %54[] : memref<vector<6x8xf32>>
affine.for %arg9 = 0 to 8 {
%86 = cmpi slt, %arg9, %45 : index
scf.if %86 {
%87 = vector.transfer_read %46[%arg9, %c0], %cst {in_bounds = [true]} : memref<?x16xf32, #map11>, vector<16xf32>
memref.store %87, %4[%arg9] : memref<8xvector<16xf32>>
} else {
memref.store %cst_2, %4[%arg9] : memref<8xvector<16xf32>>
}
}
%56 = vector.type_cast %4 : memref<8xvector<16xf32>> to memref<vector<8x16xf32>>
%57 = memref.load %56[] : memref<vector<8x16xf32>>
%58 = vector.transpose %57, [1, 0] : vector<8x16xf32> to vector<16x8xf32>
%59 = vector.transpose %55, [1, 0] : vector<6x8xf32> to vector<8x6xf32>
%60 = vector.transpose %58, [1, 0] : vector<16x8xf32> to vector<8x16xf32>
%61 = vector.extract %59[0] : vector<8x6xf32>
%62 = vector.extract %60[0] : vector<8x16xf32>
%63 = vector.outerproduct %61, %62, %cst_0 {kind = #vector.kind<add>} : vector<6xf32>, vector<16xf32>
%64 = vector.extract %59[1] : vector<8x6xf32>
%65 = vector.extract %60[1] : vector<8x16xf32>
%66 = vector.outerproduct %64, %65, %63 {kind = #vector.kind<add>} : vector<6xf32>, vector<16xf32>
.... ....
%85 = addf %arg8, %84 : vector<6x16xf32>
.... ....
scf.yield %85 : vector<6x16xf32>
}
%39 = cmpi sle, %c6, %32 : index
%40:3 = scf.if %39 -> (memref<?x16xf32, #map8>, index, index) {
%43 = memref.cast %33 : memref<?x16xf32, #map1> to memref<?x16xf32, #map8>
scf.yield %43, %c0, %c0 : memref<?x16xf32, #map8>, index, index
} else {
%43 = memref.cast %9 : memref<6x16xf32> to memref<?x16xf32, #map8>
scf.yield %43, %c0, %c0 : memref<?x16xf32, #map8>, index, index
}
%41 = vector.type_cast %5 : memref<6xvector<16xf32>> to memref<vector<6x16xf32>>
memref.store %38, %41[] : memref<vector<6x16xf32>>
affine.for %arg7 = 0 to 6 {
%43 = affine.apply #map9(%arg7, %40#2)
%44 = memref.load %5[%arg7] : memref<6xvector<16xf32>>
vector.transfer_write %44, %40#0[%43, %40#2] {in_bounds = [true]} : vector<16xf32>, memref<?x16xf32, #map8>
}
%42 = xor %39, %true : i1
scf.if %42 {
%43 = vector.type_cast %9 : memref<6x16xf32> to memref<vector<6x16xf32>>
%44 = memref.load %43[] : memref<vector<6x16xf32>>
%45 = vector.type_cast %6 : memref<6xvector<16xf32>> to memref<vector<6x16xf32>>
memref.store %44, %45[] : memref<vector<6x16xf32>>
affine.for %arg7 = 0 to 6 {
%46 = cmpi slt, %arg7, %32 : index
scf.if %46 {
%47 = memref.load %6[%arg7] : memref<6xvector<16xf32>>
vector.transfer_write %47, %33[%arg7, %c0] {in_bounds = [true]} : vector<16xf32>, memref<?x16xf32, #map1>
}
}
}
}
.... ....
return
}
Here is no zero fillings and loop upper bounds are set correctly in a proper way, but this code works approximately 3 times slower than code from ‘Case 1’.
So, the question is: how to get good and stable performance values (~100 GFLOPS) using legal options and settings for all possible tile values even when these values are not the dividers of matrix dimensions?