Bufferize dynamic shape control flow

Hi all,

We are currently upstream our dynamic shape compiler (a.k.a DISC). I’d like to discuss supporting dynamic shape control flow in advance and would appreciate any feedbacks.

1, Control flow overview

  • structure control flow (e.g. mhlo/lmhlo whileOp, IfOp). Using region to represent conditional/while semantics.
  • CFG. traditional version (based on branch op). More general, but harder to optimize.

2, Dynamic shape control flow

shapes of values inside while/if body are unknown and may be variant between different iterations.

Some examples:

func @dynamic_shape_while(arg1: tensor<?x?xf32>) {
  %1 = "mhlo.while"(%arg1) ( {
  ^bb0(%targ1: tensor<?x?xf32>):
    %true = mhlo.constant dense<1> : tensor<i1>
    %0 = "mhlo.multiply"(%targ1, %targ1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
    %1 = "mhlo.subtract"(%targ1, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>  
    %2 = "mhlo.compare"(%targ1, %1) {comparison_direction = "LT"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
    %3 = "mhlo.reduce"(%2, %true) ( {
    ^bb0(%lhs: tensor<i1>, %rhs: tensor<i1>):
       %4 = "mhlo.and"(%lhs, %rhs) : (tensor<i1>, tensor<i1>) -> tensor<i1>
      "mhlo.return"(%4) : (tensor<i1>) -> ()
    }) {dimensions = dense<[0,1]> : tensor<2xi64>} : (tensor<?x?xi1>, tensor<i1>) -> tensor<i1>
    "mhlo.return"(%1) : (tensor<i1>) -> ()
  },  {
  ^bb0(%targ1: tensor<?x?xf32>):
     %0 = "mhlo.multiply"(%targ1, %targ1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
     %1 = "mhlo.concatenate"(%0, %0) { dimension = 0 : i64 } : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
     "mhlo.return"(%1) : (tensor<?x?xf32>) -> ()
  }) : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
  return %1 : tensor<?x?xf32>
}

3, Bufferization for dynamic shape control flow

3.1 Solution #1: mhlo while → CFG in tensor level → bufferization.

To my best knowledge, bufferization (and deallocation) pass does not support dynamic shape very well. For example, deallocation pass could not support dynamic CFG loop according to this and this. In fact, I think it’s very hard to implement a general deallocation pass for dynamic shape control flow CFG. Please correct if I’m wrong.

3.2 Solution #2: mhlo while → lmhlo while (bufferize in structure format) → CFG in buffer level.

This is our current implementation. However, we have to extand lmhlo.while/if ops since they are not support dynamic shape well. Take lmhlo.while op as an example.

// original definition
def LHLO_WhileOp: LHLO_Op<"while", [
      DeclareOpInterfaceMethods<RegionBranchOpInterface>,
      DeclareOpInterfaceMethods<LoopLikeOpInterface>]> {
  let summary = "While operator";
  let description = [{
    Returns the result of executing a body function until the cond body returns
    true.

    See https://www.tensorflow.org/xla/operation_semantics#while.
  }];
  let arguments = (ins
    Arg<Variadic<LHLO_PredBuffer>, "", [MemWrite]>:$cond_val,
    OptionalAttr<I64Attr>:$trip_count);

  let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
}

// Dynamic shape version for LHLO_WhileOp
def LHLO_WhileOp: LHLO_Op<"dynamic_while", [
      DeclareOpInterfaceMethods<RegionBranchOpInterface>,
      DeclareOpInterfaceMethods<LoopLikeOpInterface>]> {
  let summary = "While operator";
  let description = [{
    Returns the result of executing a body function until the cond body returns
    true.

    See https://www.tensorflow.org/xla/operation_semantics#while.
  }];
  let arguments = (ins
    Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$args,
    OptionalAttr<I64Attr>:$trip_count);
    
  let results = (outs Variadic<LHLO_Buffer>:$results)

  let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
}

The main differences are:

  • dynamic_while op does not accept or return tuple type.
  • dynamic_while op accepts a list buffer and return a list of buffer.
  • the buffers returns from the dynamic_while op are transferred to the caller and it’s the responsibility of the user to deallocate these buffer correctly.
  • while body needs to deallocate its operands.

Some examples:

func @test(%arg0: memref<?x?xf32>) -> memref<?x?xf32> {
  // owership of %0 and %1 are transfered to the caller
  %0, %1 = "lmhlo.dynamic_while"(%arg0, %arg0) {
    // conditional body
  } {
  ^bb0(%targ0: memref<?x?xf32>, %targ1: memref<?x?xf32>)
    // loop body
    // allocate new buffers (size may change)
    // use %targ0 and %targ1 to fill the buffer
    // deallocate %targ0 and %targ1
    // return new buffers
  }
  memref.dealloc %0 : memref<?x?xf32>
  return %1: memref<?x?xf32>
}

// lower dynamic_while to CFG
^init(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>):
  %0 = memref.alloc(...) // alloc_like(%targ0)
  %1 = memref.alloc(...) // alloc_like(%targ1)
  "lmhlo.copy"(%arg0, %0) : (memref<?x?xf32>, memref<?x?xf32>) -> memref<?x?xf32>)
  "lmhlo.copy"(%arg1, %1) : (memref<?x?xf32>, memref<?x?xf32>) -> memref<?x?xf32>) 
  br ^cond(%0, %1)
^cond(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>):
  %pred = ...
  cond_br %pred, ^body(%arg0, %arg1), ^exit(%arg0, %arg1)
^body(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>):
  // allocate new buffers: %targ0, %targ1 (size may change)
  // use %arg0 and %arg1 to fill the buffer
  // deallocate %arg0 and %arg1
  br %cond(%targ0, %targ1)
^exit(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>):
  // return %arg0 and %arg1

Disadvantages:

  • reduant copies especailly if the loop body does not execute at all.
  • buffer management is different inside the loop region (e.g. operands are need to be deallocated).

Another problem is how to compose with bufferize-pass.

Is any suggestions?

Thanks!

Hi - thanks for the note. We’ve been on a similar journey with IREE for a while now and have realized some good simplifications. We’re not done yet, but maybe some of what we have learned will be useful. There are so many different levels to solve this problem that what we have done may or may not translate – so just putting it out there as a data point. Since we’ve landed in a pretty stable place, we’re preparing a round of tech talks to present in the ~next month with details of our approach.

IREE takes this option, lowering all general control flow early to CFG (i.e. MHLO/TOSA while/if). This certainly has implications and detractors but we have found it easier to do the kinds of things we want to do with general program control flow in CFG form, reserving structured control flow for representing loop nests and computations. IREE also has a strong memory model and device/task placement model that we lower into early while still in the tensor domain, and this simplifies buffer planning a lot. I suspect you are going to have a hard time implementing good/general buffer alloc/dealloc at this level with just the current structural elements in MHLO today. Another difference is that at this top-level, IREE’s buffers are not just naked pointers (i.e. memrefs) but ref-counted entities that make it easier to reason about ownership transfers between (especially asynchronous or non-straightline) components of the system. Beyond simple programs, we found that it gets increasingly twisty to do everything one wants to be doing without introducing at least minimal runtime constructs with strong contracts for managing things.

I’m not an expert on LMHLO (we don’t use it), but it seems to me for the structures and design space you are operating in, extending LMHLO constructs to carry the information you need is a reasonable option. As you note, if you have extra entities to carry, you have to do that somewhere. Without a stronger memory/ownership model, I don’t immediately see a better way to avoid the disadvantages you note, but I wouldn’t be surprised if there are some special case optimizations you can do to improve some of it.

Thanks for your reply.

This is a very interesting idea. Is there any more detail about this?

We tried something-like this before (Not using ref-count buffer in compiler side, but in runtime side) and used it to avoid redundant copies. For example:

// Original case (redundant alloc & copy)
func @test(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) -> memref<?x?xf32>
  %0 = "lmhlo.dynamic_if"(%targ0, %targ1) ({
    // true branch
    // Here is the redundant alloc & copy
    %0 = memref.alloc(...) // alloc_like(%targ0)
    "lmhlo.copy"(%targ0, %0) : (memref<?x?xf32>, memref<?x?xf32>) -> memref<?x?xf32>
    return %0 : memref<?x?xf32>
  }, {
    // false branch
    // Here is the redundant alloc & copy
    %0 = memref.alloc(...) // alloc_like(%targ1)
    "lmhlo.copy"(%targ1, %0) : (memref<?x?xf32>, memref<?x?xf32>) -> memref<?x?xf32>
    return %0 : memref<?x?xf32>
  })

  %1 = memref.alloc(...) // alloc_like(%0)
  "lmhlo.abs"(%0, %1) : (memref<?x?xf32>, memref<?x?xf32>) -> memref<?x?xf32>
  memref.dealloc(%0) : memref<?x?xf32>
  return %1 : memref<?x?xf32>
}

After optimization:

// After optimization:
//   Introduce a op to increase a ref-count of a buffer (memref) in runtime side.
//   and return a new memref to the compiler side (which can be deallocate separately)
// The above optimization is based on the assumption that buffer is only been written
// once and is read-only after that.

func @test(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) -> memref<?x?xf32>
  %0 = "lmhlo.dynamic_if"(%targ0, %targ1) ({
    // true branch
    // %0 and %targ0 share same underline buffer
    // Deallocate %0 or %targ0 will decrease reference count by 1 in the runtime side
    // and the buffer is really freed only when the ref-count is zero.
    %0 = "ral.read_only_copy"(%targ0) : memref<?x?xf32>
    return %0 : memref<?x?xf32>
  }, {
    // false branch
    // %0 and %targ1 share same underline buffer
    // Deallocate %0 or %targ1 will decrease reference count by 1 in the runtime side
    // and the buffer is really freed only when the ref-count is zero.
    %0 = "ral.read_only_copy"(%targ1) : memref<?x?xf32>
    return %0 : memref<?x?xf32>
  })

  %1 = memref.alloc(...) // alloc_like(%0)
  "lmhlo.abs"(%0, %1) : (memref<?x?xf32>, memref<?x?xf32>) -> memref<?x?xf32>
  memref.dealloc(%0) : memref<?x?xf32>
  return %1 : memref<?x?xf32>
}

