[RISCV][NFC] Refactor RISC-V vector intrinsic utils.

This patch is preparation for D111617, use class/struct/enum rather than char/StringRef to present internal information as possible, that provide more compact way to store those info and also easier to serialize/deserialize.

And also that improve readability of the code, e.g. "v" vs TypeProfile::Vector.

Reviewed By: khchen

Differential Revision: https://reviews.llvm.org/D124730
This commit is contained in:
Kito Cheng 2022-05-11 23:39:13 +08:00
parent c71f6376eb
commit 7ff0bf576b
3 changed files with 691 additions and 221 deletions

View File

@ -9,7 +9,10 @@
#ifndef CLANG_SUPPORT_RISCVVINTRINSICUTILS_H
#define CLANG_SUPPORT_RISCVVINTRINSICUTILS_H
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/BitmaskEnum.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include <cstdint>
#include <string>
@ -18,9 +21,133 @@
namespace clang {
namespace RISCV {
using BasicType = char;
using VScaleVal = llvm::Optional<unsigned>;
// Modifier for vector type.
enum class VectorTypeModifier : uint8_t {
NoModifier,
Widening2XVector,
Widening4XVector,
Widening8XVector,
MaskVector,
Log2EEW3,
Log2EEW4,
Log2EEW5,
Log2EEW6,
FixedSEW8,
FixedSEW16,
FixedSEW32,
FixedSEW64,
LFixedLog2LMULN3,
LFixedLog2LMULN2,
LFixedLog2LMULN1,
LFixedLog2LMUL0,
LFixedLog2LMUL1,
LFixedLog2LMUL2,
LFixedLog2LMUL3,
SFixedLog2LMULN3,
SFixedLog2LMULN2,
SFixedLog2LMULN1,
SFixedLog2LMUL0,
SFixedLog2LMUL1,
SFixedLog2LMUL2,
SFixedLog2LMUL3,
};
// Similar to basic type but used to describe what's kind of type related to
// basic vector type, used to compute type info of arguments.
enum class BaseTypeModifier : uint8_t {
Invalid,
Scalar,
Vector,
Void,
SizeT,
Ptrdiff,
UnsignedLong,
SignedLong,
};
// Modifier for type, used for both scalar and vector types.
enum class TypeModifier : uint8_t {
NoModifier = 0,
Pointer = 1 << 0,
Const = 1 << 1,
Immediate = 1 << 2,
UnsignedInteger = 1 << 3,
SignedInteger = 1 << 4,
Float = 1 << 5,
// LMUL1 should be kind of VectorTypeModifier, but that might come with
// Widening2XVector for widening reduction.
// However that might require VectorTypeModifier become bitmask rather than
// simple enum, so we decide keek LMUL1 in TypeModifier for code size
// optimization of clang binary size.
LMUL1 = 1 << 6,
MaxOffset = 6,
LLVM_MARK_AS_BITMASK_ENUM(LMUL1),
};
// PrototypeDescriptor is used to compute type info of arguments or return
// value.
struct PrototypeDescriptor {
constexpr PrototypeDescriptor() = default;
constexpr PrototypeDescriptor(
BaseTypeModifier PT,
VectorTypeModifier VTM = VectorTypeModifier::NoModifier,
TypeModifier TM = TypeModifier::NoModifier)
: PT(static_cast<uint8_t>(PT)), VTM(static_cast<uint8_t>(VTM)),
TM(static_cast<uint8_t>(TM)) {}
constexpr PrototypeDescriptor(uint8_t PT, uint8_t VTM, uint8_t TM)
: PT(PT), VTM(VTM), TM(TM) {}
uint8_t PT = static_cast<uint8_t>(BaseTypeModifier::Invalid);
uint8_t VTM = static_cast<uint8_t>(VectorTypeModifier::NoModifier);
uint8_t TM = static_cast<uint8_t>(TypeModifier::NoModifier);
bool operator!=(const PrototypeDescriptor &PD) const {
return PD.PT != PT || PD.VTM != VTM || PD.TM != TM;
}
bool operator>(const PrototypeDescriptor &PD) const {
return !(PD.PT <= PT && PD.VTM <= VTM && PD.TM <= TM);
}
static const PrototypeDescriptor Mask;
static const PrototypeDescriptor Vector;
static const PrototypeDescriptor VL;
static llvm::Optional<PrototypeDescriptor>
parsePrototypeDescriptor(llvm::StringRef PrototypeStr);
};
llvm::SmallVector<PrototypeDescriptor>
parsePrototypes(llvm::StringRef Prototypes);
// Basic type of vector type.
enum class BasicType : uint8_t {
Unknown = 0,
Int8 = 1 << 0,
Int16 = 1 << 1,
Int32 = 1 << 2,
Int64 = 1 << 3,
Float16 = 1 << 4,
Float32 = 1 << 5,
Float64 = 1 << 6,
MaxOffset = 6,
LLVM_MARK_AS_BITMASK_ENUM(Float64),
};
// Type of vector type.
enum ScalarTypeKind : uint8_t {
Void,
Size_t,
Ptrdiff_t,
UnsignedLong,
SignedLong,
Boolean,
SignedInteger,
UnsignedInteger,
Float,
Invalid,
};
// Exponential LMUL
struct LMULType {
int Log2LMUL;
@ -32,20 +159,12 @@ struct LMULType {
LMULType &operator*=(uint32_t RHS);
};
class RVVType;
using RVVTypePtr = RVVType *;
using RVVTypes = std::vector<RVVTypePtr>;
// This class is compact representation of a valid and invalid RVVType.
class RVVType {
enum ScalarTypeKind : uint32_t {
Void,
Size_t,
Ptrdiff_t,
UnsignedLong,
SignedLong,
Boolean,
SignedInteger,
UnsignedInteger,
Float,
Invalid,
};
BasicType BT;
ScalarTypeKind ScalarType = Invalid;
LMULType LMUL;
@ -63,9 +182,11 @@ class RVVType {
std::string Str;
std::string ShortStr;
enum class FixedLMULType { LargerThan, SmallerThan };
public:
RVVType() : RVVType(BasicType(), 0, llvm::StringRef()) {}
RVVType(BasicType BT, int Log2LMUL, llvm::StringRef prototype);
RVVType() : BT(BasicType::Unknown), LMUL(0), Valid(false) {}
RVVType(BasicType BT, int Log2LMUL, const PrototypeDescriptor &Profile);
// Return the string representation of a type, which is an encoded string for
// passing to the BUILTIN() macro in Builtins.def.
@ -114,7 +235,11 @@ private:
// Applies a prototype modifier to the current type. The result maybe an
// invalid type.
void applyModifier(llvm::StringRef prototype);
void applyModifier(const PrototypeDescriptor &prototype);
void applyLog2EEW(unsigned Log2EEW);
void applyFixedSEW(unsigned NewSEW);
void applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type);
// Compute and record a string for legal type.
void initBuiltinStr();
@ -124,10 +249,19 @@ private:
void initTypeStr();
// Compute and record a short name of a type for C/C++ name suffix.
void initShortStr();
public:
/// Compute output and input types by applying different config (basic type
/// and LMUL with type transformers). It also record result of type in legal
/// or illegal set to avoid compute the same config again. The result maybe
/// have illegal RVVType.
static llvm::Optional<RVVTypes>
computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
llvm::ArrayRef<PrototypeDescriptor> PrototypeSeq);
static llvm::Optional<RVVTypePtr> computeType(BasicType BT, int Log2LMUL,
PrototypeDescriptor Proto);
};
using RVVTypePtr = RVVType *;
using RVVTypes = std::vector<RVVTypePtr>;
using RISCVPredefinedMacroT = uint8_t;
enum RISCVPredefinedMacro : RISCVPredefinedMacroT {
@ -206,6 +340,10 @@ public:
// Return the type string for a BUILTIN() macro in Builtins.def.
std::string getBuiltinTypeStr() const;
static std::string getSuffixStr(
BasicType Type, int Log2LMUL,
const llvm::SmallVector<PrototypeDescriptor> &PrototypeDescriptors);
};
} // end namespace RISCV

View File

@ -16,12 +16,25 @@
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
#include <numeric>
#include <set>
#include <unordered_map>
using namespace llvm;
namespace clang {
namespace RISCV {
const PrototypeDescriptor PrototypeDescriptor::Mask = PrototypeDescriptor(
BaseTypeModifier::Vector, VectorTypeModifier::MaskVector);
const PrototypeDescriptor PrototypeDescriptor::VL =
PrototypeDescriptor(BaseTypeModifier::SizeT);
const PrototypeDescriptor PrototypeDescriptor::Vector =
PrototypeDescriptor(BaseTypeModifier::Vector);
// Concat BasicType, LMUL and Proto as key
static std::unordered_map<uint64_t, RVVType> LegalTypes;
static std::set<uint64_t> IllegalTypes;
//===----------------------------------------------------------------------===//
// Type implementation
//===----------------------------------------------------------------------===//
@ -70,7 +83,8 @@ LMULType &LMULType::operator*=(uint32_t RHS) {
return *this;
}
RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype)
RVVType::RVVType(BasicType BT, int Log2LMUL,
const PrototypeDescriptor &prototype)
: BT(BT), LMUL(LMULType(Log2LMUL)) {
applyBasicType();
applyModifier(prototype);
@ -326,31 +340,31 @@ void RVVType::initShortStr() {
void RVVType::applyBasicType() {
switch (BT) {
case 'c':
case BasicType::Int8:
ElementBitwidth = 8;
ScalarType = ScalarTypeKind::SignedInteger;
break;
case 's':
case BasicType::Int16:
ElementBitwidth = 16;
ScalarType = ScalarTypeKind::SignedInteger;
break;
case 'i':
case BasicType::Int32:
ElementBitwidth = 32;
ScalarType = ScalarTypeKind::SignedInteger;
break;
case 'l':
case BasicType::Int64:
ElementBitwidth = 64;
ScalarType = ScalarTypeKind::SignedInteger;
break;
case 'x':
case BasicType::Float16:
ElementBitwidth = 16;
ScalarType = ScalarTypeKind::Float;
break;
case 'f':
case BasicType::Float32:
ElementBitwidth = 32;
ScalarType = ScalarTypeKind::Float;
break;
case 'd':
case BasicType::Float64:
ElementBitwidth = 64;
ScalarType = ScalarTypeKind::Float;
break;
@ -360,160 +374,481 @@ void RVVType::applyBasicType() {
assert(ElementBitwidth != 0 && "Bad element bitwidth!");
}
void RVVType::applyModifier(StringRef Transformer) {
if (Transformer.empty())
return;
// Handle primitive type transformer
auto PType = Transformer.back();
Optional<PrototypeDescriptor> PrototypeDescriptor::parsePrototypeDescriptor(
llvm::StringRef PrototypeDescriptorStr) {
PrototypeDescriptor PD;
BaseTypeModifier PT = BaseTypeModifier::Invalid;
VectorTypeModifier VTM = VectorTypeModifier::NoModifier;
if (PrototypeDescriptorStr.empty())
return PD;
// Handle base type modifier
auto PType = PrototypeDescriptorStr.back();
switch (PType) {
case 'e':
Scale = 0;
PT = BaseTypeModifier::Scalar;
break;
case 'v':
Scale = LMUL.getScale(ElementBitwidth);
PT = BaseTypeModifier::Vector;
break;
case 'w':
ElementBitwidth *= 2;
LMUL *= 2;
Scale = LMUL.getScale(ElementBitwidth);
PT = BaseTypeModifier::Vector;
VTM = VectorTypeModifier::Widening2XVector;
break;
case 'q':
ElementBitwidth *= 4;
LMUL *= 4;
Scale = LMUL.getScale(ElementBitwidth);
PT = BaseTypeModifier::Vector;
VTM = VectorTypeModifier::Widening4XVector;
break;
case 'o':
ElementBitwidth *= 8;
LMUL *= 8;
Scale = LMUL.getScale(ElementBitwidth);
PT = BaseTypeModifier::Vector;
VTM = VectorTypeModifier::Widening8XVector;
break;
case 'm':
ScalarType = ScalarTypeKind::Boolean;
Scale = LMUL.getScale(ElementBitwidth);
ElementBitwidth = 1;
PT = BaseTypeModifier::Vector;
VTM = VectorTypeModifier::MaskVector;
break;
case '0':
ScalarType = ScalarTypeKind::Void;
PT = BaseTypeModifier::Void;
break;
case 'z':
ScalarType = ScalarTypeKind::Size_t;
PT = BaseTypeModifier::SizeT;
break;
case 't':
ScalarType = ScalarTypeKind::Ptrdiff_t;
PT = BaseTypeModifier::Ptrdiff;
break;
case 'u':
ScalarType = ScalarTypeKind::UnsignedLong;
PT = BaseTypeModifier::UnsignedLong;
break;
case 'l':
ScalarType = ScalarTypeKind::SignedLong;
PT = BaseTypeModifier::SignedLong;
break;
default:
llvm_unreachable("Illegal primitive type transformers!");
}
Transformer = Transformer.drop_back();
PD.PT = static_cast<uint8_t>(PT);
PrototypeDescriptorStr = PrototypeDescriptorStr.drop_back();
// Extract and compute complex type transformer. It can only appear one time.
if (Transformer.startswith("(")) {
size_t Idx = Transformer.find(')');
// Compute the vector type transformers, it can only appear one time.
if (PrototypeDescriptorStr.startswith("(")) {
assert(VTM == VectorTypeModifier::NoModifier &&
"VectorTypeModifier should only have one modifier");
size_t Idx = PrototypeDescriptorStr.find(')');
assert(Idx != StringRef::npos);
StringRef ComplexType = Transformer.slice(1, Idx);
Transformer = Transformer.drop_front(Idx + 1);
assert(!Transformer.contains('(') &&
"Only allow one complex type transformer");
StringRef ComplexType = PrototypeDescriptorStr.slice(1, Idx);
PrototypeDescriptorStr = PrototypeDescriptorStr.drop_front(Idx + 1);
assert(!PrototypeDescriptorStr.contains('(') &&
"Only allow one vector type modifier");
auto UpdateAndCheckComplexProto = [&]() {
Scale = LMUL.getScale(ElementBitwidth);
const StringRef VectorPrototypes("vwqom");
if (!VectorPrototypes.contains(PType))
llvm_unreachable("Complex type transformer only supports vector type!");
if (Transformer.find_first_of("PCKWS") != StringRef::npos)
llvm_unreachable(
"Illegal type transformer for Complex type transformer");
};
auto ComputeFixedLog2LMUL =
[&](StringRef Value,
std::function<bool(const int32_t &, const int32_t &)> Compare) {
int32_t Log2LMUL;
Value.getAsInteger(10, Log2LMUL);
if (!Compare(Log2LMUL, LMUL.Log2LMUL)) {
ScalarType = Invalid;
return false;
}
// Update new LMUL
LMUL = LMULType(Log2LMUL);
UpdateAndCheckComplexProto();
return true;
};
auto ComplexTT = ComplexType.split(":");
if (ComplexTT.first == "Log2EEW") {
uint32_t Log2EEW;
ComplexTT.second.getAsInteger(10, Log2EEW);
// update new elmul = (eew/sew) * lmul
LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
// update new eew
ElementBitwidth = 1 << Log2EEW;
ScalarType = ScalarTypeKind::SignedInteger;
UpdateAndCheckComplexProto();
if (ComplexTT.second.getAsInteger(10, Log2EEW)) {
llvm_unreachable("Invalid Log2EEW value!");
return None;
}
switch (Log2EEW) {
case 3:
VTM = VectorTypeModifier::Log2EEW3;
break;
case 4:
VTM = VectorTypeModifier::Log2EEW4;
break;
case 5:
VTM = VectorTypeModifier::Log2EEW5;
break;
case 6:
VTM = VectorTypeModifier::Log2EEW6;
break;
default:
llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
return None;
}
} else if (ComplexTT.first == "FixedSEW") {
uint32_t NewSEW;
ComplexTT.second.getAsInteger(10, NewSEW);
// Set invalid type if src and dst SEW are same.
if (ElementBitwidth == NewSEW) {
ScalarType = Invalid;
return;
if (ComplexTT.second.getAsInteger(10, NewSEW)) {
llvm_unreachable("Invalid FixedSEW value!");
return None;
}
switch (NewSEW) {
case 8:
VTM = VectorTypeModifier::FixedSEW8;
break;
case 16:
VTM = VectorTypeModifier::FixedSEW16;
break;
case 32:
VTM = VectorTypeModifier::FixedSEW32;
break;
case 64:
VTM = VectorTypeModifier::FixedSEW64;
break;
default:
llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
return None;
}
// Update new SEW
ElementBitwidth = NewSEW;
UpdateAndCheckComplexProto();
} else if (ComplexTT.first == "LFixedLog2LMUL") {
// New LMUL should be larger than old
if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater<int32_t>()))
return;
int32_t Log2LMUL;
if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
llvm_unreachable("Invalid LFixedLog2LMUL value!");
return None;
}
switch (Log2LMUL) {
case -3:
VTM = VectorTypeModifier::LFixedLog2LMULN3;
break;
case -2:
VTM = VectorTypeModifier::LFixedLog2LMULN2;
break;
case -1:
VTM = VectorTypeModifier::LFixedLog2LMULN1;
break;
case 0:
VTM = VectorTypeModifier::LFixedLog2LMUL0;
break;
case 1:
VTM = VectorTypeModifier::LFixedLog2LMUL1;
break;
case 2:
VTM = VectorTypeModifier::LFixedLog2LMUL2;
break;
case 3:
VTM = VectorTypeModifier::LFixedLog2LMUL3;
break;
default:
llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
return None;
}
} else if (ComplexTT.first == "SFixedLog2LMUL") {
// New LMUL should be smaller than old
if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less<int32_t>()))
return;
int32_t Log2LMUL;
if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
llvm_unreachable("Invalid SFixedLog2LMUL value!");
return None;
}
switch (Log2LMUL) {
case -3:
VTM = VectorTypeModifier::SFixedLog2LMULN3;
break;
case -2:
VTM = VectorTypeModifier::SFixedLog2LMULN2;
break;
case -1:
VTM = VectorTypeModifier::SFixedLog2LMULN1;
break;
case 0:
VTM = VectorTypeModifier::SFixedLog2LMUL0;
break;
case 1:
VTM = VectorTypeModifier::SFixedLog2LMUL1;
break;
case 2:
VTM = VectorTypeModifier::SFixedLog2LMUL2;
break;
case 3:
VTM = VectorTypeModifier::SFixedLog2LMUL3;
break;
default:
llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
return None;
}
} else {
llvm_unreachable("Illegal complex type transformers!");
}
}
PD.VTM = static_cast<uint8_t>(VTM);
// Compute the remain type transformers
for (char I : Transformer) {
TypeModifier TM = TypeModifier::NoModifier;
for (char I : PrototypeDescriptorStr) {
switch (I) {
case 'P':
if (IsConstant)
if ((TM & TypeModifier::Const) == TypeModifier::Const)
llvm_unreachable("'P' transformer cannot be used after 'C'");
if (IsPointer)
if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer)
llvm_unreachable("'P' transformer cannot be used twice");
IsPointer = true;
TM |= TypeModifier::Pointer;
break;
case 'C':
if (IsConstant)
llvm_unreachable("'C' transformer cannot be used twice");
IsConstant = true;
TM |= TypeModifier::Const;
break;
case 'K':
IsImmediate = true;
TM |= TypeModifier::Immediate;
break;
case 'U':
ScalarType = ScalarTypeKind::UnsignedInteger;
TM |= TypeModifier::UnsignedInteger;
break;
case 'I':
ScalarType = ScalarTypeKind::SignedInteger;
TM |= TypeModifier::SignedInteger;
break;
case 'F':
ScalarType = ScalarTypeKind::Float;
TM |= TypeModifier::Float;
break;
case 'S':
LMUL = LMULType(0);
// Update ElementBitwidth need to update Scale too.
Scale = LMUL.getScale(ElementBitwidth);
TM |= TypeModifier::LMUL1;
break;
default:
llvm_unreachable("Illegal non-primitive type transformer!");
}
}
PD.TM = static_cast<uint8_t>(TM);
return PD;
}
void RVVType::applyModifier(const PrototypeDescriptor &Transformer) {
// Handle primitive type transformer
switch (static_cast<BaseTypeModifier>(Transformer.PT)) {
case BaseTypeModifier::Scalar:
Scale = 0;
break;
case BaseTypeModifier::Vector:
Scale = LMUL.getScale(ElementBitwidth);
break;
case BaseTypeModifier::Void:
ScalarType = ScalarTypeKind::Void;
break;
case BaseTypeModifier::SizeT:
ScalarType = ScalarTypeKind::Size_t;
break;
case BaseTypeModifier::Ptrdiff:
ScalarType = ScalarTypeKind::Ptrdiff_t;
break;
case BaseTypeModifier::UnsignedLong:
ScalarType = ScalarTypeKind::UnsignedLong;
break;
case BaseTypeModifier::SignedLong:
ScalarType = ScalarTypeKind::SignedLong;
break;
case BaseTypeModifier::Invalid:
ScalarType = ScalarTypeKind::Invalid;
return;
}
switch (static_cast<VectorTypeModifier>(Transformer.VTM)) {
case VectorTypeModifier::Widening2XVector:
ElementBitwidth *= 2;
LMUL *= 2;
Scale = LMUL.getScale(ElementBitwidth);
break;
case VectorTypeModifier::Widening4XVector:
ElementBitwidth *= 4;
LMUL *= 4;
Scale = LMUL.getScale(ElementBitwidth);
break;
case VectorTypeModifier::Widening8XVector:
ElementBitwidth *= 8;
LMUL *= 8;
Scale = LMUL.getScale(ElementBitwidth);
break;
case VectorTypeModifier::MaskVector:
ScalarType = ScalarTypeKind::Boolean;
Scale = LMUL.getScale(ElementBitwidth);
ElementBitwidth = 1;
break;
case VectorTypeModifier::Log2EEW3:
applyLog2EEW(3);
break;
case VectorTypeModifier::Log2EEW4:
applyLog2EEW(4);
break;
case VectorTypeModifier::Log2EEW5:
applyLog2EEW(5);
break;
case VectorTypeModifier::Log2EEW6:
applyLog2EEW(6);
break;
case VectorTypeModifier::FixedSEW8:
applyFixedSEW(8);
break;
case VectorTypeModifier::FixedSEW16:
applyFixedSEW(16);
break;
case VectorTypeModifier::FixedSEW32:
applyFixedSEW(32);
break;
case VectorTypeModifier::FixedSEW64:
applyFixedSEW(64);
break;
case VectorTypeModifier::LFixedLog2LMULN3:
applyFixedLog2LMUL(-3, FixedLMULType::LargerThan);
break;
case VectorTypeModifier::LFixedLog2LMULN2:
applyFixedLog2LMUL(-2, FixedLMULType::LargerThan);
break;
case VectorTypeModifier::LFixedLog2LMULN1:
applyFixedLog2LMUL(-1, FixedLMULType::LargerThan);
break;
case VectorTypeModifier::LFixedLog2LMUL0:
applyFixedLog2LMUL(0, FixedLMULType::LargerThan);
break;
case VectorTypeModifier::LFixedLog2LMUL1:
applyFixedLog2LMUL(1, FixedLMULType::LargerThan);
break;
case VectorTypeModifier::LFixedLog2LMUL2:
applyFixedLog2LMUL(2, FixedLMULType::LargerThan);
break;
case VectorTypeModifier::LFixedLog2LMUL3:
applyFixedLog2LMUL(3, FixedLMULType::LargerThan);
break;
case VectorTypeModifier::SFixedLog2LMULN3:
applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan);
break;
case VectorTypeModifier::SFixedLog2LMULN2:
applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan);
break;
case VectorTypeModifier::SFixedLog2LMULN1:
applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan);
break;
case VectorTypeModifier::SFixedLog2LMUL0:
applyFixedLog2LMUL(0, FixedLMULType::SmallerThan);
break;
case VectorTypeModifier::SFixedLog2LMUL1:
applyFixedLog2LMUL(1, FixedLMULType::SmallerThan);
break;
case VectorTypeModifier::SFixedLog2LMUL2:
applyFixedLog2LMUL(2, FixedLMULType::SmallerThan);
break;
case VectorTypeModifier::SFixedLog2LMUL3:
applyFixedLog2LMUL(3, FixedLMULType::SmallerThan);
break;
case VectorTypeModifier::NoModifier:
break;
}
for (unsigned TypeModifierMaskShift = 0;
TypeModifierMaskShift <= static_cast<unsigned>(TypeModifier::MaxOffset);
++TypeModifierMaskShift) {
unsigned TypeModifierMask = 1 << TypeModifierMaskShift;
if ((static_cast<unsigned>(Transformer.TM) & TypeModifierMask) !=
TypeModifierMask)
continue;
switch (static_cast<TypeModifier>(TypeModifierMask)) {
case TypeModifier::Pointer:
IsPointer = true;
break;
case TypeModifier::Const:
IsConstant = true;
break;
case TypeModifier::Immediate:
IsImmediate = true;
IsConstant = true;
break;
case TypeModifier::UnsignedInteger:
ScalarType = ScalarTypeKind::UnsignedInteger;
break;
case TypeModifier::SignedInteger:
ScalarType = ScalarTypeKind::SignedInteger;
break;
case TypeModifier::Float:
ScalarType = ScalarTypeKind::Float;
break;
case TypeModifier::LMUL1:
LMUL = LMULType(0);
// Update ElementBitwidth need to update Scale too.
Scale = LMUL.getScale(ElementBitwidth);
break;
default:
llvm_unreachable("Unknown type modifier mask!");
}
}
}
void RVVType::applyLog2EEW(unsigned Log2EEW) {
// update new elmul = (eew/sew) * lmul
LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
// update new eew
ElementBitwidth = 1 << Log2EEW;
ScalarType = ScalarTypeKind::SignedInteger;
Scale = LMUL.getScale(ElementBitwidth);
}
void RVVType::applyFixedSEW(unsigned NewSEW) {
// Set invalid type if src and dst SEW are same.
if (ElementBitwidth == NewSEW) {
ScalarType = ScalarTypeKind::Invalid;
return;
}
// Update new SEW
ElementBitwidth = NewSEW;
Scale = LMUL.getScale(ElementBitwidth);
}
void RVVType::applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type) {
switch (Type) {
case FixedLMULType::LargerThan:
if (Log2LMUL < LMUL.Log2LMUL) {
ScalarType = ScalarTypeKind::Invalid;
return;
}
break;
case FixedLMULType::SmallerThan:
if (Log2LMUL > LMUL.Log2LMUL) {
ScalarType = ScalarTypeKind::Invalid;
return;
}
break;
default:
llvm_unreachable("Unknown FixedLMULType??");
}
// Update new LMUL
LMUL = LMULType(Log2LMUL);
Scale = LMUL.getScale(ElementBitwidth);
}
Optional<RVVTypes>
RVVType::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
ArrayRef<PrototypeDescriptor> PrototypeSeq) {
// LMUL x NF must be less than or equal to 8.
if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8)
return llvm::None;
RVVTypes Types;
for (const PrototypeDescriptor &Proto : PrototypeSeq) {
auto T = computeType(BT, Log2LMUL, Proto);
if (!T.hasValue())
return llvm::None;
// Record legal type index
Types.push_back(T.getValue());
}
return Types;
}
// Compute the hash value of RVVType, used for cache the result of computeType.
static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL,
PrototypeDescriptor Proto) {
// Layout of hash value:
// 0 8 16 24 32 40
// | Log2LMUL + 3 | BT | Proto.PT | Proto.TM | Proto.VTM |
assert(Log2LMUL >= -3 && Log2LMUL <= 3);
return (Log2LMUL + 3) | (static_cast<uint64_t>(BT) & 0xff) << 8 |
((uint64_t)(Proto.PT & 0xff) << 16) |
((uint64_t)(Proto.TM & 0xff) << 24) |
((uint64_t)(Proto.VTM & 0xff) << 32);
}
Optional<RVVTypePtr> RVVType::computeType(BasicType BT, int Log2LMUL,
PrototypeDescriptor Proto) {
uint64_t Idx = computeRVVTypeHashValue(BT, Log2LMUL, Proto);
// Search first
auto It = LegalTypes.find(Idx);
if (It != LegalTypes.end())
return &(It->second);
if (IllegalTypes.count(Idx))
return llvm::None;
// Compute type and record the result.
RVVType T(BT, Log2LMUL, Proto);
if (T.isValid()) {
// Record legal type index and value.
LegalTypes.insert({Idx, T});
return &(LegalTypes[Idx]);
}
// Record illegal type index.
IllegalTypes.insert(Idx);
return llvm::None;
}
//===----------------------------------------------------------------------===//
@ -593,5 +928,37 @@ std::string RVVIntrinsic::getBuiltinTypeStr() const {
return S;
}
std::string RVVIntrinsic::getSuffixStr(
BasicType Type, int Log2LMUL,
const llvm::SmallVector<PrototypeDescriptor> &PrototypeDescriptors) {
SmallVector<std::string> SuffixStrs;
for (auto PD : PrototypeDescriptors) {
auto T = RVVType::computeType(Type, Log2LMUL, PD);
SuffixStrs.push_back(T.getValue()->getShortStr());
}
return join(SuffixStrs, "_");
}
SmallVector<PrototypeDescriptor> parsePrototypes(StringRef Prototypes) {
SmallVector<PrototypeDescriptor> PrototypeDescriptors;
const StringRef Primaries("evwqom0ztul");
while (!Prototypes.empty()) {
size_t Idx = 0;
// Skip over complex prototype because it could contain primitive type
// character.
if (Prototypes[0] == '(')
Idx = Prototypes.find_first_of(')');
Idx = Prototypes.find_first_of(Primaries, Idx);
assert(Idx != StringRef::npos);
auto PD = PrototypeDescriptor::parsePrototypeDescriptor(
Prototypes.slice(0, Idx + 1));
if (!PD)
llvm_unreachable("Error during parsing prototype.");
PrototypeDescriptors.push_back(*PD);
Prototypes = Prototypes.drop_front(Idx + 1);
}
return PrototypeDescriptors;
}
} // end namespace RISCV
} // end namespace clang

