[MLIR] Move AtomicRMW into MemRef dialect and enum into Arith

Per the discussion in https://reviews.llvm.org/D116345 it makes sense
to move AtomicRMWOp out of the standard dialect. This was accentuated by the
need to add a fold op with a memref::cast. The only dialect
that would permit this is the memref dialect (keeping it in the standard dialect
or moving it to the arithmetic dialect would require those dialects to have a
dependency on the memref dialect, which breaks linking).

As the AtomicRMWKind enum is used throughout, this has been moved to Arith.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D116392
This commit is contained in:
William S. Moses 2021-12-30 00:59:58 -05:00
parent 9d37d0ea34
commit a6a583dae4
29 changed files with 429 additions and 428 deletions

View File

@ -15,6 +15,7 @@
#ifndef MLIR_ANALYSIS_AFFINE_ANALYSIS_H
#define MLIR_ANALYSIS_AFFINE_ANALYSIS_H
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/Optional.h"
@ -32,7 +33,7 @@ class Operation;
/// A description of a (parallelizable) reduction in an affine loop.
struct LoopReduction {
/// Reduction kind.
AtomicRMWKind kind;
arith::AtomicRMWKind kind;
/// Position of the iteration argument that acts as accumulator.
unsigned iterArgPosition;

View File

@ -13,7 +13,7 @@
#ifndef AFFINE_OPS
#define AFFINE_OPS
include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td"
include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
@ -691,9 +691,9 @@ def AffineParallelOp : Affine_Op<"parallel",
let builders = [
OpBuilder<(ins "TypeRange":$resultTypes,
"ArrayRef<AtomicRMWKind>":$reductions, "ArrayRef<int64_t>":$ranges)>,
"ArrayRef<arith::AtomicRMWKind>":$reductions, "ArrayRef<int64_t>":$ranges)>,
OpBuilder<(ins "TypeRange":$resultTypes,
"ArrayRef<AtomicRMWKind>":$reductions, "ArrayRef<AffineMap>":$lbMaps,
"ArrayRef<arith::AtomicRMWKind>":$reductions, "ArrayRef<AffineMap>":$lbMaps,
"ValueRange":$lbArgs, "ArrayRef<AffineMap>":$ubMaps, "ValueRange":$ubArgs,
"ArrayRef<int64_t>":$steps)>
];

View File

@ -109,6 +109,18 @@ bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs,
bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs,
const APFloat &rhs);
/// Returns the identity value attribute associated with an AtomicRMWKind op.
Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
OpBuilder &builder, Location loc);
/// Returns the identity value associated with an AtomicRMWKind op.
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder,
Location loc);
/// Returns the value obtained by applying the reduction operation kind
/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
Value lhs, Value rhs);
} // namespace arith
} // namespace mlir

View File

@ -68,4 +68,28 @@ def Arith_CmpIPredicateAttr : I64EnumAttr<
let cppNamespace = "::mlir::arith";
}
def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>;
def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>;
def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>;
def ATOMIC_RMW_KIND_MAXF : I64EnumAttrCase<"maxf", 3>;
def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>;
def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>;
def ATOMIC_RMW_KIND_MINF : I64EnumAttrCase<"minf", 6>;
def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>;
def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>;
def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;
def AtomicRMWKindAttr : I64EnumAttr<
"AtomicRMWKind", "",
[ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
ATOMIC_RMW_KIND_ANDI]> {
let cppNamespace = "::mlir::arith";
}
#endif // ARITHMETIC_BASE

View File

@ -11,6 +11,7 @@
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/CopyOpInterface.td"
@ -1673,4 +1674,51 @@ def MemRef_ViewOp : MemRef_Op<"view", [
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
// AtomicRMWOp
//===----------------------------------------------------------------------===//
def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
AllTypesMatch<["value", "result"]>,
TypesMatchWith<"value type matches element type of memref",
"memref", "value",
"$_self.cast<MemRefType>().getElementType()">
]> {
let summary = "atomic read-modify-write operation";
let description = [{
The `atomic_rmw` operation provides a way to perform a read-modify-write
sequence that is free from data races. The kind enumeration specifies the
modification to perform. The value operand represents the new value to be
applied during the modification. The memref operand represents the buffer
that the read and write will be performed against, as accessed by the
specified indices. The arity of the indices is the rank of the memref. The
result represents the latest value that was stored.
Example:
```mlir
%x = arith.atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32
```
}];
let arguments = (ins
AtomicRMWKindAttr:$kind,
AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$value,
MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref,
Variadic<Index>:$indices);
let results = (outs AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result);
let assemblyFormat = [{
$kind $value `,` $memref `[` $indices `]` attr-dict `:` `(` type($value) `,`
type($memref) `)` `->` type($result)
}];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return memref().getType().cast<MemRefType>();
}
}];
let hasFolder = 1;
}
#endif // MEMREF_OPS

View File

@ -42,31 +42,4 @@ class PatternRewriter;
#include "mlir/Dialect/StandardOps/IR/OpsDialect.h.inc"
namespace mlir {
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
/// comparison predicates.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs,
const APInt &rhs);
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
/// comparison predicates.
bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs,
const APFloat &rhs);
/// Returns the identity value attribute associated with an AtomicRMWKind op.
Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
OpBuilder &builder, Location loc);
/// Returns the identity value associated with an AtomicRMWKind op.
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder,
Location loc);
/// Returns the value obtained by applying the reduction operation kind
/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
Value lhs, Value rhs);
} // namespace mlir
#endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H

