Hello!
I’m doing some experiments to try to better understand the general vector abstraction in MLIR and I have some questions/comments. This directly relates to and adds some points on this previous post ([MLIR] Multidimensional vector abstraction). Let me share a small example to better illustrate my findings and goals. Imaging that we have the following scalar function:
func @scalar_test(%in_out : memref<21xf32>) {
affine.for %i = 0 to 21 {
%ld = affine.load %in_out[%i] : memref<21xf32>
%add = addf %ld, %ld : f32
affine.store %add, %in_out[%i] : memref<21xf32>
}
return
}
Then, we want to write a vector counterpart. Let’s assume that we target AVX2 but we want to write it in a generic way so that we can leverage some of the affine/loop/std optimizations, and then lower it to a hypothetical AVX2 dialect later on, or even leave the hw-specific lowering to LLVM. This is the first thing I wrote, which is not correct in MLIR:
func @cast_test(%in_out : memref<21xf32>) {
%c16 = constant 16 : index
%c20 = constant 20 : index
// Process 16 elements, 8 elements at a time: potentially lowered to YMM ops/regs.
%vec8 = memref_cast %in_out : memref<21xf32> to memref<?xvector<8xf32>>
affine.for %i = 0 to 2 {
%ld8 = affine.load %vec8[%i] : memref<?xvector<8xf32>>
%add8 = addf %ld8, %ld8 : vector<8xf32>
affine.store %add8, %vec8[%i] : memref<?xvector<8xf32>>
}
// Process 4 elements: potentially lowered to XMM ops/regs.
%vec4 = memref_cast %in_out : memref<21xf32> to memref<?xvector<4xf32>>
%ld4 = affine.load %vec4[%c16] : memref<?xvector<4xf32>>
%add4 = addf %ld4, %ld4 : vector<4xf32>
affine.store %add4, %vec4[%c16] : memref<?xvector<4xf32>>
// Process 1 element: Scalar ops/regs.
%ld1 = affine.load %in_out[%c20] : memref<21xf32>
%add1 = addf %ld1, %ld1 : f32
affine.store %add1, %in_out[%c20] : memref<21xf32>
return
}
The code above is not allowed because vector
is considered a memref element type and memref_cast
cannot change the element type of a memref. This means that we cannot memref_cast a vector to a scalar or even a vector to another vector with a different length (e.g., memref_cast %in_out : memref<?xvector<8xf32>> to memref<?xvector<4xf32>>
).
Then, I gave std.view
a try and wrote something like this:
func @view_test(%in_out : memref<84xi8>) {
%c2 = constant 2 : index
%c5 = constant 5 : index
%c16 = constant 16 : index
%c20 = constant 20 : index
// Are the strides correct?
%vec8 = view %in_out[][%c2] : memref<84xi8> to memref<?xvector<8xf32>>
%vec4 = view %in_out[][%c5] : memref<84xi8> to memref<?xvector<4xf32>>
%scalar = view %in_out[][] : memref<84xi8> to memref<21xf32>
// Process 16 elements, 8 elements at a time: potentially lowered to YMM ops/regs.
affine.for %i = 0 to 2 {
%ld8 = affine.load %vec8[%i] : memref<?xvector<8xf32>>
%add8 = addf %ld8, %ld8 : vector<8xf32>
affine.store %add8, %vec8[%i] : memref<?xvector<8xf32>>
}
// Process 4 elements: potentially lowered to XMM ops/regs.
%ld4 = affine.load %vec4[%c16] : memref<?xvector<4xf32>>
%add4 = addf %ld4, %ld4 : vector<4xf32>
affine.store %add4, %vec4[%c16] : memref<?xvector<4xf32>>
// Process 1 element: Scalar ops/regs.
%ld1 = affine.load %scalar[%c20] : memref<21xf32>
%add1 = addf %ld1, %ld1 : f32
affine.store %add1, %scalar[%c20] : memref<21xf32>
return
}
This seems to compile, which is great! However, I see a couple of drawbacks here:
- My function now is not type safe. Using views requires 1Dxi8 memrefs so we basically have to drop all the shape and element type information from the memref parameter, make it opaque and, therefore, rely on the caller to pass the expected buffer.
- Vectorizing a single loop in a function impacts the function signature and all other uses of the vectorized memref in the function.
Using the vector dialect would be another option, but I think that would mean moving my code to another level of abstraction and probably not being able to apply affine optimizations on it, right?
My general feeling is that currently memref_cast
is a bit too constrained and there is no other simple option for memref castings that only “change the number of read/written elements” (scalar<->vector or vector<->vector). Views are really powerful, but I think it’s an overkill to use them for these castings. They were introduced to address a different and more complex kind of problems.
I guess I can summarize the questions and design decisions I would like to better understand as follows:
- Why
vector
is a memref element type? - Why a
memref_cast
can’t convert between: a) a scalar and a vector with the same “element” type; b) two vectors with different vector length and the same “element” type? - What does it mean that an
alloc
or block argument (or any non-memory op on a memref type) has a vector type? Isn’t this unnecessarily adding/enforcing how data has to be read/written at a point where only allocation/shape/layout information should be needed? - What is the best way to represent vector code suitable for the affine/std domain?
Thanks in advance!
Diego