forked from OSchip/llvm-project
[DirectX backend] Add pass to lower llvm intrinsic into dxil op function.
A new pass DXILOpLowering was added. It will scan all llvm intrinsics, create dxil op function if it can map to dxil op function. Then translate call instructions on the intrinsic into call on dxil op function. dxil op function will add i32 argument to the begining of args for dxil opcode. So cannot use setCalledFunction to update the call instruction on intrinsic. This commit only support sin to start the work. Reviewed By: kuhar, beanz Differential Revision: https://reviews.llvm.org/D124805
This commit is contained in:
parent
4537aae0d5
commit
85285be9c3
|
@ -9,6 +9,7 @@ add_public_tablegen_target(DirectXCommonTableGen)
|
|||
add_llvm_target(DirectXCodeGen
|
||||
DirectXSubtarget.cpp
|
||||
DirectXTargetMachine.cpp
|
||||
DXILOpLowering.cpp
|
||||
DXILPointerType.cpp
|
||||
DXILPrepare.cpp
|
||||
PointerTypeAnalysis.cpp
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
//===- DXILConstants.h - Essential DXIL constants -------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
///
|
||||
/// \file This file contains essential DXIL constants.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef LLVM_LIB_TARGET_DIRECTX_DXILCONSTANTS_H
|
||||
#define LLVM_LIB_TARGET_DIRECTX_DXILCONSTANTS_H
|
||||
|
||||
namespace llvm {
|
||||
namespace DXIL {
|
||||
// Enumeration for operations specified by DXIL
|
||||
enum class OpCode : unsigned {
|
||||
Sin = 13, // returns sine(theta) for theta in radians.
|
||||
};
|
||||
// Groups for DXIL operations with equivalent function templates
|
||||
enum class OpCodeClass : unsigned {
|
||||
Unary,
|
||||
};
|
||||
|
||||
} // namespace DXIL
|
||||
} // namespace llvm
|
||||
|
||||
#endif
|
|
@ -0,0 +1,279 @@
|
|||
//===- DXILOpLower.cpp - Lowering LLVM intrinsic to DIXLOp function -------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
///
|
||||
/// \file This file contains passes and utilities to lower llvm intrinsic call
|
||||
/// to DXILOp function call.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "DXILConstants.h"
|
||||
#include "DirectX.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/CodeGen/Passes.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Instruction.h"
|
||||
#include "llvm/IR/Intrinsics.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/PassManager.h"
|
||||
#include "llvm/Pass.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
#define DEBUG_TYPE "dxil-op-lower"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace llvm::DXIL;
|
||||
|
||||
constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
|
||||
|
||||
enum OverloadKind : uint16_t {
|
||||
VOID = 1,
|
||||
HALF = 1 << 1,
|
||||
FLOAT = 1 << 2,
|
||||
DOUBLE = 1 << 3,
|
||||
I1 = 1 << 4,
|
||||
I8 = 1 << 5,
|
||||
I16 = 1 << 6,
|
||||
I32 = 1 << 7,
|
||||
I64 = 1 << 8,
|
||||
UserDefineType = 1 << 9,
|
||||
ObjectType = 1 << 10,
|
||||
};
|
||||
|
||||
static const char *getOverloadTypeName(OverloadKind Kind) {
|
||||
switch (Kind) {
|
||||
case OverloadKind::HALF:
|
||||
return "f16";
|
||||
case OverloadKind::FLOAT:
|
||||
return "f32";
|
||||
case OverloadKind::DOUBLE:
|
||||
return "f64";
|
||||
case OverloadKind::I1:
|
||||
return "i1";
|
||||
case OverloadKind::I8:
|
||||
return "i8";
|
||||
case OverloadKind::I16:
|
||||
return "i16";
|
||||
case OverloadKind::I32:
|
||||
return "i32";
|
||||
case OverloadKind::I64:
|
||||
return "i64";
|
||||
case OverloadKind::VOID:
|
||||
case OverloadKind::ObjectType:
|
||||
case OverloadKind::UserDefineType:
|
||||
llvm_unreachable("invalid overload type for name");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static OverloadKind getOverloadKind(Type *Ty) {
|
||||
Type::TypeID T = Ty->getTypeID();
|
||||
switch (T) {
|
||||
case Type::VoidTyID:
|
||||
return OverloadKind::VOID;
|
||||
case Type::HalfTyID:
|
||||
return OverloadKind::HALF;
|
||||
case Type::FloatTyID:
|
||||
return OverloadKind::FLOAT;
|
||||
case Type::DoubleTyID:
|
||||
return OverloadKind::DOUBLE;
|
||||
case Type::IntegerTyID: {
|
||||
IntegerType *ITy = cast<IntegerType>(Ty);
|
||||
unsigned Bits = ITy->getBitWidth();
|
||||
switch (Bits) {
|
||||
case 1:
|
||||
return OverloadKind::I1;
|
||||
case 8:
|
||||
return OverloadKind::I8;
|
||||
case 16:
|
||||
return OverloadKind::I16;
|
||||
case 32:
|
||||
return OverloadKind::I32;
|
||||
case 64:
|
||||
return OverloadKind::I64;
|
||||
default:
|
||||
llvm_unreachable("invalid overload type");
|
||||
return OverloadKind::VOID;
|
||||
}
|
||||
}
|
||||
case Type::PointerTyID:
|
||||
return OverloadKind::UserDefineType;
|
||||
case Type::StructTyID:
|
||||
return OverloadKind::ObjectType;
|
||||
default:
|
||||
llvm_unreachable("invalid overload type");
|
||||
return OverloadKind::VOID;
|
||||
}
|
||||
}
|
||||
|
||||
static std::string getTypeName(OverloadKind Kind, Type *Ty) {
|
||||
if (Kind < OverloadKind::UserDefineType) {
|
||||
return getOverloadTypeName(Kind);
|
||||
} else if (Kind == OverloadKind::UserDefineType) {
|
||||
StructType *ST = cast<StructType>(Ty);
|
||||
return ST->getStructName().str();
|
||||
} else if (Kind == OverloadKind::ObjectType) {
|
||||
StructType *ST = cast<StructType>(Ty);
|
||||
return ST->getStructName().str();
|
||||
} else {
|
||||
std::string Str;
|
||||
raw_string_ostream OS(Str);
|
||||
Ty->print(OS);
|
||||
return OS.str();
|
||||
}
|
||||
}
|
||||
|
||||
// Static properties.
|
||||
struct OpCodeProperty {
|
||||
DXIL::OpCode OpCode;
|
||||
// FIXME: change OpCodeName into index to a large string constant when move to
|
||||
// tableGen.
|
||||
const char *OpCodeName;
|
||||
DXIL::OpCodeClass OpCodeClass;
|
||||
uint16_t OverloadTys;
|
||||
llvm::Attribute::AttrKind FuncAttr;
|
||||
};
|
||||
|
||||
static const char *getOpCodeClassName(const OpCodeProperty &Prop) {
|
||||
// FIXME: generate this table with tableGen.
|
||||
static const char *OpCodeClassNames[] = {
|
||||
"unary",
|
||||
};
|
||||
unsigned Index = static_cast<unsigned>(Prop.OpCodeClass);
|
||||
assert(Index < (sizeof(OpCodeClassNames) / sizeof(OpCodeClassNames[0])) &&
|
||||
"Out of bound OpCodeClass");
|
||||
return OpCodeClassNames[Index];
|
||||
}
|
||||
|
||||
static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
|
||||
const OpCodeProperty &Prop) {
|
||||
if (Kind == OverloadKind::VOID) {
|
||||
return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
|
||||
}
|
||||
return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
|
||||
getTypeName(Kind, Ty))
|
||||
.str();
|
||||
}
|
||||
|
||||
static const OpCodeProperty *getOpCodeProperty(DXIL::OpCode DXILOp) {
|
||||
// FIXME: generate this table with tableGen.
|
||||
static const OpCodeProperty OpCodeProps[] = {
|
||||
{DXIL::OpCode::Sin, "Sin", OpCodeClass::Unary,
|
||||
OverloadKind::FLOAT | OverloadKind::HALF, Attribute::AttrKind::ReadNone},
|
||||
};
|
||||
// FIXME: change search to indexing with
|
||||
// DXILOp once all DXIL op is added.
|
||||
OpCodeProperty TmpProp;
|
||||
TmpProp.OpCode = DXILOp;
|
||||
const OpCodeProperty *Prop =
|
||||
llvm::lower_bound(OpCodeProps, TmpProp,
|
||||
[](const OpCodeProperty &A, const OpCodeProperty &B) {
|
||||
return A.OpCode < B.OpCode;
|
||||
});
|
||||
return Prop;
|
||||
}
|
||||
|
||||
static FunctionCallee createDXILOpFunction(DXIL::OpCode DXILOp, Function &F,
|
||||
Module &M) {
|
||||
const OpCodeProperty *Prop = getOpCodeProperty(DXILOp);
|
||||
|
||||
// Get return type as overload type for DXILOp.
|
||||
// Only simple mapping case here, so return type is good enough.
|
||||
Type *OverloadTy = F.getReturnType();
|
||||
|
||||
OverloadKind Kind = getOverloadKind(OverloadTy);
|
||||
// FIXME: find the issue and report error in clang instead of check it in
|
||||
// backend.
|
||||
if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
|
||||
llvm_unreachable("invalid overload");
|
||||
}
|
||||
|
||||
std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
|
||||
assert(!M.getFunction(FnName) && "Function already exists");
|
||||
|
||||
auto &Ctx = M.getContext();
|
||||
Type *OpCodeTy = Type::getInt32Ty(Ctx);
|
||||
|
||||
SmallVector<Type *> ArgTypes;
|
||||
// DXIL has i32 opcode as first arg.
|
||||
ArgTypes.emplace_back(OpCodeTy);
|
||||
FunctionType *FT = F.getFunctionType();
|
||||
ArgTypes.append(FT->param_begin(), FT->param_end());
|
||||
FunctionType *DXILOpFT = FunctionType::get(OverloadTy, ArgTypes, false);
|
||||
return M.getOrInsertFunction(FnName, DXILOpFT);
|
||||
}
|
||||
|
||||
static void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M) {
|
||||
auto DXILOpFn = createDXILOpFunction(DXILOp, F, M);
|
||||
IRBuilder<> B(M.getContext());
|
||||
Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
|
||||
for (User *U : make_early_inc_range(F.users())) {
|
||||
CallInst *CI = dyn_cast<CallInst>(U);
|
||||
if (!CI)
|
||||
continue;
|
||||
|
||||
SmallVector<Value *> Args;
|
||||
Args.emplace_back(DXILOpArg);
|
||||
Args.append(CI->arg_begin(), CI->arg_end());
|
||||
B.SetInsertPoint(CI);
|
||||
CallInst *DXILCI = B.CreateCall(DXILOpFn, Args);
|
||||
CI->replaceAllUsesWith(DXILCI);
|
||||
CI->eraseFromParent();
|
||||
}
|
||||
if (F.user_empty())
|
||||
F.eraseFromParent();
|
||||
}
|
||||
|
||||
static bool lowerIntrinsics(Module &M) {
|
||||
bool Updated = false;
|
||||
static SmallDenseMap<Intrinsic::ID, DXIL::OpCode> LowerMap = {
|
||||
{Intrinsic::sin, DXIL::OpCode::Sin}};
|
||||
for (Function &F : make_early_inc_range(M.functions())) {
|
||||
if (!F.isDeclaration())
|
||||
continue;
|
||||
Intrinsic::ID ID = F.getIntrinsicID();
|
||||
auto LowerIt = LowerMap.find(ID);
|
||||
if (LowerIt == LowerMap.end())
|
||||
continue;
|
||||
lowerIntrinsic(LowerIt->second, F, M);
|
||||
Updated = true;
|
||||
}
|
||||
return Updated;
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// A pass that transforms external global definitions into declarations.
|
||||
class DXILOpLowering : public PassInfoMixin<DXILOpLowering> {
|
||||
public:
|
||||
PreservedAnalyses run(Module &M, ModuleAnalysisManager &) {
|
||||
if (lowerIntrinsics(M))
|
||||
return PreservedAnalyses::none();
|
||||
return PreservedAnalyses::all();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DXILOpLoweringLegacy : public ModulePass {
|
||||
public:
|
||||
bool runOnModule(Module &M) override { return lowerIntrinsics(M); }
|
||||
StringRef getPassName() const override { return "DXIL Op Lowering"; }
|
||||
DXILOpLoweringLegacy() : ModulePass(ID) {}
|
||||
|
||||
static char ID; // Pass identification.
|
||||
};
|
||||
char DXILOpLoweringLegacy::ID = 0;
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",
|
||||
false, false)
|
||||
INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false,
|
||||
false)
|
||||
|
||||
ModulePass *llvm::createDXILOpLoweringLegacyPass() {
|
||||
return new DXILOpLoweringLegacy();
|
||||
}
|
|
@ -23,6 +23,13 @@ void initializeDXILPrepareModulePass(PassRegistry &);
|
|||
|
||||
/// Pass to convert modules into DXIL-compatable modules
|
||||
ModulePass *createDXILPrepareModulePass();
|
||||
|
||||
/// Initializer for DXILOpLowering
|
||||
void initializeDXILOpLoweringLegacyPass(PassRegistry &);
|
||||
|
||||
/// Pass to lowering LLVM intrinsic call to DXIL op function call.
|
||||
ModulePass *createDXILOpLoweringLegacyPass();
|
||||
|
||||
} // namespace llvm
|
||||
|
||||
#endif // LLVM_LIB_TARGET_DIRECTX_DIRECTX_H
|
||||
|
|
|
@ -34,6 +34,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
|
|||
RegisterTargetMachine<DirectXTargetMachine> X(getTheDirectXTarget());
|
||||
auto *PR = PassRegistry::getPassRegistry();
|
||||
initializeDXILPrepareModulePass(*PR);
|
||||
initializeDXILOpLoweringLegacyPass(*PR);
|
||||
}
|
||||
|
||||
class DXILTargetObjectFile : public TargetLoweringObjectFile {
|
||||
|
@ -84,6 +85,7 @@ bool DirectXTargetMachine::addPassesToEmitFile(
|
|||
PassManagerBase &PM, raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut,
|
||||
CodeGenFileType FileType, bool DisableVerify,
|
||||
MachineModuleInfoWrapperPass *MMIWP) {
|
||||
PM.add(createDXILOpLoweringLegacyPass());
|
||||
PM.add(createDXILPrepareModulePass());
|
||||
switch (FileType) {
|
||||
case CGFT_AssemblyFile:
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
|
||||
|
||||
; Make sure dxil operation function calls for sin are generated for float and half.
|
||||
; CHECK:call float @dx.op.unary.f32(i32 13, float %{{.*}})
|
||||
; CHECK:call half @dx.op.unary.f16(i32 13, half %{{.*}})
|
||||
|
||||
target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
|
||||
target triple = "dxil-pc-shadermodel6.7-library"
|
||||
|
||||
; Function Attrs: noinline nounwind optnone
|
||||
define noundef float @_Z3foof(float noundef %a) #0 {
|
||||
entry:
|
||||
%a.addr = alloca float, align 4
|
||||
store float %a, ptr %a.addr, align 4
|
||||
%0 = load float, ptr %a.addr, align 4
|
||||
%1 = call float @llvm.sin.f32(float %0)
|
||||
ret float %1
|
||||
}
|
||||
|
||||
; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
|
||||
declare float @llvm.sin.f32(float) #1
|
||||
|
||||
; Function Attrs: noinline nounwind optnone
|
||||
define noundef half @_Z3barDh(half noundef %a) #0 {
|
||||
entry:
|
||||
%a.addr = alloca half, align 2
|
||||
store half %a, ptr %a.addr, align 2
|
||||
%0 = load half, ptr %a.addr, align 2
|
||||
%1 = call half @llvm.sin.f16(half %0)
|
||||
ret half %1
|
||||
}
|
||||
|
||||
; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
|
||||
declare half @llvm.sin.f16(half) #1
|
||||
|
||||
attributes #0 = { noinline nounwind optnone "frame-pointer"="none" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
|
||||
attributes #1 = { nocallback nofree nosync nounwind readnone speculatable willreturn }
|
||||
|
||||
!llvm.module.flags = !{!0}
|
||||
!llvm.ident = !{!1}
|
||||
|
||||
!0 = !{i32 1, !"wchar_size", i32 4}
|
||||
!1 = !{!"clang version 15.0.0 (https://github.com/llvm/llvm-project.git 73417c517644db5c419c85c0b3cb6750172fcab5)"}
|
|
@ -476,7 +476,7 @@ static bool shouldPinPassToLegacyPM(StringRef Pass) {
|
|||
"x86-", "xcore-", "wasm-", "systemz-", "ppc-", "nvvm-",
|
||||
"nvptx-", "mips-", "lanai-", "hexagon-", "bpf-", "avr-",
|
||||
"thumb2-", "arm-", "si-", "gcn-", "amdgpu-", "aarch64-",
|
||||
"amdgcn-", "polly-", "riscv-"};
|
||||
"amdgcn-", "polly-", "riscv-", "dxil-"};
|
||||
std::vector<StringRef> PassNameContain = {"ehprepare"};
|
||||
std::vector<StringRef> PassNameExact = {
|
||||
"safe-stack", "cost-model",
|
||||
|
|
Loading…
Reference in New Issue