View File

@ -13,7 +13,6 @@
#ifndef STANDARD_OPS
#define STANDARD_OPS
include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
@ -179,52 +178,6 @@ def AssertOp : Std_Op<"assert"> {
let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
// AtomicRMWOp
//===----------------------------------------------------------------------===//
def AtomicRMWOp : Std_Op<"atomic_rmw", [
AllTypesMatch<["value", "result"]>,
TypesMatchWith<"value type matches element type of memref",
"memref", "value",
"$_self.cast<MemRefType>().getElementType()">
]> {
let summary = "atomic read-modify-write operation";
let description = [{
The `atomic_rmw` operation provides a way to perform a read-modify-write
sequence that is free from data races. The kind enumeration specifies the
modification to perform. The value operand represents the new value to be
applied during the modification. The memref operand represents the buffer
that the read and write will be performed against, as accessed by the
specified indices. The arity of the indices is the rank of the memref. The
result represents the latest value that was stored.
Example:
```mlir
%x = atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32
```
}];
let arguments = (ins
AtomicRMWKindAttr:$kind,
AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$value,
MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref,
Variadic<Index>:$indices);
let results = (outs AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result);
let assemblyFormat = [{
$kind $value `,` $memref `[` $indices `]` attr-dict `:` `(` type($value) `,`
type($memref) `)` `->` type($result)
}];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return getMemref().getType().cast<MemRefType>();
}
}];
}
def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [
SingleBlockImplicitTerminator<"AtomicYieldOp">,
TypesMatchWith<"result type matches element type of memref",

View File

@ -1,42 +0,0 @@
//===- StandardOpsBase.td - Standard ops definitions -------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines base support for standard operations.
//
//===----------------------------------------------------------------------===//
#ifndef STANDARD_OPS_BASE
#define STANDARD_OPS_BASE
include "mlir/IR/OpBase.td"
def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>;
def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>;
def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>;
def ATOMIC_RMW_KIND_MAXF : I64EnumAttrCase<"maxf", 3>;
def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>;
def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>;
def ATOMIC_RMW_KIND_MINF : I64EnumAttrCase<"minf", 6>;
def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>;
def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>;
def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;
def AtomicRMWKindAttr : I64EnumAttr<
"AtomicRMWKind", "",
[ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
ATOMIC_RMW_KIND_ANDI]> {
let cppNamespace = "::mlir";
}
#endif // STANDARD_OPS_BASE

View File

@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_VECTOR_VECTOROPS_H
#define MLIR_DIALECT_VECTOR_VECTOROPS_H
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
@ -145,8 +146,8 @@ ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values);
/// Returns the value obtained by reducing the vector into a scalar using the
/// operation kind associated with a binary AtomicRMWKind op.
Value getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
Value vector);
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder,
Location loc, Value vector);
/// Return true if the last dimension of the MemRefType has unit stride. Also
/// return true for memrefs with no strides.

View File

