diff --git a/' b/' new file mode 100644 index 00000000000..a1d8df9ce94 --- /dev/null +++ b/' @@ -0,0 +1,735 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_WORKER_PROXY_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_WORKER_PROXY_H_ + +#include +#include +#include +#include +#include +#include +#include "ps/ps.h" +#include "frontend/parallel/ps/util.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace parallel { +namespace ps { +template +class WorkerProxy : public ::ps::KVWorker { + public: + using Worker = ::ps::KVWorker; + using Callback = std::function; + using SlicedKVs = std::vector>>; + using Slicer = std::function &send, const std::vector<::ps::Range> &ranges, + SlicedKVs *sliced, std::map &attrs)>; + using ::ps::SimpleApp::obj_; + explicit WorkerProxy(int app_id, int customer_id, int lookup_customer_id, int general_customer_id) + : Worker(app_id, customer_id) { + server_num_ = ::ps::NumServers(); + Util::SetRankId(::ps::MyRank()); + using std::placeholders::_1; + using std::placeholders::_2; + using std::placeholders::_3; + using std::placeholders::_4; + using std::placeholders::_5; + lookup_customer_ = std::unique_ptr<::ps::Customer>( + new ::ps::Customer(app_id, lookup_customer_id, std::bind(&WorkerProxy::ProcessLookupResult, this, _1))); + general_customer_ = std::unique_ptr<::ps::Customer>( + new ::ps::Customer(app_id, general_customer_id, std::bind(&WorkerProxy::ProcessResponse, this, _1))); + lookup_slicer_ = std::bind(&WorkerProxy::LookupIdSlicer, this, _1, _2, _3, _4, _5); + sparse_slicer_ = std::bind(&WorkerProxy::SparseSlicer, this, _1, _2, _3, _4, _5); + broadcast_slicer_ = std::bind(&WorkerProxy::BroadcastSlicer, this, _1, _2, _3, _4, _5); + round_robin_slicer_ = std::bind(&WorkerProxy::RoundRobinSlicer, this, _1, _2, _3, _4, _5); + worker_init_embedding_slicer_ = std::bind(&WorkerProxy::WorkerInitEmbeddingSlicer, this, _1, _2, _3, _4, _5); + } + ~WorkerProxy() override = default; + + void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); + void AddKeyToServerId(const ::ps::Key &key); + void EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + const ::ps::SArray &lens, ::ps::SArray *outs, int cmd = 0, const Callback &cb = nullptr, + int priority = 0); + int InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, + const ::ps::SArray &lens = {}, const Callback &cb = nullptr, int priority = 0); + bool IsReadyForPush(const Key &key); + bool IsReadyForPull(const Key &key); + void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens = {}, + int cmd = 0, int priority = 0); + void PushSparseData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens, + size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size); + void PullData(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray *vals, ::ps::SArray *lens = nullptr, + int cmd = 0, int priority = 0); + void Finalize(); + + private: + template + int AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, C *vals, int cmd, + const Callback &cb); + int AddGeneralRspCB(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray *vals, ::ps::SArray *lens, int cmd, + const Callback &cb); + void LookupIdSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced, std::map &attrs); + void SparseSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced, std::map &attrs); + void BroadcastSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced, std::map &attrs); + void RoundRobinSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced, std::map &attrs); + void WorkerInitEmbeddingSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced, std::map &attrs); + void ProcessLookupResult(const ::ps::Message &msg); + void ProcessResponse(const ::ps::Message &msg); + void Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, const ::ps::KVPairs &kvs, + const Slicer &slicer, std::map attrs = {}); + void AddKeyByHashMod(const ::ps::Key &key); + + void PrepareSparseGradient(const size_t begin, const size_t end, const std::vector &indice_ids, + const std::vector> &indice_to_grad, const int *all_indice, + const size_t segment_size, T *gradient, int *indice); + void ReduceSparseGradient(T *gradients, int *indices, const size_t indices_size, size_t segment_size, + const size_t first_dim_size, const size_t outer_dim_size, + mindspore::kernel::SparseGradient &unique_sparse_grad); + void BuildSparseValue(const ::ps::SArray &lengths, const size_t grad_index, const size_t indice_index, + const T *original_data, const T *grads, int *indices, ::ps::SArray &reduced_data); + + int server_num_; + std::unique_ptr<::ps::Customer> lookup_customer_; + std::unique_ptr<::ps::Customer> general_customer_; + std::unordered_map<::ps::Key, std::shared_ptr>> embedding_table_ranges_; + std::unordered_map>> lookup_results_; + std::unordered_map> gathered_response_; + std::mutex mutex_; + Slicer lookup_slicer_; + Slicer sparse_slicer_; + Slicer broadcast_slicer_; + Slicer round_robin_slicer_; + Slicer worker_init_embedding_slicer_; + std::unordered_map lookup_callbacks_; + std::unordered_map general_callbacks_; + std::unordered_map expected_result_count_; + std::unordered_map<::ps::Key, int> key_to_server_id_; + std::unordered_map<::ps::Key, size_t> embedding_row_cnt_; +}; + +template +void WorkerProxy::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) { + uint64_t begin = 0; + uint64_t end = 0; + for (int i = 0; i < server_num_; i++) { + int local_row_cnt = Util::LocalShard(row_count, i, server_num_); + if (i == 0) { + end = local_row_cnt - 1; + } else { + begin = end + 1; + end += local_row_cnt; + } + ::ps::Range range(begin, end); + if (embedding_table_ranges_.count(key) == 0) { + embedding_table_ranges_[key] = std::make_shared>(); + } + embedding_table_ranges_[key]->push_back(range); + } + embedding_row_cnt_[key] = row_count; +} + +template +void WorkerProxy::AddKeyByHashMod(const ::ps::Key &key) { + if (server_num_ == 0) { + MS_LOG(EXCEPTION) << "Server number is invalid:0"; + } + key_to_server_id_[key] = static_cast(key % server_num_); + MS_LOG(INFO) << "The server id of key " << key << " is " << key_to_server_id_[key]; +} + +template +void WorkerProxy::AddKeyToServerId(const ::ps::Key &key) { + AddKeyByHashMod(key); +} + +template +void WorkerProxy::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + const ::ps::SArray &lens, ::ps::SArray *outs, int cmd, const Callback &cb, + int priority) { + int ts = AddLookupCB(keys, lookup_ids, outs, cmd, cb); + ::ps::KVPairs kvs; + kvs.keys = keys; + kvs.lens = lookup_ids; + kvs.priority = priority; + expected_result_count_[ts] = 0; + Send(lookup_customer_.get(), ts, true, true, cmd, kvs, lookup_slicer_); + int expect_rt_count = expected_result_count_[ts]; + lookup_customer_->AddResponse(ts, server_num_ - expect_rt_count); + lookup_customer_->WaitRequest(ts); + expected_result_count_.erase(ts); +} + +template +int WorkerProxy::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, + const ::ps::SArray &lens, const Callback &cb, int priority) { + int ts = obj_->NewRequest(::ps::kServerGroup); + ::ps::KVPairs kvs; + kvs.keys = keys; + kvs.vals = vals; + kvs.lens = lens; + kvs.priority = priority; + Send(obj_, ts, true, false, kInitEmbeddingsCmd, kvs, broadcast_slicer_); + return ts; +} + +template +bool WorkerProxy::IsReadyForPush(const Key &key) { + ::ps::SArray result(1, 0); + PullData({key}, &result, nullptr, kCheckReadyForPushCmd); + if (result[0] > 0) { + return true; + } else { + return false; + } +} + +template +bool WorkerProxy::IsReadyForPull(const Key &key) { + ::ps::SArray result(1, 0); + PullData({key}, &result, nullptr, kCheckReadyForPullCmd); + if (result[0] > 0) { + return true; + } else { + return false; + } +} + +template +void WorkerProxy::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, + const ::ps::SArray &lens, int cmd, int priority) { + int ts = AddGeneralRspCB(keys, nullptr, nullptr, cmd, nullptr); + ::ps::KVPairs kvs; + kvs.keys = keys; + kvs.vals = vals; + kvs.lens = lens; + kvs.priority = priority; + if (embedding_table_ranges_.count(keys[0])) { + if (cmd == kInitWeightsCmd) { + Send(general_customer_.get(), ts, true, false, cmd, kvs, worker_init_embedding_slicer_); + } else { + Send(general_customer_.get(), ts, true, false, cmd, kvs, broadcast_slicer_); + } + } else { + Send(general_customer_.get(), ts, true, false, cmd, kvs, round_robin_slicer_); + } + if (expected_result_count_[ts] < server_num_) { + general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]); + } + general_customer_->WaitRequest(ts); +} + +template +void WorkerProxy::PushSparseData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, + const ::ps::SArray &lens, size_t grad_index, size_t indice_index, + size_t first_dim_size, size_t outer_dim_size) { + int ts = AddGeneralRspCB(keys, nullptr, nullptr, 0, nullptr); + ::ps::KVPairs kvs; + kvs.keys = keys; + kvs.vals = vals; + kvs.lens = lens; + int cmd = 0; + if (embedding_table_ranges_.count(keys[0])) { + std::map attrs{{0, grad_index}, {1, indice_index}, {2, first_dim_size}, {3, outer_dim_size}}; + Send(general_customer_.get(), ts, true, false, cmd, kvs, sparse_slicer_, attrs); + } else { + Send(general_customer_.get(), ts, true, false, cmd, kvs, round_robin_slicer_); + } + if (expected_result_count_[ts] < server_num_) { + general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]); + } + general_customer_->WaitRequest(ts); +} + +template +void WorkerProxy::PullData(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray *vals, ::ps::SArray *lens, + int cmd, int priority) { + int ts = AddGeneralRspCB(keys, vals, lens, cmd, nullptr); + ::ps::KVPairs kvs; + kvs.keys = keys; + kvs.priority = priority; + if (embedding_table_ranges_.count(keys[0])) { + Send(general_customer_.get(), ts, false, true, cmd, kvs, broadcast_slicer_); + } else { + Send(general_customer_.get(), ts, false, true, cmd, kvs, round_robin_slicer_); + } + if (expected_result_count_[ts] < server_num_) { + general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]); + } + general_customer_->WaitRequest(ts); +} + +template +void WorkerProxy::Finalize() { + int ts = obj_->NewRequest(::ps::kServerGroup); + ::ps::KVPairs kvs; + kvs.keys.push_back(0); + kvs.vals.push_back(0.0f); + Send(obj_, ts, true, false, kFinalizeCmd, kvs, broadcast_slicer_); + obj_->WaitRequest(ts); + ::ps::Finalize(0, true); +} + +template +template +int WorkerProxy::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + C *lookup_result, int cmd, const Callback &cb) { + int ts = lookup_customer_->NewRequest(::ps::kServerGroup); + const auto &callback = [this, ts, keys, lookup_ids, lookup_result, cb]() mutable { + mutex_.lock(); + auto &kvs = lookup_results_[ts]; + mutex_.unlock(); + + std::unordered_map>> id_addr_map; + for (const auto &s : kvs) { + int offset = 0; + int len = s.vals.size() / s.keys.size(); + for (size_t i = 0; i < s.keys.size(); i++) { + const Key &key = s.keys[i]; + T *addr = s.vals.data() + offset; + offset += len; + id_addr_map[key] = std::make_shared>(std::make_pair(addr, len)); + } + } + + T *result_addr = lookup_result->data(); + int offset = 0; + for (size_t i = 0; i < lookup_ids.size(); i++) { + auto &pair = id_addr_map[static_cast(lookup_ids[i])]; + int size = pair->second * sizeof(T); + auto ret = memcpy_s(result_addr + offset, size, pair->first, size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + offset += pair->second; + } + + mutex_.lock(); + lookup_results_.erase(ts); + mutex_.unlock(); + if (cb) cb(); + }; + lookup_callbacks_[ts] = callback; + return ts; +} + +template +int WorkerProxy::AddGeneralRspCB(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray *vals, ::ps::SArray *lens, + int cmd, const Callback &cb) { + int ts = general_customer_->NewRequest(::ps::kServerGroup); + const auto &callback = [this, ts, keys, vals, lens, cb]() mutable { + mutex_.lock(); + auto &kvs = gathered_response_[ts]; + mutex_.unlock(); + + *vals = kvs.vals; + if (lens) { + *lens = kvs.lens; + } + + mutex_.lock(); + gathered_response_.erase(ts); + mutex_.unlock(); + if (cb) { + cb(); + } + }; + general_callbacks_[ts] = callback; + return ts; +} + +template +void WorkerProxy::LookupIdSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced, std::map &attrs) { + int *lookup_ids = send.lens.data(); + size_t id_size = send.lens.size(); + + const Key &key = send.keys[0]; + const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); + sliced->resize(ranges.size()); + + for (size_t i = 0; i < ranges.size(); i++) { + const ::ps::Range &range = ranges[i]; + const auto &begin = range.begin(); + const auto &end = range.end(); + std::unordered_set unique_ids; + auto &kvs = sliced->at(i).second; + + kvs.keys.push_back(key); + kvs.vals.push_back(0.0f); + + for (size_t j = 0; j < id_size; j++) { + auto lookup_id = static_cast(lookup_ids[j]); + if (lookup_id >= begin && lookup_id <= end) { + unique_ids.insert(lookup_id); + } + } + for (const auto &lookup_id : unique_ids) { + kvs.keys.push_back(lookup_id); + kvs.vals.push_back(0.0f); + } + + if (kvs.keys.size() <= 1) { + sliced->at(i).first = false; + } else { + sliced->at(i).first = true; + expected_result_count_[timestamp] += 1; + } + } +} + +template +void WorkerProxy::SparseSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced, std::map &attrs) { + // Init variables + T *data = send.vals.data(); + size_t grad_index = static_cast(attrs[0]); + size_t indice_index = static_cast(attrs[1]); + size_t first_dim_size = static_cast(attrs[2]); + size_t outer_dim_size = static_cast(attrs[3]); + + int grad_size = send.lens[grad_index]; + int indice_size = send.lens[indice_index]; + int segment_size = grad_size / indice_size; + + int grad_offset = 0; + int indice_offset = 0; + for (size_t i = 0; i < grad_index; i++) { + grad_offset += send.lens[i]; + } + for (size_t j = 0; j < indice_index; j++) { + indice_offset += send.lens[j]; + } + + T *grad_data = data + grad_offset; + int *indice_data = reinterpret_cast(data) + indice_offset; + + // Build the mappings of indice to gradient + std::vector> indice_to_grads; + for (int i = 0; i < indice_size; i++) { + int indice = indice_data[i]; + T *grad = grad_data + i * segment_size; + indice_to_grads.push_back(std::make_pair(indice, grad)); + } + + const Key &key = send.keys[0]; + const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); + sliced->resize(ranges.size()); + + // Construct reduced sparse data for each server + for (size_t i = 0; i < ranges.size(); i++) { + const ::ps::Range &range = ranges[i]; + const auto &begin = range.begin(); + const auto &end = range.end(); + auto &kvs = sliced->at(i).second; + kvs.keys = send.keys; + kvs.lens = send.lens; + + // Prepare the sparse gradient and indice + std::vector indice_ids; + for (int j = 0; j < indice_size; j++) { + size_t indice = static_cast(indice_data[j]); + if (indice >= begin && indice <= end) { + indice_ids.push_back(indice); + } + } + size_t indices_size = indice_ids.size(); + int slice_segment_size = indices_size * segment_size; + T *src_grad_data = new T[slice_segment_size]; + int *src_indice_data = new int[indices_size]; + PrepareSparseGradient(begin, end, indice_ids, indice_to_grads, indice_data, segment_size, src_grad_data, + src_indice_data); + + // Reduce the sparse gradient and indice + T *new_grad = new T[slice_segment_size]; + int *new_indices = new int[indices_size]; + mindspore::kernel::SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size}); + ReduceSparseGradient(src_grad_data, src_indice_data, indices_size, segment_size, first_dim_size, outer_dim_size, + unique_sparse_grad); + + // Update the length of reduce sparse gradient and indice + kvs.lens[grad_index] = unique_sparse_grad.indices_size_ * segment_size; + kvs.lens[indice_index] = unique_sparse_grad.indices_size_; + + // Build the sparse value to be sent + size_t total_size = 0; + for (auto size : kvs.lens) { + total_size += size; + } + ::ps::SArray reduced_data(total_size, 0); + BuildSparseValue(kvs.lens, grad_index, indice_index, data, unique_sparse_grad.value_, unique_sparse_grad.indices_, + reduced_data); + kvs.vals = reduced_data; + + if (indices_size <= 0) { + sliced->at(i).first = false; + } else { + sliced->at(i).first = true; + expected_result_count_[timestamp] += 1; + } + } +} + +template +void WorkerProxy::PrepareSparseGradient(const size_t begin, const size_t end, const std::vector &indice_ids, + const std::vector> &indice_to_grads, + const int *all_indice, const size_t segment_size, T *gradient, + int *indices) { + int offset = 0; + int index = 0; + size_t segment_data_size = segment_size * sizeof(T); + for (auto &pair : indice_to_grads) { + indices[index++] = pair.first; + auto ret = memcpy_s(gradient + offset, segment_data_size, pair.second, segment_data_size); + if (ret != 0) { + MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; + } + offset += segment_size; + } +} + +template +void WorkerProxy::ReduceSparseGradient(T *gradients, int *indices, const size_t indices_size, size_t segment_size, + const size_t first_dim_size, const size_t outer_dim_size, + mindspore::kernel::SparseGradient &unique_sparse_grad) { + size_t slice_segment_size = indices_size * segment_size; + auto workspace_grad = new T[slice_segment_size]; + auto workspace_indices = new int[indices_size]; + + MS_EXCEPTION_IF_NULL(gradients); + MS_EXCEPTION_IF_NULL(indices); + MS_EXCEPTION_IF_NULL(workspace_grad); + MS_EXCEPTION_IF_NULL(workspace_indices); + + mindspore::kernel::SparseGradient workspace_sparse_grad({workspace_grad, workspace_indices, indices_size}); + mindspore::kernel::SparseGradient input_sparse_grad({gradients, indices, indices_size}); + ReduceSparseGradientParam param; + param.input_grad_ = &input_sparse_grad; + param.workspace_grad_ = &workspace_sparse_grad; + param.output_grad_ = &unique_sparse_grad; + param.max_index_ = first_dim_size; + param.value_stride_ = outer_dim_size; + + BucketReduceSparseGradient(param); +} + +template +void WorkerProxy::BuildSparseValue(const ::ps::SArray &lengths, const size_t grad_index, + const size_t indice_index, const T *original_data, const T *grads, int *indices, + ::ps::SArray &reduced_data) { + int offset = 0; + for (size_t i = 0; i < lengths.size(); i++) { + if (i != grad_index && i != indice_index) { + int data_size = lengths[i] * sizeof(T); + auto ret = memcpy_s(reduced_data.data() + offset, data_size, original_data + offset, data_size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + } + offset += lengths[i]; + } + + // Fill the reduced gradient + int grad_offset = 0; + for (size_t i = 0; i < grad_index; i++) { + grad_offset += lengths[i]; + } + int data_size = lengths[grad_index] * sizeof(T); + auto ret = memcpy_s(reduced_data.data() + grad_offset, data_size, grads, data_size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + + // Fill the reduced indice + data_size = lengths[indice_index] * sizeof(T); + int indice_offset = grad_offset + data_size; + T *indice_data = reduced_data.data() + indice_offset; + T *convert = new T[lengths[indice_index]]; + for (int i = 0; i < lengths[indice_index]; i++) { + convert[i] = static_cast(indices[i]); + } + ret = memcpy_s(indice_data, data_size, convert, data_size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + delete[] convert; +} + +template +void WorkerProxy::BroadcastSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced, std::map &attr) { + sliced->resize(server_num_); + for (int i = 0; i < server_num_; i++) { + sliced->at(i).first = true; + sliced->at(i).second = send; + expected_result_count_[timestamp] += 1; + } +} + +template +void WorkerProxy::RoundRobinSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced, + std::map &attr) { + sliced->resize(server_num_); + auto keys = send.keys; + auto vals = send.vals; + auto lens = send.lens; + + int server_id, len; + ::ps::Key param_key; + for (size_t i = 0; i < keys.size(); i++) { + param_key = keys[i]; + server_id = key_to_server_id_[param_key]; + if (!sliced->at(server_id).first) { + sliced->at(server_id).first = true; + expected_result_count_[timestamp] += 1; + } + + ::ps::KVPairs &server_kv_pairs = sliced->at(server_id).second; + server_kv_pairs.keys.push_back(param_key); + if (vals.empty()) { + continue; + } + + len = lens[i]; + int offset = std::accumulate(lens.begin(), lens.begin() + i, 0); + auto val_begin = vals.begin() + offset; + auto val_end = val_begin + len; + + for (auto iter = val_begin; iter != val_end; iter++) { + server_kv_pairs.vals.push_back(*iter); + } + server_kv_pairs.lens.push_back(len); + } +} + +template +void WorkerProxy::WorkerInitEmbeddingSlicer(int timestamp, const ::ps::KVPairs &send, + const std::vector<::ps::Range> &, + std::vector>> *sliced, + std::map &attrs) { + sliced->resize(server_num_); + auto keys = send.keys; + auto vals = send.vals; + auto lens = send.lens; + + size_t col_cnt = lens[0] / embedding_row_cnt_[keys[0]]; + const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[keys[0]]); + for (size_t i = 0; i < ranges.size(); i++) { + size_t offset_begin = ranges[i].begin() * col_cnt; + size_t offset_end = (ranges[i].end() + 1) * col_cnt; + ::ps::KVPairs kvs; + kvs.keys = keys; + kvs.vals = vals.segment(offset_begin, offset_end); + kvs.lens.push_back(offset_end - offset_begin); + sliced->at(i).first = true; + sliced->at(i).second = kvs; + } +} + +template +void WorkerProxy::ProcessLookupResult(const ::ps::Message &msg) { + int ts = msg.meta.timestamp; + if (msg.meta.pull) { + CHECK_GE(msg.data.size(), (size_t)2); + ::ps::KVPairs kvs; + kvs.keys = msg.data[0]; + kvs.vals = msg.data[1]; + if (msg.data.size() > (size_t)2) { + kvs.lens = msg.data[2]; + } + mutex_.lock(); + lookup_results_[ts].push_back(kvs); + mutex_.unlock(); + } + if (lookup_customer_->NumResponse(ts) == expected_result_count_[ts] - 1) { + const auto &cb = lookup_callbacks_[ts]; + cb(); + lookup_callbacks_.erase(ts); + } +} + +template +void WorkerProxy::ProcessResponse(const ::ps::Message &msg) { + int ts = msg.meta.timestamp; + + if (msg.meta.pull) { + CHECK_GE(msg.data.size(), (size_t)2); + ::ps::KVPairs kvs; + kvs.keys = msg.data[0]; + kvs.vals = msg.data[1]; + if (msg.data.size() > (size_t)2) { + kvs.lens = msg.data[2]; + } + mutex_.lock(); + for (auto key : kvs.keys) { + gathered_response_[ts].keys.push_back(key); + } + for (auto val : kvs.vals) { + gathered_response_[ts].vals.push_back(val); + } + for (auto len : kvs.lens) { + gathered_response_[ts].lens.push_back(len); + } + mutex_.unlock(); + if (general_customer_->NumResponse(ts) + 1 == server_num_) { + const auto &cb = general_callbacks_[ts]; + cb(); + general_callbacks_.erase(ts); + } + } +} + +template +void WorkerProxy::Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, + const ::ps::KVPairs &kvs, const Slicer &slicer, std::map attrs) { + SlicedKVs sliced; + slicer(timestamp, kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced, attrs); + + for (size_t i = 0; i < sliced.size(); i++) { + const auto &s = sliced[i]; + if (!s.first) continue; + ::ps::Message msg; + msg.meta.app_id = customer->app_id(); + msg.meta.customer_id = customer->customer_id(); + msg.meta.request = true; + msg.meta.push = push; + msg.meta.pull = pull; + msg.meta.head = cmd; + msg.meta.timestamp = timestamp; + msg.meta.recver = ::ps::Postoffice::Get()->ServerRankToID(i); + msg.meta.priority = kvs.priority; + const auto &kvs = s.second; + if (kvs.keys.size()) { + msg.AddData(kvs.keys); + msg.AddData(kvs.vals); + if (kvs.lens.size()) { + msg.AddData(kvs.lens); + } + } + ::ps::Postoffice::Get()->van()->Send(msg); + } +} +} // namespace ps +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_WORKER_PROXY_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc index a9c423973dd..25af4fd4f0f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc @@ -71,8 +71,8 @@ void SparseApplyAdamPSKernel::ReInit(const std::shared_ptr>> &shape_vec = *shapes; const std::vector &indices_shape = *(shape_vec[0]); indices_size_ = indices_shape[0]; - workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); - workspace_size_list_[1] = indices_size_ * sizeof(int); + workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_; + workspace_size_list_[1] = indices_size_ * sizeof(int) * worker_num_; } void SparseApplyAdamPSKernel::ReInit(const std::vector &inputs) { @@ -85,10 +85,6 @@ void SparseApplyAdamPSKernel::ReInit(const std::vector &inputs) { bool SparseApplyAdamPSKernel::Execute(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) { ReInit(inputs); - int *indices = reinterpret_cast(inputs[10]->addr); - for (size_t i = 0; i < inputs[10]->size / sizeof(int); i++) { - indices[i] -= row_offset_; - } return Launch(inputs, workspace, outputs); } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc index 0e0ca08518c..cebb157c3c4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc @@ -74,15 +74,15 @@ void SparseApplyFtrlPSKernel::ReInit(const std::shared_ptr>> &shape_vec = *shapes; std::vector indices_shape = *(shape_vec[0]); indices_size_ = indices_shape[0]; - workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); - workspace_size_list_[1] = indices_size_ * sizeof(int); + workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_; + workspace_size_list_[1] = indices_size_ * sizeof(int) * worker_num_; } void SparseApplyFtrlPSKernel::ReInit(const std::vector &inputs) { const auto &indices_addr = inputs[4]; indices_size_ = indices_addr->size / sizeof(int); - workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); - workspace_size_list_[1] = indices_size_ * sizeof(int); + workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_; + workspace_size_list_[1] = indices_size_ * sizeof(int) * worker_num_; } bool SparseApplyFtrlPSKernel::Execute(const std::vector &inputs, const std::vector &workspace, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc index a148f0f3d8f..da9f7463da0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc @@ -71,15 +71,15 @@ void SparseApplyLazyAdamPSKernel::ReInit( const std::vector>> &shape_vec = *shapes; const std::vector &indices_shape = *(shape_vec[0]); indices_size_ = indices_shape[0]; - workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); - workspace_size_list_[1] = indices_size_ * sizeof(int); + workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_; + workspace_size_list_[1] = indices_size_ * sizeof(int) * worker_num_; } void SparseApplyLazyAdamPSKernel::ReInit(const std::vector &inputs) { const auto &indices_addr = inputs[10]; indices_size_ = indices_addr->size / sizeof(int); - workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); - workspace_size_list_[1] = indices_size_ * sizeof(int); + workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_; + workspace_size_list_[1] = indices_size_ * sizeof(int) * worker_num_; } bool SparseApplyLazyAdamPSKernel::Execute(const std::vector &inputs, diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc index dd8fedee6d6..f1fa5c029dd 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc @@ -16,6 +16,7 @@ #include "frontend/parallel/ps/optimizer_info.h" #include +#include "frontend/parallel/ps/util.h" namespace mindspore { namespace parallel { @@ -30,6 +31,8 @@ const std::vector &OptimizerInfo::outputs() { return outputs_; } bool OptimizerInfo::IsSparse() const { return false; } +const size_t OptimizerInfo::indice_size() const { return 0; } + size_t OptimizerInfo::grad_index() { return 0; } size_t OptimizerInfo::indices_index() { return 0; } @@ -57,7 +60,8 @@ void DenseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { } } -void DenseOptimInfo::ComputeMean(size_t n) { +void DenseOptimInfo::ComputeMean(const std::shared_ptr>>> &, size_t n, + size_t server_num, size_t rank_id) { if (n > 1) { float *accum_grad_data = reinterpret_cast(gradient()->addr); size_t size = gradient()->size / sizeof(float); @@ -96,15 +100,88 @@ void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { for (size_t i = 0; i < indices_index; i++) { indice_offset += lengths[i]; } - int *incr_indice_data = reinterpret_cast(values.data() + indice_offset); - size_t incr_indice_size = lengths[indices_index] * sizeof(float); + float *incr_indice_data = values.data() + indice_offset; + size_t incr_indice_size = lengths[indices_index]; + size_t incr_indice_data_size = incr_indice_size * sizeof(int); + int *converted_indices = new int[incr_indice_size]; + for (size_t i = 0; i < incr_indice_size; i++) { + converted_indices[i] = static_cast(incr_indice_data[i]); + } - auto ret2 = memcpy_s(accum_indices_data + indices_offset_, incr_indice_size, incr_indice_data, incr_indice_size); + auto ret2 = + memcpy_s(accum_indices_data + indices_offset_, incr_indice_data_size, converted_indices, incr_indice_data_size); if (ret2 != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; } + delete[] converted_indices; indices_offset_ += lengths[indices_index]; - indices()->size += incr_indice_size; + indices()->size += incr_indice_data_size; +} + +void SparseOptimInfo::ComputeMean(const std::shared_ptr>>> &shapes, + size_t n, size_t server_num, size_t rank_id) { + size_t indices_size = static_cast(indices()->size / sizeof(int)); + int segment_size = gradient()->size / indices()->size; + + float *new_grad = new float[indices_size * segment_size]; + int *new_indices = new int[indices_size]; + mindspore::kernel::SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size}); + + const std::vector>> &shape_vec = *shapes; + if (shape_vec.size() < 2 || shape_vec[1] == nullptr) { + MS_LOG(EXCEPTION) << "No input shape found"; + } + auto input_shapes = shape_vec.size() > 0 ? shape_vec[1] : nullptr; + MS_EXCEPTION_IF_NULL(input_shapes); + if (input_shapes->size() == 0) { + MS_LOG(EXCEPTION) << "Invalid input shapes"; + } + int first_dim_size = input_shapes->front(); + int outer_dim_size = segment_size; + + if (first_dim_size == 0 || outer_dim_size == 0) { + MS_LOG(ERROR) << "Invalid first dim size"; + } + + float *grad_data = reinterpret_cast(gradient()->addr); + int *indices_data = reinterpret_cast(indices()->addr); + + size_t original_row_count = input_shapes->front(); + if (original_row_count > 0) { + size_t offset = 0; + if ((original_row_count % server_num) == 0) { + offset = original_row_count / server_num * rank_id; + } else { + offset = std::round((static_cast(original_row_count)) / server_num) * rank_id; + } + for (size_t i = 0; i < indices_size; i++) { + indices_data[i] -= offset; + } + } + + Util::ReduceSparseGradient(grad_data, indices_data, indices_size, segment_size, first_dim_size, outer_dim_size, + &unique_sparse_grad); + + int reduced_grad_size = unique_sparse_grad.indices_size_ * segment_size * sizeof(float); + auto ret = memcpy_s(gradient()->addr, reduced_grad_size, unique_sparse_grad.value_, reduced_grad_size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + int reduced_indice_size = unique_sparse_grad.indices_size_ * sizeof(int); + ret = memcpy_s(indices()->addr, reduced_indice_size, unique_sparse_grad.indices_, reduced_indice_size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + + gradient()->size = reduced_grad_size; + indices()->size = reduced_indice_size; + + for (size_t i = 0; i < unique_sparse_grad.indices_size_ * segment_size; i++) { + grad_data[i] = grad_data[i] / n; + } + + delete[] new_grad; + delete[] new_indices; } void SparseOptimInfo::Reset() { @@ -135,6 +212,8 @@ void MomentumOptimInfo::Update(const Values &values, const Lengths &lens) { } } +const size_t SparseOptimInfo::indice_size() const { return indices_offset_; } + const AddressPtr &MomentumOptimInfo::gradient() { return inputs_[3]; } const AddressPtr &MomentumOptimInfo::indices() { return inputs_[3]; } diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h index f59d8ad6c1d..0edb2d21566 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h @@ -18,6 +18,7 @@ #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_OPTIMIZER_INFO_H_ #include +#include #include "backend/kernel_compiler/kernel.h" #include "frontend/parallel/ps/common.h" @@ -33,12 +34,14 @@ class OptimizerInfo { virtual void Update(const Values &values, const Lengths &lengths) {} virtual void UpdateWeight(const WeightPtr &weight); virtual void Accumulate(const Values &values, const Lengths &lengths) = 0; - virtual void ComputeMean(size_t n) {} + virtual void ComputeMean(const std::shared_ptr>>> &shapes, size_t n, + size_t server_num, size_t rank_id) {} virtual void Reset() {} void AddWorkspace(const AddressPtr &workspace); virtual const AddressPtr &gradient() = 0; virtual const AddressPtr &indices() = 0; + virtual const size_t indice_size() const; const std::vector &inputs(); const std::vector &workspaces(); const std::vector &outputs(); @@ -59,7 +62,8 @@ class DenseOptimInfo : public OptimizerInfo { ~DenseOptimInfo() override = default; void Accumulate(const Values &values, const Lengths &lens) override; - void ComputeMean(size_t n) override; + void ComputeMean(const std::shared_ptr>>> &shapes, size_t n, + size_t server_num, size_t rank_id) override; void Reset() override; }; @@ -69,7 +73,10 @@ class SparseOptimInfo : public OptimizerInfo { ~SparseOptimInfo() override = default; void Accumulate(const Values &values, const Lengths &lens) override; + void ComputeMean(const std::shared_ptr>>> &shapes, size_t n, + size_t server_num, size_t rank_id) override; void Reset() override; + const size_t indice_size() const override; protected: size_t grads_offset_{0}; diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc index 03b11c6f415..ace0dc0fb36 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc @@ -136,15 +136,21 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const std::shared_ptr> &indices_shape = (*inputs_shape)[10]; size_t total_indice_size = - std::accumulate((*indices_shape).begin(), (*indices_shape).end(), sizeof(float), std::multiplies()); + std::accumulate((*indices_shape).begin(), (*indices_shape).end(), sizeof(int), std::multiplies()); AddressPtr indices = std::make_shared(); - indices->addr = new float[total_indice_size * worker_num]; - ret = memcpy_s(indices->addr, lens[7] * sizeof(float), reinterpret_cast(epsilon->addr) + lens[5] + lens[6], - lens[7] * sizeof(float)); + indices->addr = new int[total_indice_size * worker_num]; + int *converted_indices = new int[lens[7]]; + size_t indices_data_size = lens[7] * sizeof(int); + float *indices_data = reinterpret_cast(epsilon->addr) + lens[5] + lens[6]; + for (int i = 0; i < lens[7]; i++) { + converted_indices[i] = static_cast(indices_data[i]); + } + ret = memcpy_s(indices->addr, indices_data_size, converted_indices, indices_data_size); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; } - indices->size = lens[7] * sizeof(int); + indices->size = indices_data_size; + delete[] converted_indices; return new SparseAdamOptimInfo(weight_addr, m, v, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, grad, indices); @@ -185,13 +191,19 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, size_t total_indice_size = std::accumulate((*indices_shape).begin(), (*indices_shape).end(), 1, std::multiplies()); AddressPtr indices = std::make_shared(); - indices->addr = new float[total_indice_size * worker_num]; - ret = memcpy_s(indices->addr, lens[1] * sizeof(float), reinterpret_cast(values.data()) + lens[0], - lens[1] * sizeof(float)); + indices->addr = new int[total_indice_size * worker_num]; + int *converted_indices = new int[lens[1]]; + size_t indices_data_size = lens[1] * sizeof(int); + float *indices_data = reinterpret_cast(values.data()) + lens[0]; + for (int i = 0; i < lens[1]; i++) { + converted_indices[i] = static_cast(indices_data[i]); + } + ret = memcpy_s(indices->addr, indices_data_size, converted_indices, indices_data_size); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; } - indices->size = lens[1] * sizeof(int); + indices->size = indices_data_size; + delete[] converted_indices; return new SparseFtrlOptimInfo(weight_addr, accum, linear, grad, indices); } diff --git a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h index 0de16b13f2e..2032e177f57 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h +++ b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h @@ -145,6 +145,7 @@ class ParameterServer { std::unordered_map> optimizers_; std::unordered_map optim_inputs_shape_; + std::unordered_map original_optim_inputs_shape_; std::unordered_map> optim_infos_; std::unordered_map> optim_info_builders_; std::unordered_map weight_key_to_optims_; @@ -366,19 +367,24 @@ void ParameterServer::InitWeightKeyToOptims(const Key &key, const int &optim_ template void ParameterServer::InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths) { InputsShapePtr inputs_shape = std::make_shared(); + InputsShapePtr original_inputs_shape = std::make_shared(); int val_idx = 0; const Key &key = keys[0]; MS_LOG(INFO) << "Initializing optimizer inputs shape for key:" << key; if (optim_inputs_shape_.count(key) == 0) { + original_optim_inputs_shape_[key] = original_inputs_shape; optim_inputs_shape_[key] = inputs_shape; } for (size_t i = 0; i < keys.size(); i++) { auto shape = std::make_shared>(); + auto original_shape = std::make_shared>(); inputs_shape->push_back(shape); + original_inputs_shape->push_back(original_shape); int len = lengths[i]; for (int j = 0; j < len; j++) { - shape->push_back(values[val_idx++]); + shape->push_back(values[val_idx]); + original_shape->push_back(values[val_idx++]); } } if (weight_key_to_optims_.count(key) > 0) { @@ -512,7 +518,19 @@ void ParameterServer::UpdateWeights() { const std::vector &workspaces = optim_info->workspaces(); const std::vector &outputs = optim_info->outputs(); - optim_info->ComputeMean(worker_num_); + std::shared_ptr>>> shapes = + std::make_shared>>>(); + std::shared_ptr> indices_shape = std::make_shared>(); + indices_shape->emplace_back(optim_info->indice_size()); + shapes->push_back(indices_shape); + + if (original_optim_inputs_shape_.count(key) != 0) { + for (auto &input_shapes : *(original_optim_inputs_shape_[key])) { + shapes->push_back(input_shapes); + } + } + optimizer->ReInit(shapes); + optim_info->ComputeMean(shapes, worker_num_, pserver_num_, rank_id_); optimizer->Execute(inputs, workspaces, outputs); optim_info->Reset(); if (!is_embedding_[key]) { diff --git a/mindspore/ccsrc/frontend/parallel/ps/util.cc b/mindspore/ccsrc/frontend/parallel/ps/util.cc index 2d96c3f1274..c7acb4a577f 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/util.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/util.cc @@ -146,6 +146,32 @@ int Util::LocalShard(int first_dim, int rank_id, int server_num) { void Util::SetRankId(int rank_id) { rank_id_ = rank_id; } int Util::GetRankId() { return rank_id_; } + +void Util::ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size, + const size_t first_dim_size, const size_t outer_dim_size, + mindspore::kernel::SparseGradient *unique_sparse_grad) { + size_t slice_segment_size = indices_size * segment_size; + auto workspace_grad = new float[slice_segment_size]; + auto workspace_indices = new int[indices_size]; + + MS_EXCEPTION_IF_NULL(gradients); + MS_EXCEPTION_IF_NULL(indices); + MS_EXCEPTION_IF_NULL(workspace_grad); + MS_EXCEPTION_IF_NULL(workspace_indices); + + mindspore::kernel::SparseGradient workspace_sparse_grad({workspace_grad, workspace_indices, indices_size}); + mindspore::kernel::SparseGradient input_sparse_grad({gradients, indices, indices_size}); + mindspore::kernel::ReduceSparseGradientParam param; + param.input_grad_ = &input_sparse_grad; + param.workspace_grad_ = &workspace_sparse_grad; + param.output_grad_ = unique_sparse_grad; + param.max_index_ = first_dim_size; + param.value_stride_ = outer_dim_size; + + BucketReduceSparseGradient(param); + delete[] workspace_grad; + delete[] workspace_indices; +} } // namespace ps } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ps/util.h b/mindspore/ccsrc/frontend/parallel/ps/util.h index 242012912f8..5177e63b2b9 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/util.h +++ b/mindspore/ccsrc/frontend/parallel/ps/util.h @@ -21,6 +21,7 @@ #include #include #include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/common_utils.h" namespace mindspore { namespace parallel { @@ -39,6 +40,9 @@ class Util { static int LocalShard(int first_dim, int rank_id, int server_num); static void SetRankId(int rank_id); static int GetRankId(); + static void ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size, + const size_t first_dim_size, const size_t outer_dim_size, + mindspore::kernel::SparseGradient *unique_sparse_grad); private: static std::unordered_map optimizer_to_ids; diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker.h b/mindspore/ccsrc/frontend/parallel/ps/worker.h index 7c969387737..dc8bccf9f72 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker.h @@ -95,6 +95,32 @@ void Worker::Run() { template void Worker::Push(const std::vector &keys, std::vector addrs, const std::vector &sizes) { + if (keys.size() == 0) { + MS_LOG(EXCEPTION) << "key size should be greater than zero"; + } + if (key_to_optimId_.count(keys[0]) == 0) { + MS_LOG(EXCEPTION) << "no optim id found for key" << keys[0]; + } + Key key = keys[0]; + int optim_id = key_to_optimId_[key]; + bool is_sparse = false; + if (optim_id == 1 || optim_id == 2 || optim_id == 3) { + is_sparse = true; + } + int grad_index = -1; + int indice_index = -1; + + // Sparse adam gradient + if (optim_id == 1 || optim_id == 2) { + grad_index = 6; + indice_index = 7; + + // Sparse ftrl gradient + } else if (optim_id == 3) { + grad_index = 0; + indice_index = 1; + } + size_t total_size = 0; for (auto size : sizes) { total_size += size; @@ -109,10 +135,22 @@ void Worker::Push(const std::vector &keys, std::vector add } offset += sizes[i] * sizeof(T); } + while (!kv_worker_->IsReadyForPush(keys[0])) { continue; } - kv_worker_->PushData(::ps::SArray<::ps::Key>(keys), total_buffer, ::ps::SArray(sizes)); + if (!is_sparse) { + kv_worker_->PushData(::ps::SArray<::ps::Key>(keys), total_buffer, ::ps::SArray(sizes)); + } else { + std::vector &var_shape = key_to_optim_shapes_[key][0]; + int first_dim_size = var_shape[0]; + int outer_dim_size = 1; + for (size_t i = 1; i < var_shape.size(); ++i) { + outer_dim_size *= var_shape[i]; + } + kv_worker_->PushSparseData(::ps::SArray<::ps::Key>(keys), total_buffer, ::ps::SArray(sizes), grad_index, + indice_index, first_dim_size, outer_dim_size); + } } template diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h index f0a2adbc486..6809f137e00 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h @@ -17,14 +17,16 @@ #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_WORKER_PROXY_H_ #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_WORKER_PROXY_H_ +#include #include +#include #include #include #include #include -#include #include "ps/ps.h" #include "frontend/parallel/ps/util.h" +#include "backend/kernel_compiler/common_utils.h" namespace mindspore { namespace parallel { @@ -36,7 +38,7 @@ class WorkerProxy : public ::ps::KVWorker { using Callback = std::function; using SlicedKVs = std::vector>>; using Slicer = std::function &send, const std::vector<::ps::Range> &ranges, - SlicedKVs *sliced)>; + SlicedKVs *sliced, const std::map &attrs)>; using ::ps::SimpleApp::obj_; explicit WorkerProxy(int app_id, int customer_id, int lookup_customer_id, int general_customer_id) : Worker(app_id, customer_id) { @@ -46,14 +48,16 @@ class WorkerProxy : public ::ps::KVWorker { using std::placeholders::_2; using std::placeholders::_3; using std::placeholders::_4; + using std::placeholders::_5; lookup_customer_ = std::unique_ptr<::ps::Customer>( new ::ps::Customer(app_id, lookup_customer_id, std::bind(&WorkerProxy::ProcessLookupResult, this, _1))); general_customer_ = std::unique_ptr<::ps::Customer>( new ::ps::Customer(app_id, general_customer_id, std::bind(&WorkerProxy::ProcessResponse, this, _1))); - lookup_slicer_ = std::bind(&WorkerProxy::LookupIdSlicer, this, _1, _2, _3, _4); - broadcast_slicer_ = std::bind(&WorkerProxy::BroadcastSlicer, this, _1, _2, _3, _4); - round_robin_slicer_ = std::bind(&WorkerProxy::RoundRobinSlicer, this, _1, _2, _3, _4); - worker_init_embedding_slicer_ = std::bind(&WorkerProxy::WorkerInitEmbeddingSlicer, this, _1, _2, _3, _4); + lookup_slicer_ = std::bind(&WorkerProxy::LookupIdSlicer, this, _1, _2, _3, _4, _5); + sparse_slicer_ = std::bind(&WorkerProxy::SparseSlicer, this, _1, _2, _3, _4, _5); + broadcast_slicer_ = std::bind(&WorkerProxy::BroadcastSlicer, this, _1, _2, _3, _4, _5); + round_robin_slicer_ = std::bind(&WorkerProxy::RoundRobinSlicer, this, _1, _2, _3, _4, _5); + worker_init_embedding_slicer_ = std::bind(&WorkerProxy::WorkerInitEmbeddingSlicer, this, _1, _2, _3, _4, _5); } ~WorkerProxy() override = default; @@ -68,6 +72,8 @@ class WorkerProxy : public ::ps::KVWorker { bool IsReadyForPull(const Key &key); void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens = {}, int cmd = 0, int priority = 0); + void PushSparseData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens, + size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size); void PullData(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray *vals, ::ps::SArray *lens = nullptr, int cmd = 0, int priority = 0); void Finalize(); @@ -79,19 +85,28 @@ class WorkerProxy : public ::ps::KVWorker { int AddGeneralRspCB(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray *vals, ::ps::SArray *lens, int cmd, const Callback &cb); void LookupIdSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced); + std::vector>> *sliced, const std::map &attrs); + void SparseSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced, const std::map &attrs); void BroadcastSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced); + std::vector>> *sliced, const std::map &attrs); void RoundRobinSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced); + std::vector>> *sliced, const std::map &attrs); void WorkerInitEmbeddingSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced); + std::vector>> *sliced, + const std::map &attrs); void ProcessLookupResult(const ::ps::Message &msg); void ProcessResponse(const ::ps::Message &msg); void Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, const ::ps::KVPairs &kvs, - const Slicer &slicer); + const Slicer &slicer, std::map attrs = {}); void AddKeyByHashMod(const ::ps::Key &key); + void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set &distinct_ids, + const std::vector> &indice_to_grad, const int *all_indice, + const size_t segment_size, T *gradient, int *indice); + void BuildSparseValue(const ::ps::SArray &lengths, const size_t grad_index, const size_t indice_index, + const T *original_data, const T *grads, int *indices, ::ps::SArray *reduced_data); + int server_num_; std::unique_ptr<::ps::Customer> lookup_customer_; std::unique_ptr<::ps::Customer> general_customer_; @@ -100,6 +115,7 @@ class WorkerProxy : public ::ps::KVWorker { std::unordered_map> gathered_response_; std::mutex mutex_; Slicer lookup_slicer_; + Slicer sparse_slicer_; Slicer broadcast_slicer_; Slicer round_robin_slicer_; Slicer worker_init_embedding_slicer_; @@ -221,6 +237,28 @@ void WorkerProxy::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::S general_customer_->WaitRequest(ts); } +template +void WorkerProxy::PushSparseData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, + const ::ps::SArray &lens, size_t grad_index, size_t indice_index, + size_t first_dim_size, size_t outer_dim_size) { + int ts = AddGeneralRspCB(keys, nullptr, nullptr, 0, nullptr); + ::ps::KVPairs kvs; + kvs.keys = keys; + kvs.vals = vals; + kvs.lens = lens; + int cmd = 0; + if (embedding_table_ranges_.count(keys[0])) { + std::map attrs{{0, grad_index}, {1, indice_index}, {2, first_dim_size}, {3, outer_dim_size}}; + Send(general_customer_.get(), ts, true, false, cmd, kvs, sparse_slicer_, attrs); + } else { + Send(general_customer_.get(), ts, true, false, cmd, kvs, round_robin_slicer_); + } + if (expected_result_count_[ts] < server_num_) { + general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]); + } + general_customer_->WaitRequest(ts); +} + template void WorkerProxy::PullData(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray *vals, ::ps::SArray *lens, int cmd, int priority) { @@ -320,7 +358,8 @@ int WorkerProxy::AddGeneralRspCB(const ::ps::SArray<::ps::Key> &keys, ::ps::S template void WorkerProxy::LookupIdSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced) { + std::vector>> *sliced, + const std::map &attrs) { int *lookup_ids = send.lens.data(); size_t id_size = send.lens.size(); @@ -358,9 +397,181 @@ void WorkerProxy::LookupIdSlicer(int timestamp, const ::ps::KVPairs &send, } } +template +void WorkerProxy::SparseSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced, + const std::map &attrs) { + // Init variables + T *data = send.vals.data(); + + if (attrs.count(0) == 0 || attrs.count(1) == 0 || attrs.count(2) == 0 || attrs.count(3) == 0) { + MS_LOG(EXCEPTION) << "Invalid attrs keys"; + } + auto iter = attrs.find(0); + size_t grad_index = static_cast(iter->second); + iter = attrs.find(1); + size_t indice_index = static_cast(iter->second); + iter = attrs.find(2); + size_t first_dim_size = static_cast(iter->second); + iter = attrs.find(3); + size_t outer_dim_size = static_cast(iter->second); + + int grad_size = send.lens[grad_index]; + int indice_size = send.lens[indice_index]; + int segment_size = grad_size / indice_size; + + int grad_offset = 0; + int indice_offset = 0; + for (size_t i = 0; i < grad_index; i++) { + grad_offset += send.lens[i]; + } + for (size_t j = 0; j < indice_index; j++) { + indice_offset += send.lens[j]; + } + + T *grad_data = data + grad_offset; + int *indice_data = reinterpret_cast(data) + indice_offset; + + // Build the mappings of indice to gradient + std::vector> indice_to_grads; + for (int i = 0; i < indice_size; i++) { + int indice = indice_data[i]; + T *grad = grad_data + i * segment_size; + indice_to_grads.push_back(std::make_pair(indice, grad)); + } + + const Key &key = send.keys[0]; + const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); + sliced->resize(ranges.size()); + + // Construct reduced sparse data for each server + for (size_t i = 0; i < ranges.size(); i++) { + const ::ps::Range &range = ranges[i]; + const auto &begin = range.begin(); + const auto &end = range.end(); + auto &kvs = sliced->at(i).second; + kvs.keys = send.keys; + kvs.lens = send.lens; + + // Prepare the sparse gradient and indice + std::vector indice_ids; + std::unordered_set distinct_ids; + for (int j = 0; j < indice_size; j++) { + size_t indice = static_cast(indice_data[j]); + if (indice >= begin && indice <= end) { + indice_ids.push_back(indice); + distinct_ids.insert(indice); + } + } + size_t indices_size = indice_ids.size(); + int slice_segment_size = indices_size * segment_size; + T *src_grad_data = new T[slice_segment_size]; + int *src_indice_data = new int[indices_size]; + PrepareSparseGradient(begin, end, distinct_ids, indice_to_grads, indice_data, segment_size, src_grad_data, + src_indice_data); + + // Reduce the sparse gradient and indice + T *new_grad = new T[slice_segment_size]; + int *new_indices = new int[indices_size]; + mindspore::kernel::SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size}); + Util::ReduceSparseGradient(src_grad_data, src_indice_data, indices_size, segment_size, first_dim_size, + outer_dim_size, &unique_sparse_grad); + + // Update the length of reduce sparse gradient and indice + ::ps::SArray reduced_lens; + reduced_lens.CopyFrom(kvs.lens); + reduced_lens[grad_index] = unique_sparse_grad.indices_size_ * segment_size; + reduced_lens[indice_index] = unique_sparse_grad.indices_size_; + + // Build the sparse value to be sent + size_t total_size = 0; + for (auto size : reduced_lens) { + total_size += size; + } + ::ps::SArray reduced_data(total_size, 0); + BuildSparseValue(reduced_lens, grad_index, indice_index, data, unique_sparse_grad.value_, + unique_sparse_grad.indices_, &reduced_data); + + kvs.lens = reduced_lens; + kvs.vals = reduced_data; + + if (indices_size <= 0) { + sliced->at(i).first = false; + } else { + sliced->at(i).first = true; + expected_result_count_[timestamp] += 1; + } + } +} + +template +void WorkerProxy::PrepareSparseGradient(const size_t begin, const size_t end, + const std::unordered_set &distinct_ids, + const std::vector> &indice_to_grads, + const int *all_indice, const size_t segment_size, T *gradient, + int *indices) { + int offset = 0; + int index = 0; + size_t segment_data_size = segment_size * sizeof(T); + for (auto &pair : indice_to_grads) { + if (distinct_ids.count(pair.first) == 0) { + continue; + } + indices[index++] = pair.first; + auto ret = memcpy_s(gradient + offset, segment_data_size, pair.second, segment_data_size); + if (ret != 0) { + MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; + } + offset += segment_size; + } +} + +template +void WorkerProxy::BuildSparseValue(const ::ps::SArray &lengths, const size_t grad_index, + const size_t indice_index, const T *original_data, const T *grads, int *indices, + ::ps::SArray *reduced_data) { + int offset = 0; + for (size_t i = 0; i < lengths.size(); i++) { + if (i != grad_index && i != indice_index) { + int data_size = lengths[i] * sizeof(T); + auto ret = memcpy_s(reduced_data->data() + offset, data_size, original_data + offset, data_size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + } + offset += lengths[i]; + } + + // Fill the reduced gradient + int grad_offset = 0; + for (size_t i = 0; i < grad_index; i++) { + grad_offset += lengths[i]; + } + int data_size = lengths[grad_index] * sizeof(T); + auto ret = memcpy_s(reduced_data->data() + grad_offset, data_size, grads, data_size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + + // Fill the reduced indice + data_size = lengths[indice_index] * sizeof(T); + int indice_offset = grad_offset + data_size; + T *indice_data = reduced_data->data() + indice_offset; + T *convert = new T[lengths[indice_index]]; + for (int i = 0; i < lengths[indice_index]; i++) { + convert[i] = static_cast(indices[i]); + } + ret = memcpy_s(indice_data, data_size, convert, data_size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + delete[] convert; +} + template void WorkerProxy::BroadcastSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced) { + std::vector>> *sliced, + const std::map &attr) { sliced->resize(server_num_); for (int i = 0; i < server_num_; i++) { sliced->at(i).first = true; @@ -371,7 +582,8 @@ void WorkerProxy::BroadcastSlicer(int timestamp, const ::ps::KVPairs &send template void WorkerProxy::RoundRobinSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced) { + std::vector>> *sliced, + const std::map &attr) { sliced->resize(server_num_); auto keys = send.keys; auto vals = send.vals; @@ -408,7 +620,8 @@ void WorkerProxy::RoundRobinSlicer(int timestamp, const ::ps::KVPairs &sen template void WorkerProxy::WorkerInitEmbeddingSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced) { + std::vector>> *sliced, + const std::map &attrs) { sliced->resize(server_num_); auto keys = send.keys; auto vals = send.vals; @@ -483,9 +696,9 @@ void WorkerProxy::ProcessResponse(const ::ps::Message &msg) { template void WorkerProxy::Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, - const ::ps::KVPairs &kvs, const Slicer &slicer) { + const ::ps::KVPairs &kvs, const Slicer &slicer, std::map attrs) { SlicedKVs sliced; - slicer(timestamp, kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced); + slicer(timestamp, kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced, attrs); for (size_t i = 0; i < sliced.size(); i++) { const auto &s = sliced[i];