[RFC] Multi-root PDL patterns for kernel matching

Abstract

The goal of this RFC is to extend the current implementation of PDL to support patterns with multiple roots. This extension is needed to match complex patterns that arise in the process of matching kernels in model parallel training of neural networks. We achieve this extension by appropriately extending the pdl_interp dialect, updating the pdl-to-pdl_interp lowering, and generalizing the execution model of the PDL bytecode to support iterative execution.

1. Introduction

Motivation

PDL is a high-level abstraction for the rewrite pattern infrastructure in MLIR. A PDL pattern consists of a pattern (match) region and a rewrite region. The pattern region specifies a DAG of operations starting from one or more operands. The rewrite region specifies, for a single root operation, how this operation and its descendants are altered. Notice the asymmetry: although the pattern region is a general DAG, the current semantics for the rewrite region forces the DAG to have a single root (that is, a tree, with some inner nodes connected). Traditionally, this limitation has not been too restrictive: the rewrite rules used for optimization or lowerings often operate on a single root, rewriting a tree of operations. However, model parallel training of neural networks may tie together the forward and the backward execution of a kernel. This is because doing so leverages locality, allowing the weights to be stored locally to simultaneously implement both the forward and the backward pass. This use cases requires the patterns to be general DAGs.

As an example, consider matching the following MLIR snippet, representing a 2-layer perceptron:

func @main(%arg0: tensor<2x20xf32>, %arg1: tensor<2xi32>, %arg2: tensor<256xf32>, %arg3: tensor<20x256xf32>, %arg4: tensor<256x10xf32>) -> () { // tensor<f32>, tensor<256xf32>, tensor<20x256xf32>, tensor<256x10xf32>) {
  %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
  %1 = "tf.Const"() {value = dense<1.000000e-01> : tensor<f32>} : () -> tensor<f32>
  %2 = "tf.Const"() {value = dense<5.000000e-01> : tensor<2x1xf32>} : () -> tensor<2x1xf32>

  // FC layer 1 (forward)
  %3 = "tf.MatMul"(%arg0, %arg3) {transpose_a = false, transpose_b = false} : (tensor<2x20xf32>, tensor<20x256xf32>) -> tensor<2x256xf32>
  %4 = "tf.BiasAdd"(%3, %arg2) {data_format = "NHWC"} : (tensor<2x256xf32>, tensor<256xf32>) -> tensor<2x256xf32>
  %5 = "tf.Relu"(%4) : (tensor<2x256xf32>) -> tensor<2x256xf32>

  // FC layer 2 (forward)
  %6 = "tf.MatMul"(%5, %arg4) {transpose_a = false, transpose_b = false} : (tensor<2x256xf32>, tensor<256x10xf32>) -> tensor<2x10xf32>

  // Softmax cross entropy
  %loss, %backprop = "tf.SparseSoftmaxCrossEntropyWithLogits"(%6, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<2xf32>, tensor<2x10xf32>)
  %7 = "tf.Mean"(%loss, %0) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
  %8 = "tf.PreventGradient"(%backprop) : (tensor<2x10xf32>) -> tensor<2x10xf32>
  %9 = "tf.Mul"(%8, %2) : (tensor<2x10xf32>, tensor<2x1xf32>) -> tensor<2x10xf32>

  // FC layer 2 (gradient computation)
  %10 = "tf.MatMul"(%9, %arg4) {transpose_a = false, transpose_b = true} : (tensor<2x10xf32>, tensor<256x10xf32>) -> tensor<2x256xf32>
  %11 = "tf.MatMul"(%5, %9) {transpose_a = true, transpose_b = false} : (tensor<2x256xf32>, tensor<2x10xf32>) -> tensor<256x10xf32>

  // FC layer 1 (gradient computation)
  %12 = "tf.ReluGrad"(%10, %5) : (tensor<2x256xf32>, tensor<2x256xf32>) -> tensor<2x256xf32>
  %13 = "tf.BiasAddGrad"(%12) {data_format = "NHWC"} : (tensor<2x256xf32>) -> tensor<256xf32>
  %14 = "tf.MatMul"(%arg0, %12) {transpose_a = true, transpose_b = false} : (tensor<2x20xf32>, tensor<2x256xf32>) -> tensor<20x256xf32>

  // FC layer 1 (parameter update)
  %15 = "tf.Mul"(%14, %1) : (tensor<20x256xf32>, tensor<f32>) -> tensor<20x256xf32>
  %16 = "tf.Sub"(%arg3, %15) : (tensor<20x256xf32>, tensor<20x256xf32>) -> tensor<20x256xf32>
  %17 = "tf.Mul"(%13, %1) : (tensor<256xf32>, tensor<f32>) -> tensor<256xf32>
  %18 = "tf.Sub"(%arg2, %17) : (tensor<256xf32>, tensor<256xf32>) -> tensor<256xf32>

  // FC layer 2 (parameter update)
  %19 = "tf.Mul"(%11, %1) : (tensor<256x10xf32>, tensor<f32>) -> tensor<256x10xf32>
  %20 = "tf.Sub"(%arg4, %19) : (tensor<256x10xf32>, tensor<256x10xf32>) -> tensor<256x10xf32>

  return %7, %18, %16, %20 : tensor<f32>, tensor<256xf32>, tensor<20x256xf32>, tensor<256x10xf32>
}

This example includes a fully connected (FC) layer with bias and a rectified linear activation unit (ReLU), a second FC layer without bias/ReLU, and a softmax cross entropy unit.

Suppose that we wish to match the first FC layer, including the forward activations and the parameter update, omitting the gradient computation for brevity of the example. This can be achieved with the following PDL pattern:

pdl.pattern : benefit(5) {
  %act = pdl.operand
  %weight = pdl.operand
  %bias = pdl.operand
  %weight_grad = pdl.operand
  %bias_grad = pdl.operand
  %lr = pdl.operand
  %vec = pdl.type
  %mat = pdl.type
  %matmul = pdl.operation "tf.MatMul"(%act, %weight : !pdl.value, !pdl.value) -> (%vec : !pdl.type)
  %matmul_result = pdl.result 0 of %matmul
  %biasadd = pdl.operation "tf.BiasAdd"(%matmul_result, %bias : !pdl.value, !pdl.value) -> (%vec : !pdl.type)
  %biasadd_result = pdl.result 0 of %biasadd
  %relu = pdl.operation "tf.Relu"(%biasadd_result : !pdl.value) -> (%vec : !pdl.type)
  %weight_mul = pdl.operation "tf.Mul"(%weight_grad, %lr : !pdl.value, !pdl.value) -> (%mat: !pdl.type)
  %weight_update = pdl.result 0 of %weight_mul
  %weight_sub = pdl.operation "tf.Sub"(%weight, %weight_update : !pdl.value, !pdl.value) -> (%mat : !pdl.type)
  %bias_mul = pdl.operation "tf.Mul"(%bias_grad, %lr : !pdl.value, !pdl.value) -> (%vec : !pdl.type)
  %bias_update = pdl.result 0 of %bias_mul
  %bias_sub = pdl.operation "tf.Sub"(%bias, %bias_update : !pdl.value, !pdl.value) -> (%vec : !pdl.type)

  pdl.rewrite %relu, %weight_sub, %bias_sub with "rewriter"
}

This pattern includes six operands: input activations %act, input weight %weight, input bias %bias, weight gradient %weight_grad, bias gradient %bias_grad, and the learning rate %lr (the weight/bias gradients are normally computed from the ReLU gradient; we specify them here as operands to keep our example pattern small). The pattern contains three roots: the ReLU operation %relu producing output activations, the subtraction %weight_sub producing the updated weight, and the subtraction %bias_sub producing the updated bias. The visualization of the pattern is shown below; note that it is impossible to match such a pattern using the current PDL infrastructure.

              act ---> matmul ---> biasadd ---> relu ---> act_out
                         ^            ^
                         |            |
                       weight       bias
                         |            |
                         v            v
    weight_out <--- weight_sub    bias_sub ---> bias_out
                         ^            ^
                         |            |
   weight_grad ---> weight_mul    bias_mul <--- bias_grad
                         ^            ^
                         |            |
                         +---- lr ----+

