!44737 Add util interface for gpu hash table.
Merge pull request !44737 from gaoyong10/dynamic_shape_04
This commit is contained in:
commit
288bd3bcc1
|
@ -23,7 +23,7 @@
|
||||||
#include "plugin/device/gpu/hal/device/gpu_device_manager.h"
|
#include "plugin/device/gpu/hal/device/gpu_device_manager.h"
|
||||||
#include "plugin/device/gpu/hal/device/gpu_memory_allocator.h"
|
#include "plugin/device/gpu/hal/device/gpu_memory_allocator.h"
|
||||||
#include "plugin/device/gpu/hal/hardware/gpu_device_context.h"
|
#include "plugin/device/gpu/hal/hardware/gpu_device_context.h"
|
||||||
#include "plugin/device/gpu/hal/device/gpu_hash_table.h"
|
#include "plugin/device/gpu/hal/device/gpu_hash_table_util.h"
|
||||||
#include "plugin/device/gpu/hal/device/gpu_common.h"
|
#include "plugin/device/gpu/hal/device/gpu_common.h"
|
||||||
#ifdef ENABLE_DEBUGGER
|
#ifdef ENABLE_DEBUGGER
|
||||||
#include "debug/debug_services.h"
|
#include "debug/debug_services.h"
|
||||||
|
@ -128,23 +128,20 @@ bool SyncUserDataToDevice(const UserDataPtr &user_data, const void *host_ptr, si
|
||||||
MS_EXCEPTION_IF_NULL(user_data_type);
|
MS_EXCEPTION_IF_NULL(user_data_type);
|
||||||
|
|
||||||
if (*user_data_type == UserDataType::kUserTypeHashTable) {
|
if (*user_data_type == UserDataType::kUserTypeHashTable) {
|
||||||
|
#if CUDA_VERSION > 11000 && defined(__linux__)
|
||||||
auto key_type = user_data->get<TypeId>(kHashTableKeyType);
|
auto key_type = user_data->get<TypeId>(kHashTableKeyType);
|
||||||
auto value_type = user_data->get<TypeId>(kHashTableValueType);
|
auto value_type = user_data->get<TypeId>(kHashTableValueType);
|
||||||
MS_EXCEPTION_IF_NULL(key_type);
|
MS_EXCEPTION_IF_NULL(key_type);
|
||||||
MS_EXCEPTION_IF_NULL(value_type);
|
MS_EXCEPTION_IF_NULL(value_type);
|
||||||
if (*key_type == TypeId::kNumberTypeInt32 && *value_type == TypeId::kNumberTypeFloat32) {
|
const auto &iter = hashtable_func_list.find({*key_type, *value_type});
|
||||||
#if CUDA_VERSION > 11000
|
if (iter != hashtable_func_list.end()) {
|
||||||
const auto &gpu_hash_table = user_data->get<GPUHashTable<int, float>>(kUserDataData);
|
return std::get<kSyncFuncIndex>(iter->second)(user_data, host_ptr, size);
|
||||||
MS_EXCEPTION_IF_NULL(gpu_hash_table);
|
|
||||||
if (!gpu_hash_table->Import({const_cast<void *>(host_ptr), size})) {
|
|
||||||
MS_LOG(EXCEPTION) << "Import for hash table failed.";
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
MS_LOG(EXCEPTION) << "Unsupported cuda version.";
|
|
||||||
#endif
|
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "Unsupported hash table type:" << *key_type << " and:" << *value_type;
|
MS_LOG(EXCEPTION) << "Unsupported hash table type:" << *key_type << " and:" << *value_type;
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid platform or cuda version for gpu hash table.";
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -247,6 +244,31 @@ void GPUDeviceAddress::ClearDeviceMemory() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GPUDeviceAddress::ClearUserData() {
|
||||||
|
if (user_data_ == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto user_data_type = user_data_->get<UserDataType>(kUserDataType);
|
||||||
|
MS_EXCEPTION_IF_NULL(user_data_type);
|
||||||
|
if (*user_data_type == UserDataType::kUserTypeHashTable) {
|
||||||
|
#if CUDA_VERSION > 11000 && defined(__linux__)
|
||||||
|
auto key_type = user_data_->get<TypeId>(kHashTableKeyType);
|
||||||
|
auto value_type = user_data_->get<TypeId>(kHashTableValueType);
|
||||||
|
MS_EXCEPTION_IF_NULL(key_type);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_type);
|
||||||
|
const auto &iter = hashtable_func_list.find({*key_type, *value_type});
|
||||||
|
if (iter != hashtable_func_list.end()) {
|
||||||
|
return std::get<kClearFuncIndex>(iter->second)(user_data_);
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "Unsupported hash table type:" << *key_type << " and:" << *value_type;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid platform or cuda version for gpu hash table.";
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
GPUDeviceAddress::~GPUDeviceAddress() { ClearDeviceMemory(); }
|
GPUDeviceAddress::~GPUDeviceAddress() { ClearDeviceMemory(); }
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
|
@ -66,6 +66,7 @@ class GPUDeviceAddress : public LoadableDeviceAddress {
|
||||||
|
|
||||||
// Asynchronously copy device memory to host side.
|
// Asynchronously copy device memory to host side.
|
||||||
bool AsyncDeviceToHost(const ShapeVector &, size_t size, TypeId, void *host_ptr, size_t stream_id) const override;
|
bool AsyncDeviceToHost(const ShapeVector &, size_t size, TypeId, void *host_ptr, size_t stream_id) const override;
|
||||||
|
void ClearUserData() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DeviceAddressStatus status_{DeviceAddressStatus::kInDevice};
|
DeviceAddressStatus status_{DeviceAddressStatus::kInDevice};
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_H_
|
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_H_
|
||||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_H_
|
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_H_
|
||||||
|
|
||||||
|
#if defined(__linux__)
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#if CUDA_VERSION > 11000
|
#if CUDA_VERSION > 11000
|
||||||
#include <curand_kernel.h>
|
#include <curand_kernel.h>
|
||||||
|
@ -190,4 +191,5 @@ class GPUHashTable : public HashTable<Key, Value> {
|
||||||
} // namespace device
|
} // namespace device
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_H_
|
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_H_
|
||||||
|
|
|
@ -0,0 +1,97 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_UTIL_H_
|
||||||
|
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_UTIL_H_
|
||||||
|
|
||||||
|
#include "plugin/device/gpu/hal/device/gpu_hash_table.h"
|
||||||
|
#include <map>
|
||||||
|
#include <tuple>
|
||||||
|
#include <utility>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#if CUDA_VERSION > 11000 && defined(__linux__)
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace gpu {
|
||||||
|
using SetHashTableFunc = std::function<void(const UserDataPtr &)>;
|
||||||
|
using SyncHashTableFunc = std::function<bool(const UserDataPtr &, const void *, size_t)>;
|
||||||
|
using ClearHashTableFunc = std::function<void(const UserDataPtr &)>;
|
||||||
|
|
||||||
|
template <typename KeyType, typename ValueType>
|
||||||
|
void SetHashTable(const UserDataPtr &user_data) {
|
||||||
|
MS_EXCEPTION_IF_NULL(user_data);
|
||||||
|
auto shape_vector = user_data->get<ShapeVector>(kHashTableShapeVector);
|
||||||
|
auto default_value = user_data->get<Value>(kHashTableDefaultValue);
|
||||||
|
MS_EXCEPTION_IF_NULL(shape_vector);
|
||||||
|
MS_EXCEPTION_IF_NULL(default_value);
|
||||||
|
int32_t value_size = 1;
|
||||||
|
for (size_t i = 0; i < (*shape_vector).size(); ++i) {
|
||||||
|
value_size *= (*shape_vector)[i];
|
||||||
|
}
|
||||||
|
if (value_size <= 0) {
|
||||||
|
MS_LOG(WARNING) << "Invalid value size:" << value_size;
|
||||||
|
}
|
||||||
|
if (default_value->isa<StringImm>()) {
|
||||||
|
user_data->set<GPUHashTable<KeyType, ValueType>>(
|
||||||
|
kUserDataData,
|
||||||
|
std::make_shared<GPUHashTable<KeyType, ValueType>>(value_size, GetValue<std::string>(default_value)));
|
||||||
|
} else if (default_value->isa<FloatImm>()) {
|
||||||
|
user_data->set<GPUHashTable<KeyType, ValueType>>(
|
||||||
|
kUserDataData, std::make_shared<GPUHashTable<KeyType, float>>(value_size, GetValue<float>(default_value)));
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid default value:" << default_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename KeyType, typename ValueType>
|
||||||
|
bool SyncHashTable(const UserDataPtr &user_data, const void *host_ptr, size_t size) {
|
||||||
|
MS_EXCEPTION_IF_NULL(user_data);
|
||||||
|
MS_EXCEPTION_IF_NULL(host_ptr);
|
||||||
|
const auto &gpu_hash_table = user_data->get<GPUHashTable<KeyType, ValueType>>(kUserDataData);
|
||||||
|
MS_EXCEPTION_IF_NULL(gpu_hash_table);
|
||||||
|
if (!gpu_hash_table->Import({const_cast<void *>(host_ptr), size})) {
|
||||||
|
MS_LOG(ERROR) << "Import for hash table failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename KeyType, typename ValueType>
|
||||||
|
void ClearHashTable(const UserDataPtr &user_data) {
|
||||||
|
MS_EXCEPTION_IF_NULL(user_data);
|
||||||
|
const auto &user_data_data = user_data->get<GPUHashTable<KeyType, ValueType>>(kUserDataData);
|
||||||
|
MS_EXCEPTION_IF_NULL(user_data_data);
|
||||||
|
if (!user_data_data->Clear()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Clear user data failed.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::map<std::pair<TypeId, TypeId>, std::tuple<SetHashTableFunc, SyncHashTableFunc, ClearHashTableFunc>>
|
||||||
|
hashtable_func_list = {
|
||||||
|
{std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32),
|
||||||
|
std::make_tuple(SetHashTable<int, float>, SyncHashTable<int, float>, ClearHashTable<int, float>)},
|
||||||
|
{std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeFloat32),
|
||||||
|
std::make_tuple(SetHashTable<int64_t, float>, SyncHashTable<int64_t, float>, ClearHashTable<int64_t, float>)}};
|
||||||
|
|
||||||
|
constexpr size_t kSetFuncIndex = 0;
|
||||||
|
constexpr size_t kSyncFuncIndex = 1;
|
||||||
|
constexpr size_t kClearFuncIndex = 2;
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif
|
||||||
|
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_UTIL_H_
|
|
@ -40,7 +40,7 @@
|
||||||
#include "backend/common/session/kernel_graph.h"
|
#include "backend/common/session/kernel_graph.h"
|
||||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||||
#include "plugin/device/gpu/hal/device/gpu_hash_table.h"
|
#include "plugin/device/gpu/hal/device/gpu_hash_table_util.h"
|
||||||
#include "plugin/device/gpu/optimizer/reg_gpu_const_input_to_attr.h"
|
#include "plugin/device/gpu/optimizer/reg_gpu_const_input_to_attr.h"
|
||||||
#include "backend/common/optimizer/common_backend_optimization.h"
|
#include "backend/common/optimizer/common_backend_optimization.h"
|
||||||
#include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
|
#include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
|
||||||
|
@ -262,38 +262,22 @@ void SetUserData(DeviceAddress *device_address, const UserDataPtr &user_data) {
|
||||||
const auto &user_data_type = user_data->get<UserDataType>(kUserDataType);
|
const auto &user_data_type = user_data->get<UserDataType>(kUserDataType);
|
||||||
MS_EXCEPTION_IF_NULL(user_data_type);
|
MS_EXCEPTION_IF_NULL(user_data_type);
|
||||||
if (*user_data_type == UserDataType::kUserTypeHashTable) {
|
if (*user_data_type == UserDataType::kUserTypeHashTable) {
|
||||||
auto shape_vector = user_data->get<ShapeVector>(kHashTableShapeVector);
|
#if CUDA_VERSION > 11000 && defined(__linux__)
|
||||||
auto key_type = user_data->get<TypeId>(kHashTableKeyType);
|
auto key_type = user_data->get<TypeId>(kHashTableKeyType);
|
||||||
auto value_type = user_data->get<TypeId>(kHashTableValueType);
|
auto value_type = user_data->get<TypeId>(kHashTableValueType);
|
||||||
auto default_value = user_data->get<Value>(kHashTableDefaultValue);
|
|
||||||
MS_EXCEPTION_IF_NULL(shape_vector);
|
|
||||||
MS_EXCEPTION_IF_NULL(key_type);
|
MS_EXCEPTION_IF_NULL(key_type);
|
||||||
MS_EXCEPTION_IF_NULL(value_type);
|
MS_EXCEPTION_IF_NULL(value_type);
|
||||||
MS_EXCEPTION_IF_NULL(default_value);
|
const auto &iter = hashtable_func_list.find({*key_type, *value_type});
|
||||||
if (*key_type == TypeId::kNumberTypeInt32 && *value_type == TypeId::kNumberTypeFloat32) {
|
if (iter != hashtable_func_list.end()) {
|
||||||
int32_t value_size = 1;
|
return std::get<kSetFuncIndex>(iter->second)(user_data);
|
||||||
for (size_t i = 0; i < (*shape_vector).size(); ++i) {
|
|
||||||
value_size *= (*shape_vector)[i];
|
|
||||||
}
|
|
||||||
if (value_size <= 0) {
|
|
||||||
MS_LOG(WARNING) << "Invalid value size:" << value_size;
|
|
||||||
}
|
|
||||||
#if CUDA_VERSION > 11000
|
|
||||||
if (default_value->isa<StringImm>()) {
|
|
||||||
user_data->set<GPUHashTable<int, float>>(
|
|
||||||
kUserDataData, std::make_shared<GPUHashTable<int, float>>(value_size, GetValue<std::string>(default_value)));
|
|
||||||
} else if (default_value->isa<FloatImm>()) {
|
|
||||||
user_data->set<GPUHashTable<int, float>>(
|
|
||||||
kUserDataData, std::make_shared<GPUHashTable<int, float>>(value_size, GetValue<float>(default_value)));
|
|
||||||
} else {
|
|
||||||
MS_LOG(EXCEPTION) << "Invalid default value:" << default_value;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
MS_LOG(EXCEPTION) << "Invalid cuda version for gpu hash table.";
|
|
||||||
#endif
|
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "Unsupported key type:" << key_type << " value type:" << value_type;
|
MS_LOG(EXCEPTION) << "Unsupported hash table type:" << *key_type << " and:" << *value_type;
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid platform or cuda version for gpu hash table.";
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid user data type:" << *user_data_type;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -200,27 +200,7 @@ class DeviceAddress : public mindspore::DeviceSync {
|
||||||
UserDataPtr user_data() const { return user_data_; }
|
UserDataPtr user_data() const { return user_data_; }
|
||||||
void set_user_data(const UserDataPtr &user_data) { user_data_ = user_data; }
|
void set_user_data(const UserDataPtr &user_data) { user_data_ = user_data; }
|
||||||
// Free the ptr in user data when the ref count is 0.
|
// Free the ptr in user data when the ref count is 0.
|
||||||
virtual void ClearUserData() {
|
virtual void ClearUserData() {}
|
||||||
if (user_data_ == nullptr) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto user_data_type = user_data_->get<UserDataType>(kUserDataType);
|
|
||||||
MS_EXCEPTION_IF_NULL(user_data_type);
|
|
||||||
if (*user_data_type == UserDataType::kUserTypeHashTable) {
|
|
||||||
const auto &key_type = user_data_->get<TypeId>(kHashTableKeyType);
|
|
||||||
const auto &value_type = user_data_->get<TypeId>(kHashTableValueType);
|
|
||||||
MS_EXCEPTION_IF_NULL(key_type);
|
|
||||||
MS_EXCEPTION_IF_NULL(value_type);
|
|
||||||
if (*key_type == TypeId::kNumberTypeInt32 && *value_type == TypeId::kNumberTypeFloat32) {
|
|
||||||
const auto &user_data_data = user_data_->get<HashTable<int, float>>(kUserDataData);
|
|
||||||
MS_EXCEPTION_IF_NULL(user_data_data);
|
|
||||||
if (!user_data_data->Clear()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Clear user data failed.";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
const void *ptr() const { return ptr_; }
|
const void *ptr() const { return ptr_; }
|
||||||
|
|
Loading…
Reference in New Issue