forked from OSchip/llvm-project
[SE] Platforms return Device values
Summary: Platforms were returning Device pointers, but a Device is now basically just a pointer to an underlying PlatformDevice, so we will now just pass it around as a value. Reviewers: jlebar Subscribers: jprice, jlebar, parallel_libs-commits Differential Revision: https://reviews.llvm.org/D24537 llvm-svn: 281422
This commit is contained in:
parent
6d5a29489a
commit
16a5352121
|
@ -108,25 +108,25 @@ int main() {
|
|||
if (Platform->getDeviceCount() == 0) {
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
se::Device *Device = getOrDie(Platform->getDevice(0));
|
||||
se::Device Device = getOrDie(Platform->getDevice(0));
|
||||
|
||||
// Load the kernel onto the device.
|
||||
cg::SaxpyKernel Kernel =
|
||||
getOrDie(Device->createKernel<cg::SaxpyKernel>(cg::SaxpyLoaderSpec));
|
||||
getOrDie(Device.createKernel<cg::SaxpyKernel>(cg::SaxpyLoaderSpec));
|
||||
|
||||
se::RegisteredHostMemory<float> RegisteredX =
|
||||
getOrDie(Device->registerHostMemory<float>(HostX));
|
||||
getOrDie(Device.registerHostMemory<float>(HostX));
|
||||
se::RegisteredHostMemory<float> RegisteredY =
|
||||
getOrDie(Device->registerHostMemory<float>(HostY));
|
||||
getOrDie(Device.registerHostMemory<float>(HostY));
|
||||
|
||||
// Allocate memory on the device.
|
||||
se::GlobalDeviceMemory<float> X =
|
||||
getOrDie(Device->allocateDeviceMemory<float>(ArraySize));
|
||||
getOrDie(Device.allocateDeviceMemory<float>(ArraySize));
|
||||
se::GlobalDeviceMemory<float> Y =
|
||||
getOrDie(Device->allocateDeviceMemory<float>(ArraySize));
|
||||
getOrDie(Device.allocateDeviceMemory<float>(ArraySize));
|
||||
|
||||
// Run operations on a stream.
|
||||
se::Stream Stream = getOrDie(Device->createStream());
|
||||
se::Stream Stream = getOrDie(Device.createStream());
|
||||
Stream.thenCopyH2D(RegisteredX, X)
|
||||
.thenCopyH2D(RegisteredY, Y)
|
||||
.thenLaunch(ArraySize, 1, Kernel, A, X, Y)
|
||||
|
|
|
@ -62,25 +62,25 @@ int main() {
|
|||
if (Platform->getDeviceCount() == 0) {
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
se::Device *Device = getOrDie(Platform->getDevice(0));
|
||||
se::Device Device = getOrDie(Platform->getDevice(0));
|
||||
|
||||
// Load the kernel onto the device.
|
||||
cg::SaxpyKernel Kernel =
|
||||
getOrDie(Device->createKernel<cg::SaxpyKernel>(cg::SaxpyLoaderSpec));
|
||||
getOrDie(Device.createKernel<cg::SaxpyKernel>(cg::SaxpyLoaderSpec));
|
||||
|
||||
se::RegisteredHostMemory<float> RegisteredX =
|
||||
getOrDie(Device->registerHostMemory<float>(HostX));
|
||||
getOrDie(Device.registerHostMemory<float>(HostX));
|
||||
se::RegisteredHostMemory<float> RegisteredY =
|
||||
getOrDie(Device->registerHostMemory<float>(HostY));
|
||||
getOrDie(Device.registerHostMemory<float>(HostY));
|
||||
|
||||
// Allocate memory on the device.
|
||||
se::GlobalDeviceMemory<float> X =
|
||||
getOrDie(Device->allocateDeviceMemory<float>(ArraySize));
|
||||
getOrDie(Device.allocateDeviceMemory<float>(ArraySize));
|
||||
se::GlobalDeviceMemory<float> Y =
|
||||
getOrDie(Device->allocateDeviceMemory<float>(ArraySize));
|
||||
getOrDie(Device.allocateDeviceMemory<float>(ArraySize));
|
||||
|
||||
// Run operations on a stream.
|
||||
se::Stream Stream = getOrDie(Device->createStream());
|
||||
se::Stream Stream = getOrDie(Device.createStream());
|
||||
Stream.thenCopyH2D(RegisteredX, X)
|
||||
.thenCopyH2D(RegisteredY, Y)
|
||||
.thenLaunch(1, 1, Kernel, A, X, Y, ArraySize)
|
||||
|
|
|
@ -31,10 +31,8 @@ public:
|
|||
/// Gets the number of devices available for this platform.
|
||||
virtual size_t getDeviceCount() const = 0;
|
||||
|
||||
/// Gets a pointer to a Device with the given index for this platform.
|
||||
///
|
||||
/// Ownership of the Device instance is NOT transferred to the caller.
|
||||
virtual Expected<Device *> getDevice(size_t DeviceIndex) = 0;
|
||||
/// Gets a Device with the given index for this platform.
|
||||
virtual Expected<Device> getDevice(size_t DeviceIndex) = 0;
|
||||
};
|
||||
|
||||
} // namespace streamexecutor
|
||||
|
|
|
@ -30,24 +30,21 @@ class HostPlatform : public Platform {
|
|||
public:
|
||||
size_t getDeviceCount() const override { return 1; }
|
||||
|
||||
Expected<Device *> getDevice(size_t DeviceIndex) override {
|
||||
Expected<Device> getDevice(size_t DeviceIndex) override {
|
||||
if (DeviceIndex != 0) {
|
||||
return make_error(
|
||||
"Requested device index " + llvm::Twine(DeviceIndex) +
|
||||
" from host platform which only supports device index 0");
|
||||
}
|
||||
llvm::sys::ScopedLock Lock(Mutex);
|
||||
if (!TheDevice) {
|
||||
if (!ThePlatformDevice)
|
||||
ThePlatformDevice = llvm::make_unique<HostPlatformDevice>();
|
||||
TheDevice = llvm::make_unique<Device>(ThePlatformDevice.get());
|
||||
}
|
||||
return TheDevice.get();
|
||||
return Device(ThePlatformDevice.get());
|
||||
}
|
||||
|
||||
private:
|
||||
llvm::sys::Mutex Mutex;
|
||||
std::unique_ptr<HostPlatformDevice> ThePlatformDevice;
|
||||
std::unique_ptr<Device> TheDevice;
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
|
|
Loading…
Reference in New Issue