[RFC] Reshape Ops Restructuring
Motivation
Currently, there are 5 operations that can reshape tensors/memrefs in MLIR Core.
They have a lot in common and it would be useful to investigate whether it is possible to merge/remove/restructure some of them. For example, reshape operations in Linalg dialect (linalg.reshape
and linalg.tensor_reshape
) are similar to memref.reshape
and might be moved to MemRef and Tensor dialects respectively. Currently, Tensor dialect does not have an operation that can reshape tensors.
Overview
memref.reshape
Both source and destination can be unranked or ranked. Shape argument is a 1D memref with a static or dynamic size. The data is never copied or modified. The element type and the address space stay the same. Only the identity layout is supported.
Example
// Reshape statically-shaped memref.
%dst = memref.reshape %src(%shape)
: (memref<4x1xf32>, memref<1xi32>) to memref<4xf32>
// Reshape dynamically-shaped memref.
%dst = memref.reshape %src(%shape)
: (memref<?xf32>, memref<?xi32>) to memref<*xf32>
Note, that using a 1D memref to pass the shape is quite unnatural when the size is static. It would make sense to extend this op or create another one that has a Variadic<Index>
shape argument (see Discourse question).
memref.reshape
can be transformed to memref.reinterpret_cast
for statically-sized shape args.
memref.reshape
, when lowered to LLVM, computes strides for the given shape and creates a new memref descriptor.
memref.reinterpret_cast
The source memref can be ranked/unranked, the destination is always ranked. The data is never copied or modified. The element type and the address space stay the same. The operation can override sizes/strides/offset of a memref.
Example
memref.reinterpret_cast %ranked to
offset: [0],
sizes: [%size0, 10],
strides: [1, %stride1]
: memref<?x?xf32> to memref<?x10xf32, offset: 0, strides: [1, ?]>
memref.reinterpret_cast %unranked to
offset: [%offset],
sizes: [%size0, %size1],
strides: [%stride0, %stride1]
: memref<*xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
As with memref.reshape
, this operation creates a new memref descriptor when lowered to LLVM.
linalg.reshape
This operation uses reassociation maps to contract the shape of the source memref type or the destination type, depending on which one has the higher rank. The data is never copied or modified.
Example
// Dimension collapse (i, j) -> i' and k -> k'
%1 = linalg.reshape %0 [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>
// Dimension expansion i -> (i', j') and (k) -> (k')
%1 = linalg.reshape %0 [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
Unlike memref.reshape
, this op captures more information about the resulting shape, which can be beneficial for canonicalization/folding. Also, it does not support unranked memrefs.
There was a plan to allow non-continuous groupings, e.g. permutations, to be performed by this op. If we allow only continuous groupings, then this operation will model only expansion/collapsing of dimensions with a permutation.
linalg.tensor_reshape
Surprisingly enough, it is the same as linalg.reshape
on memrefs, but only on tensors.
vector.reshape
This operations is quite different from the operations above. It does not require the number of elements in input and output vectors to coincide, i.e. there can be undefined elements.
outside MLIR Core
There are many more reshape operations that exist outside MLIR Core , e.g.
mhlo.reshape
, mhlo.dynamic_reshape
, tf.reshape
, tfl.reshape
.
Potential restructuring
Changes to linalg.reshape
The operation uses reassociation maps to model grouping of dimensions. The usage of affine maps to model this is an overkill and hard to read. We can use IndexType
arguments to represent it instead.
Depending on the relationship between source and destination ranks, linalg.reshape
either expands or collapses dimensions, which is also quite confusing because one has to compare the ranks of the input and the output types before understanding to what type the reassociation maps were applied to. We can split linalg.reshape
into memref.collape_shape
and memref.expand_shape
.
// Dimension collapse (i, j) -> i' and k -> k'
%1 = memref.collapse_shape %0 [[0, 1], [2]] :
memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>
// Dimension expansion i -> (i', j') and (k) -> (k')
%1 = memref.expand_shape %0 [[0, 1], [2]] :
memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
The linalg.reshape
that neither collapses nor expands dimensions is a memref.transpose
.
Changes to memref.reshape
memref.reshape
can be extended to accept Variadic<Index>
shape args to get rid of the buffer allocated to store the shape, which is unnecessary in most of the cases when the rank of the output is known. Also, it would allow to easily transform memref.collapse_shape
and memref.expand_shape
to memref.reshape
.
%dst = memref.reshape %src(%0, 4, %1)
: (memref<4x50xf32>, index, index, index) to memref<?x4x?xf32>
Changes to tensor
dialect
tensor.reshape
, tensor.expand_shape
, tensor.collapse_shape
should be added. Should tensor.transpose
be added as well?