move parallel-related black-list to core/ir, and fix the cloneCNode bug

This commit is contained in:
Xiaoda Zhang 2020-12-09 16:44:49 +08:00
parent ea98943848
commit e78228603b
9 changed files with 98 additions and 73 deletions

View File

@ -23,68 +23,8 @@
namespace mindspore {
namespace parallel {
const std::set<std::string> BLACK_LIST = {TUPLE_GETITEM,
J,
LIST_GETITEM,
ARRAY_GETITEM,
TUPLE_SETITEM,
DEPEND,
LIST_SETITEM,
ARRAY_SETITEM,
DICT_GETITEM,
LIST_APPEND,
LIST_MAP,
LIST_REDUCE,
TUPLE_REVERSED,
TILE_SHAPE,
TUPLE_DIV,
TUPLE_TO_ARRAY,
MAKE_DICT,
MAKE_SLICE,
MAKE_RECORD,
STRING_EQUAL,
VIRTUALLOSS,
RETURN,
ENV_GETITEM,
IDENTITY,
PARTIAL,
ENVSETITEM,
ENVGETITEM,
ENVADD,
MAKEREFKEY,
MAKEREF,
GETREFKEY,
GETREFVALUE,
GETREFORIGIN,
DOT,
IM2COL,
COL2IM,
IM2COLV1,
STATESETITEM,
SCALARSUMMARY,
IMAGESUMMARY,
TENSORSUMMARY,
DEBUG,
HISTOGRAMSUMMARY,
COL2IMV1,
RESOLVE,
BROADCASTGRADIENTARGS,
INVERTPERMUTATION,
CONTROLDEPEND,
DROPOUT_GEN_MASK,
EMBED,
CREATINSTANCE,
REF_TO_EMBED,
STOP_GRADIENT,
SEND};
const std::set<std::string> BATCH_PARALLEL_BLACK_LIST = {PACK, TENSOR_SCATTER_UPDATE, MIN_MAX_UPDATE_PER_LAYER};
bool IsInBlackList(const PrimitivePtr &prim) {
MS_EXCEPTION_IF_NULL(prim);
return (BLACK_LIST.find(prim->name()) != BLACK_LIST.end());
}
bool IsInBatchParallelBlackList(const PrimitivePtr &prim) {
MS_EXCEPTION_IF_NULL(prim);
return (BATCH_PARALLEL_BLACK_LIST.find(prim->name()) != BATCH_PARALLEL_BLACK_LIST.end());

View File

@ -21,7 +21,6 @@
namespace mindspore {
namespace parallel {
bool IsInBlackList(const PrimitivePtr &prim);
bool IsInBatchParallelBlackList(const PrimitivePtr &prim);
} // namespace parallel
} // namespace mindspore

View File

@ -32,6 +32,7 @@
#include "base/core_ops.h"
#include "utils/comm_manager.h"
#include "utils/ms_context.h"
#include "mindspore/core/utils/parallel_node_check.h"
namespace mindspore {
namespace parallel {
@ -99,7 +100,7 @@ bool PipelineTransformer::IsPipelineCareNode(const CNodePtr &cnode) {
if (IsInWhiteList(cnode)) {
return false;
}
if (IsInBlackList(prim)) {
if (IsInParallelBlackList(prim)) {
MS_LOG(INFO) << "PipelineSplit don't care node:" << prim->name();
return false;
}

View File

@ -44,6 +44,7 @@
#include "utils/comm_manager.h"
#include "utils/ms_context.h"
#include "utils/symbolic.h"
#include "mindspore/core/utils/parallel_node_check.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/util.h"
#endif
@ -439,7 +440,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) {
if (prim == nullptr) {
return false;
}
if (IsInBlackList(prim)) {
if (IsInParallelBlackList(prim)) {
MS_LOG(DEBUG) << "Parallel don't care node: " << prim->name();
return false;
}

View File

@ -26,6 +26,7 @@
#include "utils/profile.h"
#include "utils/ms_context.h"
#include "ir/graph_utils.h"
#include "utils/parallel_node_check.h"
// namespace to support intermediate representation definition
namespace mindspore {
@ -91,7 +92,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
new_node->set_inputs_value(old_node->inputs_value());
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
new_node->set_scope(scope);
if (IsPrimitiveCNode(old_node, nullptr) && new_node->scope() == kDefaultScope) {
if (IsParallelCareCNode(old_node) && new_node->scope() == kDefaultScope) {
new_node->set_fullname_with_scope(old_node->fullname_with_scope());
}
new_node->set_kernel_info(old_node->kernel_info_ptr());

View File

@ -0,0 +1,57 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/parallel_node_check.h"
#include <set>
#include <string>
namespace mindspore {
// clang-format off
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {"tuple_getitem", "J", "list_getitem",
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
"list_append", "list_map", "list_reduce", "tuple_reversed", "tile_shape", "tuple_div", "tuple_to_array",
"make_dict", "make_slice", "make_record", "string_equal", "VirtualLoss", "return", "env_getitem",
"identity", "partial", "env_setitem", "env_getitem", "env_add", "MakeRefKey", "make_ref", "get_ref_key",
"get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary",
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
"InvertPermutation", "ControlDepend", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
"stop_gradient", "Send"};
// clang-format on
bool IsInParallelBlackList(const PrimitivePtr &prim) {
MS_EXCEPTION_IF_NULL(prim);
return (PARALLEL_BLACK_LIST_.find(prim->name()) != PARALLEL_BLACK_LIST_.end());
}
bool IsParallelCareCNode(const CNodePtr &cnode) {
if (cnode == nullptr || cnode->size() == 0) {
return false;
}
const auto &prim_node = cnode->input(0)->cast<ValueNodePtr>();
if (prim_node == nullptr) {
return false;
}
const auto &prim = prim_node->value()->cast<PrimitivePtr>();
if (prim == nullptr) {
return false;
}
if (IsInParallelBlackList(prim)) {
return false;
}
return true;
}
} // namespace mindspore

View File

@ -0,0 +1,26 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_
#define MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_
#include "ir/primitive.h"
namespace mindspore {
bool IsInParallelBlackList(const PrimitivePtr &);
bool IsParallelCareCNode(const CNodePtr &);
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_

View File

@ -84,9 +84,9 @@ def test_double_star_graph():
net.set_train()
_executor.compile(net, x, y, z, w, phase='train')
strategies = _executor._get_shard_strategy(net)
expected_strategies = {'Default/network-Net/Cast-op5': [[8, 1]],
'Default/network-Net/Cast-op7': [[1, 8]],
'Default/network-Net/MatMul-op6': [[8, 1], [1, 1]],
'Default/network-Net/MatMul-op8': [[1, 1], [1, 8]],
'Default/network-Net/MatMul-op4': [[1, 8], [8, 1]]}
expected_strategies = {'Default/network-Net/Cast-op2': [[8, 1]],
'Default/network-Net/Cast-op4': [[1, 8]],
'Default/network-Net/MatMul-op3': [[8, 1], [1, 1]],
'Default/network-Net/MatMul-op5': [[1, 1], [1, 8]],
'Default/network-Net/MatMul-op1': [[1, 8], [8, 1]]}
assert strategies == expected_strategies

View File

@ -79,8 +79,8 @@ def test_two_matmul_transpose():
net.set_train()
_executor.compile(net, x, y, b, phase='train')
strategies = _executor._get_shard_strategy(net)
expected_strategies = {'Default/network-Net/Transpose-op4': [[1, 16]],
'Default/network-Net/Transpose-op5': [[16, 1]],
'Default/network-Net/MatMul-op7': [[16, 1], [1, 1]],
'Default/network-Net/MatMul-op6': [[16, 1], [1, 1]]}
expected_strategies = {'Default/network-Net/Transpose-op1': [[1, 16]],
'Default/network-Net/Transpose-op2': [[16, 1]],
'Default/network-Net/MatMul-op3': [[16, 1], [1, 1]],
'Default/network-Net/MatMul-op4': [[16, 1], [1, 1]]}
assert strategies == expected_strategies