[RFC] Reshape Ops Restructuring

[RFC] Reshape Ops Restructuring

Motivation

Currently, there are 5 operations that can reshape tensors/memrefs in MLIR Core.
They have a lot in common and it would be useful to investigate whether it is possible to merge/remove/restructure some of them. For example, reshape operations in Linalg dialect (linalg.reshape and linalg.tensor_reshape) are similar to memref.reshape and might be moved to MemRef and Tensor dialects respectively. Currently, Tensor dialect does not have an operation that can reshape tensors.

Overview

memref.reshape

Both source and destination can be unranked or ranked. Shape argument is a 1D memref with a static or dynamic size. The data is never copied or modified. The element type and the address space stay the same. Only the identity layout is supported.

Example

// Reshape statically-shaped memref.
%dst = memref.reshape %src(%shape)
            : (memref<4x1xf32>, memref<1xi32>) to memref<4xf32>

// Reshape dynamically-shaped memref.
%dst = memref.reshape %src(%shape)
            : (memref<?xf32>, memref<?xi32>) to memref<*xf32>

Note, that using a 1D memref to pass the shape is quite unnatural when the size is static. It would make sense to extend this op or create another one that has a Variadic<Index> shape argument (see Discourse question).

memref.reshape can be transformed to memref.reinterpret_cast for statically-sized shape args.

memref.reshape, when lowered to LLVM, computes strides for the given shape and creates a new memref descriptor.

memref.reinterpret_cast

The source memref can be ranked/unranked, the destination is always ranked. The data is never copied or modified. The element type and the address space stay the same. The operation can override sizes/strides/offset of a memref.

Example

memref.reinterpret_cast %ranked to
     offset: [0],
     sizes: [%size0, 10],
     strides: [1, %stride1]
   : memref<?x?xf32> to memref<?x10xf32, offset: 0, strides: [1, ?]>

memref.reinterpret_cast %unranked to
     offset: [%offset],
     sizes: [%size0, %size1],
     strides: [%stride0, %stride1]
   : memref<*xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>

As with memref.reshape, this operation creates a new memref descriptor when lowered to LLVM.

linalg.reshape

This operation uses reassociation maps to contract the shape of the source memref type or the destination type, depending on which one has the higher rank. The data is never copied or modified.

Example

