Unranked tensors in TCP

We have discussed broadcasting and some interesting aspects of ranked cases in [RFC] Declarative "named ops" in the Linalg dialect - #30 by _sean_silva , but I don’t feel like we have discussed much about handling of unranked. Linalg can’t represent unranked for example. My personal view is that unranked is a separate problem (even for “elementwise add”), best modeled at a level of ops that have implicit broadcasting and shape-related runtime errors happening “inside the ops” rather than detected via reified shape calculations like we would prefer to use at the “linalg named ops” level of abstraction. Examples:

def f(t: Tensor):
    while cond():
        t = stack(t, t)
    return t

In this example, each iteration of the loop increases the rank by 1.

def f(lhs: Tensor, rhs: Tensor):
  return lhs + rhs

In this example there is statically indeterminate broadcasting and error reporting.

I think this example makes it clear that there there is a need for a “ML frontend” dialect layer separate from the abstraction level discussed in the other thread I referenced above, with support for unranked being a key feature. One of the key transformations at this level is shape/rank inference to allow as much of the program as possible to be amenable to powerful lower-level code generation techniques. Also multiversioning / PGO-guided (or JIT-feedback guided) specialization to allow it.

This level of abstraction also can play well with other frontend constructs, like the following python code which has equivalents in most modern ML frameworks:

def f(t: Tensor, n: int) -> List[Tensor]:
    l = []
    for i in range(n):
        if i == 0:
            l.append(t)
        else:
            l.append(t * l[i-1] * i)
    return l

In this example, the frontend framework would potentially bring their own list type (which might have some frontend-specific semantics) but reuse a tensor multiplication op with implicit broadcasting and handling of unranked that MLIR provides.

1 Like

Thanks for bringing this up. I agree. I also see the ability to handle unranked code at two levels. For the front-end, this is interesting

  1. when coming from a (rank-)dynamically typed front-end language
  2. to be able to specify library functions independently of rank

Being able to model this in a shared dialect would enable the reuse of specialization and shape inference specifications. However, one might be able to achieve the same by using dialects for modelling the shape behavior and interfaces to model specialization. Either way, I see enabling reuse at this level as very valuable.

Secondly, for code generation, it is interesting if one wants to execute dynamically ranked programs without the need to infer shapes and generate code first. Like in a true AOT setting.

This is not catered for in LinAlg, as it is focused on modelling rank-specific programs. Handling rank-specific programs is an big and important part of the problem and in my experience one models the rank dynamic and rank specific computations differently at the code generation level anyway. So this is not a criticism of LinAlg.

The question is whether we would design a counterpart to LinAlg for rank-generic code and then lower higher-level ML dialects into it (having LinAlg and X side by side) or whether we specify a dialect on top that is shared and specify lowerings from that dialect to LinAlg for the known-rank case and other code-generation otherwise. I expect the latter to have more potential for sharing lowering logic.

Do you see significant opportunities for code generation for dynamically ranked programs, beyond just dispatching into precompiled libraries of rank-generic kernels? I would like to hear more about this. (I guess I typically associate the term “code generation” with lowering to loops, vectors, etc. which is not really possible in a rank-generic setting; perhaps you could elaborate on what you mean here).

I guess, what aspects of what linalg does carry over into the rank-generic setting? I’m having a hard time picturing what a “counterpart to LinAlg for rank-generic code” would even look like.

Also, I think at this level of abstraction, things like convolution window sizes should be dynamic values instead of attributes.

The question is where these rank-generic kernels come from. Are those all written by hand or are some (or all) generated?

For convolution, in the rank generic case, you still have a fixed set of dimensions the convolution computes over and an unknown number of dimensions you iterate over. So you need to be able to express the indexing in a flexible enough way to represent this. Using index-vectors instead of scalar index values is a first step and you probably want an efficient way to express transposes.

I am not claiming to have an answer here (or even that producing efficient code is feasible, this is mostly a completeness thing), I was just wondering whether this is an area we are at all interested in.

A generalized LinAlg would need rank-independent affine maps for starters.

In such cases like conv or batchmatmul, I would expect that one would reshape such that the unknown dimensions you iterate over are flattened into a single dimension. It is then reduced to the ranked case.

It sounded like @sanjoy_das_google was interested in this.

But I agree that it’s not clear how this fits in with the overall TCP effort. I guess I was trying to use this as a strawman for the question: should TCP ops verifier-check that the inputs are ranked tensors? I think they should, but I’ve heard dissenting opinions in the past.

If the unranked dimensions are not next to each other, you would need to transpose first, to make this possible (I consider a reshape a logical operation that manipulates the shape of the tensor without changing the layout of elements). That is what I meant with efficient transpose.

This is the what I was wondering about, as well. Should TCP be able to serve as a vehicle to express unranked computations (so that rank inference can be done on it and it can serve as language for expressing rank-independent “library” code) and, as a follow on, if so, would be expect to always specialize to rank for code generation or would be also want to generate code.

Do you know of any situation where a “variadic” list of dimensions are not adjacent? Or a case with two independent variadic lists of dimensions in the same shape? I’m curious if a couple primitives could be used to funnel most of the unranked code into ranked primitives in a clean way.

E.g.

%result = "tcp_frontend.batch_matmul"(%lhs, %rhs) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>

would be lowered into

%ranked_lhs, %ranked_rhs, %rank_collapse_record =
    "tcp.collapse_variadic_dims"(%lhs, %rhs)
    { lhs_dims = [VARIADIC, -2, -1], rhs_dims = [VARIADIC, -2, -1]}
    : (tensor<*xf32>, tensor<*xf32>)
    -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>, !tcp.rank_collapse_record)
