forked from OSchip/llvm-project
[mlir] VectorToSCF cleanup
Group functions/structs in namespaces for better code readability. Depends On D102123 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D102124
This commit is contained in:
parent
23a84e1c60
commit
a088bed4e3
mlir/lib/Conversion/VectorToSCF
|
@ -49,52 +49,6 @@ struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
|
||||||
VectorTransferToSCFOptions options;
|
VectorTransferToSCFOptions options;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Given a MemRefType with VectorType element type, unpack one dimension from
|
|
||||||
/// the VectorType into the MemRefType.
|
|
||||||
///
|
|
||||||
/// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
|
|
||||||
static MemRefType unpackOneDim(MemRefType type) {
|
|
||||||
auto vectorType = type.getElementType().dyn_cast<VectorType>();
|
|
||||||
auto memrefShape = type.getShape();
|
|
||||||
SmallVector<int64_t, 8> newMemrefShape;
|
|
||||||
newMemrefShape.append(memrefShape.begin(), memrefShape.end());
|
|
||||||
newMemrefShape.push_back(vectorType.getDimSize(0));
|
|
||||||
return MemRefType::get(newMemrefShape,
|
|
||||||
VectorType::get(vectorType.getShape().drop_front(),
|
|
||||||
vectorType.getElementType()));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Helper data structure for data and mask buffers.
|
|
||||||
struct BufferAllocs {
|
|
||||||
Value dataBuffer;
|
|
||||||
Value maskBuffer;
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Allocate temporary buffers for data (vector) and mask (if present).
|
|
||||||
/// TODO: Parallelism and threadlocal considerations.
|
|
||||||
template <typename OpTy>
|
|
||||||
static BufferAllocs allocBuffers(OpTy xferOp) {
|
|
||||||
auto &b = ScopedContext::getBuilderRef();
|
|
||||||
OpBuilder::InsertionGuard guard(b);
|
|
||||||
Operation *scope =
|
|
||||||
xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
|
|
||||||
assert(scope && "Expected op to be inside automatic allocation scope");
|
|
||||||
b.setInsertionPointToStart(&scope->getRegion(0).front());
|
|
||||||
|
|
||||||
BufferAllocs result;
|
|
||||||
auto bufferType = MemRefType::get({}, xferOp.getVectorType());
|
|
||||||
result.dataBuffer = memref_alloca(bufferType).value;
|
|
||||||
|
|
||||||
if (xferOp.mask()) {
|
|
||||||
auto maskType = MemRefType::get({}, xferOp.mask().getType());
|
|
||||||
Value maskBuffer = memref_alloca(maskType);
|
|
||||||
memref_store(xferOp.mask(), maskBuffer);
|
|
||||||
result.maskBuffer = memref_load(maskBuffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Given a vector transfer op, calculate which dimension of the `source`
|
/// Given a vector transfer op, calculate which dimension of the `source`
|
||||||
/// memref should be unpacked in the next application of TransferOpConversion.
|
/// memref should be unpacked in the next application of TransferOpConversion.
|
||||||
/// A return value of None indicates a broadcast.
|
/// A return value of None indicates a broadcast.
|
||||||
|
@ -284,6 +238,54 @@ static void maybeApplyPassLabel(OpBuilder &builder, OpTy newXferOp,
|
||||||
newXferOp->setAttr(kPassLabel, builder.getUnitAttr());
|
newXferOp->setAttr(kPassLabel, builder.getUnitAttr());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace lowering_n_d {
|
||||||
|
|
||||||
|
/// Helper data structure for data and mask buffers.
|
||||||
|
struct BufferAllocs {
|
||||||
|
Value dataBuffer;
|
||||||
|
Value maskBuffer;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Allocate temporary buffers for data (vector) and mask (if present).
|
||||||
|
/// TODO: Parallelism and threadlocal considerations.
|
||||||
|
template <typename OpTy>
|
||||||
|
static BufferAllocs allocBuffers(OpTy xferOp) {
|
||||||
|
auto &b = ScopedContext::getBuilderRef();
|
||||||
|
OpBuilder::InsertionGuard guard(b);
|
||||||
|
Operation *scope =
|
||||||
|
xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
|
||||||
|
assert(scope && "Expected op to be inside automatic allocation scope");
|
||||||
|
b.setInsertionPointToStart(&scope->getRegion(0).front());
|
||||||
|
|
||||||
|
BufferAllocs result;
|
||||||
|
auto bufferType = MemRefType::get({}, xferOp.getVectorType());
|
||||||
|
result.dataBuffer = memref_alloca(bufferType).value;
|
||||||
|
|
||||||
|
if (xferOp.mask()) {
|
||||||
|
auto maskType = MemRefType::get({}, xferOp.mask().getType());
|
||||||
|
auto maskBuffer = memref_alloca(maskType).value;
|
||||||
|
memref_store(xferOp.mask(), maskBuffer);
|
||||||
|
result.maskBuffer = memref_load(maskBuffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Given a MemRefType with VectorType element type, unpack one dimension from
|
||||||
|
/// the VectorType into the MemRefType.
|
||||||
|
///
|
||||||
|
/// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
|
||||||
|
static MemRefType unpackOneDim(MemRefType type) {
|
||||||
|
auto vectorType = type.getElementType().dyn_cast<VectorType>();
|
||||||
|
auto memrefShape = type.getShape();
|
||||||
|
SmallVector<int64_t, 8> newMemrefShape;
|
||||||
|
newMemrefShape.append(memrefShape.begin(), memrefShape.end());
|
||||||
|
newMemrefShape.push_back(vectorType.getDimSize(0));
|
||||||
|
return MemRefType::get(newMemrefShape,
|
||||||
|
VectorType::get(vectorType.getShape().drop_front(),
|
||||||
|
vectorType.getElementType()));
|
||||||
|
}
|
||||||
|
|
||||||
/// Given a transfer op, find the memref from which the mask is loaded. This
|
/// Given a transfer op, find the memref from which the mask is loaded. This
|
||||||
/// is similar to Strategy<TransferWriteOp>::getBuffer.
|
/// is similar to Strategy<TransferWriteOp>::getBuffer.
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
|
@ -688,6 +690,10 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
} // namespace lowering_n_d
|
||||||
|
|
||||||
|
namespace lowering_n_d_unrolled {
|
||||||
|
|
||||||
/// If the original transfer op has a mask, compute the mask of the new transfer
|
/// If the original transfer op has a mask, compute the mask of the new transfer
|
||||||
/// op (for the current iteration `i`) and assign it.
|
/// op (for the current iteration `i`) and assign it.
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
|
@ -954,6 +960,10 @@ struct UnrollTransferWriteConversion
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
} // namespace lowering_n_d_unrolled
|
||||||
|
|
||||||
|
namespace lowering_1_d {
|
||||||
|
|
||||||
/// Compute the indices into the memref for the LoadOp/StoreOp generated as
|
/// Compute the indices into the memref for the LoadOp/StoreOp generated as
|
||||||
/// part of TransferOp1dConversion. Return the memref dimension on which
|
/// part of TransferOp1dConversion. Return the memref dimension on which
|
||||||
/// the transfer is operating. A return value of None indicates a broadcast.
|
/// the transfer is operating. A return value of None indicates a broadcast.
|
||||||
|
@ -1114,6 +1124,7 @@ struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
} // namespace lowering_1_d
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
@ -1121,19 +1132,21 @@ namespace mlir {
|
||||||
void populateVectorToSCFConversionPatterns(
|
void populateVectorToSCFConversionPatterns(
|
||||||
RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
|
RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
|
||||||
if (options.unroll) {
|
if (options.unroll) {
|
||||||
patterns.add<UnrollTransferReadConversion, UnrollTransferWriteConversion>(
|
patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
|
||||||
|
lowering_n_d_unrolled::UnrollTransferWriteConversion>(
|
||||||
patterns.getContext(), options);
|
patterns.getContext(), options);
|
||||||
} else {
|
} else {
|
||||||
patterns.add<PrepareTransferReadConversion, PrepareTransferWriteConversion,
|
patterns.add<lowering_n_d::PrepareTransferReadConversion,
|
||||||
TransferOpConversion<TransferReadOp>,
|
lowering_n_d::PrepareTransferWriteConversion,
|
||||||
TransferOpConversion<TransferWriteOp>>(patterns.getContext(),
|
lowering_n_d::TransferOpConversion<TransferReadOp>,
|
||||||
options);
|
lowering_n_d::TransferOpConversion<TransferWriteOp>>(
|
||||||
|
patterns.getContext(), options);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (options.targetRank == 1) {
|
if (options.targetRank == 1) {
|
||||||
patterns.add<TransferOp1dConversion<TransferReadOp>,
|
patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
|
||||||
TransferOp1dConversion<TransferWriteOp>>(patterns.getContext(),
|
lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
|
||||||
options);
|
patterns.getContext(), options);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue