LLVM Discussion Forums

[StandardOps] Add `ReshapeMemrefCastOp` to Standard

For unranked code generation for cwise ops we need an operation that can take UnrankedMemRefType and flatten it into a 1D vector, i.e. memref<?xelement_type>. After the unary operation is computed, the vector has to be reshaped back to it’s original shape. So what we need is a reshape operation that supports unranked memrefs.

The operation will take shape memref as a second argument. Having variadic IndexType size arguments is not possible since we don’t know how many are there in the dynamically-ranked case. Having TensorType to pass the shape is also not optimal, because we would have to care about allocating it.

Here are some example applications of the op:

a. Both source and destination types are ranked memref types.

```mlir
// Reshape statically-shaped memref.
%dst = reshape_memref_cast %src(%shape)
         : (memref<4x1xf32>, memref<1xindex>) to memref<4xf32>
```

b. Source type is ranked, destination type is unranked.

```mlir
// Reshape dynamically-shaped 1D memref.
%dst = reshape_memref_cast %src(%shape)
         : (memref<?xf32>, memref<?xindex>) to memref<*xf32>
```

c. Source type is unranked, destination type is ranked.

```mlir
// Flatten unranked memref.
%dst = reshape_memref_cast %src(%shape)
         : (memref<*xf32>, memref<1xindex>) to memref<?xf32>
```

Implementation: https://reviews.llvm.org/D83068

Thanks for the proposition, I have several questions regarding the semantics:

  • Does this operation imply only cast semantics, i.e. the underlying data is never copied/moved?
  • In any case, the specification needs a more precise description of the supported cases. Can it cast memref<4x2xf32> to memref<2x4xf32>? to memref<8xf32>? to memref<16xf32>, to memref<64xi8>? Can it cast memref<*xf32> to any other memref rank than 1 or the rank of the underlying memref?
  • What is the behavior when the types don’t match dynamically, regardless of what “match” means?
  • Can this be used to replace linalg.reshape?
  • How does it work for strided memrefs?

Independently, I don’t think I understand the issue with variadic arguments. It looks like only the source or the target memref can be unranked. Therefore, the number of sizes seems always relatable to the rank of the other memref. Even if it wasn’t, we still can have a list of variadic arguments with dynamic length.

  • Does this operation imply only cast semantics, i.e. the underlying data is never copied/moved?

No, the data is never copied.

  • In any case, the specification needs a more precise description of the supported cases.

I will add it to the PR.

Can it cast memref<4x2xf32> to memref<2x4xf32> ?

yes.

to memref<8xf32> ?

yes

to memref<16xf32>

yes

to memref<64xi8>

no, i will add a constraint for element type ?

Can it cast memref<*xf32> to any other memref rank than 1 or the rank of the underlying memref?

yes

  • What is the behavior when the types don’t match dynamically, regardless of what “match” means?

Undefined, like for std.memref_cast when an unranked memref is cast back.

I am not sure about that. linalg.reshape provides a way to conveniently map/group dimensions. Also it can be fused with other Linalg ops. But linalg.reshape probably doesn’t work with unranked memrefs. I consider ReshapeMemrefCastOp a lower-level op.

It doesn’t. I will add more verification checks.

Independently, I don’t think I understand the issue with variadic arguments. It looks like only the source or the target memref can be unranked.

I have already removed this from the PR. You can now cast unranked->unranked.

Therefore, the number of sizes seems always relatable to the rank of the other memref.

The number of sizes is always the rank of the output memref in both ranked/unranked variants.

Even if it wasn’t, we still can have a list of variadic arguments with dynamic length.

I didn’t know that. Is there an example of such op?

In my mind memref always has a strong tie with buffers and dynamic memory allocation; so having a better way to express those shape dimensions would be nice. Using multiple variadic operands is one way to go; I’m also wondering whether it makes sense to use what we have in the shape dialect given I think it’s for that purpose. :slight_smile: This will make the standard dialect depends on the shape dialect for types but I guess the shape dialect, as a general utility, will be in the dependency graph anyway once it’s ready for prime use.

I think Alex meant AttrSizedOperandSegments. You can search the codebase to find its users.

Not really, Alex didn’t mean AttrSizedOperandSegments.

Alex proposed to support passing shape as a memref and as a list of indices to get rid of materializing shape at least for statically-ranked shapes.

let arguments = (ins MemRef:$input, Variadic<AnyTypeOf<MemRef, Index>>:$shape_description)

Also since we have so many cast ops for memrefs (MemRefCastOp, ReshapeMemRefCastOp, upcoming ReinterpretMemRefCastOp), we might want to understand what set of ops is actually needed.