View File

@ -32,9 +32,6 @@ namespace {
class RVVEmitter {
private:
RecordKeeper &Records;
// Concat BasicType, LMUL and Proto as key
StringMap<RVVType> LegalTypes;
StringSet<> IllegalTypes;
public:
RVVEmitter(RecordKeeper &R) : Records(R) {}
@ -48,20 +45,11 @@ public:
/// Emit all the information needed to map builtin -> LLVM IR intrinsic.
void createCodeGen(raw_ostream &o);
std::string getSuffixStr(char Type, int Log2LMUL, StringRef Prototypes);
private:
/// Create all intrinsics and add them to \p Out
void createRVVIntrinsics(std::vector<std::unique_ptr<RVVIntrinsic>> &Out);
/// Print HeaderCode in RVVHeader Record to \p Out
void printHeaderCode(raw_ostream &OS);
/// Compute output and input types by applying different config (basic type
/// and LMUL with type transformers). It also record result of type in legal
/// or illegal set to avoid compute the same config again. The result maybe
/// have illegal RVVType.
Optional<RVVTypes> computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
ArrayRef<std::string> PrototypeSeq);
Optional<RVVTypePtr> computeType(BasicType BT, int Log2LMUL, StringRef Proto);
/// Emit Acrh predecessor definitions and body, assume the element of Defs are
/// sorted by extension.
@ -73,14 +61,39 @@ private:
// non-empty string.
bool emitMacroRestrictionStr(RISCVPredefinedMacroT PredefinedMacros,
raw_ostream &o);
// Slice Prototypes string into sub prototype string and process each sub
// prototype string individually in the Handler.
void parsePrototypes(StringRef Prototypes,
std::function<void(StringRef)> Handler);
};
} // namespace
static BasicType ParseBasicType(char c) {
switch (c) {
case 'c':
return BasicType::Int8;
break;
case 's':
return BasicType::Int16;
break;
case 'i':
return BasicType::Int32;
break;
case 'l':
return BasicType::Int64;
break;
case 'x':
return BasicType::Float16;
break;
case 'f':
return BasicType::Float32;
break;
case 'd':
return BasicType::Float64;
break;
default:
return BasicType::Unknown;
}
}
void emitCodeGenSwitchBody(const RVVIntrinsic *RVVI, raw_ostream &OS) {
if (!RVVI->getIRName().empty())
OS << " ID = Intrinsic::riscv_" + RVVI->getIRName() + ";\n";
@ -202,24 +215,31 @@ void RVVEmitter::createHeader(raw_ostream &OS) {
constexpr int Log2LMULs[] = {-3, -2, -1, 0, 1, 2, 3};
// Print RVV boolean types.
for (int Log2LMUL : Log2LMULs) {
auto T = computeType('c', Log2LMUL, "m");
auto T = RVVType::computeType(BasicType::Int8, Log2LMUL,
PrototypeDescriptor::Mask);
if (T.hasValue())
printType(T.getValue());
}
// Print RVV int/float types.
for (char I : StringRef("csil")) {
BasicType BT = ParseBasicType(I);
for (int Log2LMUL : Log2LMULs) {
auto T = computeType(I, Log2LMUL, "v");
auto T = RVVType::computeType(BT, Log2LMUL, PrototypeDescriptor::Vector);
if (T.hasValue()) {
printType(T.getValue());
auto UT = computeType(I, Log2LMUL, "Uv");
auto UT = RVVType::computeType(
BT, Log2LMUL,
PrototypeDescriptor(BaseTypeModifier::Vector,
VectorTypeModifier::NoModifier,
TypeModifier::UnsignedInteger));
printType(UT.getValue());
}
}
}
OS << "#if defined(__riscv_zvfh)\n";
for (int Log2LMUL : Log2LMULs) {
auto T = computeType('x', Log2LMUL, "v");
auto T = RVVType::computeType(BasicType::Float16, Log2LMUL,
PrototypeDescriptor::Vector);
if (T.hasValue())
printType(T.getValue());
}
@ -227,7 +247,8 @@ void RVVEmitter::createHeader(raw_ostream &OS) {
OS << "#if defined(__riscv_f)\n";
for (int Log2LMUL : Log2LMULs) {
auto T = computeType('f', Log2LMUL, "v");
auto T = RVVType::computeType(BasicType::Float32, Log2LMUL,
PrototypeDescriptor::Vector);
if (T.hasValue())
printType(T.getValue());
}
@ -235,7 +256,8 @@ void RVVEmitter::createHeader(raw_ostream &OS) {
OS << "#if defined(__riscv_d)\n";
for (int Log2LMUL : Log2LMULs) {
auto T = computeType('d', Log2LMUL, "v");
auto T = RVVType::computeType(BasicType::Float64, Log2LMUL,
PrototypeDescriptor::Vector);
if (T.hasValue())
printType(T.getValue());
}
@ -359,32 +381,6 @@ void RVVEmitter::createCodeGen(raw_ostream &OS) {
OS << "\n";
}
void RVVEmitter::parsePrototypes(StringRef Prototypes,
std::function<void(StringRef)> Handler) {
const StringRef Primaries("evwqom0ztul");
while (!Prototypes.empty()) {
size_t Idx = 0;
// Skip over complex prototype because it could contain primitive type
// character.
if (Prototypes[0] == '(')
Idx = Prototypes.find_first_of(')');
Idx = Prototypes.find_first_of(Primaries, Idx);
assert(Idx != StringRef::npos);
Handler(Prototypes.slice(0, Idx + 1));
Prototypes = Prototypes.drop_front(Idx + 1);
}
}
std::string RVVEmitter::getSuffixStr(char Type, int Log2LMUL,
StringRef Prototypes) {
SmallVector<std::string> SuffixStrs;
parsePrototypes(Prototypes, [&](StringRef Proto) {
auto T = computeType(Type, Log2LMUL, Proto);
SuffixStrs.push_back(T.getValue()->getShortStr());
});
return join(SuffixStrs, "_");
}
void RVVEmitter::createRVVIntrinsics(
std::vector<std::unique_ptr<RVVIntrinsic>> &Out) {
std::vector<Record *> RV = Records.getAllDerivedDefinitions("RVVBuiltin");
@ -419,13 +415,15 @@ void RVVEmitter::createRVVIntrinsics(
// Parse prototype and create a list of primitive type with transformers
// (operand) in ProtoSeq. ProtoSeq[0] is output operand.
SmallVector<std::string> ProtoSeq;
parsePrototypes(Prototypes, [&ProtoSeq](StringRef Proto) {
ProtoSeq.push_back(Proto.str());
});
SmallVector<PrototypeDescriptor> ProtoSeq = parsePrototypes(Prototypes);
SmallVector<PrototypeDescriptor> SuffixProtoSeq =
parsePrototypes(SuffixProto);
SmallVector<PrototypeDescriptor> MangledSuffixProtoSeq =
parsePrototypes(MangledSuffixProto);
// Compute Builtin types
SmallVector<std::string> ProtoMaskSeq = ProtoSeq;
SmallVector<PrototypeDescriptor> ProtoMaskSeq = ProtoSeq;
if (HasMasked) {
// If HasMaskedOffOperand, insert result type as first input operand.
if (HasMaskedOffOperand) {
@ -436,10 +434,10 @@ void RVVEmitter::createRVVIntrinsics(
// (void, op0 address, op1 address, ...)
// to
// (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
PrototypeDescriptor MaskoffType = ProtoSeq[1];
MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
for (unsigned I = 0; I < NF; ++I)
ProtoMaskSeq.insert(
ProtoMaskSeq.begin() + NF + 1,
ProtoSeq[1].substr(1)); // Use substr(1) to skip '*'
ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, MaskoffType);
}
}
if (HasMaskedOffOperand && NF > 1) {
@ -448,28 +446,34 @@ void RVVEmitter::createRVVIntrinsics(
// to
// (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
// ...)
ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, "m");
ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1,
PrototypeDescriptor::Mask);
} else {
// If HasMasked, insert 'm' as first input operand.
ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, "m");
// If HasMasked, insert PrototypeDescriptor:Mask as first input operand.
ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1,
PrototypeDescriptor::Mask);
}
}
// If HasVL, append 'z' to last operand
// If HasVL, append PrototypeDescriptor:VL to last operand
if (HasVL) {
ProtoSeq.push_back("z");
ProtoMaskSeq.push_back("z");
ProtoSeq.push_back(PrototypeDescriptor::VL);
ProtoMaskSeq.push_back(PrototypeDescriptor::VL);
}
// Create Intrinsics for each type and LMUL.
for (char I : TypeRange) {
for (int Log2LMUL : Log2LMULList) {
Optional<RVVTypes> Types = computeTypes(I, Log2LMUL, NF, ProtoSeq);
BasicType BT = ParseBasicType(I);
Optional<RVVTypes> Types =
RVVType::computeTypes(BT, Log2LMUL, NF, ProtoSeq);
// Ignored to create new intrinsic if there are any illegal types.
if (!Types.hasValue())
continue;
auto SuffixStr = getSuffixStr(I, Log2LMUL, SuffixProto);
auto MangledSuffixStr = getSuffixStr(I, Log2LMUL, MangledSuffixProto);
auto SuffixStr =
RVVIntrinsic::getSuffixStr(BT, Log2LMUL, SuffixProtoSeq);
auto MangledSuffixStr =
RVVIntrinsic::getSuffixStr(BT, Log2LMUL, MangledSuffixProtoSeq);
// Create a unmasked intrinsic
Out.push_back(std::make_unique<RVVIntrinsic>(
Name, SuffixStr, MangledName, MangledSuffixStr, IRName,
@ -480,7 +484,7 @@ void RVVEmitter::createRVVIntrinsics(
if (HasMasked) {
// Create a masked intrinsic
Optional<RVVTypes> MaskTypes =
computeTypes(I, Log2LMUL, NF, ProtoMaskSeq);
RVVType::computeTypes(BT, Log2LMUL, NF, ProtoMaskSeq);
Out.push_back(std::make_unique<RVVIntrinsic>(
Name, SuffixStr, MangledName, MangledSuffixStr, MaskedIRName,
/*IsMasked=*/true, HasMaskedOffOperand, HasVL, MaskedPolicy,
@ -501,45 +505,6 @@ void RVVEmitter::printHeaderCode(raw_ostream &OS) {
}
}
Optional<RVVTypes>
RVVEmitter::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
ArrayRef<std::string> PrototypeSeq) {
// LMUL x NF must be less than or equal to 8.
if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8)
return llvm::None;
RVVTypes Types;
for (const std::string &Proto : PrototypeSeq) {
auto T = computeType(BT, Log2LMUL, Proto);
if (!T.hasValue())
return llvm::None;
// Record legal type index
Types.push_back(T.getValue());
}
return Types;
}
Optional<RVVTypePtr> RVVEmitter::computeType(BasicType BT, int Log2LMUL,
StringRef Proto) {
std::string Idx = Twine(Twine(BT) + Twine(Log2LMUL) + Proto).str();
// Search first
auto It = LegalTypes.find(Idx);
if (It != LegalTypes.end())
return &(It->second);
if (IllegalTypes.count(Idx))
return llvm::None;
// Compute type and record the result.
RVVType T(BT, Log2LMUL, Proto);
if (T.isValid()) {
// Record legal type index and value.
LegalTypes.insert({Idx, T});
return &(LegalTypes[Idx]);
}
// Record illegal type index.
IllegalTypes.insert(Idx);
return llvm::None;
}
void RVVEmitter::emitArchMacroAndBody(
std::vector<std::unique_ptr<RVVIntrinsic>> &Defs, raw_ostream &OS,
std::function<void(raw_ostream &, const RVVIntrinsic &)> PrintBody) {