Revert "[mlir][Linalg] Improve region support in Linalg ops."

This reverts commit 973e133b76.

It triggers an issue in gcc5 that require investigation, the build is
broken with:

/tmp/ccdpj3B9.s: Assembler messages:
/tmp/ccdpj3B9.s:5821: Error: symbol `_ZNSt17_Function_handlerIFvjjEUljjE2_E9_M_invokeERKSt9_Any_dataOjS6_' is already defined
/tmp/ccdpj3B9.s:5860: Error: symbol `_ZNSt14_Function_base13_Base_managerIUljjE2_E10_M_managerERSt9_Any_dataRKS3_St18_Manager_operation' is already defined
This commit is contained in:
Mehdi Amini 2021-02-12 18:15:15 +00:00
parent a7538fee3a
commit 3f22547fd1
11 changed files with 280 additions and 316 deletions

View File

@ -1056,6 +1056,20 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
//===------------------------------------------------------------------===// //===------------------------------------------------------------------===//
// Other static interface methods. // Other static interface methods.
//===------------------------------------------------------------------===// //===------------------------------------------------------------------===//
StaticInterfaceMethod<
/*desc=*/[{
Create an operation of the current type with the given location,
operands, and attributes.
}],
/*retTy=*/"Operation *",
/*methodName=*/"create",
(ins "OpBuilder &":$builder, "Location":$loc, "TypeRange":$resultTypes,
"ValueRange":$operands,
"ArrayRef<NamedAttribute>":$attributes), [{
return builder.create<ConcreteOp>(
loc, resultTypes, operands, attributes);
}]
>,
InterfaceMethod< InterfaceMethod<
/*desc=*/[{ /*desc=*/[{
Clone the current operation with the given location and operands. This Clone the current operation with the given location and operands. This
@ -1068,13 +1082,14 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
(ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
"ValueRange":$operands), "ValueRange":$operands),
[{ [{
BlockAndValueMapping bvm; BlockAndValueMapping map;
OperationState state( unsigned numRegions = $_op->getNumRegions();
loc, ConcreteOp::getOperationName(), operands, resultTypes, Operation *res = create(b, loc, resultTypes, operands, $_op->getAttrs());
$_op->getAttrs()); assert(res->getNumRegions() == numRegions && "inconsistent # regions");
for (Region &r : $_op->getRegions()) for (unsigned ridx = 0; ridx < numRegions; ++ridx)
r.cloneInto(state.addRegion(), bvm); $_op->getRegion(ridx).cloneInto(
return b.createOperation(state); &res->getRegion(ridx), map);
return res;
}] }]
>, >,
StaticInterfaceMethod< StaticInterfaceMethod<
@ -1083,7 +1098,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
Returns a null function if this named op does not define a region Returns a null function if this named op does not define a region
builder. builder.
}], }],
/*retTy=*/"std::function<void(Block &, ValueRange)>", /*retTy=*/"std::function<void(Block &)>",
/*methodName=*/"getRegionBuilder", /*methodName=*/"getRegionBuilder",
(ins), (ins),
[{ return ConcreteOp::getRegionBuilder(); }] [{ return ConcreteOp::getRegionBuilder(); }]

View File

