# [RFC] Reshape Ops Restructuring

## Motivation

Currently, there are 5 operations that can reshape tensors/memrefs in MLIR Core.

They have a lot in common and it would be useful to investigate whether it is possible to merge/remove/restructure some of them. For example, reshape operations in Linalg dialect (`linalg.reshape`

and `linalg.tensor_reshape`

) are similar to `memref.reshape`

and might be moved to MemRef and Tensor dialects respectively. Currently, Tensor dialect does not have an operation that can reshape tensors.

## Overview

###
`memref.reshape`

Both source and destination can be unranked or ranked. Shape argument is a 1D memref with a static or dynamic size. The data is never copied or modified. The element type and the address space stay the same. Only the identity layout is supported.

Example

```
// Reshape statically-shaped memref.
%dst = memref.reshape %src(%shape)
: (memref<4x1xf32>, memref<1xi32>) to memref<4xf32>
// Reshape dynamically-shaped memref.
%dst = memref.reshape %src(%shape)
: (memref<?xf32>, memref<?xi32>) to memref<*xf32>
```

Note, that using a 1D memref to pass the shape is quite unnatural when the size is static. It would make sense to extend this op or create another one that has a `Variadic<Index>`

shape argument (see Discourse question).

`memref.reshape`

can be transformed to `memref.reinterpret_cast`

for statically-sized shape args.

`memref.reshape`

, when lowered to LLVM, computes strides for the given shape and creates a new memref descriptor.

###
`memref.reinterpret_cast`

The source memref can be ranked/unranked, the destination is always ranked. The data is never copied or modified. The element type and the address space stay the same. The operation can override sizes/strides/offset of a memref.

Example

```
memref.reinterpret_cast %ranked to
offset: [0],
sizes: [%size0, 10],
strides: [1, %stride1]
: memref<?x?xf32> to memref<?x10xf32, offset: 0, strides: [1, ?]>
memref.reinterpret_cast %unranked to
offset: [%offset],
sizes: [%size0, %size1],
strides: [%stride0, %stride1]
: memref<*xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
```

As with `memref.reshape`

, this operation creates a new memref descriptor when lowered to LLVM.

###
`linalg.reshape`

This operation uses reassociation maps to contract the shape of the source memref type or the destination type, depending on which one has the higher rank. The data is never copied or modified.

Example

```
// Dimension collapse (i, j) -> i' and k -> k'
%1 = linalg.reshape %0 [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>
// Dimension expansion i -> (i', j') and (k) -> (k')
%1 = linalg.reshape %0 [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
```

Unlike `memref.reshape`

, this op captures more information about the resulting shape, which can be beneficial for canonicalization/folding. Also, it does not support unranked memrefs.

There was a plan to allow non-continuous groupings, e.g. permutations, to be performed by this op. If we allow only continuous groupings, then this operation will model only expansion/collapsing of dimensions with a permutation.

###
`linalg.tensor_reshape`

Surprisingly enough, it is the same as `linalg.reshape`

on memrefs, but only on tensors.

###
`vector.reshape`

This operations is quite different from the operations above. It does not require the number of elements in input and output vectors to coincide, i.e. there can be undefined elements.

### outside MLIR Core

There are many more reshape operations that exist outside MLIR Core , e.g.

`mhlo.reshape`

, `mhlo.dynamic_reshape`

, `tf.reshape`

, `tfl.reshape`

.

## Potential restructuring

###
Changes to `linalg.reshape`

The operation uses reassociation maps to model grouping of dimensions. The usage of affine maps to model this is an overkill and hard to read. We can use `IndexType`

arguments to represent it instead.

Depending on the relationship between source and destination ranks, `linalg.reshape`

either expands or collapses dimensions, which is also quite confusing because one has to compare the ranks of the input and the output types before understanding to what type the reassociation maps were applied to. We can split `linalg.reshape`

into `memref.collape_shape`

and `memref.expand_shape`

.

```
// Dimension collapse (i, j) -> i' and k -> k'
%1 = memref.collapse_shape %0 [[0, 1], [2]] :
memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>
// Dimension expansion i -> (i', j') and (k) -> (k')
%1 = memref.expand_shape %0 [[0, 1], [2]] :
memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
```

The `linalg.reshape`

that neither collapses nor expands dimensions is a `memref.transpose`

.

###
Changes to `memref.reshape`

`memref.reshape`

can be extended to accept `Variadic<Index>`

shape args to get rid of the buffer allocated to store the shape, which is unnecessary in most of the cases when the rank of the output is known. Also, it would allow to easily transform `memref.collapse_shape`

and `memref.expand_shape`

to `memref.reshape`

.

```
%dst = memref.reshape %src(%0, 4, %1)
: (memref<4x50xf32>, index, index, index) to memref<?x4x?xf32>
```

###
Changes to `tensor`

dialect

`tensor.reshape`

, `tensor.expand_shape`

, `tensor.collapse_shape`

should be added. Should `tensor.transpose`

be added as well?