From c6262111efec5c39dd38f0c45924f9435e4444f2 Mon Sep 17 00:00:00 2001 From: cristoval Date: Tue, 21 Jul 2020 19:23:56 +0800 Subject: [PATCH] add sparse lazy adam in ps --- .../backend/kernel_compiler/CMakeLists.txt | 1 + .../cpu/ps/sparse_apply_adam_ps_kernel.h | 2 +- .../ps/sparse_apply_lazy_adam_ps_kernel.cc | 104 ++++++++++++++++++ .../cpu/ps/sparse_apply_lazy_adam_ps_kernel.h | 49 +++++++++ .../cpu/sparse_apply_lazy_adam_cpu_kernel.h | 2 +- mindspore/ccsrc/frontend/parallel/ps/common.h | 1 + .../frontend/parallel/ps/optimizer_info.cc | 10 ++ .../frontend/parallel/ps/optimizer_info.h | 2 + .../frontend/parallel/ps/parameter_server.h | 26 +++-- mindspore/ccsrc/frontend/parallel/ps/worker.h | 6 + .../ccsrc/frontend/parallel/ps/worker_proxy.h | 12 ++ .../st/host_device/test_host_device_lenet.py | 2 +- 12 files changed, 203 insertions(+), 14 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt b/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt index aa1b81d8e42..4de32578a50 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt +++ b/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt @@ -42,6 +42,7 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/push_kernel.cc") list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/sparse_apply_adam_ps_kernel.cc") list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/sparse_apply_ftrl_ps_kernel.cc") + list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc") endif() if (ENABLE_GPU) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h index e0a623ce839..0b35218abc3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_ADAM_PS_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_ADAM_PS_KERNEL_H_ #include diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc new file mode 100644 index 00000000000..50b7c0a1570 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h" +#include +#include "backend/kernel_compiler/common_utils.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "frontend/parallel/ps/util.h" + +namespace mindspore { +namespace kernel { +namespace ps { +void SparseApplyLazyAdamPSKernel::InitKernel( + const std::shared_ptr>>> &shapes) { + const std::vector>> &shape_vec = *shapes; + std::vector &var_shape = *(shape_vec[0]); + std::vector &m_shape = *(shape_vec[1]); + std::vector &v_shape = *(shape_vec[2]); + const std::vector &grad_shape = *(shape_vec[9]); + const std::vector &indices_shape = *(shape_vec[10]); + + Shard(&var_shape, 0); + Shard(&m_shape, 0); + Shard(&v_shape, 0); + + if (!IsSameShape(var_shape, m_shape)) { + MS_LOG(EXCEPTION) << "var and m should have the same shape"; + } + if (!IsSameShape(var_shape, v_shape)) { + MS_LOG(EXCEPTION) << "var and v should have the same shape"; + } + var_first_dim_size_ = var_shape[0]; + for (size_t i = 1; i < var_shape.size(); ++i) { + if (var_shape[i] != grad_shape[i]) { + MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; + } + var_outer_dim_size_ *= var_shape[i]; + } + if (indices_shape.size() != 1) { + MS_LOG(EXCEPTION) << "indices must be 1D"; + } + indices_size_ = indices_shape[0]; + if (grad_shape[0] != indices_size_) { + MS_LOG(ERROR) << "The first dimension of grad shape must be equal to indices"; + } + /* + if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { + use_nesterov_ = AnfAlgo::GetNodeAttr(kernel_node, "use_nesterov"); + } + */ + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); + workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float)); +} + +void SparseApplyLazyAdamPSKernel::ReInit( + const std::shared_ptr>>> &shapes) { + const std::vector>> &shape_vec = *shapes; + const std::vector &indices_shape = *(shape_vec[0]); + indices_size_ = indices_shape[0]; + workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); + workspace_size_list_[1] = indices_size_ * sizeof(int); +} + +void SparseApplyLazyAdamPSKernel::ReInit(const std::vector &inputs) { + const auto &indices_addr = inputs[10]; + indices_size_ = indices_addr->size / sizeof(int); + workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); + workspace_size_list_[1] = indices_size_ * sizeof(int); +} + +bool SparseApplyLazyAdamPSKernel::Execute(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + ReInit(inputs); + int *indices = reinterpret_cast(inputs[10]->addr); + for (size_t i = 0; i < inputs[10]->size / sizeof(int); i++) { + indices[i] -= rank_id_ * var_first_dim_size_; + } + return Launch(inputs, workspace, outputs); +} + +const std::vector &SparseApplyLazyAdamPSKernel::input_sizes() const { return GetInputSizeList(); } + +const std::vector &SparseApplyLazyAdamPSKernel::output_sizes() const { return GetOutputSizeList(); } + +const std::vector &SparseApplyLazyAdamPSKernel::workspace_sizes() const { return GetWorkspaceSizeList(); } +} // namespace ps +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h new file mode 100644 index 00000000000..9a8e3d0475a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_LAZY_ADAM_PS_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_LAZY_ADAM_PS_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" +#include "backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace ps { +using mindspore::kernel::SparseApplyLazyAdamCPUKernel; +class SparseApplyLazyAdamPSKernel : public SparseApplyLazyAdamCPUKernel, public PServerKernel { + public: + SparseApplyLazyAdamPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} + ~SparseApplyLazyAdamPSKernel() override = default; + + void InitKernel(const std::shared_ptr>>> &) override; + void ReInit(const std::shared_ptr>>> &) override; + bool Execute(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + const std::vector &input_sizes() const override; + const std::vector &output_sizes() const override; + const std::vector &workspace_sizes() const override; + + protected: + void ReInit(const std::vector &) override; +}; +} // namespace ps +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_LAZY_ADAM_PS_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h index 0b7a3222687..2235c22ea5a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h @@ -33,7 +33,7 @@ class SparseApplyLazyAdamCPUKernel : public CPUKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; - private: + protected: size_t indices_size_{0}; size_t var_first_dim_size_{0}; size_t var_outer_dim_size_{1}; diff --git a/mindspore/ccsrc/frontend/parallel/ps/common.h b/mindspore/ccsrc/frontend/parallel/ps/common.h index 3921df2bd9c..dc79441d205 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/common.h +++ b/mindspore/ccsrc/frontend/parallel/ps/common.h @@ -63,6 +63,7 @@ constexpr int kInitWeightToOptimIdCmd = 11; constexpr int kInitOptimInputsShapeCmd = 12; constexpr int kInitEmbeddingsCmd = 20; constexpr int kEmbeddingLookupCmd = 30; +constexpr int kFinalizeCmd = 40; constexpr size_t kInvalidKey = UINT64_MAX; diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc index 6ec68b84c9e..5f25b79c233 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc @@ -57,6 +57,16 @@ void DenseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { } } +void DenseOptimInfo::ComputeMean(size_t n) { + if (n > 1) { + float *accum_grad_data = reinterpret_cast(gradient()->addr); + size_t size = gradient()->size / sizeof(float); + for (size_t i = 0; i < size; i++) { + accum_grad_data[i] /= n; + } + } +} + void DenseOptimInfo::Reset() { memset_s(gradient()->addr, gradient()->size, 0x00, gradient()->size); } void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h index 36d2b58a0e3..dc567e023cc 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h @@ -33,6 +33,7 @@ class OptimizerInfo { virtual void Update(const Values &values, const Lengths &lengths) {} virtual void UpdateWeight(const WeightPtr &weight); virtual void Accumulate(const Values &values, const Lengths &lengths) = 0; + virtual void ComputeMean(size_t n) {} virtual void Reset() {} void AddWorkspace(const AddressPtr &workspace); @@ -58,6 +59,7 @@ class DenseOptimInfo : public OptimizerInfo { ~DenseOptimInfo() override = default; void Accumulate(const Values &values, const Lengths &lens) override; + void ComputeMean(size_t n) override; void Reset() override; }; diff --git a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h index 7105b887db0..2a32904ab05 100755 --- a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h +++ b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h @@ -41,7 +41,7 @@ #include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" #include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" -#include "backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h" +#include "backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h" #include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h" #include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h" #include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h" @@ -90,6 +90,7 @@ class ParameterServer { void HandleInitInputsShape(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); void HandleInitEmbeddings(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); + void HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); ParameterServer *ps_; typedef void (ServerHandler::*RequestHandler)(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, @@ -165,6 +166,7 @@ void ParameterServer::ServerHandler::Init() { handlers_[kInitOptimInputsShapeCmd] = &ServerHandler::HandleInitInputsShape; handlers_[kInitEmbeddingsCmd] = &ServerHandler::HandleInitEmbeddings; handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup; + handlers_[kFinalizeCmd] = &ServerHandler::HandleFinalize; } template @@ -256,16 +258,16 @@ void ParameterServer::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta ps_->DoEmbeddingLookup(key, req_data.keys.segment(1, req_data.keys.size()), res); } +template +void ParameterServer::ServerHandler::HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, + ::ps::KVPairs *res) { + ::ps::Finalize(0, false); +} + template bool ParameterServer::Init(const FuncGraphPtr &func_graph) { - const char *server_num = getenv(kEnvPServerNum); - const char *worker_num = getenv(kEnvWorkerNum); - if (server_num != nullptr) { - pserver_num_ = *server_num - '0'; - } - if (worker_num != nullptr) { - worker_num_ = *worker_num - '0'; - } + pserver_num_ = ::ps::NumServers(); + worker_num_ = ::ps::NumWorkers(); func_graph_ = func_graph; rank_id_ = ::ps::MyRank(); handler_.reset(new ServerHandler(this)); @@ -319,7 +321,7 @@ void ParameterServer::InitOptimInputsShape(const Keys &keys, const Values &va if (optimizers_.count(key) == 0 && optim_inputs_shape_.count(key) > 0) { if (optim_name == kSparseAdam) { std::shared_ptr optimizer = - std::make_shared(rank_id_, pserver_num_); + std::make_shared(rank_id_, pserver_num_); optimizer->InitKernel(optim_inputs_shape_[key]); optimizers_[key] = optimizer; } else if (optim_name == kApplyMomentum) { @@ -368,10 +370,11 @@ void ParameterServer::InitEmbeddingTable( } WeightPtr embedding = std::make_shared(total_dims, 0); + T *embedding_data = embedding->data(); std::default_random_engine engine; std::normal_distribution random(0, 0.01); for (size_t i = 0; i < total_dims; i++) { - (*embedding)[i] = random(engine); + embedding_data[i] = random(engine); } weights_[key] = embedding; @@ -402,6 +405,7 @@ void ParameterServer::UpdateWeights() { const std::vector &workspaces = optim_info->workspaces(); const std::vector &outputs = optim_info->outputs(); + optim_info->ComputeMean(worker_num_); optimizer->Execute(inputs, workspaces, outputs); optim_info->Reset(); } diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker.h b/mindspore/ccsrc/frontend/parallel/ps/worker.h index 411179ad900..abdc046ffe0 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker.h @@ -50,6 +50,7 @@ class Worker { void InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size); void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, const ::ps::SArray &lens, ::ps::SArray *lookup_result, int cmd); + void Finalize(); private: Worker() : kv_worker_(nullptr), running_(false), key_cnt_(0) {} @@ -118,6 +119,11 @@ void Worker::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const : kv_worker_->EmbeddingLookup(keys, lookup_ids, lens, lookup_result, cmd); } +template +void Worker::Finalize() { + kv_worker_->Finalize(); +} + template void Worker::InitPSParamData(const std::vector &keys, void *origin_addr, size_t size) { ::ps::SArray addr(reinterpret_cast(origin_addr), size / sizeof(T)); diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h index f437f721831..c8bd27c067a 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h @@ -58,6 +58,7 @@ class WorkerProxy : public ::ps::KVWorker { const ::ps::SArray &lens = {}, const Callback &cb = nullptr, int priority = 0); void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens = {}, int cmd = 0, int priority = 0); + void Finalize(); private: template @@ -146,6 +147,17 @@ void WorkerProxy::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::S obj_->WaitRequest(ts); } +template +void WorkerProxy::Finalize() { + int ts = obj_->NewRequest(::ps::kServerGroup); + ::ps::KVPairs kvs; + kvs.keys.push_back(0); + kvs.vals.push_back(0.0f); + Send(obj_, ts, true, false, kFinalizeCmd, kvs, broadcast_slicer_); + obj_->WaitRequest(ts); + ::ps::Finalize(0, false); +} + template template int WorkerProxy::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, diff --git a/tests/st/host_device/test_host_device_lenet.py b/tests/st/host_device/test_host_device_lenet.py index d1c49dc1e48..0a312a34221 100644 --- a/tests/st/host_device/test_host_device_lenet.py +++ b/tests/st/host_device/test_host_device_lenet.py @@ -75,7 +75,7 @@ def train(net, data, label): print(res) print("+++++++++++++++++++++++++++") diff = res.asnumpy()[0] - 2.3025851 - assert np.all(diff < 1.e-7) + assert np.all(diff < 1.e-6) @pytest.mark.level0