@ -40,7 +40,7 @@ using llvm::dbgs;
/// reduction kind suitable for use in affine parallel loop builder. If the
/// reduction is not supported, returns null.
static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
AtomicRMWKind &kind) {
arith::AtomicRMWKind &kind) {
SmallVector<Operation *> combinerOps;
Value reducedVal =
matchReduction(forOp.getRegionIterArgs(), pos, combinerOps);
@ -52,21 +52,21 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
return nullptr;
Operation *combinerOp = combinerOps.back();
Optional<AtomicRMWKind> maybeKind =
TypeSwitch<Operation *, Optional<AtomicRMWKind>>(combinerOp)
.Case([](arith::AddFOp) { return AtomicRMWKind::addf; })
.Case([](arith::MulFOp) { return AtomicRMWKind::mulf; })
.Case([](arith::AddIOp) { return AtomicRMWKind::addi; })
.Case([](arith::AndIOp) { return AtomicRMWKind::andi; })
.Case([](arith::OrIOp) { return AtomicRMWKind::ori; })
.Case([](arith::MulIOp) { return AtomicRMWKind::muli; })
.Case([](arith::MinFOp) { return AtomicRMWKind::minf; })
.Case([](arith::MaxFOp) { return AtomicRMWKind::maxf; })
.Case([](arith::MinSIOp) { return AtomicRMWKind::mins; })
.Case([](arith::MaxSIOp) { return AtomicRMWKind::maxs; })
.Case([](arith::MinUIOp) { return AtomicRMWKind::minu; })
.Case([](arith::MaxUIOp) { return AtomicRMWKind::maxu; })
.Default([](Operation *) -> Optional<AtomicRMWKind> {
Optional<arith::AtomicRMWKind> maybeKind =
TypeSwitch<Operation *, Optional<arith::AtomicRMWKind>>(combinerOp)
.Case([](arith::AddFOp) { return arith::AtomicRMWKind::addf; })
.Case([](arith::MulFOp) { return arith::AtomicRMWKind::mulf; })
.Case([](arith::AddIOp) { return arith::AtomicRMWKind::addi; })
.Case([](arith::AndIOp) { return arith::AtomicRMWKind::andi; })
.Case([](arith::OrIOp) { return arith::AtomicRMWKind::ori; })
.Case([](arith::MulIOp) { return arith::AtomicRMWKind::muli; })
.Case([](arith::MinFOp) { return arith::AtomicRMWKind::minf; })
.Case([](arith::MaxFOp) { return arith::AtomicRMWKind::maxf; })
.Case([](arith::MinSIOp) { return arith::AtomicRMWKind::mins; })
.Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; })
.Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; })
.Case([](arith::MaxUIOp) { return arith::AtomicRMWKind::maxu; })
.Default([](Operation *) -> Optional<arith::AtomicRMWKind> {
// TODO: AtomicRMW supports other kinds of reductions this is
// currently not detecting, add those when the need arises.
return llvm::None;
@ -86,7 +86,7 @@ void mlir::getSupportedReductions(
return;
supportedReductions.reserve(numIterArgs);
for (unsigned i = 0; i < numIterArgs; ++i) {
AtomicRMWKind kind;
arith::AtomicRMWKind kind;
if (Value value = getSupportedReduction(forOp, i, kind))
supportedReductions.emplace_back(LoopReduction{kind, i, value});
}

View File

@ -430,13 +430,14 @@ public:
// initialization of the result values.
Attribute reduction = std::get<0>(pair);
Type resultType = std::get<1>(pair);
Optional<AtomicRMWKind> reductionOp = symbolizeAtomicRMWKind(
static_cast<uint64_t>(reduction.cast<IntegerAttr>().getInt()));
Optional<arith::AtomicRMWKind> reductionOp =
arith::symbolizeAtomicRMWKind(
static_cast<uint64_t>(reduction.cast<IntegerAttr>().getInt()));
assert(reductionOp.hasValue() &&
"Reduction operation cannot be of None Type");
AtomicRMWKind reductionOpValue = reductionOp.getValue();
arith::AtomicRMWKind reductionOpValue = reductionOp.getValue();
identityVals.push_back(
getIdentityValue(reductionOpValue, resultType, rewriter, loc));
arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc));
}
parOp = rewriter.create<scf::ParallelOp>(
loc, lowerBoundTuple, upperBoundTuple, steps, identityVals,
@ -450,16 +451,17 @@ public:
"Unequal number of reductions and operands.");
for (unsigned i = 0, end = reductions.size(); i < end; i++) {
// For each of the reduction operations get the respective mlir::Value.
Optional<AtomicRMWKind> reductionOp =
symbolizeAtomicRMWKind(reductions[i].cast<IntegerAttr>().getInt());
Optional<arith::AtomicRMWKind> reductionOp =
arith::symbolizeAtomicRMWKind(
reductions[i].cast<IntegerAttr>().getInt());
assert(reductionOp.hasValue() &&
"Reduction Operation cannot be of None Type");
AtomicRMWKind reductionOpValue = reductionOp.getValue();
arith::AtomicRMWKind reductionOpValue = reductionOp.getValue();
rewriter.setInsertionPoint(&parOp.getBody()->back());
auto reduceOp = rewriter.create<scf::ReduceOp>(
loc, affineParOpTerminator->getOperand(i));
rewriter.setInsertionPointToEnd(&reduceOp.getReductionOperator().front());
Value reductionResult = getReductionOp(
Value reductionResult = arith::getReductionOp(
reductionOpValue, rewriter, loc,
reduceOp.getReductionOperator().front().getArgument(0),
reduceOp.getReductionOperator().front().getArgument(1));

View File

@ -1553,6 +1553,62 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
}
};
//===----------------------------------------------------------------------===//
// AtomicRMWOpLowering
//===----------------------------------------------------------------------===//
/// Try to match the kind of a std.atomic_rmw to determine whether to use a
/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
static Optional<LLVM::AtomicBinOp>
matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
switch (atomicOp.kind()) {
case arith::AtomicRMWKind::addf:
return LLVM::AtomicBinOp::fadd;
case arith::AtomicRMWKind::addi:
return LLVM::AtomicBinOp::add;
case arith::AtomicRMWKind::assign:
return LLVM::AtomicBinOp::xchg;
case arith::AtomicRMWKind::maxs:
return LLVM::AtomicBinOp::max;
case arith::AtomicRMWKind::maxu:
return LLVM::AtomicBinOp::umax;
case arith::AtomicRMWKind::mins:
return LLVM::AtomicBinOp::min;
case arith::AtomicRMWKind::minu:
return LLVM::AtomicBinOp::umin;
case arith::AtomicRMWKind::ori:
return LLVM::AtomicBinOp::_or;
case arith::AtomicRMWKind::andi:
return LLVM::AtomicBinOp::_and;
default:
return llvm::None;
}
llvm_unreachable("Invalid AtomicRMWKind");
}
struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
using Base::Base;
LogicalResult
matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(match(atomicOp)))
return failure();
auto maybeKind = matchSimpleAtomicOp(atomicOp);
if (!maybeKind)
return failure();
auto resultType = adaptor.value().getType();
auto memRefType = atomicOp.getMemRefType();
auto dataPtr =
getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
adaptor.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
LLVM::AtomicOrdering::acq_rel);
return success();
}
};
} // namespace
void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
@ -1561,6 +1617,7 @@ void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
patterns.add<
AllocaOpLowering,
AllocaScopeOpLowering,
AtomicRMWOpLowering,
AssumeAlignmentOpLowering,
DimOpLowering,
GlobalMemrefOpLowering,

