runtime support cpu hash table
This commit is contained in:
parent
6d49821c11
commit
b3f9e6f2e7
|
@ -268,6 +268,7 @@ mindspore/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc:mindspore
|
|||
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:conv3d
|
||||
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_avx512_mask_fp32.c:GemmRowxColMaskKernelFp32
|
||||
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/crop_and_resize_cpu_kernel.cc:mindspore::kernel::CropAndResizeCpuKernelMod::LaunchKernel
|
||||
mindspore/mindspore/ccsrc/plugin/device/cpu/hal/device/cpu_device_address.cc:mindspore::device::cpu::CPUDeviceAddress::SyncHostToDevice
|
||||
|
||||
# AICPU migration
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/mediangrad.cc:aicpu::MedianGradCpuKernel::MedianGradCompute
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <memory>
|
||||
#include "runtime/device/convert_tensor_utils.h"
|
||||
#include "plugin/device/cpu/hal/hardware/cpu_memory_pool.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_hash_table_util.h"
|
||||
#ifndef ENABLE_SECURITY
|
||||
#include "debug/data_dump/dump_json_parser.h"
|
||||
#endif
|
||||
|
@ -45,6 +46,30 @@ bool CopySameTypeMem(void *dst_ptr, size_t dst_size, const void *src_ptr, size_t
|
|||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Synchronize user data from host to device.
|
||||
bool SyncUserDataToDevice(const UserDataPtr &user_data, const void *host_ptr, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(user_data);
|
||||
MS_EXCEPTION_IF_NULL(host_ptr);
|
||||
const auto &user_data_type = user_data->get<UserDataType>(kUserDataType);
|
||||
MS_EXCEPTION_IF_NULL(user_data_type);
|
||||
|
||||
if (*user_data_type == UserDataType::kUserTypeHashTable) {
|
||||
auto key_type = user_data->get<TypeId>(kHashTableKeyType);
|
||||
auto value_type = user_data->get<TypeId>(kHashTableValueType);
|
||||
MS_EXCEPTION_IF_NULL(key_type);
|
||||
MS_EXCEPTION_IF_NULL(value_type);
|
||||
const auto &iter = cpu_hash_table_funcs.find({*key_type, *value_type});
|
||||
if (iter != cpu_hash_table_funcs.end()) {
|
||||
// Import key, value, status tensors to CPU hash table.
|
||||
return std::get<kImportFuncIndex>(iter->second)(user_data, host_ptr, size);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported hash table type, key type:" << TypeIdLabel(*key_type)
|
||||
<< ", value type:" << TypeIdLabel(*value_type);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
CPUDeviceAddress::~CPUDeviceAddress() { DoClearDeviceMemory(); }
|
||||
|
||||
|
@ -60,6 +85,28 @@ void CPUDeviceAddress::DoClearDeviceMemory() {
|
|||
|
||||
void CPUDeviceAddress::ClearDeviceMemory() { DoClearDeviceMemory(); }
|
||||
|
||||
void CPUDeviceAddress::ClearUserData() {
|
||||
if (user_data_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto user_data_type = user_data_->get<UserDataType>(kUserDataType);
|
||||
MS_EXCEPTION_IF_NULL(user_data_type);
|
||||
if (*user_data_type == UserDataType::kUserTypeHashTable) {
|
||||
auto key_type = user_data_->get<TypeId>(kHashTableKeyType);
|
||||
auto value_type = user_data_->get<TypeId>(kHashTableValueType);
|
||||
MS_EXCEPTION_IF_NULL(key_type);
|
||||
MS_EXCEPTION_IF_NULL(value_type);
|
||||
const auto &iter = cpu_hash_table_funcs.find({*key_type, *value_type});
|
||||
if (iter != cpu_hash_table_funcs.end()) {
|
||||
// Clear CPU hash table.
|
||||
return std::get<kClearFuncIndex>(iter->second)(user_data_);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported hash table type:" << *key_type << " and:" << *value_type;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool CPUDeviceAddress::DumpMemToFile(const std::string &filepath, const std::string &, const ShapeVector &host_shape,
|
||||
TypeId host_type, bool) const {
|
||||
bool ret = false;
|
||||
|
@ -125,6 +172,10 @@ bool CPUDeviceAddress::SyncDeviceToHost(const ShapeVector &, size_t size, TypeId
|
|||
|
||||
bool CPUDeviceAddress::SyncHostToDevice(const ShapeVector &, size_t size, TypeId type, const void *host_ptr,
|
||||
const std::string &) const {
|
||||
if (user_data_ != nullptr) {
|
||||
return SyncUserDataToDevice(user_data_, host_ptr, size);
|
||||
}
|
||||
|
||||
// The input or output may be empty.
|
||||
if ((size == 0) || (size_ == 0)) {
|
||||
MS_LOG(INFO) << "No need sync, host size: " << size << ", device size: " << size_;
|
||||
|
|
|
@ -51,6 +51,8 @@ class BACKEND_EXPORT CPUDeviceAddress : public DeviceAddress {
|
|||
bool DumpMemToFile(const std::string &filepath, const std::string &host_fmt, const ShapeVector &host_shape,
|
||||
TypeId host_type, bool trans_flag) const override;
|
||||
void ClearDeviceMemory() override;
|
||||
void ClearUserData() override;
|
||||
|
||||
DeviceType GetDeviceType() const override { return DeviceType::kCPU; }
|
||||
|
||||
protected:
|
||||
|
|
|
@ -185,7 +185,7 @@ HashTableExportData CPUHashTable<Key, Value>::Export(bool) {
|
|||
template <typename Key, typename Value>
|
||||
size_t CPUHashTable<Key, Value>::capacity() const {
|
||||
std::unique_lock<std::shared_mutex> lock(mutex_);
|
||||
return values_.capacity();
|
||||
return values_.size();
|
||||
}
|
||||
|
||||
template <typename Key, typename Value>
|
||||
|
@ -214,6 +214,9 @@ bool CPUHashTable<Key, Value>::Clear() {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template class CPUHashTable<int32_t, float>;
|
||||
template class CPUHashTable<int64_t, float>;
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_HAL_DEVICE_CPU_HASH_TABLE_UTIL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_HAL_DEVICE_CPU_HASH_TABLE_UTIL_H_
|
||||
|
||||
#include <map>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "plugin/device/cpu/hal/device/cpu_hash_table.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace cpu {
|
||||
using CreateHashTableFunc = std::function<void(const UserDataPtr &)>;
|
||||
using ImportHashTableFunc = std::function<bool(const UserDataPtr &, const void *, size_t)>;
|
||||
using ClearHashTableFunc = std::function<void(const UserDataPtr &)>;
|
||||
|
||||
constexpr size_t kCreateFuncIndex = 0;
|
||||
constexpr size_t kImportFuncIndex = 1;
|
||||
constexpr size_t kClearFuncIndex = 2;
|
||||
|
||||
/**
|
||||
* @brief Create CPU hash table and set into `user_data`.
|
||||
* @param[in] `user_data`: The input user data which contains meta information to create CPU hash table.
|
||||
*/
|
||||
template <typename KeyType, typename ValueType>
|
||||
void CreateCPUHashTable(const UserDataPtr &user_data) {
|
||||
MS_EXCEPTION_IF_NULL(user_data);
|
||||
auto shape_vector = user_data->get<ShapeVector>(kHashTableShapeVector);
|
||||
auto default_value = user_data->get<Value>(kHashTableDefaultValue);
|
||||
MS_EXCEPTION_IF_NULL(shape_vector);
|
||||
MS_EXCEPTION_IF_NULL(default_value);
|
||||
|
||||
int32_t value_size = 1;
|
||||
for (size_t i = 0; i < (*shape_vector).size(); ++i) {
|
||||
value_size *= (*shape_vector)[i];
|
||||
}
|
||||
if (value_size <= 0) {
|
||||
MS_LOG(WARNING) << "Invalid value size:" << value_size;
|
||||
}
|
||||
user_data->set<CPUHashTable<KeyType, ValueType>>(kUserDataData,
|
||||
std::make_shared<CPUHashTable<KeyType, ValueType>>(value_size));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Import key, value, status tensors to CPU hash table.
|
||||
* @param[in] `user_data`: The input user data which contains CPU hash table need to import.
|
||||
* @param[in] `tensor_data`: The host pointer of tensor which need to be imported into CPU hash table.
|
||||
* @param[in] `size`: The data length in bytes of tensor data which need to be imported into CPU hash table.
|
||||
* @return Whether the function was successfully executed.
|
||||
*/
|
||||
template <typename KeyType, typename ValueType>
|
||||
bool ImportCPUHashTable(const UserDataPtr &user_data, const void *tensor_data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(user_data);
|
||||
MS_EXCEPTION_IF_NULL(tensor_data);
|
||||
const auto &cpu_hash_table = user_data->get<CPUHashTable<KeyType, ValueType>>(kUserDataData);
|
||||
MS_EXCEPTION_IF_NULL(cpu_hash_table);
|
||||
if (!cpu_hash_table->Import({const_cast<void *>(tensor_data), size})) {
|
||||
MS_LOG(ERROR) << "Import for hash table failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Clear all resource in CPU hash table and reset all statistics.
|
||||
* @param[in] `user_data`: The input user data which contains CPU hash table need to clear.
|
||||
*/
|
||||
template <typename KeyType, typename ValueType>
|
||||
void ClearCPUHashTable(const UserDataPtr &user_data) {
|
||||
MS_EXCEPTION_IF_NULL(user_data);
|
||||
const auto &user_data_data = user_data->get<CPUHashTable<KeyType, ValueType>>(kUserDataData);
|
||||
MS_EXCEPTION_IF_NULL(user_data_data);
|
||||
if (!user_data_data->Clear()) {
|
||||
MS_LOG(EXCEPTION) << "Clear user data failed.";
|
||||
}
|
||||
}
|
||||
|
||||
static std::map<std::pair<TypeId, TypeId>, std::tuple<CreateHashTableFunc, ImportHashTableFunc, ClearHashTableFunc>>
|
||||
cpu_hash_table_funcs = {
|
||||
{std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32),
|
||||
std::make_tuple(CreateCPUHashTable<int, float>, ImportCPUHashTable<int, float>, ClearCPUHashTable<int, float>)},
|
||||
{std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeFloat32),
|
||||
std::make_tuple(CreateCPUHashTable<int64_t, float>, ImportCPUHashTable<int64_t, float>,
|
||||
ClearCPUHashTable<int64_t, float>)}};
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_HAL_DEVICE_CPU_HASH_TABLE_UTIL_H_
|
|
@ -22,6 +22,7 @@
|
|||
#include "plugin/device/cpu/optimizer/reg_cpu_const_input_to_attr.h"
|
||||
#include "plugin/device/cpu/optimizer/print_value_type.h"
|
||||
#include "plugin/device/cpu/hal/hardware/cpu_somas.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_hash_table_util.h"
|
||||
#ifdef ENABLE_AKG
|
||||
#include "plugin/device/cpu/kernel/akg/akg_cpu_kernel_build.h"
|
||||
#endif
|
||||
|
@ -120,6 +121,36 @@ std::vector<void *> CPUDeviceResManager::AllocateContinuousMemory(const std::vec
|
|||
return mem_manager_->MallocContinuousMemFromMemPool(size_list);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Create user data content(such as CPU hash table) and set user data reference into device_address.
|
||||
void FillUserData(const UserDataPtr &user_data, DeviceAddress *device_address) {
|
||||
MS_EXCEPTION_IF_NULL(user_data);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
|
||||
const auto &user_data_type = user_data->get<UserDataType>(kUserDataType);
|
||||
MS_EXCEPTION_IF_NULL(user_data_type);
|
||||
if (*user_data_type == UserDataType::kUserTypeHashTable) {
|
||||
auto key_type = user_data->get<TypeId>(kHashTableKeyType);
|
||||
auto value_type = user_data->get<TypeId>(kHashTableValueType);
|
||||
MS_EXCEPTION_IF_NULL(key_type);
|
||||
MS_EXCEPTION_IF_NULL(value_type);
|
||||
const auto &iter = cpu_hash_table_funcs.find({*key_type, *value_type});
|
||||
if (iter != cpu_hash_table_funcs.end()) {
|
||||
// Create CPU hash table and set into `user_data`.
|
||||
return std::get<kCreateFuncIndex>(iter->second)(user_data);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported hash table type, key type:" << TypeIdLabel(*key_type)
|
||||
<< ", value type:" << TypeIdLabel(*value_type);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid user data type:" << *user_data_type;
|
||||
}
|
||||
|
||||
// Save reference of user data in device address.
|
||||
device_address->set_user_data(user_data);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
DeviceAddressPtr CPUDeviceResManager::CreateDeviceAddress(void *const device_ptr, size_t device_size,
|
||||
const string &format, TypeId type_id,
|
||||
const ShapeVector &shape,
|
||||
|
@ -128,6 +159,9 @@ DeviceAddressPtr CPUDeviceResManager::CreateDeviceAddress(void *const device_ptr
|
|||
device_context_->device_context_key().device_name_,
|
||||
device_context_->device_context_key().device_id_);
|
||||
device_address->set_host_shape(shape);
|
||||
if (user_data != nullptr) {
|
||||
FillUserData(user_data, device_address.get());
|
||||
}
|
||||
return device_address;
|
||||
}
|
||||
|
||||
|
|
|
@ -182,6 +182,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"../../../mindspore/ccsrc/plugin/device/cpu/hal/hardware/ms_collective_topo.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/cpu/hal/hardware/cpu_memory_pool.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/cpu/hal/device/cpu_device_address.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/cpu/hal/device/cpu_hash_table.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/cpu/optimizer/softmax_grad_fusion.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/cpu_kernel.cc"
|
||||
"../../../mindspore/ccsrc/plugin/factory/ms_factory.h"
|
||||
|
|
Loading…
Reference in New Issue