Goals: This doc is to provide a specific way of representing quantization constraints and implementing quantization transformations in MLIR for different types of targets.
Non-Goals: This doc isn’t aiming to provide an end-to-end general solution, starting from authoring languages. It also isn’t prescriptive to which levels or dialects quantization should happen, other than that the target constraints of the quantization dialect should be specified and resolved atomically.
Background
Quantization, as defined here, is the process of transforming a machine learning program into an approximated representation with lower precision operations available in the target device, which operates over values whose precision has been reduced via an affine transformation.
Quantization of a source model (i.e. program) to be runnable on the target device requires us to model two types of information:
- A set of target constraints. This involves representing the devices execution capabilities. For instance, what “kernels” (or units of computation) the devices have and what their signatures are.
- A set of model constraints. This involves representing information (via mlir::quant::QuantizedType (QT) below) relevant to how to quantize the values in the source model. For instance, for “quantization aware training” you can annotate the model with the scale and zero point learned during training (via mlir::quant::UniformQuantizedType (UQT) below). Another example is that of a “post-training quantization” process where you can annotate high precision statistics (min and max values) recorded (via mlir::quant::CalibratedQuantizedType below.
Proposal
This proposal describes a quantization module that provides mechanisms for modeling these two constraints, and also a set of compiler passes to resolve such constraints for a given model, until we resolve into a materialized and executable quantized program. This module, which we call “Propagation module”, is designed to fit into the quantization life cycle of a model as depicted in the following diagram:
Another important aspect of the proposed Propagation module is that target constraints of the quantization dialect should be specified and resolved atomically. This means that as far as the quantization passes and constraints are concerned there is no need to introspect within the boundaries defined by a “logical” kernel (i.e. everything happens at the boundaries -inputs, outputs - of the operation).
Constraints for quantization
Target Constraints
Per-kernel target constraints: kernels or lowerings need to be available for running (or even code-generating) the quantized model.
- Key: <kernel_name, kernel_signature>. The kernel_signature is a sequence of tuples (storage_type, storage_type_min/max) (can use mlir::quant::AnyQuantizedType) for all inputs and outputs. Each tuple describes the supported quantization type of the corresponding kernel port (Quantization usually doesn’t treat inputs and outputs of a kernel differently, thus we use “port” to refer to an indexed input or output of the kernel.). The tuple can also be a wildcard. The signature in the lookup key can be partially specified, and the unspecified part is derived from the first matched signature (A default value is used if it is a wildcard).
- Value: QuantFn(inputs: [QT], outputs: [QT]) -> ([int], [int]). This constraint is a (C++) function to derive the quantization parameters for some ports (determines whether it’s an input/output and its index) from other ports of the kernel. If it is successful (some QT is updated to mlir::quant::UniformQuantizedType), it returns the indices of the ports being updated, so a work-list algorithm can be developed to infer the quantization parameters of adjacent ops. Shorthands such as ‘‘same_scale’’, ‘‘fixed_scale’’, ‘‘multiplier_scale’’, ‘‘multiplier_accumultor_scale’’ can be used for common patterns.
To apply these target constraints for the backends with fixed kernels (e.g. as currently applied to TensorFlow Lite), kernel fusion can be required to wrap the high-precision ops in the annotated model as individual “logical” kernels.
Model Constraints
Per-tensor model constraints: per-tensor information for restricting quantization, or providing required information for quantization. These are modeled as MLIR element types which are extended from mlir::quant::QuantizedType:
- mlir::quant::UniformQuantizedType (a tuple of scale, zero point, storage_type, storage_type_min/max>): the fully quantization parameter for the tensor. The final target kernel should satisfy the storage_type and storage_type_min/max restriction for this tensor and the scale restriction among all the inputs and outputs.
- mlir::quant::AnyQuantizedType (bit width): quantize this tensor with at least the referred bit width. The final target kernel should provide sufficient storage for this tensor.
- mlir::quant::CalibratedQuantizedType (a pair of floating-point min/max): use these calibration statistics as the hints to calculate the quantization parameters. The bit width needs to be inferred from the chosen target kernels.
- mlir::quant::NoQuantizedType: keep this tensor as is even if the quantized kernel is available.
This list can be extended for other quantization types, i.e. bfloat16.
Also, note this proposal is not prescriptive of the MLIR transformations (that take place before the quantization passes) used to set these constraints for quantization: it should be possible to use ops from MLIR QuantOps dialect (mlir::quant::qcast, mlir::quant::dcast, mlir::quant::stats ops, mlir::quant::region, etc), as well as set by replacing tensor element type if the op allows that.
A source model with these specifications is called an annotated model.
Quantization region: the “logical” kernel
A generic region op can be used to wrap high-precision ops from source dialect into a target kernel. The values used by the enclosing operands are listed explicitly as op arguments (kernel inputs), and the values defined and used by enclosing operands from other regions are listed as op returns (kernel outputs). Any model constraints on these kernel inputs/outputs are converted to quantization attributes (input_specs/output_specs) of the region ops. These attributes and kernel names can form a lookup key to retrieve the target constraints. Any model constraints on the tensors inside of the region of the kernels will NOT be considered during quantization and only used as the settings to the “real” target kernels or hints to code generation. A tentative implementation of the region op is:
%.. = quant.region (imported_ssa_values) {
…high precision ops w/ model constraints for intermediates values…
quant.return exported_ssa_values
} {input_specs: [QT], output_specs: [QT], kernel = “kernel_name”} :
(high-precision tensor types) -> (high-precision tensor types)
The advantage of this region op design is that it doesn’t change the system integrity of the high-precision ops dialect. Quantization passes treat each region ops as an atomic unit and only update the quantization attributes of the region ops (from QTs to UQTs). These high-precision ops can be inlined after quantization for other optimizations (which need to be quantization-aware and handle the quantization parameters properly).
Propagation pass: resolving target constraints
The propagation pass resolves target constraints for the annotated models in a greedy way. Functions are inlined or duplicated for quantization because different callsite might use different quantization parameters. It iterates the region ops (essentially target kernels) and uses the “logical” kernel name and bit widths of the ports to look up the target constraints. Then it applies the returned constraint (QuantFn or shorthands) on the ports, updates their quantization parameters and returns the indices of the ports been updated. The users and defining ops of these modified ports can be identified by the returned indices and considered in the next iterations. If there are any conflicts, i.e. there is a quantization parameter being set for a visited port, a “rescale” is required to resolve this conflict. At each step, the algorithm guarantees the kernel constraints of the current kernel are satisfied. This might introduce “rescale” in the quantized model, but it doesn’t necessarily introduce end-to-end performance penalty. It is also deterministic, so the “rescale” can be eliminated by annotating the model at the right locations and with the right restrictions.
The propagation stops when no more quantization parameters can be updated, and the propagation results are used to update the quantization attributes of the region ops. A model is called quantized model after the quantization parameter propagation. A model is fully quantized when all the kernels in the model have quantization parameters for all their quantizable ports and the QT in the quantization attributes are set to UQT.
Export: materializing quantization
The quantized model can be consumed by different target backends to generate an executable model. This is to rewrite the “logical” kernels to real kernels:
- Fixed kernel targets: The quantized region ops can be rewritten to the registered ops in the target dialect, and the quantization parameters of the ops are set by the latest quantization attributes. If the quantized region ops are coarser than the target dialect ops, the legalization pattern needs to derive the quantization parameters for the intermediate tensors from that of the old tensors;
- Code-gen kernel targets: The quantized region ops are inlined and the quantization parameters of the ports are expressed by code-gen primitive ops (elemental-wise multiplication, shift, etc.). Compiler optimization can be applied to eliminate extra primitives. For example, the “scaling” of the inputs can be pushed forward and combined with the “scaling” of the outputs, so rescaling only happens on the outputs.
Other aspects of a quantization framework
The following are aspects of a quantization framework (per the diagram above) that are relevant but not directly covered in this proposal. The module proposed here should be generally applicable to any such implementations of a quantization framework.
Intermediate transformations: being as accurate as possible
There are multiple ways to generate the annotated model from the source model, for example:
- quantization-aware training: the quantization emulation ops are converted to UQT;
- post training calibration: the tensor statistics are converted to mlir::quant::CalibratedQuantizedType;
- parameter search: the explored bit widths are converted to mlir::quant::AnyQuantizedType;
This framework doesn’t have any restrictions on which approach should be used, and in practice, a combination of different approaches can refine the model annotation for higher accuracy. Each approach can use their annotation representation, and later on, they will all be converted to the quantization attributes of the quantization region ops.
Kernel fusion pass: from ops to kernels
The kernel fusion pass (separate from the previously described quantization pass) is to encapsulate high precision ops of the model into low-precision target kernels from the target dialect. It can be pattern-match or heuristic based, and customized for different targets:
- Fixed kernel targets: This is essentially a set of legalization patterns from source dialect to target dialect;
- Code-gen kernel targets: Some well-known patterns (heuristics), such as “matmul->add”, are captured for better code-gen quantity. More aggressive code-gen tools, such as XLA might want to fuse all the linear ops between non-linear ops into a single kernel. The decision to pursue a low-precision rewrite can be reversed later in the process by restoring the encapsulated high precision ops to the IR with appropriate casts. This can be useful as a fallback for multi-target scenarios.
After kernel fusion, all quantizable ops from the source dialect are wrapped by the quantization region ops.
Kernel fusion pass can run before or after quantization (and multiple times!). Fusion after quantization can be used to legalize the quantized model from one dialect to another and the legalization rules need to handle the quantization parameters properly. Also the kernels can be implied by the information from the authoring languages with some op hierarchy (a level of ops in the hierarchy can be mapped to the kernels directly). This can be discussed in a separate design doc.