Intel AMX Vector Dialect

The Intel Advanced Matrix Extensions (AMX) provides a tile matrix multiply unit (TMUL), a tile control register (TILECFG), and eight tile registers TMM0 through TMM7 (TILEDATA). I am wrapping up an AMX vector dialect that bridges the semantic gap between MLIR concepts such as 2-d vectors and memrefs and the lower level details of Intel AMX. When lowered to LLVM, configuration set up is automatically taken care of by the backend.

Something like

  %0 = amx.tilezero : vector<8x8xi8>
  amx.tilestore %arg0[%c0, %c0], %0 : memref<?x?xi8>, vector<8x8xi8>

will map to something like this

        vxorps  zmm0, zmm0, zmm0
        vmovups zmmword ptr [rsp - 64], zmm0
        mov     byte ptr [rsp - 64], 1
        mov     byte ptr [rsp - 16], 8
        mov     word ptr [rsp - 48], 8
        ldtilecfg       [rsp - 64]
        mov     ax, 8
        tilezero        tmm0
        mov     ecx, 8
        tilestored      [rsi + rcx], tmm0
        tilerelease
        vzeroupper
        ret

Expect a PR with some sample integration tests (running on an emulator) soon.

6 Likes

Great! Power to the vectors :slight_smile:

1 Like

That looks promising, thanks @aartbik :slight_smile:

1 Like

This is looking good - thanks! Narrow question: why do you have vmovups instead of vmovaps? In this case, we’d prefer to align memrefs at 64-byte boundaries (vector<8x8xi8> elt types). So you should get aligned load/stores. Wouldn’t there be a reasonable penalty for unaligned ones for AMX unlike for Haswell, Skylake-X, etc.?

Note that the vmovups above prepares an initial zero value for the tile configuration on the stack for subsequent use by ldtilecfg (and is generated by the LLVM backend). The unaligned version is probably due to assumptions on stack alignment?

The 2-d vector, in this example, is written to by tilestored.

The AMX Vector Dialect is now ready for review. It is a rather large drop, so below is a brief description. Everything seems to work, including three full integration tests (running on a Sapphire Rapids emulator), but the documentation around AMX and its intrinsics is sometimes a bit “sparse” so please let me know if you encounter issues. It tried to make this a nice demonstration of the power of MLIR’s progressive lowering from pleasant high-level VV abstractions into efficient, but much more detailed low-level HWV code.

(1) The AMX dialect in MLIR itself is very concise, it consists of just a few new vector operations.

     amx.tilezero
     amx.mulf/muli
     amx.tileload/store

All these operations work on concepts familiar within the MLIR framework: 2-d vectors and memrefs. For example, clearing a tile within an enveloping buffer can be done as follows.

%1 = amx.tilezero : vector<16x16xi32>
amx.tilestore %arg0[%i, %j], %1 : memref<?x?xi32>, vector<16x16xi32>

The idiomatic TMUL operations look as follows.

 %0 = amx.tilemulf %a, %b, %c
    : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>

The dialect verifies that the types and shapes are actually supported by AMX.

(2) The AMX dialect is lowered into an LLVMAMX dialect, which is closer to the compiler-oriented “internal” intrinsics of LLVM IR. The lowering takes care of some tedious details, such as providing tile parameters, stride computations, and instruction selection. The “amx.multf” is lowered to ‘tdpbf16ps’ and “amx.multi” to one of the ‘tdpbssd’, ‘tdpbsud’, ‘tdpbusd’, or ‘tdpbuud’ intrinsics. All this is completely transparent to the higher levels, though.

The example above lowers to the following.

%45 = "llvm_amx.tdpbf16ps"(%43, %42, %44, %40, %28, %28) : (i16, i16, i16, !llvm.array<16 x vector<16xf32>>, !llvm.array<16 x vector<32xbf16>>, !llvm.array<16 x vector<32xbf16>>) -> !llvm.array<16 x vector<16xf32>>

(3) This dialect is further lowered into LLVM IR dialect. At this point, the x86_amx type comes into play.

%47 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 16, i16 64, i16 64, x86_amx %46, x86_amx %37, x86_amx %37)

(4) The LLVM IR dialect is eventually handed off the LLVM. The backend converts the intrinsics into instructions and adds the necessary tile configuration instructions.

A sample configuration set up and actual kernel instructions are illustrated below.

vpxord  %zmm0, %zmm0, %zmm0
vmovdqu64       %zmm0, -64(%rsp)
movb    $1, -64(%rsp)
movb    $16, -16(%rsp)
movw    $64, -48(%rsp)
movb    $16, -15(%rsp)
movw    $64, -46(%rsp)
ldtilecfg       -64(%rsp)
movq    24(%rsp), %rax
movq    48(%rsp), %rcx
.Ltmp2:
addq    %r8, %r8
movw    $64, %dx
movw    $16, %di
tileloadd       (%rsi,%r8), %tmm0
shlq    $2, %rcx
tileloadd       (%rax,%rcx), %tmm1
tdpbf16ps       %tmm0, %tmm0, %tmm1
tilestored      %tmm1, (%rax,%rcx)
tilerelease
vzeroupper
retq

I added several integration tests to make sure everything works as expected (on the emulator).
For example, storing a 16x16 vector into e.g. a 19x19 buffer accounts for the proper stride for the enveloping sizes:

( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )

Likewise, a matrix multiply of pairwise bf16 into f32 works as expected.

( 124, 144 )    =  ( ( 1.0, 2.0, 3.0, 4.0 ),    (( 9.0, 10.0, 11.0, 12.0 ),
( 308, 360 ) )       ( 5.0, 6.0, 7.0, 8.0 ))  x  (13.0, 14.0, 15.0, 16.0 )):

Please let me know if you have any feedback.

Happy TMULing!

1 Like

I have a high level question:
What is the motivation behind introducing an AMX specific dialect rather than adding load/store ops to tensorflow dialect operating on tensor type? Is the idea here to introduce a more generic 2d matrix type as first class citizen rather than implementing matrices as derived types or specializations of the tensor Type? If it is the case, the dialect should be matrix dialect as a general design rather than being tied to AMX. Then, matrix would map to the 2d tiles in the case of AMX.
MLIR provides this freedom of adding new types and concepts but it is even better if they get reused for other purposes. matrix dialect could be used for other matmul 2d units whereas AMX dialect would only be used for AMX.
regarding optimizations, the AMX target specific optimizations that you talked about can be applied at the LLVM IR + AMX intrinsics level. So if the dialect represents only AMX, what are the types of optimizations that I can only do there and not at a lower level? Again, if this is general matrix type with fill, load, store, multiply and add operations, we can come up with common optimizations and reuse them elsewhere for free.

We have such a dialect: the vector dialect (“vectors” is a historical misnomer here, since they are really n-dimensional constructs; but they differ from the tensor type, so the name was already taken). Have you read @nicolasvasilache’s brilliant Vector design doc, the earlier AVX512 case studies, or any of the RFCs for Arm Neon/SVE dialects, and AVX512 dialect?

In a nutshell, the architectural-agnostic vector dialect (VV for Virtual Vector level) is used as long as possible during rewriting and progressive lowering. In the end, an architectural-specific vector dialect (HWV for Hardware Vector level) may (or may not) be used to take advantage of particular idioms. For many case, we found that extremely good performance can indeed be obtained with just VV, and relying on the LLVM “backend” to generate efficient code for the SIMD IR that is passed on. For very hardware specific features (AMX being a very clear example I would say), we introduce an HWV, at least until such a time the lowering to LLVM is better understood.

Thanks for working on this, Aart! I’m not working with AMX but I think this approach is also applicable to similar internal problems we have.

I think we all agree that optimizations should be implemented on a HW-agnostic dialect as much as we can. We recently introduced vector.load and vector.store ops with support for multi-dimensional vectors. They are lower-level than vector transfer ops and more constrained, which could make them suitable for the aforementioned HW-agnostic optimizations on 2-D matrices: ⚙ D96185 [mlir][Vector] Introduce 'vector.load' and 'vector.store' ops. It should be straightforward to convert them to the AMX dialect down the road.

