[mlir] Move std.generic_atomic_rmw to the memref dialect

This is part of splitting up the standard dialect. The move makes sense anyways,
given that the memref dialect already holds memref.atomic_rmw which is the non-region
sibling operation of std.generic_atomic_rmw (the relationship is even more clear given
they have nearly the same description % how they represent the inner computation).

Differential Revision: https://reviews.llvm.org/D118209
This commit is contained in:
River Riddle 2022-01-25 18:41:02 -08:00
parent 480cd4cb85
commit 632a4f8829
14 changed files with 395 additions and 413 deletions

View File

@ -698,6 +698,81 @@ def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> {
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// GenericAtomicRMWOp
//===----------------------------------------------------------------------===//
def GenericAtomicRMWOp : MemRef_Op<"generic_atomic_rmw", [
SingleBlockImplicitTerminator<"AtomicYieldOp">,
TypesMatchWith<"result type matches element type of memref",
"memref", "result",
"$_self.cast<MemRefType>().getElementType()">
]> {
let summary = "atomic read-modify-write operation with a region";
let description = [{
The `memref.generic_atomic_rmw` operation provides a way to perform a
read-modify-write sequence that is free from data races. The memref operand
represents the buffer that the read and write will be performed against, as
accessed by the specified indices. The arity of the indices is the rank of
the memref. The result represents the latest value that was stored. The
region contains the code for the modification itself. The entry block has
a single argument that represents the value stored in `memref[indices]`
before the write is performed. No side-effecting ops are allowed in the
body of `GenericAtomicRMWOp`.
Example:
```mlir
%x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%current_value : f32):
%c1 = arith.constant 1.0 : f32
%inc = arith.addf %c1, %current_value : f32
memref.atomic_yield %inc : f32
}
```
}];
let arguments = (ins
MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref,
Variadic<Index>:$indices);
let results = (outs
AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result);
let regions = (region AnyRegion:$atomic_body);
let skipDefaultBuilders = 1;
let builders = [OpBuilder<(ins "Value":$memref, "ValueRange":$ivs)>];
let extraClassDeclaration = [{
// TODO: remove post migrating callers.
Region &body() { return getRegion(); }
// The value stored in memref[ivs].
Value getCurrentValue() {
return getRegion().getArgument(0);
}
MemRefType getMemRefType() {
return memref().getType().cast<MemRefType>();
}
}];
}
def AtomicYieldOp : MemRef_Op<"atomic_yield", [
HasParent<"GenericAtomicRMWOp">,
NoSideEffect,
Terminator
]> {
let summary = "yield operation for GenericAtomicRMWOp";
let description = [{
"memref.atomic_yield" yields an SSA value from a
GenericAtomicRMWOp region.
}];
let arguments = (ins AnyType:$result);
let assemblyFormat = "$result attr-dict `:` type($result)";
}
//===----------------------------------------------------------------------===//
// GetGlobalOp
//===----------------------------------------------------------------------===//
@ -1687,7 +1762,7 @@ def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
]> {
let summary = "atomic read-modify-write operation";
let description = [{
The `atomic_rmw` operation provides a way to perform a read-modify-write
The `memref.atomic_rmw` operation provides a way to perform a read-modify-write
sequence that is free from data races. The kind enumeration specifies the
modification to perform. The value operand represents the new value to be
applied during the modification. The memref operand represents the buffer
@ -1698,7 +1773,7 @@ def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
Example:
```mlir
%x = arith.atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32
%x = memref.atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32
```
}];

View File

@ -178,76 +178,6 @@ def AssertOp : Std_Op<"assert"> {
let hasCanonicalizeMethod = 1;
}
def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [
SingleBlockImplicitTerminator<"AtomicYieldOp">,
TypesMatchWith<"result type matches element type of memref",
"memref", "result",
"$_self.cast<MemRefType>().getElementType()">
]> {
let summary = "atomic read-modify-write operation with a region";
let description = [{
The `generic_atomic_rmw` operation provides a way to perform a read-modify-write
sequence that is free from data races. The memref operand represents the
buffer that the read and write will be performed against, as accessed by
the specified indices. The arity of the indices is the rank of the memref.
The result represents the latest value that was stored. The region contains
the code for the modification itself. The entry block has a single argument
that represents the value stored in `memref[indices]` before the write is
performed. No side-effecting ops are allowed in the body of
`GenericAtomicRMWOp`.
Example:
```mlir
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%current_value : f32):
%c1 = arith.constant 1.0 : f32
%inc = arith.addf %c1, %current_value : f32
atomic_yield %inc : f32
}
```
}];
let arguments = (ins
MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref,
Variadic<Index>:$indices);
let results = (outs
AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result);
let regions = (region AnyRegion:$atomic_body);
let skipDefaultBuilders = 1;
let builders = [OpBuilder<(ins "Value":$memref, "ValueRange":$ivs)>];
let extraClassDeclaration = [{
// TODO: remove post migrating callers.
Region &body() { return getRegion(); }
// The value stored in memref[ivs].
Value getCurrentValue() {
return getRegion().getArgument(0);
}
MemRefType getMemRefType() {
return getMemref().getType().cast<MemRefType>();
}
}];
}
def AtomicYieldOp : Std_Op<"atomic_yield", [
HasParent<"GenericAtomicRMWOp">,
NoSideEffect,
Terminator
]> {
let summary = "yield operation for GenericAtomicRMWOp";
let description = [{
"atomic_yield" yields an SSA value from a GenericAtomicRMWOp region.
}];
let arguments = (ins AnyType:$result);
let assemblyFormat = "$result attr-dict `:` type($result)";
}
//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//

View File

@ -412,6 +412,139 @@ private:
}
};
/// Common base for load and store operations on MemRefs. Restricts the match
/// to supported MemRef types. Provides functionality to emit code accessing a
/// specific element of the underlying data buffer.
template <typename Derived>
struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
using Base = LoadStoreOpLowering<Derived>;
LogicalResult match(Derived op) const override {
MemRefType type = op.getMemRefType();
return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
}
};
/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
/// retried until it succeeds in atomically storing a new value into memory.
///
/// +---------------------------------+
/// | <code before the AtomicRMWOp> |
/// | <compute initial %loaded> |
/// | br loop(%loaded) |
/// +---------------------------------+
/// |
/// -------| |
/// | v v
/// | +--------------------------------+
/// | | loop(%loaded): |
/// | | <body contents> |
/// | | %pair = cmpxchg |
/// | | %ok = %pair[0] |
/// | | %new = %pair[1] |
/// | | cond_br %ok, end, loop(%new) |
/// | +--------------------------------+
/// | | |
/// |----------- |
/// v
/// +--------------------------------+
/// | end: |
/// | <code after the AtomicRMWOp> |
/// +--------------------------------+
///
struct GenericAtomicRMWOpLowering
: public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
using Base::Base;
LogicalResult
matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = atomicOp.getLoc();
Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
// Split the block into initial, loop, and ending parts.
auto *initBlock = rewriter.getInsertionBlock();
auto *loopBlock = rewriter.createBlock(
initBlock->getParent(), std::next(Region::iterator(initBlock)),
valueType, loc);
auto *endBlock = rewriter.createBlock(
loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
// Operations range to be moved to `endBlock`.
auto opsToMoveStart = atomicOp->getIterator();
auto opsToMoveEnd = initBlock->back().getIterator();
// Compute the loaded value and branch to the loop block.
rewriter.setInsertionPointToEnd(initBlock);
auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
adaptor.indices(), rewriter);
Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
// Prepare the body of the loop block.
rewriter.setInsertionPointToStart(loopBlock);
// Clone the GenericAtomicRMWOp region and extract the result.
auto loopArgument = loopBlock->getArgument(0);
BlockAndValueMapping mapping;
mapping.map(atomicOp.getCurrentValue(), loopArgument);
Block &entryBlock = atomicOp.body().front();
for (auto &nestedOp : entryBlock.without_terminator()) {
Operation *clone = rewriter.clone(nestedOp, mapping);
mapping.map(nestedOp.getResults(), clone->getResults());
}
Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
// Prepare the epilog of the loop block.
// Append the cmpxchg op to the end of the loop block.
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
auto boolType = IntegerType::get(rewriter.getContext(), 1);
auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
{valueType, boolType});
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
loc, pairType, dataPtr, loopArgument, result, successOrdering,
failureOrdering);
// Extract the %new_loaded and %ok values from the pair.
Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
Value ok = rewriter.create<LLVM::ExtractValueOp>(
loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
// Conditionally branch to the end or back to the loop depending on %ok.
rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
loopBlock, newLoaded);
rewriter.setInsertionPointToEnd(endBlock);
moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
std::next(opsToMoveEnd), rewriter);
// The 'result' of the atomic_rmw op is the newly loaded value.
rewriter.replaceOp(atomicOp, {newLoaded});
return success();
}
private:
// Clones a segment of ops [start, end) and erases the original.
void moveOpsRange(ValueRange oldResult, ValueRange newResult,
Block::iterator start, Block::iterator end,
ConversionPatternRewriter &rewriter) const {
BlockAndValueMapping mapping;
mapping.map(oldResult, newResult);
SmallVector<Operation *, 2> opsToErase;
for (auto it = start; it != end; ++it) {
rewriter.clone(*it, mapping);
opsToErase.push_back(&*it);
}
for (auto *it : opsToErase)
rewriter.eraseOp(it);
}
};
/// Returns the LLVM type of the global variable given the memref type `type`.
static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
LLVMTypeConverter &typeConverter) {
@ -520,21 +653,6 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
}
};
// Common base for load and store operations on MemRefs. Restricts the match
// to supported MemRef types. Provides functionality to emit code accessing a
// specific element of the underlying data buffer.
template <typename Derived>
struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
using Base = LoadStoreOpLowering<Derived>;
LogicalResult match(Derived op) const override {
MemRefType type = op.getMemRefType();
return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
}
};
// Load operation is lowered to obtaining a pointer to the indexed element
// and loading it.
struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
@ -1683,6 +1801,7 @@ void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
AtomicRMWOpLowering,
AssumeAlignmentOpLowering,
DimOpLowering,
GenericAtomicRMWOpLowering,
GlobalMemrefOpLowering,
GetGlobalMemrefOpLowering,
LoadOpLowering,