The above example only uses ref-count in runtime side, the compiler side still use simple memref directly, thus it’s coupled with runtime. It may enable more aggressive optimizations if the ref-counted memory type system is introduced in the compiler side.

Hi,

This is slightly off topic, but why do you need to ultimately lower to CFG? Can you keep structured control flow throughout the pipeline?

Can you solve both of these issues by not deallocating operands in the while body in the first iteration?

– Sanjoy

@sanjoy_das_google Thanks.

1, We currently choose to use host-side and device-side joint codegen. Thus, the control flow op eventually needs to be lowered to CFG and then to llvm IR.

2, We could choose to not deallocate the operand buffers of body region of a while op at the first iteration. However, in that case, we have to track potential buffer alias. Sometimes, that may not be very easy.

For example

result_buffer1, result_buffer2 = while(input_buffer1, input_buffer2) {
   %cond = ...
   %0 = if (%cond, input_buffer1, input_buffer2) {
      return input_buffer1;
   } else {
      return input_buffer2;
   }
   %1 = memref.alloc(...) 
   use(%1)

   forward_to_next_iteration_or_break_with(%1, %0)
}

One possibility (which is just a variant of what you and Stella suggested above) is to have explicit operations & values for managing memref refcounts. Each tensor could then be mapped to a memref and a refcount, and the emitted IR would explicitly manipulate this refcount.

So

^bb0(%x: tensor, %y: tensor)
  %a = add(%x, %y)

would become:

^bb0(%x: memref, %x_refcnt : refcnt, %y: memref, %y_refcnt : refcnt)
  // Increment by <number of uses>-1
  inc_refcount(%x_refcnt, 0)
  inc_refcount(%y_refcnt, 0)

  %a_memref = alloc()
  %a_refcnt = alloc_refcnt() // == 1 at initialization
  add(%x, %y, %a)

  // dec_refcount may deallocate the memref
  dec_refcount(%x, %x_refcnt, 1)
  dec_refcount(%y, %y_refcnt, 1)

The benefit of making these refcounts values & operations explicit is that we should be able to optimize these out in most simple cases, allowing XLA-style memory planning. In truly dynamic cases like the example you have, the refcount operations will persist and will provide correct allocation & deallocation. refcnt could just be memref<i32>, or a custom type.

(And, just to be clear, in this world the LHLO while loop body will consume & produce a list of memrefs and refcounts.)

thanks my issue has been fixed.