LLVM Discussion Forums

MLIR Support for Sparse Tensors

Note that this is more or less a continuation of the Sparse Representation discussion, but now focused specifically on the progress of MLIR support for sparse tensors.

In a distant past, I pioneered the idea of letting a compiler automatically convert annotated dense linear algebra Fortran to semantically equivalent sparse code, which resulted in the MT1 compiler. But when we started exploring to use a similar approach in MLIR for sparse tensors, I was pleasantly surprised to find the excellent work of Fredrik Kjolstad et al. in the Tensor Algebra Compiler, which formalizes the automatic transformation of annotated tensor algebra to sparse code in an elegant and rigorous manner.

For Q4 this year, we plan to prototype similar ideas in an independent prototype implementation in MLIR. A PDF with a brief description of our plans is posted below, and we welcome your early input. A big kudos to Fredrik and the whole TACO team for clearly explaining their approach in great detail in a series of very well-written papers, which made planning this project for MLIR super easy!

MLIR Support for Sparse Tensors.pdf (229.7 KB)

3 Likes

Some exciting progress. I finished a first prototype implementation of Fredrik’s “sparse iteration theory”. Next comes the tedious work of mapping the merge lattices onto actual MLIR code.

For example, given a kernel like this (expressed in Linalg as discussed in the doc)

A(i,j) = SUM_k,l B(i,k,l) * C(k,j) * D(l,j)

the sparse compiler computes the following topologically sorted loop indices (where i_0=i, i_1=j etc).

i_0 < i_2 < i_3 < i_1

Furthermore, given a kernel like this

a(i) = (b(i) + c(i)) + (d(i) + e(i))

we get the proper 15 lattice points shown below (here i_xy denotes index x for tensor y, and the subexpression is spelled out after the /).

{
  lat( i_00 i_01 i_02 i_03 / ((tensor_0 + tensor_1) + (tensor_2 + tensor_3)))
  lat( i_00 i_01 i_02 / ((tensor_0 + tensor_1) + tensor_2))
  lat( i_00 i_01 i_03 / ((tensor_0 + tensor_1) + tensor_3))
  lat( i_00 i_02 i_03 / (tensor_0 + (tensor_2 + tensor_3)))
  lat( i_01 i_02 i_03 / (tensor_1 + (tensor_2 + tensor_3)))
  lat( i_00 i_02 / (tensor_0 + tensor_2))
  lat( i_00 i_03 / (tensor_0 + tensor_3))
  lat( i_01 i_02 / (tensor_1 + tensor_2))
  lat( i_01 i_03 / (tensor_1 + tensor_3))
  lat( i_00 i_01 / (tensor_0 + tensor_1))
  lat( i_02 i_03 / (tensor_2 + tensor_3))
  lat( i_00 / tensor_0)
  lat( i_01 / tensor_1)
  lat( i_02 / tensor_2)
  lat( i_03 / tensor_3)
}

In contrast,

a(i) = (b(i) * c(i)) * (d(i) * e(i))

yields the singleton lattice, as expected

{ 
  lat( i_00 i_01 i_02 i_03 / ((tensor_0 * tensor_1) * (tensor_2 * tensor_3)))
}

Next is mapping this back to an MLIR representation.

Before actual MLIR code, I found it easier to get the code generation structure right by simply emitting pseudo-code on the debugging output.

The kernel:

x(i) = a(i) + b(i) * c(i)

looks as generic Linalg operation as follows:

func @generic_op_vec1d(%arga: tensor<?xf32>,
                       %argb: tensor<?xf32>,
                       %argc: tensor<?xf32>) -> tensor<?xf32> {
  %0 = linalg.generic #trait_vec1d
      ins(%arga,
            %argb,
            %argc : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
    ^bb(%a: f32, %b: f32, %c : f32):
      %0 = mulf %b, %c : f32
      %1 = addf %a, %0 : f32
      linalg.yield %1 : f32
  } -> tensor<?xf32>
  return %0 : tensor<?xf32>
}

The newly implemented code generation traverses the generated merge lattices (from SSA form) recursively to generate the following pseudo-code on the debugging output:

while ( i_00 i_01 i_02 )
  if ( i_00 i_01 i_02 )
    tensor_out := (tensor_0 + (tensor_1 * tensor_2));
  if ( i_01 i_02 )
    tensor_out := (tensor_1 * tensor_2);
  if ( i_00 )
    tensor_out := tensor_0;
while ( i_01 i_02 )
  if ( i_01 i_02 )
    tensor_out := (tensor_1 * tensor_2);
while ( i_00 )
  tensor_out := tensor_0;

A more elaborate kernel like:

x(i) = a(i) + b(i) + c(i)

Generates the following:

while ( i_00 i_01 i_02 )
  if ( i_00 i_01 i_02 )
    tensor_out := (tensor_2 + (tensor_0 + tensor_1));
  if ( i_00 i_02 )
    tensor_out := (tensor_2 + tensor_0);
  if ( i_01 i_02 )
    tensor_out := (tensor_2 + tensor_1);
  if ( i_00 i_01 )
    tensor_out := (tensor_0 + tensor_1);
  if ( i_02 )
    tensor_out := tensor_2;
  if ( i_00 )
    tensor_out := tensor_0;
  if ( i_01 )
    tensor_out := tensor_1;
while ( i_00 i_02 )
  if ( i_00 i_02 )
    tensor_out := (tensor_2 + tensor_0);
  if ( i_02 )
    tensor_out := tensor_2;
  if ( i_00 )
    tensor_out := tensor_0;
while ( i_01 i_02 )
  if ( i_01 i_02 )
    tensor_out := (tensor_2 + tensor_1);
  if ( i_02 )
    tensor_out := tensor_2;
  if ( i_01 )
    tensor_out := tensor_1;
while ( i_00 i_01 )
  if ( i_00 i_01 )
    tensor_out := (tensor_0 + tensor_1);
  if ( i_00 )
    tensor_out := tensor_0;
  if ( i_01 )
    tensor_out := tensor_1;
while ( i_02 )
  tensor_out := tensor_2;
while ( i_00 )
  tensor_out := tensor_0;
while ( i_01 )
  tensor_out := tensor_1;

Since these look good to me (does anyone spot something strange?), now it needs to be linked to an actual MLIR representation of the tensors in either dense or sparse format. This will probably take a while…