LLVM Discussion Forums

Beginner Q: Help with loops/affine/linalg

Hi folks, i’m just getting started with MLIR and trying to learn what is possible. I hope this is the right place to ask usability questions (if not please pardon my ignorance).

I’m looking some simple examples and trying to see what the MLIR framework has working for loop fusion.

I wrote up a simple loop with affine.for and affine.load/affine.store and i was able to get --affine-loop-fusion to fuse my loops (hooray!)

I then tried something a bit different, and maybe something not intended to be supported. I tried mixing linalg operations (slice/matmul) inside my affine.for loop and no longer see loop fusion happening. Stepping through the code, i see that the (affine loop) fusing code looks only for affine.load/affine.store. This kind of makes sense as i suppose those are the only memory operations it fully understands.

So my question for folks here is: what are some suggestions on on to accomplish i’m trying to do:

  • i want be able to track the dependencies between the loops
  • i want to enable loop fusion.
  • i do not want to lower my matrix multiplication into linear operations

My example code MLIR text is below

    func @test(%kernel1 : memref<256x128xf32>, %kernel2 : memref<256x256xf32>, %P0 : memref<128x1xf32>, %out : memref<256x1xf32>) {

    %c0 = constant 0 : index
    %c1 = constant 1 : index
    %c2 = constant 2 : index
    %c128 = constant 128 : index
    %c256 = constant 256 : index

    %r0to128 = linalg.range %c0:%c128:%c1 : !linalg.range
    %r0to256 = linalg.range %c0:%c256:%c1 : !linalg.range
    %r0to1 = linalg.range %c0:%c1:%c1 : !linalg.range

    // do a CrossProduct (256x128, 128x1) -> 256x1
    // in two iteration, doing 128 elements at a time
    affine.for %0 = 0 to 2 {
        %min = muli %c128, %0 : index
        %max = addi %min, %c128 : index
        %r0 = linalg.range %min:%max:%c0 : !linalg.range

        %sub_kernel1 = linalg.slice %kernel1[%r0, %r0to128] : memref<256x128xf32>, !linalg.range, !linalg.range, memref<128x128xf32>
        %sub_out = linalg.slice %out[%r0, %r0to1] : memref<256x1xf32>, !linalg.range, !linalg.range, memref<128x1xf32> 
        linalg.matmul(%sub_kernel1, %P0, %sub_out) : memref<128x128xf32>, memref<128x1xf32>, memref<128x1xf32>
    }

    // do a CrossProduct (256x256, 256x1) -> 256x1
    // in two iteration, doing 128 elements at a time
    affine.for %0 = 0 to 2 {
        %min = muli %c128, %0 : index
        %max = addi %min, %c128 : index
        %r0 = linalg.range %min:%max:%c0 : !linalg.range

        %sub_kernel2 = linalg.slice %kernel2[%r0, %r0to256] : memref<256x256xf32>, !linalg.range, !linalg.range, memref<128x256xf32>
        %sub_out = linalg.slice %out[%r0, %r0to1] : memref<256x1xf32>, !linalg.range, !linalg.range, memref<128x1xf32> 
        linalg.matmul(%sub_kernel2, %out, %sub_out) : memref<128x256xf32>, memref<256x1xf32>, memref<128x1xf32>            }

    return
}

thanks,

ian Bearman
Principal Software Engineering Manager
Microsoft Visual C++ Team: Optimization & Code Generation
/* Making your code faster, smaller, smarter! */

Hello @manbearian,

This is expected, affine.for does not understand linalg constructs and there is no plan for that atm, AFAIK.

These are key reasons why Linalg exists in the first place (see the Rationale).

Despite being higher-level than SSA, the notion of loop is still too fine-grained for this purpose.
In this case, I’d recommend using Linalg fusion. This transformation is inspired from Halide, you can think of it as a “tile and fuse producer(s) while preserving mapping to libraries”. It is fundamentally coarser grained than single or multi-loop loop fusion and preserves the type of semantics you care about. Note that in the limit with tile sizes={1, data_size} + canonicalizations you get close to loop-fusion but still have guarantees re. mapping to a library: the mapping resists transformation by design.

