[mlir][Linalg] Refactor Linalg vectorization for better reuse and extensibility.

This revision unifies Linalg vectorization and paves the way for vectorization of Linalg ops with mixed-precision operations.
The new algorithm traverses the ops in the linalg block in order and avoids recursion.
It uses a BlockAndValueMapping to keep track of vectorized operations.

The revision makes the following modifications but is otherwise NFC:
1. vector.transfer_read are created eagerly and may appear in a different order than the original order.
2. a more progressive vectorization to vector.contract results in only the multiply operation being converted to `vector.contract %a, %b, %zero`, where `%zero` is a
constant of the proper type. Later vector canonicalizations are assumed to rewrite vector.contract %a, %b, %zero + add to a proper accumulate form.

Differential revision: https://reviews.llvm.org/D95797
This commit is contained in:
Nicolas Vasilache 2021-02-02 11:19:21 +00:00
parent 4d904776a7
commit 0a2a260aab
2 changed files with 321 additions and 191 deletions

View File

@ -23,6 +23,8 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>
@ -36,6 +38,275 @@ using llvm::dbgs;
#define DEBUG_TYPE "linalg-vectorization"
/// Helper data structure to represent the result of vectorization.
/// In certain specific cases, like terminators, we do not want to propagate/
enum VectorizationStatus {
/// Op failed to vectorize.
Failure = 0,
/// Op vectorized and custom function took care of replacement logic
NoReplace,
/// Op vectorized into a new Op whose results will replace original Op's
/// results.
NewOp
// TODO: support values if Op vectorized to Many-Ops whose results we need to
// aggregate for replacement.
};
struct VectorizationResult {
/// Return status from vectorizing the current op.
enum VectorizationStatus status = VectorizationStatus::Failure;
/// New vectorized operation to replace the current op.
/// Replacement behavior is specified by `status`.
Operation *newOp;
};
/// Return a vector type of the same shape and element type as the (assumed)
/// ShapedType of `v`.
static VectorType extractVectorTypeFromShapedValue(Value v) {
auto st = v.getType().cast<ShapedType>();
if (st.isa<MemRefType>() && st.getShape().empty())
return VectorType();
return VectorType::get(st.getShape(), st.getElementType());
}
/// Build a vector.transfer_read from `source` at indices set to all `0`.
/// If source has rank zero, build an std.load.
/// Return the produced value.
static Value buildVectorRead(OpBuilder &builder, Value source) {
edsc::ScopedContext scope(builder);
auto shapedType = source.getType().cast<ShapedType>();
if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) {
SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
return vector_transfer_read(vectorType, source, indices);
}
return std_load(source);
}
/// Build a vector.transfer_write of `value` into `dest` at indices set to all
/// `0`. If `dest` has null rank, build an std.store.
/// Return the produced value or null if no value is produced.
static Value buildVectorWrite(OpBuilder &builder, Value value, Value dest) {
edsc::ScopedContext scope(builder);
Operation *write;
auto shapedType = dest.getType().cast<ShapedType>();
if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) {
SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
if (vectorType != value.getType())
value = vector_broadcast(vectorType, value);
write = vector_transfer_write(value, dest, indices);
} else {
write = std_store(value, dest);
}
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
if (!write->getResults().empty())
return write->getResult(0);
return Value();
}
/// If value of assumed VectorType has a shape different than `shape`, buil and
/// return a new vector.broadcast to `shape`.
/// Otherwise, just return value.
static Value broadcastIfNeeded(OpBuilder &builder, Value value,
ArrayRef<int64_t> shape) {
auto vecType = value.getType().dyn_cast<VectorType>();
if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape))
return value;
auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
: value.getType());
return builder.create<vector::BroadcastOp>(
builder.getInsertionPoint()->getLoc(), newVecType, value);
}
// Custom vectorization function type. Produce a vector form of Operation*
// assuming all its vectorized operands are already in the BlockAndValueMapping.
// Return nullptr if the Operation cannot be vectorized.
using CustomVectorizationHook = std::function<VectorizationResult(
Operation *, const BlockAndValueMapping &)>;
/// Helper function to vectorize the terminator of a `linalgOp`. New result
/// vector values are appended to `results`.
/// Return VectorizationStatus::NoReplace to signal the vectorization algorithm
/// that it should not try to map produced operations: this is the purpose of
/// the `results` argument to capture such values and make them available for
/// RAUW to the vectorization algorithm.
/// This function is meant to be used as a CustomVectorizationHook.
static VectorizationResult
vectorizeLinalgYield(OpBuilder &builder, Operation *op,
const BlockAndValueMapping &bvm, LinalgOp linalgOp,
SmallVectorImpl<Value> &results) {
auto yieldOp = dyn_cast<linalg::YieldOp>(op);
if (!yieldOp)
return VectorizationResult{VectorizationStatus::Failure, nullptr};
for (auto outputs : llvm::enumerate(yieldOp.values())) {
// TODO: Scan for an opportunity for reuse.
// TODO: use a map.
Value vectorValue = bvm.lookup(outputs.value());
Value result = buildVectorWrite(builder, vectorValue,
linalgOp.getOutput(outputs.index()));
if (result)
results.push_back(result);
}
return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
};
/// Generic vectorization for a single operation `op`, given already vectorized
/// operands carried by `bvm`. Vectorization occurs as follows:
/// 1. Try to apply any of the `customVectorizationHooks` and return its
/// result on success.
/// 2. Clone any constant in the current scope without vectorization: each
/// consumer of the constant will later determine the shape to which the
/// constant needs to be broadcast to.
/// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
/// of the `customVectorizationHooks` to cover such cases.
/// 4. Clone `op` in vector form to a vector of shape prescribed by the first
/// operand of maximal rank. Other operands have smaller rank and are
/// broadcast accordingly. It is assumed this broadcast is always legal,
/// otherwise, it means one of the `customVectorizationHooks` is incorrect.
///
/// This function assumes all operands of `op` have been vectorized and are in
/// the `bvm` mapping. As a consequence, this function is meant to be called on
/// a topologically-sorted list of ops.
/// This function does not update `bvm` but returns a VectorizationStatus that
/// instructs the caller what `bvm` update needs to occur.
static VectorizationResult
vectorizeOneOp(OpBuilder &builder, Operation *op,
const BlockAndValueMapping &bvm,
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorize op " << *op);
// 1. Try to apply any CustomVectorizationHook.
if (!customVectorizationHooks.empty()) {
for (auto &customFunc : customVectorizationHooks) {
VectorizationResult result = customFunc(op, bvm);
if (result.status == VectorizationStatus::Failure)
continue;
return result;
}
}
// 2. Constant ops don't get vectorized but rather broadcasted at their users.
// Clone so that the constant is not confined to the linalgOp block .
if (isa<ConstantOp>(op))
return VectorizationResult{VectorizationStatus::NewOp, builder.clone(*op)};
// 3. Only ElementwiseMappable are allowed in the generic vectorization.
if (!op->hasTrait<OpTrait::ElementwiseMappable>())
return VectorizationResult{VectorizationStatus::Failure, nullptr};
// 4. Generic vectorization path for ElementwiseMappable ops.
// a. first get the first max ranked shape.
SmallVector<int64_t, 4> firstMaxRankedShape;
for (Value operand : op->getOperands()) {
auto vt = bvm.lookup(operand).getType().dyn_cast<VectorType>();
if (vt && firstMaxRankedShape.size() < vt.getShape().size())
firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end());
}
// b. broadcast each op if needed.
auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) {
return firstMaxRankedShape.empty()
? bvm.lookup(v)
: broadcastIfNeeded(builder, bvm.lookup(v), firstMaxRankedShape);
});
// c. for elementwise, the result is the vector with the firstMaxRankedShape
auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) {
return firstMaxRankedShape.empty()
? t
: VectorType::get(firstMaxRankedShape, t);
});
// Build and return the new op.
OperationState state(op->getLoc(), op->getName());
state.addAttributes(op->getAttrs());
state.addOperands(llvm::to_vector<4>(vectorizedOperands));
state.addTypes(llvm::to_vector<4>(returnTypes));
return VectorizationResult{VectorizationStatus::NewOp,
builder.createOperation(state)};
}
/// Generic vectorization function that rewrites the body of a `linalgOp` into
/// vector form. Generic vectorization proceeds as follows:
/// 1. The region for the linalg op is created if necessary.
/// 2. Values defined above the region are mapped to themselves and will be
/// broadcasted on a per-need basis by their consumers.
/// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
/// load).
/// TODO: Reuse opportunities for RAR dependencies.
/// 4. Register CustomVectorizationHook for YieldOp to capture the results.
/// 5. Iteratively call vectorizeOneOp on the region operations.
/// 6. RAUW the linalg op by the results captured vectorizing the YieldOp.
static LogicalResult vectorizeAsLinalgGeneric(
OpBuilder &builder, LinalgOp linalgOp,
ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
// 1. Certain Linalg ops do not have a region but only a region builder.
// If so, build the region so we can vectorize.
std::unique_ptr<Region> owningRegion;
Region *region;
if (linalgOp->getNumRegions() > 0) {
region = &linalgOp->getRegion(0);
} else {
// RAII avoid remaining in block.
OpBuilder::InsertionGuard g(builder);
owningRegion = std::make_unique<Region>();
region = owningRegion.get();
Block *block = builder.createBlock(region);
auto elementTypes = llvm::to_vector<4>(
llvm::map_range(linalgOp.getShapedOperandTypes(),
[](ShapedType t) { return t.getElementType(); }));
block->addArguments(elementTypes);
linalgOp.getRegionBuilder()(*block);
}
Block *block = &region->front();
BlockAndValueMapping bvm;
// 2. Values defined above the region can only be broadcast for now. Make them
// map to themselves.
llvm::SetVector<Value> valuesSet;
mlir::getUsedValuesDefinedAbove(*region, valuesSet);
bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
// 3. Turn all BBArgs into vector.transfer_read / load.
SmallVector<AffineMap> indexings;
for (auto bbarg : block->getArguments()) {
Value vectorArg = linalgOp.getShapedOperand(bbarg.getArgNumber());
Value vectorRead = buildVectorRead(builder, vectorArg);
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
<< bbarg.getArgNumber() << "): " << vectorRead);
bvm.map(bbarg, vectorRead);
bvm.map(vectorArg, vectorRead);
}
// 4. Register CustomVectorizationHook for yieldOp.
SmallVector<Value> results;
CustomVectorizationHook vectorizeYield =
[&](Operation *op,
const BlockAndValueMapping &bvm) -> VectorizationResult {
return vectorizeLinalgYield(builder, op, bvm, linalgOp, results);
};
// Append the vectorizeYield hook.
auto hooks = llvm::to_vector<4>(customVectorizationHooks);
hooks.push_back(vectorizeYield);
// 5. Iteratively call `vectorizeOneOp` to each op in the slice.
for (Operation &op : block->getOperations()) {
VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks);
if (result.status == VectorizationStatus::Failure) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
return failure();
}
if (result.status == VectorizationStatus::NewOp) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
<< *result.newOp;);
bvm.map(op.getResults(), result.newOp->getResults());
}
}
// 6. RAUW the linalg op by the results captured vectorizing the YieldOp.
if (!results.empty())
linalgOp->replaceAllUsesWith(results);
return success();
}
/// Detect whether `r` exactly computes a floating-point or integer
/// multiply-accumulate.
static bool hasMultiplyAddBody(Region &r) {
if (!llvm::hasSingleElement(r))
return false;
@ -65,6 +336,7 @@ static bool hasMultiplyAddBody(Region &r) {
pattern7.match(&r.front().back()) || pattern8.match(&r.front().back());
}
/// Detect whether the LinalgOp `op` is a contraction.
// TODO: Should be Tablegen'd from a single source that generates the op itself.
static LogicalResult isContraction(Operation *op) {
// TODO: interface for named ops.
@ -84,6 +356,7 @@ static LogicalResult isContraction(Operation *op) {
hasMultiplyAddBody(genericOp.region()));
}
/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
static bool hasOnlyScalarElementwiseOp(Region &r) {
if (!llvm::hasSingleElement(r))
return false;
@ -119,171 +392,6 @@ static bool isElementwise(Operation *op) {
return hasOnlyScalarElementwiseOp(genericOp.getRegion());
}
static VectorType extractVectorTypeFromShapedValue(Value v) {
auto st = v.getType().cast<ShapedType>();
if (st.isa<MemRefType>() && st.getShape().empty())
return VectorType();
return VectorType::get(st.getShape(), st.getElementType());
}
static Value transferReadVector(OpBuilder &builder, Value source) {
edsc::ScopedContext scope(builder);
auto shapedType = source.getType().cast<ShapedType>();
if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) {
SmallVector<Value, 4> indices(shapedType.getRank(), std_constant_index(0));
return vector_transfer_read(vectorType, source, indices);
}
return std_load(source);
}
static Value transferWriteVector(OpBuilder &builder, Value value, Value dest) {
edsc::ScopedContext scope(builder);
Operation *write;
auto shapedType = dest.getType().cast<ShapedType>();
if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) {
SmallVector<Value, 4> indices(shapedType.getRank(), std_constant_index(0));
if (vectorType != value.getType())
value = vector_broadcast(vectorType, value);
write = vector_transfer_write(value, dest, indices);
} else {
write = std_store(value, dest);
}
if (!write->getResults().empty())
return write->getResult(0);
return Value();
}
namespace {
// Transforms scalar operations into their vectorized counterparts,
// while using the provided generic op to map:
// * Its arguments to transfer reads from the views of the generic op.
// * linalg.yield ops to transfer writes to the views of the generic op.
class GenericVectorizer {
public:
GenericVectorizer(OpBuilder &builder, linalg::GenericOp generic)
: builder(builder), generic(generic) {}
// Takes a scalar operation and builds its vectorized counterpart or
// counterparts using the underlying builder.
// If operands of the scalar operation are referring to previously vectorized
// operations, then in their vectorized form these operands will be referring
// to previous vectorization results.
void vectorize(Operation &scalarOp) {
auto yieldOp = dyn_cast<linalg::YieldOp>(scalarOp);
if (yieldOp) {
for (auto outputs : llvm::enumerate(yieldOp.values())) {
Value vectorValue = vectorize(outputs.value());
Value result = transferWriteVector(builder, vectorValue,
generic.getOutput(outputs.index()));
if (result)
results.push_back(result);
}
return;
}
Operation *vectorOp = uncachedVectorize(scalarOp);
assert(scalarOp.getNumResults() == vectorOp->getNumResults());
for (auto result :
llvm::zip(scalarOp.getResults(), vectorOp->getResults())) {
valueCache[std::get<0>(result)] = std::get<1>(result);
}
}
llvm::ArrayRef<Value> getResults() { return results; }
private:
// Transforms a scalar value into its vectorized counterpart, recursively
// vectorizing operations as necessary using the underlying builder.
// Keeps track of previously vectorized values and reuses vectorization
// results if these values come up again.
Value vectorize(Value scalarValue) {
// Don't vectorize values coming from outside the region.
if (scalarValue.getParentRegion() != &generic.region())
return scalarValue;
auto vectorValueIt = valueCache.find(scalarValue);
if (vectorValueIt != valueCache.end())
return vectorValueIt->second;
// If the value is from the region but not in the cache it means it is a
// block argument.
auto scalarArg = scalarValue.cast<BlockArgument>();
assert(scalarArg.getOwner() == &generic.region().front());
Value vectorArg = generic.getShapedOperand(scalarArg.getArgNumber());
Value vectorResult = transferReadVector(builder, vectorArg);
valueCache[scalarArg] = vectorResult;
return vectorResult;
}
// Return the largest shape of all the given values. Return an empty
// SmallVector if there are no vector value.
static SmallVector<int64_t, 4> getLargestShape(ArrayRef<Value> values) {
SmallVector<int64_t, 4> largestShape;
int64_t maxSize = 1;
for (Value value : values) {
auto vecType = value.getType().dyn_cast<VectorType>();
if (!vecType)
continue;
if (maxSize < vecType.getNumElements()) {
maxSize = vecType.getNumElements();
largestShape.assign(vecType.getShape().begin(),
vecType.getShape().end());
}
}
return largestShape;
}
// If the value's type doesn't have the given shape broadcast it.
Value broadcastIfNeeded(Value value, ArrayRef<int64_t> shape) {
auto vecType = value.getType().dyn_cast<VectorType>();
if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape))
return value;
auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
: value.getType());
return builder.create<vector::BroadcastOp>(
builder.getInsertionPoint()->getLoc(), newVecType, value);
}
// Takes a scalar operation and builds its vectorized counterpart or
// counterparts using underlying builder without involving any caches.
Operation *uncachedVectorize(Operation &base_scalarOp) {
SmallVector<Value, 4> vectorizedOperands;
for (Value operand : base_scalarOp.getOperands()) {
vectorizedOperands.push_back(vectorize(operand));
}
SmallVector<int64_t, 4> shape = getLargestShape(vectorizedOperands);
for (Value &operand : vectorizedOperands)
operand = broadcastIfNeeded(operand, shape);
OperationState state(base_scalarOp.getLoc(), base_scalarOp.getName());
state.addAttributes(base_scalarOp.getAttrs());
state.addOperands(vectorizedOperands);
if (shape.empty()) {
state.addTypes(base_scalarOp.getResultTypes());
} else {
SmallVector<VectorType, 4> vectorizedTypes;
for (auto Type : base_scalarOp.getResultTypes())
vectorizedTypes.push_back(VectorType::get(shape, Type));
state.addTypes(vectorizedTypes);
}
return builder.createOperation(state);
}
OpBuilder &builder;
linalg::GenericOp generic;
llvm::DenseMap<Value, Value> valueCache;
SmallVector<Value, 8> results;
};
} // namespace
// Replaces elementwise linalg.generic ops with their bodies with scalar
// operations from these bodies promoted to vector operations.
static void vectorizeElementwise(linalg::GenericOp op, OpBuilder &builder) {
GenericVectorizer vectorizer(builder, op);
for (Operation &scalarOp : op.region().front()) {
vectorizer.vectorize(scalarOp);
}
if (!op->getResults().empty())
op->replaceAllUsesWith(vectorizer.getResults());
}
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
auto linalgOp = cast<linalg::LinalgOp>(op);
// All types must be static shape to go to vector.
@ -313,7 +421,7 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
// Vectorize fill as a vector.broadcast.
LLVM_DEBUG(dbgs() << dbgPref
<< "Rewrite linalg.fill as vector.broadcast: " << *op);
transferWriteVector(builder, fillOp.value(), fillOp.output());
buildVectorWrite(builder, fillOp.value(), fillOp.output());
return;
}
if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
@ -322,17 +430,21 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
<< "Rewrite linalg.copy as vector.transfer_read + "
"vector.transfer_write: "
<< *op);
Value vector = transferReadVector(builder, copyOp.input());
transferWriteVector(builder, vector, copyOp.output());
Value vector = buildVectorRead(builder, copyOp.input());
buildVectorWrite(builder, vector, copyOp.output());
return;
}
auto linalgOp = cast<linalg::LinalgOp>(op);
Location loc = linalgOp.getLoc();
if (isElementwise(op)) {
LLVM_DEBUG(dbgs() << dbgPref
<< "Rewrite linalg op as vector.transfer_read + "
"vector_op + vector.transfer_write: "
<< *op);
return vectorizeElementwise(cast<linalg::GenericOp>(op), builder);
<< "Rewrite linalg op as vector.transfer_read + " << *op);
auto status = vectorizeAsLinalgGeneric(builder, linalgOp);
assert(succeeded(status) &&
"Unexpected vectorization failed despite preconditions");
return;
}
assert(succeeded(isContraction(op)) && "Expected contraction");
@ -341,15 +453,28 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
// TODO: interface.
LLVM_DEBUG(dbgs() << dbgPref
<< "Rewrite linalg op as vector.contract: " << *op);
auto linalgOp = cast<linalg::LinalgOp>(op);
Value a = transferReadVector(builder, linalgOp.getInput(0));
Value b = transferReadVector(builder, linalgOp.getInput(1));
Value c = transferReadVector(builder, linalgOp.getOutput(0));
Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
linalgOp.iterator_types());
Value writeResult = transferWriteVector(builder, res, linalgOp.getOutput(0));
if (writeResult)
linalgOp->replaceAllUsesWith(ArrayRef<Value>(writeResult));
// Special function that describes how to vectorize the multiplication op in a
// linalg contraction.
CustomVectorizationHook vectorizeContraction =
[&](Operation *op,
const BlockAndValueMapping &bvm) -> VectorizationResult {
if (!isa<MulIOp, MulFOp>(op))
return VectorizationResult{VectorizationStatus::Failure, nullptr};
auto outShape = linalgOp.getOutputShapedType(0).getShape();
auto vType = outShape.empty()
? op->getResult(0).getType()
: VectorType::get(outShape, op->getResult(0).getType());
auto zero =
builder.create<ConstantOp>(loc, vType, builder.getZeroAttr(vType));
Operation *contract = builder.create<vector::ContractionOp>(
loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero,
linalgOp.indexing_maps(), linalgOp.iterator_types());
return VectorizationResult{VectorizationStatus::NewOp, contract};
};
auto status =
vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction});
assert(succeeded(status) &&
"Unexpected vectorization failed despite preconditions");
}
/// Check whether there is any interleaved use of any `values` between `firstOp`

View File

@ -183,13 +183,13 @@ func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>,
// CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32>
// CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32>
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : memref<256xf32>, vector<256xf32>
// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
// CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
// CHECK: %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32>
// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
// CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
// CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
// CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
@ -267,13 +267,13 @@ func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
// CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32>
// CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32>
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<256xf32>
// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
// CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
// CHECK: %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32>
// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
// CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
// CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
// CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
@ -307,10 +307,15 @@ func @matmul_tensors(
// CHECK-LABEL: func @matmul_tensors
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>,
// CHECK-SAME: %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32>
// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32>
// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32>
// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[V2]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32>
// CHECK: %[[W:.*]] = vector.transfer_write %[[C]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32>
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0.000000e+00> : vector<8x12xf32>
// CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32>
// CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32>
// CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32>
//
// linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
// a later canonicalization fuses the add into vector.contract.
// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[VEC_C0]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32>
// CHECK: %[[C2:.*]] = addf %[[V2]], %[[C]] : vector<8x12xf32>
// CHECK: %[[W:.*]] = vector.transfer_write %[[C2]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32>
// CHECK: return %[[W]] : tensor<8x12xf32>