Canonicalization of 'x + (+0.0)' in tosa

I found that mlir/test/Dialect/Tosa/canonicalize.mlir (link) is folding ‘x + (+0.0)’ into ‘x’.

// CHECK-LABEL: @add_zero_float
func @add_zero_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
  // CHECK: return %arg0
  // CHECK-NOT: tosa.add
  %zeros = "tosa.const"() {value = dense<0.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
  %1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
  return %1 : tensor<2x3xf32>
}

If %arg0 is -0.0, this isn’t correct because -0.0 + (+0.0) is +0.0.
My question is - does tosa’s canonicalization assume unsafe fp math by default?

@sjarus

It appears that this is an area where the TOSA specification isn’t clear on behavior. In the specification, we require support for signed zero, and define multiply and divide by +/-0, but we don’t define addition with -0.0. So the current canonicalization is taking advantage of a grey area.

I think the right thing to do is disable this specific canonicalization for FP, leaving it in place for integer operation. In addition, I’ll try to clarify this specific level of behavior in the specification. (Section 1.9 of the specification here: :zap: TOSA (mlplatform.org) has the details of what we’ve defined for FP behavior.)

@rsuderman

Thank you for a quick reply. The specification document in the link looks great well.

Hello, @eric-k I think there is a similar problem in tosa.conv2d during the tosa-to-linalg pass.
The MLIR code

func @conv(%arg0: tensor<2x4x4x3xf32>, %arg1: tensor<16x3x6x3xf32>, %arg2: tensor<16xf32>) -> tensor<2x6x9x16xf32> {
    %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [2, 2, 5, 5], stride = [1, 1]} : (tensor<2x4x4x3xf32>, tensor<16x3x6x3xf32>, tensor<16xf32>) -> tensor<2x6x9x16xf32>
    return %0 : tensor<2x6x9x16xf32>
}

is transformed into

#map0 = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d3)>
module  {
  func @conv(%arg0: tensor<2x4x4x3xf32>, %arg1: tensor<16x3x6x3xf32>, %arg2: tensor<16xf32>) -> tensor<2x6x9x16xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %0 = linalg.pad_tensor %arg0 low[0, 2, 5, 0] high[0, 2, 5, 0]  {
    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):  // no predecessors
      linalg.yield %cst : f32
    } : tensor<2x4x4x3xf32> to tensor<2x8x14x3xf32>
    %cst_0 = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
    %1 = linalg.init_tensor [3, 6, 3, 16] : tensor<3x6x3x16xf32>
    %2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<16x3x6x3xf32>) outs(%1 : tensor<3x6x3x16xf32>) {
    ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
      linalg.yield %arg3 : f32
    } -> tensor<3x6x3x16xf32>
    %3 = linalg.init_tensor [2, 6, 9, 16] : tensor<2x6x9x16xf32>
    %cst_1 = arith.constant 0.000000e+00 : f32
    %4 = linalg.fill(%cst_1, %3) : f32, tensor<2x6x9x16xf32> -> tensor<2x6x9x16xf32> 
    %5 = linalg.init_tensor [2, 6, 9, 16] : tensor<2x6x9x16xf32>
    %6 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%0, %2 : tensor<2x8x14x3xf32>, tensor<3x6x3x16xf32>) outs(%4 : tensor<2x6x9x16xf32>) -> tensor<2x6x9x16xf32>
    %7 = linalg.generic {indexing_maps = [#map2, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %6 : tensor<16xf32>, tensor<2x6x9x16xf32>) outs(%5 : tensor<2x6x9x16xf32>) {
    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):  // no predecessors
      %8 = arith.addf %arg3, %arg4 : f32
      linalg.yield %8 : f32
    } -> tensor<2x6x9x16xf32>
    return %7 : tensor<2x6x9x16xf32>
  }
}

after going through the tosa-to-linalg pass, the padded element becomes +0.0 (%cst).
According to the tosa specification v0.23.0, the padded element for operation tosa.conv2d should not be added to the output elements. I think this tosa-to-linalg transformation is incorrect since ‘x+0.0’ is not ‘x’. The output tensor could have different element values.

That’s another good catch. This one is harder to determine what the right answer is, as the common software path has an easier implementation if it can add the +0.0.

Do you have a particular use case in mind where the -0 + 0 = +0 is significant? Understanding how this is used would help me reason about how to update the specification language or potential changes in linalg.

Thanks.

Thank you for the reply!
I think there isn’t any significant problem in the transformation.
I just pointed out since I am interested in formalizing the tosa operations, but there seemed to be a mismatch between the specification and the transformation.

2 Likes