This transformation can be achieved with the following loop transformations: multi-loop fusion + multi-loop tiling + multi-loop distribution of the tiled loops + infer matmuls and friends from scalar tiled loops. There have been quite some efforts over the years to get these things to scale beyond simple kernels, including our own efforts from previous lives :).
The Linalg rationale more generally explains why we think it is time for stepping back and reconsidering abstractions. Of course if affine.for does solve parts of your problems, you can use Linalg, do some transformations, lower to affine.for, do some more transformations before going to loop.for and lower-level abstractions.

The issue you will encounter is that Linalg tiling produces loop.for instead of affine.for. This is only a short-term limitation if you really want the affine analyses on the outer-tile (a.k.a inter-tile) loops, which I would expect to be an almost non-existent case (I’d be happy to discuss different use cases and scenarios, we have some experience in the field :wink: ).

The reason linalg tiling uses loop.for is because parametric tiling is an important adaptability and load-balancing tool that affine.for does not support.
Still if needs, it is trivial to make linalg tiling emit affine.for but you’ll be limited to static constant tile sizes.

HTH!

Hi, thanks for replying. What you wrote was indeed helpful. My background is more in general compiler optimization then in in HPC and loop analysis, so i’m still trying to learn a lot of the ideas used the various MLIR dialects and trying to map them to what i do know.

I don’t believe that affine.for is required for what i’m trying to do; i was just using it as a tool to represent the iterative nature of the operations i’m trying to model in MLIR.

What is the right operation to use at the linalg level to represent an iterative operation? i couldn’t figure that out from reading the documentation how to represent my iterations.

Your continued help is appreciated,

ian Bearman
Principal Software Engineering Manager
Microsoft Visual C++ Team: Optimization & Code Generation
/* Making your code faster, smaller, smarter! */

This isn’t actually accurate. It is possible to represent parametrically tiled loops with affine.for.

loop.for %i = 0 to %N step %B
affine.for %i = 0 to affine_map<()[s0, s1] -> (s0 ceildiv s1)> (%N, %B)

The approach is to normalize the step to one by scaling your bounds (since %B is expected/known to be positive). I think I’ve mentioned this in the past — there isn’t a loop.for that an affine.for cannot structurally represent.

The above of course makes the lower and upper bound maps semi-affine, which many of the analyses on affine for ops don’t support - but there are still canonicalizations and simplifications that just work (composition, unused/duplicate operand removal, const prop/folding). In any case, it only provides more than what loop.for does in terms of structure and the lowering from affine.for to loop.for always succeeds. With affine grayboxes, even the bounds can trivially become pure affine in cases like these + the restrictions on what dim/symbols can be effectively go away.

@nicolasvasilache has already answered most of your questions above. Note that the fusion pass on affine.for’s mainly takes into account only the affine load’s/store’s as far as memory goes (and then the SSA edges). The affine fusion utilities themselves could be reused to perform the actual fusion you want (the mechanics of it) on the snippet you show above. It doesn’t however have the support to analyze the validity (correctness) of the fusion of the two nests you have. The affine dialect also does not / cannot depend on the linalg dialect – the latter is higher order than it. One would have to design/use op interfaces to opaquely query the necessary information on such linalg ops if at all one wants to build something automatic that could perform such fusion as an affine dialect transform.

Thanks, again for your help. Something that still isn’t clear to me: What is the right operation to use at the linalg level to represent an iterative operation? i couldn’t figure that out from reading the documentation how to represent my iterations.

Sorry for the delay, the past few days have been a bit tricky administratively.
Atm linalg does not have the notion of an iterative operation, the thinking is to use loops for those, but then you would hit the same fusion limitations that you already described.

An alternative would be to try and express your computations at a higher level, something like

linalg.matmul_like_1(%kernel_1, %P0, %out): 
  (memref<256x128>, memref<128x1>, memref<256x1>)
linalg.matmul_like_2(%kernel_2, %out, %out2): 
  (memref<256x256>, memref<256x1>, memref<256x1>)

Then tileAndFuse with a proper tile size could produce the form you want.

I am using linalg.matmul_like because it is likely not just a matmul but you may be able to use a proper linalg.generic op.

However, you may see that I am using out2 to refer to the second output.
This is because as I read your code it wasn’t clear to me what is the semantics of the computation you wish to express.

