I tried to order my thoughts around this and came up with the following description of what I would like to do:
These are descriptions of what the operations should compute not how they should be implemented (see next section for that).
Outputs a scalar containing the sum / product / minimum / maximum / bitwise AND / bitwise OR / bitwise XOR of all elements in the input array.
Exclusive / Inclusive Scan
For each element reduce all elements up to that element, including or excluding that element respectively.
Partition (key-value-pair) elements by the provided unsigned integer key. Note: This is not a generic sort as it is not comparison-based and can not handle floats or other types of keys.
An aggregation with a 1 bit key and one partition is discarded. Thus, it selects the elements by a boolean condition.
Sorts first using aggregation, then reduces each partition to a bin. So the number of different keys in the input is the same as the number of bins in the output (which is a histogram).
Run Length Encoding
The input is already sorted (e.g. by an aggregation), count the number of elements (which share the same key) in each partition.
Swap time and frequency domain of a complex input to a complex output, encoded as vec2 / 2D floats. Reverse direction can be done by additionally multiplying with a 1/N scaling factor.
Uses one input array as sliding window over another input array and outputs the dot product for each shift offset.
- Reduction: Already exists up to workgroup level
- Scan, Aggregation, Selection, Binning, Run Length Encoding: Single-pass Parallel Prefix Scan with Decoupled Look-back
- FFT: Cooley–Tukey algorithm with power-of-two radices.
- Convolution: Perform one FFT for the two inputs each, multiply them element wise and do another (inverse) FFT on the product to get the output.
Many of these operations can be build by composition form the lower levels upwards.
It might make sense to expose these lower level versions as well.
- Low level: Subgroup / warp / wavefront / SIMD-vector
- Mid level: Workgroup / thread block
- High level: Kernel dispatch / grid
Scan, FFT and convolution can be done in arbitrary dimensions like this:
- Run 1D once for each row
- Run 1D once of each column
- etc …
I would start out by implementing reduction for the other two levels, then do the scan related operations and finally the FFT. I guess some would be a set of Op-Specific passes inside the GPU dialect and others need lowering to spv, nvvm etc.