forked from OSchip/llvm-project
[MLIR] Move AtomicRMW into MemRef dialect and enum into Arith
Per the discussion in https://reviews.llvm.org/D116345 it makes sense to move AtomicRMWOp out of the standard dialect. This was accentuated by the need to add a fold op with a memref::cast. The only dialect that would permit this is the memref dialect (keeping it in the standard dialect or moving it to the arithmetic dialect would require those dialects to have a dependency on the memref dialect, which breaks linking). As the AtomicRMWKind enum is used throughout, this has been moved to Arith. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D116392
This commit is contained in:
parent
9d37d0ea34
commit
a6a583dae4
|
@ -15,6 +15,7 @@
|
|||
#ifndef MLIR_ANALYSIS_AFFINE_ANALYSIS_H
|
||||
#define MLIR_ANALYSIS_AFFINE_ANALYSIS_H
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
|
@ -32,7 +33,7 @@ class Operation;
|
|||
/// A description of a (parallelizable) reduction in an affine loop.
|
||||
struct LoopReduction {
|
||||
/// Reduction kind.
|
||||
AtomicRMWKind kind;
|
||||
arith::AtomicRMWKind kind;
|
||||
|
||||
/// Position of the iteration argument that acts as accumulator.
|
||||
unsigned iterArgPosition;
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
#ifndef AFFINE_OPS
|
||||
#define AFFINE_OPS
|
||||
|
||||
include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td"
|
||||
include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
|
||||
include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/LoopLikeInterface.td"
|
||||
|
@ -691,9 +691,9 @@ def AffineParallelOp : Affine_Op<"parallel",
|
|||
|
||||
let builders = [
|
||||
OpBuilder<(ins "TypeRange":$resultTypes,
|
||||
"ArrayRef<AtomicRMWKind>":$reductions, "ArrayRef<int64_t>":$ranges)>,
|
||||
"ArrayRef<arith::AtomicRMWKind>":$reductions, "ArrayRef<int64_t>":$ranges)>,
|
||||
OpBuilder<(ins "TypeRange":$resultTypes,
|
||||
"ArrayRef<AtomicRMWKind>":$reductions, "ArrayRef<AffineMap>":$lbMaps,
|
||||
"ArrayRef<arith::AtomicRMWKind>":$reductions, "ArrayRef<AffineMap>":$lbMaps,
|
||||
"ValueRange":$lbArgs, "ArrayRef<AffineMap>":$ubMaps, "ValueRange":$ubArgs,
|
||||
"ArrayRef<int64_t>":$steps)>
|
||||
];
|
||||
|
|
|
@ -109,6 +109,18 @@ bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs,
|
|||
bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs,
|
||||
const APFloat &rhs);
|
||||
|
||||
/// Returns the identity value attribute associated with an AtomicRMWKind op.
|
||||
Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
|
||||
OpBuilder &builder, Location loc);
|
||||
|
||||
/// Returns the identity value associated with an AtomicRMWKind op.
|
||||
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder,
|
||||
Location loc);
|
||||
|
||||
/// Returns the value obtained by applying the reduction operation kind
|
||||
/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
|
||||
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
|
||||
Value lhs, Value rhs);
|
||||
} // namespace arith
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -68,4 +68,28 @@ def Arith_CmpIPredicateAttr : I64EnumAttr<
|
|||
let cppNamespace = "::mlir::arith";
|
||||
}
|
||||
|
||||
def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>;
|
||||
def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>;
|
||||
def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>;
|
||||
def ATOMIC_RMW_KIND_MAXF : I64EnumAttrCase<"maxf", 3>;
|
||||
def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>;
|
||||
def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>;
|
||||
def ATOMIC_RMW_KIND_MINF : I64EnumAttrCase<"minf", 6>;
|
||||
def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>;
|
||||
def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>;
|
||||
def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
|
||||
def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
|
||||
def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
|
||||
def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;
|
||||
|
||||
def AtomicRMWKindAttr : I64EnumAttr<
|
||||
"AtomicRMWKind", "",
|
||||
[ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
|
||||
ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
|
||||
ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
|
||||
ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
|
||||
ATOMIC_RMW_KIND_ANDI]> {
|
||||
let cppNamespace = "::mlir::arith";
|
||||
}
|
||||
|
||||
#endif // ARITHMETIC_BASE
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
|
||||
include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
include "mlir/Interfaces/CopyOpInterface.td"
|
||||
|
@ -1673,4 +1674,51 @@ def MemRef_ViewOp : MemRef_Op<"view", [
|
|||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtomicRMWOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
|
||||
AllTypesMatch<["value", "result"]>,
|
||||
TypesMatchWith<"value type matches element type of memref",
|
||||
"memref", "value",
|
||||
"$_self.cast<MemRefType>().getElementType()">
|
||||
]> {
|
||||
let summary = "atomic read-modify-write operation";
|
||||
let description = [{
|
||||
The `atomic_rmw` operation provides a way to perform a read-modify-write
|
||||
sequence that is free from data races. The kind enumeration specifies the
|
||||
modification to perform. The value operand represents the new value to be
|
||||
applied during the modification. The memref operand represents the buffer
|
||||
that the read and write will be performed against, as accessed by the
|
||||
specified indices. The arity of the indices is the rank of the memref. The
|
||||
result represents the latest value that was stored.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%x = arith.atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AtomicRMWKindAttr:$kind,
|
||||
AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$value,
|
||||
MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref,
|
||||
Variadic<Index>:$indices);
|
||||
let results = (outs AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$kind $value `,` $memref `[` $indices `]` attr-dict `:` `(` type($value) `,`
|
||||
type($memref) `)` `->` type($result)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getMemRefType() {
|
||||
return memref().getType().cast<MemRefType>();
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
#endif // MEMREF_OPS
|
||||
|
|
|
@ -42,31 +42,4 @@ class PatternRewriter;
|
|||
|
||||
#include "mlir/Dialect/StandardOps/IR/OpsDialect.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
|
||||
/// comparison predicates.
|
||||
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs,
|
||||
const APInt &rhs);
|
||||
|
||||
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
|
||||
/// comparison predicates.
|
||||
bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs,
|
||||
const APFloat &rhs);
|
||||
|
||||
/// Returns the identity value attribute associated with an AtomicRMWKind op.
|
||||
Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
|
||||
OpBuilder &builder, Location loc);
|
||||
|
||||
/// Returns the identity value associated with an AtomicRMWKind op.
|
||||
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder,
|
||||
Location loc);
|
||||
|
||||
/// Returns the value obtained by applying the reduction operation kind
|
||||
/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
|
||||
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
|
||||
Value lhs, Value rhs);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
#ifndef STANDARD_OPS
|
||||
#define STANDARD_OPS
|
||||
|
||||
include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td"
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
|
@ -179,52 +178,6 @@ def AssertOp : Std_Op<"assert"> {
|
|||
let hasCanonicalizeMethod = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtomicRMWOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def AtomicRMWOp : Std_Op<"atomic_rmw", [
|
||||
AllTypesMatch<["value", "result"]>,
|
||||
TypesMatchWith<"value type matches element type of memref",
|
||||
"memref", "value",
|
||||
"$_self.cast<MemRefType>().getElementType()">
|
||||
]> {
|
||||
let summary = "atomic read-modify-write operation";
|
||||
let description = [{
|
||||
The `atomic_rmw` operation provides a way to perform a read-modify-write
|
||||
sequence that is free from data races. The kind enumeration specifies the
|
||||
modification to perform. The value operand represents the new value to be
|
||||
applied during the modification. The memref operand represents the buffer
|
||||
that the read and write will be performed against, as accessed by the
|
||||
specified indices. The arity of the indices is the rank of the memref. The
|
||||
result represents the latest value that was stored.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%x = atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AtomicRMWKindAttr:$kind,
|
||||
AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$value,
|
||||
MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref,
|
||||
Variadic<Index>:$indices);
|
||||
let results = (outs AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$kind $value `,` $memref `[` $indices `]` attr-dict `:` `(` type($value) `,`
|
||||
type($memref) `)` `->` type($result)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getMemRefType() {
|
||||
return getMemref().getType().cast<MemRefType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [
|
||||
SingleBlockImplicitTerminator<"AtomicYieldOp">,
|
||||
TypesMatchWith<"result type matches element type of memref",
|
||||
|
|
|
@ -1,42 +0,0 @@
|
|||
//===- StandardOpsBase.td - Standard ops definitions -------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Defines base support for standard operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef STANDARD_OPS_BASE
|
||||
#define STANDARD_OPS_BASE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>;
|
||||
def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>;
|
||||
def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>;
|
||||
def ATOMIC_RMW_KIND_MAXF : I64EnumAttrCase<"maxf", 3>;
|
||||
def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>;
|
||||
def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>;
|
||||
def ATOMIC_RMW_KIND_MINF : I64EnumAttrCase<"minf", 6>;
|
||||
def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>;
|
||||
def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>;
|
||||
def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
|
||||
def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
|
||||
def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
|
||||
def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;
|
||||
|
||||
def AtomicRMWKindAttr : I64EnumAttr<
|
||||
"AtomicRMWKind", "",
|
||||
[ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
|
||||
ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
|
||||
ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
|
||||
ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
|
||||
ATOMIC_RMW_KIND_ANDI]> {
|
||||
let cppNamespace = "::mlir";
|
||||
}
|
||||
|
||||
#endif // STANDARD_OPS_BASE
|
|
@ -13,6 +13,7 @@
|
|||
#ifndef MLIR_DIALECT_VECTOR_VECTOROPS_H
|
||||
#define MLIR_DIALECT_VECTOR_VECTOROPS_H
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
|
@ -145,8 +146,8 @@ ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values);
|
|||
|
||||
/// Returns the value obtained by reducing the vector into a scalar using the
|
||||
/// operation kind associated with a binary AtomicRMWKind op.
|
||||
Value getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
|
||||
Value vector);
|
||||
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder,
|
||||
Location loc, Value vector);
|
||||
|
||||
/// Return true if the last dimension of the MemRefType has unit stride. Also
|
||||
/// return true for memrefs with no strides.
|
||||
|
|
|
@ -40,7 +40,7 @@ using llvm::dbgs;
|
|||
/// reduction kind suitable for use in affine parallel loop builder. If the
|
||||
/// reduction is not supported, returns null.
|
||||
static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
|
||||
AtomicRMWKind &kind) {
|
||||
arith::AtomicRMWKind &kind) {
|
||||
SmallVector<Operation *> combinerOps;
|
||||
Value reducedVal =
|
||||
matchReduction(forOp.getRegionIterArgs(), pos, combinerOps);
|
||||
|
@ -52,21 +52,21 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
|
|||
return nullptr;
|
||||
|
||||
Operation *combinerOp = combinerOps.back();
|
||||
Optional<AtomicRMWKind> maybeKind =
|
||||
TypeSwitch<Operation *, Optional<AtomicRMWKind>>(combinerOp)
|
||||
.Case([](arith::AddFOp) { return AtomicRMWKind::addf; })
|
||||
.Case([](arith::MulFOp) { return AtomicRMWKind::mulf; })
|
||||
.Case([](arith::AddIOp) { return AtomicRMWKind::addi; })
|
||||
.Case([](arith::AndIOp) { return AtomicRMWKind::andi; })
|
||||
.Case([](arith::OrIOp) { return AtomicRMWKind::ori; })
|
||||
.Case([](arith::MulIOp) { return AtomicRMWKind::muli; })
|
||||
.Case([](arith::MinFOp) { return AtomicRMWKind::minf; })
|
||||
.Case([](arith::MaxFOp) { return AtomicRMWKind::maxf; })
|
||||
.Case([](arith::MinSIOp) { return AtomicRMWKind::mins; })
|
||||
.Case([](arith::MaxSIOp) { return AtomicRMWKind::maxs; })
|
||||
.Case([](arith::MinUIOp) { return AtomicRMWKind::minu; })
|
||||
.Case([](arith::MaxUIOp) { return AtomicRMWKind::maxu; })
|
||||
.Default([](Operation *) -> Optional<AtomicRMWKind> {
|
||||
Optional<arith::AtomicRMWKind> maybeKind =
|
||||
TypeSwitch<Operation *, Optional<arith::AtomicRMWKind>>(combinerOp)
|
||||
.Case([](arith::AddFOp) { return arith::AtomicRMWKind::addf; })
|
||||
.Case([](arith::MulFOp) { return arith::AtomicRMWKind::mulf; })
|
||||
.Case([](arith::AddIOp) { return arith::AtomicRMWKind::addi; })
|
||||
.Case([](arith::AndIOp) { return arith::AtomicRMWKind::andi; })
|
||||
.Case([](arith::OrIOp) { return arith::AtomicRMWKind::ori; })
|
||||
.Case([](arith::MulIOp) { return arith::AtomicRMWKind::muli; })
|
||||
.Case([](arith::MinFOp) { return arith::AtomicRMWKind::minf; })
|
||||
.Case([](arith::MaxFOp) { return arith::AtomicRMWKind::maxf; })
|
||||
.Case([](arith::MinSIOp) { return arith::AtomicRMWKind::mins; })
|
||||
.Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; })
|
||||
.Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; })
|
||||
.Case([](arith::MaxUIOp) { return arith::AtomicRMWKind::maxu; })
|
||||
.Default([](Operation *) -> Optional<arith::AtomicRMWKind> {
|
||||
// TODO: AtomicRMW supports other kinds of reductions this is
|
||||
// currently not detecting, add those when the need arises.
|
||||
return llvm::None;
|
||||
|
@ -86,7 +86,7 @@ void mlir::getSupportedReductions(
|
|||
return;
|
||||
supportedReductions.reserve(numIterArgs);
|
||||
for (unsigned i = 0; i < numIterArgs; ++i) {
|
||||
AtomicRMWKind kind;
|
||||
arith::AtomicRMWKind kind;
|
||||
if (Value value = getSupportedReduction(forOp, i, kind))
|
||||
supportedReductions.emplace_back(LoopReduction{kind, i, value});
|
||||
}
|
||||
|
|
|
@ -430,13 +430,14 @@ public:
|
|||
// initialization of the result values.
|
||||
Attribute reduction = std::get<0>(pair);
|
||||
Type resultType = std::get<1>(pair);
|
||||
Optional<AtomicRMWKind> reductionOp = symbolizeAtomicRMWKind(
|
||||
static_cast<uint64_t>(reduction.cast<IntegerAttr>().getInt()));
|
||||
Optional<arith::AtomicRMWKind> reductionOp =
|
||||
arith::symbolizeAtomicRMWKind(
|
||||
static_cast<uint64_t>(reduction.cast<IntegerAttr>().getInt()));
|
||||
assert(reductionOp.hasValue() &&
|
||||
"Reduction operation cannot be of None Type");
|
||||
AtomicRMWKind reductionOpValue = reductionOp.getValue();
|
||||
arith::AtomicRMWKind reductionOpValue = reductionOp.getValue();
|
||||
identityVals.push_back(
|
||||
getIdentityValue(reductionOpValue, resultType, rewriter, loc));
|
||||
arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc));
|
||||
}
|
||||
parOp = rewriter.create<scf::ParallelOp>(
|
||||
loc, lowerBoundTuple, upperBoundTuple, steps, identityVals,
|
||||
|
@ -450,16 +451,17 @@ public:
|
|||
"Unequal number of reductions and operands.");
|
||||
for (unsigned i = 0, end = reductions.size(); i < end; i++) {
|
||||
// For each of the reduction operations get the respective mlir::Value.
|
||||
Optional<AtomicRMWKind> reductionOp =
|
||||
symbolizeAtomicRMWKind(reductions[i].cast<IntegerAttr>().getInt());
|
||||
Optional<arith::AtomicRMWKind> reductionOp =
|
||||
arith::symbolizeAtomicRMWKind(
|
||||
reductions[i].cast<IntegerAttr>().getInt());
|
||||
assert(reductionOp.hasValue() &&
|
||||
"Reduction Operation cannot be of None Type");
|
||||
AtomicRMWKind reductionOpValue = reductionOp.getValue();
|
||||
arith::AtomicRMWKind reductionOpValue = reductionOp.getValue();
|
||||
rewriter.setInsertionPoint(&parOp.getBody()->back());
|
||||
auto reduceOp = rewriter.create<scf::ReduceOp>(
|
||||
loc, affineParOpTerminator->getOperand(i));
|
||||
rewriter.setInsertionPointToEnd(&reduceOp.getReductionOperator().front());
|
||||
Value reductionResult = getReductionOp(
|
||||
Value reductionResult = arith::getReductionOp(
|
||||
reductionOpValue, rewriter, loc,
|
||||
reduceOp.getReductionOperator().front().getArgument(0),
|
||||
reduceOp.getReductionOperator().front().getArgument(1));
|
||||
|
|
|
@ -1553,6 +1553,62 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
|
|||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtomicRMWOpLowering
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Try to match the kind of a std.atomic_rmw to determine whether to use a
|
||||
/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
|
||||
static Optional<LLVM::AtomicBinOp>
|
||||
matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
|
||||
switch (atomicOp.kind()) {
|
||||
case arith::AtomicRMWKind::addf:
|
||||
return LLVM::AtomicBinOp::fadd;
|
||||
case arith::AtomicRMWKind::addi:
|
||||
return LLVM::AtomicBinOp::add;
|
||||
case arith::AtomicRMWKind::assign:
|
||||
return LLVM::AtomicBinOp::xchg;
|
||||
case arith::AtomicRMWKind::maxs:
|
||||
return LLVM::AtomicBinOp::max;
|
||||
case arith::AtomicRMWKind::maxu:
|
||||
return LLVM::AtomicBinOp::umax;
|
||||
case arith::AtomicRMWKind::mins:
|
||||
return LLVM::AtomicBinOp::min;
|
||||
case arith::AtomicRMWKind::minu:
|
||||
return LLVM::AtomicBinOp::umin;
|
||||
case arith::AtomicRMWKind::ori:
|
||||
return LLVM::AtomicBinOp::_or;
|
||||
case arith::AtomicRMWKind::andi:
|
||||
return LLVM::AtomicBinOp::_and;
|
||||
default:
|
||||
return llvm::None;
|
||||
}
|
||||
llvm_unreachable("Invalid AtomicRMWKind");
|
||||
}
|
||||
|
||||
struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
|
||||
using Base::Base;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(match(atomicOp)))
|
||||
return failure();
|
||||
auto maybeKind = matchSimpleAtomicOp(atomicOp);
|
||||
if (!maybeKind)
|
||||
return failure();
|
||||
auto resultType = adaptor.value().getType();
|
||||
auto memRefType = atomicOp.getMemRefType();
|
||||
auto dataPtr =
|
||||
getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
|
||||
adaptor.indices(), rewriter);
|
||||
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
|
||||
atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
|
||||
LLVM::AtomicOrdering::acq_rel);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
|
@ -1561,6 +1617,7 @@ void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
|||
patterns.add<
|
||||
AllocaOpLowering,
|
||||
AllocaScopeOpLowering,
|
||||
AtomicRMWOpLowering,
|
||||
AssumeAlignmentOpLowering,
|
||||
DimOpLowering,
|
||||
GlobalMemrefOpLowering,
|
||||
|
|
|
@ -772,61 +772,6 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
|
|||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
/// Try to match the kind of a std.atomic_rmw to determine whether to use a
|
||||
/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
|
||||
static Optional<LLVM::AtomicBinOp> matchSimpleAtomicOp(AtomicRMWOp atomicOp) {
|
||||
switch (atomicOp.getKind()) {
|
||||
case AtomicRMWKind::addf:
|
||||
return LLVM::AtomicBinOp::fadd;
|
||||
case AtomicRMWKind::addi:
|
||||
return LLVM::AtomicBinOp::add;
|
||||
case AtomicRMWKind::assign:
|
||||
return LLVM::AtomicBinOp::xchg;
|
||||
case AtomicRMWKind::maxs:
|
||||
return LLVM::AtomicBinOp::max;
|
||||
case AtomicRMWKind::maxu:
|
||||
return LLVM::AtomicBinOp::umax;
|
||||
case AtomicRMWKind::mins:
|
||||
return LLVM::AtomicBinOp::min;
|
||||
case AtomicRMWKind::minu:
|
||||
return LLVM::AtomicBinOp::umin;
|
||||
case AtomicRMWKind::ori:
|
||||
return LLVM::AtomicBinOp::_or;
|
||||
case AtomicRMWKind::andi:
|
||||
return LLVM::AtomicBinOp::_and;
|
||||
default:
|
||||
return llvm::None;
|
||||
}
|
||||
llvm_unreachable("Invalid AtomicRMWKind");
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
|
||||
using Base::Base;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(AtomicRMWOp atomicOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(match(atomicOp)))
|
||||
return failure();
|
||||
auto maybeKind = matchSimpleAtomicOp(atomicOp);
|
||||
if (!maybeKind)
|
||||
return failure();
|
||||
auto resultType = adaptor.getValue().getType();
|
||||
auto memRefType = atomicOp.getMemRefType();
|
||||
auto dataPtr =
|
||||
getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
|
||||
adaptor.getIndices(), rewriter);
|
||||
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
|
||||
atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(),
|
||||
LLVM::AtomicOrdering::acq_rel);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
|
||||
/// retried until it succeeds in atomically storing a new value into memory.
|
||||
///
|
||||
|
@ -962,7 +907,6 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
|||
// clang-format off
|
||||
patterns.add<
|
||||
AssertOpLowering,
|
||||
AtomicRMWOpLowering,
|
||||
BranchOpLowering,
|
||||
CallIndirectOpLowering,
|
||||
CallOpLowering,
|
||||
|
|
|
@ -2801,7 +2801,7 @@ LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
|
|||
|
||||
void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
|
||||
TypeRange resultTypes,
|
||||
ArrayRef<AtomicRMWKind> reductions,
|
||||
ArrayRef<arith::AtomicRMWKind> reductions,
|
||||
ArrayRef<int64_t> ranges) {
|
||||
SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
|
||||
auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
|
||||
|
@ -2814,7 +2814,7 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
|
|||
|
||||
void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
|
||||
TypeRange resultTypes,
|
||||
ArrayRef<AtomicRMWKind> reductions,
|
||||
ArrayRef<arith::AtomicRMWKind> reductions,
|
||||
ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
|
||||
ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
|
||||
ArrayRef<int64_t> steps) {
|
||||
|
@ -2843,7 +2843,7 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
|
|||
|
||||
// Convert the reductions to integer attributes.
|
||||
SmallVector<Attribute, 4> reductionAttrs;
|
||||
for (AtomicRMWKind reduction : reductions)
|
||||
for (arith::AtomicRMWKind reduction : reductions)
|
||||
reductionAttrs.push_back(
|
||||
builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
|
||||
result.addAttribute(getReductionsAttrName(),
|
||||
|
@ -3050,7 +3050,7 @@ static LogicalResult verify(AffineParallelOp op) {
|
|||
// Verify reduction ops are all valid
|
||||
for (Attribute attr : op.reductions()) {
|
||||
auto intAttr = attr.dyn_cast<IntegerAttr>();
|
||||
if (!intAttr || !symbolizeAtomicRMWKind(intAttr.getInt()))
|
||||
if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
|
||||
return op.emitOpError("invalid reduction attribute");
|
||||
}
|
||||
|
||||
|
@ -3150,9 +3150,9 @@ static void print(OpAsmPrinter &p, AffineParallelOp op) {
|
|||
if (op.getNumResults()) {
|
||||
p << " reduce (";
|
||||
llvm::interleaveComma(op.reductions(), p, [&](auto &attr) {
|
||||
AtomicRMWKind sym =
|
||||
*symbolizeAtomicRMWKind(attr.template cast<IntegerAttr>().getInt());
|
||||
p << "\"" << stringifyAtomicRMWKind(sym) << "\"";
|
||||
arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
|
||||
attr.template cast<IntegerAttr>().getInt());
|
||||
p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
|
||||
});
|
||||
p << ") -> (" << op.getResultTypes() << ")";
|
||||
}
|
||||
|
@ -3374,8 +3374,8 @@ static ParseResult parseAffineParallelOp(OpAsmParser &parser,
|
|||
if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
|
||||
attrStorage))
|
||||
return failure();
|
||||
llvm::Optional<AtomicRMWKind> reduction =
|
||||
symbolizeAtomicRMWKind(attrVal.getValue());
|
||||
llvm::Optional<arith::AtomicRMWKind> reduction =
|
||||
arith::symbolizeAtomicRMWKind(attrVal.getValue());
|
||||
if (!reduction)
|
||||
return parser.emitError(loc, "invalid reduction value: ") << attrVal;
|
||||
reductions.push_back(builder.getI64IntegerAttr(
|
||||
|
|
|
@ -971,7 +971,7 @@ static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp,
|
|||
/// Creates a constant vector filled with the neutral elements of the given
|
||||
/// reduction. The scalar type of vector elements will be taken from
|
||||
/// `oldOperand`.
|
||||
static arith::ConstantOp createInitialVector(AtomicRMWKind reductionKind,
|
||||
static arith::ConstantOp createInitialVector(arith::AtomicRMWKind reductionKind,
|
||||
Value oldOperand,
|
||||
VectorizationState &state) {
|
||||
Type scalarTy = oldOperand.getType();
|
||||
|
@ -1245,8 +1245,8 @@ static Operation *vectorizeAffineStore(AffineStoreOp storeOp,
|
|||
|
||||
/// Returns true if `value` is a constant equal to the neutral element of the
|
||||
/// given vectorizable reduction.
|
||||
static bool isNeutralElementConst(AtomicRMWKind reductionKind, Value value,
|
||||
VectorizationState &state) {
|
||||
static bool isNeutralElementConst(arith::AtomicRMWKind reductionKind,
|
||||
Value value, VectorizationState &state) {
|
||||
Type scalarTy = value.getType();
|
||||
if (!VectorType::isValidElementType(scalarTy))
|
||||
return false;
|
||||
|
@ -1361,7 +1361,8 @@ static Operation *vectorizeAffineForOp(AffineForOp forOp,
|
|||
Value origInit = forOp.getOperand(forOp.getNumControlOperands() + i);
|
||||
Value finalRes = reducedRes;
|
||||
if (!isNeutralElementConst(reductions[i].kind, origInit, state))
|
||||
finalRes = getReductionOp(reductions[i].kind, state.builder,
|
||||
finalRes =
|
||||
arith::getReductionOp(reductions[i].kind, state.builder,
|
||||
reducedRes.getLoc(), reducedRes, origInit);
|
||||
state.registerLoopResultScalarReplacement(forOp.getResult(i), finalRes);
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/CommonFolders.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
|
@ -1208,6 +1209,101 @@ OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
|
|||
return BoolAttr::get(getContext(), val);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Atomic Enum
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the identity value attribute associated with an AtomicRMWKind op.
|
||||
Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
|
||||
OpBuilder &builder, Location loc) {
|
||||
switch (kind) {
|
||||
case AtomicRMWKind::maxf:
|
||||
return builder.getFloatAttr(
|
||||
resultType,
|
||||
APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
|
||||
/*Negative=*/true));
|
||||
case AtomicRMWKind::addf:
|
||||
case AtomicRMWKind::addi:
|
||||
case AtomicRMWKind::maxu:
|
||||
case AtomicRMWKind::ori:
|
||||
return builder.getZeroAttr(resultType);
|
||||
case AtomicRMWKind::andi:
|
||||
return builder.getIntegerAttr(
|
||||
resultType,
|
||||
APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
|
||||
case AtomicRMWKind::maxs:
|
||||
return builder.getIntegerAttr(
|
||||
resultType,
|
||||
APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
|
||||
case AtomicRMWKind::minf:
|
||||
return builder.getFloatAttr(
|
||||
resultType,
|
||||
APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
|
||||
/*Negative=*/false));
|
||||
case AtomicRMWKind::mins:
|
||||
return builder.getIntegerAttr(
|
||||
resultType,
|
||||
APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
|
||||
case AtomicRMWKind::minu:
|
||||
return builder.getIntegerAttr(
|
||||
resultType,
|
||||
APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
|
||||
case AtomicRMWKind::muli:
|
||||
return builder.getIntegerAttr(resultType, 1);
|
||||
case AtomicRMWKind::mulf:
|
||||
return builder.getFloatAttr(resultType, 1);
|
||||
// TODO: Add remaining reduction operations.
|
||||
default:
|
||||
(void)emitOptionalError(loc, "Reduction operation type not supported");
|
||||
break;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Returns the identity value associated with an AtomicRMWKind op.
|
||||
Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
|
||||
OpBuilder &builder, Location loc) {
|
||||
Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
|
||||
return builder.create<arith::ConstantOp>(loc, attr);
|
||||
}
|
||||
|
||||
/// Return the value obtained by applying the reduction operation kind
|
||||
/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
|
||||
Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
|
||||
Location loc, Value lhs, Value rhs) {
|
||||
switch (op) {
|
||||
case AtomicRMWKind::addf:
|
||||
return builder.create<arith::AddFOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::addi:
|
||||
return builder.create<arith::AddIOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::mulf:
|
||||
return builder.create<arith::MulFOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::muli:
|
||||
return builder.create<arith::MulIOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::maxf:
|
||||
return builder.create<arith::MaxFOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::minf:
|
||||
return builder.create<arith::MinFOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::maxs:
|
||||
return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::mins:
|
||||
return builder.create<arith::MinSIOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::maxu:
|
||||
return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::minu:
|
||||
return builder.create<arith::MinUIOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::ori:
|
||||
return builder.create<arith::OrIOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::andi:
|
||||
return builder.create<arith::AndIOp>(loc, lhs, rhs);
|
||||
// TODO: Add remaining reduction operations.
|
||||
default:
|
||||
(void)emitOptionalError(loc, "Reduction operation type not supported");
|
||||
break;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -2286,6 +2286,50 @@ void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtomicRMWOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(AtomicRMWOp op) {
|
||||
if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
|
||||
return op.emitOpError(
|
||||
"expects the number of subscripts to be equal to memref rank");
|
||||
switch (op.kind()) {
|
||||
case arith::AtomicRMWKind::addf:
|
||||
case arith::AtomicRMWKind::maxf:
|
||||
case arith::AtomicRMWKind::minf:
|
||||
case arith::AtomicRMWKind::mulf:
|
||||
if (!op.value().getType().isa<FloatType>())
|
||||
return op.emitOpError()
|
||||
<< "with kind '" << arith::stringifyAtomicRMWKind(op.kind())
|
||||
<< "' expects a floating-point type";
|
||||
break;
|
||||
case arith::AtomicRMWKind::addi:
|
||||
case arith::AtomicRMWKind::maxs:
|
||||
case arith::AtomicRMWKind::maxu:
|
||||
case arith::AtomicRMWKind::mins:
|
||||
case arith::AtomicRMWKind::minu:
|
||||
case arith::AtomicRMWKind::muli:
|
||||
case arith::AtomicRMWKind::ori:
|
||||
case arith::AtomicRMWKind::andi:
|
||||
if (!op.value().getType().isa<IntegerType>())
|
||||
return op.emitOpError()
|
||||
<< "with kind '" << arith::stringifyAtomicRMWKind(op.kind())
|
||||
<< "' expects an integer type";
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) {
|
||||
/// atomicrmw(memrefcast) -> atomicrmw
|
||||
if (succeeded(foldMemRefCast(*this, value())))
|
||||
return getResult();
|
||||
return OpFoldResult();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -131,134 +131,6 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
|
|||
return failure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtomicRMWOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(AtomicRMWOp op) {
|
||||
if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
|
||||
return op.emitOpError(
|
||||
"expects the number of subscripts to be equal to memref rank");
|
||||
switch (op.getKind()) {
|
||||
case AtomicRMWKind::addf:
|
||||
case AtomicRMWKind::maxf:
|
||||
case AtomicRMWKind::minf:
|
||||
case AtomicRMWKind::mulf:
|
||||
if (!op.getValue().getType().isa<FloatType>())
|
||||
return op.emitOpError()
|
||||
<< "with kind '" << stringifyAtomicRMWKind(op.getKind())
|
||||
<< "' expects a floating-point type";
|
||||
break;
|
||||
case AtomicRMWKind::addi:
|
||||
case AtomicRMWKind::maxs:
|
||||
case AtomicRMWKind::maxu:
|
||||
case AtomicRMWKind::mins:
|
||||
case AtomicRMWKind::minu:
|
||||
case AtomicRMWKind::muli:
|
||||
case AtomicRMWKind::ori:
|
||||
case AtomicRMWKind::andi:
|
||||
if (!op.getValue().getType().isa<IntegerType>())
|
||||
return op.emitOpError()
|
||||
<< "with kind '" << stringifyAtomicRMWKind(op.getKind())
|
||||
<< "' expects an integer type";
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Returns the identity value attribute associated with an AtomicRMWKind op.
|
||||
Attribute mlir::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
|
||||
OpBuilder &builder, Location loc) {
|
||||
switch (kind) {
|
||||
case AtomicRMWKind::maxf:
|
||||
return builder.getFloatAttr(
|
||||
resultType,
|
||||
APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
|
||||
/*Negative=*/true));
|
||||
case AtomicRMWKind::addf:
|
||||
case AtomicRMWKind::addi:
|
||||
case AtomicRMWKind::maxu:
|
||||
case AtomicRMWKind::ori:
|
||||
return builder.getZeroAttr(resultType);
|
||||
case AtomicRMWKind::andi:
|
||||
return builder.getIntegerAttr(
|
||||
resultType,
|
||||
APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
|
||||
case AtomicRMWKind::maxs:
|
||||
return builder.getIntegerAttr(
|
||||
resultType,
|
||||
APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
|
||||
case AtomicRMWKind::minf:
|
||||
return builder.getFloatAttr(
|
||||
resultType,
|
||||
APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
|
||||
/*Negative=*/false));
|
||||
case AtomicRMWKind::mins:
|
||||
return builder.getIntegerAttr(
|
||||
resultType,
|
||||
APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
|
||||
case AtomicRMWKind::minu:
|
||||
return builder.getIntegerAttr(
|
||||
resultType,
|
||||
APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
|
||||
case AtomicRMWKind::muli:
|
||||
return builder.getIntegerAttr(resultType, 1);
|
||||
case AtomicRMWKind::mulf:
|
||||
return builder.getFloatAttr(resultType, 1);
|
||||
// TODO: Add remaining reduction operations.
|
||||
default:
|
||||
(void)emitOptionalError(loc, "Reduction operation type not supported");
|
||||
break;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Returns the identity value associated with an AtomicRMWKind op.
|
||||
Value mlir::getIdentityValue(AtomicRMWKind op, Type resultType,
|
||||
OpBuilder &builder, Location loc) {
|
||||
Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
|
||||
return builder.create<arith::ConstantOp>(loc, attr);
|
||||
}
|
||||
|
||||
/// Return the value obtained by applying the reduction operation kind
|
||||
/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
|
||||
Value mlir::getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
|
||||
Value lhs, Value rhs) {
|
||||
switch (op) {
|
||||
case AtomicRMWKind::addf:
|
||||
return builder.create<arith::AddFOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::addi:
|
||||
return builder.create<arith::AddIOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::mulf:
|
||||
return builder.create<arith::MulFOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::muli:
|
||||
return builder.create<arith::MulIOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::maxf:
|
||||
return builder.create<arith::MaxFOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::minf:
|
||||
return builder.create<arith::MinFOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::maxs:
|
||||
return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::mins:
|
||||
return builder.create<arith::MinSIOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::maxu:
|
||||
return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::minu:
|
||||
return builder.create<arith::MinUIOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::ori:
|
||||
return builder.create<arith::OrIOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::andi:
|
||||
return builder.create<arith::AndIOp>(loc, lhs, rhs);
|
||||
// TODO: Add remaining reduction operations.
|
||||
default:
|
||||
(void)emitOptionalError(loc, "Reduction operation type not supported");
|
||||
break;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericAtomicRMWOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -40,18 +40,18 @@ namespace {
|
|||
/// %new_value = select %cmp, %current, %fval : f32
|
||||
/// atomic_yield %new_value : f32
|
||||
/// }
|
||||
struct AtomicRMWOpConverter : public OpRewritePattern<AtomicRMWOp> {
|
||||
struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(AtomicRMWOp op,
|
||||
LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
arith::CmpFPredicate predicate;
|
||||
switch (op.getKind()) {
|
||||
case AtomicRMWKind::maxf:
|
||||
switch (op.kind()) {
|
||||
case arith::AtomicRMWKind::maxf:
|
||||
predicate = arith::CmpFPredicate::OGT;
|
||||
break;
|
||||
case AtomicRMWKind::minf:
|
||||
case arith::AtomicRMWKind::minf:
|
||||
predicate = arith::CmpFPredicate::OLT;
|
||||
break;
|
||||
default:
|
||||
|
@ -59,13 +59,13 @@ public:
|
|||
}
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto genericOp = rewriter.create<GenericAtomicRMWOp>(loc, op.getMemref(),
|
||||
op.getIndices());
|
||||
auto genericOp =
|
||||
rewriter.create<GenericAtomicRMWOp>(loc, op.memref(), op.indices());
|
||||
OpBuilder bodyBuilder =
|
||||
OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener());
|
||||
|
||||
Value lhs = genericOp.getCurrentValue();
|
||||
Value rhs = op.getValue();
|
||||
Value rhs = op.value();
|
||||
Value cmp = bodyBuilder.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
|
||||
Value select = bodyBuilder.create<SelectOp>(loc, cmp, lhs, rhs);
|
||||
bodyBuilder.create<AtomicYieldOp>(loc, select);
|
||||
|
@ -130,10 +130,11 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
|
|||
|
||||
target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect,
|
||||
StandardOpsDialect>();
|
||||
target.addDynamicallyLegalOp<AtomicRMWOp>([](AtomicRMWOp op) {
|
||||
return op.getKind() != AtomicRMWKind::maxf &&
|
||||
op.getKind() != AtomicRMWKind::minf;
|
||||
});
|
||||
target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
|
||||
[](memref::AtomicRMWOp op) {
|
||||
return op.kind() != arith::AtomicRMWKind::maxf &&
|
||||
op.kind() != arith::AtomicRMWKind::minf;
|
||||
});
|
||||
target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
|
||||
return !op.shape().getType().cast<MemRefType>().hasStaticShape();
|
||||
});
|
||||
|
|
|
@ -359,41 +359,42 @@ static void print(OpAsmPrinter &p, ReductionOp op) {
|
|||
p << " : " << op.vector().getType() << " into " << op.dest().getType();
|
||||
}
|
||||
|
||||
Value mlir::vector::getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder,
|
||||
Location loc, Value vector) {
|
||||
Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
|
||||
OpBuilder &builder, Location loc,
|
||||
Value vector) {
|
||||
Type scalarType = vector.getType().cast<ShapedType>().getElementType();
|
||||
switch (op) {
|
||||
case AtomicRMWKind::addf:
|
||||
case AtomicRMWKind::addi:
|
||||
case arith::AtomicRMWKind::addf:
|
||||
case arith::AtomicRMWKind::addi:
|
||||
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
|
||||
builder.getStringAttr("add"),
|
||||
vector, ValueRange{});
|
||||
case AtomicRMWKind::mulf:
|
||||
case AtomicRMWKind::muli:
|
||||
case arith::AtomicRMWKind::mulf:
|
||||
case arith::AtomicRMWKind::muli:
|
||||
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
|
||||
builder.getStringAttr("mul"),
|
||||
vector, ValueRange{});
|
||||
case AtomicRMWKind::minf:
|
||||
case arith::AtomicRMWKind::minf:
|
||||
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
|
||||
builder.getStringAttr("minf"),
|
||||
vector, ValueRange{});
|
||||
case AtomicRMWKind::mins:
|
||||
case arith::AtomicRMWKind::mins:
|
||||
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
|
||||
builder.getStringAttr("minsi"),
|
||||
vector, ValueRange{});
|
||||
case AtomicRMWKind::minu:
|
||||
case arith::AtomicRMWKind::minu:
|
||||
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
|
||||
builder.getStringAttr("minui"),
|
||||
vector, ValueRange{});
|
||||
case AtomicRMWKind::maxf:
|
||||
case arith::AtomicRMWKind::maxf:
|
||||
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
|
||||
builder.getStringAttr("maxf"),
|
||||
vector, ValueRange{});
|
||||
case AtomicRMWKind::maxs:
|
||||
case arith::AtomicRMWKind::maxs:
|
||||
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
|
||||
builder.getStringAttr("maxsi"),
|
||||
vector, ValueRange{});
|
||||
case AtomicRMWKind::maxu:
|
||||
case arith::AtomicRMWKind::maxu:
|
||||
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
|
||||
builder.getStringAttr("maxui"),
|
||||
vector, ValueRange{});
|
||||
|
|
|
@ -1551,7 +1551,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp,
|
|||
for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
|
||||
rhs = forOp.getResult(i * oldNumResults + pos);
|
||||
// Create ops based on reduction type.
|
||||
lhs = getReductionOp(reduction.kind, builder, loc, lhs, rhs);
|
||||
lhs = arith::getReductionOp(reduction.kind, builder, loc, lhs, rhs);
|
||||
if (!lhs)
|
||||
return failure();
|
||||
Operation *op = lhs.getDefiningOp();
|
||||
|
|
|
@ -859,3 +859,28 @@ func @rank_of_ranked(%ranked: memref<?xi32>) {
|
|||
}
|
||||
// CHECK: llvm.mlir.constant(1 : index) : i64
|
||||
// CHECK32: llvm.mlir.constant(1 : index) : i32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @atomic_rmw
|
||||
func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval : f32, %i : index) {
|
||||
memref.atomic_rmw assign %fval, %F[%i] : (f32, memref<10xf32>) -> f32
|
||||
// CHECK: llvm.atomicrmw xchg %{{.*}}, %{{.*}} acq_rel
|
||||
memref.atomic_rmw addi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
|
||||
// CHECK: llvm.atomicrmw add %{{.*}}, %{{.*}} acq_rel
|
||||
memref.atomic_rmw maxs %ival, %I[%i] : (i32, memref<10xi32>) -> i32
|
||||
// CHECK: llvm.atomicrmw max %{{.*}}, %{{.*}} acq_rel
|
||||
memref.atomic_rmw mins %ival, %I[%i] : (i32, memref<10xi32>) -> i32
|
||||
// CHECK: llvm.atomicrmw min %{{.*}}, %{{.*}} acq_rel
|
||||
memref.atomic_rmw maxu %ival, %I[%i] : (i32, memref<10xi32>) -> i32
|
||||
// CHECK: llvm.atomicrmw umax %{{.*}}, %{{.*}} acq_rel
|
||||
memref.atomic_rmw minu %ival, %I[%i] : (i32, memref<10xi32>) -> i32
|
||||
// CHECK: llvm.atomicrmw umin %{{.*}}, %{{.*}} acq_rel
|
||||
memref.atomic_rmw addf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
|
||||
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel
|
||||
memref.atomic_rmw ori %ival, %I[%i] : (i32, memref<10xi32>) -> i32
|
||||
// CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel
|
||||
memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
|
||||
// CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel
|
||||
return
|
||||
}
|
||||
|
|
|
@ -486,31 +486,6 @@ func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @atomic_rmw
|
||||
func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval : f32, %i : index) {
|
||||
atomic_rmw assign %fval, %F[%i] : (f32, memref<10xf32>) -> f32
|
||||
// CHECK: llvm.atomicrmw xchg %{{.*}}, %{{.*}} acq_rel
|
||||
atomic_rmw addi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
|
||||
// CHECK: llvm.atomicrmw add %{{.*}}, %{{.*}} acq_rel
|
||||
atomic_rmw maxs %ival, %I[%i] : (i32, memref<10xi32>) -> i32
|
||||
// CHECK: llvm.atomicrmw max %{{.*}}, %{{.*}} acq_rel
|
||||
atomic_rmw mins %ival, %I[%i] : (i32, memref<10xi32>) -> i32
|
||||
// CHECK: llvm.atomicrmw min %{{.*}}, %{{.*}} acq_rel
|
||||
atomic_rmw maxu %ival, %I[%i] : (i32, memref<10xi32>) -> i32
|
||||
// CHECK: llvm.atomicrmw umax %{{.*}}, %{{.*}} acq_rel
|
||||
atomic_rmw minu %ival, %I[%i] : (i32, memref<10xi32>) -> i32
|
||||
// CHECK: llvm.atomicrmw umin %{{.*}}, %{{.*}} acq_rel
|
||||
atomic_rmw addf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
|
||||
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel
|
||||
atomic_rmw ori %ival, %I[%i] : (i32, memref<10xi32>) -> i32
|
||||
// CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel
|
||||
atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
|
||||
// CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @generic_atomic_rmw
|
||||
func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) -> i32 {
|
||||
%x = generic_atomic_rmw %I[%i] : memref<10xi32> {
|
||||
|
|
|
@ -499,3 +499,14 @@ func @no_fold_dynamic_no_op_subview(%arg0 : memref<?x?xf32>) -> memref<?x?xf32,
|
|||
// CHECK-LABEL: func @no_fold_dynamic_no_op_subview(
|
||||
// CHECK: %[[SUBVIEW:.+]] = memref.subview
|
||||
// CHECK: return %[[SUBVIEW]]
|
||||
|
||||
// -----
|
||||
|
||||
func @atomicrmw_cast_fold(%arg0 : f32, %arg1 : memref<4xf32>, %c : index) {
|
||||
%v = memref.cast %arg1 : memref<4xf32> to memref<?xf32>
|
||||
%a = memref.atomic_rmw addf %arg0, %v[%c] : (f32, memref<?xf32>) -> f32
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @atomicrmw_cast_fold
|
||||
// CHECK-NEXT: memref.atomic_rmw addf %arg0, %arg1[%arg2] : (f32, memref<4xf32>) -> f32
|
||||
|
|
|
@ -848,3 +848,27 @@ func @illegal_num_offsets(%arg0 : memref<?x?x?xf32>, %arg1 : index, %arg2 : inde
|
|||
// expected-error@+1 {{expected 3 offset values}}
|
||||
%0 = memref.subview %arg0[0, 0] [%arg1, %arg2] [1, 1] : memref<?x?x?xf32> to memref<?x?x?xf32, #map>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @atomic_rmw_idxs_rank_mismatch(%I: memref<16x10xf32>, %i : index, %val : f32) {
|
||||
// expected-error@+1 {{expects the number of subscripts to be equal to memref rank}}
|
||||
%x = memref.atomic_rmw addf %val, %I[%i] : (f32, memref<16x10xf32>) -> f32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @atomic_rmw_expects_float(%I: memref<16x10xi32>, %i : index, %val : i32) {
|
||||
// expected-error@+1 {{expects a floating-point type}}
|
||||
%x = memref.atomic_rmw addf %val, %I[%i, %i] : (i32, memref<16x10xi32>) -> i32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @atomic_rmw_expects_int(%I: memref<16x10xf32>, %i : index, %val : f32) {
|
||||
// expected-error@+1 {{expects an integer type}}
|
||||
%x = memref.atomic_rmw addi %val, %I[%i, %i] : (f32, memref<16x10xf32>) -> f32
|
||||
return
|
||||
}
|
||||
|
|
|
@ -227,3 +227,13 @@ func @rank(%t : memref<4x4x?xf32>) {
|
|||
%1 = memref.rank %t : memref<4x4x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// ------
|
||||
|
||||
// CHECK-LABEL: func @atomic_rmw
|
||||
// CHECK-SAME: ([[BUF:%.*]]: memref<10xf32>, [[VAL:%.*]]: f32, [[I:%.*]]: index)
|
||||
func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) {
|
||||
%x = memref.atomic_rmw addf %val, %I[%i] : (f32, memref<10xf32>) -> f32
|
||||
// CHECK: memref.atomic_rmw addf [[VAL]], [[BUF]]{{\[}}[[I]]]
|
||||
return
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
// CHECK-LABEL: func @atomic_rmw_to_generic
|
||||
// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
|
||||
func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
|
||||
%x = atomic_rmw maxf %f, %F[%i] : (f32, memref<10xf32>) -> f32
|
||||
%x = memref.atomic_rmw maxf %f, %F[%i] : (f32, memref<10xf32>) -> f32
|
||||
return %x : f32
|
||||
}
|
||||
// CHECK: %0 = generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
|
||||
|
@ -18,7 +18,7 @@ func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
|
|||
|
||||
// CHECK-LABEL: func @atomic_rmw_no_conversion
|
||||
func @atomic_rmw_no_conversion(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
|
||||
%x = atomic_rmw addf %f, %F[%i] : (f32, memref<10xf32>) -> f32
|
||||
%x = memref.atomic_rmw addf %f, %F[%i] : (f32, memref<10xf32>) -> f32
|
||||
return %x : f32
|
||||
}
|
||||
// CHECK-NOT: generic_atomic_rmw
|
||||
|
|
|
@ -325,14 +325,6 @@ func @unranked_tensor_load_store(%0 : memref<*xi32>, %1 : tensor<*xi32>) {
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @atomic_rmw
|
||||
// CHECK-SAME: ([[BUF:%.*]]: memref<10xf32>, [[VAL:%.*]]: f32, [[I:%.*]]: index)
|
||||
func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) {
|
||||
%x = atomic_rmw addf %val, %I[%i] : (f32, memref<10xf32>) -> f32
|
||||
// CHECK: atomic_rmw addf [[VAL]], [[BUF]]{{\[}}[[I]]]
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @generic_atomic_rmw
|
||||
// CHECK-SAME: ([[BUF:%.*]]: memref<1x2xf32>, [[I:%.*]]: index, [[J:%.*]]: index)
|
||||
func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) {
|
||||
|
|
|
@ -130,30 +130,6 @@ func @invalid_splat(%v : f32) { // expected-note {{prior use here}}
|
|||
|
||||
// -----
|
||||
|
||||
func @atomic_rmw_idxs_rank_mismatch(%I: memref<16x10xf32>, %i : index, %val : f32) {
|
||||
// expected-error@+1 {{expects the number of subscripts to be equal to memref rank}}
|
||||
%x = atomic_rmw addf %val, %I[%i] : (f32, memref<16x10xf32>) -> f32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @atomic_rmw_expects_float(%I: memref<16x10xi32>, %i : index, %val : i32) {
|
||||
// expected-error@+1 {{expects a floating-point type}}
|
||||
%x = atomic_rmw addf %val, %I[%i, %i] : (i32, memref<16x10xi32>) -> i32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @atomic_rmw_expects_int(%I: memref<16x10xf32>, %i : index, %val : f32) {
|
||||
// expected-error@+1 {{expects an integer type}}
|
||||
%x = atomic_rmw addi %val, %I[%i, %i] : (f32, memref<16x10xf32>) -> f32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @generic_atomic_rmw_wrong_arg_num(%I: memref<10xf32>, %i : index) {
|
||||
// expected-error@+1 {{expected single number of entry block arguments}}
|
||||
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
|
||||
|
|
Loading…
Reference in New Issue