Looking at your second loop, I seem to distinguish a loop carried dependence:

  1. %0 = 0 does %sub_kernel2 * %out but updates the first half of out in-place (I’ll call it %modified_out)?
  2. %0 = 1 does %sub_kernel2 * %modified_out

Is this loop-carried dependence intentional?
If so, I’d suggest just writing 4 linalg.matmul with the proper std.subview and the proper dependencies.

Since I do not know how much the above is a real interesting kernel or just a small simple proxy to get the ball rolling it is a bit hard to git definitive advice.

Would it be possible to describe the computation you’d like to express in some other form (maybe equational)? Since it looks a bit non-standard I want to be sure I understand what you are trying to achieve.

Side note, I would suggest to use std.subview which is “understood” by linalg tiling and fusion.
linalg.slice is more meant to be “rank-reducing”, it can canonicalize into std.subview in your cases but it isn’t done atm.

Even when canonicalizations are there, putting all the patterns together into useful transformations is still going to be necessary.

So I’d suggest just using std.subview unless you need to reduce ranks.

Thanks!

Thanks again for getting back to me. We’re all dealing with the craziness as best we can.

You’re right about the loop carried dependency for out, it was unintentional; the intent was to have a temporary storage location between loops.

The kernel i wrote here was intended to be something simple so i can understand the IR. It doesn’t directly map directly to any real code. But it does represent the larger problem i’m trying to solve: i’d like to represent an enviornment in which matrix-vector multiplication of fixed size is a basic unit of execution. So in this example it represents a simple case of two mvm (matrix-vector-multipy) that feed each other that are wider then the fixed size (128 in this case). The naive lowering generates a loop for each one which ideally are fused into as single loop. i was playing with the current dialects to see if i could directly represent this view of the program.

On using std.subview… i don’t see this documented anywhere, is this new? I thought i remembered seeing it somewhere but it operated on values instead of memrefs? i cannot find anything on it now looking at https://mlir.llvm.org/docs/Dialects/Standard/.

thanks,
ian

An alternative is also to use vector.contract which has semantics close to linalg.generic and can be used to represent such things.

If you want to stay on pure memref<scalar> then I’d suggest looking at higher-level (e.g. 3-D linalg.generic.

In TensorComprehension-like / einsum notation, I’d would write this to resemble (where “r_” denotes a reduction dimension):

out(i, j) += kernel_1(i, j, r_k) * P0(r_k)
out2(i, j) += kernel_2(i, j, r_k, r_l) * out(r_k, r_l)
linalg.generic(%kernel_1, %P0, %out) #some_traits_1: { region }
  (memref<2x128x128>, memref<128>, memref<2x128>)
linalg.generic(%kernel_2, %out, %out2) #some_traits_2: { region }
  (memref<2x128x2x128>, memref<2x128>, memref<2x128>)

If you then took this form and tileAndFuse’d the second by {2}, it should give you the form.

With vectors it could look like:

linalg.generic(%kernel_1, %P0, %out) #some_traits_3: {region}
  (memref<2xvector<128x128>>, memref<vector<128>>, memref<2xvector<128>>)
linalg.generic(%kernel_2, %out, %out2) #some_traits_4: {region}
  (memref<2x2xvector<128x128>>, memref<2xvector<128>>, memref<2xvector<128>>)

notice however that I have cheated by turning memref<256x256> into memref<2x2xvector<128x128>> which requires moving data around.
Ultimately it depends how much control you have over the input language / data allocation etc and how much you want to rely on analysis and transformations vs using guaranteed abstractions like vectors.

One way we are attacking some of these things is bottom-up from the vector level: this uses C++ metaprogramming to create MLIR that operates on vectors and benchmark vector contraction.

Weird, the docs do not seem to have been updated for some time here it is: https://github.com/llvm/llvm-project/blob/master/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td#L1625. It used to be linalg.subview and graduated to std.subview ~6 months back.

@joker-eph @jpienaar is there something special to do to trigger doc updates?

StandardOps currently doesn’t use the auto-generated documentation. I’m working on fixing this right now.

Should be fixed whenever https://reviews.llvm.org/D76743 gets submitted.