LLVM Discussion Forums

[RFC] Tensors with unknown element types

Recently when @stellaraccident and I have been working on npcomp/TCF/TCP, along with other discussions in the context other systems, the need for a tensor type with an unknown element type has come up.

This is extremely common when attempting to faithfully model frontends with tensor types based on a dynamic language like Python (numpy, PyTorch, and TensorFlow are concrete examples). For example, in the absence of further refinement the Python code numpy.add(x, y) maps to an op something like numpy.add %x, %y : (tensor<*xany>, tensor<*xany>) -> tensor<*xany>. The same is true for TensorFlow and PyTorch. Ideally we would like to do that type refinement algorithm inside MLIR so it can be shared across all of them (and also because we generally like MLIR :stuck_out_tongue: ) .

The fact that element type is currently mandatory I think is just a historical artifact – TensorFlow’s GraphDef protocol buffer representation has no way of expressing an unknown element type for an tensor (even though tensorflow ops can be polymorphic), and I think MLIR just inherited that accidentally.

There’s two obvious ways to dealing with this (would like to hear others):

  1. Add built-in support to tensor types to know that the element type is unknown, much as we do for dimensions with unknown extents. This could be either allowing getElementType() to return nullptr or have getElementType() internally assert(hasElementType()), much as getShape() aborts for unranked ShapedType’s today.
  2. Add an any type and have, for example, tensor<?xany>. (we can bikeshed the name; in npcomp we call it !numpy.any_dtype).

I tend to lean towards 2, because:

A. It uses an existing extension point. It also scales nicely to in the future e.g. refining any -> oneof<f32, f64> in a type inference pass.
B. I’m really worried about getElementType() returning nullptr or aborting. That seems ripe for a long tail of random compiler crashes. Pragmatically, ODS predicates like AnyFloat will need to be rewritten to have null checks, e.g. $_self && $_self.isa<FloatType>(), which is inconvenient, easy to forget, and a lot of churn.

Thoughts?

I prefer having an “unknown” type as well.
However it can be done in a dialect specific way already (like !numpy.any_dtype…). The main consideration in doing this in a more endorse way in the infrastructure is gonna be about defining all the helpers and ODS predicate to be defined with respect to unknown. For example is SameOperandsElementType ok when you have an “unknown” dtype and something else?

For me, the main reason I’d like the “unknown” type to be in core is so that the {numpy,pytorch,tensorflow} -> TCF conversion doesn’t have to do a type conversion, since things always get much more annoying when having to do type conversions.

As far as ODS predicates, I don’t think we should give any special treatment to the “unknown” type. That keeps my mental model simple for how they work :slight_smile:

Not really: shape refinement of GraphDef does allow for unknown dtype, we are more restrictive. It was more that there was no pressing need for it and so we didn’t want to add complexity until there was a use case where alternatives could be considered. Also we had discussed generic types and these interactions are clearer.

Which reminds of same shape constraints: which today are static vs some cases prefer it dynamic (e.g., * == 10x10 has different answers). That is true here too.

That seemed to be the consensus when this was discussed last and I think a good starting point.

If it was only an artifact of not having done shape/type inference/refinement yet, we could have added it to shape dialect. But if we consider this for more general frontend use it may not make sense (e.g., it could be an actual type that makes it all the way to runtime and used for dynamic dispatch).

Which means you’d need ops that is effectively element_case to make the type system happy if you have a tensor<*x any> and a tensor<10x !tf.resource<20x30x f32> (the latter may dynamically be equivalent to the former and or established during shape inference) [which I think is another plus for dialect specific unknown as the unknown could occur inside a subtype of a dialect type, or some dialects may not allow unknowns there and is enforced at construction time]. And also means all current rules that verify the types (e.g., add can only operate on integer or floating point operands) would need to be updated (unless one gives special treatment to unknown, which could be subclasses of a unknown base class - which isn’t a type in and of itself :slight_smile: ).

Type conversion here would mostly be done during function specialization, shape/type refinement and there you would be doing the type conversion. I don’t think this would happen during arbitrary conversions and so there would be dedicated conversions.

Actually that is what !shape.element_type is for shapes types (which tensor is one of). But dialects would need to accept it as a valid element type for their operations.

It really didn’t have anything to do with TensorFlow’s modeling of this - it is because we wanted to have additional invariants implied by the type. This is similar to the “why wasn’t memref and index allowed inside a memref” discussion - by omitting them, address calculation logic didn’t need to be target specific.

It isn’t clear to me that it makes sense to change Tensor to allow an unspecified element type. We have operations that are type directed - e.g. integer vs floating point, signed vs unsigned, etc. Furthermore, changing this would mean that lots of things that take a Tensor would need defensive logic to detect and ignore/handle the missing type.

