Generate some of the boilerplate for reference implementation specification

PiperOrigin-RevId: 229735735
This commit is contained in:
Jacques Pienaar 2019-01-17 06:29:05 -08:00 committed by jpienaar
parent 0eebe6ffd9
commit 8cb1781657
3 changed files with 51 additions and 29 deletions

View File

@ -10,20 +10,20 @@ func @fn() {
}
// CHECK: block {
// CHECK-NEXT: for(idx($8)=$12 to $4 step $13) {
// CHECK-NEXT: for(idx($9)=$12 to $5 step $13) {
// CHECK-NEXT: for(idx($10)=$12 to $6 step $13) {
// CHECK-NEXT: for(idx($11)=$12 to $7 step $13) {
// CHECK-NEXT: lhs($14) = store( ... );
// CHECK-NEXT: };
// CHECK-NEXT: };
// CHECK-NEXT: };
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK: for(idx($12)=$16 to $4 step $17) {
// CHECK: for(idx($13)=$16 to $5 step $17) {
// CHECK: for(idx($14)=$16 to $6 step $17) {
// CHECK: for(idx($15)=$16 to $7 step $17) {
// CHECK: lhs($18) = store( ... );
// CHECK: };
// CHECK: };
// CHECK: };
// CHECK: }
// CHECK: }
// CHECK: block {
// CHECK-NEXT: for(idx($21)=$23 to $19 step $24) {
// CHECK-NEXT: for(idx($22)=$23 to $20 step $24) {
// CHECK-NEXT: lhs($25) = store( ... );
// CHECK-NEXT: };
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK: for(idx($27)=$29 to $23 step $30) {
// CHECK: for(idx($28)=$29 to $24 step $30) {
// CHECK: lhs($31) = store( ... );
// CHECK: };
// CHECK: }
// CHECK: }

View File

@ -11,13 +11,6 @@ def X_AddOp : Op<"x.add">,
// TODO: extract referenceImplementation to Op.
// TODO: shrink the reference implementation
code referenceImplementation = [{
auto *lhsMemRef = *(f->getArguments().begin());
auto *rhsMemRef = *(f->getArguments().begin() + 1);
auto *resultMemRef = *(f->getArguments().begin() + 2);
Bindable lhs, rhs, result;
auto lhsShape = emitter.makeBoundSizes(lhsMemRef);
auto ivs = makeBindables(lhsShape.size());
Bindable zero, one;
// Same bindable, all equal to `zero`.
@ -33,4 +26,4 @@ def X_AddOp : Op<"x.add">,
}];
}
// CHECK: printRefImplementation
// CHECK: printRefImplementation

View File

@ -44,16 +44,45 @@ static void emitReferenceImplementations(const RecordKeeper &recordKeeper,
<< " edsc::ScopedEDSCContext raiiContext;\n"
<< " Stmt block;\n"
<< " FuncBuilder builder(f);\n"
<< " if (false) {}";
<< "if (false) {}";
for (auto *def : defs) {
Operator op(def);
auto ref = def->getValueInit("referenceImplementation");
if (!ref)
continue;
os << "else if (opName == \"" << op.getOperationName() << "\") {\n"
<< " edsc::MLIREmitter emitter(&builder, f->getLoc());\n"
<< ref->getAsUnquotedString() << "\n"
<< "}\n";
os << " else if (opName == \"" << op.getOperationName() << "\") {\n"
<< " edsc::MLIREmitter emitter(&builder, f->getLoc());\n";
// Create memrefs for the operands. Operand $x has variable name xMemRef.
for (auto arg : op.getOperands()) {
if (!arg.name)
PrintFatalError(def->getLoc(), "all operands must be named");
os << formatv(" mlir::BlockArgument* {0}MemRef;\n",
arg.name->getAsUnquotedString());
}
os << " mlir::BlockArgument* resultMemRef;\n";
os << " {\n auto opIt = f->getArguments().begin();\n";
for (auto arg : op.getOperands()) {
os.indent(4) << arg.name->getAsUnquotedString() << "MemRef = *opIt++;\n";
}
os.indent(4) << "resultMemRef = *opIt++;\n";
os << " }\n";
for (auto arg : op.getOperands()) {
os << formatv(" Bindable {0}; (void){0};\n",
arg.name->getAsUnquotedString());
}
os << " Bindable result;\n";
for (auto arg : op.getOperands()) {
os.indent(2) << formatv(
"auto {0}Shape = emitter.makeBoundSizes({0}MemRef); "
"(void){0}Shape;\n",
arg.name->getAsUnquotedString());
}
// Print the EDSC.
os << ref->getAsUnquotedString() << "\n}";
}
os << " else {"
<< " f->emitError(\"no reference implementation for \" + opName);\n"