!4685 Export server embedding table ckpts.

Merge pull request !4685 from ZPaC/master-supports-export-server-ckpt
This commit is contained in:
mindspore-ci-bot 2020-08-19 20:08:00 +08:00 committed by Gitee
commit 748d7f1f1a
5 changed files with 69 additions and 6 deletions

View File

@ -28,7 +28,9 @@
#include <thread>
#include <cmath>
#include <random>
#include <utility>
#include <list>
#include <map>
#include "ir/func_graph.h"
#include "backend/session/session_basic.h"
#include "backend/session/anf_runtime_algorithm.h"
@ -52,6 +54,7 @@ namespace mindspore {
namespace parallel {
namespace ps {
using mindspore::kernel::ps::PServerKernel;
using AnfAlgo = session::AnfRuntimeAlgorithm;
template <typename T>
class ParameterServer {
public:
@ -127,6 +130,8 @@ class ParameterServer {
void ResetGradAccumCount();
const CNodePtr GetCNode(const std::string &name) const;
std::mutex &mutex();
void GetEmbeddingTableParamPtr();
void SyncEmbeddingTables();
size_t pserver_num_;
size_t worker_num_;
@ -155,6 +160,7 @@ class ParameterServer {
std::condition_variable apply_grads_cv_;
std::unique_ptr<std::thread> thread_;
std::map<Key, ParameterPtr> embedding_tables_;
friend class ServerHandler;
};
@ -332,6 +338,7 @@ bool ParameterServer<T>::Init(const FuncGraphPtr &func_graph) {
InitOptimInfoBuilders();
ps_->set_request_handle(*handler_);
thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this));
GetEmbeddingTableParamPtr();
return true;
}
@ -475,6 +482,7 @@ template <typename T>
void ParameterServer<T>::Finalize() {
running_ = false;
apply_grads_cv_.notify_one();
SyncEmbeddingTables();
}
template <typename T>
@ -658,6 +666,53 @@ inline std::mutex &ParameterServer<T>::mutex() {
return mutex_;
}
template <typename T>
void ParameterServer<T>::GetEmbeddingTableParamPtr() {
MS_EXCEPTION_IF_NULL(func_graph_);
auto cnodes = func_graph_->GetOrderedCnodes();
Key count = 0;
for (auto cnode : cnodes) {
std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
if (cnode_name == kEmbeddingLookupOpName) {
auto embedding_table = AnfAlgo::GetInputNode(cnode, 0);
MS_EXCEPTION_IF_NULL(embedding_table);
MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count;
embedding_tables_.insert(std::make_pair(count, embedding_table->cast<ParameterPtr>()));
count++;
}
}
}
template <typename T>
void ParameterServer<T>::SyncEmbeddingTables() {
for (auto embedding_table : embedding_tables_) {
Key key = embedding_table.first;
if (embedding_lookup_ops_.count(key) == 0) {
MS_LOG(EXCEPTION) << "Can't find look up PS kernel for key " << key;
}
auto lookup = embedding_lookup_ops_[key];
const std::vector<size_t> &input_shapes = lookup->input_sizes();
std::vector<int> new_tensor_shape(input_shapes.begin(), input_shapes.end());
tensor::TensorPtr new_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, new_tensor_shape);
float *new_tensor_data_ptr = reinterpret_cast<float *>(new_tensor->data_c());
size_t new_tensor_size = static_cast<size_t>(new_tensor->data().nbytes());
size_t embedding_table_size = weights_[key]->size() * sizeof(float);
if (new_tensor_size != embedding_table_size) {
MS_LOG(EXCEPTION) << "Shape of embedding table can't match. New tensor size:" << new_tensor_size
<< ", embedding_table size:" << embedding_table_size;
}
int ret = memcpy_s(new_tensor_data_ptr, new_tensor_size, weights_[key]->data(), embedding_table_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
}
auto paramter_tensor_ptr = embedding_table.second->default_param();
MS_EXCEPTION_IF_NULL(paramter_tensor_ptr);
paramter_tensor_ptr->cast<tensor::TensorPtr>()->AssignValue(*new_tensor);
}
}
template <typename T>
void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
::ps::Start(0);
@ -668,7 +723,6 @@ void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
Init(func_graph);
thread_->join();
::ps::Finalize(0, true);
exit(1);
}
} // namespace ps
} // namespace parallel

View File

@ -15,6 +15,7 @@
# limitations under the License.
# ============================================================================
"""Providing interface methods."""
import os
import types
from collections import OrderedDict
from functools import wraps
@ -468,7 +469,7 @@ class _Executor:
return self._executor.has_compiled(phase)
def __call__(self, obj, *args, phase='predict'):
if context.get_context("precompile_only"):
if context.get_context("precompile_only") or os.getenv("MS_ROLE") == "MS_PSERVER":
return None
return self.run(obj, *args, phase=phase)

View File

@ -16,6 +16,7 @@
from collections.abc import Iterable
import os
import sys
import math
import numpy as np
@ -496,6 +497,8 @@ class Model:
self._loss_scale_manager.update_loss_scale(overflow)
list_callback.step_end(run_context)
if os.getenv("MS_ROLE") == "MS_PSERVER":
sys.exit(0)
should_stop = should_stop or run_context.get_stop_requested()
if should_stop:
break

View File

@ -214,8 +214,8 @@ class WideDeepModel(nn.Cell):
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim)
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1)
self.embedding_table = self.deep_embeddinglookup.embedding_table
self.wide_w.set_param_ps()
self.embedding_table.set_param_ps()
self.deep_embeddinglookup.embedding_table.set_param_ps()
self.wide_embeddinglookup.embedding_table.set_param_ps()
else:
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target='DEVICE')
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target='DEVICE')

View File

@ -14,6 +14,7 @@
# ============================================================================
import os
import sys
import argparse
import numpy as np
@ -82,8 +83,12 @@ def do_sparse_embedding(ps=False):
for _ in range(epoch):
data = Tensor(np.random.randint(0, 15, (32, 3), np.int32))
label = Tensor(np.random.randint(0, 9, (32), np.int32))
loss = train_network(data, label).asnumpy()
losses.append(loss)
if envs.get("MS_ROLE") == "MS_PSERVER":
train_network(data, label)
sys.exit()
else:
loss = train_network(data, label).asnumpy()
losses.append(loss)
print(losses)
return losses