This looks awesome. I’m curious how tileconfig (expensive operation) will work if you’ll put tile multiple in the loop, I guess it will be pulled out of the loop. More tricky question is what will happen if tile multiply is in the parallel loop (multiple async.execute regions) with iterations running concurrently on different threads.

If I’d be writing something like this in c++, I’d add thread local storage to keep track of tile config per thread.

1 Like

The AMX Dialect has landed in the main branch. Driven by review feedback, some minor changes were made compared to what is describe above, most obvious, the LLVM IR AMX part of the dialect has been squashed into the main AMX dialect.

I did a quick experiment with something like this, with a fixed tile configuration throughout

scf.for %i = %c0 to %c128 step %c16 {
    scf.for %j = %c0 to %c128 step %c16 {
      ...  
      amx.tile_store %arg2[%i, %j], %4
        : memref<128x128xf32>, vector<16x16xf32>
    }
  }

But alas, the “ldtilecfg” instruction was not hoisted out of the loop, even though the contents never change. Unless I am doing something wrong, I suspect this is simply a missing optimization in LLVM codegen (perhaps somebody from Intel can confirm and say if this is actively worked on).

If you are interested in running the Intel AMX integration tests (on emulation), just follow the steps below.

During configuration, include the following definitions. With cmake, for example,

cmake  ....
   ....
   -DMLIR_INCLUDE_INTEGRATION_TESTS=ON \
   -DMLIR_RUN_AMX_TESTS=ON \
   -DINTEL_SDE_EXECUTABLE=<path to emulator>

Then run the tests as usual.

cmake --build . --target check-mlir-integration

@aartbik kindly offered to prepare a few slides about this for tomorrow’s meeting! Thanks @aartbik :slight_smile:

1 Like

The presentation/discussion from the meeting last Thursday is available: slides and recording, thanks again @aartbik for driving this!

2 Likes

I see, so this dialect is purposely very specific to AMX and cannot be applied to other 2d TPUs that share very similar intrinsics: initialize or fill, load, store, multiply and add. I have the following concerns though:

  • The passes done like the verifier is duplication work that is already done in the backend (lowering to AMX tile LLVM intrinsics). Moreover, verifying that sizes do not exceed the tile sizes is too early to happen at this level IMO. This will inhibit the compiler from performing transformations like tiling, unroll and jam to retrieve the hardware tiles sizes and perform register reuse. Hardware tile sizes imposition should really happen at the backend level after performing this kind of optimizations.
  • Instead of the data layout (4 bytes packing), it seems you are using per byte computation. In _tile_dpbf16ps it takes a pair of bf16 elements to construct the 32 bits element, so the packing factor in this case is 2. (it is 4 pairs for 8 bits int type).
  • In this work, your are lowering from vector to AMX dialect then to LLVM IR AMX tile intrinsics. What are the optimizations that we cannot do at the vector level neither at the AMX tile intrinsics (backend) that should happen at the AMX dialect level?

Thanks,
Dounia

I am not sure the verification efforts are wasted, since the error messages will be much friendlier in terms of a MLIR context compared to errors generated much later by the backend on obscure internal intrinsics. Also, note that in the context in which this work is done, loop optimizations like tiling and such are supposed to be done by the higher levels in MLIR, while progressively lowering to the vector dialect, with each rewriting bringing the implementation closer and closer to the target ISA. Generating the AMX dialect would be the last step as far as MLIR is concerned.

I agree that if the LLVM backend would do all the proper tiling and instruction selection, we would not need this HWV dialect, but could simply use our VV dialect, and rely on generic intrinsics and 2-d data types. But no such support exists today, so until that time, this dialect provides a convenient bridge to experiment with lowering strategies from MLIR into AMX (please read Nicolas’ vector dialect doc if you have not done so already, it explains this approach in detail).

I don’t understand your “it seems you are using per byte computation” remark. The MLIR operations operate on pairs of bf16 into f32 or quads of int8 into int32, just as required by AMX.