LLVM Discussion Forums

Shape type dialect

Hi all,

We’ve mentioned the shape dialect in our documents and on the mailing list a few times, but mostly in the context of shape inference instead of as a dialect on its own. The goal of the shape dialect is to develop in MLIR core a dialect for shaped type constraints specification/solving. In particular interest it will be used for ShapedType (for example, to determine the output shape and element type of a result) as representation of shape functions.

I’ll describe the dialect following the MLIR developer guide’s guidelines for contributing a new dialect, here are the elements:

What is the overall goal of the dialect? What is the first implementation milestone?

The goal of the dialect is to describe the shape constraints of operations in a manner that could be consumed by different backends (either during compilation or for a runtime). Shape functions should be expressible without being tied to a specific use case or obscured by lowering considerations (e.g., returning early when an error encountered should be possible to do without being expressly part of shape function description). The first milestone would be to be able to represent all operation shape constraints in OpBase and to automatically generate both verification and build methods for operations fully constrained by their shape functions.

How does it fit into the MLIR dialect ecosystem? Connection: how does it connect to the existing dialects in a compilation pipeline(s)?

The dialect is intended to be used by ODS/DRR and other dialects to define & solve shape constraints for operations defined in other dialects. But in turn functions in this dialect could be generated from other dialects (e.g., a shape function for a structured ops could be generated from the iteration descriptions of the op).

Consolidation: is there already a dialect with a similar goal or matching abstractions; if so, can it be improved instead of adding a new one?

None in MLIR core, there are existing library functions and different users have different shape inference implementations, but no dialect with the same goal. The standard and affine arithmetic operations could be used to specify shape function too, and for static shaped computations might be sufficient but for dynamic shapes would result in unnecessarily complicating/obscuring the functions (e.g., having shape.add operate on two dimensions which may be dynamic should result in the correct addition result vs needing to insert additional conditionals to handle dynamic cases inside the shape function if one were to represent this using standard dialect).

Reuse: how does it generalize to similar but slightly different use-cases? What is the community of users that it is serving?

A successful implementation should allow specifying shape functions for compilers, runtimes and different solvers:

  • It enables automatically generating build methods for operations whose constraints fully define an operation. Functions in this dialect would be used to generate new build methods, making it easier to add new operations where the verification and inference is based off of the same constraints.
  • It could be used to generate constraint for symbolic solvers.
  • Shape functions could be reified into dynamic allocation/checking calls for runtime memory checking/allocation.

In particular this would be used to represent TensorFlow, TFLite and XLA shape functions. nGraph, the tensor compute dialect and others could have similar needs.

Who are the future contributors/maintainers beyond those who propose the dialect?

Folks who need shape inference for their ops and want a common dialect to perform optimization, reification and constraint solving with.

1 Like

I think you know I’m supportive of this being developed in tree, and I think that you and @herhut have put so much thought in to this, I’d love to see the work carried out in view of the community. I know you all have a private dialect tablegen file you’ve been collaborating on, and having seen it, I can say that the context helps. Do you think it would be valuable to copy it into a gist and include it here for others to see/comment on?

My experience with this is that it would be great if we could prioritize getting the shape types and small set of supporting ops in tree, since it is really hard for anything to take shape (no pun intended) without those. Getting the types in place allows for better project-specific collaboration and attempts to build algorithms for doing the various tasks (which can then give us confidence in approaches that we want in-tree).

Also we are accumulating a dialect of similar things that it would be good to dedup and base on some common types/ops. This one is part of a couple of week spike that we’ve been doing to see what it takes to get (ranked) dynamic shapes threaded through our stack e2e. We’ve been approaching it from the post shape-inference perspective (assuming that something has already inferred reasonable shapes for everything), focusing more on threading the dynamic shapes through the rest of our compiler/runtime stack. We’ve learned a couple of things so far:

  • Having good value types for this is critical, and those may be at a different granularity compared to the “analysis types” that get used for solving shape constraints.
  • Particularly, since the case of ranked shapes is such an attractive codegen target, we have a type and supporting ops specifically for that and have defined it in a way that it degrades away completely in the static case.
  • We haven’t solved the unranked case yet, but we presume such shape values will always be backed by a heavier-weight runtime type than the ranked case, and would prefer to model it as such with explicit casts/checkpoints where we can move between the worlds.
  • This is in contrast to the ranked and static case, where we have found it very convenient to use the dialect conversion infra to expand the dynamic dims to scalar SSA values/args when moving across parts of the codegen stack.

I think that the primitive types and ops could all co-exist in the same dialect and would be a value add for MLIR. We may just want to keep an eye on the different levels of the problem and be open to modeling things at the different granularities that we may need for analysis and transformation/codegen.