@ -110,13 +110,14 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
AnyStridedMemRef:$output, AnyStridedMemRef:$output,
OptionalAttr<AffineMapAttr>:$inputPermutation, OptionalAttr<AffineMapAttr>:$inputPermutation,
OptionalAttr<AffineMapAttr>:$outputPermutation); OptionalAttr<AffineMapAttr>:$outputPermutation);
let regions = (region AnyRegion:$region);
let builders = [ // TODO: this should go away once the usage of OptionalAttr triggers emission
OpBuilderDAG<(ins "Value":$input, "Value":$output, // of builders with default arguments left unspecified.
CArg<"AffineMap", "AffineMap()">:$inputPermutation, let builders = [OpBuilderDAG<(ins "Value":$input, "Value":$output),
CArg<"AffineMap", "AffineMap()">:$outputPermutation, [{
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>]; return build(
$_builder, $_state, input, output, AffineMapAttr(), AffineMapAttr());
}]>];
let extraClassDeclaration = structuredOpsDecls # [{ let extraClassDeclaration = structuredOpsDecls # [{
ValueRange inputs() { return getOperands().take_front(); } ValueRange inputs() { return getOperands().take_front(); }
@ -145,31 +146,24 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
Value getSource() { return input();} Value getSource() { return input();}
Value getTarget() { return output(); } Value getTarget() { return output(); }
static void regionBuilder(Block &block, ValueRange captures); static std::function<void(Block &)> getRegionBuilder() {
static std::function<void(Block &block, ValueRange captures)> return nullptr;
getRegionBuilder() {
return &regionBuilder;
} }
static unsigned getNumRegionArgs() { return 2; }
}]; }];
let verifier = [{ return ::verify(*this); }]; let verifier = [{ return ::verify(*this); }];
let assemblyFormat = [{ let assemblyFormat = [{
`(` $input `,` $output `)` attr-dict `:` `(` operands `)` attr-dict `:` type(operands)
type($input) `,` type($output)
custom<CopyOpRegion>($region, ref(type($input)), ref(type($input)))
}]; }];
let hasFolder = 1; let hasFolder = 1;
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let skipDefaultBuilders = 1;
} }
def FillOp : LinalgStructured_Op<"fill", []> { def FillOp : LinalgStructured_Op<"fill", []> {
let arguments = (ins AnyShaped:$output, let arguments = (ins AnyShaped:$output,
AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value); AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value);
let results = (outs Optional<AnyRankedTensor>:$result); let results = (outs Optional<AnyRankedTensor>:$result);
let regions = (region AnyRegion:$region);
let extraClassDeclaration = structuredOpsDecls # [{ let extraClassDeclaration = structuredOpsDecls # [{
ValueRange inputs() { return {}; } ValueRange inputs() { return {}; }
ValueRange outputs() { return getOperands().take_front(); } ValueRange outputs() { return getOperands().take_front(); }
@ -189,18 +183,13 @@ def FillOp : LinalgStructured_Op<"fill", []> {
extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)}); extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)});
} }
static void regionBuilder(Block &block, ValueRange captures); static std::function<void(Block &)> getRegionBuilder() {
static std::function<void(Block &block, ValueRange captures)> return nullptr;
getRegionBuilder() {
return &regionBuilder;
} }
static unsigned getNumRegionArgs() { return 1; }
}]; }];
let assemblyFormat = [{ let assemblyFormat = [{
`(` $output `,` $value `)` attr-dict `:` `(` operands `)` attr-dict `:` type(operands) (`->` type($result)^)?
type($output) `,` type($value) (`->` type($result)^)?
custom<FillOpRegion>($region, ref(type($output)), ref($value))
}]; }];
let builders = [ let builders = [
@ -279,8 +268,7 @@ class PoolingBase_Op<string mnemonic, list<OpTrait> props>
return padding().getValue().getValue<int64_t>({i, 1}); return padding().getValue().getValue<int64_t>({i, 1});
} }
static std::function<void(Block &, ValueRange captures)> getRegionBuilder() static std::function<void(Block &)> getRegionBuilder() {
{
return nullptr; return nullptr;
} }
}]; }];
@ -531,7 +519,7 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
library_call()->str() : "op_has_no_registered_library_name"; library_call()->str() : "op_has_no_registered_library_name";
} }
static std::function<void(Block &, ValueRange)> getRegionBuilder() { static std::function<void(Block &)> getRegionBuilder() {
return nullptr; return nullptr;
} }
}]; }];

View File

@ -154,13 +154,7 @@ LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite(
if (in == op.input() && out == op.output()) if (in == op.input() && out == op.output())
return failure(); return failure();
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
if (!libraryCallName)
return failure();
rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, libraryCallName.getValue(), TypeRange(),
createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), {in, out}));
return success(); return success();
} }

View File

@ -27,6 +27,8 @@ Operation *mlir::edsc::makeGenericLinalgOp(
ArrayRef<StructuredIndexed> outputs, TypeRange resultTensorTypes, ArrayRef<StructuredIndexed> outputs, TypeRange resultTensorTypes,
function_ref<void(ValueRange)> regionBuilder, ArrayRef<Value> otherValues, function_ref<void(ValueRange)> regionBuilder, ArrayRef<Value> otherValues,
ArrayRef<Attribute> otherAttributes) { ArrayRef<Attribute> otherAttributes) {
OpBuilder &builder = edsc::ScopedContext::getBuilderRef();
// Build maps // Build maps
SmallVector<SmallVector<AffineExpr, 4>, 4> exprsList; SmallVector<SmallVector<AffineExpr, 4>, 4> exprsList;
exprsList.reserve(inputs.size() + outputs.size()); exprsList.reserve(inputs.size() + outputs.size());
@ -52,10 +54,13 @@ Operation *mlir::edsc::makeGenericLinalgOp(
resultTensorTypes, resultTensorTypes,
inputValues, inputValues,
outputValues, outputValues,
maps, builder.getAffineMapArrayAttr(maps),
iteratorStrTypes, builder.getStrArrayAttr(iteratorStrTypes),
""/*doc*/, StringAttr() /*doc*/,
""/*library_call*/) StringAttr() /*library_call*/,
ArrayAttr() /*sparse*/
/* TODO: other attributes in op */
)
.getOperation(); .getOperation();
// clang-format on // clang-format on

View File

