!3027 Adaptation for ps mode.

Merge pull request !3027 from ZPaC/adaptation-for-ps-mode
This commit is contained in:
mindspore-ci-bot 2020-07-13 19:25:22 +08:00 committed by Gitee
commit be9b3c53dc
7 changed files with 155 additions and 24 deletions

View File

@ -55,7 +55,7 @@ class CPUKernel : public kernel::KernelMod {
public:
CPUKernel() = default;
~CPUKernel() override = default;
void Init(const CNodePtr &kernel_node);
virtual void Init(const CNodePtr &kernel_node);
virtual void InitKernel(const CNodePtr &kernel_node) = 0;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void * /*stream_ptr*/) override {

View File

@ -62,10 +62,12 @@ class CPUKernelRegistrar {
static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_reg(#OPNAME, ATTR, \
[]() { return std::make_shared<OPCLASS>(); });
#define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) \
#define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) MS_REG_CPU_KERNEL_T_(__COUNTER__, OPNAME, ATTR, OPCLASS, T)
#define MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T)
#define _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) \
static_assert(std::is_base_of<CPUKernel, OPCLASS<T>>::value, " must be base of CPUKernel"); \
static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_##T##_reg(#OPNAME, ATTR, \
[]() { return std::make_shared<OPCLASS<T>>(); });
static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_##OPNAME##_##T##_reg( \
#OPNAME, ATTR, []() { return std::make_shared<OPCLASS<T>>(); });
#define MS_REG_CPU_KERNEL_T_S(OPNAME, ATTR, OPCLASS, T, S) \
static_assert(std::is_base_of<CPUKernel, OPCLASS<T, S>>::value, " must be base of CPUKernel"); \

View File

@ -46,24 +46,10 @@ void SparseApplyFtrlPSKernel::InitKernel(
if (grad_shape[0] != indices_size_) {
MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices";
}
/*
lr_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "lr");
if (lr_ <= 0) {
MS_LOG(EXCEPTION) << "lr should be a positive scalar";
}
l1_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "l1");
if (l1_ < 0) {
MS_LOG(EXCEPTION) << "l1 should be a non-negative scalar";
}
l2_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "l2");
if (l2_ < 0) {
MS_LOG(EXCEPTION) << "l2 should be a non-negative scalar";
}
lr_power_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "lr_power");
if (lr_power_ > 0) {
MS_LOG(EXCEPTION) << "lr_power should be a non-positive scalar";
}
*/
lr_ = 0.01;
l1_ = 1e-8;
l2_ = 1e-8;
lr_power_ = -0.5;
workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float));
workspace_size_list_.emplace_back(indices_size_ * sizeof(int));
}

View File

@ -0,0 +1,92 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/pass/replace_node_by_proxy.h"
#include <vector>
#include <memory>
#include "device/kernel_info.h"
#include "session/anf_runtime_algorithm.h"
#include "kernel/kernel_build_info.h"
namespace mindspore {
namespace opt {
kernel::KernelBuildInfoPtr ReplaceNodeByProxy::GenerateKernelBuildInfo(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
std::vector<std::string> inputs_device_format;
std::vector<std::string> outputs_device_format;
std::vector<TypeId> inputs_device_type;
std::vector<TypeId> outputs_device_type;
std::vector<std::vector<size_t>> outputs_shape;
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index));
inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index));
}
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) {
outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index));
outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index));
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
}
builder.SetFusionType(AnfAlgo::GetFusionType(cnode));
builder.SetProcessor(AnfAlgo::GetProcessor(cnode));
builder.SetKernelType(AnfAlgo::GetKernelType(cnode));
builder.SetInputsFormat(inputs_device_format);
builder.SetOutputsFormat(outputs_device_format);
builder.SetInputsDeviceType(inputs_device_type);
builder.SetOutputsDeviceType(outputs_device_type);
return builder.Build();
}
bool ReplaceNodeByProxy::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
for (auto node : node_list) {
if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) {
CNodePtr cnode = node->cast<CNodePtr>();
auto prim = std::make_shared<Primitive>(kEmbeddingLookupProxyOpName);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> proxy_inputs = {NewValueNode(prim)};
proxy_inputs.insert(proxy_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
AnfNodePtr proxy_node = func_graph->NewCNode(proxy_inputs);
MS_EXCEPTION_IF_NULL(proxy_node);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(kernel_info);
proxy_node->set_kernel_info(kernel_info);
AbstractBasePtrList abstract_list;
AnfAlgo::CopyNodeAttr(kAttrPsKey, cnode, proxy_node);
AnfAlgo::CopyNodeAttr("reduce_scatter_flag", cnode, proxy_node);
AnfAlgo::CopyNodeAttr("offset", cnode, proxy_node);
abstract_list.push_back(cnode->abstract());
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
MS_EXCEPTION_IF_NULL(abstract_tuple);
proxy_node->set_abstract(abstract_tuple);
auto kernel_build_info = GenerateKernelBuildInfo(cnode);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, proxy_node.get());
if (!manager->Replace(cnode, proxy_node)) {
MS_LOG(EXCEPTION) << "Replace node by proxy node failed.";
}
}
}
return true;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_
#include <utility>
#include <vector>
#include <string>
#include "pre_activate/common/pass.h"
#include "ir/func_graph.h"
#include "ir/anf.h"
#include "utils/utils.h"
#include "kernel/kernel_build_info.h"
namespace mindspore {
namespace opt {
class ReplaceNodeByProxy : public Pass {
public:
explicit ReplaceNodeByProxy(const std::string &name) : Pass(name) {}
~ReplaceNodeByProxy() override = default;
bool Run(const FuncGraphPtr &graph) override;
private:
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CNodePtr &cnode);
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_

View File

@ -14,7 +14,7 @@
# ============================================================================
"""comm_helper"""
import os
from ._hccl_management import load_lib as hccl_load_lib
_HCCL_AVAILABLE = False
@ -44,7 +44,7 @@ else:
HCCL_WORLD_COMM_GROUP = "hccl_world_group"
NCCL_WORLD_COMM_GROUP = "nccl_world_group"
MS_ROLE = os.getenv("MS_ROLE")
class Backend:
"""
@ -152,6 +152,9 @@ def _get_rank_helper(group, backend):
Integer. The local rank id of the calling process.
"""
rank_id = None
if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
rank_id = 0
return rank_id
if backend == Backend.HCCL:
if group == HCCL_WORLD_COMM_GROUP:
rank_id = hccl.get_rank_id()
@ -211,6 +214,9 @@ def _get_size_helper(group, backend):
Integer. The rank size of specified group.
"""
size = None
if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
size = 1
return size
if backend == Backend.HCCL:
if group == HCCL_WORLD_COMM_GROUP:
size = hccl.get_rank_size()

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""Communication management API"""
import os
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
_get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \
@ -28,6 +29,7 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
DEFAULT_BACKEND = Backend("hccl")
MS_ROLE = os.getenv("MS_ROLE")
def _get_group(group):
@ -58,6 +60,8 @@ def init(backend_name="hccl"):
TypeError: If backend name is not a string.
RuntimeError: If backend is invalid or distributed init fails.
"""
if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
return
if not isinstance(backend_name, str):
raise TypeError("Backend name must be a string, but got {}".format(type(backend_name)))