forked from OSchip/llvm-project
Generate some of the boilerplate for reference implementation specification
PiperOrigin-RevId: 229735735
This commit is contained in:
parent
0eebe6ffd9
commit
8cb1781657
|
@ -10,20 +10,20 @@ func @fn() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK: block {
|
// CHECK: block {
|
||||||
// CHECK-NEXT: for(idx($8)=$12 to $4 step $13) {
|
// CHECK: for(idx($12)=$16 to $4 step $17) {
|
||||||
// CHECK-NEXT: for(idx($9)=$12 to $5 step $13) {
|
// CHECK: for(idx($13)=$16 to $5 step $17) {
|
||||||
// CHECK-NEXT: for(idx($10)=$12 to $6 step $13) {
|
// CHECK: for(idx($14)=$16 to $6 step $17) {
|
||||||
// CHECK-NEXT: for(idx($11)=$12 to $7 step $13) {
|
// CHECK: for(idx($15)=$16 to $7 step $17) {
|
||||||
// CHECK-NEXT: lhs($14) = store( ... );
|
// CHECK: lhs($18) = store( ... );
|
||||||
// CHECK-NEXT: };
|
// CHECK: };
|
||||||
// CHECK-NEXT: };
|
// CHECK: };
|
||||||
// CHECK-NEXT: };
|
// CHECK: };
|
||||||
// CHECK-NEXT: }
|
// CHECK: }
|
||||||
// CHECK-NEXT: }
|
// CHECK: }
|
||||||
// CHECK: block {
|
// CHECK: block {
|
||||||
// CHECK-NEXT: for(idx($21)=$23 to $19 step $24) {
|
// CHECK: for(idx($27)=$29 to $23 step $30) {
|
||||||
// CHECK-NEXT: for(idx($22)=$23 to $20 step $24) {
|
// CHECK: for(idx($28)=$29 to $24 step $30) {
|
||||||
// CHECK-NEXT: lhs($25) = store( ... );
|
// CHECK: lhs($31) = store( ... );
|
||||||
// CHECK-NEXT: };
|
// CHECK: };
|
||||||
// CHECK-NEXT: }
|
// CHECK: }
|
||||||
// CHECK-NEXT: }
|
// CHECK: }
|
||||||
|
|
|
@ -11,13 +11,6 @@ def X_AddOp : Op<"x.add">,
|
||||||
// TODO: extract referenceImplementation to Op.
|
// TODO: extract referenceImplementation to Op.
|
||||||
// TODO: shrink the reference implementation
|
// TODO: shrink the reference implementation
|
||||||
code referenceImplementation = [{
|
code referenceImplementation = [{
|
||||||
auto *lhsMemRef = *(f->getArguments().begin());
|
|
||||||
auto *rhsMemRef = *(f->getArguments().begin() + 1);
|
|
||||||
auto *resultMemRef = *(f->getArguments().begin() + 2);
|
|
||||||
|
|
||||||
Bindable lhs, rhs, result;
|
|
||||||
auto lhsShape = emitter.makeBoundSizes(lhsMemRef);
|
|
||||||
|
|
||||||
auto ivs = makeBindables(lhsShape.size());
|
auto ivs = makeBindables(lhsShape.size());
|
||||||
Bindable zero, one;
|
Bindable zero, one;
|
||||||
// Same bindable, all equal to `zero`.
|
// Same bindable, all equal to `zero`.
|
||||||
|
@ -33,4 +26,4 @@ def X_AddOp : Op<"x.add">,
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK: printRefImplementation
|
// CHECK: printRefImplementation
|
||||||
|
|
|
@ -44,16 +44,45 @@ static void emitReferenceImplementations(const RecordKeeper &recordKeeper,
|
||||||
<< " edsc::ScopedEDSCContext raiiContext;\n"
|
<< " edsc::ScopedEDSCContext raiiContext;\n"
|
||||||
<< " Stmt block;\n"
|
<< " Stmt block;\n"
|
||||||
<< " FuncBuilder builder(f);\n"
|
<< " FuncBuilder builder(f);\n"
|
||||||
<< " if (false) {}";
|
<< "if (false) {}";
|
||||||
for (auto *def : defs) {
|
for (auto *def : defs) {
|
||||||
Operator op(def);
|
Operator op(def);
|
||||||
auto ref = def->getValueInit("referenceImplementation");
|
auto ref = def->getValueInit("referenceImplementation");
|
||||||
if (!ref)
|
if (!ref)
|
||||||
continue;
|
continue;
|
||||||
os << "else if (opName == \"" << op.getOperationName() << "\") {\n"
|
os << " else if (opName == \"" << op.getOperationName() << "\") {\n"
|
||||||
<< " edsc::MLIREmitter emitter(&builder, f->getLoc());\n"
|
<< " edsc::MLIREmitter emitter(&builder, f->getLoc());\n";
|
||||||
<< ref->getAsUnquotedString() << "\n"
|
|
||||||
<< "}\n";
|
// Create memrefs for the operands. Operand $x has variable name xMemRef.
|
||||||
|
for (auto arg : op.getOperands()) {
|
||||||
|
if (!arg.name)
|
||||||
|
PrintFatalError(def->getLoc(), "all operands must be named");
|
||||||
|
os << formatv(" mlir::BlockArgument* {0}MemRef;\n",
|
||||||
|
arg.name->getAsUnquotedString());
|
||||||
|
}
|
||||||
|
os << " mlir::BlockArgument* resultMemRef;\n";
|
||||||
|
os << " {\n auto opIt = f->getArguments().begin();\n";
|
||||||
|
for (auto arg : op.getOperands()) {
|
||||||
|
os.indent(4) << arg.name->getAsUnquotedString() << "MemRef = *opIt++;\n";
|
||||||
|
}
|
||||||
|
os.indent(4) << "resultMemRef = *opIt++;\n";
|
||||||
|
os << " }\n";
|
||||||
|
|
||||||
|
for (auto arg : op.getOperands()) {
|
||||||
|
os << formatv(" Bindable {0}; (void){0};\n",
|
||||||
|
arg.name->getAsUnquotedString());
|
||||||
|
}
|
||||||
|
os << " Bindable result;\n";
|
||||||
|
|
||||||
|
for (auto arg : op.getOperands()) {
|
||||||
|
os.indent(2) << formatv(
|
||||||
|
"auto {0}Shape = emitter.makeBoundSizes({0}MemRef); "
|
||||||
|
"(void){0}Shape;\n",
|
||||||
|
arg.name->getAsUnquotedString());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print the EDSC.
|
||||||
|
os << ref->getAsUnquotedString() << "\n}";
|
||||||
}
|
}
|
||||||
os << " else {"
|
os << " else {"
|
||||||
<< " f->emitError(\"no reference implementation for \" + opName);\n"
|
<< " f->emitError(\"no reference implementation for \" + opName);\n"
|
||||||
|
|
Loading…
Reference in New Issue