runtime support cpu hash table

This commit is contained in:
lizhenyu 2023-01-28 15:33:26 +08:00
parent 6d49821c11
commit b3f9e6f2e7
7 changed files with 196 additions and 1 deletions

View File

@ -268,6 +268,7 @@ mindspore/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc:mindspore
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:conv3d
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_avx512_mask_fp32.c:GemmRowxColMaskKernelFp32
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/crop_and_resize_cpu_kernel.cc:mindspore::kernel::CropAndResizeCpuKernelMod::LaunchKernel
mindspore/mindspore/ccsrc/plugin/device/cpu/hal/device/cpu_device_address.cc:mindspore::device::cpu::CPUDeviceAddress::SyncHostToDevice
# AICPU migration
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/mediangrad.cc:aicpu::MedianGradCpuKernel::MedianGradCompute

View File

@ -18,6 +18,7 @@
#include <memory>
#include "runtime/device/convert_tensor_utils.h"
#include "plugin/device/cpu/hal/hardware/cpu_memory_pool.h"
#include "plugin/device/cpu/hal/device/cpu_hash_table_util.h"
#ifndef ENABLE_SECURITY
#include "debug/data_dump/dump_json_parser.h"
#endif
@ -45,6 +46,30 @@ bool CopySameTypeMem(void *dst_ptr, size_t dst_size, const void *src_ptr, size_t
return true;
}
}
// Synchronize user data from host to device.
bool SyncUserDataToDevice(const UserDataPtr &user_data, const void *host_ptr, size_t size) {
MS_EXCEPTION_IF_NULL(user_data);
MS_EXCEPTION_IF_NULL(host_ptr);
const auto &user_data_type = user_data->get<UserDataType>(kUserDataType);
MS_EXCEPTION_IF_NULL(user_data_type);
if (*user_data_type == UserDataType::kUserTypeHashTable) {
auto key_type = user_data->get<TypeId>(kHashTableKeyType);
auto value_type = user_data->get<TypeId>(kHashTableValueType);
MS_EXCEPTION_IF_NULL(key_type);
MS_EXCEPTION_IF_NULL(value_type);
const auto &iter = cpu_hash_table_funcs.find({*key_type, *value_type});
if (iter != cpu_hash_table_funcs.end()) {
// Import key, value, status tensors to CPU hash table.
return std::get<kImportFuncIndex>(iter->second)(user_data, host_ptr, size);
} else {
MS_LOG(EXCEPTION) << "Unsupported hash table type, key type:" << TypeIdLabel(*key_type)
<< ", value type:" << TypeIdLabel(*value_type);
}
}
return true;
}
} // namespace
CPUDeviceAddress::~CPUDeviceAddress() { DoClearDeviceMemory(); }
@ -60,6 +85,28 @@ void CPUDeviceAddress::DoClearDeviceMemory() {
void CPUDeviceAddress::ClearDeviceMemory() { DoClearDeviceMemory(); }
void CPUDeviceAddress::ClearUserData() {
if (user_data_ == nullptr) {
return;
}
auto user_data_type = user_data_->get<UserDataType>(kUserDataType);
MS_EXCEPTION_IF_NULL(user_data_type);
if (*user_data_type == UserDataType::kUserTypeHashTable) {
auto key_type = user_data_->get<TypeId>(kHashTableKeyType);
auto value_type = user_data_->get<TypeId>(kHashTableValueType);
MS_EXCEPTION_IF_NULL(key_type);
MS_EXCEPTION_IF_NULL(value_type);
const auto &iter = cpu_hash_table_funcs.find({*key_type, *value_type});
if (iter != cpu_hash_table_funcs.end()) {
// Clear CPU hash table.
return std::get<kClearFuncIndex>(iter->second)(user_data_);
} else {
MS_LOG(EXCEPTION) << "Unsupported hash table type:" << *key_type << " and:" << *value_type;
}
}
}
bool CPUDeviceAddress::DumpMemToFile(const std::string &filepath, const std::string &, const ShapeVector &host_shape,
TypeId host_type, bool) const {
bool ret = false;
@ -125,6 +172,10 @@ bool CPUDeviceAddress::SyncDeviceToHost(const ShapeVector &, size_t size, TypeId
bool CPUDeviceAddress::SyncHostToDevice(const ShapeVector &, size_t size, TypeId type, const void *host_ptr,
const std::string &) const {
if (user_data_ != nullptr) {
return SyncUserDataToDevice(user_data_, host_ptr, size);
}
// The input or output may be empty.
if ((size == 0) || (size_ == 0)) {
MS_LOG(INFO) << "No need sync, host size: " << size << ", device size: " << size_;

View File

@ -51,6 +51,8 @@ class BACKEND_EXPORT CPUDeviceAddress : public DeviceAddress {
bool DumpMemToFile(const std::string &filepath, const std::string &host_fmt, const ShapeVector &host_shape,
TypeId host_type, bool trans_flag) const override;
void ClearDeviceMemory() override;
void ClearUserData() override;
DeviceType GetDeviceType() const override { return DeviceType::kCPU; }
protected:

View File

@ -185,7 +185,7 @@ HashTableExportData CPUHashTable<Key, Value>::Export(bool) {
template <typename Key, typename Value>
size_t CPUHashTable<Key, Value>::capacity() const {
std::unique_lock<std::shared_mutex> lock(mutex_);
return values_.capacity();
return values_.size();
}
template <typename Key, typename Value>
@ -214,6 +214,9 @@ bool CPUHashTable<Key, Value>::Clear() {
}
return true;
}
template class CPUHashTable<int32_t, float>;
template class CPUHashTable<int64_t, float>;
} // namespace cpu
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,103 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_HAL_DEVICE_CPU_HASH_TABLE_UTIL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_HAL_DEVICE_CPU_HASH_TABLE_UTIL_H_
#include <map>
#include <tuple>
#include <utility>
#include <memory>
#include "plugin/device/cpu/hal/device/cpu_hash_table.h"
namespace mindspore {
namespace device {
namespace cpu {
using CreateHashTableFunc = std::function<void(const UserDataPtr &)>;
using ImportHashTableFunc = std::function<bool(const UserDataPtr &, const void *, size_t)>;
using ClearHashTableFunc = std::function<void(const UserDataPtr &)>;
constexpr size_t kCreateFuncIndex = 0;
constexpr size_t kImportFuncIndex = 1;
constexpr size_t kClearFuncIndex = 2;
/**
* @brief Create CPU hash table and set into `user_data`.
* @param[in] `user_data`: The input user data which contains meta information to create CPU hash table.
*/
template <typename KeyType, typename ValueType>
void CreateCPUHashTable(const UserDataPtr &user_data) {
MS_EXCEPTION_IF_NULL(user_data);
auto shape_vector = user_data->get<ShapeVector>(kHashTableShapeVector);
auto default_value = user_data->get<Value>(kHashTableDefaultValue);
MS_EXCEPTION_IF_NULL(shape_vector);
MS_EXCEPTION_IF_NULL(default_value);
int32_t value_size = 1;
for (size_t i = 0; i < (*shape_vector).size(); ++i) {
value_size *= (*shape_vector)[i];
}
if (value_size <= 0) {
MS_LOG(WARNING) << "Invalid value size:" << value_size;
}
user_data->set<CPUHashTable<KeyType, ValueType>>(kUserDataData,
std::make_shared<CPUHashTable<KeyType, ValueType>>(value_size));
}
/**
* @brief Import key, value, status tensors to CPU hash table.
* @param[in] `user_data`: The input user data which contains CPU hash table need to import.
* @param[in] `tensor_data`: The host pointer of tensor which need to be imported into CPU hash table.
* @param[in] `size`: The data length in bytes of tensor data which need to be imported into CPU hash table.
* @return Whether the function was successfully executed.
*/
template <typename KeyType, typename ValueType>
bool ImportCPUHashTable(const UserDataPtr &user_data, const void *tensor_data, size_t size) {
MS_EXCEPTION_IF_NULL(user_data);
MS_EXCEPTION_IF_NULL(tensor_data);
const auto &cpu_hash_table = user_data->get<CPUHashTable<KeyType, ValueType>>(kUserDataData);
MS_EXCEPTION_IF_NULL(cpu_hash_table);
if (!cpu_hash_table->Import({const_cast<void *>(tensor_data), size})) {
MS_LOG(ERROR) << "Import for hash table failed.";
return false;
}
return true;
}
/**
* @brief Clear all resource in CPU hash table and reset all statistics.
* @param[in] `user_data`: The input user data which contains CPU hash table need to clear.
*/
template <typename KeyType, typename ValueType>
void ClearCPUHashTable(const UserDataPtr &user_data) {
MS_EXCEPTION_IF_NULL(user_data);
const auto &user_data_data = user_data->get<CPUHashTable<KeyType, ValueType>>(kUserDataData);
MS_EXCEPTION_IF_NULL(user_data_data);
if (!user_data_data->Clear()) {
MS_LOG(EXCEPTION) << "Clear user data failed.";
}
}
static std::map<std::pair<TypeId, TypeId>, std::tuple<CreateHashTableFunc, ImportHashTableFunc, ClearHashTableFunc>>
cpu_hash_table_funcs = {
{std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32),
std::make_tuple(CreateCPUHashTable<int, float>, ImportCPUHashTable<int, float>, ClearCPUHashTable<int, float>)},
{std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeFloat32),
std::make_tuple(CreateCPUHashTable<int64_t, float>, ImportCPUHashTable<int64_t, float>,
ClearCPUHashTable<int64_t, float>)}};
} // namespace cpu
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_HAL_DEVICE_CPU_HASH_TABLE_UTIL_H_

View File

@ -22,6 +22,7 @@
#include "plugin/device/cpu/optimizer/reg_cpu_const_input_to_attr.h"
#include "plugin/device/cpu/optimizer/print_value_type.h"
#include "plugin/device/cpu/hal/hardware/cpu_somas.h"
#include "plugin/device/cpu/hal/device/cpu_hash_table_util.h"
#ifdef ENABLE_AKG
#include "plugin/device/cpu/kernel/akg/akg_cpu_kernel_build.h"
#endif
@ -120,6 +121,36 @@ std::vector<void *> CPUDeviceResManager::AllocateContinuousMemory(const std::vec
return mem_manager_->MallocContinuousMemFromMemPool(size_list);
}
namespace {
// Create user data content(such as CPU hash table) and set user data reference into device_address.
void FillUserData(const UserDataPtr &user_data, DeviceAddress *device_address) {
MS_EXCEPTION_IF_NULL(user_data);
MS_EXCEPTION_IF_NULL(device_address);
const auto &user_data_type = user_data->get<UserDataType>(kUserDataType);
MS_EXCEPTION_IF_NULL(user_data_type);
if (*user_data_type == UserDataType::kUserTypeHashTable) {
auto key_type = user_data->get<TypeId>(kHashTableKeyType);
auto value_type = user_data->get<TypeId>(kHashTableValueType);
MS_EXCEPTION_IF_NULL(key_type);
MS_EXCEPTION_IF_NULL(value_type);
const auto &iter = cpu_hash_table_funcs.find({*key_type, *value_type});
if (iter != cpu_hash_table_funcs.end()) {
// Create CPU hash table and set into `user_data`.
return std::get<kCreateFuncIndex>(iter->second)(user_data);
} else {
MS_LOG(EXCEPTION) << "Unsupported hash table type, key type:" << TypeIdLabel(*key_type)
<< ", value type:" << TypeIdLabel(*value_type);
}
} else {
MS_LOG(EXCEPTION) << "Invalid user data type:" << *user_data_type;
}
// Save reference of user data in device address.
device_address->set_user_data(user_data);
}
} // namespace
DeviceAddressPtr CPUDeviceResManager::CreateDeviceAddress(void *const device_ptr, size_t device_size,
const string &format, TypeId type_id,
const ShapeVector &shape,
@ -128,6 +159,9 @@ DeviceAddressPtr CPUDeviceResManager::CreateDeviceAddress(void *const device_ptr
device_context_->device_context_key().device_name_,
device_context_->device_context_key().device_id_);
device_address->set_host_shape(shape);
if (user_data != nullptr) {
FillUserData(user_data, device_address.get());
}
return device_address;
}

View File

@ -182,6 +182,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/plugin/device/cpu/hal/hardware/ms_collective_topo.cc"
"../../../mindspore/ccsrc/plugin/device/cpu/hal/hardware/cpu_memory_pool.cc"
"../../../mindspore/ccsrc/plugin/device/cpu/hal/device/cpu_device_address.cc"
"../../../mindspore/ccsrc/plugin/device/cpu/hal/device/cpu_hash_table.cc"
"../../../mindspore/ccsrc/plugin/device/cpu/optimizer/softmax_grad_fusion.cc"
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/cpu_kernel.cc"
"../../../mindspore/ccsrc/plugin/factory/ms_factory.h"