Beyond that, it isn’t clear to me that the dynamically typed case intersects much with the existing tensor type - you don’t get layouts, static shapes or anything else coming with it.

I would recommend introducing a new “untyped tensor” type in your dialect, and use type/shape inference to specialize towards std.tensor.

-Chris

Ah, cool. Thanks for the info!

It seems like that same argument would apply to suggesting that dialects have their own “unranked tensor”, which we do support in the core tensor type. Do you see any fundamental difference between the unranked case and the untyped case that would make us want to treat them, from an infra perspective (e.g. whether supported in the core tensor type) in a different way?

(actually, the fact that you may be mixing unranked, untyped, and various combinations of them seems to suggest that we need to handle them in a unified way, rather than as discrete types)

I’m open to this and was probably headed there anyway because tensor doesn’t actually match the semantics of what is being represented at this level (more that it was a convenient way to do less typing to get started). Balancing that, though, if we do end up wanting generic algorithms for doing this kind of type refinement, we’re going to want corresponding types to go with them that can be generically manipulated. It is probably ok to say whatever core dialect implements such things should also bring types and/or interfaces to represent it.

(This also touches on my annoyance that the ShapedType type hierarchy is closed and in std: it is currently not possible to create a dialect type that is tensor-like in this way and have it not be completely disjoint. ShapedType is acting as a “type trait” but since we lack that concept generically, the tendency is to want to put more things into that type hierarchy because they share parts of the trait)

These frontend-isomorphic and transformation oriented dialects tend to be “islands” already, in that they are quite disconnected from the backend op and type systems, requiring analysis and transformation to get to the lower layers. It seems cleaner to me to make it their responsibility to model their semantics and generality levels correctly, which implies a more expansive type system, corresponding casts, etc. In some cases std.tensor may be enough, but if it’s not, new types are cheap (so long as they don’t need to cross too many dialect boundaries and have a clear layer beyond which they cannot exist).

I am going to nitpick the history here, though: it does seem somewhat coincidental that the existing std.tensor type hierarchy is sufficient to represent TensorFlow/XLA-HLO frontends but no others (to my knowledge others that have a concept named “tensor” or “ndarray” all have somewhat different semantics and would not be representable with the built-in tensor types). This would be less of an either/or issue if we had traits of some kind representing the key parts of the tensor type (versus a closed type hierarchy which creates subtle pressure to both conform to what is there and inch the std types forward over time).

That removes all interest in this to me: if this “any”/“unknown” types isn’t well handled with its specificity, then what does it bring more than your own dialect type? Saving a straightforward type-conversion isn’t very motivating to me.

To expand, if we don’t have predicated handling this and clear rule about this element type, then we can’t have:

"std.addi"(%a, %b) : (tensor<?x!std.unknown>, tensor<10xi32>) -> tensor<?x!std.unknown>

And if we can’t have this, then we can’t have an IR where you start with %b being tensor<?x!std.unknown> but then type/shape inference refining it to tensor<10xi32>.

Basically I see similar needs for the dynamic element type as what we have to do for the unranked (and dynamic broadcast) cases.

The main thing I was objecting to was changing ODS, which I don’t think we really need to touch anything. To be clear, things like “is tensor type compatible” would need to be be upgraded to be aware of it (but that’s like one function?).

Or one way to look at it is: our standard tensor types today model a type lattice which does not have a way of representing an unknown element type. For example, we can do join(?xf32, ?x?xf32) -> *xf32, but join(?xf32, ?xf64) cannot be modeled. This proposal is to extend that lattice to include ability to model an unknown element type. Places that reason about that lattice would need to be updated (and that should be a relatively small set of C++ helpers, not anything is ODS per se?).

This actually has the nice side effect of making the core tensor type lattice closed under the lattice join operations. It then has a well-defined “top” type of tensor<* x unknown> which it currently lacks.

I’ve never seen a type conversion that was “straightforward”, at least with MLIR’s current infra.

This is definitely a judgement call - there is a spectrum here with many valid points in the design space. The things I pointed out above (having type-specific operations) leans towards requiring a dtype, but this doesn’t apply to having dynamic shape. Let me be clear though - I’m agree completely that an approach with dynamic dtype “could work”, I just don’t think it would be worth it on balance.

Let me give you a similar but different example: the type system for the FIRRTL IR (spec), has signed and unsigned integer types, but the width is optional - width inference turns unspecified widths into specified ones. This is an important (for this API) design point, but it doesn’t mean that we should remove the width specifier from the standard IntegerType: we should make it so an MLIR FIRRTL IR can define its own integer types. In practice, this already works very well.

I completely agree, I would love to see this fixed! We have other things that are overly specific to std types as well, e.g. integer attributes really want to have IntegerType which isn’t ideal for domains that have their own integer types.

MLIR is already very good about having extensible operations, we should keep working to make sure the type system is equally extensible IMO.

