MLIR Support for Sparse Tensors

Note that this is more or less a continuation of the Sparse Representation discussion, but now focused specifically on the progress of MLIR support for sparse tensors.

In a distant past, I pioneered the idea of letting a compiler automatically convert annotated dense linear algebra Fortran to semantically equivalent sparse code, which resulted in the MT1 compiler. But when we started exploring to use a similar approach in MLIR for sparse tensors, I was pleasantly surprised to find the excellent work of Fredrik Kjolstad et al. in the Tensor Algebra Compiler, which formalizes the automatic transformation of annotated tensor algebra to sparse code in an elegant and rigorous manner.

For Q4 this year, we plan to prototype similar ideas in an independent prototype implementation in MLIR. A PDF with a brief description of our plans is posted below, and we welcome your early input. A big kudos to Fredrik and the whole TACO team for clearly explaining their approach in great detail in a series of very well-written papers, which made planning this project for MLIR super easy!

MLIR Support for Sparse Tensors.pdf (229.7 KB)

5 Likes

Some exciting progress. I finished a first prototype implementation of Fredrik’s “sparse iteration theory”. Next comes the tedious work of mapping the merge lattices onto actual MLIR code.

For example, given a kernel like this (expressed in Linalg as discussed in the doc)

A(i,j) = SUM_k,l B(i,k,l) * C(k,j) * D(l,j)

the sparse compiler computes the following topologically sorted loop indices (where i_0=i, i_1=j etc).

i_0 < i_2 < i_3 < i_1

Furthermore, given a kernel like this

a(i) = (b(i) + c(i)) + (d(i) + e(i))

we get the proper 15 lattice points shown below (here i_xy denotes index x for tensor y, and the subexpression is spelled out after the /).

{
  lat( i_00 i_01 i_02 i_03 / ((tensor_0 + tensor_1) + (tensor_2 + tensor_3)))
  lat( i_00 i_01 i_02 / ((tensor_0 + tensor_1) + tensor_2))
  lat( i_00 i_01 i_03 / ((tensor_0 + tensor_1) + tensor_3))
  lat( i_00 i_02 i_03 / (tensor_0 + (tensor_2 + tensor_3)))
  lat( i_01 i_02 i_03 / (tensor_1 + (tensor_2 + tensor_3)))
  lat( i_00 i_02 / (tensor_0 + tensor_2))
  lat( i_00 i_03 / (tensor_0 + tensor_3))
  lat( i_01 i_02 / (tensor_1 + tensor_2))
  lat( i_01 i_03 / (tensor_1 + tensor_3))
  lat( i_00 i_01 / (tensor_0 + tensor_1))
  lat( i_02 i_03 / (tensor_2 + tensor_3))
  lat( i_00 / tensor_0)
  lat( i_01 / tensor_1)
  lat( i_02 / tensor_2)
  lat( i_03 / tensor_3)
}

In contrast,

a(i) = (b(i) * c(i)) * (d(i) * e(i))

yields the singleton lattice, as expected

{ 
  lat( i_00 i_01 i_02 i_03 / ((tensor_0 * tensor_1) * (tensor_2 * tensor_3)))
}

Next is mapping this back to an MLIR representation.

Before actual MLIR code, I found it easier to get the code generation structure right by simply emitting pseudo-code on the debugging output.

The kernel:

x(i) = a(i) + b(i) * c(i)

looks as generic Linalg operation as follows:

func @generic_op_vec1d(%arga: tensor<?xf32>,
                       %argb: tensor<?xf32>,
                       %argc: tensor<?xf32>) -> tensor<?xf32> {
  %0 = linalg.generic #trait_vec1d
      ins(%arga,
            %argb,
            %argc : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
    ^bb(%a: f32, %b: f32, %c : f32):
      %0 = mulf %b, %c : f32
      %1 = addf %a, %0 : f32
      linalg.yield %1 : f32
  } -> tensor<?xf32>
  return %0 : tensor<?xf32>
}

The newly implemented code generation traverses the generated merge lattices (from SSA form) recursively to generate the following pseudo-code on the debugging output:

while ( i_00 i_01 i_02 )
  if ( i_00 i_01 i_02 )
    tensor_out := (tensor_0 + (tensor_1 * tensor_2));
  if ( i_01 i_02 )
    tensor_out := (tensor_1 * tensor_2);
  if ( i_00 )
    tensor_out := tensor_0;
while ( i_01 i_02 )
  if ( i_01 i_02 )
    tensor_out := (tensor_1 * tensor_2);
while ( i_00 )
  tensor_out := tensor_0;

A more elaborate kernel like:

x(i) = a(i) + b(i) + c(i)

Generates the following:

