[mlir] [VectorOps] Improve SIMD compares with narrower indices

When allowed, use 32-bit indices rather than 64-bit indices in the
SIMD computation of masks. This runs up to 2x and 4x faster on
a number of AVX2 and AVX512 microbenchmarks.

Reviewed By: bkramer

Differential Revision: https://reviews.llvm.org/D87116
This commit is contained in:
aartbik 2020-09-03 15:57:25 -07:00
parent 0ac81333eb
commit 060c9dd1cc
7 changed files with 218 additions and 94 deletions

View File

@ -358,7 +358,10 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
let options = [
Option<"reassociateFPReductions", "reassociate-fp-reductions",
"bool", /*default=*/"false",
"Allows llvm to reassociate floating-point reductions for speed">
"Allows llvm to reassociate floating-point reductions for speed">,
Option<"enableIndexOptimizations", "enable-index-optimizations",
"bool", /*default=*/"false",
"Allows compiler to assume indices fit in 32-bit if that yields faster code">
];
}

View File

@ -22,8 +22,13 @@ class OperationPass;
/// ConvertVectorToLLVM pass in include/mlir/Conversion/Passes.td
struct LowerVectorToLLVMOptions {
bool reassociateFPReductions = false;
LowerVectorToLLVMOptions &setReassociateFPReductions(bool r) {
reassociateFPReductions = r;
bool enableIndexOptimizations = false;
LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) {
reassociateFPReductions = b;
return *this;
}
LowerVectorToLLVMOptions &setEnableIndexOptimizations(bool b) {
enableIndexOptimizations = b;
return *this;
}
};
@ -37,7 +42,8 @@ void populateVectorToLLVMMatrixConversionPatterns(
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
bool reassociateFPReductions = false);
bool reassociateFPReductions = false,
bool enableIndexOptimizations = false);
/// Create a pass to convert vector operations to the LLVMIR dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToLLVMPass(

View File