@ -33,53 +33,32 @@ using namespace mlir;
using namespace mlir::linalg; using namespace mlir::linalg;
/// Forward declarations. /// Forward declarations.
/// Generic entry point to create the block for the region of a LinalgOp.
/// This is used by both named structured ops created by ods-gen and by manually
/// defined C++ ops.
/// This is used by both builders and parsers.
/// This function creates the block in the region with arguments corresponding
/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
/// to be ShapedType.
template <typename NamedStructuredOpType> template <typename NamedStructuredOpType>
static void fillStructuredOpRegion( static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
OpBuilder &opBuilder, Region &region, TypeRange inputTypes, OperationState &result,
TypeRange outputTypes, ValueRange captures = {}, TypeRange inputTypes,
std::function<void(unsigned, unsigned)> errorHandler = [](unsigned, TypeRange outputTypes);
unsigned) {});
/// Generic entry point to create both the region and the block of a LinalgOp.
template <typename NamedStructuredOpType>
static void
createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result,
TypeRange inputTypes, TypeRange outputTypes,
ValueRange captures = {});
/// Common parsing and printing used for both named structured ops created by
/// ods-gen and by manually defined C++ ops. Does not handle regions.
static ParseResult static ParseResult
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
SmallVectorImpl<Type> &inputTypes, SmallVectorImpl<Type> &inputTypes,
SmallVectorImpl<Type> &outputTypes); SmallVectorImpl<Type> &outputTypes);
template <typename NamedStructuredOpType>
static void printCommonStructuredOpParts(OpAsmPrinter &p,
NamedStructuredOpType op);
/// Specific parsing and printing for named structured ops created by ods-gen.
template <typename NamedStructuredOpType> template <typename NamedStructuredOpType>
static ParseResult static ParseResult
parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
TypeRange inputTypes, TypeRange outputTypes, TypeRange inputTypes, TypeRange outputTypes);
ArrayRef<OpAsmParser::OperandType> captures = {});
static ParseResult static ParseResult
parseNamedStructuredOpResults(OpAsmParser &parser, parseNamedStructuredOpResults(OpAsmParser &parser,
SmallVectorImpl<Type> &resultTypes); SmallVectorImpl<Type> &resultTypes);
template <typename NamedStructuredOpType> template <typename NamedStructuredOpType>
static ParseResult static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, OperationState &result);
ArrayRef<OpAsmParser::OperandType> captures = {});
template <typename NamedStructuredOpType>
static void printCommonStructuredOpParts(OpAsmPrinter &p,
NamedStructuredOpType op);
static void printNamedStructuredOpResults(OpAsmPrinter &p, static void printNamedStructuredOpResults(OpAsmPrinter &p,
TypeRange resultTypes); TypeRange resultTypes);
@ -122,136 +101,14 @@ static LogicalResult foldMemRefCast(Operation *op) {
return success(folded); return success(folded);
} }
//===----------------------------------------------------------------------===//
// CopyOp
//===----------------------------------------------------------------------===//
void CopyOp::regionBuilder(Block &block, ValueRange captures) {
using namespace edsc::intrinsics;
assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args");
(linalg_yield(block.getArgument(0)));
}
void CopyOp::build(OpBuilder &builder, OperationState &result, Value input,
Value output, AffineMap inputPermutation,
AffineMap outputPermutation,
ArrayRef<NamedAttribute> namedAttrs) {
result.addOperands({input, output});
result.addAttributes(namedAttrs);
if (inputPermutation)
result.addAttribute("inputPermutation",
AffineMapAttr::get(inputPermutation));
if (outputPermutation)
result.addAttribute("outputPermutation",
AffineMapAttr::get(outputPermutation));
result.addRegion();
fillStructuredOpRegion<CopyOp>(builder, *result.regions.front(),
TypeRange{input.getType()},
TypeRange{output.getType()});
}
ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType,
Type outputType) {
OpBuilder opBuilder(parser.getBuilder().getContext());
fillStructuredOpRegion<CopyOp>(opBuilder, r, TypeRange{inputType},
TypeRange{outputType});
return success();
}
/// CopyOp region is elided when printing.
void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
static LogicalResult verify(CopyOp op) {
auto outputViewType = op.getOutputShapedType(0);
auto inputViewType = op.getInputShapedType(0);
if (inputViewType.getElementType() != outputViewType.getElementType())
return op.emitOpError("expects views of the same type");
if (inputViewType.getRank() != outputViewType.getRank())
return op.emitOpError("expects views of the same rank");
auto rank = op.getNumParallelLoops();
auto inputPermutationMap = op.inputPermutation();
if (inputPermutationMap) {
if (inputPermutationMap->getNumInputs() != rank)
return op.emitOpError("expects optional input_permutation map of rank ")
<< rank;
if (!inputPermutationMap->isPermutation())
return op.emitOpError(
"expects optional input_permutation map to be a permutation");
}
auto outputPermutationMap = op.outputPermutation();
if (outputPermutationMap) {
if (outputPermutationMap->getNumInputs() != rank)
return op.emitOpError("expects optional output_permutation map of rank ")
<< rank;
if (!outputPermutationMap->isPermutation())
return op.emitOpError(
"expects optional output_permutation map to be a permutation");
}
if (rank == 0 && inputPermutationMap)
return op.emitOpError("expected no input permutation when rank == 0");
if (rank == 0 && outputPermutationMap)
return op.emitOpError("expected no output permutation when rank == 0");
return success();
}
void CopyOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Read::get(), input(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), output(),
SideEffects::DefaultResource::get());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// FillOp // FillOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void FillOp::regionBuilder(Block &block, ValueRange captures) {
using namespace edsc::intrinsics;
assert(captures.size() == 1 && "FillOp regionBuilder expects 1 capture");
(linalg_yield(captures));
}
void FillOp::build(OpBuilder &builder, OperationState &result, Value output, void FillOp::build(OpBuilder &builder, OperationState &result, Value output,
Value value) { Value value) {
build(builder, result, output.getType().dyn_cast<RankedTensorType>(), output, build(builder, result, output.getType().dyn_cast<RankedTensorType>(), output,
value); value);
fillStructuredOpRegion<FillOp>(builder, *result.regions.front(), TypeRange{},
TypeRange{output.getType()}, value);
}
ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type outputType,
OpAsmParser::OperandType valueRef) {
OpBuilder opBuilder(parser.getBuilder().getContext());
// Resolve `valueRef` into `value` at parse time so we can build the region
// with captures.
SmallVector<Value> value;
parser.resolveOperand(valueRef, getElementTypeOrSelf(outputType), value);
fillStructuredOpRegion<FillOp>(opBuilder, r, TypeRange{},
TypeRange{outputType}, value);
return success();
}
/// FillOp region is elided when printing.
void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {}
static LogicalResult verify(FillOp op) {
auto viewType = op.getOutputShapedType(0);
auto fillType = op.value().getType();
if (viewType.getElementType() != fillType)
return op.emitOpError("expects fill type to match view elemental type");
if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
return op.emitOpError(
"expected fill op with no result value to use memref type");
}
return success();
}
void FillOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (output().getType().isa<MemRefType>())
effects.emplace_back(MemoryEffects::Write::get(), output(),
SideEffects::DefaultResource::get());
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -576,6 +433,7 @@ void InitTensorOp::build(OpBuilder &b, OperationState &result,
result.addAttributes(attrs); result.addAttributes(attrs);
} }
static LogicalResult verify(InitTensorOp op) { static LogicalResult verify(InitTensorOp op) {
RankedTensorType resultType = op.getType(); RankedTensorType resultType = op.getType();
SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range( SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range(
@ -1556,6 +1414,68 @@ static LogicalResult verify(linalg::YieldOp op) {
/////// Operations corresponding to library calls defined with Tablegen //////// /////// Operations corresponding to library calls defined with Tablegen ////////
void FillOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (output().getType().isa<MemRefType>())
effects.emplace_back(MemoryEffects::Write::get(), output(),
SideEffects::DefaultResource::get());
}
static LogicalResult verify(FillOp op) {
auto viewType = op.getOutputShapedType(0);
auto fillType = op.value().getType();
if (viewType.getElementType() != fillType)
return op.emitOpError("expects fill type to match view elemental type");
if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
return op.emitOpError(
"expected fill op with no result value to use memref type");
}
return success();
}
void CopyOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Read::get(), input(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), output(),
SideEffects::DefaultResource::get());
}
static LogicalResult verify(CopyOp op) {
auto outputViewType = op.getOutputShapedType(0);
auto inputViewType = op.getInputShapedType(0);
if (inputViewType.getElementType() != outputViewType.getElementType())
return op.emitOpError("expects views of the same type");
if (inputViewType.getRank() != outputViewType.getRank())
return op.emitOpError("expects views of the same rank");
auto rank = op.getNumParallelLoops();
auto inputPermutationMap = op.inputPermutation();
if (inputPermutationMap) {
if (inputPermutationMap->getNumInputs() != rank)
return op.emitOpError("expects optional input_permutation map of rank ")
<< rank;
if (!inputPermutationMap->isPermutation())
return op.emitOpError(
"expects optional input_permutation map to be a permutation");
}
auto outputPermutationMap = op.outputPermutation();
if (outputPermutationMap) {
if (outputPermutationMap->getNumInputs() != rank)
return op.emitOpError("expects optional output_permutation map of rank ")
<< rank;
if (!outputPermutationMap->isPermutation())
return op.emitOpError(
"expects optional output_permutation map to be a permutation");
}
if (rank == 0 && inputPermutationMap)
return op.emitOpError("expected no input permutation when rank == 0");
if (rank == 0 && outputPermutationMap)
return op.emitOpError("expected no output permutation when rank == 0");
return success();
}
template <typename LinalgPoolingOp> template <typename LinalgPoolingOp>
static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op, static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op,
ArrayRef<Attribute> attrs, ArrayRef<Attribute> attrs,
@ -1788,25 +1708,14 @@ OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute> operands) {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Support for named Linalg ops defined in ods-gen. // Auto-generated Linalg named ops.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Generic entry point to create the block for the region of a LinalgOp.
/// This is used by both named structured ops created by ods-gen and by manually
/// defined C++ ops.
/// This is used by both builders and parsers.
/// This function creates the block in the region with arguments corresponding
/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
/// to be ShapedType.
template <typename NamedStructuredOpType> template <typename NamedStructuredOpType>
static void static void buildNamedStructuredOpRegionAndAttributesImpl(
fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
TypeRange inputTypes, TypeRange outputTypes, TypeRange outputTypes,
ValueRange captures, std::function<void(unsigned, unsigned)> errorHandler) {
std::function<void(unsigned, unsigned)> errorHandler) {
assert(llvm::all_of(inputTypes, [](Type t) { return t.isa<ShapedType>(); }));
assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
// TODO: atm all operands go through getElementTypeOrSelf, // TODO: atm all operands go through getElementTypeOrSelf,
// reconsider when we have evidence we need to. // reconsider when we have evidence we need to.
SmallVector<Type, 8> argTypes; SmallVector<Type, 8> argTypes;
@ -1816,7 +1725,7 @@ fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
// RAII. // RAII.
OpBuilder::InsertionGuard guard(opBuilder); OpBuilder::InsertionGuard guard(opBuilder);
Block *body = opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes); Block *body = opBuilder.createBlock(&region, {}, argTypes);
unsigned actual = body->getNumArguments(); unsigned actual = body->getNumArguments();
unsigned expected = NamedStructuredOpType::getNumRegionArgs(); unsigned expected = NamedStructuredOpType::getNumRegionArgs();
if (expected != actual) if (expected != actual)
@ -1824,30 +1733,53 @@ fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
opBuilder.setInsertionPointToStart(body); opBuilder.setInsertionPointToStart(body);
mlir::edsc::ScopedContext scope(opBuilder, opBuilder.getUnknownLoc()); mlir::edsc::ScopedContext scope(opBuilder, opBuilder.getUnknownLoc());
NamedStructuredOpType::regionBuilder(*body, captures); NamedStructuredOpType::regionBuilder(*body);
// indexing_maps is an auto-generated method. // indexing_maps is an auto-generated method.
// iterator_types is an auto-generated method. // iterator_types is an auto-generated method.
} }
/// Generic entry point to create both the region and the block of a LinalgOp.
template <typename NamedStructuredOpType> template <typename NamedStructuredOpType>
void createAndFillStructuredOpRegion(OpBuilder &opBuilder, void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
OperationState &result, OperationState &result,
TypeRange inputTypes, TypeRange inputTypes,
TypeRange outputTypes, TypeRange outputTypes) {
ValueRange captures) {
Region &region = *result.addRegion(); Region &region = *result.addRegion();
fillStructuredOpRegion<NamedStructuredOpType>( buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
opBuilder, region, inputTypes, outputTypes, captures, opBuilder, region, inputTypes, outputTypes,
[&](unsigned expected, unsigned actual) { [&](unsigned expected, unsigned actual) {
llvm::errs() << "region expects " << expected << " args, got "
<< actual;
assert(expected != actual && "incorrect number of arguments"); assert(expected != actual && "incorrect number of arguments");
}); });
} }
/// Common parsing used for both named structured ops created by ods-gen and by template <typename NamedStructuredOpType>
/// manually defined C++ ops. Does not handle regions. static ParseResult
parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
TypeRange inputTypes, TypeRange outputTypes) {
ParseResult res = success();
OpBuilder opBuilder(parser.getBuilder().getContext());
buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
opBuilder, region, inputTypes, outputTypes,
[&](unsigned expected, unsigned actual) {
res = parser.emitError(parser.getCurrentLocation(),
llvm::formatv("region expects {0} args, got {1}",
expected, actual));
});
return res;
}
static ParseResult
parseNamedStructuredOpResults(OpAsmParser &parser,
SmallVectorImpl<Type> &resultTypes) {
if (succeeded(parser.parseOptionalArrow()))
if (parser.parseTypeList(resultTypes))
return failure();
return success();
}
static ParseResult static ParseResult
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
SmallVectorImpl<Type> &inputTypes, SmallVectorImpl<Type> &inputTypes,
@ -1888,56 +1820,8 @@ parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
} }
template <typename NamedStructuredOpType> template <typename NamedStructuredOpType>
static void printCommonStructuredOpParts(OpAsmPrinter &p, static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
NamedStructuredOpType op) { OperationState &result) {
if (!op.inputs().empty())
p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
if (!op.outputs().empty())
p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
}
//===----------------------------------------------------------------------===//
// Specific parsing and printing for named structured ops created by ods-gen.
//===----------------------------------------------------------------------===//
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
TypeRange inputTypes, TypeRange outputTypes,
ArrayRef<OpAsmParser::OperandType> captures) {
ParseResult res = success();
OpBuilder opBuilder(parser.getBuilder().getContext());
// Resolve `captures` into `capturedValues` at parse time so we can build the
// region with captures.
SmallVector<Value> capturedValues;
fillStructuredOpRegion<NamedStructuredOpType>(
opBuilder, region, inputTypes, outputTypes, capturedValues,
[&](unsigned expected, unsigned actual) {
res = parser.emitError(
parser.getCurrentLocation(),
llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
"region expects {0} args, got {1}",
expected, actual));
region.front().dump();
});
return res;
}
static ParseResult
parseNamedStructuredOpResults(OpAsmParser &parser,
SmallVectorImpl<Type> &resultTypes) {
if (succeeded(parser.parseOptionalArrow()))
if (parser.parseTypeList(resultTypes))
return failure();
return success();
}
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
ArrayRef<OpAsmParser::OperandType> captures) {
// TODO: Enable when ods-gen supports captures.
assert(captures.empty() && "unexpected captures for named structured ops");
SmallVector<Type, 1> inputTypes, outputTypes; SmallVector<Type, 1> inputTypes, outputTypes;
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
return failure(); return failure();
@ -1951,7 +1835,7 @@ parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
std::unique_ptr<Region> region = std::make_unique<Region>(); std::unique_ptr<Region> region = std::make_unique<Region>();
if (parseNamedStructuredOpRegion<NamedStructuredOpType>( if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
parser, *region, inputTypes, outputTypes, captures)) parser, *region, inputTypes, outputTypes))
return failure(); return failure();
result.addRegion(std::move(region)); result.addRegion(std::move(region));
@ -1965,6 +1849,15 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
p.printOptionalArrowTypeList(resultTypes); p.printOptionalArrowTypeList(resultTypes);
} }
template <typename NamedStructuredOpType>
static void printCommonStructuredOpParts(OpAsmPrinter &p,
NamedStructuredOpType op) {
if (!op.inputs().empty())
p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
if (!op.outputs().empty())
p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
}
template <typename NamedStructuredOpType> template <typename NamedStructuredOpType>
static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
p << op.getOperationName(); p << op.getOperationName();
@ -1986,10 +1879,6 @@ static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) {
return verifyGenericOp<NamedStructuredOpType>(op); return verifyGenericOp<NamedStructuredOpType>(op);
} }
//===----------------------------------------------------------------------===//
// Canonicalizers and Folders.
//===----------------------------------------------------------------------===//
namespace { namespace {
struct EraseDeadLinalgOp : public RewritePattern { struct EraseDeadLinalgOp : public RewritePattern {
EraseDeadLinalgOp(PatternBenefit benefit = 1) EraseDeadLinalgOp(PatternBenefit benefit = 1)

View File

@ -49,7 +49,7 @@ static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp,
indexingMaps, iterators, indexingMaps, iterators,
[&regionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) { [&regionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
edsc::ScopedContext scope(bodyBuilder, loc); edsc::ScopedContext scope(bodyBuilder, loc);
regionBuilder(*bodyBuilder.getBlock(), /*captures=*/{}); regionBuilder(*bodyBuilder.getBlock());
}); });
} }