-Chris

Could you elaborate on that? I’m especially struggling to find arguments that don’t apply equally to supporting unranked in the core tensor type.

+1, much as each dialect seems to have its own flavor of “add”, perhaps they could each have their own flavor of “tensor” with similarly low overhead, and we could work toward that.

(but also to Stella’s point above, w.r.t. to the current state of things I feel like if unknown dtype was needed to roundtrip TensorFlow’s GraphDef, the core tensor type would probably “coincidentally” support it)

Personally, I’ve shied away from proposing anything to fix this because I am not confident enough in my C++ skills to build something of the generality of the Operation system that is also memory-efficient/performant to the degree that core IR constructs need to be.

But this could be a lot easier for Types since they are a) Not type-erased like Operation at the C++ level, and b) Are much less dense in real programs, possibly requiring less gymnastics to squeeze every last bit out and still retain usability.

Has anyone given any thought to having extensible type traits for IR Types?

I already mentioned examples upthread. std.addf vs std.addi and other design decisions don’t make sense with a dynamically typed dtype, but do with dynamic shapes.

How is std.addf operating on “unknown” element type different from a matmul operating on unranked tensors (or ranked tensor with unknown dimensions)?
I see quite an analogy: in one case we want to refine the shapes in the other we want to refine the element type, but fundamentally it does not seem much different: we have operations that have restrictions that can’t be entirely statically checked.

1 Like

Because we don’t have std.add that is dynamic on element type, which you would have to have to actually use a dynamic element type.

There are good reasons we don’t have this too - such an operation would have to carry a union of the info needed for signed/unsigned integers as well as floating point rounding etc. This was discussed at length very early in MLIR evolution, see here and here.

-Chris

It isn’t clear to me why we can’t just use std.addi with a statically “unknown” type, it just adds a constraint that the unknown types have to be integer types, but you don’t know which integer statically.
Just like std.addi with two unranked tensors has the expectation that they match at runtime (or after inference.

It also does not prevent other dialect to use tensors with their own operations: for instance tf.Add works on both integers and floating points.

I am curious how an untyped tensor would be plumbed through to all the backends. If we take the LLVM and SPIR-V as current backends for MLIR. Both these backends can handle dynamic shapes, i.e. you can think of a sequence of instructions that can use dynamic shape information and perform pointer arithmetic. When the type is not know I am not sure I can think of a similar sequence of instructions that can handle the type being dynamic (should the code contain integer add or floating point add?). These backends talk to infrastructures/spec that are not type polymorphic. I am might be missing something here, but dynamic shapes and dynamic types seem to be totally different things.
The way I see it, before reaching the “backends” the type needs to be resolved. So in some sense, that kills what I see as the initial requirement of type being truly dynamic (where the compiled module can handle different types of the input). The other option is multi-versioning, which seems like a very heavy hammer (and there is potential combinatorial explosion of versions).

I don’t think it makes sense for any backends we can currently envision, and from my perspective, the main reason for even having unranked as allowed (to say nothing of untyped) is to facilitate scenarios where types are being refined to legality. Having very loose constraints on these low level ops seems like a big hammer to apply for the needs of some tooling which needs to operate for a period of time with constraint violations (which can currently be done by disabling verification for limited scopes).

I would probably go the opposite way in this discussion and say that the std ops should have had stricter type constraints because it is someone’s job at the high level to refine the types and ensure that sufficient dynamic checks/casts are placed such that the stricter types can be assumed. We can always have higher level ops that do whatever we want, but the low level ops like std.addi need to be pretty closely aligned with code that can actually be generated.

I think that having looser types on these low level ops without some further framework for specifying the dynamic constraints that must be valid is going to cause trouble (I’m untangling this now on the xla side).

(The other side of me is sympathetic because type conversions in MLIR are really hard right now and the verification strictness adds to the pain)

It’s totally fine for a backend to be statically typed. In fact encouraged IMO :slight_smile:

But in many compilation flows there are many compilation steps before one gets to those statically typed backends, where any number of lowering or legalization processes can reduce a more dynamic problem to one that can be handled by the statically typed backends. (much as a backend that only supports static shapes does not negate the value of representing dynamically shaped tensors at higher levels of abstraction)

Also, keep in mind that not all compilation flows lower all the way to machine code or something low-level like SPIR-V. It’s totally fine to have a compiler that reads in a fairly program such as a restricted python subset + “numpy” (TorchScript as a great example) and spits out an optimized program at the same level of abstraction. Such programs, in the absence of any other information, have dynamic element types and dynamic ranks, but many transformations can still be done.

Also, I’ll note that many transformations can be done without knowing the element type. For example:

  • Recognizing that a slice/view operation is an identity.
  • Folding a gather of a gather.
  • Folding a reshape of a reshape.
  • Folding a transpose of a transpose.