@ -117,6 +117,49 @@ static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
return res;
}
// Helper that returns a vector comparison that constructs a mask:
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
//
// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
// much more compact, IR for this operation, but LLVM eventually
// generates more elaborate instructions for this intrinsic since it
// is very conservative on the boundary conditions.
static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
Operation *op, bool enableIndexOptimizations,
int64_t dim, Value b, Value *off = nullptr) {
auto loc = op->getLoc();
// If we can assume all indices fit in 32-bit, we perform the vector
// comparison in 32-bit to get a higher degree of SIMD parallelism.
// Otherwise we perform the vector comparison using 64-bit indices.
Value indices;
Type idxType;
if (enableIndexOptimizations) {
SmallVector<int32_t, 4> values(dim);
for (int64_t d = 0; d < dim; d++)
values[d] = d;
indices =
rewriter.create<ConstantOp>(loc, rewriter.getI32VectorAttr(values));
idxType = rewriter.getI32Type();
} else {
SmallVector<int64_t, 4> values(dim);
for (int64_t d = 0; d < dim; d++)
values[d] = d;
indices =
rewriter.create<ConstantOp>(loc, rewriter.getI64VectorAttr(values));
idxType = rewriter.getI64Type();
}
// Add in an offset if requested.
if (off) {
Value o = rewriter.create<IndexCastOp>(loc, idxType, *off);
Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
indices = rewriter.create<AddIOp>(loc, ov, indices);
}
// Construct the vector comparison.
Value bound = rewriter.create<IndexCastOp>(loc, idxType, b);
Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
}
// Helper that returns data layout alignment of an operation with memref.
template <typename T>
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
@ -512,10 +555,10 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorReductionOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter,
bool reassociateFP)
bool reassociateFPRed)
: ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
typeConverter),
reassociateFPReductions(reassociateFP) {}
reassociateFPReductions(reassociateFPRed) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -589,6 +632,34 @@ private:
const bool reassociateFPReductions;
};
/// Conversion pattern for a vector.create_mask (1-D only).
class VectorCreateMaskOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorCreateMaskOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter,
bool enableIndexOpt)
: ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context,
typeConverter),
enableIndexOptimizations(enableIndexOpt) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dstType = op->getResult(0).getType().cast<VectorType>();
int64_t rank = dstType.getRank();
if (rank == 1) {
rewriter.replaceOp(
op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
dstType.getDimSize(0), operands[0]));
return success();
}
return failure();
}
private:
const bool enableIndexOptimizations;
};
class VectorShuffleOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorShuffleOpConversion(MLIRContext *context,
@ -1121,17 +1192,19 @@ public:
/// Conversion pattern that converts a 1-D vector transfer read/write op in a
/// sequence of:
/// 1. Bitcast or addrspacecast to vector form.
/// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
/// 3. Create a mask where offsetVector is compared against memref upper bound.
/// 4. Rewrite op as a masked read or write.
/// 1. Get the source/dst address as an LLVM vector pointer.
/// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
/// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
/// 4. Create a mask where offsetVector is compared against memref upper bound.
/// 5. Rewrite op as a masked read or write.
template <typename ConcreteOp>
class VectorTransferConversion : public ConvertToLLVMPattern {
public:
explicit VectorTransferConversion(MLIRContext *context,
LLVMTypeConverter &typeConv)
: ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
typeConv) {}
LLVMTypeConverter &typeConv,
bool enableIndexOpt)
: ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv),
enableIndexOptimizations(enableIndexOpt) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -1155,7 +1228,6 @@ public:
auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
Location loc = op->getLoc();
Type i64Type = rewriter.getIntegerType(64);
MemRefType memRefType = xferOp.getMemRefType();
if (auto memrefVectorElementType =
@ -1202,41 +1274,26 @@ public:
xferOp, operands, vectorDataPtr);
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
unsigned vecWidth = vecTy.getVectorNumElements();
VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
SmallVector<int64_t, 8> indices;
indices.reserve(vecWidth);
for (unsigned i = 0; i < vecWidth; ++i)
indices.push_back(i);
Value linearIndices = rewriter.create<ConstantOp>(
loc, vectorCmpType,
DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices)));
linearIndices = rewriter.create<LLVM::DialectCastOp>(
loc, toLLVMTy(vectorCmpType), linearIndices);
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
// TODO: when the leaf transfer rank is k > 1 we need the last
// `k` dimensions here.
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
Value offsetIndex = *(xferOp.indices().begin() + lastIndex);
offsetIndex = rewriter.create<IndexCastOp>(loc, i64Type, offsetIndex);
Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex);
Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices);
// 4. Let dim the memref dimension, compute the vector comparison mask:
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
//
// TODO: when the leaf transfer rank is k > 1, we need the last `k`
// dimensions here.
unsigned vecWidth = vecTy.getVectorNumElements();
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
Value off = *(xferOp.indices().begin() + lastIndex);
Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
dim = rewriter.create<IndexCastOp>(loc, i64Type, dim);
dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim);
Value mask =
rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim);
mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()),
mask);
Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations,
vecWidth, dim, &off);
// 5. Rewrite as a masked read / write.
return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
operands, vectorDataPtr, mask);
}
private:
const bool enableIndexOptimizations;
};
class VectorPrintOpConversion : public ConvertToLLVMPattern {
@ -1444,7 +1501,7 @@ public:
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
bool reassociateFPReductions) {
bool reassociateFPReductions, bool enableIndexOptimizations) {
MLIRContext *ctx = converter.getDialect()->getContext();
// clang-format off
patterns.insert<VectorFMAOpNDRewritePattern,
@ -1453,6 +1510,10 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorExtractStridedSliceOpConversion>(ctx);
patterns.insert<VectorReductionOpConversion>(
ctx, converter, reassociateFPReductions);
patterns.insert<VectorCreateMaskOpConversion,
VectorTransferConversion<TransferReadOp>,
VectorTransferConversion<TransferWriteOp>>(
ctx, converter, enableIndexOptimizations);
patterns
.insert<VectorShuffleOpConversion,
VectorExtractElementOpConversion,
@ -1461,8 +1522,6 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorInsertElementOpConversion,
VectorInsertOpConversion,
VectorPrintOpConversion,
VectorTransferConversion<TransferReadOp>,
VectorTransferConversion<TransferWriteOp>,
VectorTypeCastOpConversion,
VectorMaskedLoadOpConversion,
VectorMaskedStoreOpConversion,
@ -1485,6 +1544,7 @@ struct LowerVectorToLLVMPass
: public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
this->reassociateFPReductions = options.reassociateFPReductions;
this->enableIndexOptimizations = options.enableIndexOptimizations;
}
void runOnOperation() override;
};
@ -1505,16 +1565,15 @@ void LowerVectorToLLVMPass::runOnOperation() {
LLVMTypeConverter converter(&getContext());
OwningRewritePatternList patterns;
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(converter, patterns,
reassociateFPReductions);
populateVectorToLLVMConversionPatterns(
converter, patterns, reassociateFPReductions, enableIndexOptimizations);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns);
LLVMConversionTarget target(getContext());
if (failed(applyPartialConversion(getOperation(), target, patterns))) {
if (failed(applyPartialConversion(getOperation(), target, patterns)))
signalPassFailure();
}
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {

View File

@ -1347,7 +1347,8 @@ public:
auto eltType = dstType.getElementType();
auto dimSizes = op.mask_dim_sizes();
int64_t rank = dimSizes.size();
int64_t trueDim = dimSizes[0].cast<IntegerAttr>().getInt();
int64_t trueDim = std::min(dstType.getDimSize(0),
dimSizes[0].cast<IntegerAttr>().getInt());
if (rank == 1) {
// Express constant 1-D case in explicit vector form:
@ -1402,21 +1403,8 @@ public:
int64_t rank = dstType.getRank();
Value idx = op.getOperand(0);
if (rank == 1) {
// Express dynamic 1-D case in explicit vector form:
// mask = [0,1,..,n-1] < [a,a,..,a]
SmallVector<int64_t, 4> values(dim);
for (int64_t d = 0; d < dim; d++)
values[d] = d;
Value indices =
rewriter.create<ConstantOp>(loc, rewriter.getI64VectorAttr(values));
Value bound =
rewriter.create<IndexCastOp>(loc, rewriter.getI64Type(), idx);
Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::slt, indices,
bounds);
return success();
}
if (rank == 1)
return failure(); // leave for lowering
VectorType lowType =
VectorType::get(dstType.getShape().drop_front(), eltType);

View File

@ -0,0 +1,48 @@
// RUN: mlir-opt %s --convert-vector-to-llvm='enable-index-optimizations=1' | FileCheck %s --check-prefix=CMP32
// RUN: mlir-opt %s --convert-vector-to-llvm='enable-index-optimizations=0' | FileCheck %s --check-prefix=CMP64
// CMP32-LABEL: llvm.func @genbool_var_1d(
// CMP32-SAME: %[[A:.*]]: !llvm.i64)
// CMP32: %[[T0:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi32>) : !llvm.vec<11 x i32>
// CMP32: %[[T1:.*]] = llvm.trunc %[[A]] : !llvm.i64 to !llvm.i32
// CMP32: %[[T2:.*]] = llvm.mlir.undef : !llvm.vec<11 x i32>
// CMP32: %[[T3:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CMP32: %[[T4:.*]] = llvm.insertelement %[[T1]], %[[T2]][%[[T3]] : !llvm.i32] : !llvm.vec<11 x i32>
// CMP32: %[[T5:.*]] = llvm.shufflevector %[[T4]], %[[T2]] [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm.vec<11 x i32>, !llvm.vec<11 x i32>
// CMP32: %[[T6:.*]] = llvm.icmp "slt" %[[T0]], %[[T5]] : !llvm.vec<11 x i32>
// CMP32: llvm.return %[[T6]] : !llvm.vec<11 x i1>
// CMP64-LABEL: llvm.func @genbool_var_1d(
// CMP64-SAME: %[[A:.*]]: !llvm.i64)
// CMP64: %[[T0:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi64>) : !llvm.vec<11 x i64>
// CMP64: %[[T1:.*]] = llvm.mlir.undef : !llvm.vec<11 x i64>
// CMP64: %[[T2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CMP64: %[[T3:.*]] = llvm.insertelement %[[A]], %[[T1]][%[[T2]] : !llvm.i32] : !llvm.vec<11 x i64>
// CMP64: %[[T4:.*]] = llvm.shufflevector %[[T3]], %[[T1]] [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm.vec<11 x i64>, !llvm.vec<11 x i64>
// CMP64: %[[T5:.*]] = llvm.icmp "slt" %[[T0]], %[[T4]] : !llvm.vec<11 x i64>
// CMP64: llvm.return %[[T5]] : !llvm.vec<11 x i1>
func @genbool_var_1d(%arg0: index) -> vector<11xi1> {
%0 = vector.create_mask %arg0 : vector<11xi1>
return %0 : vector<11xi1>
}
// CMP32-LABEL: llvm.func @transfer_read_1d
// CMP32: %[[C:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>) : !llvm.vec<16 x i32>
// CMP32: %[[A:.*]] = llvm.add %{{.*}}, %[[C]] : !llvm.vec<16 x i32>
// CMP32: %[[M:.*]] = llvm.icmp "slt" %[[A]], %{{.*}} : !llvm.vec<16 x i32>
// CMP32: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %[[M]], %{{.*}}
// CMP32: llvm.return %[[L]] : !llvm.vec<16 x float>
// CMP64-LABEL: llvm.func @transfer_read_1d
// CMP64: %[[C:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi64>) : !llvm.vec<16 x i64>
// CMP64: %[[A:.*]] = llvm.add %{{.*}}, %[[C]] : !llvm.vec<16 x i64>
// CMP64: %[[M:.*]] = llvm.icmp "slt" %[[A]], %{{.*}} : !llvm.vec<16 x i64>
// CMP64: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %[[M]], %{{.*}}
// CMP64: llvm.return %[[L]] : !llvm.vec<16 x float>
func @transfer_read_1d(%A : memref<?xf32>, %i: index) -> vector<16xf32> {
%d = constant -1.0: f32
%f = vector.transfer_read %A[%i], %d {permutation_map = affine_map<(d0) -> (d0)>} : memref<?xf32>, vector<16xf32>
return %f : vector<16xf32>
}

View File

@ -749,10 +749,12 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
// CHECK-SAME: (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
// CHECK-SAME: !llvm.ptr<float> to !llvm.ptr<vec<17 x float>>
// CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 0] :
// CHECK-SAME: !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
//
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// CHECK: %[[linearIndex:.*]] = llvm.mlir.constant(
// CHECK-SAME: dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
// CHECK: %[[linearIndex:.*]] = llvm.mlir.constant(dense
// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
// CHECK-SAME: vector<17xi64>) : !llvm.vec<17 x i64>
//
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
@ -770,8 +772,6 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
//
// 4. Let dim the memref dimension, compute the vector comparison mask:
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
// CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 0] :
// CHECK-SAME: !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[dimVec:.*]] = llvm.mlir.undef : !llvm.vec<17 x i64>
// CHECK: %[[c01:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK: %[[dimVec2:.*]] = llvm.insertelement %[[DIM]], %[[dimVec]][%[[c01]] :
@ -799,8 +799,8 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
// CHECK-SAME: !llvm.ptr<float> to !llvm.ptr<vec<17 x float>>
//
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// CHECK: %[[linearIndex_b:.*]] = llvm.mlir.constant(
// CHECK-SAME: dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
// CHECK: %[[linearIndex_b:.*]] = llvm.mlir.constant(dense
// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
// CHECK-SAME: vector<17xi64>) : !llvm.vec<17 x i64>
//
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
@ -832,6 +832,8 @@ func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index)
}
// CHECK-LABEL: func @transfer_read_2d_to_1d
// CHECK-SAME: %[[BASE_0:[a-zA-Z0-9]*]]: !llvm.i64, %[[BASE_1:[a-zA-Z0-9]*]]: !llvm.i64) -> !llvm.vec<17 x float>
// CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 1] :
// CHECK-SAME: !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
//
// Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
// CHECK: %[[offsetVec:.*]] = llvm.mlir.undef : !llvm.vec<17 x i64>
@ -847,8 +849,6 @@ func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index)
// Let dim the memref dimension, compute the vector comparison mask:
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
// Here we check we properly use %DIM[1]
// CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 1] :
// CHECK-SAME: !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[dimVec:.*]] = llvm.mlir.undef : !llvm.vec<17 x i64>
// CHECK: %[[c01:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK: %[[dimVec2:.*]] = llvm.insertelement %[[DIM]], %[[dimVec]][%[[c01]] :

View File

@ -785,43 +785,63 @@ func @genbool_3d() -> vector<2x3x4xi1> {
return %v: vector<2x3x4xi1>
}
// CHECK-LABEL: func @genbool_var_1d
// CHECK-SAME: %[[A:.*]]: index
// CHECK: %[[C1:.*]] = constant dense<[0, 1, 2]> : vector<3xi64>
// CHECK: %[[T0:.*]] = index_cast %[[A]] : index to i64
// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xi64>
// CHECK: %[[T2:.*]] = cmpi "slt", %[[C1]], %[[T1]] : vector<3xi64>
// CHECK: return %[[T2]] : vector<3xi1>
// CHECK-LABEL: func @genbool_var_1d(
// CHECK-SAME: %[[A:.*]]: index)
// CHECK: %[[T0:.*]] = vector.create_mask %[[A]] : vector<3xi1>
// CHECK: return %[[T0]] : vector<3xi1>
func @genbool_var_1d(%arg0: index) -> vector<3xi1> {
%0 = vector.create_mask %arg0 : vector<3xi1>
return %0 : vector<3xi1>
}
// CHECK-LABEL: func @genbool_var_2d
// CHECK-SAME: %[[A:.*0]]: index
// CHECK-SAME: %[[B:.*1]]: index
// CHECK: %[[CI:.*]] = constant dense<[0, 1, 2]> : vector<3xi64>
// CHECK: %[[CF:.*]] = constant dense<false> : vector<3xi1>
// CHECK-LABEL: func @genbool_var_2d(
// CHECK-SAME: %[[A:.*0]]: index,
// CHECK-SAME: %[[B:.*1]]: index)
// CHECK: %[[C1:.*]] = constant dense<false> : vector<3xi1>
// CHECK: %[[C2:.*]] = constant dense<false> : vector<2x3xi1>
// CHECK: %[[c0:.*]] = constant 0 : index
// CHECK: %[[c1:.*]] = constant 1 : index
// CHECK: %[[T0:.*]] = index_cast %[[B]] : index to i64
// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xi64>
// CHECK: %[[T2:.*]] = cmpi "slt", %[[CI]], %[[T1]] : vector<3xi64>
// CHECK: %[[T3:.*]] = cmpi "slt", %[[c0]], %[[A]] : index
// CHECK: %[[T4:.*]] = select %[[T3]], %[[T2]], %[[CF]] : vector<3xi1>
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1>
// CHECK: %[[T6:.*]] = cmpi "slt", %[[c1]], %[[A]] : index
// CHECK: %[[T7:.*]] = select %[[T6]], %[[T2]], %[[CF]] : vector<3xi1>
// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[T5]] [1] : vector<3xi1> into vector<2x3xi1>
// CHECK: return %[[T8]] : vector<2x3xi1>
// CHECK: %[[T0:.*]] = vector.create_mask %[[B]] : vector<3xi1>
// CHECK: %[[T1:.*]] = cmpi "slt", %[[c0]], %[[A]] : index
// CHECK: %[[T2:.*]] = select %[[T1]], %[[T0]], %[[C1]] : vector<3xi1>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1>
// CHECK: %[[T4:.*]] = cmpi "slt", %[[c1]], %[[A]] : index
// CHECK: %[[T5:.*]] = select %[[T4]], %[[T0]], %[[C1]] : vector<3xi1>
// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : vector<3xi1> into vector<2x3xi1>
// CHECK: return %[[T6]] : vector<2x3xi1>
func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> {
%0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1>
return %0 : vector<2x3xi1>
}
// CHECK-LABEL: func @genbool_var_3d(
// CHECK-SAME: %[[A:.*0]]: index,
// CHECK-SAME: %[[B:.*1]]: index,
// CHECK-SAME: %[[C:.*2]]: index)
// CHECK: %[[C1:.*]] = constant dense<false> : vector<7xi1>
// CHECK: %[[C2:.*]] = constant dense<false> : vector<1x7xi1>
// CHECK: %[[C3:.*]] = constant dense<false> : vector<2x1x7xi1>
// CHECK: %[[c0:.*]] = constant 0 : index
// CHECK: %[[c1:.*]] = constant 1 : index
// CHECK: %[[T0:.*]] = vector.create_mask %[[C]] : vector<7xi1>
// CHECK: %[[T1:.*]] = cmpi "slt", %[[c0]], %[[B]] : index
// CHECK: %[[T2:.*]] = select %[[T1]], %[[T0]], %[[C1]] : vector<7xi1>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<7xi1> into vector<1x7xi1>
// CHECK: %[[T4:.*]] = cmpi "slt", %[[c0]], %[[A]] : index
// CHECK: %[[T5:.*]] = select %[[T4]], %[[T3]], %[[C2]] : vector<1x7xi1>
// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[C3]] [0] : vector<1x7xi1> into vector<2x1x7xi1>
// CHECK: %[[T7:.*]] = cmpi "slt", %[[c1]], %[[A]] : index
// CHECK: %[[T8:.*]] = select %[[T7]], %[[T3]], %[[C2]] : vector<1x7xi1>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [1] : vector<1x7xi1> into vector<2x1x7xi1>
// CHECK: return %[[T9]] : vector<2x1x7xi1>
func @genbool_var_3d(%arg0: index, %arg1: index, %arg2: index) -> vector<2x1x7xi1> {
%0 = vector.create_mask %arg0, %arg1, %arg2 : vector<2x1x7xi1>
return %0 : vector<2x1x7xi1>
}
#matmat_accesses_0 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,