forked from mindspore-Ecosystem/mindspore
Export server embedding table ckpt.
This commit is contained in:
parent
9001dc48e6
commit
5084557f67
|
@ -28,7 +28,9 @@
|
|||
#include <thread>
|
||||
#include <cmath>
|
||||
#include <random>
|
||||
#include <utility>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include "ir/func_graph.h"
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
@ -52,6 +54,7 @@ namespace mindspore {
|
|||
namespace parallel {
|
||||
namespace ps {
|
||||
using mindspore::kernel::ps::PServerKernel;
|
||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||
template <typename T>
|
||||
class ParameterServer {
|
||||
public:
|
||||
|
@ -126,6 +129,8 @@ class ParameterServer {
|
|||
void ResetGradAccumCount();
|
||||
const CNodePtr GetCNode(const std::string &name) const;
|
||||
std::mutex &mutex();
|
||||
void GetEmbeddingTableParamPtr();
|
||||
void SyncEmbeddingTables();
|
||||
|
||||
size_t pserver_num_;
|
||||
size_t worker_num_;
|
||||
|
@ -154,6 +159,7 @@ class ParameterServer {
|
|||
std::condition_variable apply_grads_cv_;
|
||||
|
||||
std::unique_ptr<std::thread> thread_;
|
||||
std::map<Key, ParameterPtr> embedding_tables_;
|
||||
|
||||
friend class ServerHandler;
|
||||
};
|
||||
|
@ -329,6 +335,7 @@ bool ParameterServer<T>::Init(const FuncGraphPtr &func_graph) {
|
|||
InitOptimInfoBuilders();
|
||||
ps_->set_request_handle(*handler_);
|
||||
thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this));
|
||||
GetEmbeddingTableParamPtr();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -464,6 +471,7 @@ template <typename T>
|
|||
void ParameterServer<T>::Finalize() {
|
||||
running_ = false;
|
||||
apply_grads_cv_.notify_one();
|
||||
SyncEmbeddingTables();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -647,6 +655,53 @@ inline std::mutex &ParameterServer<T>::mutex() {
|
|||
return mutex_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ParameterServer<T>::GetEmbeddingTableParamPtr() {
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
auto cnodes = func_graph_->GetOrderedCnodes();
|
||||
Key count = 0;
|
||||
for (auto cnode : cnodes) {
|
||||
std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
|
||||
if (cnode_name == kEmbeddingLookupOpName) {
|
||||
auto embedding_table = AnfAlgo::GetInputNode(cnode, 0);
|
||||
MS_EXCEPTION_IF_NULL(embedding_table);
|
||||
MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count;
|
||||
embedding_tables_.insert(std::make_pair(count, embedding_table->cast<ParameterPtr>()));
|
||||
count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ParameterServer<T>::SyncEmbeddingTables() {
|
||||
for (auto embedding_table : embedding_tables_) {
|
||||
Key key = embedding_table.first;
|
||||
if (embedding_lookup_ops_.count(key) == 0) {
|
||||
MS_LOG(EXCEPTION) << "Can't find look up PS kernel for key " << key;
|
||||
}
|
||||
auto lookup = embedding_lookup_ops_[key];
|
||||
const std::vector<size_t> &input_shapes = lookup->input_sizes();
|
||||
std::vector<int> new_tensor_shape(input_shapes.begin(), input_shapes.end());
|
||||
|
||||
tensor::TensorPtr new_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, new_tensor_shape);
|
||||
float *new_tensor_data_ptr = reinterpret_cast<float *>(new_tensor->data_c());
|
||||
size_t new_tensor_size = static_cast<size_t>(new_tensor->data().nbytes());
|
||||
size_t embedding_table_size = weights_[key]->size() * sizeof(float);
|
||||
if (new_tensor_size != embedding_table_size) {
|
||||
MS_LOG(EXCEPTION) << "Shape of embedding table can't match. New tensor size:" << new_tensor_size
|
||||
<< ", embedding_table size:" << embedding_table_size;
|
||||
}
|
||||
int ret = memcpy_s(new_tensor_data_ptr, new_tensor_size, weights_[key]->data(), embedding_table_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
auto paramter_tensor_ptr = embedding_table.second->default_param();
|
||||
MS_EXCEPTION_IF_NULL(paramter_tensor_ptr);
|
||||
paramter_tensor_ptr->cast<tensor::TensorPtr>()->AssignValue(*new_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
|
||||
::ps::Start(0);
|
||||
|
@ -657,7 +712,6 @@ void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
|
|||
Init(func_graph);
|
||||
thread_->join();
|
||||
::ps::Finalize(0, true);
|
||||
exit(1);
|
||||
}
|
||||
} // namespace ps
|
||||
} // namespace parallel
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Providing interface methods."""
|
||||
import os
|
||||
import types
|
||||
from collections import OrderedDict
|
||||
from functools import wraps
|
||||
|
@ -468,7 +469,7 @@ class _Executor:
|
|||
return self._executor.has_compiled(phase)
|
||||
|
||||
def __call__(self, obj, *args, phase='predict'):
|
||||
if context.get_context("precompile_only"):
|
||||
if context.get_context("precompile_only") or os.getenv("MS_ROLE") == "MS_PSERVER":
|
||||
return None
|
||||
return self.run(obj, *args, phase=phase)
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
from collections.abc import Iterable
|
||||
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
@ -496,6 +497,8 @@ class Model:
|
|||
self._loss_scale_manager.update_loss_scale(overflow)
|
||||
|
||||
list_callback.step_end(run_context)
|
||||
if os.getenv("MS_ROLE") == "MS_PSERVER":
|
||||
sys.exit(0)
|
||||
should_stop = should_stop or run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
break
|
||||
|
|
|
@ -214,8 +214,8 @@ class WideDeepModel(nn.Cell):
|
|||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim)
|
||||
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1)
|
||||
self.embedding_table = self.deep_embeddinglookup.embedding_table
|
||||
self.wide_w.set_param_ps()
|
||||
self.embedding_table.set_param_ps()
|
||||
self.deep_embeddinglookup.embedding_table.set_param_ps()
|
||||
self.wide_embeddinglookup.embedding_table.set_param_ps()
|
||||
else:
|
||||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target='DEVICE')
|
||||
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target='DEVICE')
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
|
@ -82,6 +83,10 @@ def do_sparse_embedding(ps=False):
|
|||
for _ in range(epoch):
|
||||
data = Tensor(np.random.randint(0, 15, (32, 3), np.int32))
|
||||
label = Tensor(np.random.randint(0, 9, (32), np.int32))
|
||||
if envs.get("MS_ROLE") == "MS_PSERVER":
|
||||
train_network(data, label)
|
||||
sys.exit()
|
||||
else:
|
||||
loss = train_network(data, label).asnumpy()
|
||||
losses.append(loss)
|
||||
print(losses)
|
||||
|
|
Loading…
Reference in New Issue