[StreamExecutor] Simplify Kernel classes

Summary:
Make the Kernel class follow the pattern of the other classes. It now
has a type-safe user wrapper and a typeless, platform-specific handle.

Reviewers: jlebar

Subscribers: jprice, parallel_libs-commits

Differential Revision: https://reviews.llvm.org/D24043

llvm-svn: 280176
This commit is contained in:
Jason Henline 2016-08-30 23:35:24 +00:00
parent ddb53dd080
commit 90ce6e1e64
7 changed files with 87 additions and 212 deletions

View File

@ -15,13 +15,14 @@
#ifndef STREAMEXECUTOR_DEVICE_H
#define STREAMEXECUTOR_DEVICE_H
#include <type_traits>
#include "streamexecutor/KernelSpec.h"
#include "streamexecutor/PlatformInterfaces.h"
#include "streamexecutor/Utils/Error.h"
namespace streamexecutor {
class KernelInterface;
class Stream;
class Device {
@ -29,11 +30,24 @@ public:
explicit Device(PlatformDevice *PDevice);
virtual ~Device();
/// Gets the kernel implementation for the underlying platform.
virtual Expected<std::unique_ptr<KernelInterface>>
getKernelImplementation(const MultiKernelLoaderSpec &Spec) {
// TODO(jhen): Implement this.
return nullptr;
/// Creates a kernel object for this device.
///
/// If the return value is not an error, the returned pointer will never be
/// null.
///
/// See \ref CompilerGeneratedKernelExample "Kernel.h" for an example of how
/// this method is used.
template <typename KernelT>
Expected<std::unique_ptr<typename std::enable_if<
std::is_base_of<KernelBase, KernelT>::value, KernelT>::type>>
createKernel(const MultiKernelLoaderSpec &Spec) {
Expected<std::unique_ptr<PlatformKernelHandle>> MaybeKernelHandle =
PDevice->createKernel(Spec);
if (!MaybeKernelHandle) {
return MaybeKernelHandle.takeError();
}
return llvm::make_unique<KernelT>(Spec.getKernelName(),
std::move(*MaybeKernelHandle));
}
Expected<std::unique_ptr<Stream>> createStream();

View File

@ -11,62 +11,64 @@
/// Types to represent device kernels (code compiled to run on GPU or other
/// accelerator).
///
/// The TypedKernel class is used to provide type safety to the user API's
/// launch functions, and the KernelBase class is used like a void* function
/// pointer to perform type-unsafe operations inside StreamExecutor.
///
/// With the kernel parameter types recorded in the TypedKernel template
/// parameters, type-safe kernel launch functions can be written with signatures
/// like the following:
/// With the kernel parameter types recorded in the Kernel template parameters,
/// type-safe kernel launch functions can be written with signatures like the
/// following:
/// \code
/// template <typename... ParameterTs>
/// void Launch(
/// const TypedKernel<ParameterTs...> &Kernel, ParamterTs... Arguments);
/// const Kernel<ParameterTs...> &Kernel, ParamterTs... Arguments);
/// \endcode
/// and the compiler will check that the user passes in arguments with types
/// matching the corresponding kernel parameters.
///
/// A problem is that a TypedKernel template specialization with the right
/// parameter types must be passed as the first argument to the Launch function,
/// and it's just as hard to get the types right in that template specialization
/// as it is to get them right for the kernel arguments.
/// A problem is that a Kernel template specialization with the right parameter
/// types must be passed as the first argument to the Launch function, and it's
/// just as hard to get the types right in that template specialization as it is
/// to get them right for the kernel arguments.
///
/// With this problem in mind, it is not recommended for users to specialize the
/// TypedKernel template class themselves, but instead to let the compiler do it
/// for them. When the compiler encounters a device kernel function, it can
/// create a TypedKernel template specialization in the host code that has the
/// right parameter types for that kernel and which has a type name based on the
/// name of the kernel function.
/// Kernel template class themselves, but instead to let the compiler do it for
/// them. When the compiler encounters a device kernel function, it can create a
/// Kernel template specialization in the host code that has the right parameter
/// types for that kernel and which has a type name based on the name of the
/// kernel function.
///
/// \anchor CompilerGeneratedKernelExample
/// For example, if a CUDA device kernel function with the following signature
/// has been defined:
/// \code
/// void Saxpy(float *A, float *X, float *Y);
/// void Saxpy(float A, float *X, float *Y);
/// \endcode
/// the compiler can insert the following declaration in the host code:
/// \code
/// namespace compiler_cuda_namespace {
/// namespace se = streamexecutor;
/// using SaxpyKernel =
/// streamexecutor::TypedKernel<float *, float *, float *>;
/// se::Kernel<
/// float,
/// se::GlobalDeviceMemory<float>,
/// se::GlobalDeviceMemory<float>>;
/// } // namespace compiler_cuda_namespace
/// \endcode
/// and then the user can launch the kernel by calling the StreamExecutor launch
/// function as follows:
/// \code
/// namespace ccn = compiler_cuda_namespace;
/// using KernelPtr = std::unique_ptr<cnn::SaxpyKernel>;
/// // Assumes Device is a pointer to the Device on which to launch the
/// // kernel.
/// //
/// // See KernelSpec.h for details on how the compiler can create a
/// // MultiKernelLoaderSpec instance like SaxpyKernelLoaderSpec below.
/// Expected<ccn::SaxpyKernel> MaybeKernel =
/// ccn::SaxpyKernel::create(Device, ccn::SaxpyKernelLoaderSpec);
/// Expected<KernelPtr> MaybeKernel =
/// Device->createKernel<ccn::SaxpyKernel>(ccn::SaxpyKernelLoaderSpec);
/// if (!MaybeKernel) { /* Handle error */ }
/// ccn::SaxpyKernel SaxpyKernel = *MaybeKernel;
/// Launch(SaxpyKernel, A, X, Y);
/// KernelPtr SaxpyKernel = std::move(*MaybeKernel);
/// Launch(*SaxpyKernel, A, X, Y);
/// \endcode
///
/// With the compiler's help in specializing TypedKernel for each device kernel
/// With the compiler's help in specializing Kernel for each device kernel
/// function (and generating a MultiKernelLoaderSpec instance for each kernel),
/// the user can safely launch the device kernel from the host and get an error
/// message at compile time if the argument types don't match the kernel
@ -84,73 +86,37 @@
namespace streamexecutor {
class Device;
class KernelInterface;
class PlatformKernelHandle;
/// The base class for device kernel functions.
/// The base class for all kernel types.
///
/// This class has no information about the types of the parameters taken by the
/// kernel, so it is analogous to a void* pointer to a device function.
///
/// See the TypedKernel class below for the subclass which does have information
/// about parameter types.
/// Stores the name of the kernel in both mangled and demangled forms.
class KernelBase {
public:
KernelBase(KernelBase &&) = default;
KernelBase &operator=(KernelBase &&) = default;
~KernelBase();
/// Creates a kernel object from a Device and a MultiKernelLoaderSpec.
///
/// The Device knows which platform it belongs to and the
/// MultiKernelLoaderSpec knows how to find the kernel code for different
/// platforms, so the combined information is enough to get the kernel code
/// for the appropriate platform.
static Expected<KernelBase> create(Device *Dev,
const MultiKernelLoaderSpec &Spec);
KernelBase(llvm::StringRef Name);
const std::string &getName() const { return Name; }
const std::string &getDemangledName() const { return DemangledName; }
/// Gets a pointer to the platform-specific implementation of this kernel.
KernelInterface *getImplementation() { return Implementation.get(); }
private:
KernelBase(Device *Dev, const std::string &Name,
const std::string &DemangledName,
std::unique_ptr<KernelInterface> Implementation);
Device *TheDevice;
std::string Name;
std::string DemangledName;
std::unique_ptr<KernelInterface> Implementation;
KernelBase(const KernelBase &) = delete;
KernelBase &operator=(const KernelBase &) = delete;
};
/// A device kernel function with specified parameter types.
template <typename... ParameterTs> class TypedKernel : public KernelBase {
/// A StreamExecutor kernel.
///
/// The template parameters are the types of the parameters to the kernel
/// function.
template <typename... ParameterTs> class Kernel : public KernelBase {
public:
TypedKernel(TypedKernel &&) = default;
TypedKernel &operator=(TypedKernel &&) = default;
Kernel(llvm::StringRef Name, std::unique_ptr<PlatformKernelHandle> PHandle)
: KernelBase(Name), PHandle(std::move(PHandle)) {}
/// Parameters here have the same meaning as in KernelBase::create.
static Expected<TypedKernel> create(Device *Dev,
const MultiKernelLoaderSpec &Spec) {
auto MaybeBase = KernelBase::create(Dev, Spec);
if (!MaybeBase) {
return MaybeBase.takeError();
}
TypedKernel Instance(std::move(*MaybeBase));
return std::move(Instance);
}
/// Gets the underlying platform-specific handle for this kernel.
PlatformKernelHandle *getPlatformHandle() const { return PHandle.get(); }
private:
TypedKernel(KernelBase &&Base) : KernelBase(std::move(Base)) {}
TypedKernel(const TypedKernel &) = delete;
TypedKernel &operator=(const TypedKernel &) = delete;
std::unique_ptr<PlatformKernelHandle> PHandle;
};
} // namespace streamexecutor

View File

@ -33,9 +33,17 @@ namespace streamexecutor {
class PlatformDevice;
/// Methods supported by device kernel function objects on all platforms.
class KernelInterface {
// TODO(jhen): Add methods.
/// Platform-specific kernel handle.
class PlatformKernelHandle {
public:
explicit PlatformKernelHandle(PlatformDevice *PDevice) : PDevice(PDevice) {}
virtual ~PlatformKernelHandle();
PlatformDevice *getDevice() { return PDevice; }
private:
PlatformDevice *PDevice;
};
/// Platform-specific stream handle.
@ -64,12 +72,20 @@ public:
virtual std::string getName() const = 0;
/// Creates a platform-specific kernel.
virtual Expected<std::unique_ptr<PlatformKernelHandle>>
createKernel(const MultiKernelLoaderSpec &Spec) {
return make_error("createKernel not implemented for platform " + getName());
}
/// Creates a platform-specific stream.
virtual Expected<std::unique_ptr<PlatformStreamHandle>> createStream() = 0;
virtual Expected<std::unique_ptr<PlatformStreamHandle>> createStream() {
return make_error("createStream not implemented for platform " + getName());
}
/// Launches a kernel on the given stream.
virtual Error launch(PlatformStreamHandle *S, BlockDimensions BlockSize,
GridDimensions GridSize, const KernelBase &Kernel,
GridDimensions GridSize, PlatformKernelHandle *K,
const PackedKernelArgumentArrayBase &ArgumentArray) {
return make_error("launch not implemented for platform " + getName());
}

View File

@ -86,15 +86,15 @@ public:
/// These arguments can be device memory types like GlobalDeviceMemory<T> and
/// SharedDeviceMemory<T>, or they can be primitive types such as int. The
/// allowable argument types are determined by the template parameters to the
/// TypedKernel argument.
/// Kernel argument.
template <typename... ParameterTs>
Stream &thenLaunch(BlockDimensions BlockSize, GridDimensions GridSize,
const TypedKernel<ParameterTs...> &Kernel,
const Kernel<ParameterTs...> &K,
const ParameterTs &... Arguments) {
auto ArgumentArray =
make_kernel_argument_pack<ParameterTs...>(Arguments...);
setError(PDevice->launch(ThePlatformStream.get(), BlockSize, GridSize,
Kernel, ArgumentArray));
K.getPlatformHandle(), ArgumentArray));
return *this;
}

View File

@ -20,26 +20,8 @@
namespace streamexecutor {
KernelBase::KernelBase(Device *Dev, const std::string &Name,
const std::string &DemangledName,
std::unique_ptr<KernelInterface> Implementation)
: TheDevice(Dev), Name(Name), DemangledName(DemangledName),
Implementation(std::move(Implementation)) {}
KernelBase::~KernelBase() = default;
Expected<KernelBase> KernelBase::create(Device *Dev,
const MultiKernelLoaderSpec &Spec) {
auto MaybeImplementation = Dev->getKernelImplementation(Spec);
if (!MaybeImplementation) {
return MaybeImplementation.takeError();
}
std::string Name = Spec.getKernelName();
std::string DemangledName =
llvm::symbolize::LLVMSymbolizer::DemangleName(Name, nullptr);
KernelBase Instance(Dev, Name, DemangledName,
std::move(*MaybeImplementation));
return std::move(Instance);
}
KernelBase::KernelBase(llvm::StringRef Name)
: Name(Name), DemangledName(llvm::symbolize::LLVMSymbolizer::DemangleName(
Name, nullptr)) {}
} // namespace streamexecutor

View File

@ -8,16 +8,6 @@ target_link_libraries(
${CMAKE_THREAD_LIBS_INIT})
add_test(DeviceTest device_test)
add_executable(
kernel_test
KernelTest.cpp)
target_link_libraries(
kernel_test
streamexecutor
${GTEST_BOTH_LIBRARIES}
${CMAKE_THREAD_LIBS_INIT})
add_test(KernelTest kernel_test)
add_executable(
kernel_spec_test
KernelSpecTest.cpp)

View File

@ -1,93 +0,0 @@
//===-- KernelTest.cpp - Tests for Kernel objects -------------------------===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file contains the unit tests for the code in Kernel.
///
//===----------------------------------------------------------------------===//
#include <cassert>
#include "streamexecutor/Device.h"
#include "streamexecutor/Kernel.h"
#include "streamexecutor/KernelSpec.h"
#include "streamexecutor/PlatformInterfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "gtest/gtest.h"
namespace {
namespace se = ::streamexecutor;
// A Device that returns a dummy KernelInterface.
//
// During construction it creates a unique_ptr to a dummy KernelInterface and it
// also stores a separate copy of the raw pointer that is stored by that
// unique_ptr.
//
// The expectation is that the code being tested will call the
// getKernelImplementation method and will thereby take ownership of the
// unique_ptr, but the copy of the raw pointer will stay behind in this mock
// object. The raw pointer copy can then be used to identify the unique_ptr in
// its new location (by comparing the raw pointer with unique_ptr::get), to
// verify that the unique_ptr ended up where it was supposed to be.
class MockDevice : public se::Device {
public:
MockDevice()
: se::Device(nullptr), Unique(llvm::make_unique<se::KernelInterface>()),
Raw(Unique.get()) {}
// Moves the unique pointer into the returned se::Expected instance.
//
// Asserts that it is not called again after the unique pointer has been moved
// out.
se::Expected<std::unique_ptr<se::KernelInterface>>
getKernelImplementation(const se::MultiKernelLoaderSpec &) override {
assert(Unique && "MockDevice getKernelImplementation should not be "
"called more than once");
return std::move(Unique);
}
// Gets the copy of the raw pointer from the original unique pointer.
const se::KernelInterface *getRaw() const { return Raw; }
private:
std::unique_ptr<se::KernelInterface> Unique;
const se::KernelInterface *Raw;
};
// Test fixture class for typed tests for KernelBase.getImplementation.
//
// The only purpose of this class is to provide a name that types can be bound
// to in the gtest infrastructure.
template <typename T> class GetImplementationTest : public ::testing::Test {};
// Types used with the GetImplementationTest fixture class.
typedef ::testing::Types<se::KernelBase, se::TypedKernel<>,
se::TypedKernel<int>>
GetImplementationTypes;
TYPED_TEST_CASE(GetImplementationTest, GetImplementationTypes);
// Tests that the kernel create functions properly fetch the implementation
// pointers for the kernel objects they construct from the passed-in
// Device objects.
TYPED_TEST(GetImplementationTest, SetImplementationDuringCreate) {
se::MultiKernelLoaderSpec Spec;
MockDevice Dev;
auto MaybeKernel = TypeParam::create(&Dev, Spec);
EXPECT_TRUE(static_cast<bool>(MaybeKernel));
se::KernelInterface *Implementation = MaybeKernel->getImplementation();
EXPECT_EQ(Dev.getRaw(), Implementation);
}
} // namespace