Objective

As shown in the example above, complex patterns that arise in the process of matching kernels in model parallel training of neural networks can be DAGs with multiple roots. In this RFC, the describe how to extend the current implementation of PDL to support such patterns without changing to the MLIR rewrite model.

Approach

A PDL pattern is executed by lowering it a lower-level abstraction, represented in the pdl_interp dialect. The current version of pdl-to-pdl_interp lowering converts the pattern to a collection of predicates on the positions, obtained by traversing the pattern from the given root downwards. Each position is a sequence of operand indices; for example, in the expression foo(bar, baz(x, y)), foo is a root operation at position [], bar is a leaf at position [0], and x is a leaf at position [1,0]. Intuitively, these positions are obtained by “hanging” the pattern DAG from the root, arbitrarily resolving the positions of nodes reachable by multiple paths from the root.

In the fully connected layer example shown above, such an approach is not possible: no matter which root we choose, we cannot access all the nodes in the pattern solely via operand (downward) traversal. One option to address this problem is to match the DAG underneath each root separately and then join these DAGs on the common operations, attributes, and types. Although this is possible, it represents a significant departure from the current rewrite architecture of MLIR and thus carries significant risk. Instead, we propose a simpler approach, where the pattern is “hung” by one of the candidate roots, and we traverse some of the edges upwards (towards the operation accepting the source value) as needed. Because mutliple operations can accept the same value, we need to perform (a limited form of) the subgraph isomorphism search. We implement this search by introducing iteration to the execution of the pdl_interp bytecode.

Outline

The rest of this document is organized as follows. In Section 2, we begin by describing the new operations we proposed to add in the pdl_interp dialect, including one new iterative operation. In Section 3, we discuss how such an operation can be supported in the bytecode interpreter. In Section 4, we explain how to select the optimal root from which to “hang” the pattern. Finally, in Section 5, we discuss the changes to the pdl-to-pdl_interp lowering needed to integrate these two changes. These four sections coincide with the commits implementing the multi-root support that accompany this RFC (to be submitted shortly).

2. New operations in the pdl_interp dialect

The pdl_interp dialect provides a lower level abstraction than the PDL dialect and is intended to be directly translated to the bytecode executing the match and rewrites. Here, we focus on solely on the matcher portion of pdl_interp; we need no changes to the rewrite portion of pdl_interp.

The matcher region of the pdl_interp dialect consists of a sequence of blocks. Each block performs zero or more assignments to variables and is terminated by a single predicate check that branches to subsequent blocks. The assignments of values / operations are done in the order of the traversal; this is necessary, because each assignment “unlocks” the access to the values below it. The predicates, however, are not tested in the order of traversal; this is because the order of predicate evaluation is optimized over all the patterns given to the pdl-to-pdl_interp lowering. The execution of the matcher follows the branching until we reach a block containing the pdl_interp.finalize operation.

The current matching approach always follows the operations by downward traversal, going from an operation to its operands, starting from the (single) specified root. Such execution is constant – once the root operation is fixed, the values / operations associated with each (nested) operand are fixed. In the generalization proposed here, we will allow ourselves to traverse upward (going from a value to its consumers). Such execution is iterative, because there may be multiple operations accepting the value. In order to support upward traversal, we introduce two new operations:

  • pdl_interp.get_accepting_ops: Returns a list of operations accepting the given value or a range of values at the specified position. Therefore, if there are two operations %op1 = "foo"(%val) and %op2 = "bar"(%val) accepting %val at position 0, %ops = pdl_interp.get_accepting_ops of %val : !pdl.value at 0 will return a list (range) containing both of them. This allows us to traverse upwards from a value to operations accepting that value.
  • pdl_interp.choose_op: Iteratively chooses one operation from a range of operations. Therefore, writing %op = pdl_interp.choose_op from %ops in the example above selects both %op1 or %op2, one after another.

