Refactoring Ops into their own functions

Hi there,

I’m looking for a tool for refactoring ops (specifically linalg::GenericOp) into their own functions like the following:

//before
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 808 : i32}}  {
  func @function(%arg0: tensor<64x128x1024xf32>, %arg1: tensor<64x128x1024xf32>) -> tensor<64x128x1024xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "args_0,args_0_1", outputs = "Identity"}} {
    %0 = linalg.init_tensor [64, 128, 1024] : tensor<64x128x1024xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<64x128x1024xf32>, tensor<64x128x1024xf32>) outs(%0 : tensor<64x128x1024xf32>) {
    ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):  // no predecessors
      %2 = addf %arg2, %arg3 : f32
      linalg.yield %2 : f32
    } -> tensor<64x128x1024xf32>
    return %1 : tensor<64x128x1024xf32>
  }
}
// after
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 808 : i32}}  {
  func @genOp1(%arg0: tensor<64x128x1024xf32>, %arg1: tensor<64x128x1024xf32>, %0: tensor<64x128x1024xf32>) -> tensor<64x128x1024xf32> {
    %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<64x128x1024xf32>, tensor<64x128x1024xf32>) outs(%0 : tensor<64x128x1024xf32>) {
    ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):  // no predecessors
      %2 = addf %arg2, %arg3 : f32
      linalg.yield %2 : f32
    } -> tensor<64x128x1024xf32>
    return %1 : tensor<64x128x1024xf32>
  }
  func @function(%arg0: tensor<64x128x1024xf32>, %arg1: tensor<64x128x1024xf32>) -> tensor<64x128x1024xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "args_0,args_0_1", outputs = "Identity"}} {
    %0 = linalg.init_tensor [64, 128, 1024] : tensor<64x128x1024xf32>
    %1 = call @genOp1(%arg0, %arg1, %0) : (tensor<64x128x1024xf32>, tensor<64x128x1024xf32>, tensor<64x128x1024xf32>) -> tensor<64x128x1024xf32>
    return %1 : tensor<64x128x1024xf32>
  }
}

Is there a tool in the MLIR framework that does that? The reason why I want to do this is because I need to tile these generic ops depending on the input and output tensor shapes. The createLinalgTilingPass() applies the same tiling sizes to every genericOp within a funcOp. I need to tile each genericOp separately, so I want to refactor each genericOp into its own function and then run createLinalgTilingPass() on each one of these functions separately then inline them back into the original function. Is there a better way to do this process?

Thank you in advance