forked from OSchip/llvm-project
[mlir][Linalg] Add support for min/max reduction vectorization in linalg.generic
This patch extends Linalg core vectorization with support for min/max reductions in linalg.generic ops. It enables the reduction detection for min/max combiner ops. It also renames MIN/MAX combining kinds to MINS/MAXS to make the sign explicit for floating point and signed integer types. MINU/MAXU should be introduce din the future for unsigned integer types. Reviewed By: pifon2a, ThomasRaoux Differential Revision: https://reviews.llvm.org/D110854
This commit is contained in:
parent
4e8efff53e
commit
eaf2588a51
|
@ -38,20 +38,25 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
}
|
||||
|
||||
// The "kind" of combining function for contractions and reductions.
|
||||
def COMBINING_KIND_ADD : BitEnumAttrCase<"ADD", 0x1, "add">;
|
||||
def COMBINING_KIND_MUL : BitEnumAttrCase<"MUL", 0x2, "mul">;
|
||||
def COMBINING_KIND_MIN : BitEnumAttrCase<"MIN", 0x4, "min">;
|
||||
def COMBINING_KIND_MAX : BitEnumAttrCase<"MAX", 0x8, "max">;
|
||||
def COMBINING_KIND_AND : BitEnumAttrCase<"AND", 0x10, "and">;
|
||||
def COMBINING_KIND_OR : BitEnumAttrCase<"OR", 0x20, "or">;
|
||||
def COMBINING_KIND_XOR : BitEnumAttrCase<"XOR", 0x40, "xor">;
|
||||
def COMBINING_KIND_ADD : BitEnumAttrCase<"ADD", 0x1, "add">;
|
||||
def COMBINING_KIND_MUL : BitEnumAttrCase<"MUL", 0x2, "mul">;
|
||||
def COMBINING_KIND_MINUI : BitEnumAttrCase<"MINUI", 0x4, "minui">;
|
||||
def COMBINING_KIND_MINSI : BitEnumAttrCase<"MINSI", 0x8, "minsi">;
|
||||
def COMBINING_KIND_MINF : BitEnumAttrCase<"MINF", 0x10, "minf">;
|
||||
def COMBINING_KIND_MAXUI : BitEnumAttrCase<"MAXUI", 0x20, "maxui">;
|
||||
def COMBINING_KIND_MAXSI : BitEnumAttrCase<"MAXSI", 0x40, "maxsi">;
|
||||
def COMBINING_KIND_MAXF : BitEnumAttrCase<"MAXF", 0x80, "maxf">;
|
||||
def COMBINING_KIND_AND : BitEnumAttrCase<"AND", 0x100, "and">;
|
||||
def COMBINING_KIND_OR : BitEnumAttrCase<"OR", 0x200, "or">;
|
||||
def COMBINING_KIND_XOR : BitEnumAttrCase<"XOR", 0x400, "xor">;
|
||||
|
||||
def CombiningKind : BitEnumAttr<
|
||||
"CombiningKind",
|
||||
"Kind of combining function for contractions and reductions",
|
||||
[COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MIN,
|
||||
COMBINING_KIND_MAX, COMBINING_KIND_AND, COMBINING_KIND_OR,
|
||||
COMBINING_KIND_XOR]> {
|
||||
[COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MINUI,
|
||||
COMBINING_KIND_MINSI, COMBINING_KIND_MINF, COMBINING_KIND_MAXUI,
|
||||
COMBINING_KIND_MAXSI, COMBINING_KIND_MAXF, COMBINING_KIND_AND,
|
||||
COMBINING_KIND_OR, COMBINING_KIND_XOR]> {
|
||||
let cppNamespace = "::mlir::vector";
|
||||
let genSpecializedAttr = 0;
|
||||
}
|
||||
|
@ -337,7 +342,7 @@ def Vector_MultiDimReductionOp :
|
|||
|
||||
static SmallVector<int64_t> inferDestShape(
|
||||
ArrayRef<int64_t> shape, ArrayRef<bool> reducedDimsMask) {
|
||||
assert(shape.size() == reducedDimsMask.size() &&
|
||||
assert(shape.size() == reducedDimsMask.size() &&
|
||||
"shape and maks of different sizes");
|
||||
SmallVector<int64_t> res;
|
||||
for (auto it : llvm::zip(reducedDimsMask, shape))
|
||||
|
|
|
@ -434,18 +434,16 @@ public:
|
|||
else if (kind == "mul")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp,
|
||||
llvmType, operand);
|
||||
else if (kind == "min" &&
|
||||
(eltType.isIndex() || eltType.isUnsignedInteger()))
|
||||
else if (kind == "minui")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
|
||||
reductionOp, llvmType, operand);
|
||||
else if (kind == "min")
|
||||
else if (kind == "minsi")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
|
||||
reductionOp, llvmType, operand);
|
||||
else if (kind == "max" &&
|
||||
(eltType.isIndex() || eltType.isUnsignedInteger()))
|
||||
else if (kind == "maxui")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
|
||||
reductionOp, llvmType, operand);
|
||||
else if (kind == "max")
|
||||
else if (kind == "maxsi")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
|
||||
reductionOp, llvmType, operand);
|
||||
else if (kind == "and")
|
||||
|
@ -486,10 +484,14 @@ public:
|
|||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
|
||||
reductionOp, llvmType, acc, operand,
|
||||
rewriter.getBoolAttr(reassociateFPReductions));
|
||||
} else if (kind == "min")
|
||||
} else if (kind == "minf")
|
||||
// FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
|
||||
// NaNs/-0.0/+0.0 in the same way.
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp,
|
||||
llvmType, operand);
|
||||
else if (kind == "max")
|
||||
else if (kind == "maxf")
|
||||
// FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
|
||||
// NaNs/-0.0/+0.0 in the same way.
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp,
|
||||
llvmType, operand);
|
||||
else
|
||||
|
|
|
@ -111,24 +111,40 @@ static VectorType extractVectorTypeFromShapedValue(Value v) {
|
|||
return VectorType::get(st.getShape(), st.getElementType());
|
||||
}
|
||||
|
||||
static llvm::Optional<vector::CombiningKind>
|
||||
getKindForOp(Operation *reductionOp) {
|
||||
if (!reductionOp)
|
||||
return llvm::None;
|
||||
return llvm::TypeSwitch<Operation *, llvm::Optional<vector::CombiningKind>>(
|
||||
reductionOp)
|
||||
.Case<AddIOp, AddFOp>([&](auto op) { return vector::CombiningKind::ADD; })
|
||||
.Case<MaxSIOp>([&](auto op) { return vector::CombiningKind::MAXSI; })
|
||||
.Case<MaxFOp>([&](auto op) { return vector::CombiningKind::MAXF; })
|
||||
.Case<MinSIOp>([&](auto op) { return vector::CombiningKind::MINSI; })
|
||||
.Case<MinFOp>([&](auto op) { return vector::CombiningKind::MINF; })
|
||||
.Default([&](auto op) { return llvm::None; });
|
||||
}
|
||||
|
||||
/// Check whether `outputOperand` is a reduction with a single combiner
|
||||
/// operation. Return the combiner operation of the reduction, which is assumed
|
||||
/// to be a binary operation. Multiple reduction operations would impose an
|
||||
/// ordering between reduction dimensions and is currently unsupported in
|
||||
/// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
|
||||
/// operation. Return the combiner operation kind of the reduction, if
|
||||
/// supported. Return llvm::None, otherwise. Multiple reduction operations would
|
||||
/// impose an ordering between reduction dimensions and is currently unsupported
|
||||
/// in Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
|
||||
/// max(min(X))
|
||||
// TODO: use in LinalgOp verification, there is a circular dependency atm.
|
||||
static Operation *getSingleBinaryOpAssumedReduction(OpOperand *outputOperand) {
|
||||
static llvm::Optional<vector::CombiningKind>
|
||||
matchLinalgReduction(OpOperand *outputOperand) {
|
||||
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
|
||||
unsigned outputPos =
|
||||
outputOperand->getOperandNumber() - linalgOp.getNumInputs();
|
||||
// Only single combiner operatios are supported for now.
|
||||
SmallVector<Operation *, 4> combinerOps;
|
||||
if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
|
||||
combinerOps.size() != 1)
|
||||
return nullptr;
|
||||
return llvm::None;
|
||||
|
||||
// TODO: also assert no other subsequent ops break the reduction.
|
||||
return combinerOps[0];
|
||||
// Return the combiner operation kind, if supported.
|
||||
return getKindForOp(combinerOps[0]);
|
||||
}
|
||||
|
||||
/// If `value` of assumed VectorType has a shape different than `shape`, try to
|
||||
|
@ -151,19 +167,6 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
|
|||
newVecType, value);
|
||||
}
|
||||
|
||||
static llvm::Optional<vector::CombiningKind>
|
||||
getKindForOp(Operation *reductionOp) {
|
||||
if (!reductionOp)
|
||||
return llvm::None;
|
||||
return llvm::TypeSwitch<Operation *, llvm::Optional<vector::CombiningKind>>(
|
||||
reductionOp)
|
||||
.Case<AddIOp, AddFOp>([&](auto op) {
|
||||
return llvm::Optional<vector::CombiningKind>{
|
||||
vector::CombiningKind::ADD};
|
||||
})
|
||||
.Default([&](auto op) { return llvm::None; });
|
||||
}
|
||||
|
||||
/// If value of assumed VectorType has a shape different than `shape`, build and
|
||||
/// return a new vector.broadcast to `shape`.
|
||||
/// Otherwise, just return value.
|
||||
|
@ -173,9 +176,7 @@ static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
|
|||
auto vecType = value.getType().dyn_cast<VectorType>();
|
||||
if (!vecType || vecType.getShape() == targetVectorType.getShape())
|
||||
return value;
|
||||
// At this point, we know we need to reduce. Detect the reduction operator.
|
||||
// TODO: Use the generic reduction detection util.
|
||||
Operation *reductionOp = getSingleBinaryOpAssumedReduction(outputOperand);
|
||||
|
||||
unsigned pos = 0;
|
||||
MLIRContext *ctx = b.getContext();
|
||||
SmallVector<AffineExpr> exprs;
|
||||
|
@ -183,8 +184,9 @@ static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
|
|||
if (isParallelIterator(s))
|
||||
exprs.push_back(getAffineDimExpr(pos++, ctx));
|
||||
auto loc = value.getLoc();
|
||||
// TODO: reuse common CombiningKing logic and support more than add.
|
||||
auto maybeKind = getKindForOp(reductionOp);
|
||||
|
||||
// At this point, we know we need to reduce. Detect the reduction operator.
|
||||
auto maybeKind = matchLinalgReduction(outputOperand);
|
||||
assert(maybeKind && "Failed precondition: could not get reduction kind");
|
||||
unsigned idx = 0;
|
||||
SmallVector<bool> reductionMask(linalgOp.iterator_types().size(), false);
|
||||
|
@ -597,8 +599,7 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
|
|||
if (llvm::none_of(op.iterator_types(), isReductionIterator))
|
||||
return failure();
|
||||
for (OpOperand *opOperand : op.getOutputOperands()) {
|
||||
Operation *reductionOp = getSingleBinaryOpAssumedReduction(opOperand);
|
||||
if (!getKindForOp(reductionOp))
|
||||
if (!matchLinalgReduction(opOperand))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
|
|
|
@ -92,13 +92,18 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
|
|||
switch (combiningKind) {
|
||||
case CombiningKind::ADD:
|
||||
case CombiningKind::MUL:
|
||||
case CombiningKind::MIN:
|
||||
case CombiningKind::MAX:
|
||||
return elementType.isIntOrIndexOrFloat();
|
||||
case CombiningKind::MINUI:
|
||||
case CombiningKind::MINSI:
|
||||
case CombiningKind::MAXUI:
|
||||
case CombiningKind::MAXSI:
|
||||
case CombiningKind::AND:
|
||||
case CombiningKind::OR:
|
||||
case CombiningKind::XOR:
|
||||
return elementType.isIntOrIndex();
|
||||
case CombiningKind::MINF:
|
||||
case CombiningKind::MAXF:
|
||||
return elementType.isa<FloatType>();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -151,8 +156,12 @@ static constexpr const CombiningKind combiningKindsList[] = {
|
|||
// clang-format off
|
||||
CombiningKind::ADD,
|
||||
CombiningKind::MUL,
|
||||
CombiningKind::MIN,
|
||||
CombiningKind::MAX,
|
||||
CombiningKind::MINUI,
|
||||
CombiningKind::MINSI,
|
||||
CombiningKind::MINF,
|
||||
CombiningKind::MAXUI,
|
||||
CombiningKind::MAXSI,
|
||||
CombiningKind::MAXF,
|
||||
CombiningKind::AND,
|
||||
CombiningKind::OR,
|
||||
CombiningKind::XOR,
|
||||
|
@ -291,22 +300,20 @@ static LogicalResult verify(ReductionOp op) {
|
|||
return op.emitOpError("unsupported reduction rank: ") << rank;
|
||||
|
||||
// Verify supported reduction kind.
|
||||
auto kind = op.kind();
|
||||
StringRef strKind = op.kind();
|
||||
auto maybeKind = symbolizeCombiningKind(strKind);
|
||||
if (!maybeKind)
|
||||
return op.emitOpError("unknown reduction kind: ") << strKind;
|
||||
|
||||
Type eltType = op.dest().getType();
|
||||
if (kind == "add" || kind == "mul" || kind == "min" || kind == "max") {
|
||||
if (!eltType.isIntOrIndexOrFloat())
|
||||
return op.emitOpError("unsupported reduction type");
|
||||
} else if (kind == "and" || kind == "or" || kind == "xor") {
|
||||
if (!eltType.isIntOrIndex())
|
||||
return op.emitOpError("unsupported reduction type");
|
||||
} else {
|
||||
return op.emitOpError("unknown reduction kind: ") << kind;
|
||||
}
|
||||
if (!isSupportedCombiningKind(*maybeKind, eltType))
|
||||
return op.emitOpError("unsupported reduction type '")
|
||||
<< eltType << "' for kind '" << op.kind() << "'";
|
||||
|
||||
// Verify optional accumulator.
|
||||
if (!op.acc().empty()) {
|
||||
if (kind != "add" && kind != "mul")
|
||||
return op.emitOpError("no accumulator for reduction kind: ") << kind;
|
||||
if (strKind != "add" && strKind != "mul")
|
||||
return op.emitOpError("no accumulator for reduction kind: ") << strKind;
|
||||
if (!eltType.isa<FloatType>())
|
||||
return op.emitOpError("no accumulator for type: ") << eltType;
|
||||
}
|
||||
|
|
|
@ -821,15 +821,17 @@ private:
|
|||
case CombiningKind::MUL:
|
||||
combinedResult = rewriter.create<MulIOp>(loc, mul, acc);
|
||||
break;
|
||||
case CombiningKind::MIN:
|
||||
combinedResult = rewriter.create<SelectOp>(
|
||||
loc, rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, mul, acc), mul,
|
||||
acc);
|
||||
case CombiningKind::MINUI:
|
||||
combinedResult = rewriter.create<MinUIOp>(loc, mul, acc);
|
||||
break;
|
||||
case CombiningKind::MAX:
|
||||
combinedResult = rewriter.create<SelectOp>(
|
||||
loc, rewriter.create<CmpIOp>(loc, CmpIPredicate::sge, mul, acc), mul,
|
||||
acc);
|
||||
case CombiningKind::MINSI:
|
||||
combinedResult = rewriter.create<MinSIOp>(loc, mul, acc);
|
||||
break;
|
||||
case CombiningKind::MAXUI:
|
||||
combinedResult = rewriter.create<MaxUIOp>(loc, mul, acc);
|
||||
break;
|
||||
case CombiningKind::MAXSI:
|
||||
combinedResult = rewriter.create<MaxSIOp>(loc, mul, acc);
|
||||
break;
|
||||
case CombiningKind::AND:
|
||||
combinedResult = rewriter.create<AndOp>(loc, mul, acc);
|
||||
|
@ -840,6 +842,9 @@ private:
|
|||
case CombiningKind::XOR:
|
||||
combinedResult = rewriter.create<XOrOp>(loc, mul, acc);
|
||||
break;
|
||||
case CombiningKind::MINF: // Only valid for floating point types.
|
||||
case CombiningKind::MAXF: // Only valid for floating point types.
|
||||
return Optional<Value>();
|
||||
}
|
||||
return Optional<Value>(combinedResult);
|
||||
}
|
||||
|
@ -864,18 +869,18 @@ private:
|
|||
case CombiningKind::MUL:
|
||||
combinedResult = rewriter.create<MulFOp>(loc, mul, acc);
|
||||
break;
|
||||
case CombiningKind::MIN:
|
||||
combinedResult = rewriter.create<SelectOp>(
|
||||
loc, rewriter.create<CmpFOp>(loc, CmpFPredicate::OLE, mul, acc), mul,
|
||||
acc);
|
||||
case CombiningKind::MINF:
|
||||
combinedResult = rewriter.create<MinFOp>(loc, mul, acc);
|
||||
break;
|
||||
case CombiningKind::MAX:
|
||||
combinedResult = rewriter.create<SelectOp>(
|
||||
loc, rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, mul, acc), mul,
|
||||
acc);
|
||||
case CombiningKind::MAXF:
|
||||
combinedResult = rewriter.create<MaxFOp>(loc, mul, acc);
|
||||
break;
|
||||
case CombiningKind::ADD: // Already handled this special case above.
|
||||
case CombiningKind::AND: // Only valid for integer types.
|
||||
case CombiningKind::MINUI: // Only valid for integer types.
|
||||
case CombiningKind::MINSI: // Only valid for integer types.
|
||||
case CombiningKind::MAXUI: // Only valid for integer types.
|
||||
case CombiningKind::MAXSI: // Only valid for integer types.
|
||||
case CombiningKind::OR: // Only valid for integer types.
|
||||
case CombiningKind::XOR: // Only valid for integer types.
|
||||
return Optional<Value>();
|
||||
|
@ -3697,23 +3702,23 @@ struct UnrollOuterMultiReduction
|
|||
else
|
||||
result = rewriter.create<MulFOp>(loc, operand, result);
|
||||
break;
|
||||
case vector::CombiningKind::MIN:
|
||||
if (elementType.isIntOrIndex())
|
||||
condition =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, operand, result);
|
||||
else
|
||||
condition =
|
||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, result);
|
||||
result = rewriter.create<SelectOp>(loc, condition, operand, result);
|
||||
case vector::CombiningKind::MINUI:
|
||||
result = rewriter.create<MinUIOp>(loc, operand, result);
|
||||
break;
|
||||
case vector::CombiningKind::MAX:
|
||||
if (elementType.isIntOrIndex())
|
||||
condition =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::sge, operand, result);
|
||||
else
|
||||
condition =
|
||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGE, operand, result);
|
||||
result = rewriter.create<SelectOp>(loc, condition, operand, result);
|
||||
case vector::CombiningKind::MINSI:
|
||||
result = rewriter.create<MinSIOp>(loc, operand, result);
|
||||
break;
|
||||
case vector::CombiningKind::MINF:
|
||||
result = rewriter.create<MinFOp>(loc, operand, result);
|
||||
break;
|
||||
case vector::CombiningKind::MAXUI:
|
||||
result = rewriter.create<MaxUIOp>(loc, operand, result);
|
||||
break;
|
||||
case vector::CombiningKind::MAXSI:
|
||||
result = rewriter.create<MaxSIOp>(loc, operand, result);
|
||||
break;
|
||||
case vector::CombiningKind::MAXF:
|
||||
result = rewriter.create<MaxFOp>(loc, operand, result);
|
||||
break;
|
||||
case vector::CombiningKind::AND:
|
||||
result = rewriter.create<AndOp>(loc, operand, result);
|
||||
|
@ -3771,10 +3776,18 @@ struct TwoDimMultiReductionToReduction
|
|||
return "add";
|
||||
case vector::CombiningKind::MUL:
|
||||
return "mul";
|
||||
case vector::CombiningKind::MIN:
|
||||
return "min";
|
||||
case vector::CombiningKind::MAX:
|
||||
return "max";
|
||||
case vector::CombiningKind::MINUI:
|
||||
return "minui";
|
||||
case vector::CombiningKind::MINSI:
|
||||
return "minsi";
|
||||
case vector::CombiningKind::MINF:
|
||||
return "minf";
|
||||
case vector::CombiningKind::MAXUI:
|
||||
return "maxui";
|
||||
case vector::CombiningKind::MAXSI:
|
||||
return "maxsi";
|
||||
case vector::CombiningKind::MAXF:
|
||||
return "maxf";
|
||||
case vector::CombiningKind::AND:
|
||||
return "and";
|
||||
case vector::CombiningKind::OR:
|
||||
|
|
|
@ -806,3 +806,54 @@ func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output: ten
|
|||
} -> tensor<5x2xf32>
|
||||
return %0 : tensor<5x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @red_max_2d(
|
||||
func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
|
||||
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
|
||||
// CHECK: vector.transfer_read {{.*}} : tensor<4xf32>, vector<4x4xf32>
|
||||
// CHECK: maxf {{.*}} : vector<4x4xf32>
|
||||
// CHECK: vector.multi_reduction #vector.kind<maxf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
|
||||
%minf32 = constant -3.40282e+38 : f32
|
||||
%init = linalg.init_tensor [4] : tensor<4xf32>
|
||||
%fill = linalg.fill(%minf32, %init) : f32, tensor<4xf32> -> tensor<4xf32>
|
||||
%red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0)>],
|
||||
iterator_types = ["parallel", "reduction"]}
|
||||
ins(%arg0 : tensor<4x4xf32>) outs(%fill : tensor<4xf32>) {
|
||||
^bb0(%in0: f32, %out0: f32): // no predecessors
|
||||
%max = maxf %in0, %out0 : f32
|
||||
linalg.yield %max : f32
|
||||
} -> tensor<4xf32>
|
||||
return %red : tensor<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @red_min_2d(
|
||||
func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
|
||||
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
|
||||
// CHECK: vector.transfer_read {{.*}} : tensor<4xf32>, vector<4x4xf32>
|
||||
// CHECK: minf {{.*}} : vector<4x4xf32>
|
||||
// CHECK: vector.multi_reduction #vector.kind<minf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
|
||||
%maxf32 = constant 3.40282e+38 : f32
|
||||
%init = linalg.init_tensor [4] : tensor<4xf32>
|
||||
%fill = linalg.fill(%maxf32, %init) : f32, tensor<4xf32> -> tensor<4xf32>
|
||||
%red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0)>],
|
||||
iterator_types = ["parallel", "reduction"]}
|
||||
ins(%arg0 : tensor<4x4xf32>) outs(%fill : tensor<4xf32>) {
|
||||
^bb0(%in0: f32, %out0: f32): // no predecessors
|
||||
%min = minf %in0, %out0 : f32
|
||||
linalg.yield %min : f32
|
||||
} -> tensor<4xf32>
|
||||
return %red : tensor<4xf32>
|
||||
}
|
||||
|
||||
|
|
|
@ -1019,7 +1019,7 @@ func @reduce_unsupported_third_argument(%arg0: vector<16xf32>, %arg1: f32) -> f3
|
|||
|
||||
func @reduce_unsupported_accumulator_kind(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
|
||||
// expected-error@+1 {{'vector.reduction' op no accumulator for reduction kind: min}}
|
||||
%0 = vector.reduction "min", %arg0, %arg1 : vector<16xf32> into f32
|
||||
%0 = vector.reduction "minf", %arg0, %arg1 : vector<16xf32> into f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -243,13 +243,13 @@ func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32
|
|||
#contraction_to_scalar_max_trait = {
|
||||
indexing_maps = #contraction_to_scalar_max_accesses,
|
||||
iterator_types = ["reduction"],
|
||||
kind = #vector.kind<max>
|
||||
kind = #vector.kind<maxf>
|
||||
}
|
||||
// CHECK-LABEL: @contraction_to_scalar_with_max
|
||||
func @contraction_to_scalar_with_max(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 {
|
||||
// CHECK: %[[C0:.*]] = constant 0.000000e+00 : f32
|
||||
%f0 = constant 0.0: f32
|
||||
// CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind<max>} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32
|
||||
// CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind<maxf>} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32
|
||||
%0 = vector.contract #contraction_to_scalar_max_trait %arg0, %arg1, %f0
|
||||
: vector<10xf32>, vector<10xf32> into f32
|
||||
// CHECK: return %[[X]] : f32
|
||||
|
@ -281,7 +281,7 @@ func @contraction_to_scalar_with_max(%arg0: vector<10xf32>, %arg1: vector<10xf32
|
|||
#contraction_trait2 = {
|
||||
indexing_maps = #contraction_accesses1,
|
||||
iterator_types = #iterator_types1,
|
||||
kind = #vector.kind<max>
|
||||
kind = #vector.kind<maxf>
|
||||
}
|
||||
// CHECK-LABEL: @contraction
|
||||
func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
|
||||
|
@ -309,7 +309,7 @@ func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
|
|||
%3 = vector.contract #contraction_trait1 %arg4, %arg5, %arg3
|
||||
: vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32>
|
||||
// Test contraction with "max" instead of "add".
|
||||
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind<max>} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
|
||||
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind<maxf>} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
|
||||
%4 = vector.contract #contraction_trait2 %arg0, %arg1, %arg3
|
||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
|
||||
return
|
||||
|
@ -432,10 +432,10 @@ func @reduce_fp(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
|
|||
vector.reduction "mul", %arg0 : vector<16xf32> into f32
|
||||
// CHECK: vector.reduction "mul", %{{.*}}, %{{.*}} : vector<16xf32> into f32
|
||||
vector.reduction "mul", %arg0, %arg1 : vector<16xf32> into f32
|
||||
// CHECK: vector.reduction "min", %{{.*}} : vector<16xf32> into f32
|
||||
vector.reduction "min", %arg0 : vector<16xf32> into f32
|
||||
// CHECK: %[[X:.*]] = vector.reduction "max", %{{.*}} : vector<16xf32> into f32
|
||||
%0 = vector.reduction "max", %arg0 : vector<16xf32> into f32
|
||||
// CHECK: vector.reduction "minf", %{{.*}} : vector<16xf32> into f32
|
||||
vector.reduction "minf", %arg0 : vector<16xf32> into f32
|
||||
// CHECK: %[[X:.*]] = vector.reduction "maxf", %{{.*}} : vector<16xf32> into f32
|
||||
%0 = vector.reduction "maxf", %arg0 : vector<16xf32> into f32
|
||||
// CHECK: return %[[X]] : f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
@ -446,10 +446,14 @@ func @reduce_int(%arg0: vector<16xi32>) -> i32 {
|
|||
vector.reduction "add", %arg0 : vector<16xi32> into i32
|
||||
// CHECK: vector.reduction "mul", %{{.*}} : vector<16xi32> into i32
|
||||
vector.reduction "mul", %arg0 : vector<16xi32> into i32
|
||||
// CHECK: vector.reduction "min", %{{.*}} : vector<16xi32> into i32
|
||||
vector.reduction "min", %arg0 : vector<16xi32> into i32
|
||||
// CHECK: vector.reduction "max", %{{.*}} : vector<16xi32> into i32
|
||||
vector.reduction "max", %arg0 : vector<16xi32> into i32
|
||||
// CHECK: vector.reduction "minui", %{{.*}} : vector<16xi32> into i32
|
||||
vector.reduction "minui", %arg0 : vector<16xi32> into i32
|
||||
// CHECK: vector.reduction "minsi", %{{.*}} : vector<16xi32> into i32
|
||||
vector.reduction "minsi", %arg0 : vector<16xi32> into i32
|
||||
// CHECK: vector.reduction "maxui", %{{.*}} : vector<16xi32> into i32
|
||||
vector.reduction "maxui", %arg0 : vector<16xi32> into i32
|
||||
// CHECK: vector.reduction "maxsi", %{{.*}} : vector<16xi32> into i32
|
||||
vector.reduction "maxsi", %arg0 : vector<16xi32> into i32
|
||||
// CHECK: vector.reduction "and", %{{.*}} : vector<16xi32> into i32
|
||||
vector.reduction "and", %arg0 : vector<16xi32> into i32
|
||||
// CHECK: vector.reduction "or", %{{.*}} : vector<16xi32> into i32
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
#matvecmax_trait = {
|
||||
indexing_maps = #matvec_accesses,
|
||||
iterator_types = ["parallel", "reduction"],
|
||||
kind = #vector.kind<max>
|
||||
kind = #vector.kind<maxf>
|
||||
}
|
||||
|
||||
#mattransvec_accesses = [
|
||||
|
@ -91,10 +91,10 @@ func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
|
|||
// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
|
||||
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32>
|
||||
// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
|
||||
// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<max>} : vector<2xf32>, f32
|
||||
// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
|
||||
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32>
|
||||
// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
|
||||
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<max>} : vector<2xf32>, f32
|
||||
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
|
||||
// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
|
||||
// CHECK: return
|
||||
func @matvecmax2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
|
||||
|
|
|
@ -18,7 +18,7 @@ func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
|
|||
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
|
||||
|
||||
func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> {
|
||||
%0 = vector.multi_reduction #vector.kind<min>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
|
||||
%0 = vector.multi_reduction #vector.kind<minf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
|
||||
return %0 : vector<2xf32>
|
||||
}
|
||||
|
||||
|
@ -27,18 +27,15 @@ func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> {
|
|||
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
|
||||
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32>
|
||||
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32>
|
||||
// CHECK: %[[C0:.+]] = cmpf olt, %[[V1]], %[[V0]] : vector<2xf32>
|
||||
// CHECK: %[[RV01:.+]] = select %[[C0]], %[[V1]], %[[V0]] : vector<2xi1>, vector<2xf32>
|
||||
// CHECK: %[[RV01:.+]] = minf %[[V1]], %[[V0]] : vector<2xf32>
|
||||
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32>
|
||||
// CHECK: %[[C1:.+]] = cmpf olt, %[[V2]], %[[RV01]] : vector<2xf32>
|
||||
// CHECK: %[[RV012:.+]] = select %[[C1]], %[[V2]], %[[RV01]] : vector<2xi1>, vector<2xf32>
|
||||
// CHECK: %[[RV012:.+]] = minf %[[V2]], %[[RV01]] : vector<2xf32>
|
||||
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32>
|
||||
// CHECK: %[[C2:.+]] = cmpf olt, %[[V3]], %[[RV012]] : vector<2xf32>
|
||||
// CHECK: %[[RESULT_VEC:.+]] = select %[[C2]], %[[V3]], %[[RV012]] : vector<2xi1>, vector<2xf32>
|
||||
// CHECK: %[[RESULT_VEC:.+]] = minf %[[V3]], %[[RV012]] : vector<2xf32>
|
||||
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
|
||||
|
||||
func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> {
|
||||
%0 = vector.multi_reduction #vector.kind<max>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
|
||||
%0 = vector.multi_reduction #vector.kind<maxf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
|
||||
return %0 : vector<2xf32>
|
||||
}
|
||||
|
||||
|
@ -47,14 +44,11 @@ func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> {
|
|||
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
|
||||
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32>
|
||||
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32>
|
||||
// CHECK: %[[C0:.+]] = cmpf oge, %[[V1]], %[[V0]] : vector<2xf32>
|
||||
// CHECK: %[[RV01:.+]] = select %[[C0]], %[[V1]], %[[V0]] : vector<2xi1>, vector<2xf32>
|
||||
// CHECK: %[[RV01:.+]] = maxf %[[V1]], %[[V0]] : vector<2xf32>
|
||||
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32>
|
||||
// CHECK: %[[C1:.+]] = cmpf oge, %[[V2]], %[[RV01]] : vector<2xf32>
|
||||
// CHECK: %[[RV012:.+]] = select %[[C1]], %[[V2]], %[[RV01]] : vector<2xi1>, vector<2xf32>
|
||||
// CHECK: %[[RV012:.+]] = maxf %[[V2]], %[[RV01]] : vector<2xf32>
|
||||
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32>
|
||||
// CHECK: %[[C2:.+]] = cmpf oge, %[[V3]], %[[RV012]] : vector<2xf32>
|
||||
// CHECK: %[[RESULT_VEC:.+]] = select %[[C2]], %[[V3]], %[[RV012]] : vector<2xi1>, vector<2xf32>
|
||||
// CHECK: %[[RESULT_VEC:.+]] = maxf %[[V3]], %[[RV012]] : vector<2xf32>
|
||||
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
|
||||
|
||||
func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> {
|
||||
|
|
|
@ -27,10 +27,10 @@ func @entry() {
|
|||
%1 = vector.reduction "mul", %v2 : vector<64xf32> into f32
|
||||
vector.print %1 : f32
|
||||
// CHECK: 6
|
||||
%2 = vector.reduction "min", %v2 : vector<64xf32> into f32
|
||||
%2 = vector.reduction "minf", %v2 : vector<64xf32> into f32
|
||||
vector.print %2 : f32
|
||||
// CHECK: 1
|
||||
%3 = vector.reduction "max", %v2 : vector<64xf32> into f32
|
||||
%3 = vector.reduction "maxf", %v2 : vector<64xf32> into f32
|
||||
vector.print %3 : f32
|
||||
// CHECK: 3
|
||||
|
||||
|
|
|
@ -39,10 +39,10 @@ func @entry() {
|
|||
%1 = vector.reduction "mul", %v9 : vector<10xf32> into f32
|
||||
vector.print %1 : f32
|
||||
// CHECK: -5760
|
||||
%2 = vector.reduction "min", %v9 : vector<10xf32> into f32
|
||||
%2 = vector.reduction "minf", %v9 : vector<10xf32> into f32
|
||||
vector.print %2 : f32
|
||||
// CHECK: -16
|
||||
%3 = vector.reduction "max", %v9 : vector<10xf32> into f32
|
||||
%3 = vector.reduction "maxf", %v9 : vector<10xf32> into f32
|
||||
vector.print %3 : f32
|
||||
// CHECK: 5
|
||||
|
||||
|
|
|
@ -27,10 +27,10 @@ func @entry() {
|
|||
%1 = vector.reduction "mul", %v2 : vector<64xf64> into f64
|
||||
vector.print %1 : f64
|
||||
// CHECK: 6
|
||||
%2 = vector.reduction "min", %v2 : vector<64xf64> into f64
|
||||
%2 = vector.reduction "minf", %v2 : vector<64xf64> into f64
|
||||
vector.print %2 : f64
|
||||
// CHECK: 1
|
||||
%3 = vector.reduction "max", %v2 : vector<64xf64> into f64
|
||||
%3 = vector.reduction "maxf", %v2 : vector<64xf64> into f64
|
||||
vector.print %3 : f64
|
||||
// CHECK: 3
|
||||
|
||||
|
|
|
@ -39,10 +39,10 @@ func @entry() {
|
|||
%1 = vector.reduction "mul", %v9 : vector<10xf64> into f64
|
||||
vector.print %1 : f64
|
||||
// CHECK: -5760
|
||||
%2 = vector.reduction "min", %v9 : vector<10xf64> into f64
|
||||
%2 = vector.reduction "minf", %v9 : vector<10xf64> into f64
|
||||
vector.print %2 : f64
|
||||
// CHECK: -16
|
||||
%3 = vector.reduction "max", %v9 : vector<10xf64> into f64
|
||||
%3 = vector.reduction "maxf", %v9 : vector<10xf64> into f64
|
||||
vector.print %3 : f64
|
||||
// CHECK: 5
|
||||
|
||||
|
|
|
@ -39,10 +39,10 @@ func @entry() {
|
|||
%1 = vector.reduction "mul", %v9 : vector<10xi32> into i32
|
||||
vector.print %1 : i32
|
||||
// CHECK: -1228800
|
||||
%2 = vector.reduction "min", %v9 : vector<10xi32> into i32
|
||||
%2 = vector.reduction "minsi", %v9 : vector<10xi32> into i32
|
||||
vector.print %2 : i32
|
||||
// CHECK: -80
|
||||
%3 = vector.reduction "max", %v9 : vector<10xi32> into i32
|
||||
%3 = vector.reduction "maxsi", %v9 : vector<10xi32> into i32
|
||||
vector.print %3 : i32
|
||||
// CHECK: 5
|
||||
%4 = vector.reduction "and", %v9 : vector<10xi32> into i32
|
||||
|
|
|
@ -20,11 +20,11 @@ func @entry() {
|
|||
vector.print %1 : i4
|
||||
// CHECK: 0
|
||||
|
||||
%2 = vector.reduction "min", %v : vector<24xi4> into i4
|
||||
%2 = vector.reduction "minsi", %v : vector<24xi4> into i4
|
||||
vector.print %2 : i4
|
||||
// CHECK: -8
|
||||
|
||||
%3 = vector.reduction "max", %v : vector<24xi4> into i4
|
||||
%3 = vector.reduction "maxsi", %v : vector<24xi4> into i4
|
||||
vector.print %3 : i4
|
||||
// CHECK: 7
|
||||
|
||||
|
|
|
@ -39,10 +39,10 @@ func @entry() {
|
|||
%1 = vector.reduction "mul", %v9 : vector<10xi64> into i64
|
||||
vector.print %1 : i64
|
||||
// CHECK: -1228800
|
||||
%2 = vector.reduction "min", %v9 : vector<10xi64> into i64
|
||||
%2 = vector.reduction "minsi", %v9 : vector<10xi64> into i64
|
||||
vector.print %2 : i64
|
||||
// CHECK: -80
|
||||
%3 = vector.reduction "max", %v9 : vector<10xi64> into i64
|
||||
%3 = vector.reduction "maxsi", %v9 : vector<10xi64> into i64
|
||||
vector.print %3 : i64
|
||||
// CHECK: 5
|
||||
%4 = vector.reduction "and", %v9 : vector<10xi64> into i64
|
||||
|
|
|
@ -19,11 +19,11 @@ func @entry() {
|
|||
vector.print %1 : si4
|
||||
// CHECK: 0
|
||||
|
||||
%2 = vector.reduction "min", %v : vector<16xsi4> into si4
|
||||
%2 = vector.reduction "minsi", %v : vector<16xsi4> into si4
|
||||
vector.print %2 : si4
|
||||
// CHECK: -8
|
||||
|
||||
%3 = vector.reduction "max", %v : vector<16xsi4> into si4
|
||||
%3 = vector.reduction "maxsi", %v : vector<16xsi4> into si4
|
||||
vector.print %3 : si4
|
||||
// CHECK: 7
|
||||
|
||||
|
|
|
@ -19,11 +19,11 @@ func @entry() {
|
|||
vector.print %1 : ui4
|
||||
// CHECK: 0
|
||||
|
||||
%2 = vector.reduction "min", %v : vector<16xui4> into ui4
|
||||
%2 = vector.reduction "minui", %v : vector<16xui4> into ui4
|
||||
vector.print %2 : ui4
|
||||
// CHECK: 0
|
||||
|
||||
%3 = vector.reduction "max", %v : vector<16xui4> into ui4
|
||||
%3 = vector.reduction "maxui", %v : vector<16xui4> into ui4
|
||||
vector.print %3 : ui4
|
||||
// CHECK: 15
|
||||
|
||||
|
|
Loading…
Reference in New Issue