diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 95656ebd9983..c26f02208215 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -1056,6 +1056,20 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { //===------------------------------------------------------------------===// // Other static interface methods. //===------------------------------------------------------------------===// + StaticInterfaceMethod< + /*desc=*/[{ + Create an operation of the current type with the given location, + operands, and attributes. + }], + /*retTy=*/"Operation *", + /*methodName=*/"create", + (ins "OpBuilder &":$builder, "Location":$loc, "TypeRange":$resultTypes, + "ValueRange":$operands, + "ArrayRef":$attributes), [{ + return builder.create( + loc, resultTypes, operands, attributes); + }] + >, InterfaceMethod< /*desc=*/[{ Clone the current operation with the given location and operands. This @@ -1068,13 +1082,14 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, "ValueRange":$operands), [{ - BlockAndValueMapping bvm; - OperationState state( - loc, ConcreteOp::getOperationName(), operands, resultTypes, - $_op->getAttrs()); - for (Region &r : $_op->getRegions()) - r.cloneInto(state.addRegion(), bvm); - return b.createOperation(state); + BlockAndValueMapping map; + unsigned numRegions = $_op->getNumRegions(); + Operation *res = create(b, loc, resultTypes, operands, $_op->getAttrs()); + assert(res->getNumRegions() == numRegions && "inconsistent # regions"); + for (unsigned ridx = 0; ridx < numRegions; ++ridx) + $_op->getRegion(ridx).cloneInto( + &res->getRegion(ridx), map); + return res; }] >, StaticInterfaceMethod< @@ -1083,7 +1098,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { Returns a null function if this named op does not define a region builder. }], - /*retTy=*/"std::function", + /*retTy=*/"std::function", /*methodName=*/"getRegionBuilder", (ins), [{ return ConcreteOp::getRegionBuilder(); }] diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 05a6bb766dd0..8988a3a11efd 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -110,13 +110,14 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> { AnyStridedMemRef:$output, OptionalAttr:$inputPermutation, OptionalAttr:$outputPermutation); - let regions = (region AnyRegion:$region); - let builders = [ - OpBuilderDAG<(ins "Value":$input, "Value":$output, - CArg<"AffineMap", "AffineMap()">:$inputPermutation, - CArg<"AffineMap", "AffineMap()">:$outputPermutation, - CArg<"ArrayRef", "{}">:$attrs)>]; + // TODO: this should go away once the usage of OptionalAttr triggers emission + // of builders with default arguments left unspecified. + let builders = [OpBuilderDAG<(ins "Value":$input, "Value":$output), + [{ + return build( + $_builder, $_state, input, output, AffineMapAttr(), AffineMapAttr()); + }]>]; let extraClassDeclaration = structuredOpsDecls # [{ ValueRange inputs() { return getOperands().take_front(); } @@ -145,31 +146,24 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> { Value getSource() { return input();} Value getTarget() { return output(); } - static void regionBuilder(Block &block, ValueRange captures); - static std::function - getRegionBuilder() { - return ®ionBuilder; + static std::function getRegionBuilder() { + return nullptr; } - static unsigned getNumRegionArgs() { return 2; } }]; let verifier = [{ return ::verify(*this); }]; let assemblyFormat = [{ - `(` $input `,` $output `)` attr-dict `:` - type($input) `,` type($output) - custom($region, ref(type($input)), ref(type($input))) + `(` operands `)` attr-dict `:` type(operands) }]; let hasFolder = 1; let hasCanonicalizer = 1; - let skipDefaultBuilders = 1; } def FillOp : LinalgStructured_Op<"fill", []> { let arguments = (ins AnyShaped:$output, AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value); let results = (outs Optional:$result); - let regions = (region AnyRegion:$region); let extraClassDeclaration = structuredOpsDecls # [{ ValueRange inputs() { return {}; } ValueRange outputs() { return getOperands().take_front(); } @@ -189,18 +183,13 @@ def FillOp : LinalgStructured_Op<"fill", []> { extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)}); } - static void regionBuilder(Block &block, ValueRange captures); - static std::function - getRegionBuilder() { - return ®ionBuilder; + static std::function getRegionBuilder() { + return nullptr; } - static unsigned getNumRegionArgs() { return 1; } }]; let assemblyFormat = [{ - `(` $output `,` $value `)` attr-dict `:` - type($output) `,` type($value) (`->` type($result)^)? - custom($region, ref(type($output)), ref($value)) + `(` operands `)` attr-dict `:` type(operands) (`->` type($result)^)? }]; let builders = [ @@ -279,8 +268,7 @@ class PoolingBase_Op props> return padding().getValue().getValue({i, 1}); } - static std::function getRegionBuilder() - { + static std::function getRegionBuilder() { return nullptr; } }]; @@ -531,7 +519,7 @@ class GenericOpBase : LinalgStructuredBase_Opstr() : "op_has_no_registered_library_name"; } - static std::function getRegionBuilder() { + static std::function getRegionBuilder() { return nullptr; } }]; diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp index 276b124a9f10..8b53ecb74075 100644 --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -154,13 +154,7 @@ LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite( if (in == op.input() && out == op.output()) return failure(); - auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); - if (!libraryCallName) - return failure(); - - rewriter.replaceOpWithNewOp( - op, libraryCallName.getValue(), TypeRange(), - createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), {in, out})); + rewriter.replaceOpWithNewOp(op, in, out); return success(); } diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp index 46e42e26c5c8..8bb104df5bd6 100644 --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -27,6 +27,8 @@ Operation *mlir::edsc::makeGenericLinalgOp( ArrayRef outputs, TypeRange resultTensorTypes, function_ref regionBuilder, ArrayRef otherValues, ArrayRef otherAttributes) { + OpBuilder &builder = edsc::ScopedContext::getBuilderRef(); + // Build maps SmallVector, 4> exprsList; exprsList.reserve(inputs.size() + outputs.size()); @@ -52,10 +54,13 @@ Operation *mlir::edsc::makeGenericLinalgOp( resultTensorTypes, inputValues, outputValues, - maps, - iteratorStrTypes, - ""/*doc*/, - ""/*library_call*/) + builder.getAffineMapArrayAttr(maps), + builder.getStrArrayAttr(iteratorStrTypes), + StringAttr() /*doc*/, + StringAttr() /*library_call*/, + ArrayAttr() /*sparse*/ + /* TODO: other attributes in op */ + ) .getOperation(); // clang-format on diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 42a7900f23bb..3cc4a785bdae 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -33,53 +33,32 @@ using namespace mlir; using namespace mlir::linalg; /// Forward declarations. - -/// Generic entry point to create the block for the region of a LinalgOp. -/// This is used by both named structured ops created by ods-gen and by manually -/// defined C++ ops. -/// This is used by both builders and parsers. -/// This function creates the block in the region with arguments corresponding -/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted -/// to be ShapedType. template -static void fillStructuredOpRegion( - OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, - TypeRange outputTypes, ValueRange captures = {}, - std::function errorHandler = [](unsigned, - unsigned) {}); +static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder, + OperationState &result, + TypeRange inputTypes, + TypeRange outputTypes); -/// Generic entry point to create both the region and the block of a LinalgOp. -template -static void -createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result, - TypeRange inputTypes, TypeRange outputTypes, - ValueRange captures = {}); - -/// Common parsing and printing used for both named structured ops created by -/// ods-gen and by manually defined C++ ops. Does not handle regions. static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl &inputTypes, SmallVectorImpl &outputTypes); -template -static void printCommonStructuredOpParts(OpAsmPrinter &p, - NamedStructuredOpType op); -/// Specific parsing and printing for named structured ops created by ods-gen. template static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, - TypeRange inputTypes, TypeRange outputTypes, - ArrayRef captures = {}); - + TypeRange inputTypes, TypeRange outputTypes); static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl &resultTypes); template -static ParseResult -parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, - ArrayRef captures = {}); +static ParseResult parseNamedStructuredOp(OpAsmParser &parser, + OperationState &result); + +template +static void printCommonStructuredOpParts(OpAsmPrinter &p, + NamedStructuredOpType op); static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes); @@ -122,136 +101,14 @@ static LogicalResult foldMemRefCast(Operation *op) { return success(folded); } -//===----------------------------------------------------------------------===// -// CopyOp -//===----------------------------------------------------------------------===// -void CopyOp::regionBuilder(Block &block, ValueRange captures) { - using namespace edsc::intrinsics; - assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args"); - (linalg_yield(block.getArgument(0))); -} - -void CopyOp::build(OpBuilder &builder, OperationState &result, Value input, - Value output, AffineMap inputPermutation, - AffineMap outputPermutation, - ArrayRef namedAttrs) { - result.addOperands({input, output}); - result.addAttributes(namedAttrs); - if (inputPermutation) - result.addAttribute("inputPermutation", - AffineMapAttr::get(inputPermutation)); - if (outputPermutation) - result.addAttribute("outputPermutation", - AffineMapAttr::get(outputPermutation)); - result.addRegion(); - fillStructuredOpRegion(builder, *result.regions.front(), - TypeRange{input.getType()}, - TypeRange{output.getType()}); -} - -ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType, - Type outputType) { - OpBuilder opBuilder(parser.getBuilder().getContext()); - fillStructuredOpRegion(opBuilder, r, TypeRange{inputType}, - TypeRange{outputType}); - return success(); -} - -/// CopyOp region is elided when printing. -void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {} - -static LogicalResult verify(CopyOp op) { - auto outputViewType = op.getOutputShapedType(0); - auto inputViewType = op.getInputShapedType(0); - if (inputViewType.getElementType() != outputViewType.getElementType()) - return op.emitOpError("expects views of the same type"); - if (inputViewType.getRank() != outputViewType.getRank()) - return op.emitOpError("expects views of the same rank"); - auto rank = op.getNumParallelLoops(); - auto inputPermutationMap = op.inputPermutation(); - if (inputPermutationMap) { - if (inputPermutationMap->getNumInputs() != rank) - return op.emitOpError("expects optional input_permutation map of rank ") - << rank; - if (!inputPermutationMap->isPermutation()) - return op.emitOpError( - "expects optional input_permutation map to be a permutation"); - } - auto outputPermutationMap = op.outputPermutation(); - if (outputPermutationMap) { - if (outputPermutationMap->getNumInputs() != rank) - return op.emitOpError("expects optional output_permutation map of rank ") - << rank; - if (!outputPermutationMap->isPermutation()) - return op.emitOpError( - "expects optional output_permutation map to be a permutation"); - } - if (rank == 0 && inputPermutationMap) - return op.emitOpError("expected no input permutation when rank == 0"); - if (rank == 0 && outputPermutationMap) - return op.emitOpError("expected no output permutation when rank == 0"); - return success(); -} - -void CopyOp::getEffects( - SmallVectorImpl> - &effects) { - effects.emplace_back(MemoryEffects::Read::get(), input(), - SideEffects::DefaultResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), output(), - SideEffects::DefaultResource::get()); -} - //===----------------------------------------------------------------------===// // FillOp //===----------------------------------------------------------------------===// -void FillOp::regionBuilder(Block &block, ValueRange captures) { - using namespace edsc::intrinsics; - assert(captures.size() == 1 && "FillOp regionBuilder expects 1 capture"); - (linalg_yield(captures)); -} void FillOp::build(OpBuilder &builder, OperationState &result, Value output, Value value) { build(builder, result, output.getType().dyn_cast(), output, value); - fillStructuredOpRegion(builder, *result.regions.front(), TypeRange{}, - TypeRange{output.getType()}, value); -} - -ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type outputType, - OpAsmParser::OperandType valueRef) { - OpBuilder opBuilder(parser.getBuilder().getContext()); - // Resolve `valueRef` into `value` at parse time so we can build the region - // with captures. - SmallVector value; - parser.resolveOperand(valueRef, getElementTypeOrSelf(outputType), value); - fillStructuredOpRegion(opBuilder, r, TypeRange{}, - TypeRange{outputType}, value); - return success(); -} - -/// FillOp region is elided when printing. -void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {} - -static LogicalResult verify(FillOp op) { - auto viewType = op.getOutputShapedType(0); - auto fillType = op.value().getType(); - if (viewType.getElementType() != fillType) - return op.emitOpError("expects fill type to match view elemental type"); - if (!op.getNumResults() && !viewType.isa()) { - return op.emitOpError( - "expected fill op with no result value to use memref type"); - } - return success(); -} - -void FillOp::getEffects( - SmallVectorImpl> - &effects) { - if (output().getType().isa()) - effects.emplace_back(MemoryEffects::Write::get(), output(), - SideEffects::DefaultResource::get()); } //===----------------------------------------------------------------------===// @@ -576,6 +433,7 @@ void InitTensorOp::build(OpBuilder &b, OperationState &result, result.addAttributes(attrs); } + static LogicalResult verify(InitTensorOp op) { RankedTensorType resultType = op.getType(); SmallVector staticSizes = llvm::to_vector<4>(llvm::map_range( @@ -1556,6 +1414,68 @@ static LogicalResult verify(linalg::YieldOp op) { /////// Operations corresponding to library calls defined with Tablegen //////// +void FillOp::getEffects( + SmallVectorImpl> + &effects) { + if (output().getType().isa()) + effects.emplace_back(MemoryEffects::Write::get(), output(), + SideEffects::DefaultResource::get()); +} + +static LogicalResult verify(FillOp op) { + auto viewType = op.getOutputShapedType(0); + auto fillType = op.value().getType(); + if (viewType.getElementType() != fillType) + return op.emitOpError("expects fill type to match view elemental type"); + if (!op.getNumResults() && !viewType.isa()) { + return op.emitOpError( + "expected fill op with no result value to use memref type"); + } + return success(); +} + +void CopyOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), input(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), output(), + SideEffects::DefaultResource::get()); +} + +static LogicalResult verify(CopyOp op) { + auto outputViewType = op.getOutputShapedType(0); + auto inputViewType = op.getInputShapedType(0); + if (inputViewType.getElementType() != outputViewType.getElementType()) + return op.emitOpError("expects views of the same type"); + if (inputViewType.getRank() != outputViewType.getRank()) + return op.emitOpError("expects views of the same rank"); + auto rank = op.getNumParallelLoops(); + auto inputPermutationMap = op.inputPermutation(); + if (inputPermutationMap) { + if (inputPermutationMap->getNumInputs() != rank) + return op.emitOpError("expects optional input_permutation map of rank ") + << rank; + if (!inputPermutationMap->isPermutation()) + return op.emitOpError( + "expects optional input_permutation map to be a permutation"); + } + auto outputPermutationMap = op.outputPermutation(); + if (outputPermutationMap) { + if (outputPermutationMap->getNumInputs() != rank) + return op.emitOpError("expects optional output_permutation map of rank ") + << rank; + if (!outputPermutationMap->isPermutation()) + return op.emitOpError( + "expects optional output_permutation map to be a permutation"); + } + if (rank == 0 && inputPermutationMap) + return op.emitOpError("expected no input permutation when rank == 0"); + if (rank == 0 && outputPermutationMap) + return op.emitOpError("expected no output permutation when rank == 0"); + return success(); +} + template static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op, ArrayRef attrs, @@ -1788,25 +1708,14 @@ OpFoldResult TensorReshapeOp::fold(ArrayRef operands) { } //===----------------------------------------------------------------------===// -// Support for named Linalg ops defined in ods-gen. +// Auto-generated Linalg named ops. //===----------------------------------------------------------------------===// -/// Generic entry point to create the block for the region of a LinalgOp. -/// This is used by both named structured ops created by ods-gen and by manually -/// defined C++ ops. -/// This is used by both builders and parsers. -/// This function creates the block in the region with arguments corresponding -/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted -/// to be ShapedType. template -static void -fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, - TypeRange inputTypes, TypeRange outputTypes, - ValueRange captures, - std::function errorHandler) { - assert(llvm::all_of(inputTypes, [](Type t) { return t.isa(); })); - assert(llvm::all_of(outputTypes, [](Type t) { return t.isa(); })); - +static void buildNamedStructuredOpRegionAndAttributesImpl( + OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, + TypeRange outputTypes, + std::function errorHandler) { // TODO: atm all operands go through getElementTypeOrSelf, // reconsider when we have evidence we need to. SmallVector argTypes; @@ -1816,7 +1725,7 @@ fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, // RAII. OpBuilder::InsertionGuard guard(opBuilder); - Block *body = opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes); + Block *body = opBuilder.createBlock(®ion, {}, argTypes); unsigned actual = body->getNumArguments(); unsigned expected = NamedStructuredOpType::getNumRegionArgs(); if (expected != actual) @@ -1824,30 +1733,53 @@ fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, opBuilder.setInsertionPointToStart(body); mlir::edsc::ScopedContext scope(opBuilder, opBuilder.getUnknownLoc()); - NamedStructuredOpType::regionBuilder(*body, captures); + NamedStructuredOpType::regionBuilder(*body); // indexing_maps is an auto-generated method. // iterator_types is an auto-generated method. } -/// Generic entry point to create both the region and the block of a LinalgOp. template -void createAndFillStructuredOpRegion(OpBuilder &opBuilder, - OperationState &result, - TypeRange inputTypes, - TypeRange outputTypes, - ValueRange captures) { +void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder, + OperationState &result, + TypeRange inputTypes, + TypeRange outputTypes) { Region ®ion = *result.addRegion(); - fillStructuredOpRegion( - opBuilder, region, inputTypes, outputTypes, captures, + buildNamedStructuredOpRegionAndAttributesImpl( + opBuilder, region, inputTypes, outputTypes, [&](unsigned expected, unsigned actual) { + llvm::errs() << "region expects " << expected << " args, got " + << actual; assert(expected != actual && "incorrect number of arguments"); }); } -/// Common parsing used for both named structured ops created by ods-gen and by -/// manually defined C++ ops. Does not handle regions. +template +static ParseResult +parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, + TypeRange inputTypes, TypeRange outputTypes) { + ParseResult res = success(); + OpBuilder opBuilder(parser.getBuilder().getContext()); + buildNamedStructuredOpRegionAndAttributesImpl( + opBuilder, region, inputTypes, outputTypes, + [&](unsigned expected, unsigned actual) { + res = parser.emitError(parser.getCurrentLocation(), + llvm::formatv("region expects {0} args, got {1}", + expected, actual)); + }); + return res; +} + +static ParseResult +parseNamedStructuredOpResults(OpAsmParser &parser, + SmallVectorImpl &resultTypes) { + if (succeeded(parser.parseOptionalArrow())) + if (parser.parseTypeList(resultTypes)) + return failure(); + return success(); +} + static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl &inputTypes, @@ -1888,56 +1820,8 @@ parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, } template -static void printCommonStructuredOpParts(OpAsmPrinter &p, - NamedStructuredOpType op) { - if (!op.inputs().empty()) - p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")"; - if (!op.outputs().empty()) - p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")"; -} - -//===----------------------------------------------------------------------===// -// Specific parsing and printing for named structured ops created by ods-gen. -//===----------------------------------------------------------------------===// - -template -static ParseResult -parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, - TypeRange inputTypes, TypeRange outputTypes, - ArrayRef captures) { - ParseResult res = success(); - OpBuilder opBuilder(parser.getBuilder().getContext()); - // Resolve `captures` into `capturedValues` at parse time so we can build the - // region with captures. - SmallVector capturedValues; - fillStructuredOpRegion( - opBuilder, region, inputTypes, outputTypes, capturedValues, - [&](unsigned expected, unsigned actual) { - res = parser.emitError( - parser.getCurrentLocation(), - llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated " - "region expects {0} args, got {1}", - expected, actual)); - region.front().dump(); - }); - return res; -} - -static ParseResult -parseNamedStructuredOpResults(OpAsmParser &parser, - SmallVectorImpl &resultTypes) { - if (succeeded(parser.parseOptionalArrow())) - if (parser.parseTypeList(resultTypes)) - return failure(); - return success(); -} - -template -static ParseResult -parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, - ArrayRef captures) { - // TODO: Enable when ods-gen supports captures. - assert(captures.empty() && "unexpected captures for named structured ops"); +static ParseResult parseNamedStructuredOp(OpAsmParser &parser, + OperationState &result) { SmallVector inputTypes, outputTypes; if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) return failure(); @@ -1951,7 +1835,7 @@ parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, std::unique_ptr region = std::make_unique(); if (parseNamedStructuredOpRegion( - parser, *region, inputTypes, outputTypes, captures)) + parser, *region, inputTypes, outputTypes)) return failure(); result.addRegion(std::move(region)); @@ -1965,6 +1849,15 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p, p.printOptionalArrowTypeList(resultTypes); } +template +static void printCommonStructuredOpParts(OpAsmPrinter &p, + NamedStructuredOpType op) { + if (!op.inputs().empty()) + p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")"; + if (!op.outputs().empty()) + p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")"; +} + template static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { p << op.getOperationName(); @@ -1986,10 +1879,6 @@ static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) { return verifyGenericOp(op); } -//===----------------------------------------------------------------------===// -// Canonicalizers and Folders. -//===----------------------------------------------------------------------===// - namespace { struct EraseDeadLinalgOp : public RewritePattern { EraseDeadLinalgOp(PatternBenefit benefit = 1) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp index 69de55c00cb7..0be1c55c1ea7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -49,7 +49,7 @@ static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp, indexingMaps, iterators, [®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) { edsc::ScopedContext scope(bodyBuilder, loc); - regionBuilder(*bodyBuilder.getBlock(), /*captures=*/{}); + regionBuilder(*bodyBuilder.getBlock()); }); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index d09d3e0b5edd..391562b032a7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -52,6 +52,14 @@ static SmallVector makeCanonicalAffineApplies(OpBuilder &b, return res; } +static SmallVector permuteIvs(ArrayRef ivs, + Optional permutation) { + return permutation ? applyMapToValues(ScopedContext::getBuilderRef(), + ScopedContext::getLocation(), + permutation.getValue(), ivs) + : SmallVector(ivs.begin(), ivs.end()); +} + template static void inlineRegionAndEmitStore(OpType op, ArrayRef indexedValues, ArrayRef> indexing, @@ -170,6 +178,40 @@ static void emitScalarImplementation(ArrayRef allIvs, outputBuffers); } +template +static void emitScalarImplementation(ArrayRef allIvs, CopyOp copyOp) { + assert(copyOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + auto nPar = copyOp.getNumParallelLoops(); + assert(nPar == allIvs.size()); + auto inputIvs = + permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation()); + auto outputIvs = + permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation()); + SmallVector iivs(inputIvs.begin(), inputIvs.end()); + SmallVector oivs(outputIvs.begin(), outputIvs.end()); + IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0)); + // Emit the proper scalar assignment, whether we are dealing with a 0-D or + // an n-D loop nest; with or without permutations. + // clang-format off + nPar > 0 ? O(oivs) = I(iivs) : + O() = I(); + // clang-format on +} + +template +static void emitScalarImplementation(ArrayRef allIvs, FillOp fillOp) { + assert(fillOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + auto nPar = fillOp.getNumParallelLoops(); + assert(nPar == allIvs.size()); + auto ivs = SmallVector(allIvs.begin(), allIvs.begin() + nPar); + IndexedValueType O(fillOp.getOutputBuffer(0)); + // Emit the proper scalar assignment, whether we are dealing with a 0-D or + // an n-D loop nest; with or without permutations. + nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value(); +} + // Create a padded view into the given `input` tensor using the 'indices' // to access the tensor. `skipPadding` lists the dimensions for which no padding // is needed e.g. the non-spatial dimensions for convolutions. @@ -491,8 +533,8 @@ linalgOpToLoopsImpl(Operation *op, OpBuilder &builder, assert(iterArgs.empty() && "unexpected iterArgs"); allIvs.append(ivs.begin(), ivs.end()); llvm::TypeSwitch(op) - .Case([&](auto op) { + .Case([&](auto op) { emitScalarImplementation(allIvs, op); }) .Default([&](Operation *op) { assert(false && "unexpected op"); }); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index bfd288464c68..49d323aebe92 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -267,7 +267,7 @@ static Optional vectorizeAsLinalgGeneric( llvm::map_range(linalgOp.getShapedOperandTypes(), [](ShapedType t) { return t.getElementType(); })); block->addArguments(elementTypes); - linalgOp.getRegionBuilder()(*block, /*captures=*/{}); + linalgOp.getRegionBuilder()(*block); } Block *block = ®ion->front(); @@ -333,26 +333,24 @@ static bool hasOnlyScalarElementwiseOp(Region &r) { // Return true if the op is an element-wise linalg op. static bool isElementwise(Operation *op) { - auto linalgOp = dyn_cast(op); - if (!linalgOp) + auto genericOp = dyn_cast(op); + if (!genericOp) return false; - if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) + if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) return false; // TODO: relax the restrictions on indexing map. - for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) { - if (!linalgOp.getOutputIndexingMap(i).isIdentity()) + for (unsigned i = 0, e = genericOp.getNumOutputs(); i < e; i++) { + if (!genericOp.getOutputIndexingMap(i).isIdentity()) return false; } // Currently bound the input indexing map to minor identity as other // permutations might require adding transpose ops to convert the vector read // to the right shape. - for (unsigned i = 0, e = linalgOp.getNumInputs(); i < e; i++) { - if (!linalgOp.getInputIndexingMap(i).isMinorIdentity()) + for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) { + if (!genericOp.getInputIndexingMap(i).isMinorIdentity()) return false; } - if (linalgOp->getNumRegions() != 1) - return false; - return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0)); + return hasOnlyScalarElementwiseOp(genericOp.getRegion()); } static Optional vectorizeContraction(OpBuilder &builder, @@ -395,6 +393,9 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { for (Type outputTensorType : linalgOp.getOutputTensorTypes()) if (!outputTensorType.cast().hasStaticShape()) return failure(); + + if (isa(op)) + return success(); if (isElementwise(op)) return success(); return success(isaContractionOpInterface(linalgOp)); @@ -406,12 +407,43 @@ Optional mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, return llvm::None; edsc::ScopedContext scope(builder, op->getLoc()); + // In the case of 0-D memrefs, return null and special case to scalar load or + // store later. + if (auto fillOp = dyn_cast(op)) { + // Vectorize fill as a vector.broadcast. + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " + << "Rewrite linalg.fill as vector.broadcast: " << *op); + VectorizedLinalgOp res; + if (Value v = buildVectorWrite(builder, fillOp.value(), fillOp.output())) + res.tensorResults.push_back(v); + return res; + } + if (auto copyOp = dyn_cast(op)) { + // Vectorize copy as a vector.transfer_read+vector.transfer_write. + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " + << "Rewrite linalg.copy as vector.transfer_read + " + "vector.transfer_write: " + << *op); + Value vector = buildVectorRead(builder, copyOp.input()); + VectorizedLinalgOp res; + if (Value v = buildVectorWrite(builder, vector, copyOp.output())) + res.tensorResults.push_back(v); + return res; + } if (isElementwise(op)) { LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " << "Vectorize linalg op as a generic: " << *op); return vectorizeAsLinalgGeneric(builder, cast(op)); } + // TODO: as soon as Copy and FillOp. get a region builder, replace all the + // above by: + // if (isa(op) || isElementwise(op)) { + // LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " + // << "Vectorize linalg op as a generic: " << *op); + // return vectorizeAsLinalgGeneric(builder, cast(op)); + // } + return vectorizeContraction(builder, cast(op)); } diff --git a/mlir/test/Transforms/copy-removal.mlir b/mlir/test/Transforms/copy-removal.mlir index 1432037f2110..a66006fd0af8 100644 --- a/mlir/test/Transforms/copy-removal.mlir +++ b/mlir/test/Transforms/copy-removal.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt -copy-removal -split-input-file %s | FileCheck %s +// RUN: mlir-opt -copy-removal -split-input-file %s +//| FileCheck %s // All linalg copies except the linalg.copy(%1, %9) must be removed since the // defining operation of %1 and its DeallocOp have been defined in another block. @@ -255,7 +256,7 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>, %result: memref<2xf32>) %tmp2 = math.exp %gen2_arg0 : f32 linalg.yield %tmp2 : f32 } - linalg.copy(%temp, %result) : memref<2xf32>, memref<2xf32> + "linalg.copy"(%temp, %result) : (memref<2xf32>, memref<2xf32>) -> () dealloc %temp : memref<2xf32> // CHECK: return return @@ -291,7 +292,7 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>){ linalg.yield %tmp2 : f32 } // CHECK: linalg.copy - linalg.copy(%temp, %to) : memref<2xf32>, memref<2xf32> + "linalg.copy"(%temp, %to) : (memref<2xf32>, memref<2xf32>) -> () dealloc %temp : memref<2xf32> return } @@ -354,7 +355,7 @@ func @check_with_affine_dialect(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg } // CHECK-NOT: linalg.copy // CHECK-NOT: dealloc - linalg.copy(%0, %arg2) : memref<4xf32>, memref<4xf32> + "linalg.copy"(%0, %arg2) : (memref<4xf32>, memref<4xf32>) -> () dealloc %0 : memref<4xf32> //CHECK: return return diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc index b197ba3da65d..a16a2b85a9ec 100644 --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -23,7 +23,7 @@ // IMPL-NEXT: map2 = simplifyAffineMap(map2); // IMPL-NEXT: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 }); // -// IMPL: void Test1Op::regionBuilder(Block &block, ValueRange captures) { +// IMPL: void Test1Op::regionBuilder(Block &block) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]); @@ -47,7 +47,7 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) { // IMPL: AffineMap::get(3, 3, {d2, d1}, context) // IMPL: AffineMap::get(3, 3, {d0, d1}, context) // -// IMPL: Test2Op::regionBuilder(Block &block, ValueRange captures) { +// IMPL: Test2Op::regionBuilder(Block &block) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]); @@ -71,7 +71,7 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) { // IMPL: AffineMap::get(4, 4, {d3, d2}, context) // IMPL: AffineMap::get(4, 4, {d0, d1, d2}, context) // -// IMPL: Test3Op::regionBuilder(Block &block, ValueRange captures) { +// IMPL: Test3Op::regionBuilder(Block &block) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]); diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp index 4f57322c8be6..0934967f516c 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1871,11 +1871,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, $_builder.getI32VectorAttr({{ static_cast(inputs.size()), static_cast(outputs.size())})); - createAndFillStructuredOpRegion<{0}>( + buildNamedStructuredOpRegionAndAttributes<{0}>( $_builder, $_state, TypeRange(inputs), - TypeRange(outputs)/*, TODO: support captures*/); + TypeRange(outputs)); }]>, OpBuilderDAG< (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, @@ -1889,11 +1889,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, $_builder.getI32VectorAttr({{ static_cast(inputs.size()), static_cast(outputs.size())})); - createAndFillStructuredOpRegion<{0}>( + buildNamedStructuredOpRegionAndAttributes<{0}>( $_builder, $_state, TypeRange(inputs), - TypeRange(outputs)/*, TODO: support captures*/); + TypeRange(outputs)); }]>, OpBuilderDAG< (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, @@ -1907,9 +1907,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, {6} ]; let printer = [{{ return ::printNamedStructuredOp(p, *this); }]; - let parser = [{{ - return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/); - }]; + let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }]; let hasFolder = 1; let hasCanonicalizer = 1; @@ -1917,8 +1915,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, // Auto-generated. ArrayAttr iterator_types(); ArrayAttr indexing_maps(); - static void regionBuilder(Block &block, ValueRange captures); - static std::function getRegionBuilder() {{ + static void regionBuilder(Block &block); + static std::function getRegionBuilder() {{ return regionBuilder; } @@ -1982,11 +1980,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, $_builder.getI32VectorAttr({{ static_cast(inputs.size()), static_cast(outputs.size())})); - createAndFillStructuredOpRegion<{0}>( + buildNamedStructuredOpRegionAndAttributes<{0}>( $_builder, $_state, TypeRange(inputs), - TypeRange(outputs)/*, TODO: support captures*/); + TypeRange(outputs)); {2} }]> )FMT"; @@ -2313,7 +2311,7 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName, }; const char *regionBuilderFmt = R"FMT( - void {0}::regionBuilder(Block &block, ValueRange captures) { + void {0}::regionBuilder(Block &block) { using namespace edsc; using namespace intrinsics; auto args = block.getArguments();