View File

@ -52,6 +52,14 @@ static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b,
return res; return res;
} }
static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs,
Optional<AffineMap> permutation) {
return permutation ? applyMapToValues(ScopedContext::getBuilderRef(),
ScopedContext::getLocation(),
permutation.getValue(), ivs)
: SmallVector<Value, 4>(ivs.begin(), ivs.end());
}
template <typename IndexedValueType, typename OpType> template <typename IndexedValueType, typename OpType>
static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues, static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
ArrayRef<SmallVector<Value, 8>> indexing, ArrayRef<SmallVector<Value, 8>> indexing,
@ -170,6 +178,40 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs,
outputBuffers); outputBuffers);
} }
template <typename IndexedValueType>
static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
assert(copyOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto nPar = copyOp.getNumParallelLoops();
assert(nPar == allIvs.size());
auto inputIvs =
permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation());
auto outputIvs =
permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation());
SmallVector<Value, 8> iivs(inputIvs.begin(), inputIvs.end());
SmallVector<Value, 8> oivs(outputIvs.begin(), outputIvs.end());
IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0));
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
// an n-D loop nest; with or without permutations.
// clang-format off
nPar > 0 ? O(oivs) = I(iivs) :
O() = I();
// clang-format on
}
template <typename IndexedValueType>
static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
assert(fillOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto nPar = fillOp.getNumParallelLoops();
assert(nPar == allIvs.size());
auto ivs = SmallVector<Value, 4>(allIvs.begin(), allIvs.begin() + nPar);
IndexedValueType O(fillOp.getOutputBuffer(0));
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
// an n-D loop nest; with or without permutations.
nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value();
}
// Create a padded view into the given `input` tensor using the 'indices' // Create a padded view into the given `input` tensor using the 'indices'
// to access the tensor. `skipPadding` lists the dimensions for which no padding // to access the tensor. `skipPadding` lists the dimensions for which no padding
// is needed e.g. the non-spatial dimensions for convolutions. // is needed e.g. the non-spatial dimensions for convolutions.
@ -491,8 +533,8 @@ linalgOpToLoopsImpl(Operation *op, OpBuilder &builder,
assert(iterArgs.empty() && "unexpected iterArgs"); assert(iterArgs.empty() && "unexpected iterArgs");
allIvs.append(ivs.begin(), ivs.end()); allIvs.append(ivs.begin(), ivs.end());
llvm::TypeSwitch<Operation *>(op) llvm::TypeSwitch<Operation *>(op)
.Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp, .Case<CopyOp, FillOp, ConvOp, PoolingMaxOp, PoolingMinOp,
IndexedGenericOp, LinalgOp>([&](auto op) { PoolingSumOp, IndexedGenericOp, LinalgOp>([&](auto op) {
emitScalarImplementation<IndexedValueTy>(allIvs, op); emitScalarImplementation<IndexedValueTy>(allIvs, op);
}) })
.Default([&](Operation *op) { assert(false && "unexpected op"); }); .Default([&](Operation *op) { assert(false && "unexpected op"); });

View File

@ -267,7 +267,7 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
llvm::map_range(linalgOp.getShapedOperandTypes(), llvm::map_range(linalgOp.getShapedOperandTypes(),
[](ShapedType t) { return t.getElementType(); })); [](ShapedType t) { return t.getElementType(); }));
block->addArguments(elementTypes); block->addArguments(elementTypes);
linalgOp.getRegionBuilder()(*block, /*captures=*/{}); linalgOp.getRegionBuilder()(*block);
} }
Block *block = &region->front(); Block *block = &region->front();
@ -333,26 +333,24 @@ static bool hasOnlyScalarElementwiseOp(Region &r) {
// Return true if the op is an element-wise linalg op. // Return true if the op is an element-wise linalg op.
static bool isElementwise(Operation *op) { static bool isElementwise(Operation *op) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op); auto genericOp = dyn_cast<linalg::GenericOp>(op);
if (!linalgOp) if (!genericOp)
return false; return false;
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) if (genericOp.getNumLoops() != genericOp.getNumParallelLoops())
return false; return false;
// TODO: relax the restrictions on indexing map. // TODO: relax the restrictions on indexing map.
for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) { for (unsigned i = 0, e = genericOp.getNumOutputs(); i < e; i++) {
if (!linalgOp.getOutputIndexingMap(i).isIdentity()) if (!genericOp.getOutputIndexingMap(i).isIdentity())
return false; return false;
} }
// Currently bound the input indexing map to minor identity as other // Currently bound the input indexing map to minor identity as other
// permutations might require adding transpose ops to convert the vector read // permutations might require adding transpose ops to convert the vector read
// to the right shape. // to the right shape.
for (unsigned i = 0, e = linalgOp.getNumInputs(); i < e; i++) { for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) {
if (!linalgOp.getInputIndexingMap(i).isMinorIdentity()) if (!genericOp.getInputIndexingMap(i).isMinorIdentity())
return false; return false;
} }
if (linalgOp->getNumRegions() != 1) return hasOnlyScalarElementwiseOp(genericOp.getRegion());
return false;
return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
} }
static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder, static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
@ -395,6 +393,9 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
for (Type outputTensorType : linalgOp.getOutputTensorTypes()) for (Type outputTensorType : linalgOp.getOutputTensorTypes())
if (!outputTensorType.cast<ShapedType>().hasStaticShape()) if (!outputTensorType.cast<ShapedType>().hasStaticShape())
return failure(); return failure();
if (isa<linalg::FillOp, linalg::CopyOp>(op))
return success();
if (isElementwise(op)) if (isElementwise(op))
return success(); return success();
return success(isaContractionOpInterface(linalgOp)); return success(isaContractionOpInterface(linalgOp));
@ -406,12 +407,43 @@ Optional<VectorizedLinalgOp> mlir::linalg::vectorizeLinalgOp(OpBuilder &builder,
return llvm::None; return llvm::None;
edsc::ScopedContext scope(builder, op->getLoc()); edsc::ScopedContext scope(builder, op->getLoc());
// In the case of 0-D memrefs, return null and special case to scalar load or
// store later.
if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
// Vectorize fill as a vector.broadcast.
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
<< "Rewrite linalg.fill as vector.broadcast: " << *op);
VectorizedLinalgOp res;
if (Value v = buildVectorWrite(builder, fillOp.value(), fillOp.output()))
res.tensorResults.push_back(v);
return res;
}
if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
// Vectorize copy as a vector.transfer_read+vector.transfer_write.
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
<< "Rewrite linalg.copy as vector.transfer_read + "
"vector.transfer_write: "
<< *op);
Value vector = buildVectorRead(builder, copyOp.input());
VectorizedLinalgOp res;
if (Value v = buildVectorWrite(builder, vector, copyOp.output()))
res.tensorResults.push_back(v);
return res;
}
if (isElementwise(op)) { if (isElementwise(op)) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
<< "Vectorize linalg op as a generic: " << *op); << "Vectorize linalg op as a generic: " << *op);
return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op)); return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
} }
// TODO: as soon as Copy and FillOp. get a region builder, replace all the
// above by:
// if (isa<FillOp, CopyOp>(op) || isElementwise(op)) {
// LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
// << "Vectorize linalg op as a generic: " << *op);
// return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
// }
return vectorizeContraction(builder, cast<LinalgOp>(op)); return vectorizeContraction(builder, cast<LinalgOp>(op));
} }

