forked from OSchip/llvm-project
[StreamExecutor] Simplify Kernel classes
Summary: Make the Kernel class follow the pattern of the other classes. It now has a type-safe user wrapper and a typeless, platform-specific handle. Reviewers: jlebar Subscribers: jprice, parallel_libs-commits Differential Revision: https://reviews.llvm.org/D24043 llvm-svn: 280176
This commit is contained in:
parent
ddb53dd080
commit
90ce6e1e64
|
@ -15,13 +15,14 @@
|
|||
#ifndef STREAMEXECUTOR_DEVICE_H
|
||||
#define STREAMEXECUTOR_DEVICE_H
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "streamexecutor/KernelSpec.h"
|
||||
#include "streamexecutor/PlatformInterfaces.h"
|
||||
#include "streamexecutor/Utils/Error.h"
|
||||
|
||||
namespace streamexecutor {
|
||||
|
||||
class KernelInterface;
|
||||
class Stream;
|
||||
|
||||
class Device {
|
||||
|
@ -29,11 +30,24 @@ public:
|
|||
explicit Device(PlatformDevice *PDevice);
|
||||
virtual ~Device();
|
||||
|
||||
/// Gets the kernel implementation for the underlying platform.
|
||||
virtual Expected<std::unique_ptr<KernelInterface>>
|
||||
getKernelImplementation(const MultiKernelLoaderSpec &Spec) {
|
||||
// TODO(jhen): Implement this.
|
||||
return nullptr;
|
||||
/// Creates a kernel object for this device.
|
||||
///
|
||||
/// If the return value is not an error, the returned pointer will never be
|
||||
/// null.
|
||||
///
|
||||
/// See \ref CompilerGeneratedKernelExample "Kernel.h" for an example of how
|
||||
/// this method is used.
|
||||
template <typename KernelT>
|
||||
Expected<std::unique_ptr<typename std::enable_if<
|
||||
std::is_base_of<KernelBase, KernelT>::value, KernelT>::type>>
|
||||
createKernel(const MultiKernelLoaderSpec &Spec) {
|
||||
Expected<std::unique_ptr<PlatformKernelHandle>> MaybeKernelHandle =
|
||||
PDevice->createKernel(Spec);
|
||||
if (!MaybeKernelHandle) {
|
||||
return MaybeKernelHandle.takeError();
|
||||
}
|
||||
return llvm::make_unique<KernelT>(Spec.getKernelName(),
|
||||
std::move(*MaybeKernelHandle));
|
||||
}
|
||||
|
||||
Expected<std::unique_ptr<Stream>> createStream();
|
||||
|
|
|
@ -11,62 +11,64 @@
|
|||
/// Types to represent device kernels (code compiled to run on GPU or other
|
||||
/// accelerator).
|
||||
///
|
||||
/// The TypedKernel class is used to provide type safety to the user API's
|
||||
/// launch functions, and the KernelBase class is used like a void* function
|
||||
/// pointer to perform type-unsafe operations inside StreamExecutor.
|
||||
///
|
||||
/// With the kernel parameter types recorded in the TypedKernel template
|
||||
/// parameters, type-safe kernel launch functions can be written with signatures
|
||||
/// like the following:
|
||||
/// With the kernel parameter types recorded in the Kernel template parameters,
|
||||
/// type-safe kernel launch functions can be written with signatures like the
|
||||
/// following:
|
||||
/// \code
|
||||
/// template <typename... ParameterTs>
|
||||
/// void Launch(
|
||||
/// const TypedKernel<ParameterTs...> &Kernel, ParamterTs... Arguments);
|
||||
/// const Kernel<ParameterTs...> &Kernel, ParamterTs... Arguments);
|
||||
/// \endcode
|
||||
/// and the compiler will check that the user passes in arguments with types
|
||||
/// matching the corresponding kernel parameters.
|
||||
///
|
||||
/// A problem is that a TypedKernel template specialization with the right
|
||||
/// parameter types must be passed as the first argument to the Launch function,
|
||||
/// and it's just as hard to get the types right in that template specialization
|
||||
/// as it is to get them right for the kernel arguments.
|
||||
/// A problem is that a Kernel template specialization with the right parameter
|
||||
/// types must be passed as the first argument to the Launch function, and it's
|
||||
/// just as hard to get the types right in that template specialization as it is
|
||||
/// to get them right for the kernel arguments.
|
||||
///
|
||||
/// With this problem in mind, it is not recommended for users to specialize the
|
||||
/// TypedKernel template class themselves, but instead to let the compiler do it
|
||||
/// for them. When the compiler encounters a device kernel function, it can
|
||||
/// create a TypedKernel template specialization in the host code that has the
|
||||
/// right parameter types for that kernel and which has a type name based on the
|
||||
/// name of the kernel function.
|
||||
/// Kernel template class themselves, but instead to let the compiler do it for
|
||||
/// them. When the compiler encounters a device kernel function, it can create a
|
||||
/// Kernel template specialization in the host code that has the right parameter
|
||||
/// types for that kernel and which has a type name based on the name of the
|
||||
/// kernel function.
|
||||
///
|
||||
/// \anchor CompilerGeneratedKernelExample
|
||||
/// For example, if a CUDA device kernel function with the following signature
|
||||
/// has been defined:
|
||||
/// \code
|
||||
/// void Saxpy(float *A, float *X, float *Y);
|
||||
/// void Saxpy(float A, float *X, float *Y);
|
||||
/// \endcode
|
||||
/// the compiler can insert the following declaration in the host code:
|
||||
/// \code
|
||||
/// namespace compiler_cuda_namespace {
|
||||
/// namespace se = streamexecutor;
|
||||
/// using SaxpyKernel =
|
||||
/// streamexecutor::TypedKernel<float *, float *, float *>;
|
||||
/// se::Kernel<
|
||||
/// float,
|
||||
/// se::GlobalDeviceMemory<float>,
|
||||
/// se::GlobalDeviceMemory<float>>;
|
||||
/// } // namespace compiler_cuda_namespace
|
||||
/// \endcode
|
||||
/// and then the user can launch the kernel by calling the StreamExecutor launch
|
||||
/// function as follows:
|
||||
/// \code
|
||||
/// namespace ccn = compiler_cuda_namespace;
|
||||
/// using KernelPtr = std::unique_ptr<cnn::SaxpyKernel>;
|
||||
/// // Assumes Device is a pointer to the Device on which to launch the
|
||||
/// // kernel.
|
||||
/// //
|
||||
/// // See KernelSpec.h for details on how the compiler can create a
|
||||
/// // MultiKernelLoaderSpec instance like SaxpyKernelLoaderSpec below.
|
||||
/// Expected<ccn::SaxpyKernel> MaybeKernel =
|
||||
/// ccn::SaxpyKernel::create(Device, ccn::SaxpyKernelLoaderSpec);
|
||||
/// Expected<KernelPtr> MaybeKernel =
|
||||
/// Device->createKernel<ccn::SaxpyKernel>(ccn::SaxpyKernelLoaderSpec);
|
||||
/// if (!MaybeKernel) { /* Handle error */ }
|
||||
/// ccn::SaxpyKernel SaxpyKernel = *MaybeKernel;
|
||||
/// Launch(SaxpyKernel, A, X, Y);
|
||||
/// KernelPtr SaxpyKernel = std::move(*MaybeKernel);
|
||||
/// Launch(*SaxpyKernel, A, X, Y);
|
||||
/// \endcode
|
||||
///
|
||||
/// With the compiler's help in specializing TypedKernel for each device kernel
|
||||
/// With the compiler's help in specializing Kernel for each device kernel
|
||||
/// function (and generating a MultiKernelLoaderSpec instance for each kernel),
|
||||
/// the user can safely launch the device kernel from the host and get an error
|
||||
/// message at compile time if the argument types don't match the kernel
|
||||
|
@ -84,73 +86,37 @@
|
|||
|
||||
namespace streamexecutor {
|
||||
|
||||
class Device;
|
||||
class KernelInterface;
|
||||
class PlatformKernelHandle;
|
||||
|
||||
/// The base class for device kernel functions.
|
||||
/// The base class for all kernel types.
|
||||
///
|
||||
/// This class has no information about the types of the parameters taken by the
|
||||
/// kernel, so it is analogous to a void* pointer to a device function.
|
||||
///
|
||||
/// See the TypedKernel class below for the subclass which does have information
|
||||
/// about parameter types.
|
||||
/// Stores the name of the kernel in both mangled and demangled forms.
|
||||
class KernelBase {
|
||||
public:
|
||||
KernelBase(KernelBase &&) = default;
|
||||
KernelBase &operator=(KernelBase &&) = default;
|
||||
~KernelBase();
|
||||
|
||||
/// Creates a kernel object from a Device and a MultiKernelLoaderSpec.
|
||||
///
|
||||
/// The Device knows which platform it belongs to and the
|
||||
/// MultiKernelLoaderSpec knows how to find the kernel code for different
|
||||
/// platforms, so the combined information is enough to get the kernel code
|
||||
/// for the appropriate platform.
|
||||
static Expected<KernelBase> create(Device *Dev,
|
||||
const MultiKernelLoaderSpec &Spec);
|
||||
KernelBase(llvm::StringRef Name);
|
||||
|
||||
const std::string &getName() const { return Name; }
|
||||
const std::string &getDemangledName() const { return DemangledName; }
|
||||
|
||||
/// Gets a pointer to the platform-specific implementation of this kernel.
|
||||
KernelInterface *getImplementation() { return Implementation.get(); }
|
||||
|
||||
private:
|
||||
KernelBase(Device *Dev, const std::string &Name,
|
||||
const std::string &DemangledName,
|
||||
std::unique_ptr<KernelInterface> Implementation);
|
||||
|
||||
Device *TheDevice;
|
||||
std::string Name;
|
||||
std::string DemangledName;
|
||||
std::unique_ptr<KernelInterface> Implementation;
|
||||
|
||||
KernelBase(const KernelBase &) = delete;
|
||||
KernelBase &operator=(const KernelBase &) = delete;
|
||||
};
|
||||
|
||||
/// A device kernel function with specified parameter types.
|
||||
template <typename... ParameterTs> class TypedKernel : public KernelBase {
|
||||
/// A StreamExecutor kernel.
|
||||
///
|
||||
/// The template parameters are the types of the parameters to the kernel
|
||||
/// function.
|
||||
template <typename... ParameterTs> class Kernel : public KernelBase {
|
||||
public:
|
||||
TypedKernel(TypedKernel &&) = default;
|
||||
TypedKernel &operator=(TypedKernel &&) = default;
|
||||
Kernel(llvm::StringRef Name, std::unique_ptr<PlatformKernelHandle> PHandle)
|
||||
: KernelBase(Name), PHandle(std::move(PHandle)) {}
|
||||
|
||||
/// Parameters here have the same meaning as in KernelBase::create.
|
||||
static Expected<TypedKernel> create(Device *Dev,
|
||||
const MultiKernelLoaderSpec &Spec) {
|
||||
auto MaybeBase = KernelBase::create(Dev, Spec);
|
||||
if (!MaybeBase) {
|
||||
return MaybeBase.takeError();
|
||||
}
|
||||
TypedKernel Instance(std::move(*MaybeBase));
|
||||
return std::move(Instance);
|
||||
}
|
||||
/// Gets the underlying platform-specific handle for this kernel.
|
||||
PlatformKernelHandle *getPlatformHandle() const { return PHandle.get(); }
|
||||
|
||||
private:
|
||||
TypedKernel(KernelBase &&Base) : KernelBase(std::move(Base)) {}
|
||||
|
||||
TypedKernel(const TypedKernel &) = delete;
|
||||
TypedKernel &operator=(const TypedKernel &) = delete;
|
||||
std::unique_ptr<PlatformKernelHandle> PHandle;
|
||||
};
|
||||
|
||||
} // namespace streamexecutor
|
||||
|
|
|
@ -33,9 +33,17 @@ namespace streamexecutor {
|
|||
|
||||
class PlatformDevice;
|
||||
|
||||
/// Methods supported by device kernel function objects on all platforms.
|
||||
class KernelInterface {
|
||||
// TODO(jhen): Add methods.
|
||||
/// Platform-specific kernel handle.
|
||||
class PlatformKernelHandle {
|
||||
public:
|
||||
explicit PlatformKernelHandle(PlatformDevice *PDevice) : PDevice(PDevice) {}
|
||||
|
||||
virtual ~PlatformKernelHandle();
|
||||
|
||||
PlatformDevice *getDevice() { return PDevice; }
|
||||
|
||||
private:
|
||||
PlatformDevice *PDevice;
|
||||
};
|
||||
|
||||
/// Platform-specific stream handle.
|
||||
|
@ -64,12 +72,20 @@ public:
|
|||
|
||||
virtual std::string getName() const = 0;
|
||||
|
||||
/// Creates a platform-specific kernel.
|
||||
virtual Expected<std::unique_ptr<PlatformKernelHandle>>
|
||||
createKernel(const MultiKernelLoaderSpec &Spec) {
|
||||
return make_error("createKernel not implemented for platform " + getName());
|
||||
}
|
||||
|
||||
/// Creates a platform-specific stream.
|
||||
virtual Expected<std::unique_ptr<PlatformStreamHandle>> createStream() = 0;
|
||||
virtual Expected<std::unique_ptr<PlatformStreamHandle>> createStream() {
|
||||
return make_error("createStream not implemented for platform " + getName());
|
||||
}
|
||||
|
||||
/// Launches a kernel on the given stream.
|
||||
virtual Error launch(PlatformStreamHandle *S, BlockDimensions BlockSize,
|
||||
GridDimensions GridSize, const KernelBase &Kernel,
|
||||
GridDimensions GridSize, PlatformKernelHandle *K,
|
||||
const PackedKernelArgumentArrayBase &ArgumentArray) {
|
||||
return make_error("launch not implemented for platform " + getName());
|
||||
}
|
||||
|
|
|
@ -86,15 +86,15 @@ public:
|
|||
/// These arguments can be device memory types like GlobalDeviceMemory<T> and
|
||||
/// SharedDeviceMemory<T>, or they can be primitive types such as int. The
|
||||
/// allowable argument types are determined by the template parameters to the
|
||||
/// TypedKernel argument.
|
||||
/// Kernel argument.
|
||||
template <typename... ParameterTs>
|
||||
Stream &thenLaunch(BlockDimensions BlockSize, GridDimensions GridSize,
|
||||
const TypedKernel<ParameterTs...> &Kernel,
|
||||
const Kernel<ParameterTs...> &K,
|
||||
const ParameterTs &... Arguments) {
|
||||
auto ArgumentArray =
|
||||
make_kernel_argument_pack<ParameterTs...>(Arguments...);
|
||||
setError(PDevice->launch(ThePlatformStream.get(), BlockSize, GridSize,
|
||||
Kernel, ArgumentArray));
|
||||
K.getPlatformHandle(), ArgumentArray));
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
|
|
@ -20,26 +20,8 @@
|
|||
|
||||
namespace streamexecutor {
|
||||
|
||||
KernelBase::KernelBase(Device *Dev, const std::string &Name,
|
||||
const std::string &DemangledName,
|
||||
std::unique_ptr<KernelInterface> Implementation)
|
||||
: TheDevice(Dev), Name(Name), DemangledName(DemangledName),
|
||||
Implementation(std::move(Implementation)) {}
|
||||
|
||||
KernelBase::~KernelBase() = default;
|
||||
|
||||
Expected<KernelBase> KernelBase::create(Device *Dev,
|
||||
const MultiKernelLoaderSpec &Spec) {
|
||||
auto MaybeImplementation = Dev->getKernelImplementation(Spec);
|
||||
if (!MaybeImplementation) {
|
||||
return MaybeImplementation.takeError();
|
||||
}
|
||||
std::string Name = Spec.getKernelName();
|
||||
std::string DemangledName =
|
||||
llvm::symbolize::LLVMSymbolizer::DemangleName(Name, nullptr);
|
||||
KernelBase Instance(Dev, Name, DemangledName,
|
||||
std::move(*MaybeImplementation));
|
||||
return std::move(Instance);
|
||||
}
|
||||
KernelBase::KernelBase(llvm::StringRef Name)
|
||||
: Name(Name), DemangledName(llvm::symbolize::LLVMSymbolizer::DemangleName(
|
||||
Name, nullptr)) {}
|
||||
|
||||
} // namespace streamexecutor
|
||||
|
|
|
@ -8,16 +8,6 @@ target_link_libraries(
|
|||
${CMAKE_THREAD_LIBS_INIT})
|
||||
add_test(DeviceTest device_test)
|
||||
|
||||
add_executable(
|
||||
kernel_test
|
||||
KernelTest.cpp)
|
||||
target_link_libraries(
|
||||
kernel_test
|
||||
streamexecutor
|
||||
${GTEST_BOTH_LIBRARIES}
|
||||
${CMAKE_THREAD_LIBS_INIT})
|
||||
add_test(KernelTest kernel_test)
|
||||
|
||||
add_executable(
|
||||
kernel_spec_test
|
||||
KernelSpecTest.cpp)
|
||||
|
|
|
@ -1,93 +0,0 @@
|
|||
//===-- KernelTest.cpp - Tests for Kernel objects -------------------------===//
|
||||
//
|
||||
// The LLVM Compiler Infrastructure
|
||||
//
|
||||
// This file is distributed under the University of Illinois Open Source
|
||||
// License. See LICENSE.TXT for details.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
///
|
||||
/// \file
|
||||
/// This file contains the unit tests for the code in Kernel.
|
||||
///
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "streamexecutor/Device.h"
|
||||
#include "streamexecutor/Kernel.h"
|
||||
#include "streamexecutor/KernelSpec.h"
|
||||
#include "streamexecutor/PlatformInterfaces.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace {
|
||||
|
||||
namespace se = ::streamexecutor;
|
||||
|
||||
// A Device that returns a dummy KernelInterface.
|
||||
//
|
||||
// During construction it creates a unique_ptr to a dummy KernelInterface and it
|
||||
// also stores a separate copy of the raw pointer that is stored by that
|
||||
// unique_ptr.
|
||||
//
|
||||
// The expectation is that the code being tested will call the
|
||||
// getKernelImplementation method and will thereby take ownership of the
|
||||
// unique_ptr, but the copy of the raw pointer will stay behind in this mock
|
||||
// object. The raw pointer copy can then be used to identify the unique_ptr in
|
||||
// its new location (by comparing the raw pointer with unique_ptr::get), to
|
||||
// verify that the unique_ptr ended up where it was supposed to be.
|
||||
class MockDevice : public se::Device {
|
||||
public:
|
||||
MockDevice()
|
||||
: se::Device(nullptr), Unique(llvm::make_unique<se::KernelInterface>()),
|
||||
Raw(Unique.get()) {}
|
||||
|
||||
// Moves the unique pointer into the returned se::Expected instance.
|
||||
//
|
||||
// Asserts that it is not called again after the unique pointer has been moved
|
||||
// out.
|
||||
se::Expected<std::unique_ptr<se::KernelInterface>>
|
||||
getKernelImplementation(const se::MultiKernelLoaderSpec &) override {
|
||||
assert(Unique && "MockDevice getKernelImplementation should not be "
|
||||
"called more than once");
|
||||
return std::move(Unique);
|
||||
}
|
||||
|
||||
// Gets the copy of the raw pointer from the original unique pointer.
|
||||
const se::KernelInterface *getRaw() const { return Raw; }
|
||||
|
||||
private:
|
||||
std::unique_ptr<se::KernelInterface> Unique;
|
||||
const se::KernelInterface *Raw;
|
||||
};
|
||||
|
||||
// Test fixture class for typed tests for KernelBase.getImplementation.
|
||||
//
|
||||
// The only purpose of this class is to provide a name that types can be bound
|
||||
// to in the gtest infrastructure.
|
||||
template <typename T> class GetImplementationTest : public ::testing::Test {};
|
||||
|
||||
// Types used with the GetImplementationTest fixture class.
|
||||
typedef ::testing::Types<se::KernelBase, se::TypedKernel<>,
|
||||
se::TypedKernel<int>>
|
||||
GetImplementationTypes;
|
||||
|
||||
TYPED_TEST_CASE(GetImplementationTest, GetImplementationTypes);
|
||||
|
||||
// Tests that the kernel create functions properly fetch the implementation
|
||||
// pointers for the kernel objects they construct from the passed-in
|
||||
// Device objects.
|
||||
TYPED_TEST(GetImplementationTest, SetImplementationDuringCreate) {
|
||||
se::MultiKernelLoaderSpec Spec;
|
||||
MockDevice Dev;
|
||||
|
||||
auto MaybeKernel = TypeParam::create(&Dev, Spec);
|
||||
EXPECT_TRUE(static_cast<bool>(MaybeKernel));
|
||||
se::KernelInterface *Implementation = MaybeKernel->getImplementation();
|
||||
EXPECT_EQ(Dev.getRaw(), Implementation);
|
||||
}
|
||||
|
||||
} // namespace
|
Loading…
Reference in New Issue