forked from mindspore-Ecosystem/mindspore
!44812 Optimize the setting of pynative shard
Merge pull request !44812 from liuluobin/shard_identity
This commit is contained in:
commit
2efc5fcd68
|
@ -232,6 +232,7 @@ class SqrtCost : public CastCost {
|
|||
using TanhCost = SqrtCost;
|
||||
using EluCost = SqrtCost;
|
||||
using ReLUCost = SqrtCost;
|
||||
using identityCost = SqrtCost;
|
||||
using SigmoidCost = SqrtCost;
|
||||
using ReciprocalCost =
|
||||
SqrtCost; // The derivative of 'Reciprocal' is different on 'Ascend' and 'GPU'. Here, 'Ascend' is chosen
|
||||
|
|
|
@ -655,6 +655,7 @@ REGISTER(CumSumInfo);
|
|||
REGISTER(CumProdInfo);
|
||||
REGISTER(EluInfo);
|
||||
REGISTER(ReLUInfo);
|
||||
REGISTER(identityInfo);
|
||||
REGISTER(RepeatElementsInfo);
|
||||
REGISTER(ReLU6Info);
|
||||
REGISTER(SoftsignInfo);
|
||||
|
|
|
@ -184,6 +184,14 @@ class ReLUInfo : public ActivationOther {
|
|||
~ReLUInfo() override = default;
|
||||
};
|
||||
|
||||
class identityInfo : public ActivationOther {
|
||||
public:
|
||||
identityInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<identityCost>()) {}
|
||||
~identityInfo() override = default;
|
||||
};
|
||||
|
||||
class RepeatElementsInfo : public ActivationOther {
|
||||
public:
|
||||
RepeatElementsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
|
|
|
@ -48,9 +48,11 @@ static void GenerateDefaultStrategy(const ValueNodePtr &axes, const std::vector<
|
|||
|
||||
// Generate strategies like ((), (), ..., ())
|
||||
Shapes GenerateEmptyStrategies(const CNodePtr &cnode) {
|
||||
size_t input_size = cnode->size() - 1;
|
||||
Shapes ret_strategies(input_size, Shape());
|
||||
return ret_strategies;
|
||||
auto shape_list = ExtractShape(cnode);
|
||||
if (shape_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << " failed to extract shape.";
|
||||
}
|
||||
return Shapes(shape_list[0].size(), Shape());
|
||||
}
|
||||
|
||||
static bool CheckOneDimensionalIntTuple(const ValuePtr &value_ptr) {
|
||||
|
@ -172,7 +174,8 @@ AnfNodeIndexSet FindAnfNodeIndexSetToInsertStrategy(const FuncGraphPtr &func_gra
|
|||
bfs_list.pop();
|
||||
CNodePtr cnode = user.first->cast<CNodePtr>();
|
||||
// If the cnode is not a splittable operator, apply strategy to the next cnode
|
||||
if (!IsSplittableOperator(GetPrimName(cnode)) || IsPrimitiveCNode(cnode, prim::kPrimVirtualDataset)) {
|
||||
if (!IsSplittableOperator(GetPrimName(cnode)) || IsPrimitiveCNode(cnode, prim::kPrimVirtualDataset) ||
|
||||
IsPrimitiveCNode(cnode, prim::kPrimCast)) {
|
||||
auto tmp_users = manager->node_users()[cnode];
|
||||
(void)std::for_each(tmp_users.begin(), tmp_users.end(),
|
||||
[&bfs_list](const std::pair<AnfNodePtr, int> &user) { bfs_list.push(user); });
|
||||
|
@ -183,6 +186,30 @@ AnfNodeIndexSet FindAnfNodeIndexSetToInsertStrategy(const FuncGraphPtr &func_gra
|
|||
return ret_set;
|
||||
}
|
||||
|
||||
// New a primitive for cnode and set in_strategy to it.
|
||||
void SetStrategyToCNode(const CNodePtr &cnode, const Shapes &strategies) {
|
||||
auto strategy = ShapesToValueTuplePtr(strategies);
|
||||
PrimitivePtr prim = GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
PrimitivePtr new_prim;
|
||||
if (prim->isa<PrimitivePy>()) {
|
||||
PrimitivePyPtr prim_py = prim->cast<PrimitivePyPtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_py);
|
||||
new_prim = std::make_shared<PrimitivePy>(*prim_py);
|
||||
} else {
|
||||
new_prim = std::make_shared<Primitive>(*prim);
|
||||
}
|
||||
auto attrs_temp = prim->attrs();
|
||||
attrs_temp[parallel::IN_STRATEGY] = strategy;
|
||||
(void)new_prim->SetAttrs(attrs_temp);
|
||||
|
||||
ValuePtr new_prim_value = MakeValue(new_prim);
|
||||
ValueNodePtr new_prim_value_node = NewValueNode(new_prim_value);
|
||||
AnfNodePtr new_prim_anf_node = new_prim_value_node->cast<AnfNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(new_prim_anf_node);
|
||||
cnode->set_input(0, new_prim_anf_node);
|
||||
}
|
||||
|
||||
static std::set<CNodePtr> SetInputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &in_strategy,
|
||||
const int64_t &device_num) {
|
||||
auto in_strategy_tuple = in_strategy->cast<ValueNodePtr>();
|
||||
|
@ -239,21 +266,7 @@ static std::set<CNodePtr> SetInputLayout(const FuncGraphPtr &func_graph, const A
|
|||
|
||||
for (auto &cnode : concerned_nodes) {
|
||||
Shapes ret_strategy = GenerateDefaultStrategiesForCNode(cnode, input_strategy);
|
||||
// Set in_strategy
|
||||
auto strategy = ShapesToValueTuplePtr(ret_strategy);
|
||||
PrimitivePtr prim = GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto attrs_temp = prim->attrs();
|
||||
attrs_temp[parallel::IN_STRATEGY] = strategy;
|
||||
PrimitivePyPtr prim_py = prim->cast<PrimitivePyPtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_py);
|
||||
PrimitivePtr new_prim = std::make_shared<PrimitivePy>(*prim_py);
|
||||
(void)new_prim->SetAttrs(attrs_temp);
|
||||
ValuePtr new_prim_value = MakeValue(new_prim);
|
||||
ValueNodePtr new_prim_value_node = NewValueNode(new_prim_value);
|
||||
AnfNodePtr new_prim_anf_node = new_prim_value_node->cast<AnfNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(new_prim_anf_node);
|
||||
cnode->set_input(0, new_prim_anf_node);
|
||||
SetStrategyToCNode(cnode, ret_strategy);
|
||||
}
|
||||
return concerned_nodes;
|
||||
}
|
||||
|
@ -295,7 +308,8 @@ AnfNodePtr SearchParamByName(const std::vector<AnfNodePtr> ¶meter_list, std:
|
|||
}
|
||||
|
||||
static std::set<CNodePtr> SetParameterLayout(const FuncGraphPtr &root, const FuncGraphPtr &func_graph,
|
||||
const AnfNodePtr ¶meter_plan) {
|
||||
const AnfNodePtr ¶meter_plan,
|
||||
const std::set<CNodePtr> &input_concerned_node) {
|
||||
auto parameter_plan_vnode = parameter_plan->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(parameter_plan_vnode);
|
||||
if (parameter_plan_vnode->value()->isa<None>()) {
|
||||
|
@ -309,48 +323,65 @@ static std::set<CNodePtr> SetParameterLayout(const FuncGraphPtr &root, const Fun
|
|||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
auto root_parameters = root->parameters();
|
||||
std::set<CNodePtr> concerned_cnode;
|
||||
std::unordered_map<AnfNodePtr, Shape> parameter_layout_setting;
|
||||
for (auto p : parameter_plan_list) {
|
||||
auto p_tuple = p->cast<ValueTuplePtr>()->value();
|
||||
auto param_layout = GetValue<Shape>(p_tuple[kIndex1]);
|
||||
auto param_name = GetValue<std::string>(p_tuple[0]);
|
||||
auto parameter = SearchParamByName(root_parameters, param_name);
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(WARNING) << "Parameter \'" << param_name << "\' is not exist, ignore it.";
|
||||
continue;
|
||||
} else if (parameter_layout_setting.find(parameter) != parameter_layout_setting.end()) {
|
||||
MS_LOG(WARNING) << "The layout of parameter '" << param_name << "' has been set to "
|
||||
<< parameter_layout_setting[parameter] << ", current setting " << param_layout
|
||||
<< " will be ignored.";
|
||||
continue;
|
||||
}
|
||||
parameter_layout_setting.insert({parameter, param_layout});
|
||||
AnfNodeIndexSet users = manager->node_users()[parameter];
|
||||
auto to_insert_nodes_set = FindAnfNodeIndexSetToInsertStrategy(func_graph, users);
|
||||
|
||||
for (auto user : to_insert_nodes_set) {
|
||||
CNodePtr cnode = user.first->cast<CNodePtr>();
|
||||
PrimitivePtr prim = GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto attrs = prim->attrs();
|
||||
if (!StrategyFound(attrs)) {
|
||||
auto empty_strategies = GenerateEmptyStrategies(cnode);
|
||||
attrs[parallel::IN_STRATEGY] = ShapesToValueTuplePtr(empty_strategies);
|
||||
CNodePtr target_cnode = user.first->cast<CNodePtr>();
|
||||
Shapes current_strategies;
|
||||
if (input_concerned_node.find(target_cnode) == input_concerned_node.end()) {
|
||||
// If target_cnode is not involve inputs, insert an identity between Load and target_cnode,
|
||||
// and setting layout into identity.
|
||||
// e.g Load(param) -> identity{in_strategy} -> target_cnode
|
||||
auto pre_cnode = target_cnode->input(user.second)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(pre_cnode);
|
||||
if (IsPrimitiveCNode(pre_cnode, prim::kPrimCast)) {
|
||||
pre_cnode = pre_cnode->inputs().at(kIndex1)->cast<CNodePtr>();
|
||||
}
|
||||
if (!IsPrimitiveCNode(pre_cnode, prim::kPrimLoad)) {
|
||||
MS_LOG(EXCEPTION) << "The operator type of the " << user.second << "-th input in "
|
||||
<< target_cnode->fullname_with_scope() << " must be 'Load', but got "
|
||||
<< GetCNodePrimitive(pre_cnode)->ToString();
|
||||
}
|
||||
auto identity_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimIdentity), pre_cnode});
|
||||
auto pre_cnode_abstract = pre_cnode->abstract();
|
||||
MS_EXCEPTION_IF_NULL(pre_cnode_abstract);
|
||||
identity_cnode->set_abstract(pre_cnode_abstract->Clone());
|
||||
manager->Replace(pre_cnode, identity_cnode);
|
||||
target_cnode = identity_cnode;
|
||||
current_strategies = {param_layout};
|
||||
} else {
|
||||
// Setting layout into target_cnode directly.
|
||||
PrimitivePtr prim = GetCNodePrimitive(target_cnode);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto attrs = prim->attrs();
|
||||
if (StrategyFound(attrs)) {
|
||||
current_strategies = ValueTuplePtrToShapes(attrs[parallel::IN_STRATEGY]->cast<ValueTuplePtr>());
|
||||
} else {
|
||||
current_strategies = GenerateEmptyStrategies(target_cnode);
|
||||
}
|
||||
current_strategies[user.second - 1] = param_layout;
|
||||
(void)concerned_cnode.insert(target_cnode);
|
||||
}
|
||||
auto current_strategies = ValueTuplePtrToShapes(attrs[parallel::IN_STRATEGY]->cast<ValueTuplePtr>());
|
||||
auto param_layout = GetValue<Shape>(p_tuple[kIndex1]);
|
||||
// If a layout has been set, skip it.
|
||||
if (current_strategies[user.second - 1] != Shape()) {
|
||||
MS_LOG(WARNING) << "For " << cnode->fullname_with_scope() << ", the " << user.second
|
||||
<< "th strategy has been set to " << current_strategies[user.second - 1] << ", current setting "
|
||||
<< param_layout << " will be ignored.";
|
||||
continue;
|
||||
}
|
||||
current_strategies[user.second - 1] = param_layout;
|
||||
attrs[parallel::IN_STRATEGY] = ShapesToValueTuplePtr(current_strategies);
|
||||
PrimitivePyPtr prim_py = prim->cast<PrimitivePyPtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_py);
|
||||
PrimitivePtr new_prim = std::make_shared<PrimitivePy>(*prim_py);
|
||||
(void)new_prim->SetAttrs(attrs);
|
||||
ValueNodePtr new_prim_value_node = NewValueNode(MakeValue(new_prim));
|
||||
AnfNodePtr new_prim_anf_node = new_prim_value_node->cast<AnfNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(new_prim_anf_node);
|
||||
cnode->set_input(0, new_prim_anf_node);
|
||||
(void)concerned_cnode.insert(cnode);
|
||||
MS_LOG(INFO) << "The layout of \"" << param_name << "\" has been set to the " << user.second << "th of "
|
||||
<< cnode->fullname_with_scope() << "'s in_strategy. Current strategies is " << current_strategies;
|
||||
SetStrategyToCNode(target_cnode, current_strategies);
|
||||
MS_LOG(DEBUG) << "The layout of \"" << param_name << "\" has been set to the " << user.second << "th of "
|
||||
<< target_cnode->fullname_with_scope() << "'s in_strategy. Current strategies is "
|
||||
<< current_strategies;
|
||||
}
|
||||
}
|
||||
return concerned_cnode;
|
||||
|
@ -393,7 +424,7 @@ static bool SetStrategyForShard(const FuncGraphPtr &root, const std::vector<AnfN
|
|||
}
|
||||
std::set<CNodePtr> concerned_cnode;
|
||||
auto input_concerned_cnode = SetInputLayout(func_graph, in_strategy, device_num);
|
||||
auto parameter_concerned_cnode = SetParameterLayout(root, func_graph, parameter_plan);
|
||||
auto parameter_concerned_cnode = SetParameterLayout(root, func_graph, parameter_plan, input_concerned_cnode);
|
||||
(void)std::set_union(input_concerned_cnode.begin(), input_concerned_cnode.end(),
|
||||
parameter_concerned_cnode.begin(), parameter_concerned_cnode.end(),
|
||||
std::inserter(concerned_cnode, concerned_cnode.end()));
|
||||
|
|
|
@ -726,14 +726,12 @@ bool IsSplittableOperator(const std::string &op_name) {
|
|||
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
|
||||
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, GATHERD,
|
||||
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE,
|
||||
MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, RESIZE_NEAREST_NEIGHBOR, FAST_GELU, IOU, BOUNDING_BOX_ENCODE,
|
||||
RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE, ROI_ALIGN, REDUCE_PROD, REDUCE_ANY, REDUCE_ALL, ARGMAX, ARGMIN, ARGMINV2,
|
||||
UNSORTED_SEGMENT_PROD, SQUARE_SUM_ALL, MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, UNIQUE_CONSECUTIVE,
|
||||
RESIZE_NEAREST_NEIGHBOR, CUM_SUM, FAST_GELU, IOU, BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE,
|
||||
ROI_ALIGN, IS_FINITE, RINT, HSHRINK, HSIGMOID, MISH, SELU, SOFT_SHRINK, XLOGY, XDIVY, CUM_PROD, BITWISE_AND,
|
||||
BITWISE_OR, BITWISE_XOR, MUL_NO_NAN, TRUNCATE_DIV, TRUNCATE_MOD, INPLACE_ADD, INPLACE_SUB, L2_LOSS, LERP, ADDN,
|
||||
CDIST, SQUARED_DIFFERENCE, ERFINV, MASKED_FILL, SPLITV, GAMMA, KLDIV_LOSS, LIN_SPACE, CHECK_VALID, INVERT,
|
||||
UNIQUE_WITH_PAD, POPULATION_COUNT};
|
||||
MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, IOU, BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK,
|
||||
CROP_AND_RESIZE, ROI_ALIGN, REDUCE_PROD, REDUCE_ANY, REDUCE_ALL, ARGMAX, ARGMIN, ARGMINV2, UNSORTED_SEGMENT_PROD,
|
||||
SQUARE_SUM_ALL, UNIQUE_CONSECUTIVE, RESIZE_NEAREST_NEIGHBOR, CUM_SUM, IS_FINITE, RINT, HSHRINK, HSIGMOID, MISH,
|
||||
SELU, SOFT_SHRINK, XLOGY, XDIVY, CUM_PROD, BITWISE_AND, BITWISE_OR, BITWISE_XOR, MUL_NO_NAN, TRUNCATE_DIV,
|
||||
TRUNCATE_MOD, INPLACE_ADD, INPLACE_SUB, L2_LOSS, LERP, ADDN, CDIST, SQUARED_DIFFERENCE, ERFINV, MASKED_FILL,
|
||||
SPLITV, GAMMA, KLDIV_LOSS, LIN_SPACE, CHECK_VALID, INVERT, UNIQUE_WITH_PAD, POPULATION_COUNT, IDENTITY};
|
||||
// clang-format on
|
||||
|
||||
auto iter = splittable_op.find(op_name);
|
||||
|
|
|
@ -27,7 +27,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
|
|||
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
|
||||
"list_append", "list_map", "list_reduce", "tuple_reversed", "tile_shape", "tuple_div", "tuple_to_array",
|
||||
"make_dict", "make_slice", "string_eq", "VirtualLoss", "Return", "env_getitem",
|
||||
"identity", "partial", "env_setitem", "env_getitem", "env_add",
|
||||
"partial", "env_setitem", "env_getitem", "env_add",
|
||||
"dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary",
|
||||
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
|
||||
"InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
|
||||
|
|
|
@ -374,7 +374,7 @@ def train_feed(num_classes, expect_out):
|
|||
model.train(3, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
|
||||
loss_value = np.array(parallel_callback.loss_list)
|
||||
print(loss_value)
|
||||
assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)
|
||||
assert np.allclose(loss_value, expect_out, 0.0005, 0.0005)
|
||||
|
||||
|
||||
def test_train_feed_ascend():
|
||||
|
@ -391,7 +391,7 @@ def test_train_feed_ascend():
|
|||
dataset_strategy="data_parallel")
|
||||
np.random.seed(42)
|
||||
set_seed(42)
|
||||
train_feed(num_classes=65536, expect_out=[11.37239, 11.068878, 10.525374])
|
||||
train_feed(num_classes=65536, expect_out=[11.372186, 11.068188, 10.518132])
|
||||
|
||||
|
||||
def test_train_feed_gpu():
|
||||
|
|
Loading…
Reference in New Issue