View File

@ -1,4 +1,5 @@
// RUN: mlir-opt -copy-removal -split-input-file %s | FileCheck %s // RUN: mlir-opt -copy-removal -split-input-file %s
//| FileCheck %s
// All linalg copies except the linalg.copy(%1, %9) must be removed since the // All linalg copies except the linalg.copy(%1, %9) must be removed since the
// defining operation of %1 and its DeallocOp have been defined in another block. // defining operation of %1 and its DeallocOp have been defined in another block.
@ -255,7 +256,7 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>, %result: memref<2xf32>)
%tmp2 = math.exp %gen2_arg0 : f32 %tmp2 = math.exp %gen2_arg0 : f32
linalg.yield %tmp2 : f32 linalg.yield %tmp2 : f32
} }
linalg.copy(%temp, %result) : memref<2xf32>, memref<2xf32> "linalg.copy"(%temp, %result) : (memref<2xf32>, memref<2xf32>) -> ()
dealloc %temp : memref<2xf32> dealloc %temp : memref<2xf32>
// CHECK: return // CHECK: return
return return
@ -291,7 +292,7 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>){
linalg.yield %tmp2 : f32 linalg.yield %tmp2 : f32
} }
// CHECK: linalg.copy // CHECK: linalg.copy
linalg.copy(%temp, %to) : memref<2xf32>, memref<2xf32> "linalg.copy"(%temp, %to) : (memref<2xf32>, memref<2xf32>) -> ()
dealloc %temp : memref<2xf32> dealloc %temp : memref<2xf32>
return return
} }
@ -354,7 +355,7 @@ func @check_with_affine_dialect(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg
} }
// CHECK-NOT: linalg.copy // CHECK-NOT: linalg.copy
// CHECK-NOT: dealloc // CHECK-NOT: dealloc
linalg.copy(%0, %arg2) : memref<4xf32>, memref<4xf32> "linalg.copy"(%0, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
dealloc %0 : memref<4xf32> dealloc %0 : memref<4xf32>
//CHECK: return //CHECK: return
return return

View File

@ -23,7 +23,7 @@
// IMPL-NEXT: map2 = simplifyAffineMap(map2); // IMPL-NEXT: map2 = simplifyAffineMap(map2);
// IMPL-NEXT: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 }); // IMPL-NEXT: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 });
// //
// IMPL: void Test1Op::regionBuilder(Block &block, ValueRange captures) { // IMPL: void Test1Op::regionBuilder(Block &block) {
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]); // IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]);
@ -47,7 +47,7 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
// IMPL: AffineMap::get(3, 3, {d2, d1}, context) // IMPL: AffineMap::get(3, 3, {d2, d1}, context)
// IMPL: AffineMap::get(3, 3, {d0, d1}, context) // IMPL: AffineMap::get(3, 3, {d0, d1}, context)
// //
// IMPL: Test2Op::regionBuilder(Block &block, ValueRange captures) { // IMPL: Test2Op::regionBuilder(Block &block) {
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]); // IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]);
@ -71,7 +71,7 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
// IMPL: AffineMap::get(4, 4, {d3, d2}, context) // IMPL: AffineMap::get(4, 4, {d3, d2}, context)
// IMPL: AffineMap::get(4, 4, {d0, d1, d2}, context) // IMPL: AffineMap::get(4, 4, {d0, d1, d2}, context)
// //
// IMPL: Test3Op::regionBuilder(Block &block, ValueRange captures) { // IMPL: Test3Op::regionBuilder(Block &block) {
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]); // IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]);