View File

@ -565,21 +565,6 @@ struct UnrealizedConversionCastOpLowering
}
};
// Common base for load and store operations on MemRefs. Restricts the match
// to supported MemRef types. Provides functionality to emit code accessing a
// specific element of the underlying data buffer.
template <typename Derived>
struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
using Base = LoadStoreOpLowering<Derived>;
LogicalResult match(Derived op) const override {
MemRefType type = op.getMemRefType();
return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
}
};
// Base class for LLVM IR lowering terminator operations with successors.
template <typename SourceOp, typename TargetOp>
struct OneToOneLLVMTerminatorLowering
@ -771,125 +756,6 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
}
};
/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
/// retried until it succeeds in atomically storing a new value into memory.
///
/// +---------------------------------+
/// | <code before the AtomicRMWOp> |
/// | <compute initial %loaded> |
/// | br loop(%loaded) |
/// +---------------------------------+
/// |
/// -------| |
/// | v v
/// | +--------------------------------+
/// | | loop(%loaded): |
/// | | <body contents> |
/// | | %pair = cmpxchg |
/// | | %ok = %pair[0] |
/// | | %new = %pair[1] |
/// | | cond_br %ok, end, loop(%new) |
/// | +--------------------------------+
/// | | |
/// |----------- |
/// v
/// +--------------------------------+
/// | end: |
/// | <code after the AtomicRMWOp> |
/// +--------------------------------+
///
struct GenericAtomicRMWOpLowering
: public LoadStoreOpLowering<GenericAtomicRMWOp> {
using Base::Base;
LogicalResult
matchAndRewrite(GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = atomicOp.getLoc();
Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
// Split the block into initial, loop, and ending parts.
auto *initBlock = rewriter.getInsertionBlock();
auto *loopBlock = rewriter.createBlock(
initBlock->getParent(), std::next(Region::iterator(initBlock)),
valueType, loc);
auto *endBlock = rewriter.createBlock(
loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
// Operations range to be moved to `endBlock`.
auto opsToMoveStart = atomicOp->getIterator();
auto opsToMoveEnd = initBlock->back().getIterator();
// Compute the loaded value and branch to the loop block.
rewriter.setInsertionPointToEnd(initBlock);
auto memRefType = atomicOp.getMemref().getType().cast<MemRefType>();
auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
adaptor.getIndices(), rewriter);
Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
// Prepare the body of the loop block.
rewriter.setInsertionPointToStart(loopBlock);
// Clone the GenericAtomicRMWOp region and extract the result.
auto loopArgument = loopBlock->getArgument(0);
BlockAndValueMapping mapping;
mapping.map(atomicOp.getCurrentValue(), loopArgument);
Block &entryBlock = atomicOp.body().front();
for (auto &nestedOp : entryBlock.without_terminator()) {
Operation *clone = rewriter.clone(nestedOp, mapping);
mapping.map(nestedOp.getResults(), clone->getResults());
}
Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
// Prepare the epilog of the loop block.
// Append the cmpxchg op to the end of the loop block.
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
auto boolType = IntegerType::get(rewriter.getContext(), 1);
auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
{valueType, boolType});
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
loc, pairType, dataPtr, loopArgument, result, successOrdering,
failureOrdering);
// Extract the %new_loaded and %ok values from the pair.
Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
Value ok = rewriter.create<LLVM::ExtractValueOp>(
loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
// Conditionally branch to the end or back to the loop depending on %ok.
rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
loopBlock, newLoaded);
rewriter.setInsertionPointToEnd(endBlock);
moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
std::next(opsToMoveEnd), rewriter);
// The 'result' of the atomic_rmw op is the newly loaded value.
rewriter.replaceOp(atomicOp, {newLoaded});
return success();
}
private:
// Clones a segment of ops [start, end) and erases the original.
void moveOpsRange(ValueRange oldResult, ValueRange newResult,
Block::iterator start, Block::iterator end,
ConversionPatternRewriter &rewriter) const {
BlockAndValueMapping mapping;
mapping.map(oldResult, newResult);
SmallVector<Operation *, 2> opsToErase;
for (auto it = start; it != end; ++it) {
rewriter.clone(*it, mapping);
opsToErase.push_back(&*it);
}
for (auto *it : opsToErase)
rewriter.eraseOp(it);
}
};
} // namespace
void mlir::populateStdToLLVMFuncOpConversionPattern(
@ -911,7 +777,6 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
CallOpLowering,
CondBranchOpLowering,
ConstantOpLowering,
GenericAtomicRMWOpLowering,
ReturnOpLowering,
SelectOpLowering,
SplatOpLowering,

View File

@ -945,6 +945,89 @@ static LogicalResult verify(DmaWaitOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// GenericAtomicRMWOp
//===----------------------------------------------------------------------===//
void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
Value memref, ValueRange ivs) {
result.addOperands(memref);
result.addOperands(ivs);
if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
Type elementType = memrefType.getElementType();
result.addTypes(elementType);
Region *bodyRegion = result.addRegion();
bodyRegion->push_back(new Block());
bodyRegion->addArgument(elementType, memref.getLoc());
}
}
static LogicalResult verify(GenericAtomicRMWOp op) {
auto &body = op.getRegion();
if (body.getNumArguments() != 1)
return op.emitOpError("expected single number of entry block arguments");
if (op.getResult().getType() != body.getArgument(0).getType())
return op.emitOpError(
"expected block argument of the same type result type");
bool hasSideEffects =
body.walk([&](Operation *nestedOp) {
if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
return WalkResult::advance();
nestedOp->emitError(
"body of 'memref.generic_atomic_rmw' should contain "
"only operations with no side effects");
return WalkResult::interrupt();
})
.wasInterrupted();
return hasSideEffects ? failure() : success();
}
static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType memref;
Type memrefType;
SmallVector<OpAsmParser::OperandType, 4> ivs;
Type indexType = parser.getBuilder().getIndexType();
if (parser.parseOperand(memref) ||
parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
parser.parseColonType(memrefType) ||
parser.resolveOperand(memref, memrefType, result.operands) ||
parser.resolveOperands(ivs, indexType, result.operands))
return failure();
Region *body = result.addRegion();
if (parser.parseRegion(*body, llvm::None, llvm::None) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();
result.types.push_back(memrefType.cast<MemRefType>().getElementType());
return success();
}
static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
p << ' ' << op.memref() << "[" << op.indices()
<< "] : " << op.memref().getType() << ' ';
p.printRegion(op.getRegion());
p.printOptionalAttrDict(op->getAttrs());
}
//===----------------------------------------------------------------------===//
// AtomicYieldOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(AtomicYieldOp op) {
Type parentType = op->getParentOp()->getResultTypes().front();
Type resultType = op.result().getType();
if (parentType != resultType)
return op.emitOpError() << "types mismatch between yield op: " << resultType
<< " and its parent: " << parentType;
return success();
}
//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//

