[mlir] SCF: provide function_ref builders for IfOp

Now that OpBuilder is available in `build` functions, it becomes possible to
populate the "then" and "else" regions directly when building the "if"
operation. This is desirable in more structured forms of builders, especially
in when conditionals are mixed with loops. Provide new `build` APIs taking
callbacks for body constructors, similarly to scf::ForOp, and replace more
clunky edsc::BlockBuilder uses with these. The original APIs remain available
and go through the new implementation.

Differential Revision: https://reviews.llvm.org/D80527
This commit is contained in:
Alex Zinenko 2020-05-25 18:55:41 +02:00
parent 5ee902bb5f
commit cadb7ccf2c
6 changed files with 139 additions and 44 deletions

View File

@ -82,6 +82,16 @@ scf::ValueVector loopNestBuilder(
Value lb, Value ub, Value step, ValueRange iterArgInitValues,
function_ref<scf::ValueVector(Value, ValueRange)> fun = nullptr);
/// Adapters for building if conditions using the builder and the location
/// stored in ScopedContext. 'thenBody' is mandatory, 'elseBody' can be omitted
/// if the condition should not have an 'else' part.
ValueRange
conditionBuilder(TypeRange results, Value condition,
function_ref<scf::ValueVector()> thenBody,
function_ref<scf::ValueVector()> elseBody = nullptr);
ValueRange conditionBuilder(Value condition, function_ref<void()> thenBody,
function_ref<void()> elseBody = nullptr);
} // namespace edsc
} // namespace mlir

View File

@ -24,6 +24,8 @@
namespace mlir {
namespace scf {
void buildTerminatedBody(OpBuilder &builder, Location loc);
#include "mlir/Dialect/SCF/SCFOpsDialect.h.inc"
#define GET_OP_CLASSES

View File

@ -238,7 +238,18 @@ def IfOp : SCF_Op<"if",
OpBuilder<"OpBuilder &builder, OperationState &result, "
"Value cond, bool withElseRegion">,
OpBuilder<"OpBuilder &builder, OperationState &result, "
"TypeRange resultTypes, Value cond, bool withElseRegion">
"TypeRange resultTypes, Value cond, bool withElseRegion">,
OpBuilder<
"OpBuilder &builder, OperationState &result, TypeRange resultTypes, "
"Value cond, "
"function_ref<void(OpBuilder &, Location)> thenBuilder "
" = buildTerminatedBody, "
"function_ref<void(OpBuilder &, Location)> elseBuilder = nullptr">,
OpBuilder<
"OpBuilder &builder, OperationState &result, Value cond, "
"function_ref<void(OpBuilder &, Location)> thenBuilder "
" = buildTerminatedBody, "
"function_ref<void(OpBuilder &, Location)> elseBuilder = nullptr">
];
let extraClassDeclaration = [{

View File

@ -235,39 +235,38 @@ LogicalResult NDTransferOpHelper<TransferReadOp>::doReplace() {
SmallVector<Type, 1> resultType;
if (options.unroll)
resultType.push_back(vectorType);
auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
ScopedContext::getLocation(), resultType, inBoundsCondition,
/*withElseRegion=*/true);
// 3.a. If in-bounds, progressively lower to a 1-D transfer read.
BlockBuilder(&ifOp.thenRegion().front(), Append())([&] {
Value vector = load1DVector(majorIvsPlusOffsets);
// 3.a.i. If `options.unroll` is true, insert the 1-D vector in the
// aggregate. We must yield and merge with the `else` branch.
if (options.unroll) {
vector = vector_insert(vector, result, majorIvs);
(loop_yield(vector));
return;
}
// 3.a.ii. Otherwise, just go through the temporary `alloc`.
std_store(vector, alloc, majorIvs);
});
// 3. If in-bounds, progressively lower to a 1-D transfer read, otherwise
// splat a 1-D vector.
ValueRange ifResults = conditionBuilder(
resultType, inBoundsCondition,
[&]() -> scf::ValueVector {
Value vector = load1DVector(majorIvsPlusOffsets);
// 3.a. If `options.unroll` is true, insert the 1-D vector in the
// aggregate. We must yield and merge with the `else` branch.
if (options.unroll) {
vector = vector_insert(vector, result, majorIvs);
return {vector};
}
// 3.b. Otherwise, just go through the temporary `alloc`.
std_store(vector, alloc, majorIvs);
return {};
},
[&]() -> scf::ValueVector {
Value vector = std_splat(minorVectorType, xferOp.padding());
// 3.c. If `options.unroll` is true, insert the 1-D vector in the
// aggregate. We must yield and merge with the `then` branch.
if (options.unroll) {
vector = vector_insert(vector, result, majorIvs);
return {vector};
}
// 3.d. Otherwise, just go through the temporary `alloc`.
std_store(vector, alloc, majorIvs);
return {};
});
// 3.b. If not in-bounds, splat a 1-D vector.
BlockBuilder(&ifOp.elseRegion().front(), Append())([&] {
Value vector = std_splat(minorVectorType, xferOp.padding());
// 3.a.i. If `options.unroll` is true, insert the 1-D vector in the
// aggregate. We must yield and merge with the `then` branch.
if (options.unroll) {
vector = vector_insert(vector, result, majorIvs);
(loop_yield(vector));
return;
}
// 3.b.ii. Otherwise, just go through the temporary `alloc`.
std_store(vector, alloc, majorIvs);
});
if (!resultType.empty())
result = *ifOp.results().begin();
result = *ifResults.begin();
} else {
// 4. Guaranteed in-bounds, progressively lower to a 1-D transfer read.
Value loaded1D = load1DVector(majorIvsPlusOffsets);
@ -336,11 +335,8 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
if (inBoundsCondition) {
// 2.a. If the condition is not null, we need an IfOp, to write
// conditionally. Progressively lower to a 1-D transfer write.
auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
ScopedContext::getLocation(), TypeRange{}, inBoundsCondition,
/*withElseRegion=*/false);
BlockBuilder(&ifOp.thenRegion().front(),
Append())([&] { emitTransferWrite(majorIvsPlusOffsets); });
conditionBuilder(inBoundsCondition,
[&] { emitTransferWrite(majorIvsPlusOffsets); });
} else {
// 2.b. Guaranteed in-bounds. Progressively lower to a 1-D transfer write.
emitTransferWrite(majorIvsPlusOffsets);

View File

@ -159,3 +159,51 @@ mlir::scf::ValueVector mlir::edsc::loopNestBuilder(
iterArgInitValues.end());
});
}
static std::function<void(OpBuilder &, Location)>
wrapIfBody(function_ref<scf::ValueVector()> body, TypeRange expectedTypes) {
(void)expectedTypes;
return [=](OpBuilder &builder, Location loc) {
ScopedContext context(builder, loc);
scf::ValueVector returned = body();
assert(ValueRange(returned).getTypes() == expectedTypes &&
"'if' body builder returned values of unexpected type");
builder.create<scf::YieldOp>(loc, returned);
};
}
ValueRange
mlir::edsc::conditionBuilder(TypeRange results, Value condition,
function_ref<scf::ValueVector()> thenBody,
function_ref<scf::ValueVector()> elseBody) {
assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
assert(thenBody && "thenBody is mandatory");
auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
ScopedContext::getLocation(), results, condition,
wrapIfBody(thenBody, results), wrapIfBody(elseBody, results));
return ifOp.getResults();
}
static std::function<void(OpBuilder &, Location)>
wrapZeroResultIfBody(function_ref<void()> body) {
return [=](OpBuilder &builder, Location loc) {
ScopedContext context(builder, loc);
body();
builder.create<scf::YieldOp>(loc);
};
}
ValueRange mlir::edsc::conditionBuilder(Value condition,
function_ref<void()> thenBody,
function_ref<void()> elseBody) {
assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
assert(thenBody && "thenBody is mandatory");
ScopedContext::getBuilderRef().create<scf::IfOp>(
ScopedContext::getLocation(), condition, wrapZeroResultIfBody(thenBody),
elseBody ? llvm::function_ref<void(OpBuilder &, Location)>(
wrapZeroResultIfBody(elseBody))
: llvm::function_ref<void(OpBuilder &, Location)>(nullptr));
return {};
}

