[WRAPPER HELPER] Improve Record Parsing (#880)

This commit is contained in:
wannacu 2023-07-12 15:43:08 +08:00 committed by GitHub
parent a6b231ce56
commit b1c09acb0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 75 additions and 29 deletions

View File

@ -25,11 +25,15 @@
#include "gen.h"
#include "utils.h"
static void ParseParameter(clang::ASTContext* AST, WrapperGenerator* Gen, clang::ParmVarDecl* Decl, FuncInfo* Func) {
static void ParseParameter(clang::ASTContext* AST, WrapperGenerator* Gen, clang::QualType ParmType, FuncInfo* Func) {
using namespace clang;
(void)AST; (void)Func;
auto ParmType = Decl->getType();
if (ParmType->isPointerType()) {
if (ParmType->isFunctionPointerType()) {
auto ProtoType = ParmType->getPointeeType()->getAs<FunctionProtoType>();
for (unsigned i = 0; i < ProtoType->getNumParams(); i++) {
ParseParameter(AST, Gen, ProtoType->getParamType(i), Func);
}
} else if (ParmType->isPointerType()) {
auto PointeeType = ParmType->getPointeeType();
if (PointeeType->isRecordType()) {
if (Gen->records.find(StripTypedef(PointeeType)) == Gen->records.end()) {
@ -91,7 +95,7 @@ static void ParseFunction(clang::ASTContext* AST, WrapperGenerator* Gen, clang::
} else {
FuncInfo->callback_args[i] = nullptr;
}
ParseParameter(AST, Gen, ParmDecl, FuncInfo);
ParseParameter(AST, Gen, ParmDecl->getType(), FuncInfo);
}
}

View File

@ -289,6 +289,7 @@ std::string WrapperGenerator::GenDeclare(ASTContext *Ctx,
const RecordInfo &Record) {
(void)Ctx;
std::string RecordStr;
std::string PreDecl;
RecordStr += "\ntypedef ";
RecordStr +=
(Record.is_union ? "union " : "struct ") + Record.type_name + " {\n";
@ -327,10 +328,10 @@ std::string WrapperGenerator::GenDeclare(ASTContext *Ctx,
FieldStr += Name;
RecordStr += FieldStr;
} else {
RecordStr += TypeStringify(StripTypedef(Type), Field, nullptr);
RecordStr += TypeStringify(StripTypedef(Type), Field, nullptr, PreDecl);
}
} else {
RecordStr += TypeStringify(StripTypedef(Type), Field, nullptr);
RecordStr += TypeStringify(StripTypedef(Type), Field, nullptr, PreDecl);
}
RecordStr += ";\n";
}
@ -547,6 +548,7 @@ std::string WrapperGenerator::GenDeclareDiffTriple(
(void)Ctx;
std::string GuestRecord;
std::string HostRecord;
std::string PreDecl;
std::vector<uint64_t> GuestFieldOff;
std::vector<uint64_t> HostFieldOff;
GuestRecord += "typedef ";
@ -599,7 +601,7 @@ std::string WrapperGenerator::GenDeclareDiffTriple(
std::cout << "Err: unknown type size " << typeSize << std::endl;
break;
}
HostRecord += TypeStringify(StripTypedef(Type), Field, nullptr);
HostRecord += TypeStringify(StripTypedef(Type), Field, nullptr, PreDecl);
} else if (Type->isFunctionPointerType()) {
auto FuncType = StripTypedef(Type->getPointeeType());
if (callbacks.count(FuncType)) {
@ -634,12 +636,12 @@ std::string WrapperGenerator::GenDeclareDiffTriple(
GuestRecord += FieldStr;
HostRecord += "host_" + FieldStr;
} else {
GuestRecord += TypeStringify(StripTypedef(Type), Field, nullptr);
HostRecord += TypeStringify(StripTypedef(Type), Field, nullptr);
GuestRecord += TypeStringify(StripTypedef(Type), Field, nullptr, PreDecl);
HostRecord += TypeStringify(StripTypedef(Type), Field, nullptr, PreDecl);
}
} else {
HostRecord += TypeStringify(StripTypedef(Type), Field, nullptr);
GuestRecord += TypeStringify(StripTypedef(Type), Field, nullptr);
HostRecord += TypeStringify(StripTypedef(Type), Field, nullptr, PreDecl);
GuestRecord += TypeStringify(StripTypedef(Type), Field, nullptr, PreDecl);
}
GuestRecord += ";\n";
HostRecord += ";\n";
@ -685,12 +687,12 @@ std::string WrapperGenerator::GenRecordConvert(const RecordInfo &Record) {
if (!AlignDiffFields.size()) {
return res;
}
res += "void g2h_" + Record.type_name + "(" + "struct host_" + Record.type_name + "*d, struct" + Record.type_name + "*s) {\n";
res += "void g2h_" + Record.type_name + "(" + "struct host_" + Record.type_name + " *d, struct " + Record.type_name + " *s) {\n";
std::string body = " memcpy(d, s, offsetof(struct " + Record.type_name +
", " + AlignDiffFields[0]->getNameAsString() + "));\n";
std::string offstr = "offsetof(struct " + Record.type_name + ", " +
AlignDiffFields[0]->getNameAsString() + ")";
for (size_t i = 1; i < AlignDiffFields.size() - 1; i++) {
for (size_t i = 0; i < AlignDiffFields.size() - 1; i++) {
body += " memcpy(d->" + AlignDiffFields[i]->getNameAsString() + ", " +
"s->" + AlignDiffFields[i]->getNameAsString() + ", " +
"offsetof(struct " + Record.type_name + ", " +
@ -707,7 +709,7 @@ std::string WrapperGenerator::GenRecordConvert(const RecordInfo &Record) {
" - " + offstr + ");\n";
res += body + "}\n";
res += "void h2g_" + Record.type_name + "(struct" + Record.type_name + "*d, " + "struct host_" + Record.type_name + "*s) {\n";
res += "void h2g_" + Record.type_name + "(struct " + Record.type_name + " *d, " + "struct host_" + Record.type_name + " *s) {\n";
res += body;
res += "}\n";
}
@ -775,6 +777,7 @@ void WrapperGenerator::ParseRecordRecursive(
std::string WrapperGenerator::TypeStringify(const Type *Type,
FieldDecl *FieldDecl,
ParmVarDecl *ParmDecl,
std::string& PreDecl,
std::string indent,
std::string Name) {
std::string res;
@ -807,7 +810,7 @@ std::string WrapperGenerator::TypeStringify(const Type *Type,
res += records[StripTypedef(Type->getCanonicalTypeInternal())].type_name;
res += " ";
} else {
res += AnonRecordDecl(Type->getAs<RecordType>(), indent);
res += AnonRecordDecl(Type->getAs<RecordType>(), PreDecl, indent + " ");
}
res += name;
} else if (Type->isConstantArrayType()) {
@ -816,6 +819,21 @@ std::string WrapperGenerator::TypeStringify(const Type *Type,
int EleSize = ArrayType->getSize().getZExtValue();
if (ArrayType->getElementType()->isPointerType()) {
res += "void *";
} else if (ArrayType->getElementType()->isEnumeralType()) {
res += "int ";
} else if (ArrayType->getElementType()->isRecordType()) {
auto RecordType = ArrayType->getElementType()->getAs<clang::RecordType>();
auto RecordDecl = RecordType->getDecl();
if (RecordDecl->isCompleteDefinition()) {
auto& Ctx = RecordDecl->getDeclContext()->getParentASTContext();
PreDecl += "#include \"";
PreDecl += GetDeclHeaderFile(Ctx, RecordDecl);
PreDecl += "\"";
PreDecl += "\n";
}
res += StripTypedef(ArrayType->getElementType())
->getCanonicalTypeInternal()
.getAsString();
} else {
res += StripTypedef(ArrayType->getElementType())
->getCanonicalTypeInternal()
@ -887,13 +905,13 @@ std::string WrapperGenerator::SimpleTypeStringify(const Type *Type,
return indent + res;
}
std::string WrapperGenerator::AnonRecordDecl(const RecordType *Type, std::string indent) {
std::string WrapperGenerator::AnonRecordDecl(const RecordType *Type, std::string& PreDecl, std::string indent) {
auto RecordDecl = Type->getDecl();
std::string res;
res += Type->isUnionType() ? "union {\n" : "struct {\n";
for (const auto &field : RecordDecl->fields()) {
auto FieldType = field->getType();
res += TypeStringify(StripTypedef(FieldType), field, nullptr, indent + " ");
res += TypeStringify(StripTypedef(FieldType), field, nullptr, PreDecl, indent + " ");
res += ";\n";
}
res += indent + "} ";
@ -907,7 +925,7 @@ WrapperGenerator::SimpleAnonRecordDecl(const RecordType *Type, std::string inden
res += Type->isUnionType() ? "union {\n" : "struct {\n";
for (const auto &field : RecordDecl->fields()) {
auto FieldType = field->getType();
res += SimpleTypeStringify(StripTypedef(FieldType), field, nullptr, indent + " ");
res += SimpleTypeStringify(StripTypedef(FieldType), field, nullptr, indent + " ");
res += ";\n";
}
res += indent + "} ";
@ -917,15 +935,16 @@ WrapperGenerator::SimpleAnonRecordDecl(const RecordType *Type, std::string inden
// Get func info from FunctionType
FuncDefinition WrapperGenerator::GetFuncDefinition(const Type *Type) {
FuncDefinition res;
std::string PreDecl;
auto ProtoType = Type->getAs<FunctionProtoType>();
res.ret = StripTypedef(ProtoType->getReturnType());
res.ret_str =
TypeStringify(StripTypedef(ProtoType->getReturnType()), nullptr, nullptr);
TypeStringify(StripTypedef(ProtoType->getReturnType()), nullptr, nullptr, PreDecl);
for (unsigned i = 0; i < ProtoType->getNumParams(); i++) {
auto ParamType = ProtoType->getParamType(i);
res.arg_types.push_back(StripTypedef(ParamType));
res.arg_types_str.push_back(
TypeStringify(StripTypedef(ParamType), nullptr, nullptr));
TypeStringify(StripTypedef(ParamType), nullptr, nullptr, PreDecl));
res.arg_names.push_back(std::string("a") + std::to_string(i));
}
if (ProtoType->isVariadic()) {
@ -938,15 +957,16 @@ FuncDefinition WrapperGenerator::GetFuncDefinition(const Type *Type) {
// Get funcdecl info from FunctionDecl
FuncDefinition WrapperGenerator::GetFuncDefinition(FunctionDecl *Decl) {
FuncDefinition res;
std::string PreDecl;
auto RetType = Decl->getReturnType();
res.ret = RetType.getTypePtr();
res.ret_str = TypeStringify(StripTypedef(RetType), nullptr, nullptr);
res.ret_str = TypeStringify(StripTypedef(RetType), nullptr, nullptr, PreDecl);
for (unsigned i = 0; i < Decl->getNumParams(); i++) {
auto ParamDecl = Decl->getParamDecl(i);
auto ParamType = ParamDecl->getType();
res.arg_types.push_back(ParamType.getTypePtr());
res.arg_types_str.push_back(
TypeStringify(StripTypedef(ParamType), nullptr, nullptr));
TypeStringify(StripTypedef(ParamType), nullptr, nullptr, PreDecl));
res.arg_names.push_back(ParamDecl->getNameAsString());
}
if (Decl->isVariadic()) {
@ -960,7 +980,8 @@ std::vector<uint64_t> WrapperGenerator::GetRecordFieldOffDiff(
const Type *Type, const std::string &GuestTriple,
const std::string &HostTriple, std::vector<uint64_t> &GuestFieldOff,
std::vector<uint64_t> &HostFieldOff) {
std::string Code = TypeStringify(Type, nullptr, nullptr, "", "dummy;");
std::string PreDecl;
std::string Code = TypeStringify(Type, nullptr, nullptr, PreDecl, "", "dummy;");
std::vector<uint64_t> OffsetDiff;
GuestFieldOff = GetRecordFieldOff(Code, GuestTriple);
HostFieldOff = GetRecordFieldOff(Code, HostTriple);
@ -978,15 +999,18 @@ std::vector<uint64_t> WrapperGenerator::GetRecordFieldOffDiff(
// Get the size under a specific triple
uint64_t WrapperGenerator::GetRecordSize(const Type *Type,
const std::string &Triple) {
std::string Code = TypeStringify(Type, nullptr, nullptr, "", "dummy;");
return ::GetRecordSize(Code, Triple);
std::string PreDecl;
std::string Code = TypeStringify(Type, nullptr, nullptr, PreDecl, "", "dummy;");
auto Size = ::GetRecordSize(PreDecl + Code, Triple);
return Size;
}
// Get the align under a specific triple
CharUnits::QuantityType WrapperGenerator::GetRecordAlign(const Type *Type,
const std::string &Triple) {
std::string Code = TypeStringify(Type, nullptr, nullptr, "", "dummy;");
return ::GetRecordAlign(Code, Triple);
std::string PreDecl{};
std::string Code = TypeStringify(Type, nullptr, nullptr, PreDecl, "", "dummy;");
return ::GetRecordAlign(PreDecl + Code, Triple);
}
// Generate the func sig by type, used for export func

View File

@ -129,9 +129,9 @@ private:
std::string GenCallbackWrap(clang::ASTContext* Ctx, const RecordInfo& Struct);
void ParseRecordRecursive(clang::ASTContext* Ctx, const clang::Type* Type, bool& Special, std::set<const clang::Type*>& Visited);
std::string TypeStringify(const clang::Type* Type, clang::FieldDecl* FieldDecl, clang::ParmVarDecl* ParmDecl, std::string indent = "", std::string Name = "");
std::string TypeStringify(const clang::Type* Type, clang::FieldDecl* FieldDecl, clang::ParmVarDecl* ParmDecl, std::string& PreDecl, std::string indent = "", std::string Name = "");
std::string SimpleTypeStringify(const clang::Type* Type, clang::FieldDecl* FieldDecl, clang::ParmVarDecl* ParmDecl, std::string indent = "", std::string Name = "");
std::string AnonRecordDecl(const clang::RecordType* Type, std::string indent);
std::string AnonRecordDecl(const clang::RecordType* Type, std::string& PreDecl, std::string indent);
std::string SimpleAnonRecordDecl(const clang::RecordType* Type, std::string indent);
FuncDefinition GetFuncDefinition(const clang::Type* Type);
FuncDefinition GetFuncDefinition(clang::FunctionDecl* Decl);

View File

@ -59,5 +59,11 @@ int main(int argc, const char* argv[]) {
std::string err;
auto compile_db = clang::tooling::FixedCompilationDatabase::loadFromCommandLine(argc, argv, err);
clang::tooling::ClangTool Tool(*compile_db, {argv[1]});
Tool.appendArgumentsAdjuster([&guest_triple](const clang::tooling::CommandLineArguments &args, clang::StringRef) {
clang::tooling::CommandLineArguments adjusted_args = args;
adjusted_args.push_back(std::string{"-target"});
adjusted_args.push_back(guest_triple);
return adjusted_args;
});
return Tool.run(std::make_unique<MyFrontendActionFactory>(libname, host_triple, guest_triple).get());
}

View File

@ -1,4 +1,5 @@
#pragma once
#include <clang/AST/ASTContext.h>
#include <clang/AST/Decl.h>
#include <clang/AST/Type.h>
#include <clang/Tooling/Tooling.h>
@ -19,3 +20,14 @@ static const clang::Type* StripTypedef(clang::QualType type) {
return type.getTypePtr();
}
}
// FIXME: Need to support other triple except default target triple
static std::string GetDeclHeaderFile(clang::ASTContext& Ctx, clang::Decl* Decl) {
const auto& SourceManager = Ctx.getSourceManager();
const clang::FileID FileID = SourceManager.getFileID(Decl->getBeginLoc());
const clang::FileEntry *FileEntry = SourceManager.getFileEntryForID(FileID);
if (FileEntry) {
return FileEntry->getName().str();
}
return "";
}