We have discussed broadcasting and some interesting aspects of ranked cases in [RFC] Declarative "named ops" in the Linalg dialect - #30 by _sean_silva , but I don’t feel like we have discussed much about handling of unranked. Linalg can’t represent unranked for example. My personal view is that unranked is a separate problem (even for “elementwise add”), best modeled at a level of ops that have implicit broadcasting and shape-related runtime errors happening “inside the ops” rather than detected via reified shape calculations like we would prefer to use at the “linalg named ops” level of abstraction. Examples:
def f(t: Tensor):
while cond():
t = stack(t, t)
return t
In this example, each iteration of the loop increases the rank by 1.
def f(lhs: Tensor, rhs: Tensor):
return lhs + rhs
In this example there is statically indeterminate broadcasting and error reporting.
I think this example makes it clear that there there is a need for a “ML frontend” dialect layer separate from the abstraction level discussed in the other thread I referenced above, with support for unranked being a key feature. One of the key transformations at this level is shape/rank inference to allow as much of the program as possible to be amenable to powerful lower-level code generation techniques. Also multiversioning / PGO-guided (or JIT-feedback guided) specialization to allow it.
This level of abstraction also can play well with other frontend constructs, like the following python code which has equivalents in most modern ML frameworks:
def f(t: Tensor, n: int) -> List[Tensor]:
l = []
for i in range(n):
if i == 0:
l.append(t)
else:
l.append(t * l[i-1] * i)
return l
In this example, the frontend framework would potentially bring their own list type (which might have some frontend-specific semantics) but reuse a tensor multiplication op with implicit broadcasting and handling of unranked that MLIR provides.