View File

@ -772,61 +772,6 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
}
};
} // namespace
/// Try to match the kind of a std.atomic_rmw to determine whether to use a
/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
static Optional<LLVM::AtomicBinOp> matchSimpleAtomicOp(AtomicRMWOp atomicOp) {
switch (atomicOp.getKind()) {
case AtomicRMWKind::addf:
return LLVM::AtomicBinOp::fadd;
case AtomicRMWKind::addi:
return LLVM::AtomicBinOp::add;
case AtomicRMWKind::assign:
return LLVM::AtomicBinOp::xchg;
case AtomicRMWKind::maxs:
return LLVM::AtomicBinOp::max;
case AtomicRMWKind::maxu:
return LLVM::AtomicBinOp::umax;
case AtomicRMWKind::mins:
return LLVM::AtomicBinOp::min;
case AtomicRMWKind::minu:
return LLVM::AtomicBinOp::umin;
case AtomicRMWKind::ori:
return LLVM::AtomicBinOp::_or;
case AtomicRMWKind::andi:
return LLVM::AtomicBinOp::_and;
default:
return llvm::None;
}
llvm_unreachable("Invalid AtomicRMWKind");
}
namespace {
struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
using Base::Base;
LogicalResult
matchAndRewrite(AtomicRMWOp atomicOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(match(atomicOp)))
return failure();
auto maybeKind = matchSimpleAtomicOp(atomicOp);
if (!maybeKind)
return failure();
auto resultType = adaptor.getValue().getType();
auto memRefType = atomicOp.getMemRefType();
auto dataPtr =
getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
adaptor.getIndices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(),
LLVM::AtomicOrdering::acq_rel);
return success();
}
};
/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
/// retried until it succeeds in atomically storing a new value into memory.
///
@ -962,7 +907,6 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
// clang-format off
patterns.add<
AssertOpLowering,
AtomicRMWOpLowering,
BranchOpLowering,
CallIndirectOpLowering,
CallOpLowering,

View File