View File

@ -131,88 +131,6 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
return failure();
}
//===----------------------------------------------------------------------===//
// GenericAtomicRMWOp
//===----------------------------------------------------------------------===//
void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
Value memref, ValueRange ivs) {
result.addOperands(memref);
result.addOperands(ivs);
if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
Type elementType = memrefType.getElementType();
result.addTypes(elementType);
Region *bodyRegion = result.addRegion();
bodyRegion->push_back(new Block());
bodyRegion->addArgument(elementType, memref.getLoc());
}
}
static LogicalResult verify(GenericAtomicRMWOp op) {
auto &body = op.getRegion();
if (body.getNumArguments() != 1)
return op.emitOpError("expected single number of entry block arguments");
if (op.getResult().getType() != body.getArgument(0).getType())
return op.emitOpError(
"expected block argument of the same type result type");
bool hasSideEffects =
body.walk([&](Operation *nestedOp) {
if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
return WalkResult::advance();
nestedOp->emitError("body of 'generic_atomic_rmw' should contain "
"only operations with no side effects");
return WalkResult::interrupt();
})
.wasInterrupted();
return hasSideEffects ? failure() : success();
}
static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType memref;
Type memrefType;
SmallVector<OpAsmParser::OperandType, 4> ivs;
Type indexType = parser.getBuilder().getIndexType();
if (parser.parseOperand(memref) ||
parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
parser.parseColonType(memrefType) ||
parser.resolveOperand(memref, memrefType, result.operands) ||
parser.resolveOperands(ivs, indexType, result.operands))
return failure();
Region *body = result.addRegion();
if (parser.parseRegion(*body, llvm::None, llvm::None) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();
result.types.push_back(memrefType.cast<MemRefType>().getElementType());
return success();
}
static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
p << ' ' << op.getMemref() << "[" << op.getIndices()
<< "] : " << op.getMemref().getType() << ' ';
p.printRegion(op.getRegion());
p.printOptionalAttrDict(op->getAttrs());
}
//===----------------------------------------------------------------------===//
// AtomicYieldOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(AtomicYieldOp op) {
Type parentType = op->getParentOp()->getResultTypes().front();
Type resultType = op.getResult().getType();
if (parentType != resultType)
return op.emitOpError() << "types mismatch between yield op: " << resultType
<< " and its parent: " << parentType;
return success();
}
//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//

