[mlir] [VectorOps] Framework for progressive lowering of vector.contract

Summary:
Lowers all free/batch dimensions in a vector.contract progressively
into simpler vector.contract operations until a direct vector.reduction
operation is reached. Then lowers 1-D reductions into vector.reduce.

Still TBD:
multi-dimensional contractions that remain after removing all the parallel dims

Reviewers: nicolasvasilache, andydavis1, rriddle

Reviewed By: andydavis1

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D74797
This commit is contained in:
aartbik 2020-02-19 11:26:42 -08:00
parent 129c911efa
commit 0ba9ee9f0e
2 changed files with 372 additions and 23 deletions

View File

@ -14,6 +14,7 @@
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/VectorOps/VectorOps.h"
#include "mlir/Dialect/VectorOps/VectorTransforms.h"
#include "mlir/Dialect/VectorOps/VectorUtils.h"
@ -864,6 +865,19 @@ public:
};
/// Progressive lowering of ConstractionOp.
/// One:
/// %x = vector.contract with at least one free/batch dimension
/// is replaced by:
/// %a = vector.contract with one less free/batch dimension
/// %b = vector.contract with one less free/batch dimension
/// ..
/// %x = combine %a %b ..
/// until a pure contraction is reached (no free/batch dimensions),
/// which is replaced by a fma/reduction op.
///
/// TODO(ajcbik): break down into transpose/reshape/cast ops
/// when they become available to avoid code dup
/// TODO(ajcbik): investigate lowering order impact on performance
class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@ -874,16 +888,13 @@ public:
if (llvm::size(op.masks()) != 0)
return matchFailure();
auto loc = op.getLoc();
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
Type resType = op.getResultType();
// Find first batch dimension in lhs/rhs, and lower when found.
// Find first batch dimension in LHS/RHS, and lower when found.
std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
if (!batchDimMap.empty()) {
// TODO(ajcbik): implement batch
return matchFailure();
int64_t lhsIndex = batchDimMap[0].first;
int64_t rhsIndex = batchDimMap[0].second;
rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
return matchSuccess();
}
// Collect contracting dimensions.
@ -896,24 +907,35 @@ public:
rhsContractingDimSet.insert(dimPair.second);
}
// Find free dimension in lhs/rhs, and lower first when found.
for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
if (lhsContractingDimSet.count(i) == 0) {
// TODO(ajcbik): implement free
return matchFailure();
}
}
for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
if (rhsContractingDimSet.count(i) == 0) {
// TODO(ajcbik): implement free
return matchFailure();
// Find first free dimension in LHS, and lower when found.
VectorType lhsType = op.getLhsType();
for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e;
++lhsIndex) {
if (lhsContractingDimSet.count(lhsIndex) == 0) {
rewriter.replaceOp(
op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
return matchSuccess();
}
}
// Only contraction dimensions remain.
// Find first free dimension in RHS, and lower when found.
VectorType rhsType = op.getRhsType();
for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e;
++rhsIndex) {
if (rhsContractingDimSet.count(rhsIndex) == 0) {
rewriter.replaceOp(
op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
return matchSuccess();
}
}
// Lower the only remaining contraction dimensions.
// TODO(ajcbik): handle multi-dim reductions
auto loc = op.getLoc();
Type resType = op.getResultType();
if (!resType.isa<VectorType>() && lhsType.getRank() == 1 &&
rhsType.getRank() == 1) {
// Handle reduction into scalar.
Value zero = rewriter.create<ConstantOp>(loc, resType,
rewriter.getZeroAttr(resType));
Value splat = rewriter.create<SplatOp>(loc, lhsType, zero);
@ -924,9 +946,191 @@ public:
op.acc());
return matchSuccess();
}
// TODO(ajcbik): implement more contraction
return matchFailure();
}
private:
// Lower one parallel dimension.
// TODO(ajcbik): consider reusing existing contract unrolling
Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
int64_t rhsIndex, PatternRewriter &rewriter) const {
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
VectorType resType = op.getResultType().cast<VectorType>();
// Find the iterator type index and result index.
SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
int64_t iterIndex = -1;
int64_t dimSize = -1;
if (lhsIndex >= 0) {
iterIndex =
iMap[0].getResult(lhsIndex).cast<AffineDimExpr>().getPosition();
assert((rhsIndex < 0 || iterIndex == iMap[1]
.getResult(rhsIndex)
.cast<AffineDimExpr>()
.getPosition()) &&
"parallel index should be free in LHS or batch in LHS/RHS");
dimSize = lhsType.getDimSize(lhsIndex);
} else {
assert(rhsIndex >= 0 && "missing parallel index");
iterIndex =
iMap[1].getResult(rhsIndex).cast<AffineDimExpr>().getPosition();
dimSize = rhsType.getDimSize(rhsIndex);
}
assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
assert(lookup.hasValue() && "parallel index not listed in reduction");
int64_t resIndex = lookup.getValue();
// Construct new iterator types.
ArrayAttr iteratorTypes = op.iterator_types();
SmallVector<Attribute, 4> lowIterTypes;
for (auto it : llvm::enumerate(iteratorTypes)) {
int64_t idx = it.index();
if (idx == iterIndex) {
assert(it.value().cast<StringAttr>().getValue() ==
getParallelIteratorTypeName() &&
"parallel index not marked as such");
continue;
}
lowIterTypes.push_back(it.value());
}
// Construct new affine map array attribute.
SmallVector<AffineMap, 4> lowIndexingMaps;
lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
// Construct new iterator types array attribute.
auto lowIter = rewriter.getArrayAttr(lowIterTypes);
// Unroll into a series of lower dimensional vector.contract ops.
Location loc = op.getLoc();
Value result = zeroVector(loc, resType, rewriter);
for (int64_t d = 0; d < dimSize; ++d) {
auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter);
Value lowContract = rewriter.create<vector::ContractionOp>(
loc, lhs, rhs, acc, lowAffine, lowIter);
result = reshapeStore(loc, lowContract, result, resType, resIndex, d,
rewriter);
}
return result;
}
// Helper method to construct a zero vector.
static Value zeroVector(Location loc, VectorType vType,
PatternRewriter &rewriter) {
Type eltType = vType.getElementType();
Value zero = rewriter.create<ConstantOp>(loc, eltType,
rewriter.getZeroAttr(eltType));
return rewriter.create<SplatOp>(loc, vType, zero);
}
// Helper to find an index in an affine map.
static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
if (idx == index)
return i;
}
return None;
}
// Helper to construct an affine map with one index removed.
static AffineMap adjustMap(AffineMap map, int64_t index,
PatternRewriter &rewriter) {
SmallVector<AffineExpr, 4> results;
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
if (idx == index)
continue;
// Re-insert remaining indices, but renamed when occurring
// after the removed index.
auto targetExpr =
getAffineDimExpr(idx < index ? idx : idx - 1, rewriter.getContext());
results.push_back(targetExpr);
}
// Since (...) -> () cannot be represented properly,
// we resort to an empty map when this situation happens.
return results.empty() ? AffineMap::get(rewriter.getContext())
: AffineMap::get(map.getNumDims() - 1, 0, results);
}
// Helper to drop dimension from vector type.
static Type adjustType(VectorType tp, int64_t index) {
int64_t rank = tp.getRank();
Type eltType = tp.getElementType();
if (rank == 1) {
assert(index == 0 && "index for scalar result out of bounds");
return eltType;
}
SmallVector<int64_t, 4> adjustedShape;
for (int64_t i = 0; i < rank; ++i) {
// Omit dimension at the given index.
if (i == index)
continue;
// Otherwise, add dimension back.
adjustedShape.push_back(tp.getDimSize(i));
}
return VectorType::get(adjustedShape, eltType);
}
// Helper method to possibly drop a dimension in a load.
// TODO(ajcbik): use a reshaping vector load (and share lowering code)
static Value reshapeLoad(Location loc, Value val, VectorType type,
int64_t index, int64_t pos,
PatternRewriter &rewriter) {
if (index == -1)
return val;
Type lowType = adjustType(type, 0);
// At extraction dimension?
if (index == 0) {
auto posAttr = rewriter.getI64ArrayAttr(pos);
return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
}
// Unroll leading dimensions.
VectorType vType = lowType.cast<VectorType>();
VectorType resType = adjustType(type, index).cast<VectorType>();
Value result = zeroVector(loc, resType, rewriter);
for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
auto posAttr = rewriter.getI64ArrayAttr(d);
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
result = rewriter.create<vector::InsertOp>(loc, resType, load, result,
posAttr);
}
return result;
}
// Helper method to possibly drop a dimension in a store.
// TODO(ajcbik): use a reshaping vector store (and share lowering code)
static Value reshapeStore(Location loc, Value val, Value result,
VectorType type, int64_t index, int64_t pos,
PatternRewriter &rewriter) {
// Unmodified?
if (index == -1)
return val;
// At insertion dimension?
if (index == 0) {
auto posAttr = rewriter.getI64ArrayAttr(pos);
return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
}
// Unroll leading dimensions.
Type lowType = adjustType(type, 0);
VectorType vType = lowType.cast<VectorType>();
Type insType = adjustType(vType, 0);
for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
auto posAttr = rewriter.getI64ArrayAttr(d);
Value ext =
rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
Value ins =
rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
result =
rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
}
return result;
}
};
} // namespace