View File

@ -35,6 +35,11 @@ SCFDialect::SCFDialect(MLIRContext *context)
>();
}
/// Default callback for IfOp builders. Inserts a yield without arguments.
void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) {
builder.create<scf::YieldOp>(loc);
}
//===----------------------------------------------------------------------===//
// ForOp
//===----------------------------------------------------------------------===//
@ -338,20 +343,43 @@ void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
void IfOp::build(OpBuilder &builder, OperationState &result,
TypeRange resultTypes, Value cond, bool withElseRegion) {
auto addTerminator = [&](OpBuilder &nested, Location loc) {
if (resultTypes.empty())
IfOp::ensureTerminator(*nested.getInsertionBlock()->getParent(), nested,
loc);
};
build(builder, result, resultTypes, cond, addTerminator,
withElseRegion ? addTerminator
: function_ref<void(OpBuilder &, Location)>());
}
void IfOp::build(OpBuilder &builder, OperationState &result,
TypeRange resultTypes, Value cond,
function_ref<void(OpBuilder &, Location)> thenBuilder,
function_ref<void(OpBuilder &, Location)> elseBuilder) {
assert(thenBuilder && "the builder callback for 'then' must be present");
result.addOperands(cond);
result.addTypes(resultTypes);
OpBuilder::InsertionGuard guard(builder);
Region *thenRegion = result.addRegion();
thenRegion->push_back(new Block());
if (resultTypes.empty())
IfOp::ensureTerminator(*thenRegion, builder, result.location);
builder.createBlock(thenRegion);
thenBuilder(builder, result.location);
Region *elseRegion = result.addRegion();
if (withElseRegion) {
elseRegion->push_back(new Block());
if (resultTypes.empty())
IfOp::ensureTerminator(*elseRegion, builder, result.location);
}
if (!elseBuilder)
return;
builder.createBlock(elseRegion);
elseBuilder(builder, result.location);
}
void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
function_ref<void(OpBuilder &, Location)> thenBuilder,
function_ref<void(OpBuilder &, Location)> elseBuilder) {
build(builder, result, TypeRange(), cond, thenBuilder, elseBuilder);
}
static LogicalResult verify(IfOp op) {