llvm-project/mlir/lib/VectorOps/VectorOps.cpp

422 lines
17 KiB
C++

//===- VectorOps.cpp - MLIR Super Vectorizer Operations -------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements convenience types for working with super-vectorization
// operations, in particular super-vector loads and stores.
//
//===----------------------------------------------------------------------===//
#include "mlir/VectorOps/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Support/LLVM.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// VectorOpsDialect
//===----------------------------------------------------------------------===//
VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
: Dialect("vector", context) {
addOperations<VectorTransferReadOp, VectorTransferWriteOp,
VectorTypeCastOp>();
}
//===----------------------------------------------------------------------===//
// VectorTransferReadOp
//===----------------------------------------------------------------------===//
template <typename EmitFun>
static LogicalResult verifyPermutationMap(AffineMap permutationMap,
EmitFun emitOpError) {
SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
for (auto expr : permutationMap.getResults()) {
auto dim = expr.dyn_cast<AffineDimExpr>();
auto zero = expr.dyn_cast<AffineConstantExpr>();
if (zero) {
if (zero.getValue() != 0) {
return emitOpError(
"requires a projected permutation_map (at most one dim or the zero "
"constant can appear in each result)");
}
continue;
}
if (!dim) {
return emitOpError("requires a projected permutation_map (at most one "
"dim or the zero constant can appear in each result)");
}
if (seen[dim.getPosition()]) {
return emitOpError(
"requires a permutation_map that is a permutation (found one dim "
"used more than once)");
}
seen[dim.getPosition()] = true;
}
return success();
}
void VectorTransferReadOp::build(Builder *builder, OperationState *result,
VectorType vectorType, Value *srcMemRef,
ArrayRef<Value *> srcIndices,
AffineMap permutationMap,
Optional<Value *> paddingValue) {
result->addOperands(srcMemRef);
result->addOperands(srcIndices);
if (paddingValue) {
result->addOperands({*paddingValue});
}
result->addAttribute(getPermutationMapAttrName(),
builder->getAffineMapAttr(permutationMap));
result->addTypes(vectorType);
}
auto VectorTransferReadOp::getIndices() -> operand_range {
auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
auto end = begin + getMemRefType().getRank();
return {begin, end};
}
Optional<Value *> VectorTransferReadOp::getPaddingValue() {
auto memRefRank = getMemRefType().getRank();
if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) {
return None;
}
return Optional<Value *>(getOperand(Offsets::FirstIndexOffset + memRefRank));
}
AffineMap VectorTransferReadOp::getPermutationMap() {
return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
}
void VectorTransferReadOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " ";
p->printOperand(getMemRef());
*p << "[";
p->printOperands(getIndices());
*p << "]";
auto optionalPaddingValue = getPaddingValue();
if (optionalPaddingValue) {
*p << ", (";
p->printOperand(*optionalPaddingValue);
*p << ")";
}
p->printOptionalAttrDict(getAttrs());
*p << " : " << getMemRefType();
*p << ", " << getResultType();
}
ParseResult VectorTransferReadOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 8> indexInfo;
SmallVector<OpAsmParser::OperandType, 8> paddingInfo;
SmallVector<Type, 2> types;
// Parsing with support for optional paddingValue.
if (parser->parseOperand(memrefInfo) ||
parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
parser->parseTrailingOperandList(paddingInfo, -1,
OpAsmParser::Delimiter::Paren) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonTypeList(types))
return failure();
// Resolution.
if (types.size() != 2)
return parser->emitError(parser->getNameLoc(), "expected 2 types");
MemRefType memrefType = types[0].dyn_cast<MemRefType>();
if (!memrefType)
return parser->emitError(parser->getNameLoc(), "memRef type expected");
VectorType vectorType = types[1].dyn_cast<VectorType>();
if (!vectorType)
return parser->emitError(parser->getNameLoc(), "vector type expected");
// Extract optional paddingValue.
// At this point, indexInfo may contain the optional paddingValue, pop it out.
if (static_cast<int64_t>(indexInfo.size()) != memrefType.getRank())
return parser->emitError(parser->getNameLoc(),
"expected " + Twine(memrefType.getRank()) +
" indices to the memref");
if (paddingInfo.size() > 1)
return parser->emitError(parser->getNameLoc(),
"expected at most one padding value");
Type paddingType;
bool hasOptionalPaddingValue = !paddingInfo.empty();
if (hasOptionalPaddingValue) {
paddingType = vectorType.getElementType();
}
auto indexType = parser->getBuilder().getIndexType();
return failure(
parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
parser->resolveOperands(indexInfo, indexType, result->operands) ||
(hasOptionalPaddingValue &&
parser->resolveOperand(paddingInfo[0], paddingType, result->operands)) ||
parser->addTypeToList(vectorType, result->types));
}
LogicalResult VectorTransferReadOp::verify() {
// Consistency of memref type in function type.
if (llvm::empty(getOperands())) {
return emitOpError(
"requires at least a memref operand followed by 'rank' indices");
}
if (!getMemRef()->getType().isa<MemRefType>()) {
return emitOpError("requires a memref as first operand");
}
// Consistency of vector type in function type.
if (!getResult()->getType().isa<VectorType>()) {
return emitOpError("should have a vector result type in function type: "
"memref_type<...xelemental_type>, vector_type");
}
// Consistency of elemental types in memref and vector.
MemRefType memrefType = getMemRefType();
VectorType vectorType = getResultType();
if (memrefType.getElementType() != vectorType.getElementType())
return emitOpError(
"requires memref and vector types of the same elemental type");
// Consistency of number of input types.
auto optionalPaddingValue = getPaddingValue();
unsigned expectedNumOperands = Offsets::FirstIndexOffset +
memrefType.getRank() +
(optionalPaddingValue ? 1 : 0);
// Checks on the actual operands and their types.
if (getNumOperands() != expectedNumOperands) {
return emitOpError("expects ")
<< expectedNumOperands << " operands (of which "
<< memrefType.getRank() << " indices)";
}
// Consistency of padding value with vector type.
if (optionalPaddingValue) {
auto paddingValue = *optionalPaddingValue;
auto elementalType = paddingValue->getType();
if (!VectorType::isValidElementType(elementalType)) {
return emitOpError("requires valid padding vector elemental type");
}
if (elementalType != vectorType.getElementType()) {
return emitOpError(
"requires formal padding and vector of the same elemental type");
}
}
// Consistency of indices types.
unsigned numIndices = 0;
for (auto *idx : getIndices()) {
if (!idx->getType().isIndex()) {
return emitOpError(
"index to vector.transfer_read must have 'index' type");
}
++numIndices;
}
if (numIndices != memrefType.getRank()) {
return emitOpError("requires at least a memref operand followed by ")
<< memrefType.getRank() << " indices";
}
// Consistency of AffineMap attribute.
if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) {
return emitOpError("requires an AffineMapAttr named 'permutation_map'");
}
auto permutationMap = getPermutationMap();
if (permutationMap.getNumSymbols() != 0) {
return emitOpError("requires a permutation_map without symbols");
}
if (permutationMap.getNumInputs() != memrefType.getRank()) {
return emitOpError("requires a permutation_map with input dims of the "
"same rank as the memref type");
}
if (permutationMap.getNumResults() != vectorType.getRank()) {
return emitOpError("requires a permutation_map with result dims of the "
"same rank as the vector type (")
<< permutationMap.getNumResults() << " vs " << vectorType.getRank();
}
return verifyPermutationMap(permutationMap,
[this](Twine t) { return emitOpError(t); });
}
//===----------------------------------------------------------------------===//
// VectorTransferWriteOp
//===----------------------------------------------------------------------===//
void VectorTransferWriteOp::build(Builder *builder, OperationState *result,
Value *srcVector, Value *dstMemRef,
ArrayRef<Value *> dstIndices,
AffineMap permutationMap) {
result->addOperands({srcVector, dstMemRef});
result->addOperands(dstIndices);
result->addAttribute(getPermutationMapAttrName(),
builder->getAffineMapAttr(permutationMap));
}
auto VectorTransferWriteOp::getIndices() -> operand_range {
auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
auto end = begin + getMemRefType().getRank();
return {begin, end};
}
AffineMap VectorTransferWriteOp::getPermutationMap() {
return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
}
void VectorTransferWriteOp::print(OpAsmPrinter *p) {
*p << getOperationName();
*p << " " << *getVector();
*p << ", " << *getMemRef();
*p << "[";
p->printOperands(getIndices());
*p << "]";
p->printOptionalAttrDict(getAttrs());
*p << " : ";
p->printType(getVectorType());
*p << ", ";
p->printType(getMemRefType());
}
ParseResult VectorTransferWriteOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType storeValueInfo;
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
SmallVector<Type, 2> types;
auto indexType = parser->getBuilder().getIndexType();
if (parser->parseOperand(storeValueInfo) || parser->parseComma() ||
parser->parseOperand(memrefInfo) ||
parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonTypeList(types))
return failure();
if (types.size() != 2)
return parser->emitError(parser->getNameLoc(), "expected 2 types");
VectorType vectorType = types[Offsets::VectorOffset].dyn_cast<VectorType>();
if (!vectorType)
return parser->emitError(parser->getNameLoc(), "vector type expected");
MemRefType memrefType = types[Offsets::MemRefOffset].dyn_cast<MemRefType>();
if (!memrefType)
return parser->emitError(parser->getNameLoc(), "memRef type expected");
return failure(
parser->resolveOperands(storeValueInfo, vectorType, result->operands) ||
parser->resolveOperands(memrefInfo, memrefType, result->operands) ||
parser->resolveOperands(indexInfo, indexType, result->operands));
}
LogicalResult VectorTransferWriteOp::verify() {
// Consistency of memref type in function type.
if (llvm::empty(getOperands())) {
return emitOpError(
"requires at least a memref operand followed by 'rank' indices");
}
if (!getMemRef()->getType().isa<MemRefType>()) {
return emitOpError("requires a memref first operand");
}
// Consistency of vector type in function type.
if (!getVector()->getType().isa<VectorType>()) {
return emitOpError("should have a vector input type in function type: "
"(vector_type, memref_type [, elemental_type]) -> ()");
}
// Consistency of elemental types in memref and vector.
MemRefType memrefType = getMemRefType();
VectorType vectorType = getVectorType();
if (memrefType.getElementType() != vectorType.getElementType())
return emitOpError(
"requires memref and vector types of the same elemental type");
// Consistency of number of input types.
unsigned expectedNumOperands =
Offsets::FirstIndexOffset + memrefType.getRank();
// Checks on the actual operands and their types.
if (getNumOperands() != expectedNumOperands) {
return emitOpError() << "expects " << expectedNumOperands
<< " operands (of which " << memrefType.getRank()
<< " indices)";
}
// Consistency of indices types.
unsigned numIndices = 0;
for (auto *idx : getIndices()) {
if (!idx->getType().isIndex()) {
return emitOpError(
"index to vector.transfer_write must have 'index' type");
}
numIndices++;
}
if (numIndices != memrefType.getRank()) {
return emitOpError("requires at least a memref operand followed by ")
<< memrefType.getRank() << " indices";
}
// Consistency of AffineMap attribute.
if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) {
return emitOpError("requires an AffineMapAttr named 'permutation_map'");
}
auto permutationMap = getPermutationMap();
if (permutationMap.getNumSymbols() != 0) {
return emitOpError("requires a permutation_map without symbols");
}
if (permutationMap.getNumInputs() != memrefType.getRank()) {
return emitOpError("requires a permutation_map with input dims of the "
"same rank as the memref type");
}
if (permutationMap.getNumResults() != vectorType.getRank()) {
return emitOpError("requires a permutation_map with result dims of the "
"same rank as the vector type (")
<< permutationMap.getNumResults() << " vs " << vectorType.getRank();
}
return verifyPermutationMap(permutationMap,
[this](Twine t) { return emitOpError(t); });
}
//===----------------------------------------------------------------------===//
// VectorTypeCastOp
//===----------------------------------------------------------------------===//
void VectorTypeCastOp::build(Builder *builder, OperationState *result,
Value *srcVector, Type dstType) {
result->addOperands(srcVector);
result->addTypes(dstType);
}
ParseResult VectorTypeCastOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType operand;
Type srcType, dstType;
return failure(parser->parseOperand(operand) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(srcType) || parser->parseComma() ||
parser->parseType(dstType) ||
parser->addTypeToList(dstType, result->types) ||
parser->resolveOperand(operand, srcType, result->operands));
}
void VectorTypeCastOp::print(OpAsmPrinter *p) {
*p << getOperationName() << ' ' << *getOperand() << " : "
<< getOperand()->getType() << ", " << getType();
}
LogicalResult VectorTypeCastOp::verify() {
auto dstMemrefType = getType().dyn_cast<MemRefType>();
if (!dstMemrefType)
return emitOpError("expects target type to be a memref type");
auto dstVectorType = dstMemrefType.getElementType().dyn_cast<VectorType>();
if (!dstVectorType)
return emitOpError(
"expects vector as an element of the target memref type");
if (!dstMemrefType.hasStaticShape())
return emitOpError("does not support dynamic shapes");
if (!getOperand()->getType().isa<MemRefType>())
return emitOpError("expects source type to be a memref type");
return success();
}