diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp index 6eb9ae6c5268..27095ec51df4 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp @@ -23,6 +23,7 @@ #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/CodeGen/Analysis.h" #include "llvm/CodeGen/AsmPrinter.h" #include "llvm/CodeGen/MachineConstantPool.h" #include "llvm/CodeGen/MachineInstr.h" @@ -66,7 +67,7 @@ private: const auto &Subtarget = MF.getSubtarget(); TII = Subtarget.getInstrInfo(); MRI = &MF.getRegInfo(); - NumArgs = MF.getInfo()->getNumArguments(); + NumArgs = MF.getInfo()->getParams().size(); return AsmPrinter::runOnMachineFunction(MF); } @@ -82,7 +83,7 @@ private: std::string getRegTypeName(unsigned RegNo) const; static std::string toString(const APFloat &APF); - const char *toString(Type *Ty) const; + const char *toString(MVT VT) const; std::string regToString(const MachineOperand &MO); std::string argToString(const MachineOperand &MO); }; @@ -167,40 +168,20 @@ std::string WebAssemblyAsmPrinter::argToString(const MachineOperand &MO) { return utostr(ArgNo); } -const char *WebAssemblyAsmPrinter::toString(Type *Ty) const { - switch (Ty->getTypeID()) { +const char *WebAssemblyAsmPrinter::toString(MVT VT) const { + switch (VT.SimpleTy) { default: break; - // Treat all pointers as the underlying integer into linear memory. - case Type::PointerTyID: - switch (getPointerSize()) { - case 4: - return "i32"; - case 8: - return "i64"; - default: - llvm_unreachable("unsupported pointer size"); - } - break; - case Type::FloatTyID: + case MVT::f32: return "f32"; - case Type::DoubleTyID: + case MVT::f64: return "f64"; - case Type::IntegerTyID: - switch (Ty->getIntegerBitWidth()) { - case 8: - return "i8"; - case 16: - return "i16"; - case 32: - return "i32"; - case 64: - return "i64"; - default: - break; - } + case MVT::i32: + return "i32"; + case MVT::i64: + return "i64"; } - DEBUG(dbgs() << "Invalid type "; Ty->print(dbgs()); dbgs() << '\n'); + DEBUG(dbgs() << "Invalid type " << EVT(VT).getEVTString() << '\n'); llvm_unreachable("invalid type"); return ""; } @@ -219,40 +200,37 @@ void WebAssemblyAsmPrinter::EmitJumpTableInfo() { } void WebAssemblyAsmPrinter::EmitFunctionBodyStart() { - const Function *F = MF->getFunction(); - Type *Rt = F->getReturnType(); SmallString<128> Str; raw_svector_ostream OS(Str); - bool First = true; - if (!Rt->isVoidTy() || !F->arg_empty()) { - for (const Argument &A : F->args()) { - OS << (First ? "" : "\n") << "\t.param " << toString(A.getType()); - First = false; - } - if (!Rt->isVoidTy()) { - OS << (First ? "" : "\n") << "\t.result " << toString(Rt); - First = false; - } - } + for (MVT VT : MF->getInfo()->getParams()) + OS << "\t" ".param " + << toString(VT) << '\n'; + for (MVT VT : MF->getInfo()->getResults()) + OS << "\t" ".result " + << toString(VT) << '\n'; bool FirstVReg = true; for (unsigned Idx = 0, IdxE = MRI->getNumVirtRegs(); Idx != IdxE; ++Idx) { unsigned VReg = TargetRegisterInfo::index2VirtReg(Idx); // FIXME: Don't skip dead virtual registers for now: that would require // remapping all locals' numbers. - //if (!MRI->use_empty(VReg)) { - if (FirstVReg) { - OS << (First ? "" : "\n") << "\t.local "; - First = false; - } - OS << (FirstVReg ? "" : ", ") << getRegTypeName(VReg); - FirstVReg = false; + // if (!MRI->use_empty(VReg)) { + if (FirstVReg) + OS << "\t" ".local "; + else + OS << ", "; + OS << getRegTypeName(VReg); + FirstVReg = false; //} } + if (!FirstVReg) + OS << '\n'; - if (!First) - OutStreamer->EmitRawText(OS.str()); + // EmitRawText appends a newline, so strip off the last newline. + StringRef Text = OS.str(); + if (!Text.empty()) + OutStreamer->EmitRawText(Text.substr(0, Text.size() - 1)); AsmPrinter::EmitFunctionBodyStart(); } @@ -334,27 +312,75 @@ void WebAssemblyAsmPrinter::EmitInstruction(const MachineInstr *MI) { } } +static void ComputeLegalValueVTs(LLVMContext &Context, + const WebAssemblyTargetLowering &TLI, + const DataLayout &DL, Type *Ty, + SmallVectorImpl &ValueVTs) { + SmallVector VTs; + ComputeValueVTs(TLI, DL, Ty, VTs); + + for (EVT VT : VTs) { + unsigned NumRegs = TLI.getNumRegisters(Context, VT); + MVT RegisterVT = TLI.getRegisterType(Context, VT); + for (unsigned i = 0; i != NumRegs; ++i) + ValueVTs.push_back(RegisterVT); + } +} + void WebAssemblyAsmPrinter::EmitEndOfAsmFile(Module &M) { + const DataLayout &DL = M.getDataLayout(); + SmallString<128> Str; raw_svector_ostream OS(Str); for (const Function &F : M) if (F.isDeclarationForLinker()) { assert(F.hasName() && "imported functions must have a name"); - if (F.getName().startswith("llvm.")) + if (F.isIntrinsic()) continue; if (Str.empty()) OS << "\t.imports\n"; - Type *Rt = F.getReturnType(); + OS << "\t.import " << toSymbol(F.getName()) << " \"\" \"" << F.getName() - << "\" (param"; - for (const Argument &A : F.args()) - OS << ' ' << toString(A.getType()); - OS << ')'; - if (!Rt->isVoidTy()) - OS << " (result " << toString(Rt) << ')'; + << "\""; + + const WebAssemblyTargetLowering &TLI = + *TM.getSubtarget(F).getTargetLowering(); + + // If we need to legalize the return type, it'll get converted into + // passing a pointer. + bool SawParam = false; + SmallVector ResultVTs; + ComputeLegalValueVTs(M.getContext(), TLI, DL, F.getReturnType(), + ResultVTs); + if (ResultVTs.size() > 1) { + ResultVTs.clear(); + OS << " (param " << toString(TLI.getPointerTy(DL)); + SawParam = true; + } + + for (const Argument &A : F.args()) { + SmallVector ParamVTs; + ComputeLegalValueVTs(M.getContext(), TLI, DL, A.getType(), ParamVTs); + for (EVT VT : ParamVTs) { + if (!SawParam) { + OS << " (param"; + SawParam = true; + } + OS << ' ' << toString(VT.getSimpleVT()); + } + } + if (SawParam) + OS << ')'; + + for (EVT VT : ResultVTs) + OS << " (result " << toString(VT.getSimpleVT()) << ')'; + OS << '\n'; } - OutStreamer->EmitRawText(OS.str()); + + StringRef Text = OS.str(); + if (!Text.empty()) + OutStreamer->EmitRawText(Text.substr(0, Text.size() - 1)); } // Force static initialization. diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index d813367ea85a..899e768a0eb2 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -252,13 +252,8 @@ WebAssemblyTargetLowering::LowerCall(CallLoweringInfo &CLI, fail(DL, DAG, "WebAssembly doesn't support tail call yet"); CLI.IsTailCall = false; - SmallVectorImpl &Outs = CLI.Outs; SmallVectorImpl &OutVals = CLI.OutVals; - bool IsStructRet = (Outs.empty()) ? false : Outs[0].Flags.isSRet(); - if (IsStructRet) - fail(DL, DAG, "WebAssembly doesn't support struct return yet"); - SmallVectorImpl &Ins = CLI.Ins; if (Ins.size() > 1) fail(DL, DAG, "WebAssembly doesn't support more than 1 returned value yet"); @@ -316,6 +311,7 @@ SDValue WebAssemblyTargetLowering::LowerReturn( const SmallVectorImpl &Outs, const SmallVectorImpl &OutVals, SDLoc DL, SelectionDAG &DAG) const { + MachineFunction &MF = DAG.getMachineFunction(); assert(Outs.size() <= 1 && "WebAssembly can only return up to one value"); if (CallConv != CallingConv::C) @@ -327,6 +323,33 @@ SDValue WebAssemblyTargetLowering::LowerReturn( RetOps.append(OutVals.begin(), OutVals.end()); Chain = DAG.getNode(WebAssemblyISD::RETURN, DL, MVT::Other, RetOps); + // Record the number and types of the return values. + for (const ISD::OutputArg &Out : Outs) { + if (Out.Flags.isZExt()) + fail(DL, DAG, "WebAssembly hasn't implemented zext results"); + if (Out.Flags.isSExt()) + fail(DL, DAG, "WebAssembly hasn't implemented sext results"); + if (Out.Flags.isInReg()) + fail(DL, DAG, "WebAssembly hasn't implemented inreg results"); + if (Out.Flags.isSRet()) + fail(DL, DAG, "WebAssembly hasn't implemented sret results"); + if (Out.Flags.isByVal()) + fail(DL, DAG, "WebAssembly hasn't implemented byval results"); + if (Out.Flags.isInAlloca()) + fail(DL, DAG, "WebAssembly hasn't implemented inalloca results"); + if (Out.Flags.isNest()) + fail(DL, DAG, "WebAssembly hasn't implemented nest results"); + if (Out.Flags.isReturned()) + fail(DL, DAG, "WebAssembly hasn't implemented returned results"); + if (Out.Flags.isInConsecutiveRegs()) + fail(DL, DAG, "WebAssembly hasn't implemented cons regs results"); + if (Out.Flags.isInConsecutiveRegsLast()) + fail(DL, DAG, "WebAssembly hasn't implemented cons regs last results"); + if (!Out.IsFixed) + fail(DL, DAG, "WebAssembly doesn't support non-fixed results yet"); + MF.getInfo()->addResult(Out.VT); + } + return Chain; } @@ -340,8 +363,6 @@ SDValue WebAssemblyTargetLowering::LowerFormalArguments( fail(DL, DAG, "WebAssembly doesn't support non-C calling conventions"); if (IsVarArg) fail(DL, DAG, "WebAssembly doesn't support varargs yet"); - if (MF.getFunction()->hasStructRetAttr()) - fail(DL, DAG, "WebAssembly doesn't support struct return yet"); unsigned ArgNo = 0; for (const ISD::InputArg &In : Ins) { @@ -365,21 +386,18 @@ SDValue WebAssemblyTargetLowering::LowerFormalArguments( fail(DL, DAG, "WebAssembly hasn't implemented cons regs arguments"); if (In.Flags.isInConsecutiveRegsLast()) fail(DL, DAG, "WebAssembly hasn't implemented cons regs last arguments"); - if (In.Flags.isSplit()) - fail(DL, DAG, "WebAssembly hasn't implemented split arguments"); // FIXME Do something with In.getOrigAlign()? InVals.push_back( In.Used ? DAG.getNode(WebAssemblyISD::ARGUMENT, DL, In.VT, DAG.getTargetConstant(ArgNo, DL, MVT::i32)) : DAG.getNode(ISD::UNDEF, DL, In.VT)); + + // Record the number and types of arguments. + MF.getInfo()->addParam(In.VT); ++ArgNo; } - // Record the number of arguments, since argument indices and local variable - // indices are in the same index space. - MF.getInfo()->setNumArguments(ArgNo); - return Chain; } diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.h index a571e63d7f6a..bac0dfafcf31 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.h @@ -27,15 +27,19 @@ namespace llvm { class WebAssemblyFunctionInfo final : public MachineFunctionInfo { MachineFunction &MF; - unsigned NumArguments; + std::vector Params; + std::vector Results; public: explicit WebAssemblyFunctionInfo(MachineFunction &MF) - : MF(MF), NumArguments(0) {} + : MF(MF) {} ~WebAssemblyFunctionInfo() override; - void setNumArguments(unsigned N) { NumArguments = N; } - unsigned getNumArguments() const { return NumArguments; } + void addParam(MVT VT) { Params.push_back(VT); } + const std::vector &getParams() const { return Params; } + + void addResult(MVT VT) { Results.push_back(VT); } + const std::vector &getResults() const { return Results; } }; } // end namespace llvm diff --git a/llvm/test/CodeGen/WebAssembly/import.ll b/llvm/test/CodeGen/WebAssembly/import.ll index 6f1f8e0c3aee..09c7cefcd653 100644 --- a/llvm/test/CodeGen/WebAssembly/import.ll +++ b/llvm/test/CodeGen/WebAssembly/import.ll @@ -5,19 +5,28 @@ target triple = "wasm32-unknown-unknown" ; CHECK-LABEL: .text ; CHECK-LABEL: f: -define void @f(i32 %a, float %b) { +define void @f(i32 %a, float %b, i128 %c, i1 %d) { tail call i32 @printi(i32 %a) tail call float @printf(float %b) tail call void @printv() + tail call void @split_arg(i128 %c) + tail call void @expanded_arg(i1 %d) + tail call i1 @lowered_result() ret void } ; CHECK-LABEL: .imports -; CHECK-NEXT: .import $printi "" "printi" (param i32) (result i32) -; CHECK-NEXT: .import $printf "" "printf" (param f32) (result f32) -; CHECK-NEXT: .import $printv "" "printv" (param) -; CHECK-NEXT: .import $add2 "" "add2" (param i32 i32) (result i32) +; CHECK-NEXT: .import $printi "" "printi" (param i32) (result i32){{$}} +; CHECK-NEXT: .import $printf "" "printf" (param f32) (result f32){{$}} +; CHECK-NEXT: .import $printv "" "printv"{{$}} +; CHECK-NEXT: .import $add2 "" "add2" (param i32 i32) (result i32){{$}} +; CHECK-NEXT: .import $split_arg "" "split_arg" (param i64 i64){{$}} +; CHECK-NEXT: .import $expanded_arg "" "expanded_arg" (param i32){{$}} +; CHECK-NEXT: .import $lowered_result "" "lowered_result" (result i32){{$}} declare i32 @printi(i32) declare float @printf(float) declare void @printv() declare i32 @add2(i32, i32) +declare void @split_arg(i128) +declare void @expanded_arg(i1) +declare i1 @lowered_result()