Thanks!

With some consultation work with others - its been bottlenecked on unrelated tasks (and funnily we are closer to a theoretical analysis of a couple of constructs - its a very fun rabbit hole).

I completely agree, it is difficult to visualize what it means or how it works without that. It was difficult to think of staging: I didn’t want to dump ops on folks, but iterating slowly can hide the reason why :slight_smile: (e.g., “rank 1 int64 tensors are all you need”). I plan to push out for review a small set of ops that we can start with next week. Keeping it perhaps to ~6 ops, no pretty printing/parser [although that is much easier now :wink: ] and then we can discuss, iterate and expand.

The shape function description addition to the shape inference doc is up for review, in case you have some comments there.

I agree, and I think that is another reason why how one describes and interpret it should be decoupled.

+1 , that fits in with the above.

Thank you for the proposal Jacques!

With the goals that you have described, I think this dialect will compose really nicely with the rest of the ecosystem. It definitely fills a niche that is currently being handled by adhoc traits; subsuming many/all of the existing shape related traits/constraints would be a great win IMO.

I’m also +1 on mainly focusing on the types to start with. Types generally end up more important than operations in many cases, so nailing those first seems best.

– River

Hi Jacques,

I am coming a little late to the party.

This sounds and look really great, I’d like to discuss 2 additional points I do not see mentioned:

  1. How do you see this interacting with shape information that can be easily extracted from AffineMap specifications? In particular, is is possible to write a very generic shape inference procedure that works for a large class of computations in the Tensor Compute domain with relatively little effort. See for example this impl. based on Halide. Porting this to AffineMap + SSA value should be very easy and even shorter. I’d recommend we consider what the Shape dialect brings in addition to that (e.g. gather/scatter, indirect load, external shape compute function, sparse data types etc) that can’t be represented much more easily with AffineMap and be deliberate about the simplest way of attacking each problem. I definitely agree with the comment about avoiding to obscure the functions in dynamic cases. Can we see some concrete examples of intermixing the static affine and dynamic cases?
  2. Related to 1. I would be very wary to see implicit broadcast semantics carry over into that space. AFAIK it can almost always be replaced by explicit and non-surprising affine semantics. I think implicit broadcast should not leak into MLIR core. Here again, examples would go a long way.

Sorry for the delay in responding to this!

It commonalizes it :slight_smile: Information that can be easily extracted is then represented as the same form as things that would have to be manually coded. So you have a common representation format and need not be special cases (so you extract the shape from easy to extract format but have it and solve it with general one). I see these as decoupled and one feeding into another.

I’m not sure why you are referring to AffineMap here, shape dialect has nothing to do with affine maps.

Yes there will be a large class covered by compute domains, but that doesn’t cover all of TensorFlow’s shape functions, so these could be one element here. But again I see that more as a source for generating the output in shape dialect. Independent of the solving, inference and reification of shape function.

Sure, look at the ops being described in the pending rev.

That isn’t being proposed. What is being proposed is that one can write a shape function that says that operand X and Y to this op have implicit broadcast behavior. But it is an explicit broadcast function in the shape dialect.

Can you be more specific about where to find such an example in the revision?
The revision is fairly large and does not include any test and/or examples (beyond the few snippets with the individual ops.).

Would you be able to provide here in the proposal some example of pseudo-IR illustrating this dialect in action?

I think the question may be “if you were to use AffineMap to express your shape functions, would it be enough? And then could you just use affine.apply and let the regular optimizer do its job”.
Basically that comes back to my question above: I think a few examples of what you’re trying to model here may be useful!

On another angle: modeling “exact” shape functions is a great start and will get us a long way. However I don’t see a mention of propagating non-exact constraints and in particular a common one that is bounds.
For example XLA has a “dynamic padder” feature where the input of a computation is not statically shaped but bounded: the compiler can then infer max bounds on the entire computation, which comes handy at buffer allocation / scheduling time!
How do you see this playing with your work in this dialect in the future?

Sure, let me copy the larger one over

func @shape_num_elements(%shape : !shape.type) -> !shape.size {
      %0 = "shape.constant_dim"() {value = 1 : i32} : () -> !shape.size
      %1 = "shape.reduce"(%shape, %0) ( {
        ^bb0(%index: i32, %dim: !si.dim, %lci: !shape.size):
          %acc = "shape.mul"(%lci, %dim) :
            (!shape.size, !shape.size) -> !shape.size
          "shape.return"(%acc) : (!shape.size) -> ()
        }) : (!shape.type, !shape.size) -> (!shape.size)
      return %1 : !shape.size
    }

