forked from OSchip/llvm-project
496 lines
20 KiB
C++
496 lines
20 KiB
C++
//===- SuperVectorOps.cpp - MLIR Super Vectorizer Operations---------------===//
|
|
//
|
|
// Copyright 2019 The MLIR Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
// =============================================================================
|
|
//
|
|
// This file implements convenience types for working with super-vectorization
|
|
// operations, in particular super-vector loads and stores.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/SuperVectorOps/SuperVectorOps.h"
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/Types.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SuperVectorOpsDialect
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
SuperVectorOpsDialect::SuperVectorOpsDialect(MLIRContext *context)
|
|
: Dialect(/*opPrefix=*/"", context) {
|
|
addOperations<VectorTransferReadOp, VectorTransferWriteOp,
|
|
VectorTypeCastOp>();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// VectorTransferReadOp
|
|
//===----------------------------------------------------------------------===//
|
|
template <typename EmitFun>
|
|
static bool verifyPermutationMap(AffineMap permutationMap,
|
|
EmitFun emitOpError) {
|
|
SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
|
|
for (auto expr : permutationMap.getResults()) {
|
|
auto dim = expr.dyn_cast<AffineDimExpr>();
|
|
auto zero = expr.dyn_cast<AffineConstantExpr>();
|
|
if (zero) {
|
|
if (zero.getValue() != 0) {
|
|
return emitOpError(
|
|
"requires a projected permutation_map (at most one dim or the zero "
|
|
"constant can appear in each result)");
|
|
}
|
|
continue;
|
|
}
|
|
if (!dim) {
|
|
return emitOpError("requires a projected permutation_map (at most one "
|
|
"dim or the zero constant can appear in each result)");
|
|
}
|
|
if (seen[dim.getPosition()]) {
|
|
return emitOpError(
|
|
"requires a permutation_map that is a permutation (found one dim "
|
|
"used more than once)");
|
|
}
|
|
seen[dim.getPosition()] = true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
void VectorTransferReadOp::build(Builder *builder, OperationState *result,
|
|
VectorType vectorType, Value *srcMemRef,
|
|
ArrayRef<Value *> srcIndices,
|
|
AffineMap permutationMap,
|
|
Optional<Value *> paddingValue) {
|
|
result->addOperands(srcMemRef);
|
|
result->addOperands(srcIndices);
|
|
if (paddingValue) {
|
|
result->addOperands({*paddingValue});
|
|
}
|
|
result->addAttribute(getPermutationMapAttrName(),
|
|
builder->getAffineMapAttr(permutationMap));
|
|
result->addTypes(vectorType);
|
|
}
|
|
|
|
llvm::iterator_range<OperationInst::operand_iterator>
|
|
VectorTransferReadOp::getIndices() {
|
|
auto begin = getInstruction()->operand_begin() + Offsets::FirstIndexOffset;
|
|
auto end = begin + getMemRefType().getRank();
|
|
return {begin, end};
|
|
}
|
|
|
|
llvm::iterator_range<OperationInst::const_operand_iterator>
|
|
VectorTransferReadOp::getIndices() const {
|
|
auto begin = getInstruction()->operand_begin() + Offsets::FirstIndexOffset;
|
|
auto end = begin + getMemRefType().getRank();
|
|
return {begin, end};
|
|
}
|
|
|
|
Optional<Value *> VectorTransferReadOp::getPaddingValue() {
|
|
auto memRefRank = getMemRefType().getRank();
|
|
if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) {
|
|
return None;
|
|
}
|
|
return Optional<Value *>(getOperand(Offsets::FirstIndexOffset + memRefRank));
|
|
}
|
|
|
|
Optional<const Value *> VectorTransferReadOp::getPaddingValue() const {
|
|
auto memRefRank = getMemRefType().getRank();
|
|
if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) {
|
|
return None;
|
|
}
|
|
return Optional<const Value *>(
|
|
getOperand(Offsets::FirstIndexOffset + memRefRank));
|
|
}
|
|
|
|
AffineMap VectorTransferReadOp::getPermutationMap() const {
|
|
return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
|
|
}
|
|
|
|
void VectorTransferReadOp::print(OpAsmPrinter *p) const {
|
|
*p << getOperationName() << " ";
|
|
p->printOperand(getMemRef());
|
|
*p << ", ";
|
|
p->printOperands(getIndices());
|
|
auto optionalPaddingValue = getPaddingValue();
|
|
if (optionalPaddingValue) {
|
|
*p << ", ";
|
|
p->printOperand(*optionalPaddingValue);
|
|
}
|
|
p->printOptionalAttrDict(getAttrs());
|
|
// Construct the FunctionType and print it.
|
|
llvm::SmallVector<Type, 8> inputs{getMemRefType()};
|
|
// Must have at least one actual index, see verify.
|
|
const Value *firstIndex = *(getIndices().begin());
|
|
Type indexType = firstIndex->getType();
|
|
inputs.append(getMemRefType().getRank(), indexType);
|
|
if (optionalPaddingValue) {
|
|
inputs.push_back((*optionalPaddingValue)->getType());
|
|
}
|
|
*p << " : "
|
|
<< FunctionType::get(inputs, {getResultType()}, indexType.getContext());
|
|
}
|
|
|
|
bool VectorTransferReadOp::parse(OpAsmParser *parser, OperationState *result) {
|
|
SmallVector<OpAsmParser::OperandType, 8> parsedOperands;
|
|
Type type;
|
|
|
|
// Parsing with support for optional paddingValue.
|
|
auto fail = parser->parseOperandList(parsedOperands) ||
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
parser->parseColonType(type);
|
|
if (fail) {
|
|
return true;
|
|
}
|
|
|
|
// Resolution.
|
|
auto funType = type.dyn_cast<FunctionType>();
|
|
if (!funType)
|
|
return parser->emitError(parser->getNameLoc(), "Function type expected");
|
|
if (funType.getNumInputs() < 1)
|
|
return parser->emitError(parser->getNameLoc(),
|
|
"Function type expects at least one input");
|
|
MemRefType memrefType =
|
|
funType.getInput(Offsets::MemRefOffset).dyn_cast<MemRefType>();
|
|
if (!memrefType)
|
|
return parser->emitError(parser->getNameLoc(),
|
|
"MemRef type expected for first input");
|
|
if (funType.getNumResults() < 1)
|
|
return parser->emitError(parser->getNameLoc(),
|
|
"Function type expects exactly one vector result");
|
|
VectorType vectorType = funType.getResult(0).dyn_cast<VectorType>();
|
|
if (!vectorType)
|
|
return parser->emitError(parser->getNameLoc(),
|
|
"Vector type expected for first result");
|
|
if (parsedOperands.size() != funType.getNumInputs())
|
|
return parser->emitError(parser->getNameLoc(),
|
|
"requires " + Twine(funType.getNumInputs()) +
|
|
" operands");
|
|
|
|
// Extract optional paddingValue.
|
|
OpAsmParser::OperandType memrefInfo = parsedOperands[0];
|
|
// At this point, indexInfo may contain the optional paddingValue, pop it out.
|
|
SmallVector<OpAsmParser::OperandType, 8> indexInfo{
|
|
parsedOperands.begin() + Offsets::FirstIndexOffset, parsedOperands.end()};
|
|
Type paddingType;
|
|
OpAsmParser::OperandType paddingValue;
|
|
bool hasPaddingValue = indexInfo.size() > memrefType.getRank();
|
|
unsigned expectedNumOperands = Offsets::FirstIndexOffset +
|
|
memrefType.getRank() +
|
|
(hasPaddingValue ? 1 : 0);
|
|
if (hasPaddingValue) {
|
|
paddingType = funType.getInputs().back();
|
|
paddingValue = indexInfo.pop_back_val();
|
|
}
|
|
if (funType.getNumInputs() != expectedNumOperands)
|
|
return parser->emitError(
|
|
parser->getNameLoc(),
|
|
"requires actual number of operands to match function type");
|
|
|
|
auto indexType = parser->getBuilder().getIndexType();
|
|
return parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
|
|
parser->resolveOperands(indexInfo, indexType, result->operands) ||
|
|
(hasPaddingValue && parser->resolveOperand(paddingValue, paddingType,
|
|
result->operands)) ||
|
|
parser->addTypeToList(vectorType, result->types);
|
|
}
|
|
|
|
bool VectorTransferReadOp::verify() const {
|
|
// Consistency of memref type in function type.
|
|
if (llvm::empty(getOperands())) {
|
|
return emitOpError(
|
|
"requires at least a memref operand followed by 'rank' indices");
|
|
}
|
|
if (!getMemRef()->getType().isa<MemRefType>()) {
|
|
return emitOpError("requires a memref as first operand");
|
|
}
|
|
// Consistency of vector type in function type.
|
|
if (!getResult()->getType().isa<VectorType>()) {
|
|
return emitOpError("should have a vector result type in function type: "
|
|
"(memref_type [, elemental_type]) -> vector_type");
|
|
}
|
|
// Consistency of elemental types in memref and vector.
|
|
MemRefType memrefType = getMemRefType();
|
|
VectorType vectorType = getResultType();
|
|
if (memrefType.getElementType() != vectorType.getElementType())
|
|
return emitOpError(
|
|
"requires memref and vector types of the same elemental type");
|
|
// Consistency of number of input types.
|
|
auto optionalPaddingValue = getPaddingValue();
|
|
unsigned expectedNumOperands = Offsets::FirstIndexOffset +
|
|
memrefType.getRank() +
|
|
(optionalPaddingValue ? 1 : 0);
|
|
// Checks on the actual operands and their types.
|
|
if (getNumOperands() != expectedNumOperands) {
|
|
return emitOpError("expects " + Twine(expectedNumOperands) +
|
|
" operands to match the types");
|
|
}
|
|
// Consistency of padding value with vector type.
|
|
if (optionalPaddingValue) {
|
|
auto paddingValue = *optionalPaddingValue;
|
|
auto elementalType = paddingValue->getType();
|
|
if (!VectorType::isValidElementType(elementalType)) {
|
|
return emitOpError("requires valid padding vector elemental type");
|
|
}
|
|
if (elementalType != vectorType.getElementType()) {
|
|
return emitOpError(
|
|
"requires formal padding and vector of the same elemental type");
|
|
}
|
|
}
|
|
// Consistency of indices types.
|
|
unsigned numIndices = 0;
|
|
for (auto *idx : getIndices()) {
|
|
if (!idx->getType().isIndex()) {
|
|
return emitOpError(
|
|
"index to vector_transfer_read must have 'index' type");
|
|
}
|
|
++numIndices;
|
|
}
|
|
if (numIndices != memrefType.getRank()) {
|
|
return emitOpError("requires at least a memref operand followed by " +
|
|
Twine(memrefType.getRank()) + " indices");
|
|
}
|
|
|
|
// Consistency of AffineMap attribute.
|
|
if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) {
|
|
return emitOpError("requires an AffineMapAttr named 'permutation_map'");
|
|
}
|
|
auto permutationMap = getPermutationMap();
|
|
if (!permutationMap.getRangeSizes().empty()) {
|
|
return emitOpError("requires an unbounded permutation_map");
|
|
}
|
|
if (permutationMap.getNumSymbols() != 0) {
|
|
return emitOpError("requires a permutation_map without symbols");
|
|
}
|
|
if (permutationMap.getNumInputs() != memrefType.getRank()) {
|
|
return emitOpError("requires a permutation_map with input dims of the "
|
|
"same rank as the memref type");
|
|
}
|
|
if (permutationMap.getNumResults() != vectorType.getRank()) {
|
|
return emitOpError("requires a permutation_map with result dims of the "
|
|
"same rank as the vector type (" +
|
|
Twine(permutationMap.getNumResults()) + " vs " +
|
|
Twine(vectorType.getRank()));
|
|
}
|
|
return verifyPermutationMap(permutationMap,
|
|
[this](Twine t) { return emitOpError(t); });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// VectorTransferWriteOp
|
|
//===----------------------------------------------------------------------===//
|
|
void VectorTransferWriteOp::build(Builder *builder, OperationState *result,
|
|
Value *srcVector, Value *dstMemRef,
|
|
ArrayRef<Value *> dstIndices,
|
|
AffineMap permutationMap) {
|
|
result->addOperands({srcVector, dstMemRef});
|
|
result->addOperands(dstIndices);
|
|
result->addAttribute(getPermutationMapAttrName(),
|
|
builder->getAffineMapAttr(permutationMap));
|
|
}
|
|
|
|
llvm::iterator_range<OperationInst::operand_iterator>
|
|
VectorTransferWriteOp::getIndices() {
|
|
auto begin = getInstruction()->operand_begin() + Offsets::FirstIndexOffset;
|
|
auto end = begin + getMemRefType().getRank();
|
|
return {begin, end};
|
|
}
|
|
|
|
llvm::iterator_range<OperationInst::const_operand_iterator>
|
|
VectorTransferWriteOp::getIndices() const {
|
|
auto begin = getInstruction()->operand_begin() + Offsets::FirstIndexOffset;
|
|
auto end = begin + getMemRefType().getRank();
|
|
return {begin, end};
|
|
}
|
|
|
|
AffineMap VectorTransferWriteOp::getPermutationMap() const {
|
|
return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
|
|
}
|
|
|
|
void VectorTransferWriteOp::print(OpAsmPrinter *p) const {
|
|
*p << getOperationName();
|
|
*p << " " << *getVector();
|
|
*p << ", " << *getMemRef();
|
|
*p << ", ";
|
|
p->printOperands(getIndices());
|
|
p->printOptionalAttrDict(getAttrs());
|
|
Type indexType = (*getIndices().begin())->getType();
|
|
*p << " : ";
|
|
p->printType(getVectorType());
|
|
*p << ", ";
|
|
p->printType(getMemRefType());
|
|
for (unsigned r = 0, n = getMemRefType().getRank(); r < n; ++r) {
|
|
*p << ", ";
|
|
p->printType(indexType);
|
|
}
|
|
}
|
|
|
|
bool VectorTransferWriteOp::parse(OpAsmParser *parser, OperationState *result) {
|
|
SmallVector<OpAsmParser::OperandType, 8> parsedOperands;
|
|
SmallVector<Type, 8> types;
|
|
|
|
// Parsing with support for optional paddingValue.
|
|
auto fail = parser->parseOperandList(parsedOperands) ||
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
parser->parseColonTypeList(types);
|
|
if (fail) {
|
|
return true;
|
|
}
|
|
|
|
// Resolution.
|
|
if (parsedOperands.size() != types.size())
|
|
return parser->emitError(
|
|
parser->getNameLoc(),
|
|
"requires number of operands and input types to match");
|
|
if (parsedOperands.size() < Offsets::FirstIndexOffset)
|
|
return parser->emitError(parser->getNameLoc(),
|
|
"requires at least vector and memref operands");
|
|
VectorType vectorType = types[Offsets::VectorOffset].dyn_cast<VectorType>();
|
|
if (!vectorType)
|
|
return parser->emitError(parser->getNameLoc(),
|
|
"Vector type expected for first input type");
|
|
MemRefType memrefType = types[Offsets::MemRefOffset].dyn_cast<MemRefType>();
|
|
if (!memrefType)
|
|
return parser->emitError(parser->getNameLoc(),
|
|
"MemRef type expected for second input type");
|
|
|
|
unsigned expectedNumOperands =
|
|
Offsets::FirstIndexOffset + memrefType.getRank();
|
|
if (parsedOperands.size() != expectedNumOperands)
|
|
return parser->emitError(parser->getNameLoc(),
|
|
"requires " + Twine(expectedNumOperands) +
|
|
" operands");
|
|
|
|
OpAsmParser::OperandType vectorInfo = parsedOperands[Offsets::VectorOffset];
|
|
OpAsmParser::OperandType memrefInfo = parsedOperands[Offsets::MemRefOffset];
|
|
SmallVector<OpAsmParser::OperandType, 8> indexInfo{
|
|
parsedOperands.begin() + Offsets::FirstIndexOffset, parsedOperands.end()};
|
|
auto indexType = parser->getBuilder().getIndexType();
|
|
return parser->resolveOperand(vectorInfo, vectorType, result->operands) ||
|
|
parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
|
|
parser->resolveOperands(indexInfo, indexType, result->operands);
|
|
}
|
|
|
|
bool VectorTransferWriteOp::verify() const {
|
|
// Consistency of memref type in function type.
|
|
if (llvm::empty(getOperands())) {
|
|
return emitOpError(
|
|
"requires at least a memref operand followed by 'rank' indices");
|
|
}
|
|
if (!getMemRef()->getType().isa<MemRefType>()) {
|
|
return emitOpError("requires a memref first operand");
|
|
}
|
|
// Consistency of vector type in function type.
|
|
if (!getVector()->getType().isa<VectorType>()) {
|
|
return emitOpError("should have a vector input type in function type: "
|
|
"(vector_type, memref_type [, elemental_type]) -> ()");
|
|
}
|
|
// Consistency of elemental types in memref and vector.
|
|
MemRefType memrefType = getMemRefType();
|
|
VectorType vectorType = getVectorType();
|
|
if (memrefType.getElementType() != vectorType.getElementType())
|
|
return emitOpError(
|
|
"requires memref and vector types of the same elemental type");
|
|
// Consistency of number of input types.
|
|
unsigned expectedNumOperands =
|
|
Offsets::FirstIndexOffset + memrefType.getRank();
|
|
// Checks on the actual operands and their types.
|
|
if (getNumOperands() != expectedNumOperands) {
|
|
return emitOpError("expects " + Twine(expectedNumOperands) +
|
|
" operands to match the types");
|
|
}
|
|
// Consistency of indices types.
|
|
unsigned numIndices = 0;
|
|
for (auto *idx : getIndices()) {
|
|
if (!idx->getType().isIndex()) {
|
|
return emitOpError(
|
|
"index to vector_transfer_write must have 'index' type");
|
|
}
|
|
numIndices++;
|
|
}
|
|
if (numIndices != memrefType.getRank()) {
|
|
return emitOpError("requires at least a memref operand followed by " +
|
|
Twine(memrefType.getRank()) + " indices");
|
|
}
|
|
|
|
// Consistency of AffineMap attribute.
|
|
if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) {
|
|
return emitOpError("requires an AffineMapAttr named 'permutation_map'");
|
|
}
|
|
auto permutationMap = getPermutationMap();
|
|
if (!permutationMap.getRangeSizes().empty()) {
|
|
return emitOpError("requires an unbounded permutation_map");
|
|
}
|
|
if (permutationMap.getNumSymbols() != 0) {
|
|
return emitOpError("requires a permutation_map without symbols");
|
|
}
|
|
if (permutationMap.getNumInputs() != memrefType.getRank()) {
|
|
return emitOpError("requires a permutation_map with input dims of the "
|
|
"same rank as the memref type");
|
|
}
|
|
if (permutationMap.getNumResults() != vectorType.getRank()) {
|
|
return emitOpError("requires a permutation_map with result dims of the "
|
|
"same rank as the vector type (" +
|
|
Twine(permutationMap.getNumResults()) + " vs " +
|
|
Twine(vectorType.getRank()));
|
|
}
|
|
return verifyPermutationMap(permutationMap,
|
|
[this](Twine t) { return emitOpError(t); });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// VectorTypeCastOp
|
|
//===----------------------------------------------------------------------===//
|
|
void VectorTypeCastOp::build(Builder *builder, OperationState *result,
|
|
Value *srcVector, Type dstType) {
|
|
result->addOperands(srcVector);
|
|
result->addTypes(dstType);
|
|
}
|
|
|
|
bool VectorTypeCastOp::parse(OpAsmParser *parser, OperationState *result) {
|
|
OpAsmParser::OperandType operand;
|
|
Type srcType, dstType;
|
|
return parser->parseOperand(operand) ||
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
parser->parseColonType(srcType) || parser->parseComma() ||
|
|
parser->parseType(dstType) ||
|
|
parser->addTypeToList(dstType, result->types) ||
|
|
parser->resolveOperand(operand, srcType, result->operands);
|
|
}
|
|
|
|
void VectorTypeCastOp::print(OpAsmPrinter *p) const {
|
|
*p << getOperationName() << ' ' << *getOperand() << " : "
|
|
<< getOperand()->getType() << ", " << getType();
|
|
}
|
|
|
|
bool VectorTypeCastOp::verify() const {
|
|
auto dstMemrefType = getType().dyn_cast<MemRefType>();
|
|
if (!dstMemrefType)
|
|
return emitOpError("expects target type to be a memref type");
|
|
auto dstVectorType = dstMemrefType.getElementType().dyn_cast<VectorType>();
|
|
if (!dstVectorType)
|
|
return emitOpError(
|
|
"expects vector as an element of the target memref type");
|
|
if (llvm::any_of(dstMemrefType.getShape(), [](int s) { return s == -1; }))
|
|
return emitOpError("does not support dynamic shapes");
|
|
|
|
if (!getOperand()->getType().isa<MemRefType>())
|
|
return emitOpError("expects source type to be a memref type");
|
|
|
|
return false;
|
|
}
|