forked from OSchip/llvm-project
422 lines
17 KiB
C++
422 lines
17 KiB
C++
//===- VectorOps.cpp - MLIR Super Vectorizer Operations -------------------===//
|
|
//
|
|
// Copyright 2019 The MLIR Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
// =============================================================================
|
|
//
|
|
// This file implements convenience types for working with super-vectorization
|
|
// operations, in particular super-vector loads and stores.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/VectorOps/VectorOps.h"
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// VectorOpsDialect
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
|
|
: Dialect("vector", context) {
|
|
addOperations<VectorTransferReadOp, VectorTransferWriteOp,
|
|
VectorTypeCastOp>();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// VectorTransferReadOp
|
|
//===----------------------------------------------------------------------===//
|
|
template <typename EmitFun>
|
|
static LogicalResult verifyPermutationMap(AffineMap permutationMap,
|
|
EmitFun emitOpError) {
|
|
SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
|
|
for (auto expr : permutationMap.getResults()) {
|
|
auto dim = expr.dyn_cast<AffineDimExpr>();
|
|
auto zero = expr.dyn_cast<AffineConstantExpr>();
|
|
if (zero) {
|
|
if (zero.getValue() != 0) {
|
|
return emitOpError(
|
|
"requires a projected permutation_map (at most one dim or the zero "
|
|
"constant can appear in each result)");
|
|
}
|
|
continue;
|
|
}
|
|
if (!dim) {
|
|
return emitOpError("requires a projected permutation_map (at most one "
|
|
"dim or the zero constant can appear in each result)");
|
|
}
|
|
if (seen[dim.getPosition()]) {
|
|
return emitOpError(
|
|
"requires a permutation_map that is a permutation (found one dim "
|
|
"used more than once)");
|
|
}
|
|
seen[dim.getPosition()] = true;
|
|
}
|
|
return success();
|
|
}
|
|
|
|
void VectorTransferReadOp::build(Builder *builder, OperationState *result,
|
|
VectorType vectorType, Value *srcMemRef,
|
|
ArrayRef<Value *> srcIndices,
|
|
AffineMap permutationMap,
|
|
Optional<Value *> paddingValue) {
|
|
result->addOperands(srcMemRef);
|
|
result->addOperands(srcIndices);
|
|
if (paddingValue) {
|
|
result->addOperands({*paddingValue});
|
|
}
|
|
result->addAttribute(getPermutationMapAttrName(),
|
|
builder->getAffineMapAttr(permutationMap));
|
|
result->addTypes(vectorType);
|
|
}
|
|
|
|
auto VectorTransferReadOp::getIndices() -> operand_range {
|
|
auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
|
|
auto end = begin + getMemRefType().getRank();
|
|
return {begin, end};
|
|
}
|
|
|
|
Optional<Value *> VectorTransferReadOp::getPaddingValue() {
|
|
auto memRefRank = getMemRefType().getRank();
|
|
if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) {
|
|
return None;
|
|
}
|
|
return Optional<Value *>(getOperand(Offsets::FirstIndexOffset + memRefRank));
|
|
}
|
|
|
|
AffineMap VectorTransferReadOp::getPermutationMap() {
|
|
return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
|
|
}
|
|
|
|
void VectorTransferReadOp::print(OpAsmPrinter *p) {
|
|
*p << getOperationName() << " ";
|
|
p->printOperand(getMemRef());
|
|
*p << "[";
|
|
p->printOperands(getIndices());
|
|
*p << "]";
|
|
auto optionalPaddingValue = getPaddingValue();
|
|
if (optionalPaddingValue) {
|
|
*p << ", (";
|
|
p->printOperand(*optionalPaddingValue);
|
|
*p << ")";
|
|
}
|
|
p->printOptionalAttrDict(getAttrs());
|
|
*p << " : " << getMemRefType();
|
|
*p << ", " << getResultType();
|
|
}
|
|
|
|
ParseResult VectorTransferReadOp::parse(OpAsmParser *parser,
|
|
OperationState *result) {
|
|
OpAsmParser::OperandType memrefInfo;
|
|
SmallVector<OpAsmParser::OperandType, 8> indexInfo;
|
|
SmallVector<OpAsmParser::OperandType, 8> paddingInfo;
|
|
SmallVector<Type, 2> types;
|
|
|
|
// Parsing with support for optional paddingValue.
|
|
if (parser->parseOperand(memrefInfo) ||
|
|
parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
|
|
parser->parseTrailingOperandList(paddingInfo, -1,
|
|
OpAsmParser::Delimiter::Paren) ||
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
parser->parseColonTypeList(types))
|
|
return failure();
|
|
|
|
// Resolution.
|
|
if (types.size() != 2)
|
|
return parser->emitError(parser->getNameLoc(), "expected 2 types");
|
|
MemRefType memrefType = types[0].dyn_cast<MemRefType>();
|
|
if (!memrefType)
|
|
return parser->emitError(parser->getNameLoc(), "memRef type expected");
|
|
VectorType vectorType = types[1].dyn_cast<VectorType>();
|
|
if (!vectorType)
|
|
return parser->emitError(parser->getNameLoc(), "vector type expected");
|
|
|
|
// Extract optional paddingValue.
|
|
// At this point, indexInfo may contain the optional paddingValue, pop it out.
|
|
if (static_cast<int64_t>(indexInfo.size()) != memrefType.getRank())
|
|
return parser->emitError(parser->getNameLoc(),
|
|
"expected " + Twine(memrefType.getRank()) +
|
|
" indices to the memref");
|
|
if (paddingInfo.size() > 1)
|
|
return parser->emitError(parser->getNameLoc(),
|
|
"expected at most one padding value");
|
|
Type paddingType;
|
|
bool hasOptionalPaddingValue = !paddingInfo.empty();
|
|
if (hasOptionalPaddingValue) {
|
|
paddingType = vectorType.getElementType();
|
|
}
|
|
auto indexType = parser->getBuilder().getIndexType();
|
|
return failure(
|
|
parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
|
|
parser->resolveOperands(indexInfo, indexType, result->operands) ||
|
|
(hasOptionalPaddingValue &&
|
|
parser->resolveOperand(paddingInfo[0], paddingType, result->operands)) ||
|
|
parser->addTypeToList(vectorType, result->types));
|
|
}
|
|
|
|
LogicalResult VectorTransferReadOp::verify() {
|
|
// Consistency of memref type in function type.
|
|
if (llvm::empty(getOperands())) {
|
|
return emitOpError(
|
|
"requires at least a memref operand followed by 'rank' indices");
|
|
}
|
|
if (!getMemRef()->getType().isa<MemRefType>()) {
|
|
return emitOpError("requires a memref as first operand");
|
|
}
|
|
// Consistency of vector type in function type.
|
|
if (!getResult()->getType().isa<VectorType>()) {
|
|
return emitOpError("should have a vector result type in function type: "
|
|
"memref_type<...xelemental_type>, vector_type");
|
|
}
|
|
// Consistency of elemental types in memref and vector.
|
|
MemRefType memrefType = getMemRefType();
|
|
VectorType vectorType = getResultType();
|
|
if (memrefType.getElementType() != vectorType.getElementType())
|
|
return emitOpError(
|
|
"requires memref and vector types of the same elemental type");
|
|
// Consistency of number of input types.
|
|
auto optionalPaddingValue = getPaddingValue();
|
|
unsigned expectedNumOperands = Offsets::FirstIndexOffset +
|
|
memrefType.getRank() +
|
|
(optionalPaddingValue ? 1 : 0);
|
|
// Checks on the actual operands and their types.
|
|
if (getNumOperands() != expectedNumOperands) {
|
|
return emitOpError("expects ")
|
|
<< expectedNumOperands << " operands (of which "
|
|
<< memrefType.getRank() << " indices)";
|
|
}
|
|
// Consistency of padding value with vector type.
|
|
if (optionalPaddingValue) {
|
|
auto paddingValue = *optionalPaddingValue;
|
|
auto elementalType = paddingValue->getType();
|
|
if (!VectorType::isValidElementType(elementalType)) {
|
|
return emitOpError("requires valid padding vector elemental type");
|
|
}
|
|
if (elementalType != vectorType.getElementType()) {
|
|
return emitOpError(
|
|
"requires formal padding and vector of the same elemental type");
|
|
}
|
|
}
|
|
// Consistency of indices types.
|
|
unsigned numIndices = 0;
|
|
for (auto *idx : getIndices()) {
|
|
if (!idx->getType().isIndex()) {
|
|
return emitOpError(
|
|
"index to vector.transfer_read must have 'index' type");
|
|
}
|
|
++numIndices;
|
|
}
|
|
if (numIndices != memrefType.getRank()) {
|
|
return emitOpError("requires at least a memref operand followed by ")
|
|
<< memrefType.getRank() << " indices";
|
|
}
|
|
|
|
// Consistency of AffineMap attribute.
|
|
if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) {
|
|
return emitOpError("requires an AffineMapAttr named 'permutation_map'");
|
|
}
|
|
auto permutationMap = getPermutationMap();
|
|
if (permutationMap.getNumSymbols() != 0) {
|
|
return emitOpError("requires a permutation_map without symbols");
|
|
}
|
|
if (permutationMap.getNumInputs() != memrefType.getRank()) {
|
|
return emitOpError("requires a permutation_map with input dims of the "
|
|
"same rank as the memref type");
|
|
}
|
|
if (permutationMap.getNumResults() != vectorType.getRank()) {
|
|
return emitOpError("requires a permutation_map with result dims of the "
|
|
"same rank as the vector type (")
|
|
<< permutationMap.getNumResults() << " vs " << vectorType.getRank();
|
|
}
|
|
return verifyPermutationMap(permutationMap,
|
|
[this](Twine t) { return emitOpError(t); });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// VectorTransferWriteOp
|
|
//===----------------------------------------------------------------------===//
|
|
void VectorTransferWriteOp::build(Builder *builder, OperationState *result,
|
|
Value *srcVector, Value *dstMemRef,
|
|
ArrayRef<Value *> dstIndices,
|
|
AffineMap permutationMap) {
|
|
result->addOperands({srcVector, dstMemRef});
|
|
result->addOperands(dstIndices);
|
|
result->addAttribute(getPermutationMapAttrName(),
|
|
builder->getAffineMapAttr(permutationMap));
|
|
}
|
|
|
|
auto VectorTransferWriteOp::getIndices() -> operand_range {
|
|
auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
|
|
auto end = begin + getMemRefType().getRank();
|
|
return {begin, end};
|
|
}
|
|
|
|
AffineMap VectorTransferWriteOp::getPermutationMap() {
|
|
return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
|
|
}
|
|
|
|
void VectorTransferWriteOp::print(OpAsmPrinter *p) {
|
|
*p << getOperationName();
|
|
*p << " " << *getVector();
|
|
*p << ", " << *getMemRef();
|
|
*p << "[";
|
|
p->printOperands(getIndices());
|
|
*p << "]";
|
|
p->printOptionalAttrDict(getAttrs());
|
|
*p << " : ";
|
|
p->printType(getVectorType());
|
|
*p << ", ";
|
|
p->printType(getMemRefType());
|
|
}
|
|
|
|
ParseResult VectorTransferWriteOp::parse(OpAsmParser *parser,
|
|
OperationState *result) {
|
|
OpAsmParser::OperandType storeValueInfo;
|
|
OpAsmParser::OperandType memrefInfo;
|
|
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
|
SmallVector<Type, 2> types;
|
|
auto indexType = parser->getBuilder().getIndexType();
|
|
if (parser->parseOperand(storeValueInfo) || parser->parseComma() ||
|
|
parser->parseOperand(memrefInfo) ||
|
|
parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
parser->parseColonTypeList(types))
|
|
return failure();
|
|
|
|
if (types.size() != 2)
|
|
return parser->emitError(parser->getNameLoc(), "expected 2 types");
|
|
VectorType vectorType = types[Offsets::VectorOffset].dyn_cast<VectorType>();
|
|
if (!vectorType)
|
|
return parser->emitError(parser->getNameLoc(), "vector type expected");
|
|
MemRefType memrefType = types[Offsets::MemRefOffset].dyn_cast<MemRefType>();
|
|
if (!memrefType)
|
|
return parser->emitError(parser->getNameLoc(), "memRef type expected");
|
|
|
|
return failure(
|
|
parser->resolveOperands(storeValueInfo, vectorType, result->operands) ||
|
|
parser->resolveOperands(memrefInfo, memrefType, result->operands) ||
|
|
parser->resolveOperands(indexInfo, indexType, result->operands));
|
|
}
|
|
|
|
LogicalResult VectorTransferWriteOp::verify() {
|
|
// Consistency of memref type in function type.
|
|
if (llvm::empty(getOperands())) {
|
|
return emitOpError(
|
|
"requires at least a memref operand followed by 'rank' indices");
|
|
}
|
|
if (!getMemRef()->getType().isa<MemRefType>()) {
|
|
return emitOpError("requires a memref first operand");
|
|
}
|
|
// Consistency of vector type in function type.
|
|
if (!getVector()->getType().isa<VectorType>()) {
|
|
return emitOpError("should have a vector input type in function type: "
|
|
"(vector_type, memref_type [, elemental_type]) -> ()");
|
|
}
|
|
// Consistency of elemental types in memref and vector.
|
|
MemRefType memrefType = getMemRefType();
|
|
VectorType vectorType = getVectorType();
|
|
if (memrefType.getElementType() != vectorType.getElementType())
|
|
return emitOpError(
|
|
"requires memref and vector types of the same elemental type");
|
|
// Consistency of number of input types.
|
|
unsigned expectedNumOperands =
|
|
Offsets::FirstIndexOffset + memrefType.getRank();
|
|
// Checks on the actual operands and their types.
|
|
if (getNumOperands() != expectedNumOperands) {
|
|
return emitOpError() << "expects " << expectedNumOperands
|
|
<< " operands (of which " << memrefType.getRank()
|
|
<< " indices)";
|
|
}
|
|
// Consistency of indices types.
|
|
unsigned numIndices = 0;
|
|
for (auto *idx : getIndices()) {
|
|
if (!idx->getType().isIndex()) {
|
|
return emitOpError(
|
|
"index to vector.transfer_write must have 'index' type");
|
|
}
|
|
numIndices++;
|
|
}
|
|
if (numIndices != memrefType.getRank()) {
|
|
return emitOpError("requires at least a memref operand followed by ")
|
|
<< memrefType.getRank() << " indices";
|
|
}
|
|
|
|
// Consistency of AffineMap attribute.
|
|
if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) {
|
|
return emitOpError("requires an AffineMapAttr named 'permutation_map'");
|
|
}
|
|
auto permutationMap = getPermutationMap();
|
|
if (permutationMap.getNumSymbols() != 0) {
|
|
return emitOpError("requires a permutation_map without symbols");
|
|
}
|
|
if (permutationMap.getNumInputs() != memrefType.getRank()) {
|
|
return emitOpError("requires a permutation_map with input dims of the "
|
|
"same rank as the memref type");
|
|
}
|
|
if (permutationMap.getNumResults() != vectorType.getRank()) {
|
|
return emitOpError("requires a permutation_map with result dims of the "
|
|
"same rank as the vector type (")
|
|
<< permutationMap.getNumResults() << " vs " << vectorType.getRank();
|
|
}
|
|
return verifyPermutationMap(permutationMap,
|
|
[this](Twine t) { return emitOpError(t); });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// VectorTypeCastOp
|
|
//===----------------------------------------------------------------------===//
|
|
void VectorTypeCastOp::build(Builder *builder, OperationState *result,
|
|
Value *srcVector, Type dstType) {
|
|
result->addOperands(srcVector);
|
|
result->addTypes(dstType);
|
|
}
|
|
|
|
ParseResult VectorTypeCastOp::parse(OpAsmParser *parser,
|
|
OperationState *result) {
|
|
OpAsmParser::OperandType operand;
|
|
Type srcType, dstType;
|
|
return failure(parser->parseOperand(operand) ||
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
parser->parseColonType(srcType) || parser->parseComma() ||
|
|
parser->parseType(dstType) ||
|
|
parser->addTypeToList(dstType, result->types) ||
|
|
parser->resolveOperand(operand, srcType, result->operands));
|
|
}
|
|
|
|
void VectorTypeCastOp::print(OpAsmPrinter *p) {
|
|
*p << getOperationName() << ' ' << *getOperand() << " : "
|
|
<< getOperand()->getType() << ", " << getType();
|
|
}
|
|
|
|
LogicalResult VectorTypeCastOp::verify() {
|
|
auto dstMemrefType = getType().dyn_cast<MemRefType>();
|
|
if (!dstMemrefType)
|
|
return emitOpError("expects target type to be a memref type");
|
|
auto dstVectorType = dstMemrefType.getElementType().dyn_cast<VectorType>();
|
|
if (!dstVectorType)
|
|
return emitOpError(
|
|
"expects vector as an element of the target memref type");
|
|
if (!dstMemrefType.hasStaticShape())
|
|
return emitOpError("does not support dynamic shapes");
|
|
|
|
if (!getOperand()->getType().isa<MemRefType>())
|
|
return emitOpError("expects source type to be a memref type");
|
|
|
|
return success();
|
|
}
|