!5812 Add PS context.
Merge pull request !5812 from ZPaC/master-context-for-ps
This commit is contained in:
commit
2a9c458870
|
@ -41,8 +41,7 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
MS_LOG(INFO) << "Init embedding lookup proxy kernel, input shape:" << input_shape
|
||||
<< ", indices_shape:" << indices_shape << ", output_shape:" << output_shape;
|
||||
std::vector<int> lens{SizeToInt(input_shape.size()), SizeToInt(indices_shape.size()), SizeToInt(output_shape.size())};
|
||||
const char *env_role = getenv(mindspore::parallel::ps::kEnvRole);
|
||||
if (env_role != nullptr && strcmp(env_role, mindspore::parallel::ps::kEnvRoleOfWorker) == 0) {
|
||||
if (mindspore::parallel::ps::Util::IsRoleOfWorker()) {
|
||||
parallel::ps::Worker<float>::GetInstance().AddEmbeddingTable(key_, input_shape[axis]);
|
||||
parallel::ps::Worker<float>::GetInstance().InitPSEmbeddingTable(keys, values, lens);
|
||||
}
|
||||
|
|
|
@ -32,11 +32,6 @@ constexpr char kEnvWorkerNum[] = "MS_WORKER_NUM";
|
|||
constexpr char kEnvSchedulerHost[] = "MS_SCHED_HOST";
|
||||
constexpr char kEnvSchedulerPort[] = "MS_SCHED_PORT";
|
||||
|
||||
constexpr char kEnvRole[] = "MS_ROLE";
|
||||
constexpr char kEnvRoleOfPServer[] = "MS_PSERVER";
|
||||
constexpr char kEnvRoleOfWorker[] = "MS_WORKER";
|
||||
constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
|
||||
|
||||
constexpr char kDmlcCommType[] = "DMLC_PS_VAN_TYPE";
|
||||
constexpr char kDmlcInterface[] = "DMLC_INTERFACE";
|
||||
constexpr char kDmlcPServerNum[] = "DMLC_NUM_SERVER";
|
||||
|
|
|
@ -39,6 +39,7 @@
|
|||
#include "frontend/parallel/ps/optimizer_info.h"
|
||||
#include "frontend/parallel/ps/optimizer_info_builder.h"
|
||||
#include "frontend/parallel/ps/util.h"
|
||||
#include "frontend/parallel/ps/ps_context.h"
|
||||
#include "runtime/device/cpu/kernel_select_cpu.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
|
@ -741,7 +742,7 @@ void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
|
|||
return;
|
||||
}
|
||||
Init(func_graph);
|
||||
Util::SetRankId(rank_id_);
|
||||
PSContext::instance()->SetPSRankId(rank_id_);
|
||||
thread_->join();
|
||||
::ps::Finalize(0, true);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/parallel/ps/ps_context.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
namespace ps {
|
||||
std::shared_ptr<PSContext> PSContext::instance() {
|
||||
static std::shared_ptr<PSContext> ps_instance = nullptr;
|
||||
if (ps_instance == nullptr) {
|
||||
ps_instance.reset(new (std::nothrow) PSContext());
|
||||
}
|
||||
return ps_instance;
|
||||
}
|
||||
|
||||
void PSContext::SetPSEnable(bool enabled) {
|
||||
ps_enabled_ = enabled;
|
||||
if (ps_enabled_) {
|
||||
std::string ms_role = common::GetEnv(kEnvRole);
|
||||
MS_LOG(INFO) << "PS mode is enabled. MS_ROLE is " << ms_role;
|
||||
if (ms_role == kEnvRoleOfWorker) {
|
||||
is_worker_ = true;
|
||||
} else if (ms_role == kEnvRoleOfPServer) {
|
||||
is_pserver_ = true;
|
||||
} else if (ms_role == kEnvRoleOfScheduler) {
|
||||
is_sched_ = true;
|
||||
} else {
|
||||
MS_LOG(WARNING) << "MS_ROLE is " << ms_role << ", which is invalid.";
|
||||
}
|
||||
} else {
|
||||
MS_LOG(INFO) << "PS mode is disabled.";
|
||||
is_worker_ = false;
|
||||
is_pserver_ = false;
|
||||
is_sched_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
bool PSContext::is_ps_enabled() const { return ps_enabled_; }
|
||||
|
||||
void PSContext::Reset() {
|
||||
ps_enabled_ = false;
|
||||
is_worker_ = false;
|
||||
is_pserver_ = false;
|
||||
is_sched_ = false;
|
||||
}
|
||||
|
||||
std::string PSContext::ms_role() const {
|
||||
if (is_worker_) {
|
||||
return kEnvRoleOfWorker;
|
||||
} else if (is_pserver_) {
|
||||
return kEnvRoleOfPServer;
|
||||
} else if (is_sched_) {
|
||||
return kEnvRoleOfScheduler;
|
||||
} else {
|
||||
return kEnvRoleOfNotPS;
|
||||
}
|
||||
}
|
||||
|
||||
bool PSContext::is_role_worker() const { return is_worker_; }
|
||||
|
||||
bool PSContext::is_role_pserver() const { return is_pserver_; }
|
||||
|
||||
bool PSContext::is_role_sched() const { return is_sched_; }
|
||||
|
||||
void PSContext::SetPSRankId(int rank_id) { rank_id_ = rank_id; }
|
||||
|
||||
int PSContext::ps_rank_id() const { return rank_id_; }
|
||||
} // namespace ps
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_CONTEXT_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_CONTEXT_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
namespace ps {
|
||||
constexpr char kEnvRole[] = "MS_ROLE";
|
||||
constexpr char kEnvRoleOfPServer[] = "MS_PSERVER";
|
||||
constexpr char kEnvRoleOfWorker[] = "MS_WORKER";
|
||||
constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
|
||||
constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
|
||||
|
||||
class PSContext {
|
||||
public:
|
||||
~PSContext() = default;
|
||||
PSContext(PSContext const &) = delete;
|
||||
PSContext &operator=(const PSContext &) = delete;
|
||||
static std::shared_ptr<PSContext> instance();
|
||||
|
||||
void SetPSEnable(bool enabled);
|
||||
bool is_ps_enabled() const;
|
||||
void Reset();
|
||||
std::string ms_role() const;
|
||||
bool is_role_worker() const;
|
||||
bool is_role_pserver() const;
|
||||
bool is_role_sched() const;
|
||||
void SetPSRankId(int rank_id);
|
||||
int ps_rank_id() const;
|
||||
|
||||
private:
|
||||
PSContext() : ps_enabled_(false), is_worker_(false), is_pserver_(false), is_sched_(false), rank_id_(-1) {}
|
||||
bool ps_enabled_;
|
||||
bool is_worker_;
|
||||
bool is_pserver_;
|
||||
bool is_sched_;
|
||||
int rank_id_;
|
||||
};
|
||||
} // namespace ps
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_CONTEXT_H_
|
|
@ -16,7 +16,9 @@
|
|||
|
||||
#include "frontend/parallel/ps/util.h"
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "frontend/parallel/ps/common.h"
|
||||
#include "frontend/parallel/ps/ps_context.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -45,34 +47,13 @@ std::unordered_map<int, std::string> Util::id_to_optimizer_nodes{
|
|||
{3, kSparseFtrlOp},
|
||||
};
|
||||
|
||||
bool Util::IsParamServerMode() { return IsRoleOfWorker() || IsRoleOfPServer() || IsRoleOfScheduler(); }
|
||||
bool Util::IsParamServerMode() { return PSContext::instance()->is_ps_enabled(); }
|
||||
|
||||
bool Util::IsRoleOfWorker() {
|
||||
auto role = common::GetEnv(kEnvRole);
|
||||
if (strcmp(role.c_str(), kEnvRoleOfWorker) == 0) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
bool Util::IsRoleOfWorker() { return PSContext::instance()->is_role_worker(); }
|
||||
|
||||
bool Util::IsRoleOfPServer() {
|
||||
auto role = common::GetEnv(kEnvRole);
|
||||
if (strcmp(role.c_str(), kEnvRoleOfPServer) == 0) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
bool Util::IsRoleOfPServer() { return PSContext::instance()->is_role_pserver(); }
|
||||
|
||||
bool Util::IsRoleOfScheduler() {
|
||||
auto role = common::GetEnv(kEnvRole);
|
||||
if (strcmp(role.c_str(), kEnvRoleOfScheduler) == 0) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
bool Util::IsRoleOfScheduler() { return PSContext::instance()->is_role_sched(); }
|
||||
|
||||
void Util::SetInternalEnvVar() {
|
||||
if (IsParamServerMode()) {
|
||||
|
@ -163,10 +144,6 @@ std::map<int, int> Util::AllRankLocalShard(int first_dim, int rank_id, int serve
|
|||
return shard_dims;
|
||||
}
|
||||
|
||||
void Util::SetRankId(int rank_id) { rank_id_ = rank_id; }
|
||||
|
||||
int Util::GetRankId() { return rank_id_; }
|
||||
|
||||
void Util::ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size,
|
||||
const size_t first_dim_size, const size_t outer_dim_size,
|
||||
mindspore::kernel::SparseGradient<int> *unique_sparse_grad) {
|
||||
|
|
|
@ -40,8 +40,6 @@ class Util {
|
|||
static bool is_optimizer(std::string name);
|
||||
static int LocalShard(int first_dim, int rank_id, int server_num);
|
||||
static std::map<int, int> AllRankLocalShard(int first_dim, int rank_id, int server_num);
|
||||
static void SetRankId(int rank_id);
|
||||
static int GetRankId();
|
||||
static void ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size,
|
||||
const size_t first_dim_size, const size_t outer_dim_size,
|
||||
mindspore::kernel::SparseGradient<int> *unique_sparse_grad);
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "ps/ps.h"
|
||||
#include "frontend/parallel/ps/util.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "frontend/parallel/ps/ps_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
@ -43,7 +44,7 @@ class WorkerProxy : public ::ps::KVWorker<T> {
|
|||
explicit WorkerProxy(int app_id, int customer_id, int lookup_customer_id, int general_customer_id)
|
||||
: Worker(app_id, customer_id) {
|
||||
server_num_ = ::ps::NumServers();
|
||||
Util::SetRankId(::ps::MyRank());
|
||||
PSContext::instance()->SetPSRankId(::ps::MyRank());
|
||||
using std::placeholders::_1;
|
||||
using std::placeholders::_2;
|
||||
using std::placeholders::_3;
|
||||
|
|
|
@ -36,6 +36,7 @@
|
|||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "frontend/parallel/ps/util.h"
|
||||
#endif
|
||||
#include "frontend/parallel/ps/ps_context.h"
|
||||
namespace py = pybind11;
|
||||
|
||||
using EnvInstance = mindspore::EnvInstance;
|
||||
|
@ -49,6 +50,7 @@ using OpInfoLoaderPy = mindspore::kernel::OpInfoLoaderPy;
|
|||
using ParallelContext = mindspore::parallel::ParallelContext;
|
||||
using CostModelContext = mindspore::parallel::CostModelContext;
|
||||
using mindspore::MsCtxParam;
|
||||
using PSContext = mindspore::parallel::ps::PSContext;
|
||||
|
||||
// Interface with python
|
||||
PYBIND11_MODULE(_c_expression, m) {
|
||||
|
@ -276,9 +278,15 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
"Finalize gpu collective communication mode.");
|
||||
#endif
|
||||
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
(void)m.def("get_ps_mode_rank", &mindspore::parallel::ps::Util::GetRankId, "Get Worker and PServer rank id.");
|
||||
#endif
|
||||
(void)py::class_<PSContext, std::shared_ptr<PSContext>>(m, "PSContext")
|
||||
.def_static("get_instance", &PSContext::instance, "Get PS context instance.")
|
||||
.def("set_ps_enable", &PSContext::SetPSEnable, "Set PS mode enabled or disabled.")
|
||||
.def("is_ps_enabled", &PSContext::is_ps_enabled, "Get PS mode enable-disable status.")
|
||||
.def("reset", &PSContext::Reset, "Reset PS context attributes.")
|
||||
.def("is_role_worker", &PSContext::is_role_worker, "Get whether the role of this process is Worker.")
|
||||
.def("is_role_pserver", &PSContext::is_role_pserver, "Get whether the role of this process is PServer.")
|
||||
.def("is_role_sched", &PSContext::is_role_sched, "Get whether the role of this process is Scheduler.")
|
||||
.def("ps_rank_id", &PSContext::ps_rank_id, "Get Worker and PServer rank id.");
|
||||
|
||||
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
|
||||
.def(py::init())
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Providing interface methods."""
|
||||
import os
|
||||
import types
|
||||
from collections import OrderedDict
|
||||
from functools import wraps
|
||||
|
@ -25,6 +24,7 @@ from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, Pynativ
|
|||
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend
|
||||
from .tensor import Tensor as MsTensor
|
||||
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_tensor
|
||||
from ..parallel._ps_context import _is_role_pserver
|
||||
# store ms_function class compiled pipeline cache
|
||||
ms_compile_cache = {}
|
||||
|
||||
|
@ -469,7 +469,7 @@ class _Executor:
|
|||
return self._executor.has_compiled(phase)
|
||||
|
||||
def __call__(self, obj, *args, phase='predict'):
|
||||
if context.get_context("precompile_only") or os.getenv("MS_ROLE") == "MS_PSERVER":
|
||||
if context.get_context("precompile_only") or _is_role_pserver():
|
||||
return None
|
||||
return self.run(obj, *args, phase=phase)
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ from .tensor import Tensor, MetaTensor
|
|||
from .._checkparam import _check_str_by_regular
|
||||
from ..parallel._tensor import _get_slice_index
|
||||
from ..parallel._auto_parallel_context import auto_parallel_context
|
||||
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched
|
||||
|
||||
__all__ = ['Parameter', 'ParameterTuple']
|
||||
|
||||
|
@ -168,8 +169,13 @@ class Parameter(MetaTensor):
|
|||
"""For parse check."""
|
||||
|
||||
def set_param_ps(self, init_in_server=False):
|
||||
self.is_param_ps = True
|
||||
self.init_in_server = init_in_server
|
||||
if _is_role_worker() or _is_role_pserver() or _is_role_sched():
|
||||
self.is_param_ps = True
|
||||
self.init_in_server = init_in_server
|
||||
else:
|
||||
raise RuntimeError("Must complete following two steps before calling set_param_ps: \
|
||||
1. set_ps_context(enable_ps=True) \
|
||||
2. export MS_ROLE environment variable.")
|
||||
|
||||
|
||||
@property
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
"""comm_helper"""
|
||||
|
||||
import os
|
||||
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
|
||||
from ._hccl_management import load_lib as hccl_load_lib
|
||||
|
||||
_HCCL_AVAILABLE = False
|
||||
|
@ -44,7 +44,6 @@ else:
|
|||
|
||||
HCCL_WORLD_COMM_GROUP = "hccl_world_group"
|
||||
NCCL_WORLD_COMM_GROUP = "nccl_world_group"
|
||||
MS_ROLE = os.getenv("MS_ROLE")
|
||||
|
||||
class Backend:
|
||||
"""
|
||||
|
@ -113,7 +112,7 @@ def check_parameter_available(func):
|
|||
Wrapper. If not available, raise Error.
|
||||
"""
|
||||
def wrapper(*args, **kargs):
|
||||
if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
|
||||
if _is_role_pserver() or _is_role_sched():
|
||||
return func(*args, **kargs)
|
||||
group = None
|
||||
if "group" in kargs.keys():
|
||||
|
@ -154,7 +153,7 @@ def _get_rank_helper(group, backend):
|
|||
Integer. The local rank id of the calling process.
|
||||
"""
|
||||
rank_id = None
|
||||
if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
|
||||
if _is_role_pserver() or _is_role_sched():
|
||||
rank_id = 0
|
||||
return rank_id
|
||||
if backend == Backend.HCCL:
|
||||
|
@ -213,7 +212,7 @@ def _get_size_helper(group, backend):
|
|||
Integer. The rank size of specified group.
|
||||
"""
|
||||
size = None
|
||||
if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
|
||||
if _is_role_pserver() or _is_role_sched():
|
||||
size = 1
|
||||
return size
|
||||
if backend == Backend.HCCL:
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Communication management API"""
|
||||
import os
|
||||
from mindspore import context
|
||||
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
|
||||
from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
|
||||
_get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \
|
||||
_create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \
|
||||
|
@ -29,7 +29,6 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
|
|||
|
||||
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
||||
DEFAULT_BACKEND = Backend("hccl")
|
||||
MS_ROLE = os.getenv("MS_ROLE")
|
||||
|
||||
|
||||
def _get_group(group):
|
||||
|
@ -61,7 +60,7 @@ def init(backend_name=None):
|
|||
RuntimeError: If device target is invalid.
|
||||
RuntimeError: If backend is invalid or distributed init fails.
|
||||
"""
|
||||
if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
|
||||
if _is_role_pserver() or _is_role_sched():
|
||||
return
|
||||
if backend_name is None:
|
||||
device_target = context.get_context("device_target")
|
||||
|
|
|
@ -26,9 +26,11 @@ from mindspore._c_expression import MSContext, ms_ctx_param
|
|||
from mindspore._checkparam import args_type_check
|
||||
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
|
||||
_reset_auto_parallel_context
|
||||
from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context
|
||||
|
||||
__all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context',
|
||||
'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode']
|
||||
'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode', 'set_ps_context',
|
||||
'get_ps_context', 'reset_ps_context']
|
||||
|
||||
GRAPH_MODE = 0
|
||||
PYNATIVE_MODE = 1
|
||||
|
@ -569,3 +571,58 @@ class ParallelMode:
|
|||
SEMI_AUTO_PARALLEL = "semi_auto_parallel"
|
||||
AUTO_PARALLEL = "auto_parallel"
|
||||
MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL]
|
||||
|
||||
@args_type_check(enable_ps=bool)
|
||||
def set_ps_context(**kwargs):
|
||||
"""
|
||||
Set parameter server training mode context.
|
||||
|
||||
Note:
|
||||
Some other environment variables should also be set for parameter server training mode.
|
||||
These environment variables are listed below:
|
||||
MS_SERVER_NUM # Server number
|
||||
MS_WORKER_NUM # Worker number
|
||||
MS_SCHED_HOST # Scheduler IP address
|
||||
MS_SCHED_PORT # Scheduler port
|
||||
MS_ROLE # The role of this process:
|
||||
MS_SCHED represents the scheduler,
|
||||
MS_WORKER represents the worker,
|
||||
MS_PSERVER represents the Server
|
||||
|
||||
|
||||
Args:
|
||||
enable_ps (bool): Whether to enable parameter server training mode.
|
||||
Only after enable_ps is set True, the environment variables will be effective.
|
||||
Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not the attribute in parameter server training mode context.
|
||||
|
||||
Examples:
|
||||
>>> context.set_ps_context(enable_ps=True)
|
||||
"""
|
||||
_set_ps_context(**kwargs)
|
||||
|
||||
|
||||
def get_ps_context(attr_key):
|
||||
"""
|
||||
Get parameter server training mode context attribute value according to the key.
|
||||
|
||||
Args:
|
||||
attr_key (str): The key of the attribute.
|
||||
|
||||
Returns:
|
||||
Returns attribute value according to the key.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in auto parallel context.
|
||||
"""
|
||||
return _get_ps_context(attr_key)
|
||||
|
||||
def reset_ps_context():
|
||||
"""
|
||||
Reset parameter server training mode context attributes to the default values:
|
||||
|
||||
- enable_ps: False.
|
||||
"""
|
||||
_reset_ps_context()
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Context for parameter server training mode"""
|
||||
|
||||
from mindspore._c_expression import PSContext
|
||||
|
||||
_ps_context = None
|
||||
|
||||
|
||||
def ps_context():
|
||||
"""
|
||||
Get the global _ps_context, if it is not created, create a new one.
|
||||
|
||||
Returns:
|
||||
_ps_context, the global parameter server training mode context.
|
||||
"""
|
||||
global _ps_context
|
||||
if _ps_context is None:
|
||||
_ps_context = PSContext.get_instance()
|
||||
return _ps_context
|
||||
|
||||
_set_ps_context_func_map = {
|
||||
"enable_ps": ps_context().set_ps_enable
|
||||
}
|
||||
|
||||
_get_ps_context_func_map = {
|
||||
"enable_ps": ps_context().is_ps_enabled
|
||||
}
|
||||
|
||||
def _get_ps_mode_rank():
|
||||
ps_rank = ps_context().ps_rank_id()
|
||||
if ps_rank == -1:
|
||||
raise RuntimeError("The parameter server mode training is not enabled yet.")
|
||||
return ps_rank
|
||||
|
||||
def _set_ps_context(**kwargs):
|
||||
"""
|
||||
Set parameter server training mode context.
|
||||
|
||||
Note:
|
||||
Some other environment variables should also be set for parameter server training mode.
|
||||
These environment variables are listed below:
|
||||
MS_SERVER_NUM # Server number
|
||||
MS_WORKER_NUM # Worker number
|
||||
MS_SCHED_HOST # Scheduler IP address
|
||||
MS_SCHED_PORT # Scheduler port
|
||||
MS_ROLE # The role of this process:
|
||||
MS_SCHED represents the scheduler,
|
||||
MS_WORKER represents the worker,
|
||||
MS_PSERVER represents the Server
|
||||
|
||||
|
||||
Args:
|
||||
enable_ps (bool): Whether to enable parameter server training mode.
|
||||
Only after enable_ps is set True, the environment variables will be effective.
|
||||
Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not the attribute in parameter server training mode context.
|
||||
|
||||
Examples:
|
||||
>>> context.set_ps_context(enable_ps=True)
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if key not in _set_ps_context_func_map:
|
||||
raise ValueError("Set PS context keyword %s is not recognized!" % key)
|
||||
set_func = _set_ps_context_func_map[key]
|
||||
set_func(value)
|
||||
|
||||
def _get_ps_context(attr_key):
|
||||
"""
|
||||
Get parameter server training mode context attribute value according to the key.
|
||||
|
||||
Args:
|
||||
attr_key (str): The key of the attribute.
|
||||
|
||||
Returns:
|
||||
Returns attribute value according to the key.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in auto parallel context.
|
||||
"""
|
||||
if key not in _get_ps_context_func_map:
|
||||
raise ValueError("Get PS context keyword %s is not recognized!" % key)
|
||||
get_func = _get_ps_context_func_map[attr_key]
|
||||
get_func(attr_key)
|
||||
|
||||
def _reset_ps_context():
|
||||
"""
|
||||
Reset parameter server training mode context attributes to the default values:
|
||||
|
||||
- enable_ps: False.
|
||||
"""
|
||||
ps_context().reset()
|
||||
|
||||
def _is_role_worker():
|
||||
return ps_context().is_role_worker()
|
||||
|
||||
def _is_role_pserver():
|
||||
return ps_context().is_role_pserver()
|
||||
|
||||
def _is_role_sched():
|
||||
return ps_context().is_role_sched()
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Utils for parameter server training mode"""
|
||||
|
||||
from mindspore._c_expression import get_ps_mode_rank
|
||||
|
||||
def _get_ps_mode_rank():
|
||||
ps_rank = get_ps_mode_rank()
|
||||
if ps_rank == -1:
|
||||
raise RuntimeError("The parameter server mode training is not launched yet.")
|
||||
return ps_rank
|
|
@ -24,6 +24,7 @@ from mindspore import log as logger
|
|||
from mindspore._checkparam import check_bool, check_int_non_negative
|
||||
from mindspore.train._utils import _make_directory
|
||||
from mindspore.train.serialization import save_checkpoint, _save_graph
|
||||
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
|
||||
from ._callback import Callback, set_cur_net
|
||||
|
||||
|
||||
|
@ -280,8 +281,7 @@ class ModelCheckpoint(Callback):
|
|||
if save_ckpt:
|
||||
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
|
||||
+ str(step_num_in_epoch) + ".ckpt"
|
||||
if os.getenv("MS_ROLE") == "MS_PSERVER":
|
||||
from mindspore.parallel._ps_utils import _get_ps_mode_rank
|
||||
if _is_role_pserver():
|
||||
cur_ckpoint_file = "PServer_" + str(_get_ps_mode_rank()) + "_" + cur_ckpoint_file
|
||||
# update checkpoint file list.
|
||||
self._manager.update_ckpoint_filelist(self._directory, self._prefix)
|
||||
|
|
|
@ -27,6 +27,7 @@ from .callback import _InternalCallbackParam, RunContext, _CallbackManager
|
|||
from .. import context
|
||||
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||
from ..parallel._ps_context import _is_role_pserver, _is_role_sched
|
||||
from ..nn.metrics import Loss
|
||||
from .. import nn
|
||||
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
|
@ -378,8 +379,7 @@ class Model:
|
|||
cb_params.list_callback = self._transform_callbacks(callbacks)
|
||||
cb_params.train_dataset_element = None
|
||||
cb_params.network = self._network
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
if _is_role_pserver() or _is_role_sched():
|
||||
epoch = 1
|
||||
|
||||
# build callback list
|
||||
|
@ -516,7 +516,7 @@ class Model:
|
|||
self._loss_scale_manager.update_loss_scale(overflow)
|
||||
|
||||
list_callback.step_end(run_context)
|
||||
if os.getenv("MS_ROLE") == "MS_PSERVER":
|
||||
if _is_role_pserver():
|
||||
os._exit(0)
|
||||
should_stop = should_stop or run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
|
|
|
@ -70,6 +70,7 @@ if __name__ == '__main__':
|
|||
|
||||
# init context
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||
context.set_ps_context(enable_ps=True)
|
||||
if args_opt.run_distribute:
|
||||
if target == "Ascend":
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# ============================================================================
|
||||
"""Model."""
|
||||
import math
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
|
||||
import numpy as np
|
||||
|
@ -405,9 +404,6 @@ class Model:
|
|||
cb_params.list_callback = self._transform_callbacks(callbacks)
|
||||
cb_params.train_dataset_element = None
|
||||
cb_params.network = self._network
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
epoch = 1
|
||||
|
||||
# build callback list
|
||||
with _CallbackManager(callbacks) as list_callback:
|
||||
|
|
|
@ -118,6 +118,7 @@ if __name__ == "__main__":
|
|||
wide_deep_config.argparse_init()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target)
|
||||
context.set_ps_context(enable_ps=True)
|
||||
init()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
device_num=get_group_size())
|
||||
|
|
|
@ -26,6 +26,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
|
|||
from mindspore.nn.optim import Adam
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_worker
|
||||
|
||||
parser = argparse.ArgumentParser(description="test_sparse_embedding")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend")
|
||||
|
@ -34,6 +35,7 @@ device_target = args.device_target
|
|||
context.set_context(
|
||||
mode=context.GRAPH_MODE, device_target=device_target, enable_sparse=True
|
||||
)
|
||||
context.set_ps_context(enable_ps=True)
|
||||
|
||||
|
||||
def fc_with_initialize(input_channels, out_channels):
|
||||
|
@ -81,7 +83,7 @@ def do_sparse_embedding(ps=False):
|
|||
for _ in range(epoch):
|
||||
data = Tensor(np.random.randint(0, 15, (32, 3), np.int32))
|
||||
label = Tensor(np.random.randint(0, 9, (32), np.int32))
|
||||
if envs.get("MS_ROLE") == "MS_PSERVER":
|
||||
if _is_role_pserver():
|
||||
train_network(data, label)
|
||||
sys.exit()
|
||||
else:
|
||||
|
@ -96,10 +98,10 @@ if __name__ == "__main__":
|
|||
np.random.seed(0)
|
||||
ps_loss = do_sparse_embedding(True)
|
||||
|
||||
if envs.get("MS_ROLE") == "MS_WORKER":
|
||||
envs["MS_ROLE"] = ""
|
||||
if _is_role_worker():
|
||||
context.reset_ps_context()
|
||||
np.random.seed(0)
|
||||
no_ps_loss = do_sparse_embedding()
|
||||
envs["MS_ROLE"] = "MS_WORKER"
|
||||
context.set_ps_context(enable_ps=True)
|
||||
|
||||
assert np.allclose(ps_loss, no_ps_loss, rtol=1.0e-6, atol=1.0e-6)
|
||||
|
|
|
@ -35,6 +35,7 @@ args, _ = parser.parse_known_args()
|
|||
device_target = args.device_target
|
||||
dataset_path = args.dataset_path
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
context.set_ps_context(enable_ps=True)
|
||||
|
||||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||
"""weight initial for conv layer"""
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
|
@ -22,6 +23,7 @@ from mindspore.common.initializer import TruncatedNormal
|
|||
from mindspore import Tensor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.communication.management import init, get_group_size
|
||||
from mindspore.parallel._ps_context import _is_role_pserver
|
||||
# from resnet import resnet50
|
||||
|
||||
parser = argparse.ArgumentParser(description="test_ps_lenet")
|
||||
|
@ -29,6 +31,7 @@ parser.add_argument("--device_target", type=str, default="Ascend")
|
|||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
context.set_ps_context(enable_ps=True)
|
||||
if device_target == "GPU":
|
||||
init()
|
||||
|
||||
|
@ -106,6 +109,10 @@ if __name__ == "__main__":
|
|||
for _ in range(epoch):
|
||||
data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32))
|
||||
label = Tensor(np.random.randint(0, 9, (32)).astype(np.int32))
|
||||
loss = train_network(data, label).asnumpy()
|
||||
losses.append(loss)
|
||||
if _is_role_pserver():
|
||||
train_network(data, label)
|
||||
sys.exit()
|
||||
else:
|
||||
loss = train_network(data, label).asnumpy()
|
||||
losses.append(loss)
|
||||
print(losses)
|
||||
|
|
Loading…
Reference in New Issue