View File

@ -28,17 +28,17 @@ namespace {
/// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
/// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
/// `generic_atomic_rmw` with the expanded code.
/// `memref.generic_atomic_rmw` with the expanded code.
///
/// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
///
/// will be lowered to
///
/// %x = std.generic_atomic_rmw %F[%i] : memref<10xf32> {
/// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
/// ^bb0(%current: f32):
/// %cmp = arith.cmpf "ogt", %current, %fval : f32
/// %new_value = select %cmp, %current, %fval : f32
/// atomic_yield %new_value : f32
/// memref.atomic_yield %new_value : f32
/// }
struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
public:
@ -59,8 +59,8 @@ public:
}
auto loc = op.getLoc();
auto genericOp =
rewriter.create<GenericAtomicRMWOp>(loc, op.memref(), op.indices());
auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>(
loc, op.memref(), op.indices());
OpBuilder bodyBuilder =
OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener());
@ -68,7 +68,7 @@ public:
Value rhs = op.value();
Value cmp = bodyBuilder.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
Value select = bodyBuilder.create<SelectOp>(loc, cmp, lhs, rhs);
bodyBuilder.create<AtomicYieldOp>(loc, select);
bodyBuilder.create<memref::AtomicYieldOp>(loc, select);
rewriter.replaceOp(op, genericOp.getResult());
return success();

