llvm-project/clang-tools-extra/clang-tidy/portability/SIMDIntrinsicsCheck.cpp

152 lines
4.6 KiB
C++

//===--- SIMDIntrinsicsCheck.cpp - clang-tidy------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "SIMDIntrinsicsCheck.h"
#include "clang/AST/ASTContext.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/Basic/TargetInfo.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/Triple.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/Regex.h"
using namespace clang::ast_matchers;
namespace clang {
namespace tidy {
namespace portability {
namespace {
// If the callee has parameter of VectorType or pointer to VectorType,
// or the return type is VectorType, we consider it a vector function
// and a candidate for checking.
AST_MATCHER(FunctionDecl, isVectorFunction) {
bool IsVector = Node.getReturnType()->isVectorType();
for (const ParmVarDecl *Parm : Node.parameters()) {
QualType Type = Parm->getType();
if (Type->isPointerType())
Type = Type->getPointeeType();
if (Type->isVectorType())
IsVector = true;
}
return IsVector;
}
} // namespace
static StringRef trySuggestPpc(StringRef Name) {
if (!Name.consume_front("vec_"))
return {};
return llvm::StringSwitch<StringRef>(Name)
// [simd.alg]
.Case("max", "$std::max")
.Case("min", "$std::min")
// [simd.binary]
.Case("add", "operator+ on $simd objects")
.Case("sub", "operator- on $simd objects")
.Case("mul", "operator* on $simd objects")
.Default({});
}
static StringRef trySuggestX86(StringRef Name) {
if (!(Name.consume_front("_mm_") || Name.consume_front("_mm256_") ||
Name.consume_front("_mm512_")))
return {};
// [simd.alg]
if (Name.startswith("max_"))
return "$simd::max";
if (Name.startswith("min_"))
return "$simd::min";
// [simd.binary]
if (Name.startswith("add_"))
return "operator+ on $simd objects";
if (Name.startswith("sub_"))
return "operator- on $simd objects";
if (Name.startswith("mul_"))
return "operator* on $simd objects";
return {};
}
SIMDIntrinsicsCheck::SIMDIntrinsicsCheck(StringRef Name,
ClangTidyContext *Context)
: ClangTidyCheck(Name, Context), Std(Options.get("Std", "")),
Suggest(Options.get("Suggest", false)) {}
void SIMDIntrinsicsCheck::storeOptions(ClangTidyOptions::OptionMap &Opts) {
Options.store(Opts, "Std", Std);
Options.store(Opts, "Suggest", Suggest);
}
void SIMDIntrinsicsCheck::registerMatchers(MatchFinder *Finder) {
// If Std is not specified, infer it from the language options.
// libcxx implementation backports it to C++11 std::experimental::simd.
if (Std.empty())
Std = getLangOpts().CPlusPlus20 ? "std" : "std::experimental";
Finder->addMatcher(callExpr(callee(functionDecl(
matchesName("^::(_mm_|_mm256_|_mm512_|vec_)"),
isVectorFunction())),
unless(isExpansionInSystemHeader()))
.bind("call"),
this);
}
void SIMDIntrinsicsCheck::check(const MatchFinder::MatchResult &Result) {
const auto *Call = Result.Nodes.getNodeAs<CallExpr>("call");
assert(Call != nullptr);
const FunctionDecl *Callee = Call->getDirectCallee();
if (!Callee)
return;
StringRef Old = Callee->getName();
StringRef New;
llvm::Triple::ArchType Arch =
Result.Context->getTargetInfo().getTriple().getArch();
// We warn or suggest if this SIMD intrinsic function has a std::simd
// replacement.
switch (Arch) {
default:
break;
case llvm::Triple::ppc:
case llvm::Triple::ppc64:
case llvm::Triple::ppc64le:
New = trySuggestPpc(Old);
break;
case llvm::Triple::x86:
case llvm::Triple::x86_64:
New = trySuggestX86(Old);
break;
}
// We have found a std::simd replacement.
if (!New.empty()) {
// If Suggest is true, give a P0214 alternative, otherwise point it out it
// is non-portable.
if (Suggest) {
static const llvm::Regex StdRegex("\\$std"), SimdRegex("\\$simd");
diag(Call->getExprLoc(), "'%0' can be replaced by %1")
<< Old
<< SimdRegex.sub(SmallString<32>({Std, "::simd"}),
StdRegex.sub(Std, New));
} else {
diag("'%0' is a non-portable %1 intrinsic function")
<< Old << llvm::Triple::getArchTypeName(Arch);
}
}
}
} // namespace portability
} // namespace tidy
} // namespace clang