Scf.for range folding optimization pass

Hi all,

Per an open design issue on the IREE project (Upstream improvements for scf.for folding on loop ranges. · Issue #5547 · google/iree · GitHub), I have started hacking on a pass that performs range folding and hoisting on scf.for. For example, in this (simplified) example from the linked post:

func @simple_mul_dispatch_0(%arg0: !vmvx.interface, %arg1: !vmvx.buffer, %arg2: index) {
  %c0 = constant 0 : index
  %c10 = constant 2 : index
  %c1 = constant 1 : index
  %c4 = constant 4 : index
  %0 = vmvx.interface.binding<%arg0 : !vmvx.interface>[0] : !vmvx.buffer
  %1 = vmvx.interface.binding<%arg0 : !vmvx.interface>[1] : !vmvx.buffer
  %2 = vmvx.interface.binding<%arg0 : !vmvx.interface>[2] : !vmvx.buffer
  // ...
  scf.for %arg14 = %c0 to %c10 step %c1 {
    %13 = addi %arg2, %arg14 : index
    %15 = muli %13, %c4 : index
    %16 = vmvx.buffer.load<%0 : !vmvx.buffer>[%15] : i32
    %17 = vmvx.buffer.load<%1 : !vmvx.buffer>[%15] : i32
    %18 = muli %16, %17 : i32
    vmvx.buffer.store<%2 : !vmvx.buffer>[%15], %18 : i32
  }
}

%arg14 is only used once, in the calculation of %13, so we can fold it up into the range, and perform the calculation once instead of 10 times.

%arg2p = addi %arg2, %10 : index
scf.for %arg14 = %arg2 to %arg2p step %c1 {
  %15 = muli %arg14, %c4 : index
  %16 = vmvx.buffer.load<%0 : !vmvx.buffer>[%15] : i32
  %17 = vmvx.buffer.load<%1 : !vmvx.buffer>[%15] : i32
  %18 = muli %16, %17 : i32
  vmvx.buffer.store<%2 : !vmvx.buffer>[%15], %18 : i32
}

likewise, we can repeat this on the muli to scale the range…

%arg2p = addi %arg2, %10 : index
%arg2pub = %muli %arg2p, %c4
scf.for %arg14 = %arg2 to %arg2pub step %c4 {
  %16 = vmvx.buffer.load<%0 : !vmvx.buffer>[%arg14] : i32
  %17 = vmvx.buffer.load<%1 : !vmvx.buffer>[%arg14] : i32
  %18 = muli %16, %17 : i32
  vmvx.buffer.store<%2 : !vmvx.buffer>[%arg14], %18 : i32
}

Obviously, there are restrictions on when this can be applied, but I believe for IREE this optimization will help a lot with addressing math used in the reference VM.

My question is, is this transformation worthy for a pass in the MLIR codebase (so IREE and others can benefit from it upstream), and/or if there are any concerns about said optimization?

Thanks!

1 Like

Seems like general enough to be implemented in the SCF dialect, thanks!

I’ve got a patch up for review at ⚙ D104289 First crack at foor loop range folding pass.. I am new to the MLIR project, would appreciate any tips about how to move forward with this (including comments on correctness, I am not claiming this is complete yet, probably far from it :slight_smile: .