View File

@ -892,3 +892,22 @@ func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, a
%1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 64 + s0 + d1 * 64 + d2 * 8 + d3)>> into memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
return %1 : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
}
// -----
// CHECK-LABEL: func @generic_atomic_rmw
func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) {
%x = memref.generic_atomic_rmw %I[%i] : memref<10xi32> {
^bb0(%old_value : i32):
memref.atomic_yield %old_value : i32
}
// CHECK: [[init:%.*]] = llvm.load %{{.*}} : !llvm.ptr<i32>
// CHECK-NEXT: llvm.br ^bb1([[init]] : i32)
// CHECK-NEXT: ^bb1([[loaded:%.*]]: i32):
// CHECK-NEXT: [[pair:%.*]] = llvm.cmpxchg %{{.*}}, [[loaded]], [[loaded]]
// CHECK-SAME: acq_rel monotonic : i32
// CHECK-NEXT: [[new:%.*]] = llvm.extractvalue [[pair]][0]
// CHECK-NEXT: [[ok:%.*]] = llvm.extractvalue [[pair]][1]
// CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : i32)
llvm.return
}

View File

@ -486,33 +486,6 @@ func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> {
// -----
// CHECK-LABEL: func @generic_atomic_rmw
func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) -> i32 {
%x = generic_atomic_rmw %I[%i] : memref<10xi32> {
^bb0(%old_value : i32):
%c1 = arith.constant 1 : i32
atomic_yield %c1 : i32
}
// CHECK: [[init:%.*]] = llvm.load %{{.*}} : !llvm.ptr<i32>
// CHECK-NEXT: llvm.br ^bb1([[init]] : i32)
// CHECK-NEXT: ^bb1([[loaded:%.*]]: i32):
// CHECK-NEXT: [[c1:%.*]] = llvm.mlir.constant(1 : i32)
// CHECK-NEXT: [[pair:%.*]] = llvm.cmpxchg %{{.*}}, [[loaded]], [[c1]]
// CHECK-SAME: acq_rel monotonic : i32
// CHECK-NEXT: [[new:%.*]] = llvm.extractvalue [[pair]][0]
// CHECK-NEXT: [[ok:%.*]] = llvm.extractvalue [[pair]][1]
// CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : i32)
// CHECK-NEXT: ^bb2:
%c2 = arith.constant 2 : i32
%add = arith.addi %c2, %x : i32
return %add : i32
// CHECK-NEXT: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
// CHECK-NEXT: [[add:%.*]] = llvm.add [[c2]], [[new]] : i32
// CHECK-NEXT: llvm.return [[add]]
}
// -----
// CHECK-LABEL: func @ceilf(
// CHECK-SAME: f32
func @ceilf(%arg0 : f32) {

View File

@ -910,3 +910,63 @@ func @atomic_rmw_expects_int(%I: memref<16x10xf32>, %i : index, %val : f32) {
%x = memref.atomic_rmw addi %val, %I[%i, %i] : (f32, memref<16x10xf32>) -> f32
return
}
// -----
func @generic_atomic_rmw_wrong_arg_num(%I: memref<10xf32>, %i : index) {
// expected-error@+1 {{expected single number of entry block arguments}}
%x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%arg0 : f32, %arg1 : f32):
%c1 = arith.constant 1.0 : f32
memref.atomic_yield %c1 : f32
}
return
}
// -----
func @generic_atomic_rmw_wrong_arg_type(%I: memref<10xf32>, %i : index) {
// expected-error@+1 {{expected block argument of the same type result type}}
%x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%old_value : i32):
%c1 = arith.constant 1.0 : f32
memref.atomic_yield %c1 : f32
}
return
}
// -----
func @generic_atomic_rmw_result_type_mismatch(%I: memref<10xf32>, %i : index) {
// expected-error@+1 {{failed to verify that result type matches element type of memref}}
%0 = "memref.generic_atomic_rmw"(%I, %i) ({
^bb0(%old_value: f32):
%c1 = arith.constant 1.0 : f32
memref.atomic_yield %c1 : f32
}) : (memref<10xf32>, index) -> i32
return
}
// -----
func @generic_atomic_rmw_has_side_effects(%I: memref<10xf32>, %i : index) {
// expected-error@+4 {{should contain only operations with no side effects}}
%x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%old_value : f32):
%c1 = arith.constant 1.0 : f32
%buf = memref.alloc() : memref<2048xf32>
memref.atomic_yield %c1 : f32
}
}
// -----
func @atomic_yield_type_mismatch(%I: memref<10xf32>, %i : index) {
// expected-error@+4 {{op types mismatch between yield op: 'i32' and its parent: 'f32'}}
%x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%old_value : f32):
%c1 = arith.constant 1 : i32
memref.atomic_yield %c1 : i32
}
return
}

