forked from mindspore-Ecosystem/mindspore
move parallel-related black-list to core/ir, and fix the cloneCNode bug
This commit is contained in:
parent
ea98943848
commit
e78228603b
|
@ -23,68 +23,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
const std::set<std::string> BLACK_LIST = {TUPLE_GETITEM,
|
||||
J,
|
||||
LIST_GETITEM,
|
||||
ARRAY_GETITEM,
|
||||
TUPLE_SETITEM,
|
||||
DEPEND,
|
||||
LIST_SETITEM,
|
||||
ARRAY_SETITEM,
|
||||
DICT_GETITEM,
|
||||
LIST_APPEND,
|
||||
LIST_MAP,
|
||||
LIST_REDUCE,
|
||||
TUPLE_REVERSED,
|
||||
TILE_SHAPE,
|
||||
TUPLE_DIV,
|
||||
TUPLE_TO_ARRAY,
|
||||
MAKE_DICT,
|
||||
MAKE_SLICE,
|
||||
MAKE_RECORD,
|
||||
STRING_EQUAL,
|
||||
VIRTUALLOSS,
|
||||
RETURN,
|
||||
ENV_GETITEM,
|
||||
IDENTITY,
|
||||
PARTIAL,
|
||||
ENVSETITEM,
|
||||
ENVGETITEM,
|
||||
ENVADD,
|
||||
MAKEREFKEY,
|
||||
MAKEREF,
|
||||
GETREFKEY,
|
||||
GETREFVALUE,
|
||||
GETREFORIGIN,
|
||||
DOT,
|
||||
IM2COL,
|
||||
COL2IM,
|
||||
IM2COLV1,
|
||||
STATESETITEM,
|
||||
SCALARSUMMARY,
|
||||
IMAGESUMMARY,
|
||||
TENSORSUMMARY,
|
||||
DEBUG,
|
||||
HISTOGRAMSUMMARY,
|
||||
COL2IMV1,
|
||||
RESOLVE,
|
||||
BROADCASTGRADIENTARGS,
|
||||
INVERTPERMUTATION,
|
||||
CONTROLDEPEND,
|
||||
DROPOUT_GEN_MASK,
|
||||
EMBED,
|
||||
CREATINSTANCE,
|
||||
REF_TO_EMBED,
|
||||
STOP_GRADIENT,
|
||||
SEND};
|
||||
|
||||
const std::set<std::string> BATCH_PARALLEL_BLACK_LIST = {PACK, TENSOR_SCATTER_UPDATE, MIN_MAX_UPDATE_PER_LAYER};
|
||||
|
||||
bool IsInBlackList(const PrimitivePtr &prim) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
return (BLACK_LIST.find(prim->name()) != BLACK_LIST.end());
|
||||
}
|
||||
|
||||
bool IsInBatchParallelBlackList(const PrimitivePtr &prim) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
return (BATCH_PARALLEL_BLACK_LIST.find(prim->name()) != BATCH_PARALLEL_BLACK_LIST.end());
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
bool IsInBlackList(const PrimitivePtr &prim);
|
||||
bool IsInBatchParallelBlackList(const PrimitivePtr &prim);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "base/core_ops.h"
|
||||
#include "utils/comm_manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "mindspore/core/utils/parallel_node_check.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
@ -99,7 +100,7 @@ bool PipelineTransformer::IsPipelineCareNode(const CNodePtr &cnode) {
|
|||
if (IsInWhiteList(cnode)) {
|
||||
return false;
|
||||
}
|
||||
if (IsInBlackList(prim)) {
|
||||
if (IsInParallelBlackList(prim)) {
|
||||
MS_LOG(INFO) << "PipelineSplit don't care node:" << prim->name();
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -44,6 +44,7 @@
|
|||
#include "utils/comm_manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "mindspore/core/utils/parallel_node_check.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "ps/util.h"
|
||||
#endif
|
||||
|
@ -439,7 +440,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) {
|
|||
if (prim == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (IsInBlackList(prim)) {
|
||||
if (IsInParallelBlackList(prim)) {
|
||||
MS_LOG(DEBUG) << "Parallel don't care node: " << prim->name();
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "utils/profile.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "utils/parallel_node_check.h"
|
||||
|
||||
// namespace to support intermediate representation definition
|
||||
namespace mindspore {
|
||||
|
@ -91,7 +92,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
|
|||
new_node->set_inputs_value(old_node->inputs_value());
|
||||
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
|
||||
new_node->set_scope(scope);
|
||||
if (IsPrimitiveCNode(old_node, nullptr) && new_node->scope() == kDefaultScope) {
|
||||
if (IsParallelCareCNode(old_node) && new_node->scope() == kDefaultScope) {
|
||||
new_node->set_fullname_with_scope(old_node->fullname_with_scope());
|
||||
}
|
||||
new_node->set_kernel_info(old_node->kernel_info_ptr());
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "utils/parallel_node_check.h"
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
// clang-format off
|
||||
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {"tuple_getitem", "J", "list_getitem",
|
||||
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
|
||||
"list_append", "list_map", "list_reduce", "tuple_reversed", "tile_shape", "tuple_div", "tuple_to_array",
|
||||
"make_dict", "make_slice", "make_record", "string_equal", "VirtualLoss", "return", "env_getitem",
|
||||
"identity", "partial", "env_setitem", "env_getitem", "env_add", "MakeRefKey", "make_ref", "get_ref_key",
|
||||
"get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary",
|
||||
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
|
||||
"InvertPermutation", "ControlDepend", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
|
||||
"stop_gradient", "Send"};
|
||||
// clang-format on
|
||||
|
||||
bool IsInParallelBlackList(const PrimitivePtr &prim) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
return (PARALLEL_BLACK_LIST_.find(prim->name()) != PARALLEL_BLACK_LIST_.end());
|
||||
}
|
||||
|
||||
bool IsParallelCareCNode(const CNodePtr &cnode) {
|
||||
if (cnode == nullptr || cnode->size() == 0) {
|
||||
return false;
|
||||
}
|
||||
const auto &prim_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
if (prim_node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
const auto &prim = prim_node->value()->cast<PrimitivePtr>();
|
||||
if (prim == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (IsInParallelBlackList(prim)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_
|
||||
#define MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_
|
||||
|
||||
#include "ir/primitive.h"
|
||||
|
||||
namespace mindspore {
|
||||
bool IsInParallelBlackList(const PrimitivePtr &);
|
||||
bool IsParallelCareCNode(const CNodePtr &);
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_
|
|
@ -84,9 +84,9 @@ def test_double_star_graph():
|
|||
net.set_train()
|
||||
_executor.compile(net, x, y, z, w, phase='train')
|
||||
strategies = _executor._get_shard_strategy(net)
|
||||
expected_strategies = {'Default/network-Net/Cast-op5': [[8, 1]],
|
||||
'Default/network-Net/Cast-op7': [[1, 8]],
|
||||
'Default/network-Net/MatMul-op6': [[8, 1], [1, 1]],
|
||||
'Default/network-Net/MatMul-op8': [[1, 1], [1, 8]],
|
||||
'Default/network-Net/MatMul-op4': [[1, 8], [8, 1]]}
|
||||
expected_strategies = {'Default/network-Net/Cast-op2': [[8, 1]],
|
||||
'Default/network-Net/Cast-op4': [[1, 8]],
|
||||
'Default/network-Net/MatMul-op3': [[8, 1], [1, 1]],
|
||||
'Default/network-Net/MatMul-op5': [[1, 1], [1, 8]],
|
||||
'Default/network-Net/MatMul-op1': [[1, 8], [8, 1]]}
|
||||
assert strategies == expected_strategies
|
||||
|
|
|
@ -79,8 +79,8 @@ def test_two_matmul_transpose():
|
|||
net.set_train()
|
||||
_executor.compile(net, x, y, b, phase='train')
|
||||
strategies = _executor._get_shard_strategy(net)
|
||||
expected_strategies = {'Default/network-Net/Transpose-op4': [[1, 16]],
|
||||
'Default/network-Net/Transpose-op5': [[16, 1]],
|
||||
'Default/network-Net/MatMul-op7': [[16, 1], [1, 1]],
|
||||
'Default/network-Net/MatMul-op6': [[16, 1], [1, 1]]}
|
||||
expected_strategies = {'Default/network-Net/Transpose-op1': [[1, 16]],
|
||||
'Default/network-Net/Transpose-op2': [[16, 1]],
|
||||
'Default/network-Net/MatMul-op3': [[16, 1], [1, 1]],
|
||||
'Default/network-Net/MatMul-op4': [[16, 1], [1, 1]]}
|
||||
assert strategies == expected_strategies
|
||||
|
|
Loading…
Reference in New Issue