forked from mindspore-Ecosystem/mindspore
!4685 Export server embedding table ckpts.
Merge pull request !4685 from ZPaC/master-supports-export-server-ckpt
This commit is contained in:
commit
748d7f1f1a
|
@ -28,7 +28,9 @@
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
#include <utility>
|
||||||
#include <list>
|
#include <list>
|
||||||
|
#include <map>
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
#include "backend/session/session_basic.h"
|
#include "backend/session/session_basic.h"
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
@ -52,6 +54,7 @@ namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
namespace ps {
|
namespace ps {
|
||||||
using mindspore::kernel::ps::PServerKernel;
|
using mindspore::kernel::ps::PServerKernel;
|
||||||
|
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class ParameterServer {
|
class ParameterServer {
|
||||||
public:
|
public:
|
||||||
|
@ -127,6 +130,8 @@ class ParameterServer {
|
||||||
void ResetGradAccumCount();
|
void ResetGradAccumCount();
|
||||||
const CNodePtr GetCNode(const std::string &name) const;
|
const CNodePtr GetCNode(const std::string &name) const;
|
||||||
std::mutex &mutex();
|
std::mutex &mutex();
|
||||||
|
void GetEmbeddingTableParamPtr();
|
||||||
|
void SyncEmbeddingTables();
|
||||||
|
|
||||||
size_t pserver_num_;
|
size_t pserver_num_;
|
||||||
size_t worker_num_;
|
size_t worker_num_;
|
||||||
|
@ -155,6 +160,7 @@ class ParameterServer {
|
||||||
std::condition_variable apply_grads_cv_;
|
std::condition_variable apply_grads_cv_;
|
||||||
|
|
||||||
std::unique_ptr<std::thread> thread_;
|
std::unique_ptr<std::thread> thread_;
|
||||||
|
std::map<Key, ParameterPtr> embedding_tables_;
|
||||||
|
|
||||||
friend class ServerHandler;
|
friend class ServerHandler;
|
||||||
};
|
};
|
||||||
|
@ -332,6 +338,7 @@ bool ParameterServer<T>::Init(const FuncGraphPtr &func_graph) {
|
||||||
InitOptimInfoBuilders();
|
InitOptimInfoBuilders();
|
||||||
ps_->set_request_handle(*handler_);
|
ps_->set_request_handle(*handler_);
|
||||||
thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this));
|
thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this));
|
||||||
|
GetEmbeddingTableParamPtr();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -475,6 +482,7 @@ template <typename T>
|
||||||
void ParameterServer<T>::Finalize() {
|
void ParameterServer<T>::Finalize() {
|
||||||
running_ = false;
|
running_ = false;
|
||||||
apply_grads_cv_.notify_one();
|
apply_grads_cv_.notify_one();
|
||||||
|
SyncEmbeddingTables();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -658,6 +666,53 @@ inline std::mutex &ParameterServer<T>::mutex() {
|
||||||
return mutex_;
|
return mutex_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void ParameterServer<T>::GetEmbeddingTableParamPtr() {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||||
|
auto cnodes = func_graph_->GetOrderedCnodes();
|
||||||
|
Key count = 0;
|
||||||
|
for (auto cnode : cnodes) {
|
||||||
|
std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
|
||||||
|
if (cnode_name == kEmbeddingLookupOpName) {
|
||||||
|
auto embedding_table = AnfAlgo::GetInputNode(cnode, 0);
|
||||||
|
MS_EXCEPTION_IF_NULL(embedding_table);
|
||||||
|
MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count;
|
||||||
|
embedding_tables_.insert(std::make_pair(count, embedding_table->cast<ParameterPtr>()));
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void ParameterServer<T>::SyncEmbeddingTables() {
|
||||||
|
for (auto embedding_table : embedding_tables_) {
|
||||||
|
Key key = embedding_table.first;
|
||||||
|
if (embedding_lookup_ops_.count(key) == 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "Can't find look up PS kernel for key " << key;
|
||||||
|
}
|
||||||
|
auto lookup = embedding_lookup_ops_[key];
|
||||||
|
const std::vector<size_t> &input_shapes = lookup->input_sizes();
|
||||||
|
std::vector<int> new_tensor_shape(input_shapes.begin(), input_shapes.end());
|
||||||
|
|
||||||
|
tensor::TensorPtr new_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, new_tensor_shape);
|
||||||
|
float *new_tensor_data_ptr = reinterpret_cast<float *>(new_tensor->data_c());
|
||||||
|
size_t new_tensor_size = static_cast<size_t>(new_tensor->data().nbytes());
|
||||||
|
size_t embedding_table_size = weights_[key]->size() * sizeof(float);
|
||||||
|
if (new_tensor_size != embedding_table_size) {
|
||||||
|
MS_LOG(EXCEPTION) << "Shape of embedding table can't match. New tensor size:" << new_tensor_size
|
||||||
|
<< ", embedding_table size:" << embedding_table_size;
|
||||||
|
}
|
||||||
|
int ret = memcpy_s(new_tensor_data_ptr, new_tensor_size, weights_[key]->data(), embedding_table_size);
|
||||||
|
if (ret != 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto paramter_tensor_ptr = embedding_table.second->default_param();
|
||||||
|
MS_EXCEPTION_IF_NULL(paramter_tensor_ptr);
|
||||||
|
paramter_tensor_ptr->cast<tensor::TensorPtr>()->AssignValue(*new_tensor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
|
void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
|
||||||
::ps::Start(0);
|
::ps::Start(0);
|
||||||
|
@ -668,7 +723,6 @@ void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
|
||||||
Init(func_graph);
|
Init(func_graph);
|
||||||
thread_->join();
|
thread_->join();
|
||||||
::ps::Finalize(0, true);
|
::ps::Finalize(0, true);
|
||||||
exit(1);
|
|
||||||
}
|
}
|
||||||
} // namespace ps
|
} // namespace ps
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Providing interface methods."""
|
"""Providing interface methods."""
|
||||||
|
import os
|
||||||
import types
|
import types
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
@ -468,7 +469,7 @@ class _Executor:
|
||||||
return self._executor.has_compiled(phase)
|
return self._executor.has_compiled(phase)
|
||||||
|
|
||||||
def __call__(self, obj, *args, phase='predict'):
|
def __call__(self, obj, *args, phase='predict'):
|
||||||
if context.get_context("precompile_only"):
|
if context.get_context("precompile_only") or os.getenv("MS_ROLE") == "MS_PSERVER":
|
||||||
return None
|
return None
|
||||||
return self.run(obj, *args, phase=phase)
|
return self.run(obj, *args, phase=phase)
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -496,6 +497,8 @@ class Model:
|
||||||
self._loss_scale_manager.update_loss_scale(overflow)
|
self._loss_scale_manager.update_loss_scale(overflow)
|
||||||
|
|
||||||
list_callback.step_end(run_context)
|
list_callback.step_end(run_context)
|
||||||
|
if os.getenv("MS_ROLE") == "MS_PSERVER":
|
||||||
|
sys.exit(0)
|
||||||
should_stop = should_stop or run_context.get_stop_requested()
|
should_stop = should_stop or run_context.get_stop_requested()
|
||||||
if should_stop:
|
if should_stop:
|
||||||
break
|
break
|
||||||
|
|
|
@ -214,8 +214,8 @@ class WideDeepModel(nn.Cell):
|
||||||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim)
|
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim)
|
||||||
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1)
|
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1)
|
||||||
self.embedding_table = self.deep_embeddinglookup.embedding_table
|
self.embedding_table = self.deep_embeddinglookup.embedding_table
|
||||||
self.wide_w.set_param_ps()
|
self.deep_embeddinglookup.embedding_table.set_param_ps()
|
||||||
self.embedding_table.set_param_ps()
|
self.wide_embeddinglookup.embedding_table.set_param_ps()
|
||||||
else:
|
else:
|
||||||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target='DEVICE')
|
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target='DEVICE')
|
||||||
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target='DEVICE')
|
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target='DEVICE')
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -82,8 +83,12 @@ def do_sparse_embedding(ps=False):
|
||||||
for _ in range(epoch):
|
for _ in range(epoch):
|
||||||
data = Tensor(np.random.randint(0, 15, (32, 3), np.int32))
|
data = Tensor(np.random.randint(0, 15, (32, 3), np.int32))
|
||||||
label = Tensor(np.random.randint(0, 9, (32), np.int32))
|
label = Tensor(np.random.randint(0, 9, (32), np.int32))
|
||||||
loss = train_network(data, label).asnumpy()
|
if envs.get("MS_ROLE") == "MS_PSERVER":
|
||||||
losses.append(loss)
|
train_network(data, label)
|
||||||
|
sys.exit()
|
||||||
|
else:
|
||||||
|
loss = train_network(data, label).asnumpy()
|
||||||
|
losses.append(loss)
|
||||||
print(losses)
|
print(losses)
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue