forked from OSchip/llvm-project
[mlir] [VectorOps] Framework for progressive lowering of vector.contract
Summary: Lowers all free/batch dimensions in a vector.contract progressively into simpler vector.contract operations until a direct vector.reduction operation is reached. Then lowers 1-D reductions into vector.reduce. Still TBD: multi-dimensional contractions that remain after removing all the parallel dims Reviewers: nicolasvasilache, andydavis1, rriddle Reviewed By: andydavis1 Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D74797
This commit is contained in:
parent
129c911efa
commit
0ba9ee9f0e
|
@ -14,6 +14,7 @@
|
|||
|
||||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/Dialect/VectorOps/VectorOps.h"
|
||||
#include "mlir/Dialect/VectorOps/VectorTransforms.h"
|
||||
#include "mlir/Dialect/VectorOps/VectorUtils.h"
|
||||
|
@ -864,6 +865,19 @@ public:
|
|||
};
|
||||
|
||||
/// Progressive lowering of ConstractionOp.
|
||||
/// One:
|
||||
/// %x = vector.contract with at least one free/batch dimension
|
||||
/// is replaced by:
|
||||
/// %a = vector.contract with one less free/batch dimension
|
||||
/// %b = vector.contract with one less free/batch dimension
|
||||
/// ..
|
||||
/// %x = combine %a %b ..
|
||||
/// until a pure contraction is reached (no free/batch dimensions),
|
||||
/// which is replaced by a fma/reduction op.
|
||||
///
|
||||
/// TODO(ajcbik): break down into transpose/reshape/cast ops
|
||||
/// when they become available to avoid code dup
|
||||
/// TODO(ajcbik): investigate lowering order impact on performance
|
||||
class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
|
||||
public:
|
||||
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
|
||||
|
@ -874,16 +888,13 @@ public:
|
|||
if (llvm::size(op.masks()) != 0)
|
||||
return matchFailure();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
VectorType lhsType = op.getLhsType();
|
||||
VectorType rhsType = op.getRhsType();
|
||||
Type resType = op.getResultType();
|
||||
|
||||
// Find first batch dimension in lhs/rhs, and lower when found.
|
||||
// Find first batch dimension in LHS/RHS, and lower when found.
|
||||
std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
|
||||
if (!batchDimMap.empty()) {
|
||||
// TODO(ajcbik): implement batch
|
||||
return matchFailure();
|
||||
int64_t lhsIndex = batchDimMap[0].first;
|
||||
int64_t rhsIndex = batchDimMap[0].second;
|
||||
rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
// Collect contracting dimensions.
|
||||
|
@ -896,24 +907,35 @@ public:
|
|||
rhsContractingDimSet.insert(dimPair.second);
|
||||
}
|
||||
|
||||
// Find free dimension in lhs/rhs, and lower first when found.
|
||||
for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
|
||||
if (lhsContractingDimSet.count(i) == 0) {
|
||||
// TODO(ajcbik): implement free
|
||||
return matchFailure();
|
||||
}
|
||||
}
|
||||
for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
|
||||
if (rhsContractingDimSet.count(i) == 0) {
|
||||
// TODO(ajcbik): implement free
|
||||
return matchFailure();
|
||||
// Find first free dimension in LHS, and lower when found.
|
||||
VectorType lhsType = op.getLhsType();
|
||||
for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e;
|
||||
++lhsIndex) {
|
||||
if (lhsContractingDimSet.count(lhsIndex) == 0) {
|
||||
rewriter.replaceOp(
|
||||
op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
|
||||
return matchSuccess();
|
||||
}
|
||||
}
|
||||
|
||||
// Only contraction dimensions remain.
|
||||
// Find first free dimension in RHS, and lower when found.
|
||||
VectorType rhsType = op.getRhsType();
|
||||
for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e;
|
||||
++rhsIndex) {
|
||||
if (rhsContractingDimSet.count(rhsIndex) == 0) {
|
||||
rewriter.replaceOp(
|
||||
op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
|
||||
return matchSuccess();
|
||||
}
|
||||
}
|
||||
|
||||
// Lower the only remaining contraction dimensions.
|
||||
// TODO(ajcbik): handle multi-dim reductions
|
||||
auto loc = op.getLoc();
|
||||
Type resType = op.getResultType();
|
||||
if (!resType.isa<VectorType>() && lhsType.getRank() == 1 &&
|
||||
rhsType.getRank() == 1) {
|
||||
// Handle reduction into scalar.
|
||||
|
||||
Value zero = rewriter.create<ConstantOp>(loc, resType,
|
||||
rewriter.getZeroAttr(resType));
|
||||
Value splat = rewriter.create<SplatOp>(loc, lhsType, zero);
|
||||
|
@ -924,9 +946,191 @@ public:
|
|||
op.acc());
|
||||
return matchSuccess();
|
||||
}
|
||||
// TODO(ajcbik): implement more contraction
|
||||
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
private:
|
||||
// Lower one parallel dimension.
|
||||
// TODO(ajcbik): consider reusing existing contract unrolling
|
||||
Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
|
||||
int64_t rhsIndex, PatternRewriter &rewriter) const {
|
||||
VectorType lhsType = op.getLhsType();
|
||||
VectorType rhsType = op.getRhsType();
|
||||
VectorType resType = op.getResultType().cast<VectorType>();
|
||||
// Find the iterator type index and result index.
|
||||
SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
|
||||
int64_t iterIndex = -1;
|
||||
int64_t dimSize = -1;
|
||||
if (lhsIndex >= 0) {
|
||||
iterIndex =
|
||||
iMap[0].getResult(lhsIndex).cast<AffineDimExpr>().getPosition();
|
||||
assert((rhsIndex < 0 || iterIndex == iMap[1]
|
||||
.getResult(rhsIndex)
|
||||
.cast<AffineDimExpr>()
|
||||
.getPosition()) &&
|
||||
"parallel index should be free in LHS or batch in LHS/RHS");
|
||||
dimSize = lhsType.getDimSize(lhsIndex);
|
||||
} else {
|
||||
assert(rhsIndex >= 0 && "missing parallel index");
|
||||
iterIndex =
|
||||
iMap[1].getResult(rhsIndex).cast<AffineDimExpr>().getPosition();
|
||||
dimSize = rhsType.getDimSize(rhsIndex);
|
||||
}
|
||||
assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
|
||||
Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
|
||||
assert(lookup.hasValue() && "parallel index not listed in reduction");
|
||||
int64_t resIndex = lookup.getValue();
|
||||
// Construct new iterator types.
|
||||
ArrayAttr iteratorTypes = op.iterator_types();
|
||||
SmallVector<Attribute, 4> lowIterTypes;
|
||||
for (auto it : llvm::enumerate(iteratorTypes)) {
|
||||
int64_t idx = it.index();
|
||||
if (idx == iterIndex) {
|
||||
assert(it.value().cast<StringAttr>().getValue() ==
|
||||
getParallelIteratorTypeName() &&
|
||||
"parallel index not marked as such");
|
||||
continue;
|
||||
}
|
||||
lowIterTypes.push_back(it.value());
|
||||
}
|
||||
// Construct new affine map array attribute.
|
||||
SmallVector<AffineMap, 4> lowIndexingMaps;
|
||||
lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
|
||||
lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
|
||||
lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
|
||||
auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
|
||||
// Construct new iterator types array attribute.
|
||||
auto lowIter = rewriter.getArrayAttr(lowIterTypes);
|
||||
// Unroll into a series of lower dimensional vector.contract ops.
|
||||
Location loc = op.getLoc();
|
||||
Value result = zeroVector(loc, resType, rewriter);
|
||||
for (int64_t d = 0; d < dimSize; ++d) {
|
||||
auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
|
||||
auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
|
||||
auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter);
|
||||
Value lowContract = rewriter.create<vector::ContractionOp>(
|
||||
loc, lhs, rhs, acc, lowAffine, lowIter);
|
||||
result = reshapeStore(loc, lowContract, result, resType, resIndex, d,
|
||||
rewriter);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Helper method to construct a zero vector.
|
||||
static Value zeroVector(Location loc, VectorType vType,
|
||||
PatternRewriter &rewriter) {
|
||||
Type eltType = vType.getElementType();
|
||||
Value zero = rewriter.create<ConstantOp>(loc, eltType,
|
||||
rewriter.getZeroAttr(eltType));
|
||||
return rewriter.create<SplatOp>(loc, vType, zero);
|
||||
}
|
||||
|
||||
// Helper to find an index in an affine map.
|
||||
static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
|
||||
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
|
||||
int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
|
||||
if (idx == index)
|
||||
return i;
|
||||
}
|
||||
return None;
|
||||
}
|
||||
|
||||
// Helper to construct an affine map with one index removed.
|
||||
static AffineMap adjustMap(AffineMap map, int64_t index,
|
||||
PatternRewriter &rewriter) {
|
||||
SmallVector<AffineExpr, 4> results;
|
||||
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
|
||||
int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
|
||||
if (idx == index)
|
||||
continue;
|
||||
// Re-insert remaining indices, but renamed when occurring
|
||||
// after the removed index.
|
||||
auto targetExpr =
|
||||
getAffineDimExpr(idx < index ? idx : idx - 1, rewriter.getContext());
|
||||
results.push_back(targetExpr);
|
||||
}
|
||||
// Since (...) -> () cannot be represented properly,
|
||||
// we resort to an empty map when this situation happens.
|
||||
return results.empty() ? AffineMap::get(rewriter.getContext())
|
||||
: AffineMap::get(map.getNumDims() - 1, 0, results);
|
||||
}
|
||||
|
||||
// Helper to drop dimension from vector type.
|
||||
static Type adjustType(VectorType tp, int64_t index) {
|
||||
int64_t rank = tp.getRank();
|
||||
Type eltType = tp.getElementType();
|
||||
if (rank == 1) {
|
||||
assert(index == 0 && "index for scalar result out of bounds");
|
||||
return eltType;
|
||||
}
|
||||
SmallVector<int64_t, 4> adjustedShape;
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
// Omit dimension at the given index.
|
||||
if (i == index)
|
||||
continue;
|
||||
// Otherwise, add dimension back.
|
||||
adjustedShape.push_back(tp.getDimSize(i));
|
||||
}
|
||||
return VectorType::get(adjustedShape, eltType);
|
||||
}
|
||||
|
||||
// Helper method to possibly drop a dimension in a load.
|
||||
// TODO(ajcbik): use a reshaping vector load (and share lowering code)
|
||||
static Value reshapeLoad(Location loc, Value val, VectorType type,
|
||||
int64_t index, int64_t pos,
|
||||
PatternRewriter &rewriter) {
|
||||
if (index == -1)
|
||||
return val;
|
||||
Type lowType = adjustType(type, 0);
|
||||
// At extraction dimension?
|
||||
if (index == 0) {
|
||||
auto posAttr = rewriter.getI64ArrayAttr(pos);
|
||||
return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
|
||||
}
|
||||
// Unroll leading dimensions.
|
||||
VectorType vType = lowType.cast<VectorType>();
|
||||
VectorType resType = adjustType(type, index).cast<VectorType>();
|
||||
Value result = zeroVector(loc, resType, rewriter);
|
||||
for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
|
||||
auto posAttr = rewriter.getI64ArrayAttr(d);
|
||||
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
|
||||
Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
|
||||
result = rewriter.create<vector::InsertOp>(loc, resType, load, result,
|
||||
posAttr);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Helper method to possibly drop a dimension in a store.
|
||||
// TODO(ajcbik): use a reshaping vector store (and share lowering code)
|
||||
static Value reshapeStore(Location loc, Value val, Value result,
|
||||
VectorType type, int64_t index, int64_t pos,
|
||||
PatternRewriter &rewriter) {
|
||||
// Unmodified?
|
||||
if (index == -1)
|
||||
return val;
|
||||
// At insertion dimension?
|
||||
if (index == 0) {
|
||||
auto posAttr = rewriter.getI64ArrayAttr(pos);
|
||||
return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
|
||||
}
|
||||
// Unroll leading dimensions.
|
||||
Type lowType = adjustType(type, 0);
|
||||
VectorType vType = lowType.cast<VectorType>();
|
||||
Type insType = adjustType(vType, 0);
|
||||
for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
|
||||
auto posAttr = rewriter.getI64ArrayAttr(d);
|
||||
Value ext =
|
||||
rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
|
||||
Value ins =
|
||||
rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
|
||||
Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
|
||||
result =
|
||||
rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
// CHECK-SAME: %[[A:.*0]]: vector<4xf32>,
|
||||
// CHECK-SAME: %[[B:.*1]]: vector<4xf32>,
|
||||
// CHECK-SAME: %[[C:.*2]]: f32
|
||||
// CHECK: %[[Z:.*]] = constant dense<0.000000e+00>
|
||||
// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<4xf32>
|
||||
// CHECK: %[[F:.*]] = vector.fma %[[A]], %[[B]], %[[Z]] : vector<4xf32>
|
||||
// CHECK: %[[R:.*]] = vector.reductionv2 "add", %[[F]], %[[C]]
|
||||
// CHECK: return %[[R]] : f32
|
||||
|
@ -24,3 +24,148 @@ func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32)
|
|||
: vector<4xf32>, vector<4xf32> into f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
#matvec_accesses = [
|
||||
affine_map<(i, j) -> (i, j)>,
|
||||
affine_map<(i, j) -> (j)>,
|
||||
affine_map<(i, j) -> (i)>
|
||||
]
|
||||
#matvec_trait = {
|
||||
indexing_maps = #matvec_accesses,
|
||||
iterator_types = ["parallel", "reduction"]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @extract_contract2
|
||||
// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
|
||||
// CHECK-SAME: %[[B:.*1]]: vector<3xf32>,
|
||||
// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
|
||||
// CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32>
|
||||
// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
|
||||
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
|
||||
// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32>
|
||||
// CHECK: %[[T2:.*]] = vector.fma %[[T0]], %[[B]], %[[Z]] : vector<3xf32>
|
||||
// CHECK: %[[T3:.*]] = vector.reductionv2 "add", %[[T2]], %[[T1]] : vector<3xf32>, f32 into f32
|
||||
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
|
||||
// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
|
||||
// CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32>
|
||||
// CHECK: %[[T7:.*]] = vector.fma %[[T5]], %[[B]], %[[Z]] : vector<3xf32>
|
||||
// CHECK: %[[T8:.*]] = vector.reductionv2 "add", %[[T7]], %[[T6]] : vector<3xf32>, f32 into f32
|
||||
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
|
||||
// CHECK: return %[[T9]] : vector<2xf32>
|
||||
|
||||
func @extract_contract2(%arg0: vector<2x3xf32>,
|
||||
%arg1: vector<3xf32>,
|
||||
%arg2: vector<2xf32>) -> vector<2xf32> {
|
||||
%0 = vector.contract #matvec_trait %arg0, %arg1, %arg2
|
||||
: vector<2x3xf32>, vector<3xf32> into vector<2xf32>
|
||||
return %0 : vector<2xf32>
|
||||
}
|
||||
|
||||
#vecmat_accesses = [
|
||||
affine_map<(i, j) -> (j)>,
|
||||
affine_map<(i, j) -> (i, j)>,
|
||||
affine_map<(i, j) -> (i)>
|
||||
]
|
||||
#vecmat_trait = {
|
||||
indexing_maps = #vecmat_accesses,
|
||||
iterator_types = ["parallel", "reduction"]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @extract_contract3
|
||||
// CHECK-SAME: %[[A:.*0]]: vector<3xf32>,
|
||||
// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
|
||||
// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
|
||||
// CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32>
|
||||
// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
|
||||
// CHECK: %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
|
||||
// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32>
|
||||
// CHECK: %[[T2:.*]] = vector.fma %[[A]], %[[T0]], %[[Z]] : vector<3xf32>
|
||||
// CHECK: %[[T3:.*]] = vector.reductionv2 "add", %[[T2]], %[[T1]] : vector<3xf32>, f32 into f32
|
||||
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
|
||||
// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
|
||||
// CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32>
|
||||
// CHECK: %[[T7:.*]] = vector.fma %[[A]], %[[T5]], %[[Z]] : vector<3xf32>
|
||||
// CHECK: %[[T8:.*]] = vector.reductionv2 "add", %[[T7]], %[[T6]] : vector<3xf32>, f32 into f32
|
||||
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
|
||||
// CHECK: return %[[T9]] : vector<2xf32>
|
||||
|
||||
func @extract_contract3(%arg0: vector<3xf32>,
|
||||
%arg1: vector<2x3xf32>,
|
||||
%arg2: vector<2xf32>) -> vector<2xf32> {
|
||||
%0 = vector.contract #vecmat_trait %arg0, %arg1, %arg2
|
||||
: vector<3xf32>, vector<2x3xf32> into vector<2xf32>
|
||||
return %0 : vector<2xf32>
|
||||
}
|
||||
|
||||
#matmat_accesses = [
|
||||
affine_map<(i, j, k) -> (i, k)>,
|
||||
affine_map<(i, j, k) -> (k, j)>,
|
||||
affine_map<(i, j, k) -> (i, j)>
|
||||
]
|
||||
#matmat_trait = {
|
||||
indexing_maps = #matmat_accesses,
|
||||
iterator_types = ["parallel", "parallel", "reduction"]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @extract_contract4
|
||||
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>,
|
||||
// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
|
||||
// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32>
|
||||
// CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2x2xf32>
|
||||
// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<2xf32>
|
||||
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32>
|
||||
// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2x2xf32>
|
||||
// CHECK: %[[T2:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
|
||||
// CHECK: %[[T3:.*]] = vector.extract %[[T2]][0] : vector<2xf32>
|
||||
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[Z]] [0] : f32 into vector<2xf32>
|
||||
// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
|
||||
// CHECK: %[[T6:.*]] = vector.extract %[[T5]][0] : vector<2xf32>
|
||||
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T4]] [1] : f32 into vector<2xf32>
|
||||
// CHECK: %[[T8:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
|
||||
// CHECK: %[[T9:.*]] = vector.fma %[[T0]], %[[T7]], %[[Z]] : vector<2xf32>
|
||||
// CHECK: %[[T10:.*]] = vector.reductionv2 "add", %[[T9]], %[[T8]] : vector<2xf32>, f32 into f32
|
||||
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[Z]] [0] : f32 into vector<2xf32>
|
||||
// CHECK: %[[T12:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
|
||||
// CHECK: %[[T13:.*]] = vector.extract %[[T12]][1] : vector<2xf32>
|
||||
// CHECK: %[[T14:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<2xf32>
|
||||
// CHECK: %[[T15:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
|
||||
// CHECK: %[[T16:.*]] = vector.extract %[[T15]][1] : vector<2xf32>
|
||||
// CHECK: %[[T17:.*]] = vector.insert %[[T16]], %[[T14]] [1] : f32 into vector<2xf32>
|
||||
// CHECK: %[[T18:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
|
||||
// CHECK: %[[T19:.*]] = vector.fma %[[T0]], %[[T17]], %[[Z]] : vector<2xf32>
|
||||
// CHECK: %[[T20:.*]] = vector.reductionv2 "add", %[[T19]], %[[T18]] : vector<2xf32>, f32 into f32
|
||||
// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [1] : f32 into vector<2xf32>
|
||||
// CHECK: %[[T22:.*]] = vector.insert %[[T21]], %[[R]] [0] : vector<2xf32> into vector<2x2xf32>
|
||||
// CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xf32>
|
||||
// CHECK: %[[T24:.*]] = vector.extract %[[C]][1] : vector<2x2xf32>
|
||||
// CHECK: %[[T25:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
|
||||
// CHECK: %[[T26:.*]] = vector.extract %[[T25]][0] : vector<2xf32>
|
||||
// CHECK: %[[T27:.*]] = vector.insert %[[T26]], %[[Z]] [0] : f32 into vector<2xf32>
|
||||
// CHECK: %[[T28:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
|
||||
// CHECK: %[[T29:.*]] = vector.extract %[[T28]][0] : vector<2xf32>
|
||||
// CHECK: %[[T30:.*]] = vector.insert %[[T29]], %[[T27]] [1] : f32 into vector<2xf32>
|
||||
// CHECK: %[[T31:.*]] = vector.extract %[[T24]][0] : vector<2xf32>
|
||||
// CHECK: %[[T32:.*]] = vector.fma %[[T23]], %[[T30]], %[[Z]] : vector<2xf32>
|
||||
// CHECK: %[[T33:.*]] = vector.reductionv2 "add", %[[T32]], %[[T31]] : vector<2xf32>, f32 into f32
|
||||
// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[Z]] [0] : f32 into vector<2xf32>
|
||||
// CHECK: %[[T35:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
|
||||
// CHECK: %[[T36:.*]] = vector.extract %[[T35]][1] : vector<2xf32>
|
||||
// CHECK: %[[T37:.*]] = vector.insert %[[T36]], %[[Z]] [0] : f32 into vector<2xf32>
|
||||
// CHECK: %[[T38:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
|
||||
// CHECK: %[[T39:.*]] = vector.extract %[[T38]][1] : vector<2xf32>
|
||||
// CHECK: %[[T40:.*]] = vector.insert %[[T39]], %[[T37]] [1] : f32 into vector<2xf32>
|
||||
// CHECK: %[[T41:.*]] = vector.extract %[[T24]][1] : vector<2xf32>
|
||||
// CHECK: %[[T42:.*]] = vector.fma %[[T23]], %[[T40]], %[[Z]] : vector<2xf32>
|
||||
// CHECK: %[[T43:.*]] = vector.reductionv2 "add", %[[T42]], %[[T41]] : vector<2xf32>, f32 into f32
|
||||
// CHECK: %[[T44:.*]] = vector.insert %[[T43]], %[[T34]] [1] : f32 into vector<2xf32>
|
||||
// CHECK: %[[T45:.*]] = vector.insert %[[T44]], %[[T22]] [1] : vector<2xf32> into vector<2x2xf32>
|
||||
// CHECK: return %[[T45]] : vector<2x2xf32>
|
||||
|
||||
func @extract_contract4(%arg0: vector<2x2xf32>,
|
||||
%arg1: vector<2x2xf32>,
|
||||
%arg2: vector<2x2xf32>) -> vector<2x2xf32> {
|
||||
%0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
|
||||
: vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
return %0 : vector<2x2xf32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue