Abstract
The goal of this RFC is to extend the current implementation of PDL to support patterns with multiple roots. This extension is needed to match complex patterns that arise in the process of matching kernels in model parallel training of neural networks. We achieve this extension by appropriately extending the pdl_interp
dialect, updating the pdl
topdl_interp
lowering, and generalizing the execution model of the PDL bytecode to support iterative execution.
1. Introduction
Motivation
PDL is a highlevel abstraction for the rewrite pattern infrastructure in MLIR. A PDL pattern consists of a pattern (match) region and a rewrite region. The pattern region specifies a DAG of operations starting from one or more operands. The rewrite region specifies, for a single root operation, how this operation and its descendants are altered. Notice the asymmetry: although the pattern region is a general DAG, the current semantics for the rewrite region forces the DAG to have a single root (that is, a tree, with some inner nodes connected). Traditionally, this limitation has not been too restrictive: the rewrite rules used for optimization or lowerings often operate on a single root, rewriting a tree of operations. However, model parallel training of neural networks may tie together the forward and the backward execution of a kernel. This is because doing so leverages locality, allowing the weights to be stored locally to simultaneously implement both the forward and the backward pass. This use cases requires the patterns to be general DAGs.
As an example, consider matching the following MLIR snippet, representing a 2layer perceptron:
func @main(%arg0: tensor<2x20xf32>, %arg1: tensor<2xi32>, %arg2: tensor<256xf32>, %arg3: tensor<20x256xf32>, %arg4: tensor<256x10xf32>) > () { // tensor<f32>, tensor<256xf32>, tensor<20x256xf32>, tensor<256x10xf32>) {
%0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () > tensor<1xi32>
%1 = "tf.Const"() {value = dense<1.000000e01> : tensor<f32>} : () > tensor<f32>
%2 = "tf.Const"() {value = dense<5.000000e01> : tensor<2x1xf32>} : () > tensor<2x1xf32>
// FC layer 1 (forward)
%3 = "tf.MatMul"(%arg0, %arg3) {transpose_a = false, transpose_b = false} : (tensor<2x20xf32>, tensor<20x256xf32>) > tensor<2x256xf32>
%4 = "tf.BiasAdd"(%3, %arg2) {data_format = "NHWC"} : (tensor<2x256xf32>, tensor<256xf32>) > tensor<2x256xf32>
%5 = "tf.Relu"(%4) : (tensor<2x256xf32>) > tensor<2x256xf32>
// FC layer 2 (forward)
%6 = "tf.MatMul"(%5, %arg4) {transpose_a = false, transpose_b = false} : (tensor<2x256xf32>, tensor<256x10xf32>) > tensor<2x10xf32>
// Softmax cross entropy
%loss, %backprop = "tf.SparseSoftmaxCrossEntropyWithLogits"(%6, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) > (tensor<2xf32>, tensor<2x10xf32>)
%7 = "tf.Mean"(%loss, %0) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) > tensor<f32>
%8 = "tf.PreventGradient"(%backprop) : (tensor<2x10xf32>) > tensor<2x10xf32>
%9 = "tf.Mul"(%8, %2) : (tensor<2x10xf32>, tensor<2x1xf32>) > tensor<2x10xf32>
// FC layer 2 (gradient computation)
%10 = "tf.MatMul"(%9, %arg4) {transpose_a = false, transpose_b = true} : (tensor<2x10xf32>, tensor<256x10xf32>) > tensor<2x256xf32>
%11 = "tf.MatMul"(%5, %9) {transpose_a = true, transpose_b = false} : (tensor<2x256xf32>, tensor<2x10xf32>) > tensor<256x10xf32>
// FC layer 1 (gradient computation)
%12 = "tf.ReluGrad"(%10, %5) : (tensor<2x256xf32>, tensor<2x256xf32>) > tensor<2x256xf32>
%13 = "tf.BiasAddGrad"(%12) {data_format = "NHWC"} : (tensor<2x256xf32>) > tensor<256xf32>
%14 = "tf.MatMul"(%arg0, %12) {transpose_a = true, transpose_b = false} : (tensor<2x20xf32>, tensor<2x256xf32>) > tensor<20x256xf32>
// FC layer 1 (parameter update)
%15 = "tf.Mul"(%14, %1) : (tensor<20x256xf32>, tensor<f32>) > tensor<20x256xf32>
%16 = "tf.Sub"(%arg3, %15) : (tensor<20x256xf32>, tensor<20x256xf32>) > tensor<20x256xf32>
%17 = "tf.Mul"(%13, %1) : (tensor<256xf32>, tensor<f32>) > tensor<256xf32>
%18 = "tf.Sub"(%arg2, %17) : (tensor<256xf32>, tensor<256xf32>) > tensor<256xf32>
// FC layer 2 (parameter update)
%19 = "tf.Mul"(%11, %1) : (tensor<256x10xf32>, tensor<f32>) > tensor<256x10xf32>
%20 = "tf.Sub"(%arg4, %19) : (tensor<256x10xf32>, tensor<256x10xf32>) > tensor<256x10xf32>
return %7, %18, %16, %20 : tensor<f32>, tensor<256xf32>, tensor<20x256xf32>, tensor<256x10xf32>
}
This example includes a fully connected (FC) layer with bias and a rectified linear activation unit (ReLU), a second FC layer without bias/ReLU, and a softmax cross entropy unit.
Suppose that we wish to match the first FC layer, including the forward activations and the parameter update, omitting the gradient computation for brevity of the example. This can be achieved with the following PDL pattern:
pdl.pattern : benefit(5) {
%act = pdl.operand
%weight = pdl.operand
%bias = pdl.operand
%weight_grad = pdl.operand
%bias_grad = pdl.operand
%lr = pdl.operand
%vec = pdl.type
%mat = pdl.type
%matmul = pdl.operation "tf.MatMul"(%act, %weight : !pdl.value, !pdl.value) > (%vec : !pdl.type)
%matmul_result = pdl.result 0 of %matmul
%biasadd = pdl.operation "tf.BiasAdd"(%matmul_result, %bias : !pdl.value, !pdl.value) > (%vec : !pdl.type)
%biasadd_result = pdl.result 0 of %biasadd
%relu = pdl.operation "tf.Relu"(%biasadd_result : !pdl.value) > (%vec : !pdl.type)
%weight_mul = pdl.operation "tf.Mul"(%weight_grad, %lr : !pdl.value, !pdl.value) > (%mat: !pdl.type)
%weight_update = pdl.result 0 of %weight_mul
%weight_sub = pdl.operation "tf.Sub"(%weight, %weight_update : !pdl.value, !pdl.value) > (%mat : !pdl.type)
%bias_mul = pdl.operation "tf.Mul"(%bias_grad, %lr : !pdl.value, !pdl.value) > (%vec : !pdl.type)
%bias_update = pdl.result 0 of %bias_mul
%bias_sub = pdl.operation "tf.Sub"(%bias, %bias_update : !pdl.value, !pdl.value) > (%vec : !pdl.type)
pdl.rewrite %relu, %weight_sub, %bias_sub with "rewriter"
}
This pattern includes six operands: input activations %act
, input weight %weight
, input bias %bias
, weight gradient %weight_grad
, bias gradient %bias_grad
, and the learning rate %lr
(the weight/bias gradients are normally computed from the ReLU gradient; we specify them here as operands to keep our example pattern small). The pattern contains three roots: the ReLU operation %relu
producing output activations, the subtraction %weight_sub
producing the updated weight, and the subtraction %bias_sub
producing the updated bias. The visualization of the pattern is shown below; note that it is impossible to match such a pattern using the current PDL infrastructure.
act > matmul > biasadd > relu > act_out
^ ^
 
weight bias
 
v v
weight_out < weight_sub bias_sub > bias_out
^ ^
 
weight_grad > weight_mul bias_mul < bias_grad
^ ^
 
+ lr +
Objective
As shown in the example above, complex patterns that arise in the process of matching kernels in model parallel training of neural networks can be DAGs with multiple roots. In this RFC, the describe how to extend the current implementation of PDL to support such patterns without changing to the MLIR rewrite model.
Approach
A PDL pattern is executed by lowering it a lowerlevel abstraction, represented in the pdl_interp
dialect. The current version of pdl
topdl_interp
lowering converts the pattern to a collection of predicates on the positions, obtained by traversing the pattern from the given root downwards. Each position is a sequence of operand indices; for example, in the expression foo(bar, baz(x, y))
, foo
is a root operation at position []
, bar
is a leaf at position [0]
, and x
is a leaf at position [1,0]
. Intuitively, these positions are obtained by â€śhangingâ€ť the pattern DAG from the root, arbitrarily resolving the positions of nodes reachable by multiple paths from the root.
In the fully connected layer example shown above, such an approach is not possible: no matter which root we choose, we cannot access all the nodes in the pattern solely via operand (downward) traversal. One option to address this problem is to match the DAG underneath each root separately and then join these DAGs on the common operations, attributes, and types. Although this is possible, it represents a significant departure from the current rewrite architecture of MLIR and thus carries significant risk. Instead, we propose a simpler approach, where the pattern is â€śhungâ€ť by one of the candidate roots, and we traverse some of the edges upwards (towards the operation accepting the source value) as needed. Because mutliple operations can accept the same value, we need to perform (a limited form of) the subgraph isomorphism search. We implement this search by introducing iteration to the execution of the pdl_interp
bytecode.
Outline
The rest of this document is organized as follows. In Section 2, we begin by describing the new operations we proposed to add in the pdl_interp
dialect, including one new iterative operation. In Section 3, we discuss how such an operation can be supported in the bytecode interpreter. In Section 4, we explain how to select the optimal root from which to â€śhangâ€ť the pattern. Finally, in Section 5, we discuss the changes to the pdl
topdl_interp
lowering needed to integrate these two changes. These four sections coincide with the commits implementing the multiroot support that accompany this RFC (to be submitted shortly).
2. New operations in the pdl_interp
dialect
The pdl_interp
dialect provides a lower level abstraction than the PDL dialect and is intended to be directly translated to the bytecode executing the match and rewrites. Here, we focus on solely on the matcher portion of pdl_interp
; we need no changes to the rewrite portion of pdl_interp
.
The matcher region of the pdl_interp
dialect consists of a sequence of blocks. Each block performs zero or more assignments to variables and is terminated by a single predicate check that branches to subsequent blocks. The assignments of values / operations are done in the order of the traversal; this is necessary, because each assignment â€śunlocksâ€ť the access to the values below it. The predicates, however, are not tested in the order of traversal; this is because the order of predicate evaluation is optimized over all the patterns given to the pdl
topdl_interp
lowering. The execution of the matcher follows the branching until we reach a block containing the pdl_interp.finalize
operation.
The current matching approach always follows the operations by downward traversal, going from an operation to its operands, starting from the (single) specified root. Such execution is constant â€“ once the root operation is fixed, the values / operations associated with each (nested) operand are fixed. In the generalization proposed here, we will allow ourselves to traverse upward (going from a value to its consumers). Such execution is iterative, because there may be multiple operations accepting the value. In order to support upward traversal, we introduce two new operations:

