[DialectConversion] Avoiding conversion for a function's entry BB

Hello,

I am trying to use the dialect conversion framework to apply some conversion pattern to the internal control flow inside a function. This is part of implementing “detensoring” for linalg-on-tensor ops which basically means converting such op instances that take and produce only 0D tensors to their equivalent ops that directly work on the underlying tensor element types. This is being implemented in these 2 patches: ⚙ D96271 [MLIR][LinAlg] Start detensoring implementation. and ⚙ D97148 {WIP: PLZ DON'T REVIEW YET}[MLIR][LinAlg] Detensorize interal CF..

To that end, we would like to avoid detensoring across function boundaries. This means that we would like to apply type conversion for all the basic blocks signatures inside a function except for the entry one. For example for the following (contrived) example:

func @main(%arg0: tensor<i32>) -> tensor<i32> attributes {iree.module.export} {
  br ^bb1(%arg0: tensor<i32>)
^bb1(%0: tensor<i32>):
  return %0 : tensor<i32>
}

we would like to convert it to something similar to:

func @main(%arg0: tensor<i32>) -> tensor<i32> attributes {iree.module.export} {
  %ex_arg0 = tensor.extract %arg0[] : tensor<i32>
  br ^bb1(%ex_arg0: i32)
^bb1(%0: i32):
  %1 = tensor.from_elements %0 : tensor<1xi32>
  %2 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
  return %2 : tensor<i32>
}

I am trying to use the dialect conversion framework to properly handle control flow conversion in this patch (⚙ D97148 {WIP: PLZ DON'T REVIEW YET}[MLIR][LinAlg] Detensorize interal CF.) (Note that the conversion within the BB boundary is implemented here: ⚙ D96271 [MLIR][LinAlg] Start detensoring implementation.).

The difficulty I am currently facing is in properly handling the function’s entry block. In particular, this block should be left as is without converting its signature. I do this by passing an “identity” TypeConverter::SignatureConversion instance to ConversionPatternRewriter::convertRegionTypes(...) as you can see in Detensorize.cpp:198 (copied here):

  LogicalResult
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    ...
    TypeConverter::SignatureConversion result(type.getNumInputs());
    result.addInputs(type.getInputs());

    SmallVector<Type, 1> newResults;
    if (failed(rewriter.convertRegionTypes(&mlir::impl::getFunctionBody(op),
                                           *typeConverter, &result))) {
      rewriter.cancelRootUpdate(op);
      return failure();
    }
    ...
}

The above code properly achieves the desired goal of not converting the entry BB’s arguments. Also, a target materialization is added in order to extract %arg0's element and pass the extracted value to the converted br op. However, during the OperationConverter::finalize(...) -> OperationConverter::legalizeConvertedArgumentTypes(...) -> ArgConverter::materializeLiveConversions(...), the framework tries to create a source materialization for %arg0 in the entry BB. Looking at the implementation, the last method in that sequence (i.e. ArgConverter::materializeLiveConversions) invokes the source materialization hook with a non-empty value only if the argReplacementValue is different from the original value (llvm-project/DialectConversion.cpp at main · llvm/llvm-project · GitHub). This makes me suspect that I didn’t setup the framework properly in someway or another because I believe a source materialization shouldn’t have been needed in this situation.

Below is the debug output:

//===-------------------------------------------===//
Legalizing operation : 'func'(0x7fd297104510) {
  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'func -> ()' {
    ** Insert  : 'tensor.from_elements'(0x7fd2971168d8)
    ** Insert  : 'linalg.tensor_reshape'(0x7fd297116c18)

    //===-------------------------------------------===//
    Legalizing operation : 'func'(0x7fd297104510) {
    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tensor.from_elements'(0x7fd2971168d8) {
      %1 = "tensor.from_elements"(%0) : (i32) -> tensor<1xi32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'linalg.tensor_reshape'(0x7fd297116c18) {
      %2 = "linalg.tensor_reshape"(%1) {reassociation = []} : (tensor<1xi32>) -> tensor<i32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//
  } -> SUCCESS : pattern applied successfully
} -> SUCCESS
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'std.br'(0x7fd29710b2e0) {
  "std.br"(<<UNKNOWN SSA VALUE>>)[^bb1] : (tensor<i32>) -> ()

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'std.br -> ()' {
    ** Insert  : 'tensor.extract'(0x7fd297204088)

    //===-------------------------------------------===//
    Legalizing operation : 'std.br'(0x7fd29710b2e0) {
      "std.br"(%0)[^bb1] : (i32) -> ()

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tensor.extract'(0x7fd297204088) {
      %0 = "tensor.extract"(<<UNKNOWN SSA VALUE>>) : (tensor<i32>) -> i32

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//
  } -> SUCCESS : pattern applied successfully
} -> SUCCESS
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'std.return'(0x7fd29710b4b0) {
  "std.return"(<<UNKNOWN SSA VALUE>>) : (tensor<i32>) -> ()

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
Assertion failed: (inputs.size() == 1), function operator(), file /Users/ergawy/work/llvm-project/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp, line 91.
PLEASE submit a bug report to https://bugs.llvm.org/ and include the crash backtrace.
Stack dump:
0.      Program arguments: /Users/ergawy/work/llvm-project/build/bin/mlir-opt /Users/ergawy/work/llvm-project/mlir/test/Dialect/Linalg/detensorized_while.mlir -linalg-detensorize -func-detensorize -debug -print-ir-after-all
1.      Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
0  mlir-opt                 0x00000001089cfa1b llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) + 43
1  mlir-opt                 0x00000001089ce728 llvm::sys::RunSignalHandlers() + 248
2  mlir-opt                 0x00000001089d0077 SignalHandler(int) + 295
3  libsystem_platform.dylib 0x00007fff6d7495fd _sigtramp + 29
4  libsystem_platform.dylib 0x0000000000000b40 _sigtramp + 18446603338679809376
5  libsystem_c.dylib        0x00007fff6d61f808 abort + 120
6  libsystem_c.dylib        0x00007fff6d61eac6 err + 0
7  mlir-opt                 0x000000010a4e1b01 std::__1::__function::__func<std::__1::function<llvm::Optional<mlir::Value> (mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)> mlir::TypeConverter::wrapMaterialization<mlir::Ty
pe, (anonymous namespace)::DetensorizeTypeConverter::DetensorizeTypeConverter()::'lambda0'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)>((anonymous namespace)::DetensorizeTypeConverter::DetensorizeTypeConverter()::'lamb
2  mlir-opt                 0x00000001089d0077 SignalHandler(int) + 295
3  libsystem_platform.dylib 0x00007fff6d7495fd _sigtramp + 29
4  libsystem_platform.dylib 0x0000000000000b40 _sigtramp + 18446603338679809376
5  libsystem_c.dylib        0x00007fff6d61f808 abort + 120
6  libsystem_c.dylib        0x00007fff6d61eac6 err + 0
7  mlir-opt                 0x000000010a4e1b01 std::__1::__function::__func<std::__1::function<llvm::Optional<mlir::Value> (mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)> mlir::TypeConverter::wrapMaterialization<mlir::Ty
pe, (anonymous namespace)::DetensorizeTypeConverter::DetensorizeTypeConverter()::'lambda0'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)>((anonymous namespace)::DetensorizeTypeConverter::DetensorizeTypeConverter()::'lamb
da0'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)&&)::'lambda'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location), std::__1::allocator<std::__1::function<llvm::Optional<mlir::Value> (mlir::OpBuilder&, mlir:
:Type, mlir::ValueRange, mlir::Location)> mlir::TypeConverter::wrapMaterialization<mlir::Type, (anonymous namespace)::DetensorizeTypeConverter::DetensorizeTypeConverter()::'lambda0'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Lo
cation)>((anonymous namespace)::DetensorizeTypeConverter::DetensorizeTypeConverter()::'lambda0'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)&&)::'lambda'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)>,
 llvm::Optional<mlir::Value> (mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)>::operator()(mlir::OpBuilder&, mlir::Type&&, mlir::ValueRange&&, mlir::Location&&) (.cold.3) + 33
8  mlir-opt                 0x0000000108ba6058 std::__1::__function::__func<std::__1::function<llvm::Optional<mlir::Value> (mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)> mlir::TypeConverter::wrapMaterialization<mlir::Ty
pe, (anonymous namespace)::DetensorizeTypeConverter::DetensorizeTypeConverter()::'lambda0'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)>((anonymous namespace)::DetensorizeTypeConverter::DetensorizeTypeConverter()::'lamb
da0'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)&&)::'lambda'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location), std::__1::allocator<std::__1::function<llvm::Optional<mlir::Value> (mlir::OpBuilder&, mlir:
:Type, mlir::ValueRange, mlir::Location)> mlir::TypeConverter::wrapMaterialization<mlir::Type, (anonymous namespace)::DetensorizeTypeConverter::DetensorizeTypeConverter()::'lambda0'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Lo
cation)>((anonymous namespace)::DetensorizeTypeConverter::DetensorizeTypeConverter()::'lambda0'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)&&)::'lambda'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)>,
 llvm::Optional<mlir::Value> (mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)>::operator()(mlir::OpBuilder&, mlir::Type&&, mlir::ValueRange&&, mlir::Location&&) + 232