while ( i_00 i_01 i_02 )
  if ( i_00 i_01 i_02 )
    tensor_out := (tensor_2 + (tensor_0 + tensor_1));
  if ( i_00 i_02 )
    tensor_out := (tensor_2 + tensor_0);
  if ( i_01 i_02 )
    tensor_out := (tensor_2 + tensor_1);
  if ( i_00 i_01 )
    tensor_out := (tensor_0 + tensor_1);
  if ( i_02 )
    tensor_out := tensor_2;
  if ( i_00 )
    tensor_out := tensor_0;
  if ( i_01 )
    tensor_out := tensor_1;
while ( i_00 i_02 )
  if ( i_00 i_02 )
    tensor_out := (tensor_2 + tensor_0);
  if ( i_02 )
    tensor_out := tensor_2;
  if ( i_00 )
    tensor_out := tensor_0;
while ( i_01 i_02 )
  if ( i_01 i_02 )
    tensor_out := (tensor_2 + tensor_1);
  if ( i_02 )
    tensor_out := tensor_2;
  if ( i_01 )
    tensor_out := tensor_1;
while ( i_00 i_01 )
  if ( i_00 i_01 )
    tensor_out := (tensor_0 + tensor_1);
  if ( i_00 )
    tensor_out := tensor_0;
  if ( i_01 )
    tensor_out := tensor_1;
while ( i_02 )
  tensor_out := tensor_2;
while ( i_00 )
  tensor_out := tensor_0;
while ( i_01 )
  tensor_out := tensor_1;

Since these look good to me (does anyone spot something strange?), now it needs to be linked to an actual MLIR representation of the tensors in either dense or sparse format. This will probably take a while…

I am still very enthusiastic, with lots of progress this week! Most part of the actual MLIR codegen is done now too, pending while-loops for true co-iteration. But anything expressible with for-loops works already. Given a kernel “x(i) = a(i) * b(i)”, expressed in linalg as follows, and b annotated sparse

#trait_vec_1d = {
  indexing_maps = [
    affine_map<(i) -> (i)>,  // a
    affine_map<(i) -> (i)>,  // b
    affine_map<(i) -> (i)>   // x out
  ],
  sparse = [
    [ "D" ],  // a
    [ "S" ],  // b
    [ "D" ]   // x
  ],
  iterator_types = ["parallel"],
  doc = "x(i) = a(i) * b(i)"
}

The sparse compiler component in MLIR now generates the following code. Note that bufferization is done “locally”, since we don’t have a global solution to propagating sparse types yet.

     ... allocs for local bufferization ....
    %6 = load %2[%c0] : memref<?xindex>
    %7 = load %2[%c1] : memref<?xindex>
    scf.for %arg2 = %6 to %7 step %c1 {
      %9 = load %3[%arg2] : memref<?xindex>
      %10 = load %1[%9] : memref<?xf32>
      %11 = load %4[%arg2] : memref<?xf32>
      %12 = mulf %10, %11 : f32
      store %12, %5[%9] : memref<?xf32>
    }

In comparison, taco would yield the following code.

for (int32_t ib = b1_pos[0]; ib < b1_pos[1]; ib++) {
  int32_t i = b1_crd[ib];
  x_vals[i] = a_vals[i] * b_vals[ib];
}

I hope to send out a CL next week.

Thanks to @ftynse, we now have while-loops in the SCF dialect as well, which greatly sped up development of the sparse compiler part! I expect to send out the first CL any day now.

Given a kernel like

x(i) = a(i) + b(i)