@ -2801,7 +2801,7 @@ LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
TypeRange resultTypes,
ArrayRef<AtomicRMWKind> reductions,
ArrayRef<arith::AtomicRMWKind> reductions,
ArrayRef<int64_t> ranges) {
SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
@ -2814,7 +2814,7 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
TypeRange resultTypes,
ArrayRef<AtomicRMWKind> reductions,
ArrayRef<arith::AtomicRMWKind> reductions,
ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
ArrayRef<int64_t> steps) {
@ -2843,7 +2843,7 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
// Convert the reductions to integer attributes.
SmallVector<Attribute, 4> reductionAttrs;
for (AtomicRMWKind reduction : reductions)
for (arith::AtomicRMWKind reduction : reductions)
reductionAttrs.push_back(
builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
result.addAttribute(getReductionsAttrName(),
@ -3050,7 +3050,7 @@ static LogicalResult verify(AffineParallelOp op) {
// Verify reduction ops are all valid
for (Attribute attr : op.reductions()) {
auto intAttr = attr.dyn_cast<IntegerAttr>();
if (!intAttr || !symbolizeAtomicRMWKind(intAttr.getInt()))
if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
return op.emitOpError("invalid reduction attribute");
}
@ -3150,9 +3150,9 @@ static void print(OpAsmPrinter &p, AffineParallelOp op) {
if (op.getNumResults()) {
p << " reduce (";
llvm::interleaveComma(op.reductions(), p, [&](auto &attr) {
AtomicRMWKind sym =
*symbolizeAtomicRMWKind(attr.template cast<IntegerAttr>().getInt());
p << "\"" << stringifyAtomicRMWKind(sym) << "\"";
arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
attr.template cast<IntegerAttr>().getInt());
p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
});
p << ") -> (" << op.getResultTypes() << ")";
}
@ -3374,8 +3374,8 @@ static ParseResult parseAffineParallelOp(OpAsmParser &parser,
if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
attrStorage))
return failure();
llvm::Optional<AtomicRMWKind> reduction =
symbolizeAtomicRMWKind(attrVal.getValue());
llvm::Optional<arith::AtomicRMWKind> reduction =
arith::symbolizeAtomicRMWKind(attrVal.getValue());
if (!reduction)
return parser.emitError(loc, "invalid reduction value: ") << attrVal;
reductions.push_back(builder.getI64IntegerAttr(

View File

@ -971,7 +971,7 @@ static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp,
/// Creates a constant vector filled with the neutral elements of the given
/// reduction. The scalar type of vector elements will be taken from
/// `oldOperand`.
static arith::ConstantOp createInitialVector(AtomicRMWKind reductionKind,
static arith::ConstantOp createInitialVector(arith::AtomicRMWKind reductionKind,
Value oldOperand,
VectorizationState &state) {
Type scalarTy = oldOperand.getType();
@ -1245,8 +1245,8 @@ static Operation *vectorizeAffineStore(AffineStoreOp storeOp,
/// Returns true if `value` is a constant equal to the neutral element of the
/// given vectorizable reduction.
static bool isNeutralElementConst(AtomicRMWKind reductionKind, Value value,
VectorizationState &state) {
static bool isNeutralElementConst(arith::AtomicRMWKind reductionKind,
Value value, VectorizationState &state) {
Type scalarTy = value.getType();
if (!VectorType::isValidElementType(scalarTy))
return false;
@ -1361,7 +1361,8 @@ static Operation *vectorizeAffineForOp(AffineForOp forOp,
Value origInit = forOp.getOperand(forOp.getNumControlOperands() + i);
Value finalRes = reducedRes;
if (!isNeutralElementConst(reductions[i].kind, origInit, state))
finalRes = getReductionOp(reductions[i].kind, state.builder,
finalRes =
arith::getReductionOp(reductions[i].kind, state.builder,
reducedRes.getLoc(), reducedRes, origInit);
state.registerLoopResultScalarReplacement(forOp.getResult(i), finalRes);
}

View File

@ -8,6 +8,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
@ -1208,6 +1209,101 @@ OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
return BoolAttr::get(getContext(), val);
}
//===----------------------------------------------------------------------===//
// Atomic Enum
//===----------------------------------------------------------------------===//
/// Returns the identity value attribute associated with an AtomicRMWKind op.
Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
OpBuilder &builder, Location loc) {
switch (kind) {
case AtomicRMWKind::maxf:
return builder.getFloatAttr(
resultType,
APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
/*Negative=*/true));
case AtomicRMWKind::addf:
case AtomicRMWKind::addi:
case AtomicRMWKind::maxu:
case AtomicRMWKind::ori:
return builder.getZeroAttr(resultType);
case AtomicRMWKind::andi:
return builder.getIntegerAttr(
resultType,
APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
case AtomicRMWKind::maxs:
return builder.getIntegerAttr(
resultType,
APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
case AtomicRMWKind::minf:
return builder.getFloatAttr(
resultType,
APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
/*Negative=*/false));
case AtomicRMWKind::mins:
return builder.getIntegerAttr(
resultType,
APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
case AtomicRMWKind::minu:
return builder.getIntegerAttr(
resultType,
APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
case AtomicRMWKind::muli:
return builder.getIntegerAttr(resultType, 1);
case AtomicRMWKind::mulf:
return builder.getFloatAttr(resultType, 1);
// TODO: Add remaining reduction operations.
default:
(void)emitOptionalError(loc, "Reduction operation type not supported");
break;
}
return nullptr;
}
/// Returns the identity value associated with an AtomicRMWKind op.
Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
OpBuilder &builder, Location loc) {
Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
return builder.create<arith::ConstantOp>(loc, attr);
}
/// Return the value obtained by applying the reduction operation kind
/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
Location loc, Value lhs, Value rhs) {
switch (op) {
case AtomicRMWKind::addf:
return builder.create<arith::AddFOp>(loc, lhs, rhs);
case AtomicRMWKind::addi:
return builder.create<arith::AddIOp>(loc, lhs, rhs);
case AtomicRMWKind::mulf:
return builder.create<arith::MulFOp>(loc, lhs, rhs);
case AtomicRMWKind::muli:
return builder.create<arith::MulIOp>(loc, lhs, rhs);
case AtomicRMWKind::maxf:
return builder.create<arith::MaxFOp>(loc, lhs, rhs);
case AtomicRMWKind::minf:
return builder.create<arith::MinFOp>(loc, lhs, rhs);
case AtomicRMWKind::maxs:
return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
case AtomicRMWKind::mins:
return builder.create<arith::MinSIOp>(loc, lhs, rhs);
case AtomicRMWKind::maxu:
return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
case AtomicRMWKind::minu:
return builder.create<arith::MinUIOp>(loc, lhs, rhs);
case AtomicRMWKind::ori:
return builder.create<arith::OrIOp>(loc, lhs, rhs);
case AtomicRMWKind::andi:
return builder.create<arith::AndIOp>(loc, lhs, rhs);
// TODO: Add remaining reduction operations.
default:
(void)emitOptionalError(loc, "Reduction operation type not supported");
break;
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

View File

@ -2286,6 +2286,50 @@ void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
}
//===----------------------------------------------------------------------===//
// AtomicRMWOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(AtomicRMWOp op) {
if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
return op.emitOpError(
"expects the number of subscripts to be equal to memref rank");
switch (op.kind()) {
case arith::AtomicRMWKind::addf:
case arith::AtomicRMWKind::maxf:
case arith::AtomicRMWKind::minf:
case arith::AtomicRMWKind::mulf:
if (!op.value().getType().isa<FloatType>())
return op.emitOpError()
<< "with kind '" << arith::stringifyAtomicRMWKind(op.kind())
<< "' expects a floating-point type";
break;
case arith::AtomicRMWKind::addi:
case arith::AtomicRMWKind::maxs:
case arith::AtomicRMWKind::maxu:
case arith::AtomicRMWKind::mins:
case arith::AtomicRMWKind::minu:
case arith::AtomicRMWKind::muli:
case arith::AtomicRMWKind::ori:
case arith::AtomicRMWKind::andi:
if (!op.value().getType().isa<IntegerType>())
return op.emitOpError()
<< "with kind '" << arith::stringifyAtomicRMWKind(op.kind())
<< "' expects an integer type";
break;
default:
break;
}
return success();
}
OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) {
/// atomicrmw(memrefcast) -> atomicrmw
if (succeeded(foldMemRefCast(*this, value())))
return getResult();
return OpFoldResult();
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

View File

@ -131,134 +131,6 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
return failure();
}
//===----------------------------------------------------------------------===//
// AtomicRMWOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(AtomicRMWOp op) {
if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
return op.emitOpError(
"expects the number of subscripts to be equal to memref rank");
switch (op.getKind()) {
case AtomicRMWKind::addf:
case AtomicRMWKind::maxf:
case AtomicRMWKind::minf:
case AtomicRMWKind::mulf:
if (!op.getValue().getType().isa<FloatType>())
return op.emitOpError()
<< "with kind '" << stringifyAtomicRMWKind(op.getKind())
<< "' expects a floating-point type";
break;
case AtomicRMWKind::addi:
case AtomicRMWKind::maxs:
case AtomicRMWKind::maxu:
case AtomicRMWKind::mins:
case AtomicRMWKind::minu:
case AtomicRMWKind::muli:
case AtomicRMWKind::ori:
case AtomicRMWKind::andi:
if (!op.getValue().getType().isa<IntegerType>())
return op.emitOpError()
<< "with kind '" << stringifyAtomicRMWKind(op.getKind())
<< "' expects an integer type";
break;
default:
break;
}
return success();
}
/// Returns the identity value attribute associated with an AtomicRMWKind op.
Attribute mlir::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
OpBuilder &builder, Location loc) {
switch (kind) {
case AtomicRMWKind::maxf:
return builder.getFloatAttr(
resultType,
APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
/*Negative=*/true));
case AtomicRMWKind::addf:
case AtomicRMWKind::addi:
case AtomicRMWKind::maxu:
case AtomicRMWKind::ori:
return builder.getZeroAttr(resultType);
case AtomicRMWKind::andi:
return builder.getIntegerAttr(
resultType,
APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
case AtomicRMWKind::maxs:
return builder.getIntegerAttr(
resultType,
APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
case AtomicRMWKind::minf:
return builder.getFloatAttr(
resultType,
APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
/*Negative=*/false));
case AtomicRMWKind::mins:
return builder.getIntegerAttr(
resultType,
APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
case AtomicRMWKind::minu:
return builder.getIntegerAttr(
resultType,
APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
case AtomicRMWKind::muli:
return builder.getIntegerAttr(resultType, 1);
case AtomicRMWKind::mulf:
return builder.getFloatAttr(resultType, 1);
// TODO: Add remaining reduction operations.
default:
(void)emitOptionalError(loc, "Reduction operation type not supported");
break;
}
return nullptr;
}
/// Returns the identity value associated with an AtomicRMWKind op.
Value mlir::getIdentityValue(AtomicRMWKind op, Type resultType,
OpBuilder &builder, Location loc) {
Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
return builder.create<arith::ConstantOp>(loc, attr);
}
/// Return the value obtained by applying the reduction operation kind
/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
Value mlir::getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
Value lhs, Value rhs) {
switch (op) {
case AtomicRMWKind::addf:
return builder.create<arith::AddFOp>(loc, lhs, rhs);
case AtomicRMWKind::addi:
return builder.create<arith::AddIOp>(loc, lhs, rhs);
case AtomicRMWKind::mulf:
return builder.create<arith::MulFOp>(loc, lhs, rhs);
case AtomicRMWKind::muli:
return builder.create<arith::MulIOp>(loc, lhs, rhs);
case AtomicRMWKind::maxf:
return builder.create<arith::MaxFOp>(loc, lhs, rhs);
case AtomicRMWKind::minf:
return builder.create<arith::MinFOp>(loc, lhs, rhs);
case AtomicRMWKind::maxs:
return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
case AtomicRMWKind::mins:
return builder.create<arith::MinSIOp>(loc, lhs, rhs);
case AtomicRMWKind::maxu:
return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
case AtomicRMWKind::minu:
return builder.create<arith::MinUIOp>(loc, lhs, rhs);
case AtomicRMWKind::ori:
return builder.create<arith::OrIOp>(loc, lhs, rhs);
case AtomicRMWKind::andi:
return builder.create<arith::AndIOp>(loc, lhs, rhs);
// TODO: Add remaining reduction operations.
default:
(void)emitOptionalError(loc, "Reduction operation type not supported");
break;
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// GenericAtomicRMWOp
//===----------------------------------------------------------------------===//

View File

@ -40,18 +40,18 @@ namespace {
/// %new_value = select %cmp, %current, %fval : f32
/// atomic_yield %new_value : f32
/// }
struct AtomicRMWOpConverter : public OpRewritePattern<AtomicRMWOp> {
struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtomicRMWOp op,
LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
PatternRewriter &rewriter) const final {
arith::CmpFPredicate predicate;
switch (op.getKind()) {
case AtomicRMWKind::maxf:
switch (op.kind()) {
case arith::AtomicRMWKind::maxf:
predicate = arith::CmpFPredicate::OGT;
break;
case AtomicRMWKind::minf:
case arith::AtomicRMWKind::minf:
predicate = arith::CmpFPredicate::OLT;
break;
default:
@ -59,13 +59,13 @@ public:
}
auto loc = op.getLoc();
auto genericOp = rewriter.create<GenericAtomicRMWOp>(loc, op.getMemref(),
op.getIndices());
auto genericOp =
rewriter.create<GenericAtomicRMWOp>(loc, op.memref(), op.indices());
OpBuilder bodyBuilder =
OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener());
Value lhs = genericOp.getCurrentValue();
Value rhs = op.getValue();
Value rhs = op.value();
Value cmp = bodyBuilder.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
Value select = bodyBuilder.create<SelectOp>(loc, cmp, lhs, rhs);
bodyBuilder.create<AtomicYieldOp>(loc, select);
@ -130,10 +130,11 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect,
StandardOpsDialect>();
target.addDynamicallyLegalOp<AtomicRMWOp>([](AtomicRMWOp op) {
return op.getKind() != AtomicRMWKind::maxf &&
op.getKind() != AtomicRMWKind::minf;
});
target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
[](memref::AtomicRMWOp op) {
return op.kind() != arith::AtomicRMWKind::maxf &&
op.kind() != arith::AtomicRMWKind::minf;
});
target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
return !op.shape().getType().cast<MemRefType>().hasStaticShape();
});

View File

@ -359,41 +359,42 @@ static void print(OpAsmPrinter &p, ReductionOp op) {
p << " : " << op.vector().getType() << " into " << op.dest().getType();
}
Value mlir::vector::getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder,
Location loc, Value vector) {
Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
OpBuilder &builder, Location loc,
Value vector) {
Type scalarType = vector.getType().cast<ShapedType>().getElementType();
switch (op) {
case AtomicRMWKind::addf:
case AtomicRMWKind::addi:
case arith::AtomicRMWKind::addf:
case arith::AtomicRMWKind::addi:
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
builder.getStringAttr("add"),
vector, ValueRange{});
case AtomicRMWKind::mulf:
case AtomicRMWKind::muli:
case arith::AtomicRMWKind::mulf:
case arith::AtomicRMWKind::muli:
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
builder.getStringAttr("mul"),
vector, ValueRange{});
case AtomicRMWKind::minf:
case arith::AtomicRMWKind::minf:
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
builder.getStringAttr("minf"),
vector, ValueRange{});
case AtomicRMWKind::mins:
case arith::AtomicRMWKind::mins:
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
builder.getStringAttr("minsi"),
vector, ValueRange{});
case AtomicRMWKind::minu:
case arith::AtomicRMWKind::minu:
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
builder.getStringAttr("minui"),
vector, ValueRange{});
case AtomicRMWKind::maxf:
case arith::AtomicRMWKind::maxf:
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
builder.getStringAttr("maxf"),
vector, ValueRange{});
case AtomicRMWKind::maxs:
case arith::AtomicRMWKind::maxs:
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
builder.getStringAttr("maxsi"),
vector, ValueRange{});
case AtomicRMWKind::maxu:
case arith::AtomicRMWKind::maxu:
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
builder.getStringAttr("maxui"),
vector, ValueRange{});

View File

@ -1551,7 +1551,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp,
for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
rhs = forOp.getResult(i * oldNumResults + pos);
// Create ops based on reduction type.
lhs = getReductionOp(reduction.kind, builder, loc, lhs, rhs);
lhs = arith::getReductionOp(reduction.kind, builder, loc, lhs, rhs);
if (!lhs)
return failure();
Operation *op = lhs.getDefiningOp();

View File

@ -859,3 +859,28 @@ func @rank_of_ranked(%ranked: memref<?xi32>) {
}
// CHECK: llvm.mlir.constant(1 : index) : i64
// CHECK32: llvm.mlir.constant(1 : index) : i32
// -----
// CHECK-LABEL: func @atomic_rmw
func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval : f32, %i : index) {
memref.atomic_rmw assign %fval, %F[%i] : (f32, memref<10xf32>) -> f32
// CHECK: llvm.atomicrmw xchg %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw addi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw add %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw maxs %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw max %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw mins %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw min %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw maxu %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw umax %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw minu %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw umin %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw addf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw ori %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel
return
}