9  mlir-opt                 0x00000001092857ab (anonymous namespace)::OperationConverter::convertOperations(llvm::ArrayRef<mlir::Operation*>) + 3499
10 mlir-opt                 0x0000000109286e09 mlir::applyPartialConversion(mlir::Operation*, mlir::ConversionTarget&, mlir::FrozenRewritePatternList const&, llvm::DenseSet<mlir::Operation*, llvm::DenseMapInfo<mlir::Operation*> >*) + 73
11 mlir-opt                 0x0000000108ba7c76 (anonymous namespace)::FuncDetensorize::runOnFunction() + 1590
12 mlir-opt                 0x000000010921cad0 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 512
13 mlir-opt                 0x000000010921cf65 mlir::detail::OpToOpPassAdaptor::runPipeline(llvm::iterator_range<llvm::pointee_iterator<std::__1::unique_ptr<mlir::Pass, std::__1::default_delete<mlir::Pass> >*, mlir::Pass> >, mlir::Operati
on*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 133
14 mlir-opt                 0x0000000109222bc4 mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_8::operator()(llvm::MutableArrayRef<mlir::OpPassManager>) const + 452
15 mlir-opt                 0x000000010921dc21 mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool) + 1761
16 mlir-opt                 0x000000010921cc67 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 919
17 mlir-opt                 0x000000010921f76a mlir::PassManager::run(mlir::Operation*) + 762
18 mlir-opt                 0x00000001091faf9d performActions(llvm::raw_ostream&, bool, bool, llvm::SourceMgr&, mlir::MLIRContext*, mlir::PassPipelineCLParser const&) + 397
19 mlir-opt                 0x00000001091f90f0 processBuffer(llvm::raw_ostream&, std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer> >, bool, bool, bool, bool, mlir::PassPipelineCLParser const&, mlir::Dia
lectRegistry&) + 304
20 mlir-opt                 0x00000001091f9c84 mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&, bool) + 2788
21 mlir-opt                 0x00000001087bcfbc main + 140
22 libdyld.dylib            0x00007fff6d550cc9 start + 1

