[mlir][linalg] Fix generic reduction vectorization

We shouldn't broadcast the original value when doing reduction. Instead
we compute the reduction and then combine it with the original value.

Differential Revision: https://reviews.llvm.org/D111666
This commit is contained in:
thomasraoux 2021-10-12 15:42:02 -07:00
parent b6a8c69554
commit 7c97e328b3
2 changed files with 102 additions and 57 deletions

View File

@ -134,14 +134,13 @@ getKindForOp(Operation *reductionOp) {
}
/// Check whether `outputOperand` is a reduction with a single combiner
/// operation. Return the combiner operation kind of the reduction, if
/// supported. Return llvm::None, otherwise. Multiple reduction operations would
/// impose an ordering between reduction dimensions and is currently unsupported
/// in Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
/// operation. Return the combiner operation of the reduction. Return
/// nullptr otherwise. Multiple reduction operations would impose an
/// ordering between reduction dimensions and is currently unsupported in
/// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
/// max(min(X))
// TODO: use in LinalgOp verification, there is a circular dependency atm.
static llvm::Optional<vector::CombiningKind>
matchLinalgReduction(OpOperand *outputOperand) {
static Operation *matchLinalgReduction(OpOperand *outputOperand) {
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
unsigned outputPos =
outputOperand->getOperandNumber() - linalgOp.getNumInputs();
@ -149,10 +148,10 @@ matchLinalgReduction(OpOperand *outputOperand) {
SmallVector<Operation *, 4> combinerOps;
if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
combinerOps.size() != 1)
return llvm::None;
return nullptr;
// Return the combiner operation kind, if supported.
return getKindForOp(combinerOps[0]);
// Return the combiner operation.
return combinerOps[0];
}
/// Broadcast `value` to a vector of `shape` if possible. Return value
@ -171,11 +170,60 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
}
/// Build a vector.transfer_read from `source` at indices set to all `0`.
/// If source has rank zero, build a `vector<1xt> transfer_read + extract`.
/// Return the produced value.
static Value buildVectorRead(OpBuilder &b, Value source, Type readType,
AffineMap map) {
Location loc = source.getLoc();
auto shapedType = source.getType().cast<ShapedType>();
SmallVector<Value> indices(shapedType.getRank(),
b.create<ConstantIndexOp>(loc, 0));
if (auto vectorType = readType.dyn_cast<VectorType>())
return b.create<vector::TransferReadOp>(loc, vectorType, source, indices,
map);
return vector::TransferReadOp::createScalarOp(b, loc, source, indices);
}
/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
/// assumes that `reductionOp` has tow operands and one of them is the reduction
/// initial value.
static Value buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
Value outputArg,
const SmallVector<bool> &reductionMask,
const BlockAndValueMapping &bvm) {
auto maybeKind = getKindForOp(reduceOp);
assert(maybeKind && "Failed precondition: could not get reduction kind");
Value operandToReduce = reduceOp->getOperand(0) == outputArg
? reduceOp->getOperand(1)
: reduceOp->getOperand(0);
Value vec = bvm.lookup(operandToReduce);
return b.create<vector::MultiDimReductionOp>(reduceOp->getLoc(), vec,
reductionMask, *maybeKind);
}
/// Read the initial value associated to the given `outputOperand`.
static Value readInitialValue(OpBuilder &b, LinalgOp linalgOp,
OpOperand *outputOperand) {
AffineMap map = inversePermutation(
reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand)));
Type readType;
if (linalgOp.getShape(outputOperand).empty()) {
readType = getElementTypeOrSelf(outputOperand->get());
} else {
readType = VectorType::get(map.compose(linalgOp.getShape(outputOperand)),
getElementTypeOrSelf(outputOperand->get()));
}
Value vectorRead = buildVectorRead(b, outputOperand->get(), readType, map);
return vectorRead;
}
/// Assuming `outputOperand` is an output operand of a LinalgOp, determine
/// whether a reduction is needed to produce a `targetType` and create that
/// reduction if it is the case.
static Value reduceIfNeeded(OpBuilder &b, Type targetType, Value value,
OpOperand *outputOperand) {
OpOperand *outputOperand,
const BlockAndValueMapping &bvm) {
LDBG("Reduce " << value << " to type " << targetType);
LDBG("In LinalgOp operand #" << outputOperand->getOperandNumber() << "\n"
<< *(outputOperand->getOwner()));
@ -194,10 +242,9 @@ static Value reduceIfNeeded(OpBuilder &b, Type targetType, Value value,
for (auto s : linalgOp.iterator_types())
if (isParallelIterator(s))
exprs.push_back(getAffineDimExpr(pos++, ctx));
auto loc = value.getLoc();
auto maybeKind = matchLinalgReduction(outputOperand);
assert(maybeKind && "Failed precondition: could not get reduction kind");
Operation *reduceOp = matchLinalgReduction(outputOperand);
assert(reduceOp && "Failed precondition: could not math a reduction");
unsigned idx = 0;
SmallVector<bool> reductionMask(linalgOp.iterator_types().size(), false);
for (auto attr : linalgOp.iterator_types()) {
@ -205,23 +252,24 @@ static Value reduceIfNeeded(OpBuilder &b, Type targetType, Value value,
reductionMask[idx] = true;
++idx;
}
return b.create<vector::MultiDimReductionOp>(loc, value, reductionMask,
*maybeKind);
}
assert(reduceOp->getNumOperands() == 2 &&
"Only support binary reduce op right now");
unsigned outputPos =
outputOperand->getOperandNumber() - linalgOp.getNumInputs();
Value outputArg = linalgOp.getRegionOutputArgs()[outputPos];
// Reduce across the iteration space.
Value reduce =
buildMultiDimReduce(b, reduceOp, outputArg, reductionMask, bvm);
/// Build a vector.transfer_read from `source` at indices set to all `0`.
/// If source has rank zero, build a `vector<1xt> transfer_read + extract`.
/// Return the produced value.
static Value buildVectorRead(OpBuilder &b, Value source, Type readType,
AffineMap map) {
Location loc = source.getLoc();
auto shapedType = source.getType().cast<ShapedType>();
SmallVector<Value> indices(shapedType.getRank(),
b.create<ConstantIndexOp>(loc, 0));
if (auto vectorType = readType.dyn_cast<VectorType>())
return b.create<vector::TransferReadOp>(loc, vectorType, source, indices,
map);
return vector::TransferReadOp::createScalarOp(b, loc, source, indices);
// Read the original output value.
Value initialValue = readInitialValue(b, linalgOp, outputOperand);
// Combine the output argument with the reduced value.
OperationState state(reduceOp->getLoc(), reduceOp->getName());
state.addAttributes(reduceOp->getAttrs());
state.addOperands({reduce, initialValue});
state.addTypes(initialValue.getType());
return b.createOperation(state)->getResult(0);
}
/// Build a vector.transfer_write of `value` into `outputOperand` at indices set
@ -229,7 +277,8 @@ static Value buildVectorRead(OpBuilder &b, Value source, Type readType,
/// currently being vectorized. If `dest` has null rank, build an memref.store.
/// Return the produced value or null if no value is produced.
static Value buildVectorWrite(OpBuilder &b, Value value,
OpOperand *outputOperand) {
OpOperand *outputOperand,
const BlockAndValueMapping &bvm) {
Operation *write;
Location loc = value.getLoc();
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
@ -244,12 +293,12 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
SmallVector<Value> indices(linalgOp.getRank(outputOperand),
b.create<ConstantIndexOp>(loc, 0));
value = broadcastIfNeeded(b, value, vectorType.getShape());
value = reduceIfNeeded(b, vectorType, value, outputOperand);
value = reduceIfNeeded(b, vectorType, value, outputOperand, bvm);
write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
indices, map);
} else {
value =
reduceIfNeeded(b, getElementTypeOrSelf(value), value, outputOperand);
value = reduceIfNeeded(b, getElementTypeOrSelf(value), value, outputOperand,
bvm);
write = vector::TransferWriteOp::createScalarOp(
b, loc, value, outputOperand->get(), ValueRange{});
}
@ -284,7 +333,7 @@ vectorizeLinalgYield(OpBuilder &b, Operation *op,
// TODO: use a map.
Value vectorValue = bvm.lookup(outputs.value());
Value newResult = buildVectorWrite(
b, vectorValue, linalgOp.getOutputOperand(outputs.index()));
b, vectorValue, linalgOp.getOutputOperand(outputs.index()), bvm);
if (newResult)
newResults.push_back(newResult);
}
@ -611,7 +660,8 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
return failure();
}
for (OpOperand *opOperand : op.getOutputOperands()) {
if (!matchLinalgReduction(opOperand)) {
Operation *reduceOp = matchLinalgReduction(opOperand);
if (!reduceOp || !getKindForOp(reduceOp)) {
LDBG("reduction precondition failed: reduction detection failed");
return failure();
}

View File

@ -744,17 +744,15 @@ func @pad_tensor_non_const_pad_value(%arg0: tensor<5x6xf32>) -> tensor<12x13xf32
// -----
// CHECK-DAG: #[[$M0:.*]] = affine_map<(d0, d1) -> (d0, d1, 0)>
// CHECK-LABEL: func @sum_exp
func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>)
-> tensor<4x16xf32>
{
// CHECK: vector.transfer_read {{.*}} : tensor<4x16x8xf32>, vector<4x16x8xf32>
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true], permutation_map = #[[$M0]]} : tensor<4x16xf32>, vector<4x16x8xf32>
// CHECK: math.exp {{.*}} : vector<4x16x8xf32>
// CHECK: addf {{.*}} : vector<4x16x8xf32>
// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32>
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32>
// CHECK: addf {{.*}} : vector<4x16xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4x16xf32>, tensor<4x16xf32>
// CHECK: return {{.*}} : tensor<4x16xf32>
%0 = linalg.generic {
@ -776,8 +774,7 @@ func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>)
// CHECK-DAG: #[[$M1:.*]] = affine_map<(d0, d1) -> (d1, d0, 0, 0)>
// CHECK-DAG: #[[$M2:.*]] = affine_map<(d0, d1) -> (0, 0, d1, d0)>
// CHECK-DAG: #[[$M3:.*]] = affine_map<(d0, d1) -> (d1, 0, 0, d0)>
// CHECK-DAG: #[[$M4:.*]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK-DAG: #[[$M3:.*]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK-LABEL: func @sum_exp_2
func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output: tensor<5x2xf32>)
@ -785,13 +782,13 @@ func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output: ten
{
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M1]]} : tensor<3x2xf32>, vector<2x3x4x5xf32>
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M2]]} : tensor<5x4xf32>, vector<2x3x4x5xf32>
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M3]]} : tensor<5x2xf32>, vector<2x3x4x5xf32>
// CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
// CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
// CHECK: addf {{.*}} : vector<2x3x4x5xf32>
// CHECK: addf {{.*}} : vector<2x3x4x5xf32>
// CHECK: vector.multi_reduction #vector.kind<add>, {{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
// CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M4]]} : vector<2x5xf32>, tensor<5x2xf32>
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : tensor<5x2xf32>, vector<2x5xf32>
// CHECK: addf {{.*}} : vector<2x5xf32>
// CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : vector<2x5xf32>, tensor<5x2xf32>
// CHECK: return {{.*}} : tensor<5x2xf32>
%0 = linalg.generic {
indexing_maps = [
@ -815,12 +812,11 @@ func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output: ten
// CHECK-LABEL: func @red_max_2d(
func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
// CHECK: %[[CMINF:.+]] = constant dense<-3.402820e+38> : vector<4xf32>
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
// CHECK: vector.transfer_read {{.*}} : tensor<4xf32>, vector<4x4xf32>
// CHECK: maxf {{.*}} : vector<4x4xf32>
// CHECK: vector.multi_reduction #vector.kind<maxf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
// CHECK: %[[R:.+]] = vector.multi_reduction #vector.kind<maxf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
// CHECK: maxf %[[R]], %[[CMINF]] : vector<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
%ident = constant -3.40282e+38 : f32
%init = linalg.init_tensor [4] : tensor<4xf32>
@ -840,12 +836,12 @@ func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
// CHECK-LABEL: func @red_min_2d(
func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
// CHECK: %[[CMAXF:.+]] = constant dense<3.402820e+38> : vector<4xf32>
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
// CHECK: vector.transfer_read {{.*}} : tensor<4xf32>, vector<4x4xf32>
// CHECK: minf {{.*}} : vector<4x4xf32>
// CHECK: vector.multi_reduction #vector.kind<minf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
// CHECK: %[[R:.+]] = vector.multi_reduction #vector.kind<minf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
// CHECK: minf %[[R]], %[[CMAXF]] : vector<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
%maxf32 = constant 3.40282e+38 : f32
%init = linalg.init_tensor [4] : tensor<4xf32>
@ -855,7 +851,7 @@ func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
iterator_types = ["parallel", "reduction"]}
ins(%arg0 : tensor<4x4xf32>) outs(%fill : tensor<4xf32>) {
^bb0(%in0: f32, %out0: f32): // no predecessors
%min = minf %in0, %out0 : f32
%min = minf %out0, %in0 : f32
linalg.yield %min : f32
} -> tensor<4xf32>
return %red : tensor<4xf32>
@ -1026,7 +1022,7 @@ func @fused_broadcast_red_2d(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>) ->
// CHECK-SAME: %[[A:.*]]: tensor<32xf32>
func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
// CHECK-DAG: %[[F0_v1:.*]] = constant dense<0.000000e+00> : vector<1xf32>
// CHECK-DAG: %[[F0_v32:.*]] = constant dense<0.000000e+00> : vector<32xf32>
// CHECK-DAG: %[[F0:.*]] = constant 0.000000e+00 : f32
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
%f0 = constant 0.000000e+00 : f32
@ -1036,13 +1032,12 @@ func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
// CHECK: %[[f:.*]] = vector.transfer_write %[[F0_v1]], %[[init]][]
// CHECK-SAME: : vector<1xf32>, tensor<f32>
%1 = linalg.fill(%f0, %0) : f32, tensor<f32> -> tensor<f32>
// CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
// CHECK-SAME: : tensor<32xf32>, vector<32xf32>
// CHECK: %[[a:.*]] = addf %[[r]], %[[F0_v32]] : vector<32xf32>
// CHECK: %[[red:.*]] = vector.multi_reduction #vector.kind<add>, %[[a]] [0]
// CHECK: %[[red:.*]] = vector.multi_reduction #vector.kind<add>, %[[r]] [0]
// CHECK-SAME: : vector<32xf32> to f32
// CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<1xf32>
// CHECK: %[[a:.*]] = addf %[[red]], %[[F0]] : f32
// CHECK: %[[red_v1:.*]] = vector.broadcast %[[a]] : f32 to vector<1xf32>
// CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][]
// CHECK-SAME: : vector<1xf32>, tensor<f32>
%2 = linalg.generic {