Commenting here based on experience in IREE. I think having a list of affine maps that compose to arbitrary striding is very hard to support and reason about in a compilation stack. IMO, allowing the composition map to allow accesses where the least significant dimension is not the fastest varying dimension just adds complexity to analysis passes that are trying to optimize the data accesses in the generated code. You not only need to look at the shape, you also need to look at the stride. It also makes some trivial ops that should be no-ops, not a no-op. For example, the following reshape is a no-op (effectively just changes metadata)
%1 = reshape %0 : memref<?x?x?xf32> to memref<?x?xf32>
where as this is not a no-op
%1 = reshape %0 :
memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (s0 + d0*s1 + d1*s2 + d2 *s3)> to
memref<?x?x?xf32>, affine_map<(d0, d1)[s0, s1, s2] -> (s0 + d0*s1 + d2*s2)>
It depends on the values of the strides dynamically. IMO this is conflating data layout transformations and metadata change to just change the shape (without having to move the underlying data). A more canonical representation is to have the data movement ops represented explicitly.
The only valid use case of the affine maps in MemRefType
used in IREE is where you are accessing an N-D slice of a memref. So you only need the offsets and extents of each dimensions, and there is a deterministic ordering between shape dimensions and order in which the data is laid out, i.e. the inner most dimension is also the fastest varying dimension (and so on).
A side-effect of having affine composition map is that an operation does not have all the information needed to be lowered to say LLVM/SPIR-V dialects since the values of the strides (s1
, s2
, etc.) are available at the time the value is created and not at the op use site. So either
- you need to carry this information during lowering using an additional data structure like the Memref descriptor, as is done in the conversion to LLVM dialect, or
- you need to walk the use-def chains the get the values of the strides.
For the LLVM backend, using a memref descriptor might be OK for most cases since the LLVM (proper) passes like SROA are able to break up the memref descriptor before actual code-generation. For backends like SPIR-V it is unreasonable to expect all driver compilers to be able to handle such complex transformations (most driver compilers are JIT compilers, and they need to be cognizant of compilation time).
The counter argument to the above is that if you dont want to support arbitrary striding, then the specific compilation stack errors out. Thats fair, and thats what is done in IREE. The question for me is that this feature exists in MLIR and is not used in core, is it paying for itself? There might be some simplifying assumptions that can be made to reduce the complexity of, say LLVM lowering, if such a feature did not need to be supported (and my view here is that this is not a feature, but a problematic representation)