!3491 Delete parameter name hard code for embedding-lookup
Merge pull request !3491 from ZPaC/delete-param-name-hard-code
This commit is contained in:
commit
c9e43ffb85
|
@ -1194,6 +1194,7 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
|
||||||
}
|
}
|
||||||
auto ms_context = MsContext::GetInstance();
|
auto ms_context = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(ms_context);
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
|
std::vector<int> shape_init_in_server = {1};
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
auto tensor = inputs[i];
|
auto tensor = inputs[i];
|
||||||
MS_EXCEPTION_IF_NULL(tensor);
|
MS_EXCEPTION_IF_NULL(tensor);
|
||||||
|
@ -1201,8 +1202,13 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
|
||||||
MS_EXCEPTION_IF_NULL(input_node);
|
MS_EXCEPTION_IF_NULL(input_node);
|
||||||
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
|
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
|
||||||
auto pk_node = input_node->cast<ParameterPtr>();
|
auto pk_node = input_node->cast<ParameterPtr>();
|
||||||
|
bool init_in_server = false;
|
||||||
|
if (tensor->shape_c() == shape_init_in_server) {
|
||||||
|
MS_LOG(INFO) << "The param need to be initialized in server " << pk_node->fullname_with_scope();
|
||||||
|
init_in_server = true;
|
||||||
|
}
|
||||||
mindspore::parallel::ps::Worker<float>::GetInstance().InitPSParamAndOptim(
|
mindspore::parallel::ps::Worker<float>::GetInstance().InitPSParamAndOptim(
|
||||||
pk_node->fullname_with_scope(), tensor->data_c(), LongToSize(tensor->data().nbytes()));
|
pk_node->fullname_with_scope(), tensor->data_c(), LongToSize(tensor->data().nbytes()), init_in_server);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ps_init_ = true;
|
ps_init_ = true;
|
||||||
|
|
|
@ -542,6 +542,10 @@ inline bool ParameterServer<T>::ReadyForUpdateWeights() {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline bool ParameterServer<T>::ReadyForAccumGrads() {
|
inline bool ParameterServer<T>::ReadyForAccumGrads() {
|
||||||
|
if (weights_.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "The weights in server is empty. Many reasons could cause this: 1.The Worker didn't send "
|
||||||
|
"kInitWeightsCmd command. 2.The Server failed to initialize weights.";
|
||||||
|
}
|
||||||
return grad_accum_count_ < weights_.size();
|
return grad_accum_count_ < weights_.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -47,7 +47,8 @@ class Worker {
|
||||||
void SetOptimInputShapes(size_t key, const std::vector<int> &shape);
|
void SetOptimInputShapes(size_t key, const std::vector<int> &shape);
|
||||||
void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count);
|
void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count);
|
||||||
void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<size_t> shapes, const std::vector<int> &sizes);
|
void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<size_t> shapes, const std::vector<int> &sizes);
|
||||||
void InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size);
|
void InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size,
|
||||||
|
bool init_in_server = false);
|
||||||
void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
|
void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
|
||||||
const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int cmd);
|
const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int cmd);
|
||||||
void Finalize();
|
void Finalize();
|
||||||
|
@ -240,7 +241,8 @@ void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vecto
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
// Initialize parameters and optimizer kernels of Parameter Server.
|
// Initialize parameters and optimizer kernels of Parameter Server.
|
||||||
void Worker<T>::InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size) {
|
void Worker<T>::InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size,
|
||||||
|
bool init_in_server) {
|
||||||
size_t param_key = GetParamKey(param_name);
|
size_t param_key = GetParamKey(param_name);
|
||||||
if (param_key == kInvalidKey) {
|
if (param_key == kInvalidKey) {
|
||||||
MS_LOG(INFO) << "Parameter " << param_name << " has no key assigned.";
|
MS_LOG(INFO) << "Parameter " << param_name << " has no key assigned.";
|
||||||
|
@ -248,9 +250,9 @@ void Worker<T>::InitPSParamAndOptim(const std::string ¶m_name, void *param_d
|
||||||
}
|
}
|
||||||
bool init = IsKeyInit(param_key);
|
bool init = IsKeyInit(param_key);
|
||||||
if (!init) {
|
if (!init) {
|
||||||
MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name;
|
MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name
|
||||||
// No need to push embedding table data to Parameter Server.
|
<< ", whether init in server: " << init_in_server;
|
||||||
if (param_name.find("embedding_table") == std::string::npos && param_name.find("wide_w") == std::string::npos) {
|
if (!init_in_server) {
|
||||||
InitPSParamData({param_key}, param_data, param_size);
|
InitPSParamData({param_key}, param_data, param_size);
|
||||||
}
|
}
|
||||||
InitPSOptimId(param_key);
|
InitPSOptimId(param_key);
|
||||||
|
|
Loading…
Reference in New Issue