View File

@ -486,31 +486,6 @@ func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> {
// -----
// CHECK-LABEL: func @atomic_rmw
func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval : f32, %i : index) {
atomic_rmw assign %fval, %F[%i] : (f32, memref<10xf32>) -> f32
// CHECK: llvm.atomicrmw xchg %{{.*}}, %{{.*}} acq_rel
atomic_rmw addi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw add %{{.*}}, %{{.*}} acq_rel
atomic_rmw maxs %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw max %{{.*}}, %{{.*}} acq_rel
atomic_rmw mins %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw min %{{.*}}, %{{.*}} acq_rel
atomic_rmw maxu %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw umax %{{.*}}, %{{.*}} acq_rel
atomic_rmw minu %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw umin %{{.*}}, %{{.*}} acq_rel
atomic_rmw addf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel
atomic_rmw ori %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel
atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel
return
}
// -----
// CHECK-LABEL: func @generic_atomic_rmw
func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) -> i32 {
%x = generic_atomic_rmw %I[%i] : memref<10xi32> {

View File

@ -499,3 +499,14 @@ func @no_fold_dynamic_no_op_subview(%arg0 : memref<?x?xf32>) -> memref<?x?xf32,
// CHECK-LABEL: func @no_fold_dynamic_no_op_subview(
// CHECK: %[[SUBVIEW:.+]] = memref.subview
// CHECK: return %[[SUBVIEW]]
// -----
func @atomicrmw_cast_fold(%arg0 : f32, %arg1 : memref<4xf32>, %c : index) {
%v = memref.cast %arg1 : memref<4xf32> to memref<?xf32>
%a = memref.atomic_rmw addf %arg0, %v[%c] : (f32, memref<?xf32>) -> f32
return
}
// CHECK-LABEL: func @atomicrmw_cast_fold
// CHECK-NEXT: memref.atomic_rmw addf %arg0, %arg1[%arg2] : (f32, memref<4xf32>) -> f32

View File

@ -848,3 +848,27 @@ func @illegal_num_offsets(%arg0 : memref<?x?x?xf32>, %arg1 : index, %arg2 : inde
// expected-error@+1 {{expected 3 offset values}}
%0 = memref.subview %arg0[0, 0] [%arg1, %arg2] [1, 1] : memref<?x?x?xf32> to memref<?x?x?xf32, #map>
}
// -----
func @atomic_rmw_idxs_rank_mismatch(%I: memref<16x10xf32>, %i : index, %val : f32) {
// expected-error@+1 {{expects the number of subscripts to be equal to memref rank}}
%x = memref.atomic_rmw addf %val, %I[%i] : (f32, memref<16x10xf32>) -> f32
return
}
// -----
func @atomic_rmw_expects_float(%I: memref<16x10xi32>, %i : index, %val : i32) {
// expected-error@+1 {{expects a floating-point type}}
%x = memref.atomic_rmw addf %val, %I[%i, %i] : (i32, memref<16x10xi32>) -> i32
return
}
// -----
func @atomic_rmw_expects_int(%I: memref<16x10xf32>, %i : index, %val : f32) {
// expected-error@+1 {{expects an integer type}}
%x = memref.atomic_rmw addi %val, %I[%i, %i] : (f32, memref<16x10xf32>) -> f32
return
}

View File

@ -227,3 +227,13 @@ func @rank(%t : memref<4x4x?xf32>) {
%1 = memref.rank %t : memref<4x4x?xf32>
return
}
// ------
// CHECK-LABEL: func @atomic_rmw
// CHECK-SAME: ([[BUF:%.*]]: memref<10xf32>, [[VAL:%.*]]: f32, [[I:%.*]]: index)
func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) {
%x = memref.atomic_rmw addf %val, %I[%i] : (f32, memref<10xf32>) -> f32
// CHECK: memref.atomic_rmw addf [[VAL]], [[BUF]]{{\[}}[[I]]]
return
}

View File

@ -3,7 +3,7 @@
// CHECK-LABEL: func @atomic_rmw_to_generic
// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
%x = atomic_rmw maxf %f, %F[%i] : (f32, memref<10xf32>) -> f32
%x = memref.atomic_rmw maxf %f, %F[%i] : (f32, memref<10xf32>) -> f32
return %x : f32
}
// CHECK: %0 = generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
@ -18,7 +18,7 @@ func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
// CHECK-LABEL: func @atomic_rmw_no_conversion
func @atomic_rmw_no_conversion(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
%x = atomic_rmw addf %f, %F[%i] : (f32, memref<10xf32>) -> f32
%x = memref.atomic_rmw addf %f, %F[%i] : (f32, memref<10xf32>) -> f32
return %x : f32
}
// CHECK-NOT: generic_atomic_rmw

