From 632a4f8829425325c9d5b110a78bfc6e0bb855e9 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 25 Jan 2022 18:41:02 -0800 Subject: [PATCH] [mlir] Move std.generic_atomic_rmw to the memref dialect This is part of splitting up the standard dialect. The move makes sense anyways, given that the memref dialect already holds memref.atomic_rmw which is the non-region sibling operation of std.generic_atomic_rmw (the relationship is even more clear given they have nearly the same description % how they represent the inner computation). Differential Revision: https://reviews.llvm.org/D118209 --- .../mlir/Dialect/MemRef/IR/MemRefOps.td | 79 +++++++++- .../mlir/Dialect/StandardOps/IR/Ops.td | 70 -------- .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 149 ++++++++++++++++-- .../StandardToLLVM/StandardToLLVM.cpp | 135 ---------------- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 83 ++++++++++ mlir/lib/Dialect/StandardOps/IR/Ops.cpp | 82 ---------- .../StandardOps/Transforms/ExpandOps.cpp | 12 +- .../MemRefToLLVM/memref-to-llvm.mlir | 19 +++ .../StandardToLLVM/standard-to-llvm.mlir | 27 ---- mlir/test/Dialect/MemRef/invalid.mlir | 60 +++++++ mlir/test/Dialect/MemRef/ops.mlir | 14 ++ mlir/test/Dialect/Standard/expand-ops.mlir | 4 +- mlir/test/IR/core-ops.mlir | 14 -- mlir/test/IR/invalid-ops.mlir | 60 ------- 14 files changed, 395 insertions(+), 413 deletions(-) diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 49f716dbf9b2..b4a7002726ce 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -698,6 +698,81 @@ def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> { let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// GenericAtomicRMWOp +//===----------------------------------------------------------------------===// + +def GenericAtomicRMWOp : MemRef_Op<"generic_atomic_rmw", [ + SingleBlockImplicitTerminator<"AtomicYieldOp">, + TypesMatchWith<"result type matches element type of memref", + "memref", "result", + "$_self.cast().getElementType()"> + ]> { + let summary = "atomic read-modify-write operation with a region"; + let description = [{ + The `memref.generic_atomic_rmw` operation provides a way to perform a + read-modify-write sequence that is free from data races. The memref operand + represents the buffer that the read and write will be performed against, as + accessed by the specified indices. The arity of the indices is the rank of + the memref. The result represents the latest value that was stored. The + region contains the code for the modification itself. The entry block has + a single argument that represents the value stored in `memref[indices]` + before the write is performed. No side-effecting ops are allowed in the + body of `GenericAtomicRMWOp`. + + Example: + + ```mlir + %x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> { + ^bb0(%current_value : f32): + %c1 = arith.constant 1.0 : f32 + %inc = arith.addf %c1, %current_value : f32 + memref.atomic_yield %inc : f32 + } + ``` + }]; + + let arguments = (ins + MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref, + Variadic:$indices); + + let results = (outs + AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result); + + let regions = (region AnyRegion:$atomic_body); + + let skipDefaultBuilders = 1; + let builders = [OpBuilder<(ins "Value":$memref, "ValueRange":$ivs)>]; + + let extraClassDeclaration = [{ + // TODO: remove post migrating callers. + Region &body() { return getRegion(); } + + // The value stored in memref[ivs]. + Value getCurrentValue() { + return getRegion().getArgument(0); + } + MemRefType getMemRefType() { + return memref().getType().cast(); + } + }]; +} + +def AtomicYieldOp : MemRef_Op<"atomic_yield", [ + HasParent<"GenericAtomicRMWOp">, + NoSideEffect, + Terminator + ]> { + let summary = "yield operation for GenericAtomicRMWOp"; + let description = [{ + "memref.atomic_yield" yields an SSA value from a + GenericAtomicRMWOp region. + }]; + + let arguments = (ins AnyType:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + //===----------------------------------------------------------------------===// // GetGlobalOp //===----------------------------------------------------------------------===// @@ -1687,7 +1762,7 @@ def AtomicRMWOp : MemRef_Op<"atomic_rmw", [ ]> { let summary = "atomic read-modify-write operation"; let description = [{ - The `atomic_rmw` operation provides a way to perform a read-modify-write + The `memref.atomic_rmw` operation provides a way to perform a read-modify-write sequence that is free from data races. The kind enumeration specifies the modification to perform. The value operand represents the new value to be applied during the modification. The memref operand represents the buffer @@ -1698,7 +1773,7 @@ def AtomicRMWOp : MemRef_Op<"atomic_rmw", [ Example: ```mlir - %x = arith.atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32 + %x = memref.atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32 ``` }]; diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index 794f0157ef14..deedafb66118 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -178,76 +178,6 @@ def AssertOp : Std_Op<"assert"> { let hasCanonicalizeMethod = 1; } -def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [ - SingleBlockImplicitTerminator<"AtomicYieldOp">, - TypesMatchWith<"result type matches element type of memref", - "memref", "result", - "$_self.cast().getElementType()"> - ]> { - let summary = "atomic read-modify-write operation with a region"; - let description = [{ - The `generic_atomic_rmw` operation provides a way to perform a read-modify-write - sequence that is free from data races. The memref operand represents the - buffer that the read and write will be performed against, as accessed by - the specified indices. The arity of the indices is the rank of the memref. - The result represents the latest value that was stored. The region contains - the code for the modification itself. The entry block has a single argument - that represents the value stored in `memref[indices]` before the write is - performed. No side-effecting ops are allowed in the body of - `GenericAtomicRMWOp`. - - Example: - - ```mlir - %x = generic_atomic_rmw %I[%i] : memref<10xf32> { - ^bb0(%current_value : f32): - %c1 = arith.constant 1.0 : f32 - %inc = arith.addf %c1, %current_value : f32 - atomic_yield %inc : f32 - } - ``` - }]; - - let arguments = (ins - MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref, - Variadic:$indices); - - let results = (outs - AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result); - - let regions = (region AnyRegion:$atomic_body); - - let skipDefaultBuilders = 1; - let builders = [OpBuilder<(ins "Value":$memref, "ValueRange":$ivs)>]; - - let extraClassDeclaration = [{ - // TODO: remove post migrating callers. - Region &body() { return getRegion(); } - - // The value stored in memref[ivs]. - Value getCurrentValue() { - return getRegion().getArgument(0); - } - MemRefType getMemRefType() { - return getMemref().getType().cast(); - } - }]; -} - -def AtomicYieldOp : Std_Op<"atomic_yield", [ - HasParent<"GenericAtomicRMWOp">, - NoSideEffect, - Terminator - ]> { - let summary = "yield operation for GenericAtomicRMWOp"; - let description = [{ - "atomic_yield" yields an SSA value from a GenericAtomicRMWOp region. - }]; - - let arguments = (ins AnyType:$result); - let assemblyFormat = "$result attr-dict `:` type($result)"; -} - //===----------------------------------------------------------------------===// // BranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 1a3384ca2c04..81f842e80e35 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -412,6 +412,139 @@ private: } }; +/// Common base for load and store operations on MemRefs. Restricts the match +/// to supported MemRef types. Provides functionality to emit code accessing a +/// specific element of the underlying data buffer. +template +struct LoadStoreOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using ConvertOpToLLVMPattern::isConvertibleAndHasIdentityMaps; + using Base = LoadStoreOpLowering; + + LogicalResult match(Derived op) const override { + MemRefType type = op.getMemRefType(); + return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); + } +}; + +/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be +/// retried until it succeeds in atomically storing a new value into memory. +/// +/// +---------------------------------+ +/// | | +/// | | +/// | br loop(%loaded) | +/// +---------------------------------+ +/// | +/// -------| | +/// | v v +/// | +--------------------------------+ +/// | | loop(%loaded): | +/// | | | +/// | | %pair = cmpxchg | +/// | | %ok = %pair[0] | +/// | | %new = %pair[1] | +/// | | cond_br %ok, end, loop(%new) | +/// | +--------------------------------+ +/// | | | +/// |----------- | +/// v +/// +--------------------------------+ +/// | end: | +/// | | +/// +--------------------------------+ +/// +struct GenericAtomicRMWOpLowering + : public LoadStoreOpLowering { + using Base::Base; + + LogicalResult + matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = atomicOp.getLoc(); + Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); + + // Split the block into initial, loop, and ending parts. + auto *initBlock = rewriter.getInsertionBlock(); + auto *loopBlock = rewriter.createBlock( + initBlock->getParent(), std::next(Region::iterator(initBlock)), + valueType, loc); + auto *endBlock = rewriter.createBlock( + loopBlock->getParent(), std::next(Region::iterator(loopBlock))); + + // Operations range to be moved to `endBlock`. + auto opsToMoveStart = atomicOp->getIterator(); + auto opsToMoveEnd = initBlock->back().getIterator(); + + // Compute the loaded value and branch to the loop block. + rewriter.setInsertionPointToEnd(initBlock); + auto memRefType = atomicOp.memref().getType().cast(); + auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), + adaptor.indices(), rewriter); + Value init = rewriter.create(loc, dataPtr); + rewriter.create(loc, init, loopBlock); + + // Prepare the body of the loop block. + rewriter.setInsertionPointToStart(loopBlock); + + // Clone the GenericAtomicRMWOp region and extract the result. + auto loopArgument = loopBlock->getArgument(0); + BlockAndValueMapping mapping; + mapping.map(atomicOp.getCurrentValue(), loopArgument); + Block &entryBlock = atomicOp.body().front(); + for (auto &nestedOp : entryBlock.without_terminator()) { + Operation *clone = rewriter.clone(nestedOp, mapping); + mapping.map(nestedOp.getResults(), clone->getResults()); + } + Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); + + // Prepare the epilog of the loop block. + // Append the cmpxchg op to the end of the loop block. + auto successOrdering = LLVM::AtomicOrdering::acq_rel; + auto failureOrdering = LLVM::AtomicOrdering::monotonic; + auto boolType = IntegerType::get(rewriter.getContext(), 1); + auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), + {valueType, boolType}); + auto cmpxchg = rewriter.create( + loc, pairType, dataPtr, loopArgument, result, successOrdering, + failureOrdering); + // Extract the %new_loaded and %ok values from the pair. + Value newLoaded = rewriter.create( + loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); + Value ok = rewriter.create( + loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); + + // Conditionally branch to the end or back to the loop depending on %ok. + rewriter.create(loc, ok, endBlock, ArrayRef(), + loopBlock, newLoaded); + + rewriter.setInsertionPointToEnd(endBlock); + moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), + std::next(opsToMoveEnd), rewriter); + + // The 'result' of the atomic_rmw op is the newly loaded value. + rewriter.replaceOp(atomicOp, {newLoaded}); + + return success(); + } + +private: + // Clones a segment of ops [start, end) and erases the original. + void moveOpsRange(ValueRange oldResult, ValueRange newResult, + Block::iterator start, Block::iterator end, + ConversionPatternRewriter &rewriter) const { + BlockAndValueMapping mapping; + mapping.map(oldResult, newResult); + SmallVector opsToErase; + for (auto it = start; it != end; ++it) { + rewriter.clone(*it, mapping); + opsToErase.push_back(&*it); + } + for (auto *it : opsToErase) + rewriter.eraseOp(it); + } +}; + /// Returns the LLVM type of the global variable given the memref type `type`. static Type convertGlobalMemrefTypeToLLVM(MemRefType type, LLVMTypeConverter &typeConverter) { @@ -520,21 +653,6 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { } }; -// Common base for load and store operations on MemRefs. Restricts the match -// to supported MemRef types. Provides functionality to emit code accessing a -// specific element of the underlying data buffer. -template -struct LoadStoreOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - using ConvertOpToLLVMPattern::isConvertibleAndHasIdentityMaps; - using Base = LoadStoreOpLowering; - - LogicalResult match(Derived op) const override { - MemRefType type = op.getMemRefType(); - return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); - } -}; - // Load operation is lowered to obtaining a pointer to the indexed element // and loading it. struct LoadOpLowering : public LoadStoreOpLowering { @@ -1683,6 +1801,7 @@ void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, AtomicRMWOpLowering, AssumeAlignmentOpLowering, DimOpLowering, + GenericAtomicRMWOpLowering, GlobalMemrefOpLowering, GetGlobalMemrefOpLowering, LoadOpLowering, diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 97a3be46b7ee..b68587d43c77 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -565,21 +565,6 @@ struct UnrealizedConversionCastOpLowering } }; -// Common base for load and store operations on MemRefs. Restricts the match -// to supported MemRef types. Provides functionality to emit code accessing a -// specific element of the underlying data buffer. -template -struct LoadStoreOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - using ConvertOpToLLVMPattern::isConvertibleAndHasIdentityMaps; - using Base = LoadStoreOpLowering; - - LogicalResult match(Derived op) const override { - MemRefType type = op.getMemRefType(); - return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); - } -}; - // Base class for LLVM IR lowering terminator operations with successors. template struct OneToOneLLVMTerminatorLowering @@ -771,125 +756,6 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern { } }; -/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be -/// retried until it succeeds in atomically storing a new value into memory. -/// -/// +---------------------------------+ -/// | | -/// | | -/// | br loop(%loaded) | -/// +---------------------------------+ -/// | -/// -------| | -/// | v v -/// | +--------------------------------+ -/// | | loop(%loaded): | -/// | | | -/// | | %pair = cmpxchg | -/// | | %ok = %pair[0] | -/// | | %new = %pair[1] | -/// | | cond_br %ok, end, loop(%new) | -/// | +--------------------------------+ -/// | | | -/// |----------- | -/// v -/// +--------------------------------+ -/// | end: | -/// | | -/// +--------------------------------+ -/// -struct GenericAtomicRMWOpLowering - : public LoadStoreOpLowering { - using Base::Base; - - LogicalResult - matchAndRewrite(GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - auto loc = atomicOp.getLoc(); - Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); - - // Split the block into initial, loop, and ending parts. - auto *initBlock = rewriter.getInsertionBlock(); - auto *loopBlock = rewriter.createBlock( - initBlock->getParent(), std::next(Region::iterator(initBlock)), - valueType, loc); - auto *endBlock = rewriter.createBlock( - loopBlock->getParent(), std::next(Region::iterator(loopBlock))); - - // Operations range to be moved to `endBlock`. - auto opsToMoveStart = atomicOp->getIterator(); - auto opsToMoveEnd = initBlock->back().getIterator(); - - // Compute the loaded value and branch to the loop block. - rewriter.setInsertionPointToEnd(initBlock); - auto memRefType = atomicOp.getMemref().getType().cast(); - auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(), - adaptor.getIndices(), rewriter); - Value init = rewriter.create(loc, dataPtr); - rewriter.create(loc, init, loopBlock); - - // Prepare the body of the loop block. - rewriter.setInsertionPointToStart(loopBlock); - - // Clone the GenericAtomicRMWOp region and extract the result. - auto loopArgument = loopBlock->getArgument(0); - BlockAndValueMapping mapping; - mapping.map(atomicOp.getCurrentValue(), loopArgument); - Block &entryBlock = atomicOp.body().front(); - for (auto &nestedOp : entryBlock.without_terminator()) { - Operation *clone = rewriter.clone(nestedOp, mapping); - mapping.map(nestedOp.getResults(), clone->getResults()); - } - Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); - - // Prepare the epilog of the loop block. - // Append the cmpxchg op to the end of the loop block. - auto successOrdering = LLVM::AtomicOrdering::acq_rel; - auto failureOrdering = LLVM::AtomicOrdering::monotonic; - auto boolType = IntegerType::get(rewriter.getContext(), 1); - auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), - {valueType, boolType}); - auto cmpxchg = rewriter.create( - loc, pairType, dataPtr, loopArgument, result, successOrdering, - failureOrdering); - // Extract the %new_loaded and %ok values from the pair. - Value newLoaded = rewriter.create( - loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); - Value ok = rewriter.create( - loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); - - // Conditionally branch to the end or back to the loop depending on %ok. - rewriter.create(loc, ok, endBlock, ArrayRef(), - loopBlock, newLoaded); - - rewriter.setInsertionPointToEnd(endBlock); - moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), - std::next(opsToMoveEnd), rewriter); - - // The 'result' of the atomic_rmw op is the newly loaded value. - rewriter.replaceOp(atomicOp, {newLoaded}); - - return success(); - } - -private: - // Clones a segment of ops [start, end) and erases the original. - void moveOpsRange(ValueRange oldResult, ValueRange newResult, - Block::iterator start, Block::iterator end, - ConversionPatternRewriter &rewriter) const { - BlockAndValueMapping mapping; - mapping.map(oldResult, newResult); - SmallVector opsToErase; - for (auto it = start; it != end; ++it) { - rewriter.clone(*it, mapping); - opsToErase.push_back(&*it); - } - for (auto *it : opsToErase) - rewriter.eraseOp(it); - } -}; - } // namespace void mlir::populateStdToLLVMFuncOpConversionPattern( @@ -911,7 +777,6 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, CallOpLowering, CondBranchOpLowering, ConstantOpLowering, - GenericAtomicRMWOpLowering, ReturnOpLowering, SelectOpLowering, SplatOpLowering, diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 211af3045b9d..0c29607d4601 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -945,6 +945,89 @@ static LogicalResult verify(DmaWaitOp op) { return success(); } +//===----------------------------------------------------------------------===// +// GenericAtomicRMWOp +//===----------------------------------------------------------------------===// + +void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result, + Value memref, ValueRange ivs) { + result.addOperands(memref); + result.addOperands(ivs); + + if (auto memrefType = memref.getType().dyn_cast()) { + Type elementType = memrefType.getElementType(); + result.addTypes(elementType); + + Region *bodyRegion = result.addRegion(); + bodyRegion->push_back(new Block()); + bodyRegion->addArgument(elementType, memref.getLoc()); + } +} + +static LogicalResult verify(GenericAtomicRMWOp op) { + auto &body = op.getRegion(); + if (body.getNumArguments() != 1) + return op.emitOpError("expected single number of entry block arguments"); + + if (op.getResult().getType() != body.getArgument(0).getType()) + return op.emitOpError( + "expected block argument of the same type result type"); + + bool hasSideEffects = + body.walk([&](Operation *nestedOp) { + if (MemoryEffectOpInterface::hasNoEffect(nestedOp)) + return WalkResult::advance(); + nestedOp->emitError( + "body of 'memref.generic_atomic_rmw' should contain " + "only operations with no side effects"); + return WalkResult::interrupt(); + }) + .wasInterrupted(); + return hasSideEffects ? failure() : success(); +} + +static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType memref; + Type memrefType; + SmallVector ivs; + + Type indexType = parser.getBuilder().getIndexType(); + if (parser.parseOperand(memref) || + parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) || + parser.parseColonType(memrefType) || + parser.resolveOperand(memref, memrefType, result.operands) || + parser.resolveOperands(ivs, indexType, result.operands)) + return failure(); + + Region *body = result.addRegion(); + if (parser.parseRegion(*body, llvm::None, llvm::None) || + parser.parseOptionalAttrDict(result.attributes)) + return failure(); + result.types.push_back(memrefType.cast().getElementType()); + return success(); +} + +static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) { + p << ' ' << op.memref() << "[" << op.indices() + << "] : " << op.memref().getType() << ' '; + p.printRegion(op.getRegion()); + p.printOptionalAttrDict(op->getAttrs()); +} + +//===----------------------------------------------------------------------===// +// AtomicYieldOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(AtomicYieldOp op) { + Type parentType = op->getParentOp()->getResultTypes().front(); + Type resultType = op.result().getType(); + if (parentType != resultType) + return op.emitOpError() << "types mismatch between yield op: " << resultType + << " and its parent: " << parentType; + return success(); +} + //===----------------------------------------------------------------------===// // GlobalOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 7487aea0902d..ffe155ad214e 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -131,88 +131,6 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) { return failure(); } -//===----------------------------------------------------------------------===// -// GenericAtomicRMWOp -//===----------------------------------------------------------------------===// - -void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result, - Value memref, ValueRange ivs) { - result.addOperands(memref); - result.addOperands(ivs); - - if (auto memrefType = memref.getType().dyn_cast()) { - Type elementType = memrefType.getElementType(); - result.addTypes(elementType); - - Region *bodyRegion = result.addRegion(); - bodyRegion->push_back(new Block()); - bodyRegion->addArgument(elementType, memref.getLoc()); - } -} - -static LogicalResult verify(GenericAtomicRMWOp op) { - auto &body = op.getRegion(); - if (body.getNumArguments() != 1) - return op.emitOpError("expected single number of entry block arguments"); - - if (op.getResult().getType() != body.getArgument(0).getType()) - return op.emitOpError( - "expected block argument of the same type result type"); - - bool hasSideEffects = - body.walk([&](Operation *nestedOp) { - if (MemoryEffectOpInterface::hasNoEffect(nestedOp)) - return WalkResult::advance(); - nestedOp->emitError("body of 'generic_atomic_rmw' should contain " - "only operations with no side effects"); - return WalkResult::interrupt(); - }) - .wasInterrupted(); - return hasSideEffects ? failure() : success(); -} - -static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType memref; - Type memrefType; - SmallVector ivs; - - Type indexType = parser.getBuilder().getIndexType(); - if (parser.parseOperand(memref) || - parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) || - parser.parseColonType(memrefType) || - parser.resolveOperand(memref, memrefType, result.operands) || - parser.resolveOperands(ivs, indexType, result.operands)) - return failure(); - - Region *body = result.addRegion(); - if (parser.parseRegion(*body, llvm::None, llvm::None) || - parser.parseOptionalAttrDict(result.attributes)) - return failure(); - result.types.push_back(memrefType.cast().getElementType()); - return success(); -} - -static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) { - p << ' ' << op.getMemref() << "[" << op.getIndices() - << "] : " << op.getMemref().getType() << ' '; - p.printRegion(op.getRegion()); - p.printOptionalAttrDict(op->getAttrs()); -} - -//===----------------------------------------------------------------------===// -// AtomicYieldOp -//===----------------------------------------------------------------------===// - -static LogicalResult verify(AtomicYieldOp op) { - Type parentType = op->getParentOp()->getResultTypes().front(); - Type resultType = op.getResult().getType(); - if (parentType != resultType) - return op.emitOpError() << "types mismatch between yield op: " << resultType - << " and its parent: " << parentType; - return success(); -} - //===----------------------------------------------------------------------===// // BranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp index d7c7755e2402..e5cf08da4904 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp @@ -28,17 +28,17 @@ namespace { /// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with /// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to -/// `generic_atomic_rmw` with the expanded code. +/// `memref.generic_atomic_rmw` with the expanded code. /// /// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32 /// /// will be lowered to /// -/// %x = std.generic_atomic_rmw %F[%i] : memref<10xf32> { +/// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> { /// ^bb0(%current: f32): /// %cmp = arith.cmpf "ogt", %current, %fval : f32 /// %new_value = select %cmp, %current, %fval : f32 -/// atomic_yield %new_value : f32 +/// memref.atomic_yield %new_value : f32 /// } struct AtomicRMWOpConverter : public OpRewritePattern { public: @@ -59,8 +59,8 @@ public: } auto loc = op.getLoc(); - auto genericOp = - rewriter.create(loc, op.memref(), op.indices()); + auto genericOp = rewriter.create( + loc, op.memref(), op.indices()); OpBuilder bodyBuilder = OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener()); @@ -68,7 +68,7 @@ public: Value rhs = op.value(); Value cmp = bodyBuilder.create(loc, predicate, lhs, rhs); Value select = bodyBuilder.create(loc, cmp, lhs, rhs); - bodyBuilder.create(loc, select); + bodyBuilder.create(loc, select); rewriter.replaceOp(op, genericOp.getResult()); return success(); diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir index 221f019db875..ee7d36052c4a 100644 --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -892,3 +892,22 @@ func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, a %1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 64 + s0 + d1 * 64 + d2 * 8 + d3)>> into memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> return %1 : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> } + +// ----- + +// CHECK-LABEL: func @generic_atomic_rmw +func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) { + %x = memref.generic_atomic_rmw %I[%i] : memref<10xi32> { + ^bb0(%old_value : i32): + memref.atomic_yield %old_value : i32 + } + // CHECK: [[init:%.*]] = llvm.load %{{.*}} : !llvm.ptr + // CHECK-NEXT: llvm.br ^bb1([[init]] : i32) + // CHECK-NEXT: ^bb1([[loaded:%.*]]: i32): + // CHECK-NEXT: [[pair:%.*]] = llvm.cmpxchg %{{.*}}, [[loaded]], [[loaded]] + // CHECK-SAME: acq_rel monotonic : i32 + // CHECK-NEXT: [[new:%.*]] = llvm.extractvalue [[pair]][0] + // CHECK-NEXT: [[ok:%.*]] = llvm.extractvalue [[pair]][1] + // CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : i32) + llvm.return +} diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir index 0dc6bf10dc5e..26f064fb0312 100644 --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -486,33 +486,6 @@ func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> { // ----- -// CHECK-LABEL: func @generic_atomic_rmw -func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) -> i32 { - %x = generic_atomic_rmw %I[%i] : memref<10xi32> { - ^bb0(%old_value : i32): - %c1 = arith.constant 1 : i32 - atomic_yield %c1 : i32 - } - // CHECK: [[init:%.*]] = llvm.load %{{.*}} : !llvm.ptr - // CHECK-NEXT: llvm.br ^bb1([[init]] : i32) - // CHECK-NEXT: ^bb1([[loaded:%.*]]: i32): - // CHECK-NEXT: [[c1:%.*]] = llvm.mlir.constant(1 : i32) - // CHECK-NEXT: [[pair:%.*]] = llvm.cmpxchg %{{.*}}, [[loaded]], [[c1]] - // CHECK-SAME: acq_rel monotonic : i32 - // CHECK-NEXT: [[new:%.*]] = llvm.extractvalue [[pair]][0] - // CHECK-NEXT: [[ok:%.*]] = llvm.extractvalue [[pair]][1] - // CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : i32) - // CHECK-NEXT: ^bb2: - %c2 = arith.constant 2 : i32 - %add = arith.addi %c2, %x : i32 - return %add : i32 - // CHECK-NEXT: [[c2:%.*]] = llvm.mlir.constant(2 : i32) - // CHECK-NEXT: [[add:%.*]] = llvm.add [[c2]], [[new]] : i32 - // CHECK-NEXT: llvm.return [[add]] -} - -// ----- - // CHECK-LABEL: func @ceilf( // CHECK-SAME: f32 func @ceilf(%arg0 : f32) { diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir index 6d0cef1efc7d..54e405b3f2ad 100644 --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -910,3 +910,63 @@ func @atomic_rmw_expects_int(%I: memref<16x10xf32>, %i : index, %val : f32) { %x = memref.atomic_rmw addi %val, %I[%i, %i] : (f32, memref<16x10xf32>) -> f32 return } + +// ----- + +func @generic_atomic_rmw_wrong_arg_num(%I: memref<10xf32>, %i : index) { + // expected-error@+1 {{expected single number of entry block arguments}} + %x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> { + ^bb0(%arg0 : f32, %arg1 : f32): + %c1 = arith.constant 1.0 : f32 + memref.atomic_yield %c1 : f32 + } + return +} + +// ----- + +func @generic_atomic_rmw_wrong_arg_type(%I: memref<10xf32>, %i : index) { + // expected-error@+1 {{expected block argument of the same type result type}} + %x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> { + ^bb0(%old_value : i32): + %c1 = arith.constant 1.0 : f32 + memref.atomic_yield %c1 : f32 + } + return +} + +// ----- + +func @generic_atomic_rmw_result_type_mismatch(%I: memref<10xf32>, %i : index) { + // expected-error@+1 {{failed to verify that result type matches element type of memref}} + %0 = "memref.generic_atomic_rmw"(%I, %i) ({ + ^bb0(%old_value: f32): + %c1 = arith.constant 1.0 : f32 + memref.atomic_yield %c1 : f32 + }) : (memref<10xf32>, index) -> i32 + return +} + +// ----- + +func @generic_atomic_rmw_has_side_effects(%I: memref<10xf32>, %i : index) { + // expected-error@+4 {{should contain only operations with no side effects}} + %x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> { + ^bb0(%old_value : f32): + %c1 = arith.constant 1.0 : f32 + %buf = memref.alloc() : memref<2048xf32> + memref.atomic_yield %c1 : f32 + } +} + +// ----- + +func @atomic_yield_type_mismatch(%I: memref<10xf32>, %i : index) { + // expected-error@+4 {{op types mismatch between yield op: 'i32' and its parent: 'f32'}} + %x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> { + ^bb0(%old_value : f32): + %c1 = arith.constant 1 : i32 + memref.atomic_yield %c1 : i32 + } + return +} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir index 1303a896e7ec..6191cdab02e2 100644 --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -246,3 +246,17 @@ func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) { // CHECK: memref.atomic_rmw addf [[VAL]], [[BUF]]{{\[}}[[I]]] return } + +// CHECK-LABEL: func @generic_atomic_rmw +// CHECK-SAME: ([[BUF:%.*]]: memref<1x2xf32>, [[I:%.*]]: index, [[J:%.*]]: index) +func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) { + %x = memref.generic_atomic_rmw %I[%i, %j] : memref<1x2xf32> { + // CHECK-NEXT: memref.generic_atomic_rmw [[BUF]]{{\[}}[[I]], [[J]]] : memref + ^bb0(%old_value : f32): + %c1 = arith.constant 1.0 : f32 + %out = arith.addf %c1, %old_value : f32 + memref.atomic_yield %out : f32 + // CHECK: index_attr = 8 : index + } { index_attr = 8 : index } + return +} diff --git a/mlir/test/Dialect/Standard/expand-ops.mlir b/mlir/test/Dialect/Standard/expand-ops.mlir index cb650ffd11bd..2a1c367ff80f 100644 --- a/mlir/test/Dialect/Standard/expand-ops.mlir +++ b/mlir/test/Dialect/Standard/expand-ops.mlir @@ -6,11 +6,11 @@ func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 { %x = memref.atomic_rmw maxf %f, %F[%i] : (f32, memref<10xf32>) -> f32 return %x : f32 } -// CHECK: %0 = generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { +// CHECK: %0 = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { // CHECK: ^bb0([[CUR_VAL:%.*]]: f32): // CHECK: [[CMP:%.*]] = arith.cmpf ogt, [[CUR_VAL]], [[f]] : f32 // CHECK: [[SELECT:%.*]] = select [[CMP]], [[CUR_VAL]], [[f]] : f32 -// CHECK: atomic_yield [[SELECT]] : f32 +// CHECK: memref.atomic_yield [[SELECT]] : f32 // CHECK: } // CHECK: return %0 : f32 diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index 351e8a6b39c1..81fdac964f76 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -325,20 +325,6 @@ func @unranked_tensor_load_store(%0 : memref<*xi32>, %1 : tensor<*xi32>) { return } -// CHECK-LABEL: func @generic_atomic_rmw -// CHECK-SAME: ([[BUF:%.*]]: memref<1x2xf32>, [[I:%.*]]: index, [[J:%.*]]: index) -func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) { - %x = generic_atomic_rmw %I[%i, %j] : memref<1x2xf32> { - // CHECK-NEXT: generic_atomic_rmw [[BUF]]{{\[}}[[I]], [[J]]] : memref - ^bb0(%old_value : f32): - %c1 = arith.constant 1.0 : f32 - %out = arith.addf %c1, %old_value : f32 - atomic_yield %out : f32 - // CHECK: index_attr = 8 : index - } { index_attr = 8 : index } - return -} - // CHECK-LABEL: func @assume_alignment // CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16> func @assume_alignment(%0: memref<4x4xf16>) { diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 9950262a481a..6650330c7eb7 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -127,63 +127,3 @@ func @invalid_splat(%v : f32) { // expected-note {{prior use here}} // expected-error@-1 {{expects different type than prior uses}} return } - -// ----- - -func @generic_atomic_rmw_wrong_arg_num(%I: memref<10xf32>, %i : index) { - // expected-error@+1 {{expected single number of entry block arguments}} - %x = generic_atomic_rmw %I[%i] : memref<10xf32> { - ^bb0(%arg0 : f32, %arg1 : f32): - %c1 = arith.constant 1.0 : f32 - atomic_yield %c1 : f32 - } - return -} - -// ----- - -func @generic_atomic_rmw_wrong_arg_type(%I: memref<10xf32>, %i : index) { - // expected-error@+1 {{expected block argument of the same type result type}} - %x = generic_atomic_rmw %I[%i] : memref<10xf32> { - ^bb0(%old_value : i32): - %c1 = arith.constant 1.0 : f32 - atomic_yield %c1 : f32 - } - return -} - -// ----- - -func @generic_atomic_rmw_result_type_mismatch(%I: memref<10xf32>, %i : index) { - // expected-error@+1 {{failed to verify that result type matches element type of memref}} - %0 = "std.generic_atomic_rmw"(%I, %i) ({ - ^bb0(%old_value: f32): - %c1 = arith.constant 1.0 : f32 - atomic_yield %c1 : f32 - }) : (memref<10xf32>, index) -> i32 - return -} - -// ----- - -func @generic_atomic_rmw_has_side_effects(%I: memref<10xf32>, %i : index) { - // expected-error@+4 {{should contain only operations with no side effects}} - %x = generic_atomic_rmw %I[%i] : memref<10xf32> { - ^bb0(%old_value : f32): - %c1 = arith.constant 1.0 : f32 - %buf = memref.alloc() : memref<2048xf32> - atomic_yield %c1 : f32 - } -} - -// ----- - -func @atomic_yield_type_mismatch(%I: memref<10xf32>, %i : index) { - // expected-error@+4 {{op types mismatch between yield op: 'i32' and its parent: 'f32'}} - %x = generic_atomic_rmw %I[%i] : memref<10xf32> { - ^bb0(%old_value : f32): - %c1 = arith.constant 1 : i32 - atomic_yield %c1 : i32 - } - return -}