[mlir][vector] Refactor linalg vectorization for reductions

Emit reduction during op vectorization instead of doing it when creating the
transfer write. This allow us to not broadcast output arguments for reduction
initial value.

Differential Revision: https://reviews.llvm.org/D111825
This commit is contained in:
thomasraoux 2021-10-14 10:39:15 -07:00
parent 8b31f07cdf
commit afad0cdf31
2 changed files with 74 additions and 89 deletions

View File

@ -189,65 +189,18 @@ static Value buildVectorRead(OpBuilder &b, Value source, Type readType,
}
/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
/// assumes that `reductionOp` has tow operands and one of them is the reduction
/// assumes that `reductionOp` has two operands and one of them is the reduction
/// initial value.
static Value buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
Value outputArg,
const SmallVector<bool> &reductionMask,
const BlockAndValueMapping &bvm) {
Value valueToReduce,
const SmallVector<bool> &reductionMask) {
auto maybeKind = getKindForOp(reduceOp);
assert(maybeKind && "Failed precondition: could not get reduction kind");
Value operandToReduce = reduceOp->getOperand(0) == outputArg
? reduceOp->getOperand(1)
: reduceOp->getOperand(0);
Value vec = bvm.lookup(operandToReduce);
return b.create<vector::MultiDimReductionOp>(reduceOp->getLoc(), vec,
reductionMask, *maybeKind);
return b.create<vector::MultiDimReductionOp>(
reduceOp->getLoc(), valueToReduce, reductionMask, *maybeKind);
}
/// Read the initial value associated to the given `outputOperand`.
static Value readInitialValue(OpBuilder &b, LinalgOp linalgOp,
OpOperand *outputOperand) {
AffineMap map = inversePermutation(
reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand)));
Type readType;
if (linalgOp.getShape(outputOperand).empty()) {
readType = getElementTypeOrSelf(outputOperand->get());
} else {
readType = VectorType::get(map.compose(linalgOp.getShape(outputOperand)),
getElementTypeOrSelf(outputOperand->get()));
}
Value vectorRead = buildVectorRead(b, outputOperand->get(), readType, map);
return vectorRead;
}
/// Assuming `outputOperand` is an output operand of a LinalgOp, determine
/// whether a reduction is needed to produce a `targetType` and create that
/// reduction if it is the case.
static Value reduceIfNeeded(OpBuilder &b, Type targetType, Value value,
OpOperand *outputOperand,
const BlockAndValueMapping &bvm) {
LDBG("Reduce " << value << " to type " << targetType);
LDBG("In LinalgOp operand #" << outputOperand->getOperandNumber() << "\n"
<< *(outputOperand->getOwner()));
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
auto vecType = value.getType().dyn_cast<VectorType>();
VectorType targetVectorType = targetType.dyn_cast<VectorType>();
if (!vecType)
return value;
if (targetVectorType && vecType.getShape() == targetVectorType.getShape())
return value;
// At this point, we know we need to reduce. Detect the reduction operator.
unsigned pos = 0;
MLIRContext *ctx = b.getContext();
SmallVector<AffineExpr> exprs;
for (auto s : linalgOp.iterator_types())
if (isParallelIterator(s))
exprs.push_back(getAffineDimExpr(pos++, ctx));
Operation *reduceOp = matchLinalgReduction(outputOperand);
assert(reduceOp && "Failed precondition: could not math a reduction");
static SmallVector<bool> getReductionMask(LinalgOp linalgOp) {
unsigned idx = 0;
SmallVector<bool> reductionMask(linalgOp.iterator_types().size(), false);
for (auto attr : linalgOp.iterator_types()) {
@ -255,24 +208,7 @@ static Value reduceIfNeeded(OpBuilder &b, Type targetType, Value value,
reductionMask[idx] = true;
++idx;
}
assert(reduceOp->getNumOperands() == 2 &&
"Only support binary reduce op right now");
unsigned outputPos =
outputOperand->getOperandNumber() - linalgOp.getNumInputs();
Value outputArg = linalgOp.getRegionOutputArgs()[outputPos];
// Reduce across the iteration space.
Value reduce =
buildMultiDimReduce(b, reduceOp, outputArg, reductionMask, bvm);
// Read the original output value.
Value initialValue = readInitialValue(b, linalgOp, outputOperand);
// Combine the output argument with the reduced value.
OperationState state(reduceOp->getLoc(), reduceOp->getName());
state.addAttributes(reduceOp->getAttrs());
state.addOperands({reduce, initialValue});
state.addTypes(initialValue.getType());
return b.createOperation(state)->getResult(0);
return reductionMask;
}
/// Build a vector.transfer_write of `value` into `outputOperand` at indices set
@ -280,8 +216,7 @@ static Value reduceIfNeeded(OpBuilder &b, Type targetType, Value value,
/// currently being vectorized. If `dest` has null rank, build an memref.store.
/// Return the produced value or null if no value is produced.
static Value buildVectorWrite(OpBuilder &b, Value value,
OpOperand *outputOperand,
const BlockAndValueMapping &bvm) {
OpOperand *outputOperand) {
Operation *write;
Location loc = value.getLoc();
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
@ -296,12 +231,9 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
SmallVector<Value> indices(linalgOp.getRank(outputOperand),
b.create<arith::ConstantIndexOp>(loc, 0));
value = broadcastIfNeeded(b, value, vectorType.getShape());
value = reduceIfNeeded(b, vectorType, value, outputOperand, bvm);
write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
indices, map);
} else {
value = reduceIfNeeded(b, getElementTypeOrSelf(value), value, outputOperand,
bvm);
write = vector::TransferWriteOp::createScalarOp(
b, loc, value, outputOperand->get(), ValueRange{});
}
@ -336,7 +268,7 @@ vectorizeLinalgYield(OpBuilder &b, Operation *op,
// TODO: use a map.
Value vectorValue = bvm.lookup(outputs.value());
Value newResult = buildVectorWrite(
b, vectorValue, linalgOp.getOutputOperand(outputs.index()), bvm);
b, vectorValue, linalgOp.getOutputOperand(outputs.index()));
if (newResult)
newResults.push_back(newResult);
}
@ -379,6 +311,36 @@ static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
}
/// Create a new vectorized verstion of `op` with the given operands and types.
static Operation *createVectorizedOp(OpBuilder &b, Operation *op,
ValueRange newOperands,
ArrayRef<Type> types) {
OperationState state(op->getLoc(), op->getName());
state.addAttributes(op->getAttrs());
state.addOperands(newOperands);
state.addTypes(types);
return b.createOperation(state);
}
/// Emit reduction operations if the shapes of the value to reduce is different
/// that the result shape.
static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
Value reduceValue, Value initialValue,
const BlockAndValueMapping &bvm) {
Value reduceVec = bvm.lookup(reduceValue);
Value outputVec = bvm.lookup(initialValue);
auto reduceType = reduceVec.getType().dyn_cast<VectorType>();
auto outputType = outputVec.getType().dyn_cast<VectorType>();
// Reduce only if needed as the value may already have been reduce for
// contraction vectorization.
if (!reduceType ||
(outputType && reduceType.getShape() == outputType.getShape()))
return nullptr;
SmallVector<bool> reductionMask = getReductionMask(linalgOp);
Value reduce = buildMultiDimReduce(b, op, reduceVec, reductionMask);
return createVectorizedOp(b, op, {reduce, outputVec}, reduce.getType());
}
/// Generic vectorization for a single operation `op`, given already vectorized
/// operands carried by `bvm`. Vectorization occurs as follows:
/// 1. Try to apply any of the `customVectorizationHooks` and return its
@ -399,7 +361,8 @@ static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
/// This function does not update `bvm` but returns a VectorizationStatus that
/// instructs the caller what `bvm` update needs to occur.
static VectorizationResult
vectorizeOneOp(OpBuilder &b, Operation *op, const BlockAndValueMapping &bvm,
vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op,
const BlockAndValueMapping &bvm,
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
LDBG("vectorize op " << *op);
@ -422,7 +385,30 @@ vectorizeOneOp(OpBuilder &b, Operation *op, const BlockAndValueMapping &bvm,
if (!OpTrait::hasElementwiseMappableTraits(op))
return VectorizationResult{VectorizationStatus::Failure, nullptr};
// 4. Generic vectorization path for ElementwiseMappable ops.
// 4 . Check if the operation is a reduction.
SmallVector<std::pair<Value, Value>> reductionOperands;
for (Value operand : op->getOperands()) {
auto arg = operand.dyn_cast<BlockArgument>();
if (!arg || arg.getArgNumber() < linalgOp.getNumInputs())
continue;
SmallVector<Operation *> reductionOps;
Value reduceValue = matchReduction(
linalgOp.getRegionOutputArgs(),
arg.getArgNumber() - linalgOp.getNumInputs(), reductionOps);
if (!reduceValue)
continue;
reductionOperands.push_back(std::make_pair(reduceValue, operand));
}
if (!reductionOperands.empty()) {
assert(reductionOperands.size() == 1);
Operation *reduceOp =
reduceIfNeeded(b, linalgOp, op, reductionOperands[0].first,
reductionOperands[0].second, bvm);
if (reduceOp)
return VectorizationResult{VectorizationStatus::NewOp, reduceOp};
}
// 5. Generic vectorization path for ElementwiseMappable ops.
// a. first get the first max ranked shape.
SmallVector<int64_t, 4> firstMaxRankedShape;
for (Value operand : op->getOperands()) {
@ -444,12 +430,10 @@ vectorizeOneOp(OpBuilder &b, Operation *op, const BlockAndValueMapping &bvm,
});
// Build and return the new op.
OperationState state(op->getLoc(), op->getName());
state.addAttributes(op->getAttrs());
state.addOperands(llvm::to_vector<4>(vectorizedOperands));
state.addTypes(llvm::to_vector<4>(returnTypes));
return VectorizationResult{VectorizationStatus::NewOp,
b.createOperation(state)};
return VectorizationResult{
VectorizationStatus::NewOp,
createVectorizedOp(b, op, llvm::to_vector<4>(vectorizedOperands),
llvm::to_vector<4>(returnTypes))};
}
/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
@ -539,7 +523,8 @@ LogicalResult vectorizeAsLinalgGeneric(
if (linalgOp.getShape(opOperand).empty()) {
readType = bbarg.getType();
} else {
if (broadcastToMaximalCommonShape) {
if (broadcastToMaximalCommonShape &&
opOperand->getOperandNumber() < linalgOp.getNumInputs()) {
map = inverseAndBroadcastProjectedPermuation(
linalgOp.getTiedIndexingMap(opOperand));
readType = VectorType::get(commonVectorShape,
@ -576,7 +561,7 @@ LogicalResult vectorizeAsLinalgGeneric(
// 5. Iteratively call `vectorizeOneOp` to each op in the slice.
for (Operation &op : block->getOperations()) {
VectorizationResult result = vectorizeOneOp(b, &op, bvm, hooks);
VectorizationResult result = vectorizeOneOp(b, linalgOp, &op, bvm, hooks);
if (result.status == VectorizationStatus::Failure) {
LDBG("failed to vectorize: " << op);
return failure();

View File

@ -749,9 +749,9 @@ func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>)
-> tensor<4x16xf32>
{
// CHECK: vector.transfer_read {{.*}} : tensor<4x16x8xf32>, vector<4x16x8xf32>
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32>
// CHECK: math.exp {{.*}} : vector<4x16x8xf32>
// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32>
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32>
// CHECK: addf {{.*}} : vector<4x16xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4x16xf32>, tensor<4x16xf32>
// CHECK: return {{.*}} : tensor<4x16xf32>
@ -782,11 +782,11 @@ func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output: ten
{
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M1]]} : tensor<3x2xf32>, vector<2x3x4x5xf32>
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M2]]} : tensor<5x4xf32>, vector<2x3x4x5xf32>
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : tensor<5x2xf32>, vector<2x5xf32>
// CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
// CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
// CHECK: addf {{.*}} : vector<2x3x4x5xf32>
// CHECK: vector.multi_reduction #vector.kind<add>, {{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : tensor<5x2xf32>, vector<2x5xf32>
// CHECK: addf {{.*}} : vector<2x5xf32>
// CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : vector<2x5xf32>, tensor<5x2xf32>
// CHECK: return {{.*}} : tensor<5x2xf32>