View File

@ -325,14 +325,6 @@ func @unranked_tensor_load_store(%0 : memref<*xi32>, %1 : tensor<*xi32>) {
return
}
// CHECK-LABEL: func @atomic_rmw
// CHECK-SAME: ([[BUF:%.*]]: memref<10xf32>, [[VAL:%.*]]: f32, [[I:%.*]]: index)
func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) {
%x = atomic_rmw addf %val, %I[%i] : (f32, memref<10xf32>) -> f32
// CHECK: atomic_rmw addf [[VAL]], [[BUF]]{{\[}}[[I]]]
return
}
// CHECK-LABEL: func @generic_atomic_rmw
// CHECK-SAME: ([[BUF:%.*]]: memref<1x2xf32>, [[I:%.*]]: index, [[J:%.*]]: index)
func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) {

View File

@ -130,30 +130,6 @@ func @invalid_splat(%v : f32) { // expected-note {{prior use here}}
// -----
func @atomic_rmw_idxs_rank_mismatch(%I: memref<16x10xf32>, %i : index, %val : f32) {
// expected-error@+1 {{expects the number of subscripts to be equal to memref rank}}
%x = atomic_rmw addf %val, %I[%i] : (f32, memref<16x10xf32>) -> f32
return
}
// -----
func @atomic_rmw_expects_float(%I: memref<16x10xi32>, %i : index, %val : i32) {
// expected-error@+1 {{expects a floating-point type}}
%x = atomic_rmw addf %val, %I[%i, %i] : (i32, memref<16x10xi32>) -> i32
return
}
// -----
func @atomic_rmw_expects_int(%I: memref<16x10xf32>, %i : index, %val : f32) {
// expected-error@+1 {{expects an integer type}}
%x = atomic_rmw addi %val, %I[%i, %i] : (f32, memref<16x10xf32>) -> f32
return
}
// -----
func @generic_atomic_rmw_wrong_arg_num(%I: memref<10xf32>, %i : index) {
// expected-error@+1 {{expected single number of entry block arguments}}
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {