[RFC][Standard] Memref cast ops
TF code generation for unranked and dynamically-shaped inputs required additional memref cast ops that don’t modify the underlying data, but change the memref descriptor.
Example: CWise unary op
For unranked code generation for cwise ops we need an operation that can take UnrankedMemRefType
and flatten it into a 1D vector, i.e. memref<?xelement_type>
. After the unary operation is computed, the vector has to be reshaped back to it’s original shape.
StaticMemRefCastOp
, DynamicMemRefCastOp
and ReshapeMemRefCastOp
are now a part of LMHLO. Conversion patterns to LLVM also live within TF code base. Since these operations are quite generic, they might provide value to a much wider audience.
Moreover, there is already MemRefCastOp
in Standard used for type erasure/inference. It might become confusing to have so many cast ops.
Before moving LMHLO memref cast operations to Standard, I suggest to revisit their design.
Previous discussions:
Current State Overview
MemRefCastOp in Standard
Argument
AnyRankedOrUnrankedMemRef:$source
Application
- Erasing rank
memref_cast %buf : memref<4x?xf32> to memref<*xf32>
- Erasing static offset, size, stride information
memref_cast %buf : memref<12x4xf32, offset:5, strides: [4, 1]> to
memref<12x4xf32, offset:?, strides: [?, ?]>
- Cast unranked memref to a concrete shape
memref_cast %buf : memref<*xf32> to memref<4x?xf32>
- Cast dynamically-shaped memref to a concrete shape
memref_cast %buf : memref<12x4xf32, offset:?, strides: [?, ?]> to
memref<12x4xf32, offset:5, strides: [4, 1]>
Remarks
The name of the operation does not reflect what is actually happening. It might become a problem for IR readability, when more memref cast ops are added to the Standard dialect. This operation can be used for type erasure and type inference. I think it would be helpful to split this op into MemRefErasingCastOp
and MemRefInferringCastOp
.
StaticMemRefCastOp in LMHLO
Argument
MemRefOf<AnyType>:$source
Application
- Modify the offset, sizes and strides of a statically shaped memref.
lmhlo.static_memref_cast %buf : memref<1x5xf32> ->
memref<5xf32, offset: 2, strides: [1]>
DynamicMemRefCastOp in LMHLO
Arguments:
MemRefOf<AnyType>:$source
Variadic<Index>:$sizes
Variadic<Index>:$strides
Application
- Modify sizes and strides of a memref using the values computed in runtime.
lmhlo.dynamic_memref_cast %buf(%size0, %size1)[%stride0, %stride1]
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
Remark
This operation is almost equivalent to lmhlo.static_memref_cast
except it does not extract sizes and strides from the result type, but from $sizes
and $strides
arguments.
ReshapeMemRefCastOp in LMHLO
Arguments:
AnyRankedOrUnrankedMemRef:$source
MemRefRankOf<[AnySignlessInteger], [1]>:$shape
Application
- Modify sizes of an unranked/ranked memref when the
$shape
argument has static length
// Reshape statically-shaped memref.
lmhlo.reshape_memref_cast %buf(%shape)
: (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32>
// Reshape unranked memref.
lmhlo.reshape_memref_cast %buf0(%shape0)
: (memref<*xf32>, memref<1xi32>) to memref<?xf32>
- Modify sizes of an unranked/ranked memref when the
$shape
argument has dynamic length
// Reshape dynamically-shaped 1D memref.
lmhlo.reshape_memref_cast %buf(%shape)
: (memref<?xf32>, memref<?xi32>) to memref<*xf32>
// Reshape unranked memref.
lmhlo.reshape_memref_cast %buf0(%shape0)
: (memref<*xf32>, memref<?xi32>) to memref<*xf32>
Remark
This operation is capable of reshaping any ranked/unranked memref with a dynamic/static shape length and output unranked or ranked memref. Notice, that it is not possible to get a ranked memref result if $shape
argument has a dynamic length.
When this operation is converted to LLVM dialect, strides are computed in the ConversionPattern
. At the moment, the strides model the identity map only.
Also note that dynamic_memref_cast
is almost identical to reshape_memref_cast
when $shape
has static length. There are two minor differences that can be eliminated in the new design. reshape_memref_cast
supports unranked $source
input and dynamic_memref_cast
has $strides
argument.
Proposed ops
I propose to split std.memref_cast
into std.memref_erasing_cast
and std.memref_inferring_cast
depending on whether it adds more info to the result type or not.
lmhlo.static_memref_cast
, lmhlo.dynamic_memref_cast
and static-length shape case of lmhlo.reshape_memref_cast
will be merged into std.reinterpret_memref_cast
(thanks to Uday for the op name).
Dynamic-length shape case of lmhlo.reshape_memref_cast
will become std.memref_dynamic_reshape
.
I am open to suggestions, especially, when it comes to naming and printing of the ops.
MemRefErasingCastOp in Standard
Argument
MemRefOf<AnyType>:$source
Application
- Erasing rank
memref_erasing_cast %buf : memref<4x?xf32> to memref<*xf32>
- Erasing static offset, size, stride information
memref_erasing_cast %buf : memref<12x4xf32, offset:5, strides: [4, 1]> to
memref<12x4xf32, offset:?, strides: [?, ?]>
MemRefInferringCastOp in Standard
Argument
AnyRankedOrUnrankedMemRef:$source
Application
- Cast unranked memref to a concrete shape
memref_inferring_cast %buf : memref<*xf32> to memref<4x?xf32>
- Cast dynamically-shaped memref to a concrete shape
memref_inferring_cast %buf : memref<12x4xf32, offset:?, strides: [?, ?]> to
memref<12x4xf32, offset:5, strides: [4, 1]>
MemRefReinterpretCastOp in Standard
Arguments:
MemRefOf<AnyType>:$source
Index:$offset
Variadic<Index>:$sizes
Variadic<Index>:$strides
Application
- Modify sizes and strides of an unranked/ranked memref using the values computed in runtime.
memref_reinterpret_cast %buf [
offset: $offset, sizes: %size0, %size1, strides: %stride0, %stride1
] : memref<?x?xf32> -> memref<?x?xf32, offset: ?, strides: [?, ?]>
memref_reinterpret_cast %unranked [
offset: $offset, sizes: %size0, %size1, strides: %stride0, %stride1
] : memref<*xf32> -> memref<?x?xf32, offset: ?, strides: [?, ?]>
MemRefDynamicReshapeOp in Standard
Arguments:
AnyRankedOrUnrankedMemRef:$source
MemRefRankOf<[AnySignlessInteger], [1]>:$shape // strictly memref<?xIntegerType>
Application
- Modify sizes and strides of an unranked/ranked memref using the values computed in runtime.
// Reshape dynamically-shaped memref.
memref_dynamic_reshape %src(%shape)
: (memref<?x?xf32>, memref<?xi32>) to memref<*xf32>
// Reshape unranked memref.
memref_dynamic_reshape %src(%shape)
: (memref<*xf32>, memref<?xi32>) to memref<*xf32>