forked from OSchip/llvm-project
Generate some of the boilerplate for reference implementation specification
PiperOrigin-RevId: 229735735
This commit is contained in:
parent
0eebe6ffd9
commit
8cb1781657
|
@ -10,20 +10,20 @@ func @fn() {
|
|||
}
|
||||
|
||||
// CHECK: block {
|
||||
// CHECK-NEXT: for(idx($8)=$12 to $4 step $13) {
|
||||
// CHECK-NEXT: for(idx($9)=$12 to $5 step $13) {
|
||||
// CHECK-NEXT: for(idx($10)=$12 to $6 step $13) {
|
||||
// CHECK-NEXT: for(idx($11)=$12 to $7 step $13) {
|
||||
// CHECK-NEXT: lhs($14) = store( ... );
|
||||
// CHECK-NEXT: };
|
||||
// CHECK-NEXT: };
|
||||
// CHECK-NEXT: };
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK: for(idx($12)=$16 to $4 step $17) {
|
||||
// CHECK: for(idx($13)=$16 to $5 step $17) {
|
||||
// CHECK: for(idx($14)=$16 to $6 step $17) {
|
||||
// CHECK: for(idx($15)=$16 to $7 step $17) {
|
||||
// CHECK: lhs($18) = store( ... );
|
||||
// CHECK: };
|
||||
// CHECK: };
|
||||
// CHECK: };
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: block {
|
||||
// CHECK-NEXT: for(idx($21)=$23 to $19 step $24) {
|
||||
// CHECK-NEXT: for(idx($22)=$23 to $20 step $24) {
|
||||
// CHECK-NEXT: lhs($25) = store( ... );
|
||||
// CHECK-NEXT: };
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK: for(idx($27)=$29 to $23 step $30) {
|
||||
// CHECK: for(idx($28)=$29 to $24 step $30) {
|
||||
// CHECK: lhs($31) = store( ... );
|
||||
// CHECK: };
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
|
|
|
@ -11,13 +11,6 @@ def X_AddOp : Op<"x.add">,
|
|||
// TODO: extract referenceImplementation to Op.
|
||||
// TODO: shrink the reference implementation
|
||||
code referenceImplementation = [{
|
||||
auto *lhsMemRef = *(f->getArguments().begin());
|
||||
auto *rhsMemRef = *(f->getArguments().begin() + 1);
|
||||
auto *resultMemRef = *(f->getArguments().begin() + 2);
|
||||
|
||||
Bindable lhs, rhs, result;
|
||||
auto lhsShape = emitter.makeBoundSizes(lhsMemRef);
|
||||
|
||||
auto ivs = makeBindables(lhsShape.size());
|
||||
Bindable zero, one;
|
||||
// Same bindable, all equal to `zero`.
|
||||
|
@ -33,4 +26,4 @@ def X_AddOp : Op<"x.add">,
|
|||
}];
|
||||
}
|
||||
|
||||
// CHECK: printRefImplementation
|
||||
// CHECK: printRefImplementation
|
||||
|
|
|
@ -44,16 +44,45 @@ static void emitReferenceImplementations(const RecordKeeper &recordKeeper,
|
|||
<< " edsc::ScopedEDSCContext raiiContext;\n"
|
||||
<< " Stmt block;\n"
|
||||
<< " FuncBuilder builder(f);\n"
|
||||
<< " if (false) {}";
|
||||
<< "if (false) {}";
|
||||
for (auto *def : defs) {
|
||||
Operator op(def);
|
||||
auto ref = def->getValueInit("referenceImplementation");
|
||||
if (!ref)
|
||||
continue;
|
||||
os << "else if (opName == \"" << op.getOperationName() << "\") {\n"
|
||||
<< " edsc::MLIREmitter emitter(&builder, f->getLoc());\n"
|
||||
<< ref->getAsUnquotedString() << "\n"
|
||||
<< "}\n";
|
||||
os << " else if (opName == \"" << op.getOperationName() << "\") {\n"
|
||||
<< " edsc::MLIREmitter emitter(&builder, f->getLoc());\n";
|
||||
|
||||
// Create memrefs for the operands. Operand $x has variable name xMemRef.
|
||||
for (auto arg : op.getOperands()) {
|
||||
if (!arg.name)
|
||||
PrintFatalError(def->getLoc(), "all operands must be named");
|
||||
os << formatv(" mlir::BlockArgument* {0}MemRef;\n",
|
||||
arg.name->getAsUnquotedString());
|
||||
}
|
||||
os << " mlir::BlockArgument* resultMemRef;\n";
|
||||
os << " {\n auto opIt = f->getArguments().begin();\n";
|
||||
for (auto arg : op.getOperands()) {
|
||||
os.indent(4) << arg.name->getAsUnquotedString() << "MemRef = *opIt++;\n";
|
||||
}
|
||||
os.indent(4) << "resultMemRef = *opIt++;\n";
|
||||
os << " }\n";
|
||||
|
||||
for (auto arg : op.getOperands()) {
|
||||
os << formatv(" Bindable {0}; (void){0};\n",
|
||||
arg.name->getAsUnquotedString());
|
||||
}
|
||||
os << " Bindable result;\n";
|
||||
|
||||
for (auto arg : op.getOperands()) {
|
||||
os.indent(2) << formatv(
|
||||
"auto {0}Shape = emitter.makeBoundSizes({0}MemRef); "
|
||||
"(void){0}Shape;\n",
|
||||
arg.name->getAsUnquotedString());
|
||||
}
|
||||
|
||||
// Print the EDSC.
|
||||
os << ref->getAsUnquotedString() << "\n}";
|
||||
}
|
||||
os << " else {"
|
||||
<< " f->emitError(\"no reference implementation for \" + opName);\n"
|
||||
|
|
Loading…
Reference in New Issue