View File

@ -246,3 +246,17 @@ func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) {
// CHECK: memref.atomic_rmw addf [[VAL]], [[BUF]]{{\[}}[[I]]]
return
}
// CHECK-LABEL: func @generic_atomic_rmw
// CHECK-SAME: ([[BUF:%.*]]: memref<1x2xf32>, [[I:%.*]]: index, [[J:%.*]]: index)
func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) {
%x = memref.generic_atomic_rmw %I[%i, %j] : memref<1x2xf32> {
// CHECK-NEXT: memref.generic_atomic_rmw [[BUF]]{{\[}}[[I]], [[J]]] : memref
^bb0(%old_value : f32):
%c1 = arith.constant 1.0 : f32
%out = arith.addf %c1, %old_value : f32
memref.atomic_yield %out : f32
// CHECK: index_attr = 8 : index
} { index_attr = 8 : index }
return
}

View File

@ -6,11 +6,11 @@ func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
%x = memref.atomic_rmw maxf %f, %F[%i] : (f32, memref<10xf32>) -> f32
return %x : f32
}
// CHECK: %0 = generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
// CHECK: %0 = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
// CHECK: [[CMP:%.*]] = arith.cmpf ogt, [[CUR_VAL]], [[f]] : f32
// CHECK: [[SELECT:%.*]] = select [[CMP]], [[CUR_VAL]], [[f]] : f32
// CHECK: atomic_yield [[SELECT]] : f32
// CHECK: memref.atomic_yield [[SELECT]] : f32
// CHECK: }
// CHECK: return %0 : f32

