From a5f57ce8a0a9c0452b0eb865d1dc6390e5961143 Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Tue, 8 Dec 2020 17:29:22 +0800 Subject: [PATCH] add ps cache --- .../kernel_compiler/aicpu/aicpu_util.h | 1 + .../gpu/cuda_impl/hash_impl.cu | 64 +++++ .../gpu/cuda_impl/hash_impl.cuh | 27 ++ .../ccsrc/backend/kernel_compiler/kernel.h | 2 + mindspore/ccsrc/ps/CMakeLists.txt | 11 + mindspore/ccsrc/ps/ps_cache/CMakeLists.txt | 7 + .../ps/ps_cache/ascend/ascend_ps_cache.cc | 253 ++++++++++++++++++ .../ps/ps_cache/ascend/ascend_ps_cache.h | 73 +++++ .../ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc | 92 +++++++ .../ccsrc/ps/ps_cache/gpu/gpu_ps_cache.h | 49 ++++ mindspore/ccsrc/ps/ps_cache/ps_cache_basic.h | 46 ++++ .../ccsrc/ps/ps_cache/ps_cache_factory.cc | 42 +++ .../ccsrc/ps/ps_cache/ps_cache_factory.h | 57 ++++ .../ps/ps_cache/ps_data/ps_data_channel.cc | 52 ++++ .../ps/ps_cache/ps_data/ps_data_channel.h | 58 ++++ .../ccsrc/runtime/device/gpu/gpu_common.h | 18 ++ tests/ut/cpp/CMakeLists.txt | 2 + 17 files changed, 854 insertions(+) create mode 100755 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hash_impl.cu create mode 100755 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hash_impl.cuh create mode 100644 mindspore/ccsrc/ps/ps_cache/CMakeLists.txt create mode 100644 mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc create mode 100644 mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.h create mode 100644 mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc create mode 100644 mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.h create mode 100644 mindspore/ccsrc/ps/ps_cache/ps_cache_basic.h create mode 100644 mindspore/ccsrc/ps/ps_cache/ps_cache_factory.cc create mode 100644 mindspore/ccsrc/ps/ps_cache/ps_cache_factory.h create mode 100644 mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_channel.cc create mode 100644 mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_channel.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h index 2e707c18de6..7bd5974a17f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h @@ -46,6 +46,7 @@ constexpr auto kTopKV2 = "TopKV2"; constexpr auto kEditDistance = "EditDistance"; constexpr auto kGatherD = "GatherD"; constexpr auto kIdentity = "Identity"; +constexpr auto kUpdateCache = "UpdateCache"; constexpr auto kCustRunApi = "RunCpuKernel"; const std::set kCustAiCpuKernelOps{kEditDistance, kIdentity}; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hash_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hash_impl.cu new file mode 100755 index 00000000000..0a20616c58f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hash_impl.cu @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hash_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void HashSwapOut(const T *hash_table, T *swap_out_value, const int *swap_out_index, const int index_size, + const int hash_dim) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < index_size; i += blockDim.x * gridDim.x) { + int hash_index = swap_out_index[i]; + for (int j = 0; j < hash_dim; j++) { + swap_out_value[i * hash_dim + j] = hash_table[hash_index * hash_dim + j]; + } + } + return; +} + +template +__global__ void HashSwapIn(T *hash_table, const T *swap_in_value, const int *swap_in_index, const int index_size, + const int hash_dim) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < index_size; i += blockDim.x * gridDim.x) { + int hash_index = swap_in_index[i]; + for (int j = 0; j < hash_dim; j++) { + hash_table[hash_index * hash_dim + j] = swap_in_value[i * hash_dim + j]; + } + } + return; +} + +template +void DoHashSwapOut(const T *hash_table, T *swap_out_value, const int *swap_out_index, const int index_size, + const int hash_dim, cudaStream_t cuda_stream) { + HashSwapOut<<>>(hash_table, swap_out_value, swap_out_index, + index_size, hash_dim); + return; +} + +template +void DoHashSwapIn(T *hash_table, const T *swap_in_value, const int *swap_in_index, const int index_size, + const int hash_dim, cudaStream_t cuda_stream) { + HashSwapIn<<>>(hash_table, swap_in_value, swap_in_index, + index_size, hash_dim); + return; +} + +template void DoHashSwapOut(const float *hash_table, float *swap_out_value, const int *swap_out_index, + const int index_size, const int hash_dim, cudaStream_t cuda_stream); + +template void DoHashSwapIn(float *hash_table, const float *swap_in_value, const int *swap_in_index, + const int index_size, const int hash_dim, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hash_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hash_impl.cuh new file mode 100755 index 00000000000..e13d2dc124b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hash_impl.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_KERNEL_HASH_IMPL_H_ +#define MINDSPORE_CCSRC_PS_PS_CACHE_KERNEL_HASH_IMPL_H_ + +template +void DoHashSwapOut(const T *hash_table, T *swap_out_value, const int *swap_out_index, const int index_size, + const int hash_dim, cudaStream_t cuda_stream); + +template +void DoHashSwapIn(T *hash_table, const T *swap_in_value, const int *swap_in_index, const int index_size, + const int hash_dim, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PS_PS_CACHE_KERNEL_HASH_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel.h b/mindspore/ccsrc/backend/kernel_compiler/kernel.h index 5bfbe4cd3f6..64b2f697c1b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel.h @@ -116,6 +116,8 @@ using KernelPackPtr = std::shared_ptr; * @brief base class for autotensor kernel and cce kernel. */ struct Address { + Address() {} + Address(void *address_addr, size_t address_size) : addr(address_addr), size(address_size) {} void *addr; size_t size; }; diff --git a/mindspore/ccsrc/ps/CMakeLists.txt b/mindspore/ccsrc/ps/CMakeLists.txt index 6c7e313c0a9..a2b292c2260 100644 --- a/mindspore/ccsrc/ps/CMakeLists.txt +++ b/mindspore/ccsrc/ps/CMakeLists.txt @@ -16,5 +16,16 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") endif () +if (NOT ENABLE_D) + list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ascend/ascend_ps_cache.cc") +endif() + +if (NOT ENABLE_GPU) + list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc") +endif() + +list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_channel.cc") +add_subdirectory(ps_cache) + set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES}) diff --git a/mindspore/ccsrc/ps/ps_cache/CMakeLists.txt b/mindspore/ccsrc/ps/ps_cache/CMakeLists.txt new file mode 100644 index 00000000000..03d5aa136eb --- /dev/null +++ b/mindspore/ccsrc/ps/ps_cache/CMakeLists.txt @@ -0,0 +1,7 @@ +if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) + file(GLOB_RECURSE _PS_CACHE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps_data/*.cc") + set_property(SOURCE ${_PS_CACHE_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) + add_library(ps_cache SHARED ${_PS_CACHE_SRC_FILES}) +endif() + + diff --git a/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc b/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc new file mode 100644 index 00000000000..917b999db1f --- /dev/null +++ b/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc @@ -0,0 +1,253 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/ps_cache/ascend/ascend_ps_cache.h" +#include +#include +#include +#include +#include "ps/ps_cache/ps_cache_factory.h" +#include "runtime/device/ascend/ascend_memory_pool.h" +#include "backend/kernel_compiler/aicpu/aicpu_kernel_mod.h" +#include "utils/ms_context.h" +#include "proto/tensor.pb.h" +#include "proto/tensor_shape.pb.h" +#include "proto/attr.pb.h" +#include "proto/node_def.pb.h" + +using mindspore::kernel::Address; +using AddressPtr = std::shared_ptr
; +using AddressPtrList = std::vector; + +namespace mindspore { +namespace ps { +namespace ascend { +MS_REG_PS_CACHE(kAscendDevice, AscendPsCache); +namespace { +void SetProtoInputs(const std::vector> &data_shape, const std::vector &data_type, + mindspore::NodeDef *proto) { + MS_EXCEPTION_IF_NULL(proto); + if (data_shape.size() != data_type.size()) { + MS_LOG(EXCEPTION) << "The size of data shape is not equal to the size of data type."; + } + for (size_t input_index = 0; input_index < data_shape.size(); input_index++) { + ::mindspore::Tensor *proto_inputs = proto->add_inputs(); + MS_EXCEPTION_IF_NULL(proto_inputs); + auto input_shape = data_shape[input_index]; + mindspore::TensorShape *tensorShape = proto_inputs->mutable_tensor_shape(); + MS_EXCEPTION_IF_NULL(tensorShape); + for (auto item : input_shape) { + mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); + MS_EXCEPTION_IF_NULL(dim); + dim->set_size((::google::protobuf::int64)item); + } + auto input_type = kernel::AicpuOpUtil::MsTypeToProtoType(data_type[input_index]); + proto_inputs->set_tensor_type(input_type); + proto_inputs->set_mem_device("HBM"); + } +} + +void SetProtoOutputs(const std::vector> &data_shape, const std::vector &data_type, + mindspore::NodeDef *proto) { + MS_EXCEPTION_IF_NULL(proto); + if (data_shape.size() != data_type.size()) { + MS_LOG(EXCEPTION) << "The size of data shape is not equal to the size of data type."; + } + for (size_t output_index = 0; output_index < data_shape.size(); output_index++) { + ::mindspore::Tensor *proto_outputs = proto->add_outputs(); + MS_EXCEPTION_IF_NULL(proto_outputs); + auto output_shape = data_shape[output_index]; + mindspore::TensorShape *tensorShape = proto_outputs->mutable_tensor_shape(); + MS_EXCEPTION_IF_NULL(tensorShape); + for (auto item : output_shape) { + mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); + MS_EXCEPTION_IF_NULL(dim); + dim->set_size((::google::protobuf::int64)item); + } + auto output_type = kernel::AicpuOpUtil::MsTypeToProtoType(data_type[output_index]); + proto_outputs->set_tensor_type(output_type); + proto_outputs->set_mem_device("HBM"); + } +} + +void SetNodedefProto(const std::shared_ptr &op_info, + const std::shared_ptr &kernel_mod_ptr) { + MS_EXCEPTION_IF_NULL(op_info); + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + mindspore::NodeDef proto; + proto.set_op(op_info->op_name_); + SetProtoInputs(op_info->input_data_shape_, op_info->input_data_type_, &proto); + SetProtoOutputs(op_info->output_data_shape_, op_info->output_data_type_, &proto); + std::string nodeDefStr; + if (!proto.SerializeToString(&nodeDefStr)) { + MS_LOG(EXCEPTION) << "Serialize nodeDef to string failed."; + } + MS_LOG(DEBUG) << "Set node def proto, node name:" << op_info->op_name_; + kernel_mod_ptr->SetNodeDef(nodeDefStr); +} +} // namespace + +void AscendPsCache::InitDevice(uint32_t device_id, const void *context) { + MS_EXCEPTION_IF_NULL(context); + auto ret = rtSetDevice(device_id); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << ret << "]"; + } + auto rt_context = const_cast(context); + ret = rtCtxSetCurrent(rt_context); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]"; + } + ret = rtStreamCreate(&stream_, 0); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Call rtStreamCreate, ret[" << ret << "]"; + } +} + +void *AscendPsCache::MallocMemory(size_t size) { + return device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(size); +} + +void AscendPsCache::MallocConstantMemory(size_t constant_value) { + offset_addr_ = reinterpret_cast(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int))); + MS_EXCEPTION_IF_NULL(offset_addr_); + rtMemset(offset_addr_, sizeof(int), 0, sizeof(int)); + cache_vocab_size_addr_ = + reinterpret_cast(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int))); + MS_EXCEPTION_IF_NULL(cache_vocab_size_addr_); + rtMemset(cache_vocab_size_addr_, sizeof(int), constant_value, sizeof(int)); +} + +void AscendPsCache::RecordEvent() { + event_.reset(new rtEvent_t()); + auto ret = rtEventCreate(&(*event_)); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Create event failed"; + } + ret = rtEventRecord(*event_, stream_); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Record event failed"; + } +} + +void AscendPsCache::SynchronizeEvent() { + auto ret = rtEventSynchronize(*event_); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "tEventSynchronize failed"; + } + ret = rtEventDestroy(*event_); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "rtEventDestroy failed"; + } +} + +void AscendPsCache::SynchronizeStream() { + auto ret = rtStreamSynchronize(stream_); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "rtStreamSynchronize failed"; + } +} + +void AscendPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { + MS_EXCEPTION_IF_NULL(dst); + MS_EXCEPTION_IF_NULL(src); + auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_HOST_TO_DEVICE, stream_); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "rtMemcpyAsync failed"; + } +} + +void AscendPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { + MS_EXCEPTION_IF_NULL(dst); + MS_EXCEPTION_IF_NULL(src); + auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_DEVICE_TO_HOST, stream_); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "rtMemcpyAsync failed"; + } +} + +void AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, + size_t hash_table_size, size_t embedding_size, size_t swap_out_size) { + MS_EXCEPTION_IF_NULL(hash_table_addr); + MS_EXCEPTION_IF_NULL(swap_out_value_addr); + MS_EXCEPTION_IF_NULL(swap_out_index_addr); + auto hash_swap_out_mod = std::make_shared(); + MS_EXCEPTION_IF_NULL(hash_swap_out_mod); + hash_swap_out_mod->SetNodeName(kEmbeddingLookupOpName); + std::vector> input_shape; + std::vector> output_shape; + std::vector input_type = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}; + std::vector output_type = {TypeId::kNumberTypeFloat32}; + input_shape.push_back({hash_table_size, embedding_size}); + input_shape.push_back({swap_out_size}); + input_shape.push_back({1}); + output_shape.push_back({swap_out_size, embedding_size}); + auto op_info = + std::make_shared(kEmbeddingLookupOpName, input_shape, input_type, output_shape, output_type); + SetNodedefProto(op_info, hash_swap_out_mod); + + AddressPtrList kernel_inputs; + AddressPtrList kernel_outputs = { + std::make_shared
(swap_out_value_addr, swap_out_size * embedding_size * sizeof(float))}; + AddressPtrList kernel_workspaces; + kernel_inputs.push_back(std::make_shared
(hash_table_addr, hash_table_size * embedding_size * sizeof(float))); + kernel_inputs.push_back(std::make_shared
(swap_out_index_addr, swap_out_size * sizeof(int))); + kernel_inputs.push_back(std::make_shared
(offset_addr_, sizeof(int))); + auto ret = hash_swap_out_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); + if (!ret) { + MS_LOG(EXCEPTION) << "Hash swap out launch failed."; + } +} + +void AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, + size_t hash_table_size, size_t embedding_size, size_t swap_in_size) { + MS_EXCEPTION_IF_NULL(hash_table_addr); + MS_EXCEPTION_IF_NULL(swap_in_value_addr); + MS_EXCEPTION_IF_NULL(swap_in_index_addr); + auto hash_swap_in_mod = std::make_shared(); + MS_EXCEPTION_IF_NULL(hash_swap_in_mod); + hash_swap_in_mod->SetNodeName(kernel::kUpdateCache); + std::vector> input_shape; + std::vector> output_shape; + std::vector input_type = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32, + TypeId::kNumberTypeInt32}; + std::vector output_type = {TypeId::kNumberTypeInt32}; + input_shape.push_back({hash_table_size, embedding_size}); + input_shape.push_back({swap_in_size}); + input_shape.push_back({swap_in_size, embedding_size}); + input_shape.push_back({1}); + output_shape.push_back({1}); + auto op_info = + std::make_shared(kernel::kUpdateCache, input_shape, input_type, output_shape, output_type); + SetNodedefProto(op_info, hash_swap_in_mod); + + AddressPtrList kernel_inputs; + AddressPtrList kernel_outputs; + AddressPtrList kernel_workspaces; + kernel_inputs.push_back(std::make_shared
(hash_table_addr, hash_table_size * embedding_size * sizeof(float))); + kernel_inputs.push_back(std::make_shared
(swap_in_index_addr, swap_in_size * sizeof(int))); + kernel_inputs.push_back(std::make_shared
(swap_in_value_addr, swap_in_size * embedding_size * sizeof(float))); + kernel_inputs.push_back(std::make_shared
(cache_vocab_size_addr_, sizeof(int))); + // The output of updateCache kernel is required but not useful, so any address can be assigned. + kernel_outputs.push_back(std::make_shared
(offset_addr_, sizeof(int))); + auto ret = hash_swap_in_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); + if (!ret) { + MS_LOG(EXCEPTION) << "Hash swap in launch failed."; + } +} +} // namespace ascend +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.h b/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.h new file mode 100644 index 00000000000..4f7c04d853a --- /dev/null +++ b/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.h @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_ASCEND_ASCEND_PS_CACHE_H_ +#define MINDSPORE_CCSRC_PS_PS_CACHE_ASCEND_ASCEND_PS_CACHE_H_ + +#include +#include +#include +#include +#include "ps/ps_cache/ps_cache_basic.h" +#include "backend/kernel_compiler/aicpu/aicpu_kernel_mod.h" +#include "ir/dtype.h" + +namespace mindspore { +namespace ps { +namespace ascend { +struct KernelNodeInfo { + KernelNodeInfo(const std::string &op_name, std::vector> input_data_shape, + std::vector input_data_type, std::vector> output_data_shape, + std::vector output_data_type) + : op_name_(op_name) { + input_data_shape_.swap(input_data_shape); + input_data_type_.swap(input_data_type); + output_data_shape_.swap(output_data_shape); + output_data_type_.swap(output_data_type); + } + std::string op_name_; + std::vector> input_data_shape_; + std::vector input_data_type_; + std::vector> output_data_shape_; + std::vector output_data_type_; +}; + +class AscendPsCache : public PsCacheBasic { + public: + AscendPsCache() = default; + ~AscendPsCache() override = default; + void InitDevice(uint32_t device_id, const void *context) override; + void *MallocMemory(size_t size) override; + void MallocConstantMemory(size_t constant_value) override; + void RecordEvent() override; + void SynchronizeEvent() override; + void SynchronizeStream() override; + void CopyHostMemToDevice(void *dst, void *src, size_t size) override; + void CopyDeviceMemToHost(void *dst, void *src, size_t size) override; + void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size, + size_t embedding_size, size_t swap_out_size) override; + void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size, + size_t embedding_size, size_t swap_in_size) override; + + private: + int *offset_addr_{nullptr}; + int *cache_vocab_size_addr_{nullptr}; + std::unique_ptr event_; +}; +} // namespace ascend +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_PS_CACHE_ASCEND_ASCEND_PS_CACHE_H_ diff --git a/mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc b/mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc new file mode 100644 index 00000000000..11008728787 --- /dev/null +++ b/mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/ps_cache/gpu/gpu_ps_cache.h" +#include "ps/ps_cache/ps_cache_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/hash_impl.cuh" +#include "runtime/device/gpu/gpu_common.h" +#include "runtime/device/gpu/gpu_memory_allocator.h" +#include "utils/ms_context.h" + +namespace mindspore { +namespace ps { +namespace gpu { +MS_REG_PS_CACHE(kGPUDevice, GPUPsCache); +void GPUPsCache::InitDevice(uint32_t device_id, const void *) { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaSetDevice(device_id), "Cuda set device failed") + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamCreate(reinterpret_cast(&stream_)), + "Cuda create stream failed"); +} + +void *GPUPsCache::MallocMemory(size_t size) { + return device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(size); +} + +void GPUPsCache::RecordEvent() { + event_.reset(new cudaEvent_t()); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventCreate(&(*event_)), "Cuda create event failed"); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventRecord(*event_, reinterpret_cast(stream_)), + "Cuda record event failed"); +} + +void GPUPsCache::SynchronizeEvent() { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventSynchronize(*event_), "Cuda sync event failed"); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventDestroy(*event_), "Cuda destroy event failed"); +} + +void GPUPsCache::SynchronizeStream() { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast(stream_)), + "Cuda sync stream failed"); +} + +void GPUPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { + MS_EXCEPTION_IF_NULL(dst); + MS_EXCEPTION_IF_NULL(src); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, reinterpret_cast(stream_)), + "Cuda memcpy failed"); +} + +void GPUPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { + MS_EXCEPTION_IF_NULL(dst); + MS_EXCEPTION_IF_NULL(src); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, reinterpret_cast(stream_)), + "Cuda memcpy failed"); +} + +void GPUPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t, + size_t embedding_size, size_t swap_out_size) { + MS_EXCEPTION_IF_NULL(hash_table_addr); + MS_EXCEPTION_IF_NULL(swap_out_value_addr); + MS_EXCEPTION_IF_NULL(swap_out_index_addr); + DoHashSwapOut(reinterpret_cast(hash_table_addr), reinterpret_cast(swap_out_value_addr), + reinterpret_cast(swap_out_index_addr), swap_out_size, embedding_size, + reinterpret_cast(stream_)); +} + +void GPUPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t, + size_t embedding_size, size_t swap_in_size) { + MS_EXCEPTION_IF_NULL(hash_table_addr); + MS_EXCEPTION_IF_NULL(swap_in_value_addr); + MS_EXCEPTION_IF_NULL(swap_in_index_addr); + DoHashSwapIn(reinterpret_cast(hash_table_addr), reinterpret_cast(swap_in_value_addr), + reinterpret_cast(swap_in_index_addr), swap_in_size, embedding_size, + reinterpret_cast(stream_)); +} +} // namespace gpu +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.h b/mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.h new file mode 100644 index 00000000000..612382c2b3a --- /dev/null +++ b/mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_GPU_GPU_PS_CACHE_H_ +#define MINDSPORE_CCSRC_PS_PS_CACHE_GPU_GPU_PS_CACHE_H_ + +#include +#include +#include "ps/ps_cache/ps_cache_basic.h" + +namespace mindspore { +namespace ps { +namespace gpu { +class GPUPsCache : public PsCacheBasic { + public: + GPUPsCache() = default; + ~GPUPsCache() override = default; + void InitDevice(uint32_t device_id, const void *context) override; + void *MallocMemory(size_t size) override; + void RecordEvent() override; + void SynchronizeEvent() override; + void SynchronizeStream() override; + void CopyHostMemToDevice(void *dst, void *src, size_t size) override; + void CopyDeviceMemToHost(void *dst, void *src, size_t size) override; + void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size, + size_t embedding_size, size_t swap_out_size) override; + void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size, + size_t embedding_size, size_t swap_in_size) override; + + private: + std::unique_ptr event_; +}; +} // namespace gpu +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_PS_CACHE_GPU_GPU_PS_CACHE_H_ diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_basic.h b/mindspore/ccsrc/ps/ps_cache/ps_cache_basic.h new file mode 100644 index 00000000000..c9c7f703a94 --- /dev/null +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_basic.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_BASIC_H +#define MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_BASIC_H + +#include +#include + +namespace mindspore { +namespace ps { +class PsCacheBasic { + public: + PsCacheBasic() = default; + virtual ~PsCacheBasic() = default; + virtual void InitDevice(uint32_t device_id, const void *context) = 0; + virtual void *MallocMemory(size_t size) = 0; + virtual void MallocConstantMemory(size_t constant_value) {} + virtual void RecordEvent() = 0; + virtual void SynchronizeEvent() = 0; + virtual void SynchronizeStream() = 0; + virtual void CopyHostMemToDevice(void *dst, void *src, size_t size) = 0; + virtual void CopyDeviceMemToHost(void *dst, void *src, size_t size) = 0; + virtual void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, + size_t hash_table_size, size_t embedding_size, size_t swap_out_size) = 0; + virtual void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, + size_t hash_table_size, size_t embedding_size, size_t swap_in_size) = 0; + + protected: + void *stream_; +}; +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_BASIC_H diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_factory.cc b/mindspore/ccsrc/ps/ps_cache/ps_cache_factory.cc new file mode 100644 index 00000000000..41a1e05983a --- /dev/null +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_factory.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/ps_cache/ps_cache_factory.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace ps { +PsCacheFactory &PsCacheFactory::Get() { + static PsCacheFactory instance; + return instance; +} + +void PsCacheFactory::Register(const std::string &device_name, PsCacheCreator &&ps_cache_creator) { + if (ps_cache_creators_.end() == ps_cache_creators_.find(device_name)) { + (void)ps_cache_creators_.emplace(device_name, ps_cache_creator); + } +} + +std::shared_ptr PsCacheFactory::ps_cache(const std::string &device_name) { + auto iter = ps_cache_creators_.find(device_name); + if (ps_cache_creators_.end() != iter) { + MS_EXCEPTION_IF_NULL(iter->second); + return (iter->second)(); + } + return nullptr; +} +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_factory.h b/mindspore/ccsrc/ps/ps_cache/ps_cache_factory.h new file mode 100644 index 00000000000..fd06cfbbecc --- /dev/null +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_factory.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_FACTORY_H_ +#define MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_FACTORY_H_ + +#include +#include +#include +#include +#include +#include "ps/ps_cache/ps_cache_basic.h" +#include "utils/ms_utils.h" + +namespace mindspore { +namespace ps { +using PsCacheCreator = std::function()>; +class PsCacheFactory { + public: + static PsCacheFactory &Get(); + void Register(const std::string &device_name, PsCacheCreator &&ps_cache_creator); + std::shared_ptr ps_cache(const std::string &device_name); + + private: + PsCacheFactory() = default; + ~PsCacheFactory() = default; + DISABLE_COPY_AND_ASSIGN(PsCacheFactory) + std::map ps_cache_creators_; +}; + +class PsCacheRegistrar { + public: + PsCacheRegistrar(const std::string &device_name, PsCacheCreator &&ps_cache_creator) { + PsCacheFactory::Get().Register(device_name, std::move(ps_cache_creator)); + } + ~PsCacheRegistrar() = default; +}; + +#define MS_REG_PS_CACHE(DEVICE_NAME, PS_CACHE_CLASS) \ + static const PsCacheRegistrar g_ps_cache_registrar__##DEVICE_NAME##_##_reg( \ + DEVICE_NAME, []() { return std::make_shared(); }); +} // namespace ps +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_FACTORY_H_ diff --git a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_channel.cc b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_channel.cc new file mode 100644 index 00000000000..6ce70dcc6f7 --- /dev/null +++ b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_channel.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/ps_cache/ps_data/ps_data_channel.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace ps { +void PsDataChannel::TryLockChannel() { + // The prefetch order of data needs to be consistent with the graph execution order. + // Example: if graph execution order is graph1 --> graph2 --> graph1 -->graph2, + // then the data prefetch order needs be channel1 --> channel 2 --> channel1 --> channel2. + if ((current_data_step_ != 0) && (current_data_step_ % step_num_ == 0)) { + MS_LOG(INFO) << "Lock channel:" << channel_name_; + std::unique_lock locker(channel_mutex_); + channel_.wait(locker, [this] { return channel_open_; }); + channel_open_ = false; + } + current_data_step_++; +} + +void PsDataChannel::TryWakeChannel() { + if ((current_graph_step_ != 0) && (current_graph_step_ % step_num_ == 0)) { + MS_LOG(INFO) << "Wake up channel:" << channel_name_; + std::lock_guard locker(channel_mutex_); + channel_open_ = true; + channel_.notify_one(); + } + current_graph_step_++; +} + +void PsDataChannel::set_data(void *data, const size_t data_size) { + MS_EXCEPTION_IF_NULL(data); + TryLockChannel(); + data_ = data; + data_size_ = data_size; +} +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_channel.h b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_channel.h new file mode 100644 index 00000000000..83ffb132494 --- /dev/null +++ b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_channel.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_CHANNEL_H_ +#define MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_CHANNEL_H_ + +#include +#include +#include + +namespace mindspore { +namespace ps { +class PsDataChannel { + public: + PsDataChannel(const std::string &channel_name, size_t step_num) + : channel_name_(channel_name), + step_num_(step_num), + current_data_step_(0), + current_graph_step_(0), + channel_open_(false), + data_(nullptr), + data_size_(0) {} + virtual ~PsDataChannel() = default; + void set_data(void *data, const size_t data_size); + void *data() const { return data_; } + size_t data_size() const { return data_size_; } + void ResetData() { data_ = nullptr; } + void set_step_num(size_t step_num) { step_num_ = step_num; } + void TryWakeChannel(); + + private: + void TryLockChannel(); + std::string channel_name_; + // The step num of each epoch. + size_t step_num_; + size_t current_data_step_; + size_t current_graph_step_; + bool channel_open_; + std::mutex channel_mutex_; + std::condition_variable channel_; + void *data_; + size_t data_size_; +}; +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_CHANNEL_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_common.h b/mindspore/ccsrc/runtime/device/gpu/gpu_common.h index 9ae4db2e782..d9cefc42753 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_common.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_common.h @@ -53,6 +53,15 @@ namespace gpu { } \ } +#define CHECK_CUDA_RET_WITH_ERROR_NOTRACE(expression, message) \ + { \ + cudaError_t status = (expression); \ + if (status != cudaSuccess) { \ + MS_LOG(ERROR) << "CUDA Error: " << message << " | Error Number: " << status << " " \ + << cudaGetErrorString(status); \ + } \ + } + #define CHECK_CUDA_RET_WITH_EXCEPT(node, expression, message) \ { \ cudaError_t status = (expression); \ @@ -62,6 +71,15 @@ namespace gpu { } \ } +#define CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(expression, message) \ + { \ + cudaError_t status = (expression); \ + if (status != cudaSuccess) { \ + MS_LOG(EXCEPTION) << "CUDA Error: " << message << " | Error Number: " << status << " " \ + << cudaGetErrorString(status); \ + } \ + } + #define CHECK_CUDNN_RET_WITH_EXCEPT(node, expression, message) \ { \ cudnnStatus_t status = (expression); \ diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index 6519f44f1d5..e6fd916db80 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -139,6 +139,8 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/util.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/scheduler.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info_builder.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc")