[RFC][Standard] Memref cast ops

[RFC][Standard] Memref cast ops

TF code generation for unranked and dynamically-shaped inputs required additional memref cast ops that don’t modify the underlying data, but change the memref descriptor.

Example: CWise unary op
For unranked code generation for cwise ops we need an operation that can take UnrankedMemRefType and flatten it into a 1D vector, i.e. memref<?xelement_type>. After the unary operation is computed, the vector has to be reshaped back to it’s original shape.

StaticMemRefCastOp, DynamicMemRefCastOp and ReshapeMemRefCastOp are now a part of LMHLO. Conversion patterns to LLVM also live within TF code base. Since these operations are quite generic, they might provide value to a much wider audience.

Moreover, there is already MemRefCastOp in Standard used for type erasure/inference. It might become confusing to have so many cast ops.

Before moving LMHLO memref cast operations to Standard, I suggest to revisit their design.

Previous discussions:

Current State Overview

MemRefCastOp in Standard

Argument

  • AnyRankedOrUnrankedMemRef:$source

Application

  • Erasing rank
memref_cast %buf : memref<4x?xf32> to memref<*xf32>
  • Erasing static offset, size, stride information
memref_cast %buf : memref<12x4xf32, offset:5, strides: [4, 1]> to
                   memref<12x4xf32, offset:?, strides: [?, ?]> 
  • Cast unranked memref to a concrete shape
memref_cast %buf : memref<*xf32> to memref<4x?xf32>
  • Cast dynamically-shaped memref to a concrete shape
memref_cast %buf : memref<12x4xf32, offset:?, strides: [?, ?]> to
                   memref<12x4xf32, offset:5, strides: [4, 1]>

Remarks
The name of the operation does not reflect what is actually happening. It might become a problem for IR readability, when more memref cast ops are added to the Standard dialect. This operation can be used for type erasure and type inference. I think it would be helpful to split this op into MemRefErasingCastOp and MemRefInferringCastOp.

StaticMemRefCastOp in LMHLO

Argument

  • MemRefOf<AnyType>:$source

Application

  • Modify the offset, sizes and strides of a statically shaped memref.
lmhlo.static_memref_cast %buf : memref<1x5xf32> ->
                                memref<5xf32, offset: 2, strides: [1]>

DynamicMemRefCastOp in LMHLO

Arguments:

  • MemRefOf<AnyType>:$source
  • Variadic<Index>:$sizes
  • Variadic<Index>:$strides

Application

  • Modify sizes and strides of a memref using the values computed in runtime.
lmhlo.dynamic_memref_cast %buf(%size0, %size1)[%stride0, %stride1]
    : memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>

Remark
This operation is almost equivalent to lmhlo.static_memref_cast except it does not extract sizes and strides from the result type, but from $sizes and $strides arguments.

ReshapeMemRefCastOp in LMHLO

Arguments:

  • AnyRankedOrUnrankedMemRef:$source
  • MemRefRankOf<[AnySignlessInteger], [1]>:$shape

Application

  • Modify sizes of an unranked/ranked memref when the $shape argument has static length
// Reshape statically-shaped memref.
lmhlo.reshape_memref_cast %buf(%shape)
      : (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32>
      
// Reshape unranked memref.
lmhlo.reshape_memref_cast %buf0(%shape0)
     : (memref<*xf32>, memref<1xi32>) to memref<?xf32>
  • Modify sizes of an unranked/ranked memref when the $shape argument has dynamic length
// Reshape dynamically-shaped 1D memref.
lmhlo.reshape_memref_cast %buf(%shape)
    : (memref<?xf32>, memref<?xi32>) to memref<*xf32>

// Reshape unranked memref.
lmhlo.reshape_memref_cast %buf0(%shape0)
    : (memref<*xf32>, memref<?xi32>) to memref<*xf32>

Remark
This operation is capable of reshaping any ranked/unranked memref with a dynamic/static shape length and output unranked or ranked memref. Notice, that it is not possible to get a ranked memref result if $shape argument has a dynamic length.

When this operation is converted to LLVM dialect, strides are computed in the ConversionPattern. At the moment, the strides model the identity map only.

Also note that dynamic_memref_cast is almost identical to reshape_memref_cast when $shape has static length. There are two minor differences that can be eliminated in the new design. reshape_memref_cast supports unranked $source input and dynamic_memref_cast has $strides argument.

Proposed ops

I propose to split std.memref_cast into std.memref_erasing_cast and std.memref_inferring_cast depending on whether it adds more info to the result type or not.

lmhlo.static_memref_cast, lmhlo.dynamic_memref_cast and static-length shape case of lmhlo.reshape_memref_cast will be merged into std.reinterpret_memref_cast (thanks to Uday for the op name).

Dynamic-length shape case of lmhlo.reshape_memref_cast will become std.memref_dynamic_reshape.

I am open to suggestions, especially, when it comes to naming and printing of the ops.

MemRefErasingCastOp in Standard

Argument

  • MemRefOf<AnyType>:$source

Application

  • Erasing rank
memref_erasing_cast %buf : memref<4x?xf32> to memref<*xf32>
  • Erasing static offset, size, stride information
memref_erasing_cast %buf : memref<12x4xf32, offset:5, strides: [4, 1]> to
                           memref<12x4xf32, offset:?, strides: [?, ?]> 

MemRefInferringCastOp in Standard

Argument

  • AnyRankedOrUnrankedMemRef:$source

Application

  • Cast unranked memref to a concrete shape
memref_inferring_cast %buf : memref<*xf32> to memref<4x?xf32>
  • Cast dynamically-shaped memref to a concrete shape
memref_inferring_cast %buf : memref<12x4xf32, offset:?, strides: [?, ?]> to
                             memref<12x4xf32, offset:5, strides: [4, 1]>

MemRefReinterpretCastOp in Standard

Arguments:

  • MemRefOf<AnyType>:$source
  • Index:$offset
  • Variadic<Index>:$sizes
  • Variadic<Index>:$strides

Application

  • Modify sizes and strides of an unranked/ranked memref using the values computed in runtime.
memref_reinterpret_cast %buf [
  offset: $offset, sizes: %size0, %size1, strides: %stride0, %stride1
] : memref<?x?xf32> -> memref<?x?xf32, offset: ?, strides: [?, ?]>

memref_reinterpret_cast %unranked [
  offset: $offset, sizes: %size0, %size1, strides: %stride0, %stride1
] : memref<*xf32> -> memref<?x?xf32, offset: ?, strides: [?, ?]>

MemRefDynamicReshapeOp in Standard

Arguments:

  • AnyRankedOrUnrankedMemRef:$source
  • MemRefRankOf<[AnySignlessInteger], [1]>:$shape // strictly memref<?xIntegerType>

Application

  • Modify sizes and strides of an unranked/ranked memref using the values computed in runtime.
// Reshape dynamically-shaped memref.
memref_dynamic_reshape %src(%shape)
    : (memref<?x?xf32>, memref<?xi32>) to memref<*xf32>

// Reshape unranked memref.
memref_dynamic_reshape %src(%shape)
    : (memref<*xf32>, memref<?xi32>) to memref<*xf32>
2 Likes

This looks really good to me. It’s nice and important to have these ops move from mlhlo into the standard dialect. For memref_cast splitting, I can’t immediately tell the pros and cons of splitting into two (inferring and erasing). Can you say something about the benefits of the split?

For memref_dynamic_reshape, the shape could be a static memref, right?, in which case, would you require the result type to be ranked? For eg.

memref_dynamic_reshape %src(%shape)
    : (memref<?x?xf32>, memref<3xi32>) to memref<?x?x?xf32>

(One can always convert to an unranked one with memref_erasing_cast; so it may be good to keep this op less loaded.)

On a minor note, should the shape be an i32 memref or an index memref? It’s really modifying the shape as opposed to sizes and strides, right? (The latter is derived.)

For memref_dynamic_reshape , the shape could be a static memref, right?, in which case, would you require the result type to be ranked?

That’s an interesting one. We can either allow static memref shape arguments and then have a transformation that converts memref_dynamic_reshape to memref_reinterpret_cast or we can disallow static memref shape completely, so that the users construct memref_reinterpret_cast from the start. If we allow static memref shapes, then we should rename it to just memref_reshape.

On a minor note, should the shape be an i32 memref or an index memref? It’s really modifying the shape as opposed to sizes and strides, right? (The latter is derived.)

memref<index> does not exist in MLIR, because index can be a different type/size depending on the platform, you cannot allocate memref<index>. Only tensor<index> exists. The tensor<index> is not used here in order to skip buffer allocation for it.

For memref_cast splitting, I can’t immediately tell the pros and cons of splitting into two (inferring and erasing). Can you say something about the benefits of the split?

Mostly readability and less confusion. If these ops are split, nobody would be able to do smth like

memref_cast %buf : memref<4x?xf32> to memref<?x4xf32>

accidentally. Also having memref_cast, reinterpret_memref_cast in Standard at the same time is a bit confusing.

1 Like

That was one of the reasons it didn’t exist originally but there was a proposal/discussion on adding it. I thought that went through but I guess not.

Reg the dynamic_reshape one, enforcing a fully dynamic shape as the operand $shape could be weird. You’d want to instead allow representing and canonicalize/simplify. Rewriting to reinterpret_cast sounds better.

Splitting memref_cast sounds fine to me.

Sounds good. Will do.

Thanks for looking into unifying these! I agree with the general direction but have a question:

MemRefReinterpretCastOp and MemRefDynamicReshapeOp seem to serve the same purpose but the dynamic version is more restricted. One could express the MemRefDynamicReshapeOp as an MemRefDynamicReinterpretCastOp if one was also to compute offsets, strides, etc. Could you explain the motivation for not doing this?

Also, should we have a static MemRefReshapeOp that also avoids the explicit offset and stride computations?

Final remark, I think MemRefInferringCastOp is misleading in its name. The op does not infer a new type. Instead, it is more like a refinement to a specified type. MemRefRefiningCastOp maybe? But that is at the bike-shedding level.

MemRefReinterpretCastOp and MemRefDynamicReshapeOp seem to serve the same purpose but the dynamic version is more restricted. One could express the MemRefDynamicReshapeOp as an MemRefDynamicReinterpretCastOp if one was also to compute offsets, strides, etc. Could you explain the motivation for not doing this?

I did not want to have MemRefDynamicReinterpretCastOp because it would require an additional memref<?> argument for strides and hence allocation for it. Also I am not sure I see how it will be used. Will users ever want to specify non-default strides for unranked memrefs?

Also, should we have a static MemRefReshapeOp that also avoids the explicit offset and stride computations?

Do you mean the thing that was discussed above? When we allow memref_reshape to take static shape and then we transform it to memref_reinterpret_cast in that case?

Final remark, I think MemRefInferringCastOp is misleading in its name. The op does not infer a new type. Instead, it is more like a refinement to a specified type. MemRefRefiningCastOp maybe? But that is at the bike-shedding level.

I like MemRefRefiningCastOp name. Thanks!

I agree, we can add it when it is needed and then the MemRefDynamicReshapeOp can lower to it. It would still make sense to have the latter form for convenience and also to encode that the result has an identity mapping which would otherwise likely be hard to infer.

Did not read the full thread. I agree with @bondhugula. I would allow static shapes mostly for rewrite convenience. When gradually specializing code to static shapes, it would be helpful if the memref_reshape would accept a static shape and then be canonicalized away.

+1 to memref_refining_cast in place of using inferring.

Thanks for attempting to bring some order in these things!

I have some general comments:

  • I’m not entirely convinced by the separation between erasing/refining casts. We actually do support memref_cast memref<4x?xf32> to memref<?x4xf32> and exercise it in tests, for example. I think a new operation is clearly justified when we want to pattern match it and do different things. Maybe we need to push this separation further: “reinterpret cast” seems also erasing for shape information, but may be seen as “refining” when it introduces rank into hitherto unranked memref.
  • I’m not a fan of the “reinterpret cast” name. Because of C++, it feels like it would support casting element types, but it is not the case.
  • Let’s be proactive and define semantics of these ops wrt aliasing, because it looks like we can lose the benefits of structured type really fast here. For example, I would consider it a reasonable limitation that, after the cast, one can access either exactly the same set of elements as before, or a subset. Otherwise, given a memref<?xf32, offset: ? (actually 0), strides: [2]>, we can obtain a memref<?xf32, offset: ? (actually 1), strides: [2]> that points to completely different data. IMO, such behavior should be undefined and the compiler should be able to make some assumptions about the subsets of data being accessed. Otherwise, we risk losing lots of static information real quick and should just use pointers instead…

I would attempt to build a design space of various conversions, and look at which dimensions of the space materialize as new operations. For example, the dimensions may be:

  • adds/removes information;
  • if adds information, is the addition static or dynamic;
  • affects sizes/offset/strides/rank.

We can then say:

  • if I want to remove information about sizes, use op1; (memref<42x42xf32> -> memref<?x?xf32>)
  • if I want to add static information about rank, use op2; (memref<*xf32> -> mermef<?x?xf32>)
  • if I want to add dynamic information about rank, use op3; (memref<*xf32>, index... -> memref<*xf32>)
  • etc.

Ideally, we should cover the entire space with the set of ops we have. Note that the inverse side of having orthogonal ops is composability. Think about it as having a sequence of addi/muli vs. having an affine.apply. If we want to reason about the effect of a composition of operations (again, aliasing is my battle horse here), it tends to get exponentially harder with the number of operations in the sequence (for analysis, the ideal situation is to be able to fold everything into one op and disallow chaining, like we did with affine.apply). That being said, I don’t necessarily argue for having a single uber-op (I can write an internal combinator for the analysis), but more giving a different perspective to consider.

Side note, may be we should not insist on calling these ops something_cast. memref_erase_rank and memref_refine_rank sound just fine, for example.

@pifon2a in your original post description of the current memref_cast, have you missed the cases where the cast is from a static to a dynamic memref and vice versa? (erase shape as opposed to rank) , i.e., memref<42x10xf32> -> memref<?x?xf32> or memref<?x10xf32> -> memref<?x?xf32> or memref<?x?xf32> -> memref<10x10xf32>. When memref_cast was first introduced in MLIR, these were the only things it did (same rank, replace '?'s with constants and the other way round).

Thank you for the comments.

I’m not entirely convinced by the separation between erasing/refining casts. We actually do support memref_cast memref<4x?xf32> to memref<?x4xf32> and exercise it in tests, for example.

Why would we ever need such conversion outside of the tests? It can be a composition of memref_refining_cast and memref_erasing_cast. I liked your idea of splitting even further and dropping _cast. memref_cast in that case would become 4 ops:

  1. refine_rank
  2. erase_rank
  3. refine_shape
  4. erase_shape

Side questions to you about performance: if we have several chained ops that create new LLVM structs for memref descriptors, e.g, first refine_rank, refine_shape, reshape, does that affect compilation, runtime? what can be canonicalized/folded away?

I think a new operation is clearly justified when we want to pattern match it and do different things. Maybe we need to push this separation further: “reinterpret cast” seems also erasing for shape information, but may be seen as “refining” when it introduces rank into hitherto unranked memref.

Actually, nothing stops us from making reinterpret_cast accept and output ranked memrefs only. That’s a good point. It can be prepended by refine_rank operation (or whatever name we select).

I’m not a fan of the “reinterpret cast” name. Because of C++, it feels like it would support casting element types, but it is not the case.

I am fine renaming it if someone comes up with a better name.

…should be undefined and the compiler should be able to make some assumptions about the subsets of data being accessed. …

Yes, it definitely should be. But not all of these checks can happen in compile-time. Are there any plans to have smth like a dbg mode in which compiler would insert more IR for verification in run time?

I forgot. thanks for noticing. I will update it now.

Added later: it seems to be impossible to edit my first post in-place. I can post the new set of ops after we all agree.

Compilation time is a bit hard to predict: we get more ops to traverse (and potentially a couple of extra switches/dyn-casts in the traversal code), but each op is simpler (so less switches on the “subtype” of the op in the traversal code). So ultimately I think these should be equivalent.

As for runtime, shape manipulation is insertelement/extractelement. Currently each cast unconditionally writes all elements of the new descriptor, and conditionally reads elements of the old descriptor (dynamic components only). I don’t expect much difference in these. If we allow the descriptor to contain undef/garbage for static parts of the shapes, refine_rank would become a noop in the lowered code, but it’s beyond the scope of this proposal. Rank is tricker because unranked descriptors need allocation. With a special pair of ops to create unranked memrefs, it will become easier to track the lifetime of these allocations IMO.

Regarding canonicalization, again, I expect the patterns to become simpler erase_rank / refine_rank can cancel out in absence of intermediate uses, same is true for the shapes. It should to fold chains of refines into a single op and bring it closer to its use.

We may also want a story for pseudo-void * to memref (currently done by std.view) and inverse (currently impossible) casts that would give you an equivalent of reinterpret.

We already have an std.assert, we can add a pass that sprinkles assertions around memref ops and let the user call it. Alternatively, we can have an option for the std-to-llvm conversion that injects such assertions and lets them get converted by the infra.

I like the idea of a separate pass that adds assertions where it is necessary. I will start implementing the ops then.

2 Likes