View File

@ -325,20 +325,6 @@ func @unranked_tensor_load_store(%0 : memref<*xi32>, %1 : tensor<*xi32>) {
return
}
// CHECK-LABEL: func @generic_atomic_rmw
// CHECK-SAME: ([[BUF:%.*]]: memref<1x2xf32>, [[I:%.*]]: index, [[J:%.*]]: index)
func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) {
%x = generic_atomic_rmw %I[%i, %j] : memref<1x2xf32> {
// CHECK-NEXT: generic_atomic_rmw [[BUF]]{{\[}}[[I]], [[J]]] : memref
^bb0(%old_value : f32):
%c1 = arith.constant 1.0 : f32
%out = arith.addf %c1, %old_value : f32
atomic_yield %out : f32
// CHECK: index_attr = 8 : index
} { index_attr = 8 : index }
return
}
// CHECK-LABEL: func @assume_alignment
// CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16>
func @assume_alignment(%0: memref<4x4xf16>) {

View File

@ -127,63 +127,3 @@ func @invalid_splat(%v : f32) { // expected-note {{prior use here}}
// expected-error@-1 {{expects different type than prior uses}}
return
}
// -----
func @generic_atomic_rmw_wrong_arg_num(%I: memref<10xf32>, %i : index) {
// expected-error@+1 {{expected single number of entry block arguments}}
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%arg0 : f32, %arg1 : f32):
%c1 = arith.constant 1.0 : f32
atomic_yield %c1 : f32
}
return
}
// -----
func @generic_atomic_rmw_wrong_arg_type(%I: memref<10xf32>, %i : index) {
// expected-error@+1 {{expected block argument of the same type result type}}
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%old_value : i32):
%c1 = arith.constant 1.0 : f32
atomic_yield %c1 : f32
}
return
}
// -----
func @generic_atomic_rmw_result_type_mismatch(%I: memref<10xf32>, %i : index) {
// expected-error@+1 {{failed to verify that result type matches element type of memref}}
%0 = "std.generic_atomic_rmw"(%I, %i) ({
^bb0(%old_value: f32):
%c1 = arith.constant 1.0 : f32
atomic_yield %c1 : f32
}) : (memref<10xf32>, index) -> i32
return
}
// -----
func @generic_atomic_rmw_has_side_effects(%I: memref<10xf32>, %i : index) {
// expected-error@+4 {{should contain only operations with no side effects}}
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%old_value : f32):
%c1 = arith.constant 1.0 : f32
%buf = memref.alloc() : memref<2048xf32>
atomic_yield %c1 : f32
}
}
// -----
func @atomic_yield_type_mismatch(%I: memref<10xf32>, %i : index) {
// expected-error@+4 {{op types mismatch between yield op: 'i32' and its parent: 'f32'}}
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%old_value : f32):
%c1 = arith.constant 1 : i32
atomic_yield %c1 : i32
}
return
}