Any pointers on where I might have gone wrong?

Kareem

I have a hunch the DetensorizeTypeConverter’s conversion for TensorType might not be set up quite right. Specifically, for an entry block, if you want the argument types to be legal, they should convert to themselves. I commented on ⚙ D96271 [MLIR][LinAlg] Start detensoring implementation. where I think the logic is going wrong for entry block arguments.

I think the legality of entry block arguments should be driven by the conversion target (line 224 in Detensorize.cpp in the patch, relevant code copied below) and not the type converter: DetensorizeTypeConverter. I might be wrong but the reason I think this is the proper way is that the type converter is agnostic to the value being converted; it doesn’t differentiate between a value that’s an entry block’s argument and a value that’s an internal block argument.

    target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
      return std::all_of(std::next(op.getBody().begin()), op.getBody().end(),
                         [&](Block &block) {
                           return typeConverter.isLegal(
                               block.getArgumentTypes());
                         });
    });

Does that make sense?

That dynamic legality callback you shared is just calling isLegal on the type converter, so I think this still comes down to what types the type converter considers legal. Calling typeConverter.isLegal just checks if the types convert to themselves: llvm-project/DialectConversion.cpp at 3b148d6f991181a1b8f089c4bc2126e1a6c1212d · llvm/llvm-project · GitHub.

I think the entry block arguments do not convert to themselves using the logic in DetensorizeTypeConverter, so they are not legal and the framework has to legalize them, which is why you are seeing materializations. You should be able to keep the type converter logic as-is if you ensure the entry block arguments are always legal in that callback.

This is another way of rephrasing the problem. Since you want different behavior for entry blocks versus the internal blocks, you will need to differentiate between the two somewhere. Where does this happen? Right now, there doesn’t appear to be any distinction in the type converter or the dynamic legality callback.

This feels buggy to me:

      bool isDroppedArg = argReplacementValue == origArg;
      if (argReplacementValue.getType() == origArg.getType() && !isDroppedArg)
        continue;

maybe the code needs to be changed to better distinguish the “I didn’t change it at all” and “dropped the arg” cases. @River707 ? What’s the best way to have an entry block SignatureConversion that just means “don’t change it”.

One thing to try, which I recall helping me in the past, is to change a single call to SignatureConversion::addInputs(ArrayRef<Type> types) into N calls to SignatureConversion::addInputs(unsigned origInputNo, ArrayRef<Type> types) with a single one for each type. (I needed to do that in DecomposeCallGraphTypes for a reason I didn’t fully explore).

I could have swore that SignatureConversion had better documentation, but I guess not. The way that it is supposed to work is that SignatureConversion is the complete description of the new signature, i.e. if one of the original inputs wasn’t remapped it doesn’t exist anymore. This invariant is documented on the addInputs method that is used in the original post:

    /// Append new input types to the signature conversion, this should only be
    /// used if the new types are not intended to remap an existing input.
    void addInputs(ArrayRef<Type> types);

If an argument gets remapped to itself, right now you would likely need to remap each individual argument to itself with the same type. That is pretty ugly when you are doing that for the entire block though, we just haven’t really run into this use case yet. For the use case described above, I’d probably just add a method that is something like convertNonEntryRegionTypes to ConversionPatternRewriter that converts the signature of everything but the entry block. I don’t think we should try and adapt SignatureConversion for this case, given that you shouldn’t need to provide a signature conversion at all.

– River

1 Like

Thank you all for your help. I added the method as @River707 suggested and updated the review (⚙ D97148 [MLIR][LinAlg] Detensorize interal function control flow.).

1 Like