[mlir][vector] Refactor linalg vectorization for reductions

Emit reduction during op vectorization instead of doing it when creating the
transfer write. This allow us to not broadcast output arguments for reduction
initial value.

Differential Revision: https://reviews.llvm.org/D111825
This commit is contained in:
thomasraoux 2021-10-14 10:39:15 -07:00
parent 8b31f07cdf
commit afad0cdf31
2 changed files with 74 additions and 89 deletions

View File

@ -189,65 +189,18 @@ static Value buildVectorRead(OpBuilder &b, Value source, Type readType,
} }
/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
/// assumes that `reductionOp` has tow operands and one of them is the reduction /// assumes that `reductionOp` has two operands and one of them is the reduction
/// initial value. /// initial value.
static Value buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, static Value buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
Value outputArg, Value valueToReduce,
const SmallVector<bool> &reductionMask, const SmallVector<bool> &reductionMask) {
const BlockAndValueMapping &bvm) {
auto maybeKind = getKindForOp(reduceOp); auto maybeKind = getKindForOp(reduceOp);
assert(maybeKind && "Failed precondition: could not get reduction kind"); assert(maybeKind && "Failed precondition: could not get reduction kind");
Value operandToReduce = reduceOp->getOperand(0) == outputArg return b.create<vector::MultiDimReductionOp>(
? reduceOp->getOperand(1) reduceOp->getLoc(), valueToReduce, reductionMask, *maybeKind);
: reduceOp->getOperand(0);
Value vec = bvm.lookup(operandToReduce);
return b.create<vector::MultiDimReductionOp>(reduceOp->getLoc(), vec,
reductionMask, *maybeKind);
} }
/// Read the initial value associated to the given `outputOperand`. static SmallVector<bool> getReductionMask(LinalgOp linalgOp) {
static Value readInitialValue(OpBuilder &b, LinalgOp linalgOp,
OpOperand *outputOperand) {
AffineMap map = inversePermutation(
reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand)));
Type readType;
if (linalgOp.getShape(outputOperand).empty()) {
readType = getElementTypeOrSelf(outputOperand->get());
} else {
readType = VectorType::get(map.compose(linalgOp.getShape(outputOperand)),
getElementTypeOrSelf(outputOperand->get()));
}
Value vectorRead = buildVectorRead(b, outputOperand->get(), readType, map);
return vectorRead;
}
/// Assuming `outputOperand` is an output operand of a LinalgOp, determine
/// whether a reduction is needed to produce a `targetType` and create that
/// reduction if it is the case.
static Value reduceIfNeeded(OpBuilder &b, Type targetType, Value value,
OpOperand *outputOperand,
const BlockAndValueMapping &bvm) {
LDBG("Reduce " << value << " to type " << targetType);
LDBG("In LinalgOp operand #" << outputOperand->getOperandNumber() << "\n"
<< *(outputOperand->getOwner()));
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
auto vecType = value.getType().dyn_cast<VectorType>();
VectorType targetVectorType = targetType.dyn_cast<VectorType>();
if (!vecType)
return value;
if (targetVectorType && vecType.getShape() == targetVectorType.getShape())
return value;
// At this point, we know we need to reduce. Detect the reduction operator.
unsigned pos = 0;
MLIRContext *ctx = b.getContext();
SmallVector<AffineExpr> exprs;
for (auto s : linalgOp.iterator_types())
if (isParallelIterator(s))
exprs.push_back(getAffineDimExpr(pos++, ctx));
Operation *reduceOp = matchLinalgReduction(outputOperand);
assert(reduceOp && "Failed precondition: could not math a reduction");
unsigned idx = 0; unsigned idx = 0;
SmallVector<bool> reductionMask(linalgOp.iterator_types().size(), false); SmallVector<bool> reductionMask(linalgOp.iterator_types().size(), false);
for (auto attr : linalgOp.iterator_types()) { for (auto attr : linalgOp.iterator_types()) {
@ -255,24 +208,7 @@ static Value reduceIfNeeded(OpBuilder &b, Type targetType, Value value,
reductionMask[idx] = true; reductionMask[idx] = true;
++idx; ++idx;
} }
assert(reduceOp->getNumOperands() == 2 && return reductionMask;
"Only support binary reduce op right now");
unsigned outputPos =
outputOperand->getOperandNumber() - linalgOp.getNumInputs();
Value outputArg = linalgOp.getRegionOutputArgs()[outputPos];
// Reduce across the iteration space.
Value reduce =
buildMultiDimReduce(b, reduceOp, outputArg, reductionMask, bvm);
// Read the original output value.
Value initialValue = readInitialValue(b, linalgOp, outputOperand);
// Combine the output argument with the reduced value.
OperationState state(reduceOp->getLoc(), reduceOp->getName());
state.addAttributes(reduceOp->getAttrs());
state.addOperands({reduce, initialValue});
state.addTypes(initialValue.getType());
return b.createOperation(state)->getResult(0);
} }
/// Build a vector.transfer_write of `value` into `outputOperand` at indices set /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
@ -280,8 +216,7 @@ static Value reduceIfNeeded(OpBuilder &b, Type targetType, Value value,
/// currently being vectorized. If `dest` has null rank, build an memref.store. /// currently being vectorized. If `dest` has null rank, build an memref.store.
/// Return the produced value or null if no value is produced. /// Return the produced value or null if no value is produced.
static Value buildVectorWrite(OpBuilder &b, Value value, static Value buildVectorWrite(OpBuilder &b, Value value,
OpOperand *outputOperand, OpOperand *outputOperand) {
const BlockAndValueMapping &bvm) {
Operation *write; Operation *write;
Location loc = value.getLoc(); Location loc = value.getLoc();
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner()); auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
@ -296,12 +231,9 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
SmallVector<Value> indices(linalgOp.getRank(outputOperand), SmallVector<Value> indices(linalgOp.getRank(outputOperand),
b.create<arith::ConstantIndexOp>(loc, 0)); b.create<arith::ConstantIndexOp>(loc, 0));
value = broadcastIfNeeded(b, value, vectorType.getShape()); value = broadcastIfNeeded(b, value, vectorType.getShape());
value = reduceIfNeeded(b, vectorType, value, outputOperand, bvm);
write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(), write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
indices, map); indices, map);
} else { } else {
value = reduceIfNeeded(b, getElementTypeOrSelf(value), value, outputOperand,
bvm);
write = vector::TransferWriteOp::createScalarOp( write = vector::TransferWriteOp::createScalarOp(
b, loc, value, outputOperand->get(), ValueRange{}); b, loc, value, outputOperand->get(), ValueRange{});
} }
@ -336,7 +268,7 @@ vectorizeLinalgYield(OpBuilder &b, Operation *op,
// TODO: use a map. // TODO: use a map.
Value vectorValue = bvm.lookup(outputs.value()); Value vectorValue = bvm.lookup(outputs.value());
Value newResult = buildVectorWrite( Value newResult = buildVectorWrite(
b, vectorValue, linalgOp.getOutputOperand(outputs.index()), bvm); b, vectorValue, linalgOp.getOutputOperand(outputs.index()));
if (newResult) if (newResult)
newResults.push_back(newResult); newResults.push_back(newResult);
} }
@ -379,6 +311,36 @@ static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
return VectorizationResult{VectorizationStatus::NewOp, transposeOp}; return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
} }
/// Create a new vectorized verstion of `op` with the given operands and types.
static Operation *createVectorizedOp(OpBuilder &b, Operation *op,
ValueRange newOperands,
ArrayRef<Type> types) {
OperationState state(op->getLoc(), op->getName());
state.addAttributes(op->getAttrs());
state.addOperands(newOperands);
state.addTypes(types);
return b.createOperation(state);
}
/// Emit reduction operations if the shapes of the value to reduce is different
/// that the result shape.
static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
Value reduceValue, Value initialValue,
const BlockAndValueMapping &bvm) {
Value reduceVec = bvm.lookup(reduceValue);
Value outputVec = bvm.lookup(initialValue);
auto reduceType = reduceVec.getType().dyn_cast<VectorType>();
auto outputType = outputVec.getType().dyn_cast<VectorType>();
// Reduce only if needed as the value may already have been reduce for
// contraction vectorization.
if (!reduceType ||
(outputType && reduceType.getShape() == outputType.getShape()))
return nullptr;
SmallVector<bool> reductionMask = getReductionMask(linalgOp);
Value reduce = buildMultiDimReduce(b, op, reduceVec, reductionMask);
return createVectorizedOp(b, op, {reduce, outputVec}, reduce.getType());
}
/// Generic vectorization for a single operation `op`, given already vectorized /// Generic vectorization for a single operation `op`, given already vectorized
/// operands carried by `bvm`. Vectorization occurs as follows: /// operands carried by `bvm`. Vectorization occurs as follows:
/// 1. Try to apply any of the `customVectorizationHooks` and return its /// 1. Try to apply any of the `customVectorizationHooks` and return its
@ -399,7 +361,8 @@ static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
/// This function does not update `bvm` but returns a VectorizationStatus that /// This function does not update `bvm` but returns a VectorizationStatus that
/// instructs the caller what `bvm` update needs to occur. /// instructs the caller what `bvm` update needs to occur.
static VectorizationResult static VectorizationResult
vectorizeOneOp(OpBuilder &b, Operation *op, const BlockAndValueMapping &bvm, vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op,
const BlockAndValueMapping &bvm,
ArrayRef<CustomVectorizationHook> customVectorizationHooks) { ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
LDBG("vectorize op " << *op); LDBG("vectorize op " << *op);
@ -422,7 +385,30 @@ vectorizeOneOp(OpBuilder &b, Operation *op, const BlockAndValueMapping &bvm,
if (!OpTrait::hasElementwiseMappableTraits(op)) if (!OpTrait::hasElementwiseMappableTraits(op))
return VectorizationResult{VectorizationStatus::Failure, nullptr}; return VectorizationResult{VectorizationStatus::Failure, nullptr};
// 4. Generic vectorization path for ElementwiseMappable ops. // 4 . Check if the operation is a reduction.
SmallVector<std::pair<Value, Value>> reductionOperands;
for (Value operand : op->getOperands()) {
auto arg = operand.dyn_cast<BlockArgument>();
if (!arg || arg.getArgNumber() < linalgOp.getNumInputs())
continue;
SmallVector<Operation *> reductionOps;
Value reduceValue = matchReduction(
linalgOp.getRegionOutputArgs(),
arg.getArgNumber() - linalgOp.getNumInputs(), reductionOps);
if (!reduceValue)
continue;
reductionOperands.push_back(std::make_pair(reduceValue, operand));
}
if (!reductionOperands.empty()) {
assert(reductionOperands.size() == 1);
Operation *reduceOp =
reduceIfNeeded(b, linalgOp, op, reductionOperands[0].first,
reductionOperands[0].second, bvm);
if (reduceOp)
return VectorizationResult{VectorizationStatus::NewOp, reduceOp};
}
// 5. Generic vectorization path for ElementwiseMappable ops.
// a. first get the first max ranked shape. // a. first get the first max ranked shape.
SmallVector<int64_t, 4> firstMaxRankedShape; SmallVector<int64_t, 4> firstMaxRankedShape;
for (Value operand : op->getOperands()) { for (Value operand : op->getOperands()) {
@ -444,12 +430,10 @@ vectorizeOneOp(OpBuilder &b, Operation *op, const BlockAndValueMapping &bvm,
}); });
// Build and return the new op. // Build and return the new op.
OperationState state(op->getLoc(), op->getName()); return VectorizationResult{
state.addAttributes(op->getAttrs()); VectorizationStatus::NewOp,
state.addOperands(llvm::to_vector<4>(vectorizedOperands)); createVectorizedOp(b, op, llvm::to_vector<4>(vectorizedOperands),
state.addTypes(llvm::to_vector<4>(returnTypes)); llvm::to_vector<4>(returnTypes))};
return VectorizationResult{VectorizationStatus::NewOp,
b.createOperation(state)};
} }
/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp. /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
@ -539,7 +523,8 @@ LogicalResult vectorizeAsLinalgGeneric(
if (linalgOp.getShape(opOperand).empty()) { if (linalgOp.getShape(opOperand).empty()) {
readType = bbarg.getType(); readType = bbarg.getType();
} else { } else {
if (broadcastToMaximalCommonShape) { if (broadcastToMaximalCommonShape &&
opOperand->getOperandNumber() < linalgOp.getNumInputs()) {
map = inverseAndBroadcastProjectedPermuation( map = inverseAndBroadcastProjectedPermuation(
linalgOp.getTiedIndexingMap(opOperand)); linalgOp.getTiedIndexingMap(opOperand));
readType = VectorType::get(commonVectorShape, readType = VectorType::get(commonVectorShape,
@ -576,7 +561,7 @@ LogicalResult vectorizeAsLinalgGeneric(
// 5. Iteratively call `vectorizeOneOp` to each op in the slice. // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
for (Operation &op : block->getOperations()) { for (Operation &op : block->getOperations()) {
VectorizationResult result = vectorizeOneOp(b, &op, bvm, hooks); VectorizationResult result = vectorizeOneOp(b, linalgOp, &op, bvm, hooks);
if (result.status == VectorizationStatus::Failure) { if (result.status == VectorizationStatus::Failure) {
LDBG("failed to vectorize: " << op); LDBG("failed to vectorize: " << op);
return failure(); return failure();

View File

@ -749,9 +749,9 @@ func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>)
-> tensor<4x16xf32> -> tensor<4x16xf32>
{ {
// CHECK: vector.transfer_read {{.*}} : tensor<4x16x8xf32>, vector<4x16x8xf32> // CHECK: vector.transfer_read {{.*}} : tensor<4x16x8xf32>, vector<4x16x8xf32>
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32>
// CHECK: math.exp {{.*}} : vector<4x16x8xf32> // CHECK: math.exp {{.*}} : vector<4x16x8xf32>
// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32> // CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32>
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32>
// CHECK: addf {{.*}} : vector<4x16xf32> // CHECK: addf {{.*}} : vector<4x16xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4x16xf32>, tensor<4x16xf32> // CHECK: vector.transfer_write {{.*}} : vector<4x16xf32>, tensor<4x16xf32>
// CHECK: return {{.*}} : tensor<4x16xf32> // CHECK: return {{.*}} : tensor<4x16xf32>
@ -782,11 +782,11 @@ func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output: ten
{ {
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M1]]} : tensor<3x2xf32>, vector<2x3x4x5xf32> // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M1]]} : tensor<3x2xf32>, vector<2x3x4x5xf32>
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M2]]} : tensor<5x4xf32>, vector<2x3x4x5xf32> // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M2]]} : tensor<5x4xf32>, vector<2x3x4x5xf32>
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : tensor<5x2xf32>, vector<2x5xf32>
// CHECK: math.exp {{.*}} : vector<2x3x4x5xf32> // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
// CHECK: math.exp {{.*}} : vector<2x3x4x5xf32> // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
// CHECK: addf {{.*}} : vector<2x3x4x5xf32> // CHECK: addf {{.*}} : vector<2x3x4x5xf32>
// CHECK: vector.multi_reduction #vector.kind<add>, {{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> // CHECK: vector.multi_reduction #vector.kind<add>, {{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : tensor<5x2xf32>, vector<2x5xf32>
// CHECK: addf {{.*}} : vector<2x5xf32> // CHECK: addf {{.*}} : vector<2x5xf32>
// CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : vector<2x5xf32>, tensor<5x2xf32> // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : vector<2x5xf32>, tensor<5x2xf32>
// CHECK: return {{.*}} : tensor<5x2xf32> // CHECK: return {{.*}} : tensor<5x2xf32>