!44812 Optimize the setting of pynative shard

Merge pull request !44812 from liuluobin/shard_identity
This commit is contained in:
i-robot 2022-11-01 08:22:34 +00:00 committed by Gitee
commit 2efc5fcd68
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 101 additions and 62 deletions

View File

@ -232,6 +232,7 @@ class SqrtCost : public CastCost {
using TanhCost = SqrtCost;
using EluCost = SqrtCost;
using ReLUCost = SqrtCost;
using identityCost = SqrtCost;
using SigmoidCost = SqrtCost;
using ReciprocalCost =
SqrtCost; // The derivative of 'Reciprocal' is different on 'Ascend' and 'GPU'. Here, 'Ascend' is chosen

View File

@ -655,6 +655,7 @@ REGISTER(CumSumInfo);
REGISTER(CumProdInfo);
REGISTER(EluInfo);
REGISTER(ReLUInfo);
REGISTER(identityInfo);
REGISTER(RepeatElementsInfo);
REGISTER(ReLU6Info);
REGISTER(SoftsignInfo);

View File

@ -184,6 +184,14 @@ class ReLUInfo : public ActivationOther {
~ReLUInfo() override = default;
};
class identityInfo : public ActivationOther {
public:
identityInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<identityCost>()) {}
~identityInfo() override = default;
};
class RepeatElementsInfo : public ActivationOther {
public:
RepeatElementsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,

View File

@ -48,9 +48,11 @@ static void GenerateDefaultStrategy(const ValueNodePtr &axes, const std::vector<
// Generate strategies like ((), (), ..., ())
Shapes GenerateEmptyStrategies(const CNodePtr &cnode) {
size_t input_size = cnode->size() - 1;
Shapes ret_strategies(input_size, Shape());
return ret_strategies;
auto shape_list = ExtractShape(cnode);
if (shape_list.empty()) {
MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << " failed to extract shape.";
}
return Shapes(shape_list[0].size(), Shape());
}
static bool CheckOneDimensionalIntTuple(const ValuePtr &value_ptr) {
@ -172,7 +174,8 @@ AnfNodeIndexSet FindAnfNodeIndexSetToInsertStrategy(const FuncGraphPtr &func_gra
bfs_list.pop();
CNodePtr cnode = user.first->cast<CNodePtr>();
// If the cnode is not a splittable operator, apply strategy to the next cnode
if (!IsSplittableOperator(GetPrimName(cnode)) || IsPrimitiveCNode(cnode, prim::kPrimVirtualDataset)) {
if (!IsSplittableOperator(GetPrimName(cnode)) || IsPrimitiveCNode(cnode, prim::kPrimVirtualDataset) ||
IsPrimitiveCNode(cnode, prim::kPrimCast)) {
auto tmp_users = manager->node_users()[cnode];
(void)std::for_each(tmp_users.begin(), tmp_users.end(),
[&bfs_list](const std::pair<AnfNodePtr, int> &user) { bfs_list.push(user); });
@ -183,6 +186,30 @@ AnfNodeIndexSet FindAnfNodeIndexSetToInsertStrategy(const FuncGraphPtr &func_gra
return ret_set;
}
// New a primitive for cnode and set in_strategy to it.
void SetStrategyToCNode(const CNodePtr &cnode, const Shapes &strategies) {
auto strategy = ShapesToValueTuplePtr(strategies);
PrimitivePtr prim = GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(prim);
PrimitivePtr new_prim;
if (prim->isa<PrimitivePy>()) {
PrimitivePyPtr prim_py = prim->cast<PrimitivePyPtr>();
MS_EXCEPTION_IF_NULL(prim_py);
new_prim = std::make_shared<PrimitivePy>(*prim_py);
} else {
new_prim = std::make_shared<Primitive>(*prim);
}
auto attrs_temp = prim->attrs();
attrs_temp[parallel::IN_STRATEGY] = strategy;
(void)new_prim->SetAttrs(attrs_temp);
ValuePtr new_prim_value = MakeValue(new_prim);
ValueNodePtr new_prim_value_node = NewValueNode(new_prim_value);
AnfNodePtr new_prim_anf_node = new_prim_value_node->cast<AnfNodePtr>();
MS_EXCEPTION_IF_NULL(new_prim_anf_node);
cnode->set_input(0, new_prim_anf_node);
}
static std::set<CNodePtr> SetInputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &in_strategy,
const int64_t &device_num) {
auto in_strategy_tuple = in_strategy->cast<ValueNodePtr>();
@ -239,21 +266,7 @@ static std::set<CNodePtr> SetInputLayout(const FuncGraphPtr &func_graph, const A
for (auto &cnode : concerned_nodes) {
Shapes ret_strategy = GenerateDefaultStrategiesForCNode(cnode, input_strategy);
// Set in_strategy
auto strategy = ShapesToValueTuplePtr(ret_strategy);
PrimitivePtr prim = GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(prim);
auto attrs_temp = prim->attrs();
attrs_temp[parallel::IN_STRATEGY] = strategy;
PrimitivePyPtr prim_py = prim->cast<PrimitivePyPtr>();
MS_EXCEPTION_IF_NULL(prim_py);
PrimitivePtr new_prim = std::make_shared<PrimitivePy>(*prim_py);
(void)new_prim->SetAttrs(attrs_temp);
ValuePtr new_prim_value = MakeValue(new_prim);
ValueNodePtr new_prim_value_node = NewValueNode(new_prim_value);
AnfNodePtr new_prim_anf_node = new_prim_value_node->cast<AnfNodePtr>();
MS_EXCEPTION_IF_NULL(new_prim_anf_node);
cnode->set_input(0, new_prim_anf_node);
SetStrategyToCNode(cnode, ret_strategy);
}
return concerned_nodes;
}
@ -295,7 +308,8 @@ AnfNodePtr SearchParamByName(const std::vector<AnfNodePtr> &parameter_list, std:
}
static std::set<CNodePtr> SetParameterLayout(const FuncGraphPtr &root, const FuncGraphPtr &func_graph,
const AnfNodePtr &parameter_plan) {
const AnfNodePtr &parameter_plan,
const std::set<CNodePtr> &input_concerned_node) {
auto parameter_plan_vnode = parameter_plan->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(parameter_plan_vnode);
if (parameter_plan_vnode->value()->isa<None>()) {
@ -309,48 +323,65 @@ static std::set<CNodePtr> SetParameterLayout(const FuncGraphPtr &root, const Fun
FuncGraphManagerPtr manager = func_graph->manager();
auto root_parameters = root->parameters();
std::set<CNodePtr> concerned_cnode;
std::unordered_map<AnfNodePtr, Shape> parameter_layout_setting;
for (auto p : parameter_plan_list) {
auto p_tuple = p->cast<ValueTuplePtr>()->value();
auto param_layout = GetValue<Shape>(p_tuple[kIndex1]);
auto param_name = GetValue<std::string>(p_tuple[0]);
auto parameter = SearchParamByName(root_parameters, param_name);
if (parameter == nullptr) {
MS_LOG(WARNING) << "Parameter \'" << param_name << "\' is not exist, ignore it.";
continue;
} else if (parameter_layout_setting.find(parameter) != parameter_layout_setting.end()) {
MS_LOG(WARNING) << "The layout of parameter '" << param_name << "' has been set to "
<< parameter_layout_setting[parameter] << ", current setting " << param_layout
<< " will be ignored.";
continue;
}
parameter_layout_setting.insert({parameter, param_layout});
AnfNodeIndexSet users = manager->node_users()[parameter];
auto to_insert_nodes_set = FindAnfNodeIndexSetToInsertStrategy(func_graph, users);
for (auto user : to_insert_nodes_set) {
CNodePtr cnode = user.first->cast<CNodePtr>();
PrimitivePtr prim = GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
if (!StrategyFound(attrs)) {
auto empty_strategies = GenerateEmptyStrategies(cnode);
attrs[parallel::IN_STRATEGY] = ShapesToValueTuplePtr(empty_strategies);
CNodePtr target_cnode = user.first->cast<CNodePtr>();
Shapes current_strategies;
if (input_concerned_node.find(target_cnode) == input_concerned_node.end()) {
// If target_cnode is not involve inputs, insert an identity between Load and target_cnode,
// and setting layout into identity.
// e.g Load(param) -> identity{in_strategy} -> target_cnode
auto pre_cnode = target_cnode->input(user.second)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_cnode);
if (IsPrimitiveCNode(pre_cnode, prim::kPrimCast)) {
pre_cnode = pre_cnode->inputs().at(kIndex1)->cast<CNodePtr>();
}
if (!IsPrimitiveCNode(pre_cnode, prim::kPrimLoad)) {
MS_LOG(EXCEPTION) << "The operator type of the " << user.second << "-th input in "
<< target_cnode->fullname_with_scope() << " must be 'Load', but got "
<< GetCNodePrimitive(pre_cnode)->ToString();
}
auto identity_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimIdentity), pre_cnode});
auto pre_cnode_abstract = pre_cnode->abstract();
MS_EXCEPTION_IF_NULL(pre_cnode_abstract);
identity_cnode->set_abstract(pre_cnode_abstract->Clone());
manager->Replace(pre_cnode, identity_cnode);
target_cnode = identity_cnode;
current_strategies = {param_layout};
} else {
// Setting layout into target_cnode directly.
PrimitivePtr prim = GetCNodePrimitive(target_cnode);
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
if (StrategyFound(attrs)) {
current_strategies = ValueTuplePtrToShapes(attrs[parallel::IN_STRATEGY]->cast<ValueTuplePtr>());
} else {
current_strategies = GenerateEmptyStrategies(target_cnode);
}
current_strategies[user.second - 1] = param_layout;
(void)concerned_cnode.insert(target_cnode);
}
auto current_strategies = ValueTuplePtrToShapes(attrs[parallel::IN_STRATEGY]->cast<ValueTuplePtr>());
auto param_layout = GetValue<Shape>(p_tuple[kIndex1]);
// If a layout has been set, skip it.
if (current_strategies[user.second - 1] != Shape()) {
MS_LOG(WARNING) << "For " << cnode->fullname_with_scope() << ", the " << user.second
<< "th strategy has been set to " << current_strategies[user.second - 1] << ", current setting "
<< param_layout << " will be ignored.";
continue;
}
current_strategies[user.second - 1] = param_layout;
attrs[parallel::IN_STRATEGY] = ShapesToValueTuplePtr(current_strategies);
PrimitivePyPtr prim_py = prim->cast<PrimitivePyPtr>();
MS_EXCEPTION_IF_NULL(prim_py);
PrimitivePtr new_prim = std::make_shared<PrimitivePy>(*prim_py);
(void)new_prim->SetAttrs(attrs);
ValueNodePtr new_prim_value_node = NewValueNode(MakeValue(new_prim));
AnfNodePtr new_prim_anf_node = new_prim_value_node->cast<AnfNodePtr>();
MS_EXCEPTION_IF_NULL(new_prim_anf_node);
cnode->set_input(0, new_prim_anf_node);
(void)concerned_cnode.insert(cnode);
MS_LOG(INFO) << "The layout of \"" << param_name << "\" has been set to the " << user.second << "th of "
<< cnode->fullname_with_scope() << "'s in_strategy. Current strategies is " << current_strategies;
SetStrategyToCNode(target_cnode, current_strategies);
MS_LOG(DEBUG) << "The layout of \"" << param_name << "\" has been set to the " << user.second << "th of "
<< target_cnode->fullname_with_scope() << "'s in_strategy. Current strategies is "
<< current_strategies;
}
}
return concerned_cnode;
@ -393,7 +424,7 @@ static bool SetStrategyForShard(const FuncGraphPtr &root, const std::vector<AnfN
}
std::set<CNodePtr> concerned_cnode;
auto input_concerned_cnode = SetInputLayout(func_graph, in_strategy, device_num);
auto parameter_concerned_cnode = SetParameterLayout(root, func_graph, parameter_plan);
auto parameter_concerned_cnode = SetParameterLayout(root, func_graph, parameter_plan, input_concerned_cnode);
(void)std::set_union(input_concerned_cnode.begin(), input_concerned_cnode.end(),
parameter_concerned_cnode.begin(), parameter_concerned_cnode.end(),
std::inserter(concerned_cnode, concerned_cnode.end()));

View File

@ -726,14 +726,12 @@ bool IsSplittableOperator(const std::string &op_name) {
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, GATHERD,
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE,
MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, RESIZE_NEAREST_NEIGHBOR, FAST_GELU, IOU, BOUNDING_BOX_ENCODE,
RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE, ROI_ALIGN, REDUCE_PROD, REDUCE_ANY, REDUCE_ALL, ARGMAX, ARGMIN, ARGMINV2,
UNSORTED_SEGMENT_PROD, SQUARE_SUM_ALL, MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, UNIQUE_CONSECUTIVE,
RESIZE_NEAREST_NEIGHBOR, CUM_SUM, FAST_GELU, IOU, BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE,
ROI_ALIGN, IS_FINITE, RINT, HSHRINK, HSIGMOID, MISH, SELU, SOFT_SHRINK, XLOGY, XDIVY, CUM_PROD, BITWISE_AND,
BITWISE_OR, BITWISE_XOR, MUL_NO_NAN, TRUNCATE_DIV, TRUNCATE_MOD, INPLACE_ADD, INPLACE_SUB, L2_LOSS, LERP, ADDN,
CDIST, SQUARED_DIFFERENCE, ERFINV, MASKED_FILL, SPLITV, GAMMA, KLDIV_LOSS, LIN_SPACE, CHECK_VALID, INVERT,
UNIQUE_WITH_PAD, POPULATION_COUNT};
MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, IOU, BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK,
CROP_AND_RESIZE, ROI_ALIGN, REDUCE_PROD, REDUCE_ANY, REDUCE_ALL, ARGMAX, ARGMIN, ARGMINV2, UNSORTED_SEGMENT_PROD,
SQUARE_SUM_ALL, UNIQUE_CONSECUTIVE, RESIZE_NEAREST_NEIGHBOR, CUM_SUM, IS_FINITE, RINT, HSHRINK, HSIGMOID, MISH,
SELU, SOFT_SHRINK, XLOGY, XDIVY, CUM_PROD, BITWISE_AND, BITWISE_OR, BITWISE_XOR, MUL_NO_NAN, TRUNCATE_DIV,
TRUNCATE_MOD, INPLACE_ADD, INPLACE_SUB, L2_LOSS, LERP, ADDN, CDIST, SQUARED_DIFFERENCE, ERFINV, MASKED_FILL,
SPLITV, GAMMA, KLDIV_LOSS, LIN_SPACE, CHECK_VALID, INVERT, UNIQUE_WITH_PAD, POPULATION_COUNT, IDENTITY};
// clang-format on
auto iter = splittable_op.find(op_name);

View File

@ -27,7 +27,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
"list_append", "list_map", "list_reduce", "tuple_reversed", "tile_shape", "tuple_div", "tuple_to_array",
"make_dict", "make_slice", "string_eq", "VirtualLoss", "Return", "env_getitem",
"identity", "partial", "env_setitem", "env_getitem", "env_add",
"partial", "env_setitem", "env_getitem", "env_add",
"dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary",
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
"InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",

View File

@ -374,7 +374,7 @@ def train_feed(num_classes, expect_out):
model.train(3, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
loss_value = np.array(parallel_callback.loss_list)
print(loss_value)
assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)
assert np.allclose(loss_value, expect_out, 0.0005, 0.0005)
def test_train_feed_ascend():
@ -391,7 +391,7 @@ def test_train_feed_ascend():
dataset_strategy="data_parallel")
np.random.seed(42)
set_seed(42)
train_feed(num_classes=65536, expect_out=[11.37239, 11.068878, 10.525374])
train_feed(num_classes=65536, expect_out=[11.372186, 11.068188, 10.518132])
def test_train_feed_gpu():