I tried this with my own type: but its throwing errors:
here is what i have:
struct ThreadGroupTypeStorage : public TypeStorage {
ThreadGroupTypeStorage(ArrayRef<Thread> threads) : threads(threads) {}
using KeyTy = ArrayRef<Thread>;
static auto hashKey(const KeyTy &key) -> llvm::hash_code {
return llvm::hash_value(key);
}
auto operator==(const KeyTy &key) const -> bool {
return KeyTy(threads) == key;
}
static auto construct(TypeStorageAllocator &allocator, KeyTy key)
-> ThreadGroupTypeStorage * {
return new (allocator.allocate<ThreadGroupTypeStorage>())
ThreadGroupTypeStorage(key);
}
ArrayRef<Thread> threads;
};
struct ThreadTypeStorage : public TypeStorage {
ThreadTypeStorage(ThreadType type) : type(type) {}
/// The hash key used for uniquing, just type
using KeyTy = ThreadType;
static auto hashKey(const KeyTy &key) -> llvm::hash_code {
return llvm::hash_value(key);
}
auto operator==(const KeyTy &key) const -> bool { return KeyTy(type) == key; }
static auto construct(TypeStorageAllocator &allocator, KeyTy key)
-> ThreadTypeStorage * {
return new (allocator.allocate<ThreadTypeStorage>()) ThreadTypeStorage(key);
}
ThreadType getType() const { return type; }
ThreadType type;
};
struct Thread : public Type::TypeBase<Thread, Type, detail::ThreadTypeStorage> {
using Base::Base;
static Thread get(MLIRContext *context);
static Thread get(MLIRContext *context, ThreadType type);
};
struct ThreadGroup
: public Type::TypeBase<ThreadGroup, Type, detail::ThreadGroupTypeStorage> {
using Base::Base;
static ThreadGroup get(MLIRContext *context, ArrayRef<Thread> elements);
static ThreadGroup get(MLIRContext *context);
};
My Ops:
def Thread_CreateOp : Pulse_Op<"thread", [NoSideEffect]> {
let arguments = (ins ThreadTypeEnum:$type, IndexAttr:$id);
let results = (outs Thread:$out);
}
def ThreadGroup_CreateOp : Pulse_Op<"thread_group", [NoSideEffect]> {
let results = (outs ThreadGroup:$out);
}
To give a bit of background on the IRs:
Original
%a = foo.create_core {id = 0 : i32} : !foo.core
foo.execute (%a) : (!foo.core) -> ()
Expected conversion:
%b1 = bar.create_thread {id = 0 : i32} : !bar.thread
%b2 = bar.create_thread {id = 1 : i32} : !bar.thread
%b3 = bar.create_thread {id = 2 : i32} : !bar.thread
%b = bar.create_thread_group(%b1, %b2, %b3) : !bar.thread_group
bar.schedule (%b) : (!bar.thread_group) -> ()
in my conversion code, im doing:
- When converting create_core
Thread t1 = rewriter
.create<Thread_CreateOp>(
loc, Thread::get(ctx), ThreadType::Lazy,
rewriter.getIndexAttr(originalOp.id().getValue()))
.out();
.getType()
.dyn_cast<Thread>();
Thread t2 = rewriter
.create<Thread_CreateOp>(
loc, Thread::get(ctx), ThreadType::Lazy,
rewriter.getIndexAttr(originalOp.id().getValue()))
.out();
.getType()
.dyn_cast<Thread>();
ThreadGroup tg = rewriter.create<ThreadGroup_CreateOp>(
loc, ThreadGroup::get(ctx, ArrayRef({t1, t2}})));
rewriter.replaceOp(originalOp, {tg.out()});
- When converting execute
I will lookup ThreadGroup in the operands and go from there.
But here is an error i am getting:
Legalizing operation : 'quir.builtin_U'(0x7fe273504790) {
"foo.execute" (%a) : (!foo.core) -> ())
* Fold {
} -> FAILURE : unable to fold
* Pattern : 'foo.execute-> ()' {
** Failure : unable to materialize a conversion for operand #0, from '!bar.thread_group' to '!bar.thread_group'
} -> FAILURE : pattern failed to match
} -> FAILURE : no matched legalization pattern