[RFC][Linalg] Add Concatenate operation to Linalg

I would like to propose adding a concatenate op to Linalg. I’ve been prototyping on concatenate op support in IREE for a while. When adapting the patterns to tensors world, I hit the limit of Linalg again. There are some workarounds, but the promising solution to me is to add a concatenate op. There was a discussion long time ago. I put some context from it and added more thought below.

Option 1: lowering the op to an indexed_generic op

This is the first prototype I had. The idea is to lower a concatenate op to an indexed_generic op, and use the indices to yield correct value. With this approach, extra dimensions/loops are needed because of loop bound inference. The main problem is that you have 2 tensors (2, 2) and (3, 2) from which the generic op needs to infer the bounds for d0 and d1.

The first tensor will tell you that (d0, d1) is in the range [0, 2) x [0, 2).
The second tensor will tell you that (d0, d1) is in the range [2, 5) x [0, 2).
According to the first property of linalg generic, these ranges don’t match and this is undefined behavior: 'linalg' Dialect - MLIR.

A snippet of lowering concatenating memref<2x2xi32> and memref<3x2xi32> to
memref<5x2xi32> to scf dialect.

scf.for %arg3 = %c0 to %c2 step %c1 {
  scf.for %arg4 = %c0 to %c5 step %c1 {
    scf.for %arg5 = %c0 to %c2 step %c1 {
      scf.for %arg6 = %c0 to %c3 step %c1 {
        %0 = load %arg1[%arg5, %arg3] : memref<2x2xi32>
        %1 = load %arg2[%arg6, %arg3] : memref<3x2xi32>
        %2 = load %arg0[%arg4, %arg3] : memref<5x2xi32>
        %3 = subi %arg4, %c0 : index
        %4 = cmpi "eq", %3, %arg5 : index
        %5 = cmpi "sge", %arg4, %c0 : index
        %6 = cmpi "slt", %arg4, %c2 : index
        %7 = and %5, %6 : i1
        %8 = select %7, %0, %0 : i32
        %9 = subi %arg4, %c2 : index
        %10 = cmpi "eq", %9, %arg6 : index
        %11 = or %4, %10 : i1
        %12 = cmpi "sge", %arg4, %c2 : index
        %13 = cmpi "slt", %arg4, %c5 : index
        %14 = and %12, %13 : i1
        %15 = select %14, %1, %8 : i32
        %16 = select %11, %15, %2 : i32
        store %16, %arg0[%arg4, %arg3] : memref<5x2xi32>
      }
    }
  }
}

This actually is just a workaround to express concat in Linalg, but it is not the way to go. It will generate many inefficient loops.

Option 2 – lowering the op to bunch of subviews/subtensors + copy/subtensor_insert

The current solution in IREE is to lower the op to subviews + copies in buffer’s world. I have a prototype to have similar behavior in tensor’s world by using subtensors and subtensor_insert. The idea is to create a linalg.fill op, and operate on it.

E.g.,

Input:

func @concatenate(%0: tensor<2x2xi32>, %1: tensor<2x3xi32>) -> tensor<2x5xi32> {
  %2 = "mhlo.concatenate"(%0, %1) {
    dimension = 1
  } : (tensor<2x2xi32>, tensor<2x3xi32>) -> tensor<2x5xi32>
  return %2 : tensor<2x5xi32>
}

Output:

func @concatenate(%0: tensor<2x2xi32>, %1: tensor<2x3xi32>) -> tensor<2x5xi32> {
  %init_tensor = linalg.init_tensor [2, 5] : tensor<2x5xi32>
  %filled_tensor = linalg.fill(%init_tensor, 0) : tensor<2x5xi32>
  %sub1 = subtensor_insert %0 into %filled_tensor[0, 0] [2, 2] [1, 1] : tensor<2x2xi32> into tensor<?x?xi32>
  %sub2 = subtensor_insert %1 into %filled_tensor[0, 2], [2, 3] [1, 1] : tensor<2x3xi32> into tensor<?x?x?i32>
  return %filled_tensor : tensor<2x5xi32>
}

The issue here that there are no uses for %sub1 and %sub2, so they will be killed in DCE. So far, I don’t have a workable prototype with this option.

Option 3 – adding a linalg.concat op

Adding a special Linalg op looks most promising to me.

Input in tensor world:

%1 = add (...) : tensor< a x ...>
%2 = add (...) : tensor< b x ...>
%3 = linalg.concat(%1, %2) : tensor<(a + b) x ...>

Proposed lowering of a special linalg.concat op that is not an index generic.

