forked from OSchip/llvm-project
[mlir] Move std.generic_atomic_rmw to the memref dialect
This is part of splitting up the standard dialect. The move makes sense anyways, given that the memref dialect already holds memref.atomic_rmw which is the non-region sibling operation of std.generic_atomic_rmw (the relationship is even more clear given they have nearly the same description % how they represent the inner computation). Differential Revision: https://reviews.llvm.org/D118209
This commit is contained in:
parent
480cd4cb85
commit
632a4f8829
|
@ -698,6 +698,81 @@ def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> {
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericAtomicRMWOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def GenericAtomicRMWOp : MemRef_Op<"generic_atomic_rmw", [
|
||||
SingleBlockImplicitTerminator<"AtomicYieldOp">,
|
||||
TypesMatchWith<"result type matches element type of memref",
|
||||
"memref", "result",
|
||||
"$_self.cast<MemRefType>().getElementType()">
|
||||
]> {
|
||||
let summary = "atomic read-modify-write operation with a region";
|
||||
let description = [{
|
||||
The `memref.generic_atomic_rmw` operation provides a way to perform a
|
||||
read-modify-write sequence that is free from data races. The memref operand
|
||||
represents the buffer that the read and write will be performed against, as
|
||||
accessed by the specified indices. The arity of the indices is the rank of
|
||||
the memref. The result represents the latest value that was stored. The
|
||||
region contains the code for the modification itself. The entry block has
|
||||
a single argument that represents the value stored in `memref[indices]`
|
||||
before the write is performed. No side-effecting ops are allowed in the
|
||||
body of `GenericAtomicRMWOp`.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> {
|
||||
^bb0(%current_value : f32):
|
||||
%c1 = arith.constant 1.0 : f32
|
||||
%inc = arith.addf %c1, %current_value : f32
|
||||
memref.atomic_yield %inc : f32
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref,
|
||||
Variadic<Index>:$indices);
|
||||
|
||||
let results = (outs
|
||||
AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result);
|
||||
|
||||
let regions = (region AnyRegion:$atomic_body);
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
let builders = [OpBuilder<(ins "Value":$memref, "ValueRange":$ivs)>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TODO: remove post migrating callers.
|
||||
Region &body() { return getRegion(); }
|
||||
|
||||
// The value stored in memref[ivs].
|
||||
Value getCurrentValue() {
|
||||
return getRegion().getArgument(0);
|
||||
}
|
||||
MemRefType getMemRefType() {
|
||||
return memref().getType().cast<MemRefType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def AtomicYieldOp : MemRef_Op<"atomic_yield", [
|
||||
HasParent<"GenericAtomicRMWOp">,
|
||||
NoSideEffect,
|
||||
Terminator
|
||||
]> {
|
||||
let summary = "yield operation for GenericAtomicRMWOp";
|
||||
let description = [{
|
||||
"memref.atomic_yield" yields an SSA value from a
|
||||
GenericAtomicRMWOp region.
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyType:$result);
|
||||
let assemblyFormat = "$result attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GetGlobalOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1687,7 +1762,7 @@ def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
|
|||
]> {
|
||||
let summary = "atomic read-modify-write operation";
|
||||
let description = [{
|
||||
The `atomic_rmw` operation provides a way to perform a read-modify-write
|
||||
The `memref.atomic_rmw` operation provides a way to perform a read-modify-write
|
||||
sequence that is free from data races. The kind enumeration specifies the
|
||||
modification to perform. The value operand represents the new value to be
|
||||
applied during the modification. The memref operand represents the buffer
|
||||
|
@ -1698,7 +1773,7 @@ def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
|
|||
Example:
|
||||
|
||||
```mlir
|
||||
%x = arith.atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32
|
||||
%x = memref.atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32
|
||||
```
|
||||
}];
|
||||
|
||||
|
|
|
@ -178,76 +178,6 @@ def AssertOp : Std_Op<"assert"> {
|
|||
let hasCanonicalizeMethod = 1;
|
||||
}
|
||||
|
||||
def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [
|
||||
SingleBlockImplicitTerminator<"AtomicYieldOp">,
|
||||
TypesMatchWith<"result type matches element type of memref",
|
||||
"memref", "result",
|
||||
"$_self.cast<MemRefType>().getElementType()">
|
||||
]> {
|
||||
let summary = "atomic read-modify-write operation with a region";
|
||||
let description = [{
|
||||
The `generic_atomic_rmw` operation provides a way to perform a read-modify-write
|
||||
sequence that is free from data races. The memref operand represents the
|
||||
buffer that the read and write will be performed against, as accessed by
|
||||
the specified indices. The arity of the indices is the rank of the memref.
|
||||
The result represents the latest value that was stored. The region contains
|
||||
the code for the modification itself. The entry block has a single argument
|
||||
that represents the value stored in `memref[indices]` before the write is
|
||||
performed. No side-effecting ops are allowed in the body of
|
||||
`GenericAtomicRMWOp`.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
|
||||
^bb0(%current_value : f32):
|
||||
%c1 = arith.constant 1.0 : f32
|
||||
%inc = arith.addf %c1, %current_value : f32
|
||||
atomic_yield %inc : f32
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref,
|
||||
Variadic<Index>:$indices);
|
||||
|
||||
let results = (outs
|
||||
AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result);
|
||||
|
||||
let regions = (region AnyRegion:$atomic_body);
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
let builders = [OpBuilder<(ins "Value":$memref, "ValueRange":$ivs)>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TODO: remove post migrating callers.
|
||||
Region &body() { return getRegion(); }
|
||||
|
||||
// The value stored in memref[ivs].
|
||||
Value getCurrentValue() {
|
||||
return getRegion().getArgument(0);
|
||||
}
|
||||
MemRefType getMemRefType() {
|
||||
return getMemref().getType().cast<MemRefType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def AtomicYieldOp : Std_Op<"atomic_yield", [
|
||||
HasParent<"GenericAtomicRMWOp">,
|
||||
NoSideEffect,
|
||||
Terminator
|
||||
]> {
|
||||
let summary = "yield operation for GenericAtomicRMWOp";
|
||||
let description = [{
|
||||
"atomic_yield" yields an SSA value from a GenericAtomicRMWOp region.
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyType:$result);
|
||||
let assemblyFormat = "$result attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BranchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -412,6 +412,139 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
/// Common base for load and store operations on MemRefs. Restricts the match
|
||||
/// to supported MemRef types. Provides functionality to emit code accessing a
|
||||
/// specific element of the underlying data buffer.
|
||||
template <typename Derived>
|
||||
struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
|
||||
using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
|
||||
using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
|
||||
using Base = LoadStoreOpLowering<Derived>;
|
||||
|
||||
LogicalResult match(Derived op) const override {
|
||||
MemRefType type = op.getMemRefType();
|
||||
return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
|
||||
}
|
||||
};
|
||||
|
||||
/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
|
||||
/// retried until it succeeds in atomically storing a new value into memory.
|
||||
///
|
||||
/// +---------------------------------+
|
||||
/// | <code before the AtomicRMWOp> |
|
||||
/// | <compute initial %loaded> |
|
||||
/// | br loop(%loaded) |
|
||||
/// +---------------------------------+
|
||||
/// |
|
||||
/// -------| |
|
||||
/// | v v
|
||||
/// | +--------------------------------+
|
||||
/// | | loop(%loaded): |
|
||||
/// | | <body contents> |
|
||||
/// | | %pair = cmpxchg |
|
||||
/// | | %ok = %pair[0] |
|
||||
/// | | %new = %pair[1] |
|
||||
/// | | cond_br %ok, end, loop(%new) |
|
||||
/// | +--------------------------------+
|
||||
/// | | |
|
||||
/// |----------- |
|
||||
/// v
|
||||
/// +--------------------------------+
|
||||
/// | end: |
|
||||
/// | <code after the AtomicRMWOp> |
|
||||
/// +--------------------------------+
|
||||
///
|
||||
struct GenericAtomicRMWOpLowering
|
||||
: public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
|
||||
using Base::Base;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = atomicOp.getLoc();
|
||||
Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
|
||||
|
||||
// Split the block into initial, loop, and ending parts.
|
||||
auto *initBlock = rewriter.getInsertionBlock();
|
||||
auto *loopBlock = rewriter.createBlock(
|
||||
initBlock->getParent(), std::next(Region::iterator(initBlock)),
|
||||
valueType, loc);
|
||||
auto *endBlock = rewriter.createBlock(
|
||||
loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
|
||||
|
||||
// Operations range to be moved to `endBlock`.
|
||||
auto opsToMoveStart = atomicOp->getIterator();
|
||||
auto opsToMoveEnd = initBlock->back().getIterator();
|
||||
|
||||
// Compute the loaded value and branch to the loop block.
|
||||
rewriter.setInsertionPointToEnd(initBlock);
|
||||
auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
|
||||
auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
|
||||
adaptor.indices(), rewriter);
|
||||
Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
|
||||
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
|
||||
|
||||
// Prepare the body of the loop block.
|
||||
rewriter.setInsertionPointToStart(loopBlock);
|
||||
|
||||
// Clone the GenericAtomicRMWOp region and extract the result.
|
||||
auto loopArgument = loopBlock->getArgument(0);
|
||||
BlockAndValueMapping mapping;
|
||||
mapping.map(atomicOp.getCurrentValue(), loopArgument);
|
||||
Block &entryBlock = atomicOp.body().front();
|
||||
for (auto &nestedOp : entryBlock.without_terminator()) {
|
||||
Operation *clone = rewriter.clone(nestedOp, mapping);
|
||||
mapping.map(nestedOp.getResults(), clone->getResults());
|
||||
}
|
||||
Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
|
||||
|
||||
// Prepare the epilog of the loop block.
|
||||
// Append the cmpxchg op to the end of the loop block.
|
||||
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
|
||||
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
|
||||
auto boolType = IntegerType::get(rewriter.getContext(), 1);
|
||||
auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
|
||||
{valueType, boolType});
|
||||
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
|
||||
loc, pairType, dataPtr, loopArgument, result, successOrdering,
|
||||
failureOrdering);
|
||||
// Extract the %new_loaded and %ok values from the pair.
|
||||
Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
|
||||
Value ok = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
|
||||
|
||||
// Conditionally branch to the end or back to the loop depending on %ok.
|
||||
rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
|
||||
loopBlock, newLoaded);
|
||||
|
||||
rewriter.setInsertionPointToEnd(endBlock);
|
||||
moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
|
||||
std::next(opsToMoveEnd), rewriter);
|
||||
|
||||
// The 'result' of the atomic_rmw op is the newly loaded value.
|
||||
rewriter.replaceOp(atomicOp, {newLoaded});
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
// Clones a segment of ops [start, end) and erases the original.
|
||||
void moveOpsRange(ValueRange oldResult, ValueRange newResult,
|
||||
Block::iterator start, Block::iterator end,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
BlockAndValueMapping mapping;
|
||||
mapping.map(oldResult, newResult);
|
||||
SmallVector<Operation *, 2> opsToErase;
|
||||
for (auto it = start; it != end; ++it) {
|
||||
rewriter.clone(*it, mapping);
|
||||
opsToErase.push_back(&*it);
|
||||
}
|
||||
for (auto *it : opsToErase)
|
||||
rewriter.eraseOp(it);
|
||||
}
|
||||
};
|
||||
|
||||
/// Returns the LLVM type of the global variable given the memref type `type`.
|
||||
static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
|
||||
LLVMTypeConverter &typeConverter) {
|
||||
|
@ -520,21 +653,6 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
|
|||
}
|
||||
};
|
||||
|
||||
// Common base for load and store operations on MemRefs. Restricts the match
|
||||
// to supported MemRef types. Provides functionality to emit code accessing a
|
||||
// specific element of the underlying data buffer.
|
||||
template <typename Derived>
|
||||
struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
|
||||
using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
|
||||
using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
|
||||
using Base = LoadStoreOpLowering<Derived>;
|
||||
|
||||
LogicalResult match(Derived op) const override {
|
||||
MemRefType type = op.getMemRefType();
|
||||
return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
|
||||
}
|
||||
};
|
||||
|
||||
// Load operation is lowered to obtaining a pointer to the indexed element
|
||||
// and loading it.
|
||||
struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
|
||||
|
@ -1683,6 +1801,7 @@ void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
|||
AtomicRMWOpLowering,
|
||||
AssumeAlignmentOpLowering,
|
||||
DimOpLowering,
|
||||
GenericAtomicRMWOpLowering,
|
||||
GlobalMemrefOpLowering,
|
||||
GetGlobalMemrefOpLowering,
|
||||
LoadOpLowering,
|
||||
|
|
|
@ -565,21 +565,6 @@ struct UnrealizedConversionCastOpLowering
|
|||
}
|
||||
};
|
||||
|
||||
// Common base for load and store operations on MemRefs. Restricts the match
|
||||
// to supported MemRef types. Provides functionality to emit code accessing a
|
||||
// specific element of the underlying data buffer.
|
||||
template <typename Derived>
|
||||
struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
|
||||
using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
|
||||
using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
|
||||
using Base = LoadStoreOpLowering<Derived>;
|
||||
|
||||
LogicalResult match(Derived op) const override {
|
||||
MemRefType type = op.getMemRefType();
|
||||
return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
|
||||
}
|
||||
};
|
||||
|
||||
// Base class for LLVM IR lowering terminator operations with successors.
|
||||
template <typename SourceOp, typename TargetOp>
|
||||
struct OneToOneLLVMTerminatorLowering
|
||||
|
@ -771,125 +756,6 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
|
||||
/// retried until it succeeds in atomically storing a new value into memory.
|
||||
///
|
||||
/// +---------------------------------+
|
||||
/// | <code before the AtomicRMWOp> |
|
||||
/// | <compute initial %loaded> |
|
||||
/// | br loop(%loaded) |
|
||||
/// +---------------------------------+
|
||||
/// |
|
||||
/// -------| |
|
||||
/// | v v
|
||||
/// | +--------------------------------+
|
||||
/// | | loop(%loaded): |
|
||||
/// | | <body contents> |
|
||||
/// | | %pair = cmpxchg |
|
||||
/// | | %ok = %pair[0] |
|
||||
/// | | %new = %pair[1] |
|
||||
/// | | cond_br %ok, end, loop(%new) |
|
||||
/// | +--------------------------------+
|
||||
/// | | |
|
||||
/// |----------- |
|
||||
/// v
|
||||
/// +--------------------------------+
|
||||
/// | end: |
|
||||
/// | <code after the AtomicRMWOp> |
|
||||
/// +--------------------------------+
|
||||
///
|
||||
struct GenericAtomicRMWOpLowering
|
||||
: public LoadStoreOpLowering<GenericAtomicRMWOp> {
|
||||
using Base::Base;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
auto loc = atomicOp.getLoc();
|
||||
Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
|
||||
|
||||
// Split the block into initial, loop, and ending parts.
|
||||
auto *initBlock = rewriter.getInsertionBlock();
|
||||
auto *loopBlock = rewriter.createBlock(
|
||||
initBlock->getParent(), std::next(Region::iterator(initBlock)),
|
||||
valueType, loc);
|
||||
auto *endBlock = rewriter.createBlock(
|
||||
loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
|
||||
|
||||
// Operations range to be moved to `endBlock`.
|
||||
auto opsToMoveStart = atomicOp->getIterator();
|
||||
auto opsToMoveEnd = initBlock->back().getIterator();
|
||||
|
||||
// Compute the loaded value and branch to the loop block.
|
||||
rewriter.setInsertionPointToEnd(initBlock);
|
||||
auto memRefType = atomicOp.getMemref().getType().cast<MemRefType>();
|
||||
auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
|
||||
adaptor.getIndices(), rewriter);
|
||||
Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
|
||||
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
|
||||
|
||||
// Prepare the body of the loop block.
|
||||
rewriter.setInsertionPointToStart(loopBlock);
|
||||
|
||||
// Clone the GenericAtomicRMWOp region and extract the result.
|
||||
auto loopArgument = loopBlock->getArgument(0);
|
||||
BlockAndValueMapping mapping;
|
||||
mapping.map(atomicOp.getCurrentValue(), loopArgument);
|
||||
Block &entryBlock = atomicOp.body().front();
|
||||
for (auto &nestedOp : entryBlock.without_terminator()) {
|
||||
Operation *clone = rewriter.clone(nestedOp, mapping);
|
||||
mapping.map(nestedOp.getResults(), clone->getResults());
|
||||
}
|
||||
Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
|
||||
|
||||
// Prepare the epilog of the loop block.
|
||||
// Append the cmpxchg op to the end of the loop block.
|
||||
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
|
||||
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
|
||||
auto boolType = IntegerType::get(rewriter.getContext(), 1);
|
||||
auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
|
||||
{valueType, boolType});
|
||||
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
|
||||
loc, pairType, dataPtr, loopArgument, result, successOrdering,
|
||||
failureOrdering);
|
||||
// Extract the %new_loaded and %ok values from the pair.
|
||||
Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
|
||||
Value ok = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
|
||||
|
||||
// Conditionally branch to the end or back to the loop depending on %ok.
|
||||
rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
|
||||
loopBlock, newLoaded);
|
||||
|
||||
rewriter.setInsertionPointToEnd(endBlock);
|
||||
moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
|
||||
std::next(opsToMoveEnd), rewriter);
|
||||
|
||||
// The 'result' of the atomic_rmw op is the newly loaded value.
|
||||
rewriter.replaceOp(atomicOp, {newLoaded});
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
// Clones a segment of ops [start, end) and erases the original.
|
||||
void moveOpsRange(ValueRange oldResult, ValueRange newResult,
|
||||
Block::iterator start, Block::iterator end,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
BlockAndValueMapping mapping;
|
||||
mapping.map(oldResult, newResult);
|
||||
SmallVector<Operation *, 2> opsToErase;
|
||||
for (auto it = start; it != end; ++it) {
|
||||
rewriter.clone(*it, mapping);
|
||||
opsToErase.push_back(&*it);
|
||||
}
|
||||
for (auto *it : opsToErase)
|
||||
rewriter.eraseOp(it);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::populateStdToLLVMFuncOpConversionPattern(
|
||||
|
@ -911,7 +777,6 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
|||
CallOpLowering,
|
||||
CondBranchOpLowering,
|
||||
ConstantOpLowering,
|
||||
GenericAtomicRMWOpLowering,
|
||||
ReturnOpLowering,
|
||||
SelectOpLowering,
|
||||
SplatOpLowering,
|
||||
|
|
|
@ -945,6 +945,89 @@ static LogicalResult verify(DmaWaitOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericAtomicRMWOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value memref, ValueRange ivs) {
|
||||
result.addOperands(memref);
|
||||
result.addOperands(ivs);
|
||||
|
||||
if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
|
||||
Type elementType = memrefType.getElementType();
|
||||
result.addTypes(elementType);
|
||||
|
||||
Region *bodyRegion = result.addRegion();
|
||||
bodyRegion->push_back(new Block());
|
||||
bodyRegion->addArgument(elementType, memref.getLoc());
|
||||
}
|
||||
}
|
||||
|
||||
static LogicalResult verify(GenericAtomicRMWOp op) {
|
||||
auto &body = op.getRegion();
|
||||
if (body.getNumArguments() != 1)
|
||||
return op.emitOpError("expected single number of entry block arguments");
|
||||
|
||||
if (op.getResult().getType() != body.getArgument(0).getType())
|
||||
return op.emitOpError(
|
||||
"expected block argument of the same type result type");
|
||||
|
||||
bool hasSideEffects =
|
||||
body.walk([&](Operation *nestedOp) {
|
||||
if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
|
||||
return WalkResult::advance();
|
||||
nestedOp->emitError(
|
||||
"body of 'memref.generic_atomic_rmw' should contain "
|
||||
"only operations with no side effects");
|
||||
return WalkResult::interrupt();
|
||||
})
|
||||
.wasInterrupted();
|
||||
return hasSideEffects ? failure() : success();
|
||||
}
|
||||
|
||||
static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType memref;
|
||||
Type memrefType;
|
||||
SmallVector<OpAsmParser::OperandType, 4> ivs;
|
||||
|
||||
Type indexType = parser.getBuilder().getIndexType();
|
||||
if (parser.parseOperand(memref) ||
|
||||
parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
|
||||
parser.parseColonType(memrefType) ||
|
||||
parser.resolveOperand(memref, memrefType, result.operands) ||
|
||||
parser.resolveOperands(ivs, indexType, result.operands))
|
||||
return failure();
|
||||
|
||||
Region *body = result.addRegion();
|
||||
if (parser.parseRegion(*body, llvm::None, llvm::None) ||
|
||||
parser.parseOptionalAttrDict(result.attributes))
|
||||
return failure();
|
||||
result.types.push_back(memrefType.cast<MemRefType>().getElementType());
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
|
||||
p << ' ' << op.memref() << "[" << op.indices()
|
||||
<< "] : " << op.memref().getType() << ' ';
|
||||
p.printRegion(op.getRegion());
|
||||
p.printOptionalAttrDict(op->getAttrs());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtomicYieldOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(AtomicYieldOp op) {
|
||||
Type parentType = op->getParentOp()->getResultTypes().front();
|
||||
Type resultType = op.result().getType();
|
||||
if (parentType != resultType)
|
||||
return op.emitOpError() << "types mismatch between yield op: " << resultType
|
||||
<< " and its parent: " << parentType;
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GlobalOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -131,88 +131,6 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
|
|||
return failure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericAtomicRMWOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value memref, ValueRange ivs) {
|
||||
result.addOperands(memref);
|
||||
result.addOperands(ivs);
|
||||
|
||||
if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
|
||||
Type elementType = memrefType.getElementType();
|
||||
result.addTypes(elementType);
|
||||
|
||||
Region *bodyRegion = result.addRegion();
|
||||
bodyRegion->push_back(new Block());
|
||||
bodyRegion->addArgument(elementType, memref.getLoc());
|
||||
}
|
||||
}
|
||||
|
||||
static LogicalResult verify(GenericAtomicRMWOp op) {
|
||||
auto &body = op.getRegion();
|
||||
if (body.getNumArguments() != 1)
|
||||
return op.emitOpError("expected single number of entry block arguments");
|
||||
|
||||
if (op.getResult().getType() != body.getArgument(0).getType())
|
||||
return op.emitOpError(
|
||||
"expected block argument of the same type result type");
|
||||
|
||||
bool hasSideEffects =
|
||||
body.walk([&](Operation *nestedOp) {
|
||||
if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
|
||||
return WalkResult::advance();
|
||||
nestedOp->emitError("body of 'generic_atomic_rmw' should contain "
|
||||
"only operations with no side effects");
|
||||
return WalkResult::interrupt();
|
||||
})
|
||||
.wasInterrupted();
|
||||
return hasSideEffects ? failure() : success();
|
||||
}
|
||||
|
||||
static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType memref;
|
||||
Type memrefType;
|
||||
SmallVector<OpAsmParser::OperandType, 4> ivs;
|
||||
|
||||
Type indexType = parser.getBuilder().getIndexType();
|
||||
if (parser.parseOperand(memref) ||
|
||||
parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
|
||||
parser.parseColonType(memrefType) ||
|
||||
parser.resolveOperand(memref, memrefType, result.operands) ||
|
||||
parser.resolveOperands(ivs, indexType, result.operands))
|
||||
return failure();
|
||||
|
||||
Region *body = result.addRegion();
|
||||
if (parser.parseRegion(*body, llvm::None, llvm::None) ||
|
||||
parser.parseOptionalAttrDict(result.attributes))
|
||||
return failure();
|
||||
result.types.push_back(memrefType.cast<MemRefType>().getElementType());
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
|
||||
p << ' ' << op.getMemref() << "[" << op.getIndices()
|
||||
<< "] : " << op.getMemref().getType() << ' ';
|
||||
p.printRegion(op.getRegion());
|
||||
p.printOptionalAttrDict(op->getAttrs());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtomicYieldOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(AtomicYieldOp op) {
|
||||
Type parentType = op->getParentOp()->getResultTypes().front();
|
||||
Type resultType = op.getResult().getType();
|
||||
if (parentType != resultType)
|
||||
return op.emitOpError() << "types mismatch between yield op: " << resultType
|
||||
<< " and its parent: " << parentType;
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BranchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -28,17 +28,17 @@ namespace {
|
|||
|
||||
/// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
|
||||
/// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
|
||||
/// `generic_atomic_rmw` with the expanded code.
|
||||
/// `memref.generic_atomic_rmw` with the expanded code.
|
||||
///
|
||||
/// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
|
||||
///
|
||||
/// will be lowered to
|
||||
///
|
||||
/// %x = std.generic_atomic_rmw %F[%i] : memref<10xf32> {
|
||||
/// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
|
||||
/// ^bb0(%current: f32):
|
||||
/// %cmp = arith.cmpf "ogt", %current, %fval : f32
|
||||
/// %new_value = select %cmp, %current, %fval : f32
|
||||
/// atomic_yield %new_value : f32
|
||||
/// memref.atomic_yield %new_value : f32
|
||||
/// }
|
||||
struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
|
||||
public:
|
||||
|
@ -59,8 +59,8 @@ public:
|
|||
}
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto genericOp =
|
||||
rewriter.create<GenericAtomicRMWOp>(loc, op.memref(), op.indices());
|
||||
auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>(
|
||||
loc, op.memref(), op.indices());
|
||||
OpBuilder bodyBuilder =
|
||||
OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener());
|
||||
|
||||
|
@ -68,7 +68,7 @@ public:
|
|||
Value rhs = op.value();
|
||||
Value cmp = bodyBuilder.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
|
||||
Value select = bodyBuilder.create<SelectOp>(loc, cmp, lhs, rhs);
|
||||
bodyBuilder.create<AtomicYieldOp>(loc, select);
|
||||
bodyBuilder.create<memref::AtomicYieldOp>(loc, select);
|
||||
|
||||
rewriter.replaceOp(op, genericOp.getResult());
|
||||
return success();
|
||||
|
|
|
@ -892,3 +892,22 @@ func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, a
|
|||
%1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 64 + s0 + d1 * 64 + d2 * 8 + d3)>> into memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
|
||||
return %1 : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @generic_atomic_rmw
|
||||
func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) {
|
||||
%x = memref.generic_atomic_rmw %I[%i] : memref<10xi32> {
|
||||
^bb0(%old_value : i32):
|
||||
memref.atomic_yield %old_value : i32
|
||||
}
|
||||
// CHECK: [[init:%.*]] = llvm.load %{{.*}} : !llvm.ptr<i32>
|
||||
// CHECK-NEXT: llvm.br ^bb1([[init]] : i32)
|
||||
// CHECK-NEXT: ^bb1([[loaded:%.*]]: i32):
|
||||
// CHECK-NEXT: [[pair:%.*]] = llvm.cmpxchg %{{.*}}, [[loaded]], [[loaded]]
|
||||
// CHECK-SAME: acq_rel monotonic : i32
|
||||
// CHECK-NEXT: [[new:%.*]] = llvm.extractvalue [[pair]][0]
|
||||
// CHECK-NEXT: [[ok:%.*]] = llvm.extractvalue [[pair]][1]
|
||||
// CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : i32)
|
||||
llvm.return
|
||||
}
|
||||
|
|
|
@ -486,33 +486,6 @@ func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @generic_atomic_rmw
|
||||
func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) -> i32 {
|
||||
%x = generic_atomic_rmw %I[%i] : memref<10xi32> {
|
||||
^bb0(%old_value : i32):
|
||||
%c1 = arith.constant 1 : i32
|
||||
atomic_yield %c1 : i32
|
||||
}
|
||||
// CHECK: [[init:%.*]] = llvm.load %{{.*}} : !llvm.ptr<i32>
|
||||
// CHECK-NEXT: llvm.br ^bb1([[init]] : i32)
|
||||
// CHECK-NEXT: ^bb1([[loaded:%.*]]: i32):
|
||||
// CHECK-NEXT: [[c1:%.*]] = llvm.mlir.constant(1 : i32)
|
||||
// CHECK-NEXT: [[pair:%.*]] = llvm.cmpxchg %{{.*}}, [[loaded]], [[c1]]
|
||||
// CHECK-SAME: acq_rel monotonic : i32
|
||||
// CHECK-NEXT: [[new:%.*]] = llvm.extractvalue [[pair]][0]
|
||||
// CHECK-NEXT: [[ok:%.*]] = llvm.extractvalue [[pair]][1]
|
||||
// CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : i32)
|
||||
// CHECK-NEXT: ^bb2:
|
||||
%c2 = arith.constant 2 : i32
|
||||
%add = arith.addi %c2, %x : i32
|
||||
return %add : i32
|
||||
// CHECK-NEXT: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
|
||||
// CHECK-NEXT: [[add:%.*]] = llvm.add [[c2]], [[new]] : i32
|
||||
// CHECK-NEXT: llvm.return [[add]]
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @ceilf(
|
||||
// CHECK-SAME: f32
|
||||
func @ceilf(%arg0 : f32) {
|
||||
|
|
|
@ -910,3 +910,63 @@ func @atomic_rmw_expects_int(%I: memref<16x10xf32>, %i : index, %val : f32) {
|
|||
%x = memref.atomic_rmw addi %val, %I[%i, %i] : (f32, memref<16x10xf32>) -> f32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @generic_atomic_rmw_wrong_arg_num(%I: memref<10xf32>, %i : index) {
|
||||
// expected-error@+1 {{expected single number of entry block arguments}}
|
||||
%x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> {
|
||||
^bb0(%arg0 : f32, %arg1 : f32):
|
||||
%c1 = arith.constant 1.0 : f32
|
||||
memref.atomic_yield %c1 : f32
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @generic_atomic_rmw_wrong_arg_type(%I: memref<10xf32>, %i : index) {
|
||||
// expected-error@+1 {{expected block argument of the same type result type}}
|
||||
%x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> {
|
||||
^bb0(%old_value : i32):
|
||||
%c1 = arith.constant 1.0 : f32
|
||||
memref.atomic_yield %c1 : f32
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @generic_atomic_rmw_result_type_mismatch(%I: memref<10xf32>, %i : index) {
|
||||
// expected-error@+1 {{failed to verify that result type matches element type of memref}}
|
||||
%0 = "memref.generic_atomic_rmw"(%I, %i) ({
|
||||
^bb0(%old_value: f32):
|
||||
%c1 = arith.constant 1.0 : f32
|
||||
memref.atomic_yield %c1 : f32
|
||||
}) : (memref<10xf32>, index) -> i32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @generic_atomic_rmw_has_side_effects(%I: memref<10xf32>, %i : index) {
|
||||
// expected-error@+4 {{should contain only operations with no side effects}}
|
||||
%x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> {
|
||||
^bb0(%old_value : f32):
|
||||
%c1 = arith.constant 1.0 : f32
|
||||
%buf = memref.alloc() : memref<2048xf32>
|
||||
memref.atomic_yield %c1 : f32
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @atomic_yield_type_mismatch(%I: memref<10xf32>, %i : index) {
|
||||
// expected-error@+4 {{op types mismatch between yield op: 'i32' and its parent: 'f32'}}
|
||||
%x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> {
|
||||
^bb0(%old_value : f32):
|
||||
%c1 = arith.constant 1 : i32
|
||||
memref.atomic_yield %c1 : i32
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -246,3 +246,17 @@ func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) {
|
|||
// CHECK: memref.atomic_rmw addf [[VAL]], [[BUF]]{{\[}}[[I]]]
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @generic_atomic_rmw
|
||||
// CHECK-SAME: ([[BUF:%.*]]: memref<1x2xf32>, [[I:%.*]]: index, [[J:%.*]]: index)
|
||||
func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) {
|
||||
%x = memref.generic_atomic_rmw %I[%i, %j] : memref<1x2xf32> {
|
||||
// CHECK-NEXT: memref.generic_atomic_rmw [[BUF]]{{\[}}[[I]], [[J]]] : memref
|
||||
^bb0(%old_value : f32):
|
||||
%c1 = arith.constant 1.0 : f32
|
||||
%out = arith.addf %c1, %old_value : f32
|
||||
memref.atomic_yield %out : f32
|
||||
// CHECK: index_attr = 8 : index
|
||||
} { index_attr = 8 : index }
|
||||
return
|
||||
}
|
||||
|
|
|
@ -6,11 +6,11 @@ func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
|
|||
%x = memref.atomic_rmw maxf %f, %F[%i] : (f32, memref<10xf32>) -> f32
|
||||
return %x : f32
|
||||
}
|
||||
// CHECK: %0 = generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
|
||||
// CHECK: %0 = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
|
||||
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
|
||||
// CHECK: [[CMP:%.*]] = arith.cmpf ogt, [[CUR_VAL]], [[f]] : f32
|
||||
// CHECK: [[SELECT:%.*]] = select [[CMP]], [[CUR_VAL]], [[f]] : f32
|
||||
// CHECK: atomic_yield [[SELECT]] : f32
|
||||
// CHECK: memref.atomic_yield [[SELECT]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: return %0 : f32
|
||||
|
||||
|
|
|
@ -325,20 +325,6 @@ func @unranked_tensor_load_store(%0 : memref<*xi32>, %1 : tensor<*xi32>) {
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @generic_atomic_rmw
|
||||
// CHECK-SAME: ([[BUF:%.*]]: memref<1x2xf32>, [[I:%.*]]: index, [[J:%.*]]: index)
|
||||
func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) {
|
||||
%x = generic_atomic_rmw %I[%i, %j] : memref<1x2xf32> {
|
||||
// CHECK-NEXT: generic_atomic_rmw [[BUF]]{{\[}}[[I]], [[J]]] : memref
|
||||
^bb0(%old_value : f32):
|
||||
%c1 = arith.constant 1.0 : f32
|
||||
%out = arith.addf %c1, %old_value : f32
|
||||
atomic_yield %out : f32
|
||||
// CHECK: index_attr = 8 : index
|
||||
} { index_attr = 8 : index }
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @assume_alignment
|
||||
// CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16>
|
||||
func @assume_alignment(%0: memref<4x4xf16>) {
|
||||
|
|
|
@ -127,63 +127,3 @@ func @invalid_splat(%v : f32) { // expected-note {{prior use here}}
|
|||
// expected-error@-1 {{expects different type than prior uses}}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @generic_atomic_rmw_wrong_arg_num(%I: memref<10xf32>, %i : index) {
|
||||
// expected-error@+1 {{expected single number of entry block arguments}}
|
||||
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
|
||||
^bb0(%arg0 : f32, %arg1 : f32):
|
||||
%c1 = arith.constant 1.0 : f32
|
||||
atomic_yield %c1 : f32
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @generic_atomic_rmw_wrong_arg_type(%I: memref<10xf32>, %i : index) {
|
||||
// expected-error@+1 {{expected block argument of the same type result type}}
|
||||
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
|
||||
^bb0(%old_value : i32):
|
||||
%c1 = arith.constant 1.0 : f32
|
||||
atomic_yield %c1 : f32
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @generic_atomic_rmw_result_type_mismatch(%I: memref<10xf32>, %i : index) {
|
||||
// expected-error@+1 {{failed to verify that result type matches element type of memref}}
|
||||
%0 = "std.generic_atomic_rmw"(%I, %i) ({
|
||||
^bb0(%old_value: f32):
|
||||
%c1 = arith.constant 1.0 : f32
|
||||
atomic_yield %c1 : f32
|
||||
}) : (memref<10xf32>, index) -> i32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @generic_atomic_rmw_has_side_effects(%I: memref<10xf32>, %i : index) {
|
||||
// expected-error@+4 {{should contain only operations with no side effects}}
|
||||
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
|
||||
^bb0(%old_value : f32):
|
||||
%c1 = arith.constant 1.0 : f32
|
||||
%buf = memref.alloc() : memref<2048xf32>
|
||||
atomic_yield %c1 : f32
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @atomic_yield_type_mismatch(%I: memref<10xf32>, %i : index) {
|
||||
// expected-error@+4 {{op types mismatch between yield op: 'i32' and its parent: 'f32'}}
|
||||
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
|
||||
^bb0(%old_value : f32):
|
||||
%c1 = arith.constant 1 : i32
|
||||
atomic_yield %c1 : i32
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue