!44737 Add util interface for gpu hash table.
Merge pull request !44737 from gaoyong10/dynamic_shape_04
This commit is contained in:
commit
288bd3bcc1
|
@ -23,7 +23,7 @@
|
|||
#include "plugin/device/gpu/hal/device/gpu_device_manager.h"
|
||||
#include "plugin/device/gpu/hal/device/gpu_memory_allocator.h"
|
||||
#include "plugin/device/gpu/hal/hardware/gpu_device_context.h"
|
||||
#include "plugin/device/gpu/hal/device/gpu_hash_table.h"
|
||||
#include "plugin/device/gpu/hal/device/gpu_hash_table_util.h"
|
||||
#include "plugin/device/gpu/hal/device/gpu_common.h"
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
#include "debug/debug_services.h"
|
||||
|
@ -128,23 +128,20 @@ bool SyncUserDataToDevice(const UserDataPtr &user_data, const void *host_ptr, si
|
|||
MS_EXCEPTION_IF_NULL(user_data_type);
|
||||
|
||||
if (*user_data_type == UserDataType::kUserTypeHashTable) {
|
||||
#if CUDA_VERSION > 11000 && defined(__linux__)
|
||||
auto key_type = user_data->get<TypeId>(kHashTableKeyType);
|
||||
auto value_type = user_data->get<TypeId>(kHashTableValueType);
|
||||
MS_EXCEPTION_IF_NULL(key_type);
|
||||
MS_EXCEPTION_IF_NULL(value_type);
|
||||
if (*key_type == TypeId::kNumberTypeInt32 && *value_type == TypeId::kNumberTypeFloat32) {
|
||||
#if CUDA_VERSION > 11000
|
||||
const auto &gpu_hash_table = user_data->get<GPUHashTable<int, float>>(kUserDataData);
|
||||
MS_EXCEPTION_IF_NULL(gpu_hash_table);
|
||||
if (!gpu_hash_table->Import({const_cast<void *>(host_ptr), size})) {
|
||||
MS_LOG(EXCEPTION) << "Import for hash table failed.";
|
||||
}
|
||||
#else
|
||||
MS_LOG(EXCEPTION) << "Unsupported cuda version.";
|
||||
#endif
|
||||
const auto &iter = hashtable_func_list.find({*key_type, *value_type});
|
||||
if (iter != hashtable_func_list.end()) {
|
||||
return std::get<kSyncFuncIndex>(iter->second)(user_data, host_ptr, size);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported hash table type:" << *key_type << " and:" << *value_type;
|
||||
}
|
||||
#else
|
||||
MS_LOG(EXCEPTION) << "Invalid platform or cuda version for gpu hash table.";
|
||||
#endif
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -247,6 +244,31 @@ void GPUDeviceAddress::ClearDeviceMemory() {
|
|||
}
|
||||
}
|
||||
|
||||
void GPUDeviceAddress::ClearUserData() {
|
||||
if (user_data_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto user_data_type = user_data_->get<UserDataType>(kUserDataType);
|
||||
MS_EXCEPTION_IF_NULL(user_data_type);
|
||||
if (*user_data_type == UserDataType::kUserTypeHashTable) {
|
||||
#if CUDA_VERSION > 11000 && defined(__linux__)
|
||||
auto key_type = user_data_->get<TypeId>(kHashTableKeyType);
|
||||
auto value_type = user_data_->get<TypeId>(kHashTableValueType);
|
||||
MS_EXCEPTION_IF_NULL(key_type);
|
||||
MS_EXCEPTION_IF_NULL(value_type);
|
||||
const auto &iter = hashtable_func_list.find({*key_type, *value_type});
|
||||
if (iter != hashtable_func_list.end()) {
|
||||
return std::get<kClearFuncIndex>(iter->second)(user_data_);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported hash table type:" << *key_type << " and:" << *value_type;
|
||||
}
|
||||
#else
|
||||
MS_LOG(EXCEPTION) << "Invalid platform or cuda version for gpu hash table.";
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
GPUDeviceAddress::~GPUDeviceAddress() { ClearDeviceMemory(); }
|
||||
|
||||
/*
|
||||
|
|
|
@ -66,6 +66,7 @@ class GPUDeviceAddress : public LoadableDeviceAddress {
|
|||
|
||||
// Asynchronously copy device memory to host side.
|
||||
bool AsyncDeviceToHost(const ShapeVector &, size_t size, TypeId, void *host_ptr, size_t stream_id) const override;
|
||||
void ClearUserData() override;
|
||||
|
||||
private:
|
||||
DeviceAddressStatus status_{DeviceAddressStatus::kInDevice};
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_H_
|
||||
|
||||
#if defined(__linux__)
|
||||
#include <cuda.h>
|
||||
#if CUDA_VERSION > 11000
|
||||
#include <curand_kernel.h>
|
||||
|
@ -190,4 +191,5 @@ class GPUHashTable : public HashTable<Key, Value> {
|
|||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif
|
||||
#endif
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_H_
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_UTIL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_UTIL_H_
|
||||
|
||||
#include "plugin/device/gpu/hal/device/gpu_hash_table.h"
|
||||
#include <map>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#if CUDA_VERSION > 11000 && defined(__linux__)
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace gpu {
|
||||
using SetHashTableFunc = std::function<void(const UserDataPtr &)>;
|
||||
using SyncHashTableFunc = std::function<bool(const UserDataPtr &, const void *, size_t)>;
|
||||
using ClearHashTableFunc = std::function<void(const UserDataPtr &)>;
|
||||
|
||||
template <typename KeyType, typename ValueType>
|
||||
void SetHashTable(const UserDataPtr &user_data) {
|
||||
MS_EXCEPTION_IF_NULL(user_data);
|
||||
auto shape_vector = user_data->get<ShapeVector>(kHashTableShapeVector);
|
||||
auto default_value = user_data->get<Value>(kHashTableDefaultValue);
|
||||
MS_EXCEPTION_IF_NULL(shape_vector);
|
||||
MS_EXCEPTION_IF_NULL(default_value);
|
||||
int32_t value_size = 1;
|
||||
for (size_t i = 0; i < (*shape_vector).size(); ++i) {
|
||||
value_size *= (*shape_vector)[i];
|
||||
}
|
||||
if (value_size <= 0) {
|
||||
MS_LOG(WARNING) << "Invalid value size:" << value_size;
|
||||
}
|
||||
if (default_value->isa<StringImm>()) {
|
||||
user_data->set<GPUHashTable<KeyType, ValueType>>(
|
||||
kUserDataData,
|
||||
std::make_shared<GPUHashTable<KeyType, ValueType>>(value_size, GetValue<std::string>(default_value)));
|
||||
} else if (default_value->isa<FloatImm>()) {
|
||||
user_data->set<GPUHashTable<KeyType, ValueType>>(
|
||||
kUserDataData, std::make_shared<GPUHashTable<KeyType, float>>(value_size, GetValue<float>(default_value)));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid default value:" << default_value;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename KeyType, typename ValueType>
|
||||
bool SyncHashTable(const UserDataPtr &user_data, const void *host_ptr, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(user_data);
|
||||
MS_EXCEPTION_IF_NULL(host_ptr);
|
||||
const auto &gpu_hash_table = user_data->get<GPUHashTable<KeyType, ValueType>>(kUserDataData);
|
||||
MS_EXCEPTION_IF_NULL(gpu_hash_table);
|
||||
if (!gpu_hash_table->Import({const_cast<void *>(host_ptr), size})) {
|
||||
MS_LOG(ERROR) << "Import for hash table failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename KeyType, typename ValueType>
|
||||
void ClearHashTable(const UserDataPtr &user_data) {
|
||||
MS_EXCEPTION_IF_NULL(user_data);
|
||||
const auto &user_data_data = user_data->get<GPUHashTable<KeyType, ValueType>>(kUserDataData);
|
||||
MS_EXCEPTION_IF_NULL(user_data_data);
|
||||
if (!user_data_data->Clear()) {
|
||||
MS_LOG(EXCEPTION) << "Clear user data failed.";
|
||||
}
|
||||
}
|
||||
|
||||
static std::map<std::pair<TypeId, TypeId>, std::tuple<SetHashTableFunc, SyncHashTableFunc, ClearHashTableFunc>>
|
||||
hashtable_func_list = {
|
||||
{std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32),
|
||||
std::make_tuple(SetHashTable<int, float>, SyncHashTable<int, float>, ClearHashTable<int, float>)},
|
||||
{std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeFloat32),
|
||||
std::make_tuple(SetHashTable<int64_t, float>, SyncHashTable<int64_t, float>, ClearHashTable<int64_t, float>)}};
|
||||
|
||||
constexpr size_t kSetFuncIndex = 0;
|
||||
constexpr size_t kSyncFuncIndex = 1;
|
||||
constexpr size_t kClearFuncIndex = 2;
|
||||
} // namespace gpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_UTIL_H_
|
|
@ -40,7 +40,7 @@
|
|||
#include "backend/common/session/kernel_graph.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/hal/device/gpu_hash_table.h"
|
||||
#include "plugin/device/gpu/hal/device/gpu_hash_table_util.h"
|
||||
#include "plugin/device/gpu/optimizer/reg_gpu_const_input_to_attr.h"
|
||||
#include "backend/common/optimizer/common_backend_optimization.h"
|
||||
#include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
|
||||
|
@ -262,38 +262,22 @@ void SetUserData(DeviceAddress *device_address, const UserDataPtr &user_data) {
|
|||
const auto &user_data_type = user_data->get<UserDataType>(kUserDataType);
|
||||
MS_EXCEPTION_IF_NULL(user_data_type);
|
||||
if (*user_data_type == UserDataType::kUserTypeHashTable) {
|
||||
auto shape_vector = user_data->get<ShapeVector>(kHashTableShapeVector);
|
||||
#if CUDA_VERSION > 11000 && defined(__linux__)
|
||||
auto key_type = user_data->get<TypeId>(kHashTableKeyType);
|
||||
auto value_type = user_data->get<TypeId>(kHashTableValueType);
|
||||
auto default_value = user_data->get<Value>(kHashTableDefaultValue);
|
||||
MS_EXCEPTION_IF_NULL(shape_vector);
|
||||
MS_EXCEPTION_IF_NULL(key_type);
|
||||
MS_EXCEPTION_IF_NULL(value_type);
|
||||
MS_EXCEPTION_IF_NULL(default_value);
|
||||
if (*key_type == TypeId::kNumberTypeInt32 && *value_type == TypeId::kNumberTypeFloat32) {
|
||||
int32_t value_size = 1;
|
||||
for (size_t i = 0; i < (*shape_vector).size(); ++i) {
|
||||
value_size *= (*shape_vector)[i];
|
||||
}
|
||||
if (value_size <= 0) {
|
||||
MS_LOG(WARNING) << "Invalid value size:" << value_size;
|
||||
}
|
||||
#if CUDA_VERSION > 11000
|
||||
if (default_value->isa<StringImm>()) {
|
||||
user_data->set<GPUHashTable<int, float>>(
|
||||
kUserDataData, std::make_shared<GPUHashTable<int, float>>(value_size, GetValue<std::string>(default_value)));
|
||||
} else if (default_value->isa<FloatImm>()) {
|
||||
user_data->set<GPUHashTable<int, float>>(
|
||||
kUserDataData, std::make_shared<GPUHashTable<int, float>>(value_size, GetValue<float>(default_value)));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid default value:" << default_value;
|
||||
}
|
||||
#else
|
||||
MS_LOG(EXCEPTION) << "Invalid cuda version for gpu hash table.";
|
||||
#endif
|
||||
const auto &iter = hashtable_func_list.find({*key_type, *value_type});
|
||||
if (iter != hashtable_func_list.end()) {
|
||||
return std::get<kSetFuncIndex>(iter->second)(user_data);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported key type:" << key_type << " value type:" << value_type;
|
||||
MS_LOG(EXCEPTION) << "Unsupported hash table type:" << *key_type << " and:" << *value_type;
|
||||
}
|
||||
#else
|
||||
MS_LOG(EXCEPTION) << "Invalid platform or cuda version for gpu hash table.";
|
||||
#endif
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid user data type:" << *user_data_type;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -200,27 +200,7 @@ class DeviceAddress : public mindspore::DeviceSync {
|
|||
UserDataPtr user_data() const { return user_data_; }
|
||||
void set_user_data(const UserDataPtr &user_data) { user_data_ = user_data; }
|
||||
// Free the ptr in user data when the ref count is 0.
|
||||
virtual void ClearUserData() {
|
||||
if (user_data_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto user_data_type = user_data_->get<UserDataType>(kUserDataType);
|
||||
MS_EXCEPTION_IF_NULL(user_data_type);
|
||||
if (*user_data_type == UserDataType::kUserTypeHashTable) {
|
||||
const auto &key_type = user_data_->get<TypeId>(kHashTableKeyType);
|
||||
const auto &value_type = user_data_->get<TypeId>(kHashTableValueType);
|
||||
MS_EXCEPTION_IF_NULL(key_type);
|
||||
MS_EXCEPTION_IF_NULL(value_type);
|
||||
if (*key_type == TypeId::kNumberTypeInt32 && *value_type == TypeId::kNumberTypeFloat32) {
|
||||
const auto &user_data_data = user_data_->get<HashTable<int, float>>(kUserDataData);
|
||||
MS_EXCEPTION_IF_NULL(user_data_data);
|
||||
if (!user_data_data->Clear()) {
|
||||
MS_LOG(EXCEPTION) << "Clear user data failed.";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
virtual void ClearUserData() {}
|
||||
|
||||
protected:
|
||||
const void *ptr() const { return ptr_; }
|
||||
|
|
Loading…
Reference in New Issue