with both a and b annotated as sparse vectors, MLIR now lowers the linalg operation for this into the following (note the elaborate yield structure to get the conditional induction right):

    %7 = load %0[%c0] : memref<?xindex>
    %8 = load %0[%c1] : memref<?xindex>
    %9 = load %3[%c0] : memref<?xindex>
    %10 = load %3[%c1] : memref<?xindex>
    %11:2 = scf.while (%arg2 = %7, %arg3 = %9) : (index, index) -> (index, index) {
      %13 = cmpi "ult", %arg2, %8 : index
      %14 = cmpi "ult", %arg3, %10 : index
      %15 = and %13, %14 : i1
      scf.condition(%15) %arg2, %arg3 : index, index
    } do {
    ^bb0(%arg2: index, %arg3: index):  // no predecessors
      %13 = load %1[%arg2] : memref<?xindex>
      %14 = load %4[%arg3] : memref<?xindex>
      %15 = cmpi "ult", %14, %13 : index
      %16 = select %15, %14, %13 : index
      %17 = cmpi "eq", %13, %16 : index
      %18 = cmpi "eq", %14, %16 : index
      %19 = and %17, %18 : i1
      %20:2 = scf.if %19 -> (index, index) {
        %21 = load %2[%arg2] : memref<?xf32>
        %22 = load %5[%arg3] : memref<?xf32>
        %23 = addf %21, %22 : f32
        store %23, %6[%16] : memref<32xf32>
        %24 = addi %arg2, %c1 : index
        %25 = addi %arg3, %c1 : index
        scf.yield %24, %25 : index, index
      } else {
       %21 = cmpi "eq", %13, %16 : index
        %22:2 = scf.if %21 -> (index, index) {
          %23 = load %2[%arg2] : memref<?xf32>
          store %23, %6[%16] : memref<32xf32>
          %24 = addi %arg2, %c1 : index
          scf.yield %24, %arg3 : index, index
        } else {
          %23 = cmpi "eq", %14, %16 : index
          %24:2 = scf.if %23 -> (index, index) {
            %25 = load %5[%arg3] : memref<?xf32>
            store %25, %6[%16] : memref<32xf32>
            %26 = addi %arg3, %c1 : index
            scf.yield %arg2, %26 : index, index
          } else {
            scf.yield %arg2, %arg3 : index, index
          }
          scf.yield %24#0, %24#1 : index, index
        }
        scf.yield %22#0, %22#1 : index, index
      }
      scf.yield %20#0, %20#1 : index, index
    }
    scf.for %arg2 = %11#0 to %8 step %c1 {
      %13 = load %1[%arg2] : memref<?xindex>
      %14 = load %2[%arg2] : memref<?xf32>
      store %14, %6[%13] : memref<32xf32>
    }
    scf.for %arg2 = %11#1 to %10 step %c1 {
      %13 = load %4[%arg2] : memref<?xindex>
      %14 = load %5[%arg2] : memref<?xf32>
      store %14, %6[%13] : memref<32xf32>
    }

For comparison, this would be TACO’s output (more readable at source level):

  int32_t ia = a1_pos[0];
  int32_t pa1_end = a1_pos[1];
  int32_t ib = b1_pos[0];
  int32_t pb1_end = b1_pos[1];
  while (ia < pa1_end && ib < pb1_end) {
    int32_t ia0 = a1_crd[ia];
    int32_t ib0 = b1_crd[ib];
    int32_t i = TACO_MIN(ia0,ib0);
    if (ia0 == i && ib0 == i) {
      x_vals[i] = a_vals[ia] + b_vals[ib];
    }
    else if (ia0 == i) {
      x_vals[i] = a_vals[ia];
    }
    else {
      x_vals[i] = b_vals[ib];
    }
    ia += (int32_t)(ia0 == i);
    ib += (int32_t)(ib0 == i);
  }
  while (ia < pa1_end) {
    int32_t i = a1_crd[ia];
    x_vals[i] = a_vals[ia];
    ia++;
  }
  while (ib < pb1_end) {
    int32_t i = b1_crd[ib];
    x_vals[i] = b_vals[ib];
    ib++;
  }

For those interested, the first CL is out for review (https://reviews.llvm.org/D90994). Note that a lot still needs to be done. Running and testing the generated code is next, then measuring and improving the performance. Also, proliferating the sparse bufferization needs to be addressed.

The first prototype CL has landed, so it was time for a quick sanity check. Starting with a “dense” Linalg expression for column-wise matrix x vector, I let MLIR automatically generate sparse code. Then I replaced the “localized buffers” with parameters of a surrounding method and hand-coded initialization code for the sparse storage scheme (but no changes to the core code).

For test matrix fidapm37.mtx (9152x9152 with 765,944 nonzero elements) in single-precision, this runs in 760 microseconds (about 2GFlops), which is about the same time taken by the Eigen library, and produces the same results. Likewise fidap011.mtx (16614x16614 with 1,091,362 nonzero elements), both run in about 1070 vs. 950 microseconds (bit over 2GFlops).

func @matvec(
       %0 : memref<?xindex>,    // pointer array
       %1 : memref<?xindex>,    // index array
       %2 : memref<?xf32>,      // value array
       %3 : memref<9152xf32>,   // input vector
       %4 : memref<9152xf32>) { // output vector
    .... automatically generated sparse code ....
}

A pending CL (https://reviews.llvm.org/D91978) adds some rudimentary parallelization strategies for the generated for-loops to the sparse compiler (outer/all, dense-only/all). This work will combine nicely with the work @ezhulenev is doing on executing parallel loops with a thread pool on CPU.

And right before the break, the sparse compiler also has a vectorization strategy, at least for the innermost for-loops.

#trait_s = {
  indexing_maps = [
    affine_map<(i) -> (i)>,  // a
    affine_map<(i) -> (i)>   // x (out)
  ],
  sparse = [
    [ "S" ],  // a
    [ "D" ]   // x
  ],
  iterator_types = ["parallel"],
  doc = "x(i) = a(i) * b"
}

  %0 = linalg.generic #trait_s
    ins(%arga: tensor<32xf32>) {
      ^bb(%a: f32):
        %0 = mulf %a, %argb  : f32
        linalg.yield %0 : f32
  } -> tensor<32xf32>

With AVX512, the default sparse conversion + SIMD will now yield:

    vbroadcastss    zmm0, xmm0
.LBB0_2:                                
    vmovups zmm1, zmmword ptr [rdi + 4*rax]
    vmulps  zmm2, zmm0, zmmword ptr [rsi + 4*rax]
    kxnorw  k1, k0, k0
    vscatterdps     zmmword ptr [rdx + 4*zmm1] {k1}, zmm2
    add     rax, 16
    cmp     rax, rcx
    jl      .LBB0_2
 Happy holidays everyone!
1 Like

Nice :), happy holidays!

For who is interested, the vectorization strategy in the sparse compiler is now available for review. The generated masked memory operations interact nicely with subsequent vector dialect optimizations, such as folding constant masks. Note, however, for masks that remain, we will have to add more loop optimizations that split such loops into a full unconditional vector loop and a scalar cleanup loop later (preferable as a new independent, reusable pass).

For example, given a vector kernel, in index notation, like “x(i) = a(i) * b(i)”, then with all vectors marked dense we get the following (vector length evenly divides iteration space):

    scf.for %arg2 = %c0 to %c1024 step %c16 {
      %4 = vector.transfer_read %0[%arg2], %cst {masked = [false]}
        : memref<1024xf32>, vector<16xf32>
      %5 = vector.transfer_read %1[%arg2], %cst {masked = [false]}
        : memref<1024xf32>, vector<16xf32>
      %6 = mulf %4, %5 : vector<16xf32>
      vector.transfer_write %6, %2[%arg2] {masked = [false]}
        : vector<16xf32>, memref<1024xf32>
    }

and with vector “a” marked sparse we get (masking required to deal with the symbolic, unknown trip count).

    scf.for %arg2 = %6 to %8 step %c16 {
      %10 = subi %8, %arg2 : index
      %11 = vector.create_mask %10 : vector<16xi1>
      %12 = vector.maskedload %1[%arg2], %11, %cst
        : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
      %13 = vector.maskedload %2[%arg2], %11, %cst_0
        : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
      %14 = vector.gather %3[%12], %11, %cst_0
        : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
           into vector<16xf32>
      %15 = mulf %13, %14 : vector<16xf32>
      vector.scatter %4[%12], %11, %15
        : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
    }

The latter maps to the following AVX512 code.

loop:
        vpbroadcastd    zmm1, esi
        vpcmpgtd        k1, zmm1, zmm0
        vmovdqu32       zmm1 {k1} {z}, zmmword ptr [rdx + 4*rax]
        vxorps  xmm2, xmm2, xmm2
        vmovups zmm3 {k1} {z}, zmmword ptr [rdi + 4*rax]
        kmovq   k2, k1
        vgatherdps      zmm2 {k2}, zmmword ptr [r9 + 4*zmm1]
        vmulps  zmm2, zmm3, zmm2
        vscatterdps     zmmword ptr [r8 + 4*zmm1] {k1}, zmm2
        add     rax, 16
        add     esi, -16
        cmp     rax, rcx
        jl      loop

Note that the strategies can also be combined. For example, a kernel like "“x(i) = SUM A(i,j) * B(i,j)”, with A’s innermost dimension marked as sparse can yield the following parallel-vector-reduction loop nest:

scf.parallel (%arg3) = (%c0) to (%c512) step (%c1) {
  %6 = load %0[%arg3] : memref<?xi32>
  %7 = index_cast %6 : i32 to index
  %8 = addi %arg3, %c1 : index
  %9 = load %0[%8] : memref<?xi32>
  %10 = index_cast %9 : i32 to index
  %11 = scf.for %arg4 = %7 to %10 step %c16
             iter_args(%arg5 = %cst_0) -> (vector<16xf32>) {
    %14 = subi %10, %arg4 : index
    %15 = vector.create_mask %14 : vector<16xi1>
    %16 = vector.maskedload %1[%arg4], %15, %cst
     : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
    %17 = vector.maskedload %2[%arg4], %15, %cst_0
     : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
    %18 = vector.gather %3[%16], %15, %cst_0
     : memref<512x1024xf32>, vector<16xi32>,
          vector<16xi1>, vector<16xf32> into vector<16xf32>
    %19 = mulf %17, %18 : vector<16xf32>
    %20 = addf %arg5, %19 : vector<16xf32>
    scf.yield %20 : vector<16xf32>
  }
  %12 = load %4[%arg3] : memref<512xf32>
  %13 = vector.reduction "add", %11, %12 : vector<16xf32> into f32
  store %13, %4[%arg3] : memref<512xf32>
  scf.yield
}