!44737 Add util interface for gpu hash table.

Merge pull request !44737 from gaoyong10/dynamic_shape_04
This commit is contained in:
i-robot 2022-11-03 22:32:13 +00:00 committed by Gitee
commit 288bd3bcc1
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 145 additions and 59 deletions

View File

@ -23,7 +23,7 @@
#include "plugin/device/gpu/hal/device/gpu_device_manager.h" #include "plugin/device/gpu/hal/device/gpu_device_manager.h"
#include "plugin/device/gpu/hal/device/gpu_memory_allocator.h" #include "plugin/device/gpu/hal/device/gpu_memory_allocator.h"
#include "plugin/device/gpu/hal/hardware/gpu_device_context.h" #include "plugin/device/gpu/hal/hardware/gpu_device_context.h"
#include "plugin/device/gpu/hal/device/gpu_hash_table.h" #include "plugin/device/gpu/hal/device/gpu_hash_table_util.h"
#include "plugin/device/gpu/hal/device/gpu_common.h" #include "plugin/device/gpu/hal/device/gpu_common.h"
#ifdef ENABLE_DEBUGGER #ifdef ENABLE_DEBUGGER
#include "debug/debug_services.h" #include "debug/debug_services.h"
@ -128,23 +128,20 @@ bool SyncUserDataToDevice(const UserDataPtr &user_data, const void *host_ptr, si
MS_EXCEPTION_IF_NULL(user_data_type); MS_EXCEPTION_IF_NULL(user_data_type);
if (*user_data_type == UserDataType::kUserTypeHashTable) { if (*user_data_type == UserDataType::kUserTypeHashTable) {
#if CUDA_VERSION > 11000 && defined(__linux__)
auto key_type = user_data->get<TypeId>(kHashTableKeyType); auto key_type = user_data->get<TypeId>(kHashTableKeyType);
auto value_type = user_data->get<TypeId>(kHashTableValueType); auto value_type = user_data->get<TypeId>(kHashTableValueType);
MS_EXCEPTION_IF_NULL(key_type); MS_EXCEPTION_IF_NULL(key_type);
MS_EXCEPTION_IF_NULL(value_type); MS_EXCEPTION_IF_NULL(value_type);
if (*key_type == TypeId::kNumberTypeInt32 && *value_type == TypeId::kNumberTypeFloat32) { const auto &iter = hashtable_func_list.find({*key_type, *value_type});
#if CUDA_VERSION > 11000 if (iter != hashtable_func_list.end()) {
const auto &gpu_hash_table = user_data->get<GPUHashTable<int, float>>(kUserDataData); return std::get<kSyncFuncIndex>(iter->second)(user_data, host_ptr, size);
MS_EXCEPTION_IF_NULL(gpu_hash_table);
if (!gpu_hash_table->Import({const_cast<void *>(host_ptr), size})) {
MS_LOG(EXCEPTION) << "Import for hash table failed.";
}
#else
MS_LOG(EXCEPTION) << "Unsupported cuda version.";
#endif
} else { } else {
MS_LOG(EXCEPTION) << "Unsupported hash table type:" << *key_type << " and:" << *value_type; MS_LOG(EXCEPTION) << "Unsupported hash table type:" << *key_type << " and:" << *value_type;
} }
#else
MS_LOG(EXCEPTION) << "Invalid platform or cuda version for gpu hash table.";
#endif
} }
return true; return true;
} }
@ -247,6 +244,31 @@ void GPUDeviceAddress::ClearDeviceMemory() {
} }
} }
void GPUDeviceAddress::ClearUserData() {
if (user_data_ == nullptr) {
return;
}
auto user_data_type = user_data_->get<UserDataType>(kUserDataType);
MS_EXCEPTION_IF_NULL(user_data_type);
if (*user_data_type == UserDataType::kUserTypeHashTable) {
#if CUDA_VERSION > 11000 && defined(__linux__)
auto key_type = user_data_->get<TypeId>(kHashTableKeyType);
auto value_type = user_data_->get<TypeId>(kHashTableValueType);
MS_EXCEPTION_IF_NULL(key_type);
MS_EXCEPTION_IF_NULL(value_type);
const auto &iter = hashtable_func_list.find({*key_type, *value_type});
if (iter != hashtable_func_list.end()) {
return std::get<kClearFuncIndex>(iter->second)(user_data_);
} else {
MS_LOG(EXCEPTION) << "Unsupported hash table type:" << *key_type << " and:" << *value_type;
}
#else
MS_LOG(EXCEPTION) << "Invalid platform or cuda version for gpu hash table.";
#endif
}
}
GPUDeviceAddress::~GPUDeviceAddress() { ClearDeviceMemory(); } GPUDeviceAddress::~GPUDeviceAddress() { ClearDeviceMemory(); }
/* /*

View File

@ -66,6 +66,7 @@ class GPUDeviceAddress : public LoadableDeviceAddress {
// Asynchronously copy device memory to host side. // Asynchronously copy device memory to host side.
bool AsyncDeviceToHost(const ShapeVector &, size_t size, TypeId, void *host_ptr, size_t stream_id) const override; bool AsyncDeviceToHost(const ShapeVector &, size_t size, TypeId, void *host_ptr, size_t stream_id) const override;
void ClearUserData() override;
private: private:
DeviceAddressStatus status_{DeviceAddressStatus::kInDevice}; DeviceAddressStatus status_{DeviceAddressStatus::kInDevice};

View File

@ -16,6 +16,7 @@
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_H_ #ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_H_ #define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_H_
#if defined(__linux__)
#include <cuda.h> #include <cuda.h>
#if CUDA_VERSION > 11000 #if CUDA_VERSION > 11000
#include <curand_kernel.h> #include <curand_kernel.h>
@ -190,4 +191,5 @@ class GPUHashTable : public HashTable<Key, Value> {
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore
#endif #endif
#endif
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_H_ #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_H_

View File

@ -0,0 +1,97 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_UTIL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_UTIL_H_
#include "plugin/device/gpu/hal/device/gpu_hash_table.h"
#include <map>
#include <tuple>
#include <utility>
#include <string>
#include <memory>
#if CUDA_VERSION > 11000 && defined(__linux__)
namespace mindspore {
namespace device {
namespace gpu {
using SetHashTableFunc = std::function<void(const UserDataPtr &)>;
using SyncHashTableFunc = std::function<bool(const UserDataPtr &, const void *, size_t)>;
using ClearHashTableFunc = std::function<void(const UserDataPtr &)>;
template <typename KeyType, typename ValueType>
void SetHashTable(const UserDataPtr &user_data) {
MS_EXCEPTION_IF_NULL(user_data);
auto shape_vector = user_data->get<ShapeVector>(kHashTableShapeVector);
auto default_value = user_data->get<Value>(kHashTableDefaultValue);
MS_EXCEPTION_IF_NULL(shape_vector);
MS_EXCEPTION_IF_NULL(default_value);
int32_t value_size = 1;
for (size_t i = 0; i < (*shape_vector).size(); ++i) {
value_size *= (*shape_vector)[i];
}
if (value_size <= 0) {
MS_LOG(WARNING) << "Invalid value size:" << value_size;
}
if (default_value->isa<StringImm>()) {
user_data->set<GPUHashTable<KeyType, ValueType>>(
kUserDataData,
std::make_shared<GPUHashTable<KeyType, ValueType>>(value_size, GetValue<std::string>(default_value)));
} else if (default_value->isa<FloatImm>()) {
user_data->set<GPUHashTable<KeyType, ValueType>>(
kUserDataData, std::make_shared<GPUHashTable<KeyType, float>>(value_size, GetValue<float>(default_value)));
} else {
MS_LOG(EXCEPTION) << "Invalid default value:" << default_value;
}
}
template <typename KeyType, typename ValueType>
bool SyncHashTable(const UserDataPtr &user_data, const void *host_ptr, size_t size) {
MS_EXCEPTION_IF_NULL(user_data);
MS_EXCEPTION_IF_NULL(host_ptr);
const auto &gpu_hash_table = user_data->get<GPUHashTable<KeyType, ValueType>>(kUserDataData);
MS_EXCEPTION_IF_NULL(gpu_hash_table);
if (!gpu_hash_table->Import({const_cast<void *>(host_ptr), size})) {
MS_LOG(ERROR) << "Import for hash table failed.";
return false;
}
return true;
}
template <typename KeyType, typename ValueType>
void ClearHashTable(const UserDataPtr &user_data) {
MS_EXCEPTION_IF_NULL(user_data);
const auto &user_data_data = user_data->get<GPUHashTable<KeyType, ValueType>>(kUserDataData);
MS_EXCEPTION_IF_NULL(user_data_data);
if (!user_data_data->Clear()) {
MS_LOG(EXCEPTION) << "Clear user data failed.";
}
}
static std::map<std::pair<TypeId, TypeId>, std::tuple<SetHashTableFunc, SyncHashTableFunc, ClearHashTableFunc>>
hashtable_func_list = {
{std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32),
std::make_tuple(SetHashTable<int, float>, SyncHashTable<int, float>, ClearHashTable<int, float>)},
{std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeFloat32),
std::make_tuple(SetHashTable<int64_t, float>, SyncHashTable<int64_t, float>, ClearHashTable<int64_t, float>)}};
constexpr size_t kSetFuncIndex = 0;
constexpr size_t kSyncFuncIndex = 1;
constexpr size_t kClearFuncIndex = 2;
} // namespace gpu
} // namespace device
} // namespace mindspore
#endif
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_UTIL_H_

View File

@ -40,7 +40,7 @@
#include "backend/common/session/kernel_graph.h" #include "backend/common/session/kernel_graph.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h" #include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" #include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/hal/device/gpu_hash_table.h" #include "plugin/device/gpu/hal/device/gpu_hash_table_util.h"
#include "plugin/device/gpu/optimizer/reg_gpu_const_input_to_attr.h" #include "plugin/device/gpu/optimizer/reg_gpu_const_input_to_attr.h"
#include "backend/common/optimizer/common_backend_optimization.h" #include "backend/common/optimizer/common_backend_optimization.h"
#include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h" #include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
@ -262,38 +262,22 @@ void SetUserData(DeviceAddress *device_address, const UserDataPtr &user_data) {
const auto &user_data_type = user_data->get<UserDataType>(kUserDataType); const auto &user_data_type = user_data->get<UserDataType>(kUserDataType);
MS_EXCEPTION_IF_NULL(user_data_type); MS_EXCEPTION_IF_NULL(user_data_type);
if (*user_data_type == UserDataType::kUserTypeHashTable) { if (*user_data_type == UserDataType::kUserTypeHashTable) {
auto shape_vector = user_data->get<ShapeVector>(kHashTableShapeVector); #if CUDA_VERSION > 11000 && defined(__linux__)
auto key_type = user_data->get<TypeId>(kHashTableKeyType); auto key_type = user_data->get<TypeId>(kHashTableKeyType);
auto value_type = user_data->get<TypeId>(kHashTableValueType); auto value_type = user_data->get<TypeId>(kHashTableValueType);
auto default_value = user_data->get<Value>(kHashTableDefaultValue);
MS_EXCEPTION_IF_NULL(shape_vector);
MS_EXCEPTION_IF_NULL(key_type); MS_EXCEPTION_IF_NULL(key_type);
MS_EXCEPTION_IF_NULL(value_type); MS_EXCEPTION_IF_NULL(value_type);
MS_EXCEPTION_IF_NULL(default_value); const auto &iter = hashtable_func_list.find({*key_type, *value_type});
if (*key_type == TypeId::kNumberTypeInt32 && *value_type == TypeId::kNumberTypeFloat32) { if (iter != hashtable_func_list.end()) {
int32_t value_size = 1; return std::get<kSetFuncIndex>(iter->second)(user_data);
for (size_t i = 0; i < (*shape_vector).size(); ++i) {
value_size *= (*shape_vector)[i];
}
if (value_size <= 0) {
MS_LOG(WARNING) << "Invalid value size:" << value_size;
}
#if CUDA_VERSION > 11000
if (default_value->isa<StringImm>()) {
user_data->set<GPUHashTable<int, float>>(
kUserDataData, std::make_shared<GPUHashTable<int, float>>(value_size, GetValue<std::string>(default_value)));
} else if (default_value->isa<FloatImm>()) {
user_data->set<GPUHashTable<int, float>>(
kUserDataData, std::make_shared<GPUHashTable<int, float>>(value_size, GetValue<float>(default_value)));
} else { } else {
MS_LOG(EXCEPTION) << "Invalid default value:" << default_value; MS_LOG(EXCEPTION) << "Unsupported hash table type:" << *key_type << " and:" << *value_type;
} }
#else #else
MS_LOG(EXCEPTION) << "Invalid cuda version for gpu hash table."; MS_LOG(EXCEPTION) << "Invalid platform or cuda version for gpu hash table.";
#endif #endif
} else { } else {
MS_LOG(EXCEPTION) << "Unsupported key type:" << key_type << " value type:" << value_type; MS_LOG(EXCEPTION) << "Invalid user data type:" << *user_data_type;
}
} }
} }
} // namespace } // namespace

View File

@ -200,27 +200,7 @@ class DeviceAddress : public mindspore::DeviceSync {
UserDataPtr user_data() const { return user_data_; } UserDataPtr user_data() const { return user_data_; }
void set_user_data(const UserDataPtr &user_data) { user_data_ = user_data; } void set_user_data(const UserDataPtr &user_data) { user_data_ = user_data; }
// Free the ptr in user data when the ref count is 0. // Free the ptr in user data when the ref count is 0.
virtual void ClearUserData() { virtual void ClearUserData() {}
if (user_data_ == nullptr) {
return;
}
auto user_data_type = user_data_->get<UserDataType>(kUserDataType);
MS_EXCEPTION_IF_NULL(user_data_type);
if (*user_data_type == UserDataType::kUserTypeHashTable) {
const auto &key_type = user_data_->get<TypeId>(kHashTableKeyType);
const auto &value_type = user_data_->get<TypeId>(kHashTableValueType);
MS_EXCEPTION_IF_NULL(key_type);
MS_EXCEPTION_IF_NULL(value_type);
if (*key_type == TypeId::kNumberTypeInt32 && *value_type == TypeId::kNumberTypeFloat32) {
const auto &user_data_data = user_data_->get<HashTable<int, float>>(kUserDataData);
MS_EXCEPTION_IF_NULL(user_data_data);
if (!user_data_data->Clear()) {
MS_LOG(EXCEPTION) << "Clear user data failed.";
}
}
}
}
protected: protected:
const void *ptr() const { return ptr_; } const void *ptr() const { return ptr_; }