forked from OSchip/llvm-project
[mlir] SCF: provide function_ref builders for IfOp
Now that OpBuilder is available in `build` functions, it becomes possible to populate the "then" and "else" regions directly when building the "if" operation. This is desirable in more structured forms of builders, especially in when conditionals are mixed with loops. Provide new `build` APIs taking callbacks for body constructors, similarly to scf::ForOp, and replace more clunky edsc::BlockBuilder uses with these. The original APIs remain available and go through the new implementation. Differential Revision: https://reviews.llvm.org/D80527
This commit is contained in:
parent
5ee902bb5f
commit
cadb7ccf2c
|
@ -82,6 +82,16 @@ scf::ValueVector loopNestBuilder(
|
|||
Value lb, Value ub, Value step, ValueRange iterArgInitValues,
|
||||
function_ref<scf::ValueVector(Value, ValueRange)> fun = nullptr);
|
||||
|
||||
/// Adapters for building if conditions using the builder and the location
|
||||
/// stored in ScopedContext. 'thenBody' is mandatory, 'elseBody' can be omitted
|
||||
/// if the condition should not have an 'else' part.
|
||||
ValueRange
|
||||
conditionBuilder(TypeRange results, Value condition,
|
||||
function_ref<scf::ValueVector()> thenBody,
|
||||
function_ref<scf::ValueVector()> elseBody = nullptr);
|
||||
ValueRange conditionBuilder(Value condition, function_ref<void()> thenBody,
|
||||
function_ref<void()> elseBody = nullptr);
|
||||
|
||||
} // namespace edsc
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -24,6 +24,8 @@
|
|||
namespace mlir {
|
||||
namespace scf {
|
||||
|
||||
void buildTerminatedBody(OpBuilder &builder, Location loc);
|
||||
|
||||
#include "mlir/Dialect/SCF/SCFOpsDialect.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
|
|
@ -238,7 +238,18 @@ def IfOp : SCF_Op<"if",
|
|||
OpBuilder<"OpBuilder &builder, OperationState &result, "
|
||||
"Value cond, bool withElseRegion">,
|
||||
OpBuilder<"OpBuilder &builder, OperationState &result, "
|
||||
"TypeRange resultTypes, Value cond, bool withElseRegion">
|
||||
"TypeRange resultTypes, Value cond, bool withElseRegion">,
|
||||
OpBuilder<
|
||||
"OpBuilder &builder, OperationState &result, TypeRange resultTypes, "
|
||||
"Value cond, "
|
||||
"function_ref<void(OpBuilder &, Location)> thenBuilder "
|
||||
" = buildTerminatedBody, "
|
||||
"function_ref<void(OpBuilder &, Location)> elseBuilder = nullptr">,
|
||||
OpBuilder<
|
||||
"OpBuilder &builder, OperationState &result, Value cond, "
|
||||
"function_ref<void(OpBuilder &, Location)> thenBuilder "
|
||||
" = buildTerminatedBody, "
|
||||
"function_ref<void(OpBuilder &, Location)> elseBuilder = nullptr">
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
|
|
@ -235,39 +235,38 @@ LogicalResult NDTransferOpHelper<TransferReadOp>::doReplace() {
|
|||
SmallVector<Type, 1> resultType;
|
||||
if (options.unroll)
|
||||
resultType.push_back(vectorType);
|
||||
auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
|
||||
ScopedContext::getLocation(), resultType, inBoundsCondition,
|
||||
/*withElseRegion=*/true);
|
||||
|
||||
// 3.a. If in-bounds, progressively lower to a 1-D transfer read.
|
||||
BlockBuilder(&ifOp.thenRegion().front(), Append())([&] {
|
||||
Value vector = load1DVector(majorIvsPlusOffsets);
|
||||
// 3.a.i. If `options.unroll` is true, insert the 1-D vector in the
|
||||
// aggregate. We must yield and merge with the `else` branch.
|
||||
if (options.unroll) {
|
||||
vector = vector_insert(vector, result, majorIvs);
|
||||
(loop_yield(vector));
|
||||
return;
|
||||
}
|
||||
// 3.a.ii. Otherwise, just go through the temporary `alloc`.
|
||||
std_store(vector, alloc, majorIvs);
|
||||
});
|
||||
// 3. If in-bounds, progressively lower to a 1-D transfer read, otherwise
|
||||
// splat a 1-D vector.
|
||||
ValueRange ifResults = conditionBuilder(
|
||||
resultType, inBoundsCondition,
|
||||
[&]() -> scf::ValueVector {
|
||||
Value vector = load1DVector(majorIvsPlusOffsets);
|
||||
// 3.a. If `options.unroll` is true, insert the 1-D vector in the
|
||||
// aggregate. We must yield and merge with the `else` branch.
|
||||
if (options.unroll) {
|
||||
vector = vector_insert(vector, result, majorIvs);
|
||||
return {vector};
|
||||
}
|
||||
// 3.b. Otherwise, just go through the temporary `alloc`.
|
||||
std_store(vector, alloc, majorIvs);
|
||||
return {};
|
||||
},
|
||||
[&]() -> scf::ValueVector {
|
||||
Value vector = std_splat(minorVectorType, xferOp.padding());
|
||||
// 3.c. If `options.unroll` is true, insert the 1-D vector in the
|
||||
// aggregate. We must yield and merge with the `then` branch.
|
||||
if (options.unroll) {
|
||||
vector = vector_insert(vector, result, majorIvs);
|
||||
return {vector};
|
||||
}
|
||||
// 3.d. Otherwise, just go through the temporary `alloc`.
|
||||
std_store(vector, alloc, majorIvs);
|
||||
return {};
|
||||
});
|
||||
|
||||
// 3.b. If not in-bounds, splat a 1-D vector.
|
||||
BlockBuilder(&ifOp.elseRegion().front(), Append())([&] {
|
||||
Value vector = std_splat(minorVectorType, xferOp.padding());
|
||||
// 3.a.i. If `options.unroll` is true, insert the 1-D vector in the
|
||||
// aggregate. We must yield and merge with the `then` branch.
|
||||
if (options.unroll) {
|
||||
vector = vector_insert(vector, result, majorIvs);
|
||||
(loop_yield(vector));
|
||||
return;
|
||||
}
|
||||
// 3.b.ii. Otherwise, just go through the temporary `alloc`.
|
||||
std_store(vector, alloc, majorIvs);
|
||||
});
|
||||
if (!resultType.empty())
|
||||
result = *ifOp.results().begin();
|
||||
result = *ifResults.begin();
|
||||
} else {
|
||||
// 4. Guaranteed in-bounds, progressively lower to a 1-D transfer read.
|
||||
Value loaded1D = load1DVector(majorIvsPlusOffsets);
|
||||
|
@ -336,11 +335,8 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
|
|||
if (inBoundsCondition) {
|
||||
// 2.a. If the condition is not null, we need an IfOp, to write
|
||||
// conditionally. Progressively lower to a 1-D transfer write.
|
||||
auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
|
||||
ScopedContext::getLocation(), TypeRange{}, inBoundsCondition,
|
||||
/*withElseRegion=*/false);
|
||||
BlockBuilder(&ifOp.thenRegion().front(),
|
||||
Append())([&] { emitTransferWrite(majorIvsPlusOffsets); });
|
||||
conditionBuilder(inBoundsCondition,
|
||||
[&] { emitTransferWrite(majorIvsPlusOffsets); });
|
||||
} else {
|
||||
// 2.b. Guaranteed in-bounds. Progressively lower to a 1-D transfer write.
|
||||
emitTransferWrite(majorIvsPlusOffsets);
|
||||
|
|
|
@ -159,3 +159,51 @@ mlir::scf::ValueVector mlir::edsc::loopNestBuilder(
|
|||
iterArgInitValues.end());
|
||||
});
|
||||
}
|
||||
|
||||
static std::function<void(OpBuilder &, Location)>
|
||||
wrapIfBody(function_ref<scf::ValueVector()> body, TypeRange expectedTypes) {
|
||||
(void)expectedTypes;
|
||||
return [=](OpBuilder &builder, Location loc) {
|
||||
ScopedContext context(builder, loc);
|
||||
scf::ValueVector returned = body();
|
||||
assert(ValueRange(returned).getTypes() == expectedTypes &&
|
||||
"'if' body builder returned values of unexpected type");
|
||||
builder.create<scf::YieldOp>(loc, returned);
|
||||
};
|
||||
}
|
||||
|
||||
ValueRange
|
||||
mlir::edsc::conditionBuilder(TypeRange results, Value condition,
|
||||
function_ref<scf::ValueVector()> thenBody,
|
||||
function_ref<scf::ValueVector()> elseBody) {
|
||||
assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
|
||||
assert(thenBody && "thenBody is mandatory");
|
||||
|
||||
auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
|
||||
ScopedContext::getLocation(), results, condition,
|
||||
wrapIfBody(thenBody, results), wrapIfBody(elseBody, results));
|
||||
return ifOp.getResults();
|
||||
}
|
||||
|
||||
static std::function<void(OpBuilder &, Location)>
|
||||
wrapZeroResultIfBody(function_ref<void()> body) {
|
||||
return [=](OpBuilder &builder, Location loc) {
|
||||
ScopedContext context(builder, loc);
|
||||
body();
|
||||
builder.create<scf::YieldOp>(loc);
|
||||
};
|
||||
}
|
||||
|
||||
ValueRange mlir::edsc::conditionBuilder(Value condition,
|
||||
function_ref<void()> thenBody,
|
||||
function_ref<void()> elseBody) {
|
||||
assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
|
||||
assert(thenBody && "thenBody is mandatory");
|
||||
|
||||
ScopedContext::getBuilderRef().create<scf::IfOp>(
|
||||
ScopedContext::getLocation(), condition, wrapZeroResultIfBody(thenBody),
|
||||
elseBody ? llvm::function_ref<void(OpBuilder &, Location)>(
|
||||
wrapZeroResultIfBody(elseBody))
|
||||
: llvm::function_ref<void(OpBuilder &, Location)>(nullptr));
|
||||
return {};
|
||||
}
|
||||
|
|
|
@ -35,6 +35,11 @@ SCFDialect::SCFDialect(MLIRContext *context)
|
|||
>();
|
||||
}
|
||||
|
||||
/// Default callback for IfOp builders. Inserts a yield without arguments.
|
||||
void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) {
|
||||
builder.create<scf::YieldOp>(loc);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ForOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -338,20 +343,43 @@ void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
|
|||
|
||||
void IfOp::build(OpBuilder &builder, OperationState &result,
|
||||
TypeRange resultTypes, Value cond, bool withElseRegion) {
|
||||
auto addTerminator = [&](OpBuilder &nested, Location loc) {
|
||||
if (resultTypes.empty())
|
||||
IfOp::ensureTerminator(*nested.getInsertionBlock()->getParent(), nested,
|
||||
loc);
|
||||
};
|
||||
|
||||
build(builder, result, resultTypes, cond, addTerminator,
|
||||
withElseRegion ? addTerminator
|
||||
: function_ref<void(OpBuilder &, Location)>());
|
||||
}
|
||||
|
||||
void IfOp::build(OpBuilder &builder, OperationState &result,
|
||||
TypeRange resultTypes, Value cond,
|
||||
function_ref<void(OpBuilder &, Location)> thenBuilder,
|
||||
function_ref<void(OpBuilder &, Location)> elseBuilder) {
|
||||
assert(thenBuilder && "the builder callback for 'then' must be present");
|
||||
|
||||
result.addOperands(cond);
|
||||
result.addTypes(resultTypes);
|
||||
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
Region *thenRegion = result.addRegion();
|
||||
thenRegion->push_back(new Block());
|
||||
if (resultTypes.empty())
|
||||
IfOp::ensureTerminator(*thenRegion, builder, result.location);
|
||||
builder.createBlock(thenRegion);
|
||||
thenBuilder(builder, result.location);
|
||||
|
||||
Region *elseRegion = result.addRegion();
|
||||
if (withElseRegion) {
|
||||
elseRegion->push_back(new Block());
|
||||
if (resultTypes.empty())
|
||||
IfOp::ensureTerminator(*elseRegion, builder, result.location);
|
||||
}
|
||||
if (!elseBuilder)
|
||||
return;
|
||||
|
||||
builder.createBlock(elseRegion);
|
||||
elseBuilder(builder, result.location);
|
||||
}
|
||||
|
||||
void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
|
||||
function_ref<void(OpBuilder &, Location)> thenBuilder,
|
||||
function_ref<void(OpBuilder &, Location)> elseBuilder) {
|
||||
build(builder, result, TypeRange(), cond, thenBuilder, elseBuilder);
|
||||
}
|
||||
|
||||
static LogicalResult verify(IfOp op) {
|
||||
|
|
Loading…
Reference in New Issue