View File

@ -14,7 +14,7 @@
// CHECK-SAME: %[[A:.*0]]: vector<4xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<4xf32>,
// CHECK-SAME: %[[C:.*2]]: f32
// CHECK: %[[Z:.*]] = constant dense<0.000000e+00>
// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<4xf32>
// CHECK: %[[F:.*]] = vector.fma %[[A]], %[[B]], %[[Z]] : vector<4xf32>
// CHECK: %[[R:.*]] = vector.reductionv2 "add", %[[F]], %[[C]]
// CHECK: return %[[R]] : f32
@ -24,3 +24,148 @@ func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32)
: vector<4xf32>, vector<4xf32> into f32
return %0 : f32
}
#matvec_accesses = [
affine_map<(i, j) -> (i, j)>,
affine_map<(i, j) -> (j)>,
affine_map<(i, j) -> (i)>
]
#matvec_trait = {
indexing_maps = #matvec_accesses,
iterator_types = ["parallel", "reduction"]
}
// CHECK-LABEL: func @extract_contract2
// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<3xf32>,
// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
// CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32>
// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32>
// CHECK: %[[T2:.*]] = vector.fma %[[T0]], %[[B]], %[[Z]] : vector<3xf32>
// CHECK: %[[T3:.*]] = vector.reductionv2 "add", %[[T2]], %[[T1]] : vector<3xf32>, f32 into f32
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
// CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32>
// CHECK: %[[T7:.*]] = vector.fma %[[T5]], %[[B]], %[[Z]] : vector<3xf32>
// CHECK: %[[T8:.*]] = vector.reductionv2 "add", %[[T7]], %[[T6]] : vector<3xf32>, f32 into f32
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
// CHECK: return %[[T9]] : vector<2xf32>
func @extract_contract2(%arg0: vector<2x3xf32>,
%arg1: vector<3xf32>,
%arg2: vector<2xf32>) -> vector<2xf32> {
%0 = vector.contract #matvec_trait %arg0, %arg1, %arg2
: vector<2x3xf32>, vector<3xf32> into vector<2xf32>
return %0 : vector<2xf32>
}
#vecmat_accesses = [
affine_map<(i, j) -> (j)>,
affine_map<(i, j) -> (i, j)>,
affine_map<(i, j) -> (i)>
]
#vecmat_trait = {
indexing_maps = #vecmat_accesses,
iterator_types = ["parallel", "reduction"]
}
// CHECK-LABEL: func @extract_contract3
// CHECK-SAME: %[[A:.*0]]: vector<3xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
// CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32>
// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32>
// CHECK: %[[T2:.*]] = vector.fma %[[A]], %[[T0]], %[[Z]] : vector<3xf32>
// CHECK: %[[T3:.*]] = vector.reductionv2 "add", %[[T2]], %[[T1]] : vector<3xf32>, f32 into f32
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
// CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32>
// CHECK: %[[T7:.*]] = vector.fma %[[A]], %[[T5]], %[[Z]] : vector<3xf32>
// CHECK: %[[T8:.*]] = vector.reductionv2 "add", %[[T7]], %[[T6]] : vector<3xf32>, f32 into f32
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
// CHECK: return %[[T9]] : vector<2xf32>
func @extract_contract3(%arg0: vector<3xf32>,
%arg1: vector<2x3xf32>,
%arg2: vector<2xf32>) -> vector<2xf32> {
%0 = vector.contract #vecmat_trait %arg0, %arg1, %arg2
: vector<3xf32>, vector<2x3xf32> into vector<2xf32>
return %0 : vector<2xf32>
}
#matmat_accesses = [
affine_map<(i, j, k) -> (i, k)>,
affine_map<(i, j, k) -> (k, j)>,
affine_map<(i, j, k) -> (i, j)>
]
#matmat_trait = {
indexing_maps = #matmat_accesses,
iterator_types = ["parallel", "parallel", "reduction"]
}
// CHECK-LABEL: func @extract_contract4
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32>
// CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2x2xf32>
// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<2xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32>
// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2x2xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
// CHECK: %[[T3:.*]] = vector.extract %[[T2]][0] : vector<2xf32>
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[Z]] [0] : f32 into vector<2xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
// CHECK: %[[T6:.*]] = vector.extract %[[T5]][0] : vector<2xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T4]] [1] : f32 into vector<2xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
// CHECK: %[[T9:.*]] = vector.fma %[[T0]], %[[T7]], %[[Z]] : vector<2xf32>
// CHECK: %[[T10:.*]] = vector.reductionv2 "add", %[[T9]], %[[T8]] : vector<2xf32>, f32 into f32
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[Z]] [0] : f32 into vector<2xf32>
// CHECK: %[[T12:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
// CHECK: %[[T13:.*]] = vector.extract %[[T12]][1] : vector<2xf32>
// CHECK: %[[T14:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<2xf32>
// CHECK: %[[T15:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
// CHECK: %[[T16:.*]] = vector.extract %[[T15]][1] : vector<2xf32>
// CHECK: %[[T17:.*]] = vector.insert %[[T16]], %[[T14]] [1] : f32 into vector<2xf32>
// CHECK: %[[T18:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
// CHECK: %[[T19:.*]] = vector.fma %[[T0]], %[[T17]], %[[Z]] : vector<2xf32>
// CHECK: %[[T20:.*]] = vector.reductionv2 "add", %[[T19]], %[[T18]] : vector<2xf32>, f32 into f32
// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [1] : f32 into vector<2xf32>
// CHECK: %[[T22:.*]] = vector.insert %[[T21]], %[[R]] [0] : vector<2xf32> into vector<2x2xf32>
// CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xf32>
// CHECK: %[[T24:.*]] = vector.extract %[[C]][1] : vector<2x2xf32>
// CHECK: %[[T25:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
// CHECK: %[[T26:.*]] = vector.extract %[[T25]][0] : vector<2xf32>
// CHECK: %[[T27:.*]] = vector.insert %[[T26]], %[[Z]] [0] : f32 into vector<2xf32>
// CHECK: %[[T28:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
// CHECK: %[[T29:.*]] = vector.extract %[[T28]][0] : vector<2xf32>
// CHECK: %[[T30:.*]] = vector.insert %[[T29]], %[[T27]] [1] : f32 into vector<2xf32>
// CHECK: %[[T31:.*]] = vector.extract %[[T24]][0] : vector<2xf32>
// CHECK: %[[T32:.*]] = vector.fma %[[T23]], %[[T30]], %[[Z]] : vector<2xf32>
// CHECK: %[[T33:.*]] = vector.reductionv2 "add", %[[T32]], %[[T31]] : vector<2xf32>, f32 into f32
// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[Z]] [0] : f32 into vector<2xf32>
// CHECK: %[[T35:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
// CHECK: %[[T36:.*]] = vector.extract %[[T35]][1] : vector<2xf32>
// CHECK: %[[T37:.*]] = vector.insert %[[T36]], %[[Z]] [0] : f32 into vector<2xf32>
// CHECK: %[[T38:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
// CHECK: %[[T39:.*]] = vector.extract %[[T38]][1] : vector<2xf32>
// CHECK: %[[T40:.*]] = vector.insert %[[T39]], %[[T37]] [1] : f32 into vector<2xf32>
// CHECK: %[[T41:.*]] = vector.extract %[[T24]][1] : vector<2xf32>
// CHECK: %[[T42:.*]] = vector.fma %[[T23]], %[[T40]], %[[Z]] : vector<2xf32>
// CHECK: %[[T43:.*]] = vector.reductionv2 "add", %[[T42]], %[[T41]] : vector<2xf32>, f32 into f32
// CHECK: %[[T44:.*]] = vector.insert %[[T43]], %[[T34]] [1] : f32 into vector<2xf32>
// CHECK: %[[T45:.*]] = vector.insert %[[T44]], %[[T22]] [1] : vector<2xf32> into vector<2x2xf32>
// CHECK: return %[[T45]] : vector<2x2xf32>
func @extract_contract4(%arg0: vector<2x2xf32>,
%arg1: vector<2x2xf32>,
%arg2: vector<2x2xf32>) -> vector<2x2xf32> {
%0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
: vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
return %0 : vector<2x2xf32>
}