forked from OSchip/llvm-project
325 lines
9.9 KiB
C++
325 lines
9.9 KiB
C++
//===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
///
|
|
/// \file This file contains class to help build DXIL op functions.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "DXILOpBuilder.h"
|
|
#include "DXILConstants.h"
|
|
#include "llvm/IR/IRBuilder.h"
|
|
#include "llvm/IR/Module.h"
|
|
#include "llvm/Support/DXILOperationCommon.h"
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
|
|
using namespace llvm;
|
|
using namespace llvm::DXIL;
|
|
|
|
constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
|
|
|
|
namespace {
|
|
|
|
enum OverloadKind : uint16_t {
|
|
VOID = 1,
|
|
HALF = 1 << 1,
|
|
FLOAT = 1 << 2,
|
|
DOUBLE = 1 << 3,
|
|
I1 = 1 << 4,
|
|
I8 = 1 << 5,
|
|
I16 = 1 << 6,
|
|
I32 = 1 << 7,
|
|
I64 = 1 << 8,
|
|
UserDefineType = 1 << 9,
|
|
ObjectType = 1 << 10,
|
|
};
|
|
|
|
} // namespace
|
|
|
|
static const char *getOverloadTypeName(OverloadKind Kind) {
|
|
switch (Kind) {
|
|
case OverloadKind::HALF:
|
|
return "f16";
|
|
case OverloadKind::FLOAT:
|
|
return "f32";
|
|
case OverloadKind::DOUBLE:
|
|
return "f64";
|
|
case OverloadKind::I1:
|
|
return "i1";
|
|
case OverloadKind::I8:
|
|
return "i8";
|
|
case OverloadKind::I16:
|
|
return "i16";
|
|
case OverloadKind::I32:
|
|
return "i32";
|
|
case OverloadKind::I64:
|
|
return "i64";
|
|
case OverloadKind::VOID:
|
|
case OverloadKind::ObjectType:
|
|
case OverloadKind::UserDefineType:
|
|
break;
|
|
}
|
|
llvm_unreachable("invalid overload type for name");
|
|
return "void";
|
|
}
|
|
|
|
static OverloadKind getOverloadKind(Type *Ty) {
|
|
Type::TypeID T = Ty->getTypeID();
|
|
switch (T) {
|
|
case Type::VoidTyID:
|
|
return OverloadKind::VOID;
|
|
case Type::HalfTyID:
|
|
return OverloadKind::HALF;
|
|
case Type::FloatTyID:
|
|
return OverloadKind::FLOAT;
|
|
case Type::DoubleTyID:
|
|
return OverloadKind::DOUBLE;
|
|
case Type::IntegerTyID: {
|
|
IntegerType *ITy = cast<IntegerType>(Ty);
|
|
unsigned Bits = ITy->getBitWidth();
|
|
switch (Bits) {
|
|
case 1:
|
|
return OverloadKind::I1;
|
|
case 8:
|
|
return OverloadKind::I8;
|
|
case 16:
|
|
return OverloadKind::I16;
|
|
case 32:
|
|
return OverloadKind::I32;
|
|
case 64:
|
|
return OverloadKind::I64;
|
|
default:
|
|
llvm_unreachable("invalid overload type");
|
|
return OverloadKind::VOID;
|
|
}
|
|
}
|
|
case Type::PointerTyID:
|
|
return OverloadKind::UserDefineType;
|
|
case Type::StructTyID:
|
|
return OverloadKind::ObjectType;
|
|
default:
|
|
llvm_unreachable("invalid overload type");
|
|
return OverloadKind::VOID;
|
|
}
|
|
}
|
|
|
|
static std::string getTypeName(OverloadKind Kind, Type *Ty) {
|
|
if (Kind < OverloadKind::UserDefineType) {
|
|
return getOverloadTypeName(Kind);
|
|
} else if (Kind == OverloadKind::UserDefineType) {
|
|
StructType *ST = cast<StructType>(Ty);
|
|
return ST->getStructName().str();
|
|
} else if (Kind == OverloadKind::ObjectType) {
|
|
StructType *ST = cast<StructType>(Ty);
|
|
return ST->getStructName().str();
|
|
} else {
|
|
std::string Str;
|
|
raw_string_ostream OS(Str);
|
|
Ty->print(OS);
|
|
return OS.str();
|
|
}
|
|
}
|
|
|
|
// Static properties.
|
|
struct OpCodeProperty {
|
|
DXIL::OpCode OpCode;
|
|
// Offset in DXILOpCodeNameTable.
|
|
unsigned OpCodeNameOffset;
|
|
DXIL::OpCodeClass OpCodeClass;
|
|
// Offset in DXILOpCodeClassNameTable.
|
|
unsigned OpCodeClassNameOffset;
|
|
uint16_t OverloadTys;
|
|
llvm::Attribute::AttrKind FuncAttr;
|
|
int OverloadParamIndex; // parameter index which control the overload.
|
|
// When < 0, should be only 1 overload type.
|
|
unsigned NumOfParameters; // Number of parameters include return value.
|
|
unsigned ParameterTableOffset; // Offset in ParameterTable.
|
|
};
|
|
|
|
// Include getOpCodeClassName getOpCodeProperty, getOpCodeName and
|
|
// getOpCodeParameterKind which generated by tableGen.
|
|
#define DXIL_OP_OPERATION_TABLE
|
|
#include "DXILOperation.inc"
|
|
#undef DXIL_OP_OPERATION_TABLE
|
|
|
|
static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
|
|
const OpCodeProperty &Prop) {
|
|
if (Kind == OverloadKind::VOID) {
|
|
return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
|
|
}
|
|
return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
|
|
getTypeName(Kind, Ty))
|
|
.str();
|
|
}
|
|
|
|
static std::string constructOverloadTypeName(OverloadKind Kind,
|
|
StringRef TypeName) {
|
|
if (Kind == OverloadKind::VOID)
|
|
return TypeName.str();
|
|
|
|
assert(Kind < OverloadKind::UserDefineType && "invalid overload kind");
|
|
return (Twine(TypeName) + getOverloadTypeName(Kind)).str();
|
|
}
|
|
|
|
static StructType *getOrCreateStructType(StringRef Name,
|
|
ArrayRef<Type *> EltTys,
|
|
LLVMContext &Ctx) {
|
|
StructType *ST = StructType::getTypeByName(Ctx, Name);
|
|
if (ST)
|
|
return ST;
|
|
|
|
return StructType::create(Ctx, EltTys, Name);
|
|
}
|
|
|
|
static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) {
|
|
OverloadKind Kind = getOverloadKind(OverloadTy);
|
|
std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
|
|
Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy,
|
|
Type::getInt32Ty(Ctx)};
|
|
return getOrCreateStructType(TypeName, FieldTypes, Ctx);
|
|
}
|
|
|
|
static StructType *getHandleType(LLVMContext &Ctx) {
|
|
return getOrCreateStructType("dx.types.Handle", Type::getInt8PtrTy(Ctx), Ctx);
|
|
}
|
|
|
|
static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) {
|
|
auto &Ctx = OverloadTy->getContext();
|
|
switch (Kind) {
|
|
case ParameterKind::VOID:
|
|
return Type::getVoidTy(Ctx);
|
|
case ParameterKind::HALF:
|
|
return Type::getHalfTy(Ctx);
|
|
case ParameterKind::FLOAT:
|
|
return Type::getFloatTy(Ctx);
|
|
case ParameterKind::DOUBLE:
|
|
return Type::getDoubleTy(Ctx);
|
|
case ParameterKind::I1:
|
|
return Type::getInt1Ty(Ctx);
|
|
case ParameterKind::I8:
|
|
return Type::getInt8Ty(Ctx);
|
|
case ParameterKind::I16:
|
|
return Type::getInt16Ty(Ctx);
|
|
case ParameterKind::I32:
|
|
return Type::getInt32Ty(Ctx);
|
|
case ParameterKind::I64:
|
|
return Type::getInt64Ty(Ctx);
|
|
case ParameterKind::OVERLOAD:
|
|
return OverloadTy;
|
|
case ParameterKind::RESOURCE_RET:
|
|
return getResRetType(OverloadTy, Ctx);
|
|
case ParameterKind::DXIL_HANDLE:
|
|
return getHandleType(Ctx);
|
|
default:
|
|
break;
|
|
}
|
|
llvm_unreachable("Invalid parameter kind");
|
|
return nullptr;
|
|
}
|
|
|
|
static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
|
|
Type *OverloadTy) {
|
|
SmallVector<Type *> ArgTys;
|
|
|
|
auto ParamKinds = getOpCodeParameterKind(*Prop);
|
|
|
|
for (unsigned I = 0; I < Prop->NumOfParameters; ++I) {
|
|
ParameterKind Kind = ParamKinds[I];
|
|
ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy));
|
|
}
|
|
return FunctionType::get(
|
|
ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false);
|
|
}
|
|
|
|
static FunctionCallee getOrCreateDXILOpFunction(DXIL::OpCode DXILOp,
|
|
Type *OverloadTy, Module &M) {
|
|
const OpCodeProperty *Prop = getOpCodeProperty(DXILOp);
|
|
|
|
OverloadKind Kind = getOverloadKind(OverloadTy);
|
|
// FIXME: find the issue and report error in clang instead of check it in
|
|
// backend.
|
|
if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
|
|
llvm_unreachable("invalid overload");
|
|
}
|
|
|
|
std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
|
|
// Dependent on name to dedup.
|
|
if (auto *Fn = M.getFunction(FnName))
|
|
return FunctionCallee(Fn);
|
|
|
|
FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, OverloadTy);
|
|
return M.getOrInsertFunction(FnName, DXILOpFT);
|
|
}
|
|
|
|
namespace llvm {
|
|
namespace DXIL {
|
|
|
|
CallInst *DXILOpBuilder::createDXILOpCall(DXIL::OpCode OpCode, Type *OverloadTy,
|
|
llvm::iterator_range<Use *> Args) {
|
|
auto Fn = getOrCreateDXILOpFunction(OpCode, OverloadTy, M);
|
|
SmallVector<Value *> FullArgs;
|
|
FullArgs.emplace_back(B.getInt32((int32_t)OpCode));
|
|
FullArgs.append(Args.begin(), Args.end());
|
|
return B.CreateCall(Fn, FullArgs);
|
|
}
|
|
|
|
Type *DXILOpBuilder::getOverloadTy(DXIL::OpCode OpCode, FunctionType *FT,
|
|
bool NoOpCodeParam) {
|
|
|
|
const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
|
|
if (Prop->OverloadParamIndex < 0) {
|
|
auto &Ctx = FT->getContext();
|
|
// When only has 1 overload type, just return it.
|
|
switch (Prop->OverloadTys) {
|
|
case OverloadKind::VOID:
|
|
return Type::getVoidTy(Ctx);
|
|
case OverloadKind::HALF:
|
|
return Type::getHalfTy(Ctx);
|
|
case OverloadKind::FLOAT:
|
|
return Type::getFloatTy(Ctx);
|
|
case OverloadKind::DOUBLE:
|
|
return Type::getDoubleTy(Ctx);
|
|
case OverloadKind::I1:
|
|
return Type::getInt1Ty(Ctx);
|
|
case OverloadKind::I8:
|
|
return Type::getInt8Ty(Ctx);
|
|
case OverloadKind::I16:
|
|
return Type::getInt16Ty(Ctx);
|
|
case OverloadKind::I32:
|
|
return Type::getInt32Ty(Ctx);
|
|
case OverloadKind::I64:
|
|
return Type::getInt64Ty(Ctx);
|
|
default:
|
|
llvm_unreachable("invalid overload type");
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
// Prop->OverloadParamIndex is 0, overload type is FT->getReturnType().
|
|
Type *OverloadType = FT->getReturnType();
|
|
if (Prop->OverloadParamIndex != 0) {
|
|
// Skip Return Type and Type for DXIL opcode.
|
|
const unsigned SkipedParam = NoOpCodeParam ? 2 : 1;
|
|
OverloadType = FT->getParamType(Prop->OverloadParamIndex - SkipedParam);
|
|
}
|
|
|
|
auto ParamKinds = getOpCodeParameterKind(*Prop);
|
|
auto Kind = ParamKinds[Prop->OverloadParamIndex];
|
|
// For ResRet and CBufferRet, OverloadTy is in field of StructType.
|
|
if (Kind == ParameterKind::CBUFFER_RET ||
|
|
Kind == ParameterKind::RESOURCE_RET) {
|
|
auto *ST = cast<StructType>(OverloadType);
|
|
OverloadType = ST->getElementType(0);
|
|
}
|
|
return OverloadType;
|
|
}
|
|
|
|
const char *DXILOpBuilder::getOpCodeName(DXIL::OpCode DXILOp) {
|
|
return ::getOpCodeName(DXILOp);
|
|
}
|
|
} // namespace DXIL
|
|
} // namespace llvm
|