pdl_interp.get_accepting_ops
: Returns a list of operations accepting the given value or a range of values at the specified position. Therefore, if there are two operations%op1 = "foo"(%val)
and%op2 = "bar"(%val)
accepting%val
at position 0,%ops = pdl_interp.get_accepting_ops of %val : !pdl.value at 0
will return a list (range) containing both of them. This allows us to traverse upwards from a value to operations accepting that value. 
pdl_interp.choose_op
: Iteratively chooses one operation from a range of operations. Therefore, writing%op = pdl_interp.choose_op from %ops
in the example above selects both%op1
or%op2
, one after another.
A fragment of the pdl_interp matcher for the FC layer pattern is shown below:
func @matcher(%arg0: !pdl.operation) {
%0 = pdl_interp.get_operand 0 of %arg0
%1 = pdl_interp.get_defining_op of %0 : !pdl.value
pdl_interp.is_not_null %1 : !pdl.operation > ^bb2, ^bb1
^bb1: // 57 preds
pdl_interp.finalize
^bb2: // pred: ^bb0
%2 = pdl_interp.get_operand 0 of %1
%3 = pdl_interp.get_defining_op of %2 : !pdl.value
pdl_interp.is_not_null %3 : !pdl.operation > ^bb3, ^bb1
^bb3: // pred: ^bb2
%4 = pdl_interp.get_operand 1 of %1
%5 = pdl_interp.get_accepting_ops of %4 : !pdl.value at 0
%6 = pdl_interp.choose_op from %5
...
^bb5: // pred: ^bb4
pdl_interp.check_operation_name of %arg0 is "tf.Relu" > ^bb6, ^bb1
...
^bb10: // pred: ^bb9
pdl_interp.check_operation_name of %1 is "tf.BiasAdd" > ^bb11, ^bb1
...
^bb18: // pred: ^bb17
pdl_interp.check_operation_name of %6 is "tf.Sub" > ^bb19, ^bb1
^bb19: // pred: ^bb18
pdl_interp.check_operation_name of %3 is "tf.MatMul" > ^bb20, ^bb1
In this matcher, %arg0
is the ReLU operation (%relu
), %1
is the bias addition (%biasadd
), %3
is the matrixmatrix multiply (%matmul
), and %4
is the input bias (%bias
). Note that in block ^bb3
, we extract the operations accepting the input bias at position 0 (%5
) and then iteratively choose an operation from this list (%6
). Note that the order of predicate checks could be further optimized, so that we match the operation name â€śtf.BiasAddâ€ť before we extract its first operand (%bias
) and make an iterative choice. Such optimization does not affect correctness, and we plan to address it later by testing on a large collection of input patterns.
3. Iterative bytecode execution
Before we discuss the generalizations needed to implement the new pdl_interp.choose_op
operation, it is useful to review how the bytecode is generated and executed presently. This will make it easier to understand the proposed changes.
The PDL bytecode is stored in a flat array as a sequence of fields and addresses. The bytecode executor maintains a pointer (address) to the current instruction and iteratively executes the pdl_interp
operations until we reach pdl_interp.finalize
. The translation between the pdl_interp
dialect and the bytecode is somewhat direct, with one noteworthy optimization: memory allocation. The variables declared in the pdl_interp
matcher often do not interact, because they are in unrelated branches of the execution. In order to save memory, the bytecode generator performs the analogue of register allocation, with unlimited number of registers that we try to minimize. This optimization is performed by storing, for each variable, the interval map of bytecode positions where this variable is used. The generator then greedily merges the storage of variables whose interval maps do not overlap. This procedure greatly reduces the number of registers (memory locations) used.
In order to support the new pdl_interp.choose_op
operation, we make the following two changes to the bytecode execution and generation:
 In addition to the current address, we maintain a stack of addresses where the operation should resume. Whenever we encouter
pdl_interp.choose_op
that has not exhausted its iteration range, we push the current address to the stack and proceed with the execution as normally. Once we reachpdl_interp.finalize
, we do not abort the computation right away. Rather, if the stack is nonempty, we resume the execution at the topmost address. Thus, we effectively perform recursion within the current (sequential) implementation of the bytecode executor.  Because
pdl_interp.finalize
no longer terminates the execution, we may need to access values that were previously assigned (if the execution goes back topdl_interp.choose_op
). For this execution to work correctly, we need to extend the lifetime of the values that are deemed alive at the time a iterative operation is performed. This lifetime neeeds to be extended in all the branches that the bytecode interpreter might take.
These changes are sufficient to implement the proposed extension and may be useful in further development of PDL.
4. Root ordering
Having described the extensions to the pdl_interp
dialect and the underlying bytecode, we now examine the task of determining the starting node for matching and how the pattern is traversed. Here, the intuition is twofold:
 Intuition 1: The iterative
pdl_interp.choose_op
operations are costly, both because there may be multiple consumers (operations) of the same value, and because there is overhead (mostly in memory) to having iterative operations in our bytecode. Unfortunately, it is hard to quantify the former, because we do not know the number of uses of a value until the pattern is matched (although we could assume that the pattern lists all uses of a value). Therefore, we choose a simple metric of minimizing the number ofpdl_interp.choose_op
operations, which is the same as minimizing the number of upward traversals.  Intuition 2: We only need to consider roots when selecting the starting node for matching and how the pattern is traversed. Clearly, selecting a nonroot node can only increase the number of iterative operations. Furthermore, the traversals have to include all roots, and once we reach a root, there are no more upward traversals needed for the part of the pattern underneath this root. Therefore, we can use roots as anchors in our traversal search, where traversing one root after another corresponds to a path in the underlying pattern.
Following this intuition, we treat optimal traversal as a root ordering problem. We form a graph over the specified roots, provided in pdl.rewrite
, where two roots are connected by a directed edge if the target root can be connected (via a chain of operations) in the underlying pattern to the source root. We place a restriction that the path connecting the two candidate roots must only contain the nodes in the subgraphs underneath these two roots. The cost of an edge is the smallest number of upward traversals (edges) required to go from the source to the target root, and the connector is a Value
in the intersection of the two subtrees rooted at the source and target root that results in that smallest number of such upward traversals. Optimal root ordering is then formulated as the problem of finding a spanning arborescence (i.e., a directed spanning tree) of minimal weight.
Consider once again the pattern shown in Section 1, repeated below:
act > matmul > biasadd > relu > act_out
^ ^
 
weight bias
 
v v
weight_out < weight_sub bias_sub > bias_out
^ ^
 
weight_grad > weight_mul bias_mul < bias_grad
^ ^
 
+ lr +
The cost of the edge relu > weight_sub
is 1 (the outedge traversal weight > weight_sub
), with weight
being the connector and similarly for relu > bias_sub
(cost 1, connector bias
). The cost of the edge weight_sub > relu
is 3 (traversals weight > matmul > biasadd > relu
, connector weight
), while the cost of edge bias > relu
is 2 (traversals bias > biasadd > relu
). The edges between weight_sub
and bias_sub
in the cost graph each have weight 2, because they need to go through the common node lr
. It is easy to see that the optimal root for this pattern is relu
, resulting in the spanning arborescence relu > {weight_sub, bias_sub}
with weight 2.
In order to determine the spanning arborescence (directed spanning tree) of minimum weight, we use the Edmondsâ€™ algorithm. The worstcase computational complexity of this algorithm is O(N^3) for a single root, where N is the number of specified roots. The pdl
topdl_interp
lowering calls this algorithm as a subroutine N times (once for each candidate root), so the overall complexity of root ordering is O(N^4). If needed, this complexity could be reduced to O(N^3) with a more efficient algorithm. However, note that the underlying implementation is very efficient, and N in our instances tends to be very small (<10). Therefore, we believe that the proposed (asymptotically suboptimal) implementation will suffice for now.
5. Lowering from PDL to pdl_interp
Having described the building blocks of our approach, we now turn to integrating them in the pdl
topdl_interp
dialect lowering. The input to the lowering is a collection of patterns, whose publicfacing API remains unchanged, except for the pdl.rewrite
operation now taking multiple candidate roots. In generalizing the lowering, we perform the following steps:
 Build the cost graph among the candidate roots: this follows the construction described in the previous section and computes the parent map of operations for each candidate root as a byproduct.
 Compute the optimal root and the corresponding minimum directed spanning tree (arborescence) as described in the previous section.
 Gather the predicates at the optimal root, and then, for each edge of the spanning tree, follow the path in the underlying pattern from the connector to the target root. We collect the predicates for the subgraph rooted at each node (operation) on that path, guarding against going back to the positions where we came from.

OperationPos
predicate has been extended to allow tracing the operation accepting a value (the opposite of operation defining a value).
With these changes in place, the publicfacing API of the PDL pattern (from the perspective of the MLIR rewrite framework) remains unchanged: it accepts a single (optimal root), performing the matches entirely within the bytecode interpreter. We have unittested the changes and will provide an integration test matching a multilayer perceptron (MLP).