// Dimension collapse (i, j) -> i' and k -> k'
%1 = linalg.reshape %0 [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
     memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>

// Dimension expansion i -> (i', j') and (k) -> (k')
%1 = linalg.reshape %0 [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
     memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>

Unlike memref.reshape, this op captures more information about the resulting shape, which can be beneficial for canonicalization/folding. Also, it does not support unranked memrefs.

There was a plan to allow non-continuous groupings, e.g. permutations, to be performed by this op. If we allow only continuous groupings, then this operation will model only expansion/collapsing of dimensions with a permutation.

linalg.tensor_reshape

Surprisingly enough, it is the same as linalg.reshape on memrefs, but only on tensors.

vector.reshape

This operations is quite different from the operations above. It does not require the number of elements in input and output vectors to coincide, i.e. there can be undefined elements.

outside MLIR Core

There are many more reshape operations that exist outside MLIR Core , e.g.
mhlo.reshape, mhlo.dynamic_reshape, tf.reshape, tfl.reshape.

Potential restructuring

Changes to linalg.reshape

The operation uses reassociation maps to model grouping of dimensions. The usage of affine maps to model this is an overkill and hard to read. We can use IndexType arguments to represent it instead.
Depending on the relationship between source and destination ranks, linalg.reshape either expands or collapses dimensions, which is also quite confusing because one has to compare the ranks of the input and the output types before understanding to what type the reassociation maps were applied to. We can split linalg.reshape into memref.collape_shape and memref.expand_shape.

// Dimension collapse (i, j) -> i' and k -> k'
%1 = memref.collapse_shape %0 [[0, 1], [2]] :
     memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>

// Dimension expansion i -> (i', j') and (k) -> (k')
%1 = memref.expand_shape %0 [[0, 1], [2]] :
     memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>

The linalg.reshape that neither collapses nor expands dimensions is a memref.transpose.

Changes to memref.reshape

memref.reshape can be extended to accept Variadic<Index> shape args to get rid of the buffer allocated to store the shape, which is unnecessary in most of the cases when the rank of the output is known. Also, it would allow to easily transform memref.collapse_shape and memref.expand_shape to memref.reshape.

%dst = memref.reshape %src(%0, 4, %1)
            : (memref<4x50xf32>, index, index, index) to memref<?x4x?xf32>

Changes to tensor dialect

tensor.reshape, tensor.expand_shape, tensor.collapse_shape should be added. Should tensor.transpose be added as well?

Transformations/Conversions diagram

diagram

2 Likes

Great, thanks for writing this RFC @pifon2a , this will be a welcome restructuring!

I’ll just raise the points we already discussed offline for the sake of discussion.
One aspect that seems still in the air is how much each of the memref ops are either:

  • Assumed correct by construction + UB (+ possibly runtime checks). Or
  • Verified at construction time under what rules.

At this point in spacetime, linalg.reshape wants to statically guarantee that the reshape does not move data and that the tensor types are compatible so it performs type inference and verification. But there are cases where it cannot do better than fail. Some of the annoying cases are:

  • splitting a ? in the memref.expand_reshape is ambiguous, it is possible we’ll want to add dynamic operands to help there with a “must divide or UB” requirement.
  • handling of memref<?x?x?xf32, offset:?, strides:[?, ?, 1]> leads to ambiguous cases, what we would want is something like memref<axbxcxf32, offset:?, strides:[bxc, c, 1]> to capture the contiguity at a boundaries: reshapes are only valid at contiguous boundaries (otherwise we need div and mod and lower level representations).

In the transformation/conversion diagram, there is also the fact that collapse_reshape and expand_reshape compose with certain transformations in linalg-land. This is beyond the scope of this refactoring but is the reason for keeping track of the expansion/contraction list<list<index>>.

Looking forward to seeing this restructuring land!

Can you expand a bit on the transformations/conversions plan. For instance, will any passes be removed (if so, which), and will new passes be added or will functionality be added to existing passes?

I’ll admit, I don’t fully understand the sequence of passes necessary to go from linalg.tensor_reshape to memref.reinterpret_cast (see, for example, Bufferizing linalg.tensor_reshape?), but I love the idea of making it simpler.

There is no plan to remove any passes/patterns/conversions. There will be new patterns to bufferize tensor.reshape to memref.reshape, to convert memref.expand_shape to memref.reinterpret_cast and so on.

For linalg.tensor_reshape (after restructuring it will become tensor.expand_shape or tensor.collapse_shape) to memref.reinterpret_cast, one has to bufferize it first by going tensor.collapse_shape to memref.collapse_shape. After that it can be transformed to memref.reshape if possible or directly to memref.reinterpret_cast if not.

I commented in the revision, but repeating here. I think the attribute that can be useful in general outside of linalg (for example the sparse attribute also has the same concept!).
Could we get a first class attribute to model this instead of arrays of array of attribute?
Ideally we expose a “real” ReassociationAttr class with a safer API (both accessors and constructors).
It can also then be stored more densely in memory, basically the attribute could be defined in TableGen with just two ArrayRef<int64_t>. But it may also be useful for verification purpose to store the rank as well (alternatively it can expose a verifyInvariant method that takes a rank and then operation that have this attribute should call this in their verifier).

1 Like

Having used ArrayAttr<ArrayAttr<I64Attr>>, I cannot agree more that it would be nice to make it a first-class citizen. I will send an RFC and implement it after I am done with this restructuring.

2 Likes

+1 for better APIs thanks for the suggestion!