%ranked_result = "tcp.batch_matmul"(%ranked_lhs, %ranked_rhs) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
%result = "tcp.reconstruct_variadic_dims"(%ranked_result, %rank_collapse_record)
    { result_dims = [VARIADIC, 1, 2] }
    : (tensor<?x?x?xf32>, !tcp.rank_collapse_record) -> tensor<*xf32>

In this code, !tcp.rank_collapse_record saves the information that was lost when collapsing a variadic number of batch dims to a single dim. The tcp.collapse_variadic_dims op slices out certain dimensions from an unranked shape, with a special VARIADIC token used to capture a variadic list of dimensions into a single dimension of the output, and saving that information into the rank collapse record. That is, the tcp.collapse_variadic_dims does the following:

  1. it constructs a rank-3 %ranked_lhs consisting of first all leading variadic dimensions flattened into one, then the -2’th dimension of the input shape, followed by the -1st.
  2. it constructs a rank-3 %ranked_rhs result similarly (though in general the dimension spec could be different), and the special “VARIADIC” tokens embed an assertion (with error bailout) that it matches with the VARIADIC dimensions on the LHS.
  3. it packages up the the information needed to reconstruct the VARIADIC leading shape into the !tcp.rank_collapse_record.

The tcp.batch_matmul is then a ranked operator.

Finally, the tcp.reconstruct_variadic_dims reconstructs the variadic leading dimensions by inserting the VARIADIC dimension list from the rank collapse record in place of dim 0 of %ranked_result, followed by dim 1 of %ranked_result, followed by dim 2 of %ranked_result.

After applying a transformation like this locally to the code, further transformations can clean it up. For example, if shape inference can somehow prove that %lhs and %rhs are already ranked, the tcp.collapse_variadic_dims could be removed from the program.

As with the tcp.clean_for_matmul ops in Is the TCP "matmul" op marked NoSideEffect? - #5 by _sean_silva I don’t necessarily believe that we literally need ops like tcp.collapse_variadic_dims. This could all be open-coded in the shape dialect or with other TCP-specific ops.

In fact, all of this can be seen as just reifying code “hidden inside” the tcp_frontend.batch_matmul. For example, if one looks at tf.BatchMatMulV2 runtime code here, one sees a large amount of runtime code checking things, doing shape computations, and issuing errors. What we are doing at this stage is reifying all that runtime code statically so we can optimize it and statically see the nucleus tcp.batch_matmul op.

(actually, further lowerings can be seen that way as well. e.g. lowering to operate on buffers reifies buffer allocations, which are then subject to their own transformation; lowering to commands on CUDA streams or vulkan/IREE command buffers is another layer of reification)

I’m +1 on this. I think such considerations are closer to the numpy/frontend dialect world and that one would go through levels of specialization that introduce static behavior from dynamic code.

For anything that handles loops, tiling, vectorization and other things that TCP/Linalg on buffers and Affine do, we want static ranks.

I am a bit torn here … on one hand it is tempting to try and generalize every (i, j, k) into (i..., j..., k...) and ping-pong between linearizations and delinearizations. OTOH is it worth subjecting ourselves to the pain of doing this in MLIR today?

In other words what is the expected benefit? I know of the following cases:

  1. “rank-agnostic” library calls that the first thing they do is put a switch on the rank + sizes and dispatch to different versions. Let this live in C++ land, it’s easier to create/maintain/extend there.
  2. “rank-agnostic” kernels that really want to work for everything (e.g. add 2 “structured” tensor blobs as long as they have the same number of elements). This ends up spending an inordinate amount of effort in linearize/delinearize ping-pongs. Same remark, it’s better left in C++ for now I think.
  3. “special” rank-agnostic kernels that are so simple that they fold into a static n-D counterpart. This generally only works on contiguous buffers and pointwise stuff. This feels like a top-level cast.

Bottome line, before we embark on doing unranked codegen, I’d really love to see very compelling reasons for which libraries are 1) say an order of magnitude off, 2) relevant for us to address and 3) that pure unranked codegen would solve (casting to ranked in the process does not qualify).

I am not aware of such cases at this time.

I would expect that the tcp.batch_matmul has arguments that specify which dimensions of the inputs it applies to. As the operation itself is ranked, the number of such inputs is known. The point here is that those numbers need to be dynamic values.

You then essentially can compute what transpose to do to make these dimensions adjacent and then reshape it into the required ranked tensor. That is your step 1 and 2 and one can certainly hide it behind a compound of with later lowering. If the operands are static and ranks are known, you could also lower this directly to linalg.

+1

I agree. I think it is very powerful to be able express all these things in IR rather than hiding them in some runtime implementation. That is one of the powers that gradual lowering gives. We can start with a very high-level IR where operations have fat semantics and then gradually lower to ops with simpler semantics and explicit encoding of checks and transformations. The hope being that we no longer need to implement all the special versions of operations but instead can generate them from one definition.

Buffer allocation is a good example for this. Keeping buffers implicit (fat semantics) makes algebraic optimizations and fusion easier. Making buffers explicit later unlocks optimization potential.

Can you explain this in more detail? The semantics I am understanding from what you wrote sound more like an extension of tcp_frontend.batch_matmul to incorporate lists of collapsing/free dims as dynamic arguments, rather than something related to tcp.batch_matmul as I have set it out here.

Yes, I was referring to the frontend variant. And I am not disagreeing with what you say. I just tried to tie this back to my initial comment that you need an efficient way to do transposes.