This is without pretty printing/parsing form :slight_smile: But you can express the reduction without explicitly having ifKnown/ifRanked etc. in every part. tf.Reshape was the one I was playing with most and there one ends up with multiple different conditionals that makes it difficult to see flow (not that the final tf.Reshape shape function suddenly becomes all that pretty :wink: )

Yes and I think that was the first answer: AffineMaps would not be able to handle the general case (e.g., tf.Reshape requires computing 2 products and dividing one by another, which I don’t think I could express as one). I want this to get very close to being able to express all TF’s shape functions (which makes things tricky, else I could have shape functions only operating on shapes, but that would not be sufficient).

We should still do that where possible! Lowering the shape functions to std/affine etc. and then to code to reuse the same lowering paths. From that point of view shape functions can be a consumer of affine for describing the input/output space relation but also a producer for lowering/codegeneration.

Yes indeed that is one of the main goals to unify these. Propagating the upper bounds fits in with the declarative goal. E.g., if one has shape.join [?<n, 10], [?, ?] then one has a new result shape [?<n, 10] and so the upper bound is propagated through the declaration of “equality”. Now, it is depends on the inference context/approach if it is propagated. E.g., for fully dynamic it could drop it on the floor, for others it could be stored in the context or lowered type (or even lowered operation). The decoupling of the description with the usage is my interest with it being separate dialect.

But that is also a good point in that as different cases require different specification (e.g., shape functions for op by op launch where all values and shapes are known, need not consider anything beyond static shapes, but we’d want to have single source so that we could specialize the different forms) and so a more general inference could be constrained by the information provided by less exact shape functions. This is not unique to here, I see that similar to op definitions: the more you specify, the more precise the modelling becomes but it is good to keep in mind. That is where common building blocks help to reduce number of places to expand.

Thanks!

One more question: how is !shape.size different from index type? (And why not reuse it?)
After all this is how the standard dialect represent shape dimensions for ShapedType.
You have operations like "shape.mul"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size which I am not sure how they differ from a multiplication on the index type.

Answering myself, I forgot I already looked this up in the past and the answer is in the initial patch (which wasn’t linked from this thread before I think?):

`shape.size` represents a non-negative integer with support for being unknown and invalid.

Hi Jacques,

Affine maps can handle products of shape symbols and divisions by shape symbols. Here’s one example.

What are all the operations you’ll need in your shape functions? Multiplications, additions, divisions, and modulos should all be fine with affine maps, and all canonicalizations on affine map operations work in the presence of that. Would you ever need the exponential or logarithmic function? These can’t be captured inside of affine maps.

Can you say which revision? I couldn’t find any in D73572, the only one linked in the original post.

Has this revision been posted for review? I couldn’t find this example in D73944 either.

Thanks.

In this case I’d need the equivalent of

\prod_{i=0}^na_i/\prod_{i=0}^n(b_i)H(b_i)

where H is the Heaviside function for one of the indices. And of course I’d be very interested if broadcastable could be represented as affine map :slight_smile:

I’m not 100% sure: in general and with static shapes one can do arbitrary arithmetic, now I think that would be rare and that the majority we can just express with affine maps (well, even equality and broadcastable gets us quite far, ~400 ops I can add shape support for with 4 primitives [broadcastable, scalar, join, no output]). Then also there are ops for which computing the shape is as much work as executing the op itself and so may not be worth defining an exact shape function for.

Yes sorry, Mehdi linked it above.

In https://reviews.llvm.org/D73944 search for shape_num_elements, you’ll see in the ops. And I need to fix a typo (I originally called the dialect si … but decided that might have folks think of SI ;-))

Thanks! It looks like you clearly need arbitrary functions. I think you’ll anyway get folding canonicalizations on such arbitrary IR in shape functions as well. It’s just that your shape function when represented as an affine map would:
(a) perhaps be easier to analyze and extract / infer structure from,
(b) would have a lower IR memory footprint, but that probably doesn’t matter at the high-level stages of the IR this is all for (with tensor types / compute graph stages).

On the flip side, having both affine maps and arbitrary functions would provide you two things/choices to deal with, sort of fragmenting/duplicating things.

Out of curiosity, are n(a_i), n(b_i) known at compile time? If they are, you can still use maps. Just define each of the heaviside function outputs as a symbol a-priori and use those symbols in the affine map. You can do:

affine_map < () [ah0, ah1, ah2, bh0, bh1, bh2] -> ( (ah0 * ah1 * ah2) floordiv (bh0 * bh1 * bh2) ) >

And if really needed, any analysis surrounding shapes can be made aware of the heaviside ops say to help determine ranges for those symbols. However, one is lost if these n(a_i), n(b_i) are themselves unknown at compile time, :slight_smile: in which case you’ll anyway need unknown trip count loops in your shape function, and this issue is common to unknown rank tensors.