View File

@ -1871,11 +1871,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
$_builder.getI32VectorAttr({{ $_builder.getI32VectorAttr({{
static_cast<int32_t>(inputs.size()), static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())})); static_cast<int32_t>(outputs.size())}));
createAndFillStructuredOpRegion<{0}>( buildNamedStructuredOpRegionAndAttributes<{0}>(
$_builder, $_builder,
$_state, $_state,
TypeRange(inputs), TypeRange(inputs),
TypeRange(outputs)/*, TODO: support captures*/); TypeRange(outputs));
}]>, }]>,
OpBuilderDAG< OpBuilderDAG<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
@ -1889,11 +1889,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
$_builder.getI32VectorAttr({{ $_builder.getI32VectorAttr({{
static_cast<int32_t>(inputs.size()), static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())})); static_cast<int32_t>(outputs.size())}));
createAndFillStructuredOpRegion<{0}>( buildNamedStructuredOpRegionAndAttributes<{0}>(
$_builder, $_builder,
$_state, $_state,
TypeRange(inputs), TypeRange(inputs),
TypeRange(outputs)/*, TODO: support captures*/); TypeRange(outputs));
}]>, }]>,
OpBuilderDAG< OpBuilderDAG<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
@ -1907,9 +1907,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
{6} {6}
]; ];
let printer = [{{ return ::printNamedStructuredOp(p, *this); }]; let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
let parser = [{{ let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }];
return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
}];
let hasFolder = 1; let hasFolder = 1;
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
@ -1917,8 +1915,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
// Auto-generated. // Auto-generated.
ArrayAttr iterator_types(); ArrayAttr iterator_types();
ArrayAttr indexing_maps(); ArrayAttr indexing_maps();
static void regionBuilder(Block &block, ValueRange captures); static void regionBuilder(Block &block);
static std::function<void(Block &, ValueRange)> getRegionBuilder() {{ static std::function<void(Block &)> getRegionBuilder() {{
return regionBuilder; return regionBuilder;
} }
@ -1982,11 +1980,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
$_builder.getI32VectorAttr({{ $_builder.getI32VectorAttr({{
static_cast<int32_t>(inputs.size()), static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())})); static_cast<int32_t>(outputs.size())}));
createAndFillStructuredOpRegion<{0}>( buildNamedStructuredOpRegionAndAttributes<{0}>(
$_builder, $_builder,
$_state, $_state,
TypeRange(inputs), TypeRange(inputs),
TypeRange(outputs)/*, TODO: support captures*/); TypeRange(outputs));
{2} {2}
}]> }]>
)FMT"; )FMT";
@ -2313,7 +2311,7 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
}; };
const char *regionBuilderFmt = R"FMT( const char *regionBuilderFmt = R"FMT(
void {0}::regionBuilder(Block &block, ValueRange captures) { void {0}::regionBuilder(Block &block) {
using namespace edsc; using namespace edsc;
using namespace intrinsics; using namespace intrinsics;
auto args = block.getArguments(); auto args = block.getArguments();