Linalg.generic for full BLAS matmul expression

Hi all,
I have been playing for a few days to implement the full BLAS expression of GEMM in Linalg, but I am hitting a wall and I was hoping for some help. For sake of clarity, this is the full GEMM expression:

C = alpha*A*B + beta*C

Where alpha and beta are scalars.

I was able to get the correct result by doing the matmul first (alpha*A*B) and then the beta*C element-wise. However, this sequence needs a temporary variable to store the matmul.

The most efficient way to do this is to first calculate beta*C in place and then to run the A*B calculation on C, still in place. No matter how many attempts, I was not able to generate optimal and correct MLIR for this.

This is the Linalg code I tried (EDIT : In this code I am pretranposing matrix A and I am ignoring the multiplication by alpha, so what I am really doing is: C = trans(A)*B + beta*C):

// mlir-opt --linalg-comprehensive-module-bufferize  --convert-linalg-to-loops %s
!type_A = type tensor<2048x2048xf32>
!type_B = type tensor<2048x2048xf32>
!type_C = type tensor<2048x2048xf32>

#map0 = affine_map<(d0, d1) -> (d0, d1)>
#scalar_map_0 = affine_map<(d0, d1) -> ()>
#scalar_map_1 = affine_map<(d0, d1, d2) -> ()>
#identity_map = affine_map<(d0, d1) -> (d0, d1)>

func @gemm(%A : !type_A {linalg.buffer_layout = #identity_map, linalg.inplaceable = false}, 
           %B : !type_B {linalg.buffer_layout = #identity_map, linalg.inplaceable = false}, 
           %C : !type_C {linalg.buffer_layout = #identity_map, linalg.inplaceable = true}, %alpha : f32, %beta : f32) -> !type_C {

    // %1 = beta * C
    %1 = linalg.generic {
      indexing_maps = [#scalar_map_0, #map0],
      iterator_types = ["parallel", "parallel"]}
      ins(%beta :  f32)
      outs(%C: !type_C) {
        ^bb0(%be :f32, %c: f32):
          %out = arith.mulf %be, %c : f32
          linalg.yield %out : f32
      } -> !type_C

    // %2 = alpha*A*B + %1 = alpha*A*B + beta*C
    %2 = linalg.generic
      {indexing_maps = [ affine_map<(m, n, k) -> (k, m)>,
                        affine_map<(m, n, k) -> (k, n)>,
                        #scalar_map_1,
                        affine_map<(m, n, k) -> (m, n)>],
       iterator_types = ["parallel", "parallel", "reduction"]}

      ins(%A, %B, %alpha:  !type_A, !type_B,  f32)
      outs(%1: !type_C) {
      ^bb0(%a: f32, %b: f32, %al : f32, %c: f32) :
        %d = arith.mulf %a, %b: f32
        %e = arith.addf %c, %d: f32 
        linalg.yield %e : f32
      } -> !type_C

    return %2 : !type_C
}

Which is not doing what I want. Indeed this is the result:

module  {
  func @gemm(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2048x2048xf32>, %arg3: f32, %arg4: f32) {
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %c2048 = arith.constant 2048 : index
    scf.for %arg5 = %c0 to %c2048 step %c1 {
      scf.for %arg6 = %c0 to %c2048 step %c1 {
        %0 = memref.load %arg2[%arg5, %arg6] : memref<2048x2048xf32>
        %1 = arith.mulf %arg4, %0 : f32
        memref.store %1, %arg2[%arg5, %arg6] : memref<2048x2048xf32>
      }
    }
    scf.for %arg5 = %c0 to %c2048 step %c1 {
      scf.for %arg6 = %c0 to %c2048 step %c1 {
        scf.for %arg7 = %c0 to %c2048 step %c1 {
          %0 = memref.load %arg0[%arg7, %arg5] : memref<2048x2048xf32>
          %1 = memref.load %arg1[%arg7, %arg6] : memref<2048x2048xf32>
          %2 = memref.load %arg2[%arg5, %arg6] : memref<2048x2048xf32>
          %3 = arith.mulf %0, %1 : f32
          %4 = arith.addf %2, %3 : f32
          memref.store %4, %arg2[%arg5, %arg6] : memref<2048x2048xf32>
        }
      }
    }
    return
  }
}

The problem is on the inner loop. What this loops is basically doing is:

for i = 0:K{
   C[i,j] += A[i,k]*A[k,j];
}

But this wrong, because we are accumulating beta for K times, so we end up with C= A*B + K*beta*C. What I really want is:

tmp = 0;
for i = 0:K{
   tmp += A[i,k]*A[k,j];
}
C[i,j] += tmp;

One way to achieve this is to create a temporary tensor %tmp initialized to zero, use the tensor for the matmul and then adding another linalg.generic that does %out = %C + %tmp .

For completeness, this what the code would look like:

// same maps as before
func @gemm(%A : !type_A {linalg.buffer_layout = #identity_map, linalg.inplaceable = false}, 
           %B : !type_B {linalg.buffer_layout = #identity_map, linalg.inplaceable = false}, 
           %C : !type_C {linalg.buffer_layout = #identity_map, linalg.inplaceable = true}, %alpha : f32, %beta : f32) -> !type_C {

    %cst = arith.constant 0.0 : f32
    %init = linalg.init_tensor [2048, 2048] : !type_C
    %0 = linalg.fill(%cst, %C) : f32, !type_C -> !type_C

    // %1 = beta * C
    %1 = linalg.generic {
      indexing_maps = [#scalar_map_0, #map0],
      iterator_types = ["parallel", "parallel"]}
      ins(%beta :  f32)
      outs(%C: !type_C) {
        ^bb0(%be :f32, %c: f32):
          %out = arith.mulf %be, %c : f32
          linalg.yield %out : f32
      } -> !type_C

    // %2 = alpha*A*B 
    %2 = linalg.generic
      {indexing_maps = [ affine_map<(m, n, k) -> (k, m)>,
                        affine_map<(m, n, k) -> (k, n)>,
                        #scalar_map_1,
                        affine_map<(m, n, k) -> (m, n)>],
       iterator_types = ["parallel", "parallel", "reduction"]}

      ins(%A, %B, %alpha:  !type_A, !type_B,  f32)
      outs(%0: !type_C) {
      ^bb0(%a: f32, %b: f32, %al : f32, %c: f32) :
        %d = arith.mulf %a, %b: f32
        %e = arith.addf %c, %d: f32 
        linalg.yield %e : f32
      } -> !type_C

    // %C = %C + %2 = beta*C + A*B
    %3 = linalg.generic {
      indexing_maps = [#map0, #map0],
      iterator_types = ["parallel", "parallel"]}
      ins(%2:  !type_C)
      outs(%1: !type_C) {
        ^bb0(%x :f32, %y: f32):
          %out = arith.addf %x, %y : f32
          linalg.yield %out : f32
      } -> !type_C

    return %3 : !type_C
}

But now I have again a temporary tensor I want to get rid of. One way to do this is to fuse the addition into the matmul, and this is where I am stuck.

No matter how many attempts, I was not able to achieve this fusion.

So, I have two/three questions (and a lot of gratitude for any answer :slight_smile: ) :

  • Is it possible to generate the optimal correct result by using only two linalg.generic operations?
  • From what I understood @nicolasvasilache mentioned that this is not possible in Linalg . If this is the case, how hard would it be to add support for something like this (if it makes sense to add it at all)?
  • If adding support for this is not viable, what is the right way to fuse the third linalg.generic into the matmul?

cc @chelini

You should be able to write it as:

tmp = A * B
C  = alpha * tmp + beta * C

For larger fusions, we are working on better transformation control that will make it easier to target and compose transformations on the IR, to avoid the need for heuristics.

If you want heuristics, you may look at what IREE is doing (ping @MaheshRavishankar for a pointer).

Are the affine maps of the second linalg.generic correct? I was expecting affine_map<(m, n, k) -> (m, k)> instead of affine_map<(m, n, k) -> (k, m)>. Also, where is the multiplication with the alpha scalar in the first example?

The example is not correct indeed but this is more general: multiplication with a scalar is not something you want to within a perfectly nested loop that has a reduction as it is the same problem as the first one Giuseppe encountered.

That is where you need separate ops for “reduction” and “linear” behavior,
The desired final code structure comes from fusion.

Note this is the dual behavior of starting from

tmp = 0;
for i = 0:K{
   tmp += A[i,k]*A[k,j];
}
C[i,j] += tmp;

and needing to privatize/array expand tmp to get a good tiled and vectorized impl.

I think this has come up a couple of times. Lets split the problem into two. First lets consider the version of GEMM that is

D_{ij} = A_{ik} * B_{kj} + C_{ij}

This is directly what a linalg.matmul represents and this works as expected. Connecting this to vector dialect makes sure that you read an element of (or a vector of elements) of C_{ij} into registers, compute the A_{ik} * B_{kj} in registers as well and write out the final result to memory (I am not going into details of what the vector dialect code looks like, can dig it up from IREEs generated code)

Current Linalg named operations does not allow you to express

D_{ij} = alpha * A_{ik} * B_{kj} + beta * C_{ij}

It comes down to how to handle alpha and beta.

First lets say beta == 1. For this case, we need a new version of (or need to adapt) linalg.matmul (say linalg.matmul_alpha) which either takes three ins operands: A, B and alpha with the indexing map for alpha being affine_map<(d0, d1, d2) -> ()>. Other option is that alpha is not part of the ins list of linalg.matmul_alpha and is just implicitly captured in the region of the op

^bb0(%arg0 : f32, %arg1 : f32, %arg2 : f32) :
    %0 = mulf %arg0, %arg1 : f32
    %1 = mulf %0, %alpha : f32
    %2 = addf %1, %arg2 : f32
    linalg.yield %2 : f32

Next for beta != 1 is a harder case. I cannot see an easy way to extend linalg.matmul to handle this case.

There is a “work around” kind of solution that works in many case. You can split the computation into three separate statements

A_alpha_{ij} = alpha * A_{ij}
temp_{ij} = A_alpha_{ik} * B_{kj} + C_{ij}
D_{ij} = beta * temp_{ij}

The first and the last statements are just linalg.generic operations which could be fused with the producer of A_{ij} (or consumer of D_{ij}), which is typically possible in many situations…

Hi @chelini , @nicolasvasilache ,
Thanks for you replies. About the correctness, sorry, I should have mentioned that I am doing (for now) C = trans(A)*B + beta*C. Maps should be correct for this use-case.

So, after a chat with @nicolasvasilache ,I think I get things more clearly. Basically, It is algorithmically impossible to fuse the element-wise addition into the reduction: so as soon as we start accumulating, we can only add beta*C after that. Indeed, in my previous example, I had:

for i = 0:M{
   for j= 0:N{
      tmp = 0;
      for k = 0:K{
            tmp += A[i,k]*A[k,j];
       }
       C[i,j] += tmp;
    }
}

The tmp variable (a scalar) is used just before the accumulation loop. If I pull the accumulation loop up, this becomes:

for i = 0:M{
   tmp[N] = {0};
   for k = 0:K{
      for j= 0:N{
            tmp[j] += A[i,k]*A[k,j];
       }
    }
    C[i,j] += tmp[j];
}

And if I pull the accumulation loop even higher, I will have a tmp[M,N], which is a copy of my entire matrix!

One last thing

Does the same hold for alpha?
You were mentioning that also for alpha we have the same issues. But is that the case? Because I can do C = alpha*A*B like this:

for i = 0:M{
   for j= 0:N{
      for k = 0:K{
            C[i,j] += alpha*A[i,k]*A[k,j];
       }
    }
}

And the result “should” be correct.

Thanks again for your help,
Giuseppe

Actually missed a case above. With this representation

A_alpha_{ij} = alpha * A_{ij}
temp_{ij} = A_alpha_{ik} * B_{kj} + C_{ij}
D_{ij} = beta * temp_{ij}

Just the last two statements can be tile + fused. So you compute a tile of the 2nd statement and use it immediately to compute a tile of the third statement. Then if you connect this to vector dialect, it behaves as expected: the k reduction happens in registers and the scalar multiplication is done before it is written out to memory. (IREE backend handles these as well the same way)

Hi @MaheshRavishankar ,
Thanks a lot for your answer (sorry, didn’t see it when posted, because the page did not update in time).

Yes, I agree that when beta==1 everything becomes simpler.

About this case, when beta != 1:

A_alpha_{ij} = alpha * A_{ij}
temp_{ij} = A_alpha_{ik} * B_{kj} + C_{ij}
D_{ij} = beta * temp_{ij}

Can’t I always do:

temp_{ij} =alpha* A * B_{kj} + C_{ij}
D_{ij} = beta * temp_{ij}

Also, is it correct to say that D can be fused into temp only after the (first) accumulation? The problem is that using the BLIS approach I have 5 levels of tiling, and I start accumulating at the second level.

Thanks again for walking me through this,
Giuseppe

Hi @MaheshRavishankar , @nicolasvasilache ,
I had more thought about this. So I agree about the fusion. We cannot fuse the addition by beta*C when beta != 1.

But, as @MaheshRavishankar , we can fuse the addition when beta == 1.

So, let’s forget about fusion for a moment. In my original post I was doing:

%C = beta*C 
%C = A*B + %C

You see, now the second operation is back to be a matmul with beta=1. I don’t want to fuse anything in this case, I just want to not create a copy of %C. @nicolasvasilache do you know how can I do this in Linalg?

Thanks,
Giuseppe

Hi @MaheshRavishankar , @nicolasvasilache ,
As I mentioned on the IREE channel, I think I was looking at the problem in the wrong way. Indeed, the original snippet I wrote in my first post was bufferized in the correct way and was vectorized in the way I wanted.

Thanks again for walking me through this!

1 Like