A fragment of the pdl_interp matcher for the FC layer pattern is shown below:

func @matcher(%arg0: !pdl.operation) {
  %0 = pdl_interp.get_operand 0 of %arg0
  %1 = pdl_interp.get_defining_op of %0 : !pdl.value
  pdl_interp.is_not_null %1 : !pdl.operation -> ^bb2, ^bb1
^bb1:  // 57 preds
  pdl_interp.finalize
^bb2:  // pred: ^bb0
  %2 = pdl_interp.get_operand 0 of %1
  %3 = pdl_interp.get_defining_op of %2 : !pdl.value
  pdl_interp.is_not_null %3 : !pdl.operation -> ^bb3, ^bb1
^bb3:  // pred: ^bb2
  %4 = pdl_interp.get_operand 1 of %1
  %5 = pdl_interp.get_accepting_ops of %4 : !pdl.value at 0
  %6 = pdl_interp.choose_op from %5
  ...
^bb5:  // pred: ^bb4
  pdl_interp.check_operation_name of %arg0 is "tf.Relu" -> ^bb6, ^bb1
  ...
^bb10:  // pred: ^bb9
  pdl_interp.check_operation_name of %1 is "tf.BiasAdd" -> ^bb11, ^bb1
  ...
^bb18:  // pred: ^bb17
  pdl_interp.check_operation_name of %6 is "tf.Sub" -> ^bb19, ^bb1
^bb19:  // pred: ^bb18
  pdl_interp.check_operation_name of %3 is "tf.MatMul" -> ^bb20, ^bb1

In this matcher, %arg0 is the ReLU operation (%relu), %1 is the bias addition (%biasadd), %3 is the matrix-matrix multiply (%matmul), and %4 is the input bias (%bias). Note that in block ^bb3, we extract the operations accepting the input bias at position 0 (%5) and then iteratively choose an operation from this list (%6). Note that the order of predicate checks could be further optimized, so that we match the operation name “tf.BiasAdd” before we extract its first operand (%bias) and make an iterative choice. Such optimization does not affect correctness, and we plan to address it later by testing on a large collection of input patterns.

3. Iterative bytecode execution

Before we discuss the generalizations needed to implement the new pdl_interp.choose_op operation, it is useful to review how the bytecode is generated and executed presently. This will make it easier to understand the proposed changes.

The PDL bytecode is stored in a flat array as a sequence of fields and addresses. The bytecode executor maintains a pointer (address) to the current instruction and iteratively executes the pdl_interp operations until we reach pdl_interp.finalize. The translation between the pdl_interp dialect and the bytecode is somewhat direct, with one noteworthy optimization: memory allocation. The variables declared in the pdl_interp matcher often do not interact, because they are in unrelated branches of the execution. In order to save memory, the bytecode generator performs the analogue of register allocation, with unlimited number of registers that we try to minimize. This optimization is performed by storing, for each variable, the interval map of bytecode positions where this variable is used. The generator then greedily merges the storage of variables whose interval maps do not overlap. This procedure greatly reduces the number of registers (memory locations) used.

In order to support the new pdl_interp.choose_op operation, we make the following two changes to the bytecode execution and generation:

  1. In addition to the current address, we maintain a stack of addresses where the operation should resume. Whenever we encouter pdl_interp.choose_op that has not exhausted its iteration range, we push the current address to the stack and proceed with the execution as normally. Once we reach pdl_interp.finalize, we do not abort the computation right away. Rather, if the stack is non-empty, we resume the execution at the topmost address. Thus, we effectively perform recursion within the current (sequential) implementation of the bytecode executor.
  2. Because pdl_interp.finalize no longer terminates the execution, we may need to access values that were previously assigned (if the execution goes back to pdl_interp.choose_op). For this execution to work correctly, we need to extend the lifetime of the values that are deemed alive at the time a iterative operation is performed. This lifetime neeeds to be extended in all the branches that the bytecode interpreter might take.

These changes are sufficient to implement the proposed extension and may be useful in further development of PDL.

4. Root ordering

Having described the extensions to the pdl_interp dialect and the underlying bytecode, we now examine the task of determining the starting node for matching and how the pattern is traversed. Here, the intuition is twofold:

  • Intuition 1: The iterative pdl_interp.choose_op operations are costly, both because there may be multiple consumers (operations) of the same value, and because there is overhead (mostly in memory) to having iterative operations in our bytecode. Unfortunately, it is hard to quantify the former, because we do not know the number of uses of a value until the pattern is matched (although we could assume that the pattern lists all uses of a value). Therefore, we choose a simple metric of minimizing the number of pdl_interp.choose_op operations, which is the same as minimizing the number of upward traversals.
  • Intuition 2: We only need to consider roots when selecting the starting node for matching and how the pattern is traversed. Clearly, selecting a non-root node can only increase the number of iterative operations. Furthermore, the traversals have to include all roots, and once we reach a root, there are no more upward traversals needed for the part of the pattern underneath this root. Therefore, we can use roots as anchors in our traversal search, where traversing one root after another corresponds to a path in the underlying pattern.

Following this intuition, we treat optimal traversal as a root ordering problem. We form a graph over the specified roots, provided in pdl.rewrite, where two roots are connected by a directed edge if the target root can be connected (via a chain of operations) in the underlying pattern to the source root. We place a restriction that the path connecting the two candidate roots must only contain the nodes in the subgraphs underneath these two roots. The cost of an edge is the smallest number of upward traversals (edges) required to go from the source to the target root, and the connector is a Value in the intersection of the two subtrees rooted at the source and target root that results in that smallest number of such upward traversals. Optimal root ordering is then formulated as the problem of finding a spanning arborescence (i.e., a directed spanning tree) of minimal weight.

Consider once again the pattern shown in Section 1, repeated below:

              act ---> matmul ---> biasadd ---> relu ---> act_out
                         ^            ^
                         |            |
                       weight       bias
                         |            |
                         v            v
    weight_out <--- weight_sub    bias_sub ---> bias_out
                         ^            ^
                         |            |
   weight_grad ---> weight_mul    bias_mul <--- bias_grad
                         ^            ^
                         |            |
                         +---- lr ----+

The cost of the edge relu -> weight_sub is 1 (the out-edge traversal weight -> weight_sub), with weight being the connector and similarly for relu -> bias_sub (cost 1, connector bias). The cost of the edge weight_sub -> relu is 3 (traversals weight -> matmul -> biasadd -> relu, connector weight), while the cost of edge bias -> relu is 2 (traversals bias -> biasadd -> relu). The edges between weight_sub and bias_sub in the cost graph each have weight 2, because they need to go through the common node lr. It is easy to see that the optimal root for this pattern is relu, resulting in the spanning arborescence relu -> {weight_sub, bias_sub} with weight 2.

In order to determine the spanning arborescence (directed spanning tree) of minimum weight, we use the Edmonds’ algorithm. The worst-case computational complexity of this algorithm is O(N^3) for a single root, where N is the number of specified roots. The pdl-to-pdl_interp lowering calls this algorithm as a subroutine N times (once for each candidate root), so the overall complexity of root ordering is O(N^4). If needed, this complexity could be reduced to O(N^3) with a more efficient algorithm. However, note that the underlying implementation is very efficient, and N in our instances tends to be very small (<10). Therefore, we believe that the proposed (asymptotically suboptimal) implementation will suffice for now.

5. Lowering from PDL to pdl_interp

Having described the building blocks of our approach, we now turn to integrating them in the pdl-to-pdl_interp dialect lowering. The input to the lowering is a collection of patterns, whose public-facing API remains unchanged, except for the pdl.rewrite operation now taking multiple candidate roots. In generalizing the lowering, we perform the following steps:

  1. Build the cost graph among the candidate roots: this follows the construction described in the previous section and computes the parent map of operations for each candidate root as a by-product.
  2. Compute the optimal root and the corresponding minimum directed spanning tree (arborescence) as described in the previous section.
  3. Gather the predicates at the optimal root, and then, for each edge of the spanning tree, follow the path in the underlying pattern from the connector to the target root. We collect the predicates for the subgraph rooted at each node (operation) on that path, guarding against going back to the positions where we came from.
  4. OperationPos predicate has been extended to allow tracing the operation accepting a value (the opposite of operation defining a value).

With these changes in place, the public-facing API of the PDL pattern (from the perspective of the MLIR rewrite framework) remains unchanged: it accepts a single (optimal root), performing the matches entirely within the bytecode interpreter. We have unit-tested the changes and will provide an integration test matching a multi-layer perceptron (MLP).

1 Like

Thanks for the RFC! I’ve been a bit swamped internally over the past week(perf season), but intend to give this more attention this week.

– River

Thanks, River. Looking forward to your review.

Stano

This is a really great RFC! Reading it, it seems like a good direction to move towards. Comments inline.

I would use pdl_interp.get_users instead, as it matches the C++ API.

Is it a requirement that this be a value, as opposed to Value or ValueRange? I’m thinking of situations where a range of results (possibly all) of an operation are used.

Another thing I’m a tad bit apprehensive about for this operation is that it kind of obfuscates the control flow of the matcher. It isn’t clear from a first glance that this is going to be iterated from this point on. I don’t have even real suggestions off the top of my head, given that the other parts of the pattern can be shared, but just a feeling.

I’m okay with starting from here. PDL is going to be getting a lot of optimization attention at some point soon, so starting with something clean that works well sounds good to me. We can always optimize as needed when necessary.

I think it would also be nice if we could automatically detect when a pattern has additional “roots” aside from the one given to pdl.rewrite (in the case of a single root). I assume that we could use the same infrastructure as you have described above, but we would implicitly treat the single root as the “optimal” root (I think, unless I misread your description). IMO this would simplify quite a bit of work for frontends (including the one I’m building right now), as the user wouldn’t need to provide all of the roots explicitly (i.e. the mental model would be close to what you expect from the C++ API).

The concepts in general look good to me, so +1. I’ll start looking over the code this next week and start review. I’ll be OOO the 8-18th, but will likely have comments on most of the reviews before then.

– River

I had a similar feeling. It would perhaps help readability renaming this to something like: pdl_interp.interatively_choose_op or pdl_interp.choose_op_iterative?

We can certainly rename the op, the current name has the root in the original thinking of treating the state machine as nondeterministic polynomial - making simultaneous choices from a range of options. Something like pdl_interp.foreach would probably capture all we want.

I agree that, with the current proposal, the control flow is not all that obvious. One way to address that would be to specify a region within the pdl_interp.foreach operation that gets executed for each choice. Then, it would be slightly clearer what gets executed repeatedly, and maybe even the liveness checker would work better, too. On the flip side, this will likely require bigger changes to the the lowering and Bytecode generator (I am happy to undertake these if there is consensus that it would make the code cleaner).

I don’t think that the full power of a foreach is necessary, but regions would help readability. You could have pdl_interp.iterate_ops(%ops) (%op : !pdl.operation) { ... } implemented the same way with an address stack and pdl_interp.pop_and_jmp. And instead of complicating the behaviour of pdl_interp.finalize, you could move it to its own block and preload its address onto the stack.

On the other hand, if a foreach would be useful going forward, I’m all for it.

1 Like

I don’t think that the full power of a foreach is necessary, but regions would help readability. You could have pdl_interp.iterate_ops(%ops) (%op : !pdl.operation) { ... } implemented the same way with an address stack and pdl_interp.pop_and_jmp . And instead of complicating the behaviour of pdl_interp.finalize , you could move it to its own block and preload its address onto the stack.

Thank you for your feedback @Mogball. I will attempt to implement this using regions. As for the operation indicating the next iteration (or the end if moving past the last element), how about we call it pdl_interp.continue?