%3 = alloc (...) : memref<(a + b) x ...>
%1 = subview %3[][][] : memref<a x ...>
%2 = subview %3[][][] : memref<b x ...>
add(..., %1)
add(..., %2)

The issue is with fusion: the split between %1 and %2 is propagated into anything the adds fuse into until fusion stops. Also, fusion is not a real fusion: we end up with 2 “fused columns”: anything that depends on %1 and anything that depends on %2.

However, having a linalg.concat operation makes sense to me because it fills the missing part in Linalg. There are two version of linalg.concat op in my mind.

Simple linalg.concat op

A linalg.concat op takes various tensors and an index, then produce a concatenated tensor.

E.g.,

%0 = linalg ... : tensor<2x2xi32>
%1 = linalg ... : tensor<2x3xi32>
%c1 = constant 1 : index
%2 = linalg.concat %0, %1 along dim %c1 : (tensor<2x2xi32>, tensor<2x3xi32>) -> tensor<2x5xi32>

In bufferization, you can allocate a buffer for the result and copy the values from operands to the buffer.

%0 = linalg ... : memref<2x2xi32>
%1 = linalg ... : memref<2x3xi32>
%buf = alloc (...) : memref<2x5xi32>
%sub1 = subview %buf [0, 0] [2, 2] [1, 1] : memref<...>
linalg.copy %0, %sub1 : memref<...>
%sub2 = subview %buf [0, 2] [2, 3] [1, 1] : memref<...>
linalg.copy %1, %sub2 : memref<...>

Non-simple linalg.concat op

Essentially a linalg.concat operation is a “collection” of N linalg operations, where N is the number of operations to the concat, each producing a single result tensor. In it simplest form teach of the concat-ed operation will just be a trivial op that just returns the input.

%result = linalg.concat [%0], [%1], [%2], .... [%n] {
   %r0 = linalg... %0
   %r1 = linalg...  %1
   ...
   %rn = linalg... %n
} : tensor<....>

The square brackets around each operand in the top-level linalg.concat operation is to indicate the arguments that are to be “forwarded” to each of the individual linalg operations. So the concat operation takes N list of values.

To fuse this concat operation with its producers, is easy. Lets say %n is produced by another linalg operations as follows

%n = linalg %a, %b
%result = linalg.concat ... [%n] {
   ...
   %rn = linalg... %n
}

after fusion you get

%result = linalg.concat .... [%a, %b] {
   ...
   %rn = linalg... %a, %b 
}

The new Linalg operation producing %rn is just obtained by fusing the operation producing %n and the old operation producing %rn.

When converting to buffers, you can allocate a buffer for the result of the outer linalg.concat operation, and then split it based on the concat specification to get the result buffer for the inner linalg operations. So you get a sequence of operations which computes the linalg.concat “in-place”. Effectively this operation would not exist in buffer world.

(@nicolasvasilache @MaheshRavishankar @asaadaldien @antiagainst @ThomasRaoux for visibility)

I have a workable prototype for option 2 now: Adapt mhlo.concatenate lowering to Linalg on tensors by hanhanW · Pull Request #4954 · google/iree · GitHub

Input:

func @concatenate(%arg0: tensor<2x2xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x5xi32> {
  %0 = "mhlo.concatenate"(%arg0, %arg1)
    {dimension = 1 : i64} : (tensor<2x2xi32>, tensor<2x3xi32>) -> tensor<2x5xi32>
  return %0 : tensor<2x5xi32>
}

Output:

func @concatenate(%arg0: tensor<2x2xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x5xi32> {
  %c1 = constant 1 : index
  %c0 = constant 0 : index
  %c2 = constant 2 : index
  %c3 = constant 3 : index
  %0 = linalg.init_tensor [2, 5] : tensor<2x5xi32>
  %1 = linalg.fill(%0, %c0_i32) : tensor<2x5xi32>, i32 -> tensor<2x5xi32>
  %2 = subtensor_insert %arg0 into %1[%c0, %c0] [%c2, %c2] [%c1, %c1] : tensor<2x2xi32> into tensor<2x5xi32>
  %3 = subtensor_insert %arg1 into %2[%c0, %c2] [%c2, %c3] [%c1, %c1] : tensor<2x3xi32> into tensor<2x5xi32>
  return %3 : tensor<2x5xi32>
}

THanks Hanhan. Lets use the prototype you have for now (this is going into IREE anyway). Once we have a good idea of how bufferization fits in we might be able to pick up one of the better version of concatenate you have above. Thanks for the detailed documentation. Its nice to have all option considered and listed for future reference.