forked from mindspore-Ecosystem/mindspore
!3027 Adaptation for ps mode.
Merge pull request !3027 from ZPaC/adaptation-for-ps-mode
This commit is contained in:
commit
be9b3c53dc
|
@ -55,7 +55,7 @@ class CPUKernel : public kernel::KernelMod {
|
|||
public:
|
||||
CPUKernel() = default;
|
||||
~CPUKernel() override = default;
|
||||
void Init(const CNodePtr &kernel_node);
|
||||
virtual void Init(const CNodePtr &kernel_node);
|
||||
virtual void InitKernel(const CNodePtr &kernel_node) = 0;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void * /*stream_ptr*/) override {
|
||||
|
|
|
@ -62,10 +62,12 @@ class CPUKernelRegistrar {
|
|||
static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_reg(#OPNAME, ATTR, \
|
||||
[]() { return std::make_shared<OPCLASS>(); });
|
||||
|
||||
#define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) \
|
||||
#define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) MS_REG_CPU_KERNEL_T_(__COUNTER__, OPNAME, ATTR, OPCLASS, T)
|
||||
#define MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T)
|
||||
#define _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) \
|
||||
static_assert(std::is_base_of<CPUKernel, OPCLASS<T>>::value, " must be base of CPUKernel"); \
|
||||
static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_##T##_reg(#OPNAME, ATTR, \
|
||||
[]() { return std::make_shared<OPCLASS<T>>(); });
|
||||
static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_##OPNAME##_##T##_reg( \
|
||||
#OPNAME, ATTR, []() { return std::make_shared<OPCLASS<T>>(); });
|
||||
|
||||
#define MS_REG_CPU_KERNEL_T_S(OPNAME, ATTR, OPCLASS, T, S) \
|
||||
static_assert(std::is_base_of<CPUKernel, OPCLASS<T, S>>::value, " must be base of CPUKernel"); \
|
||||
|
|
|
@ -46,24 +46,10 @@ void SparseApplyFtrlPSKernel::InitKernel(
|
|||
if (grad_shape[0] != indices_size_) {
|
||||
MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices";
|
||||
}
|
||||
/*
|
||||
lr_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "lr");
|
||||
if (lr_ <= 0) {
|
||||
MS_LOG(EXCEPTION) << "lr should be a positive scalar";
|
||||
}
|
||||
l1_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "l1");
|
||||
if (l1_ < 0) {
|
||||
MS_LOG(EXCEPTION) << "l1 should be a non-negative scalar";
|
||||
}
|
||||
l2_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "l2");
|
||||
if (l2_ < 0) {
|
||||
MS_LOG(EXCEPTION) << "l2 should be a non-negative scalar";
|
||||
}
|
||||
lr_power_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "lr_power");
|
||||
if (lr_power_ > 0) {
|
||||
MS_LOG(EXCEPTION) << "lr_power should be a non-positive scalar";
|
||||
}
|
||||
*/
|
||||
lr_ = 0.01;
|
||||
l1_ = 1e-8;
|
||||
l2_ = 1e-8;
|
||||
lr_power_ = -0.5;
|
||||
workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float));
|
||||
workspace_size_list_.emplace_back(indices_size_ * sizeof(int));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pre_activate/pass/replace_node_by_proxy.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "device/kernel_info.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
kernel::KernelBuildInfoPtr ReplaceNodeByProxy::GenerateKernelBuildInfo(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<std::string> inputs_device_format;
|
||||
std::vector<std::string> outputs_device_format;
|
||||
std::vector<TypeId> inputs_device_type;
|
||||
std::vector<TypeId> outputs_device_type;
|
||||
std::vector<std::vector<size_t>> outputs_shape;
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
|
||||
inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index));
|
||||
inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index));
|
||||
}
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) {
|
||||
outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index));
|
||||
outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index));
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
|
||||
}
|
||||
builder.SetFusionType(AnfAlgo::GetFusionType(cnode));
|
||||
builder.SetProcessor(AnfAlgo::GetProcessor(cnode));
|
||||
builder.SetKernelType(AnfAlgo::GetKernelType(cnode));
|
||||
|
||||
builder.SetInputsFormat(inputs_device_format);
|
||||
builder.SetOutputsFormat(outputs_device_format);
|
||||
builder.SetInputsDeviceType(inputs_device_type);
|
||||
builder.SetOutputsDeviceType(outputs_device_type);
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
bool ReplaceNodeByProxy::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
|
||||
for (auto node : node_list) {
|
||||
if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) {
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
auto prim = std::make_shared<Primitive>(kEmbeddingLookupProxyOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> proxy_inputs = {NewValueNode(prim)};
|
||||
proxy_inputs.insert(proxy_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
AnfNodePtr proxy_node = func_graph->NewCNode(proxy_inputs);
|
||||
MS_EXCEPTION_IF_NULL(proxy_node);
|
||||
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
proxy_node->set_kernel_info(kernel_info);
|
||||
|
||||
AbstractBasePtrList abstract_list;
|
||||
AnfAlgo::CopyNodeAttr(kAttrPsKey, cnode, proxy_node);
|
||||
AnfAlgo::CopyNodeAttr("reduce_scatter_flag", cnode, proxy_node);
|
||||
AnfAlgo::CopyNodeAttr("offset", cnode, proxy_node);
|
||||
abstract_list.push_back(cnode->abstract());
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
||||
proxy_node->set_abstract(abstract_tuple);
|
||||
|
||||
auto kernel_build_info = GenerateKernelBuildInfo(cnode);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, proxy_node.get());
|
||||
|
||||
if (!manager->Replace(cnode, proxy_node)) {
|
||||
MS_LOG(EXCEPTION) << "Replace node by proxy node failed.";
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "pre_activate/common/pass.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/anf.h"
|
||||
#include "utils/utils.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ReplaceNodeByProxy : public Pass {
|
||||
public:
|
||||
explicit ReplaceNodeByProxy(const std::string &name) : Pass(name) {}
|
||||
~ReplaceNodeByProxy() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CNodePtr &cnode);
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
"""comm_helper"""
|
||||
|
||||
|
||||
import os
|
||||
from ._hccl_management import load_lib as hccl_load_lib
|
||||
|
||||
_HCCL_AVAILABLE = False
|
||||
|
@ -44,7 +44,7 @@ else:
|
|||
|
||||
HCCL_WORLD_COMM_GROUP = "hccl_world_group"
|
||||
NCCL_WORLD_COMM_GROUP = "nccl_world_group"
|
||||
|
||||
MS_ROLE = os.getenv("MS_ROLE")
|
||||
|
||||
class Backend:
|
||||
"""
|
||||
|
@ -152,6 +152,9 @@ def _get_rank_helper(group, backend):
|
|||
Integer. The local rank id of the calling process.
|
||||
"""
|
||||
rank_id = None
|
||||
if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
|
||||
rank_id = 0
|
||||
return rank_id
|
||||
if backend == Backend.HCCL:
|
||||
if group == HCCL_WORLD_COMM_GROUP:
|
||||
rank_id = hccl.get_rank_id()
|
||||
|
@ -211,6 +214,9 @@ def _get_size_helper(group, backend):
|
|||
Integer. The rank size of specified group.
|
||||
"""
|
||||
size = None
|
||||
if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
|
||||
size = 1
|
||||
return size
|
||||
if backend == Backend.HCCL:
|
||||
if group == HCCL_WORLD_COMM_GROUP:
|
||||
size = hccl.get_rank_size()
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Communication management API"""
|
||||
import os
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
|
||||
_get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \
|
||||
|
@ -28,6 +29,7 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
|
|||
|
||||
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
||||
DEFAULT_BACKEND = Backend("hccl")
|
||||
MS_ROLE = os.getenv("MS_ROLE")
|
||||
|
||||
|
||||
def _get_group(group):
|
||||
|
@ -58,6 +60,8 @@ def init(backend_name="hccl"):
|
|||
TypeError: If backend name is not a string.
|
||||
RuntimeError: If backend is invalid or distributed init fails.
|
||||
"""
|
||||
if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
|
||||
return
|
||||
if not isinstance(backend_name, str):
|
||||
raise TypeError("Backend name must be a string, but got {}".format(type(backend_name)))
|
||||
|
||||
|
|
Loading…
Reference in New Issue