forked from mindspore-Ecosystem/mindspore
Modify the name of the funtions and variables of Parameter shared User strategy treatment
This commit is contained in:
parent
384e6ca851
commit
fdfbe2dedc
|
@ -191,11 +191,11 @@ class CostGraph {
|
|||
inputs_tensor_name_list_ = inputs_tensor_name_list;
|
||||
}
|
||||
// Needed by rec_parser 2
|
||||
void add_shared_tensor(const std::vector<std::string> &shared_tensor_ops_names) {
|
||||
shared_tensors_ops_name_list_.push_back(shared_tensor_ops_names);
|
||||
void add_param_users_uniqueid(const std::vector<std::string> ¶m_users_uniqueid) {
|
||||
param_users_uniqueid_list_.push_back(param_users_uniqueid);
|
||||
}
|
||||
const std::vector<std::vector<std::string>> get_shared_tensors_ops_name_list() const {
|
||||
return shared_tensors_ops_name_list_;
|
||||
const std::vector<std::vector<std::string>> get_param_users_uniqueid_list() const {
|
||||
return param_users_uniqueid_list_;
|
||||
}
|
||||
void add_tuple_getitem(const std::pair<std::string, std::string> &tuple_getitem) {
|
||||
auto ret = tuple_getitem_list_.insert(tuple_getitem);
|
||||
|
@ -213,7 +213,7 @@ class CostGraph {
|
|||
// Needed by rec_parser
|
||||
std::vector<std::vector<std::string>> inputs_tensor_name_list_;
|
||||
// Needed by rec_parser 2
|
||||
std::vector<std::vector<std::string>> shared_tensors_ops_name_list_;
|
||||
std::vector<std::vector<std::string>> param_users_uniqueid_list_;
|
||||
std::map<std::string, std::string> tuple_getitem_list_;
|
||||
std::vector<OperatorInfoPtr> ops_;
|
||||
std::map<std::pair<OperatorInfoPtr, OperatorInfoPtr>, std::vector<EdgePtr>> edges_;
|
||||
|
|
|
@ -33,7 +33,7 @@ void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std
|
|||
const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
|
||||
const std::vector<std::vector<std::string>> &input_tensor_names,
|
||||
const std::shared_ptr<std::vector<size_t>> &index_list, bool is_training,
|
||||
const std::vector<std::vector<size_t>> &shared_tensors_ops) {
|
||||
const std::vector<std::vector<size_t>> ¶m_users_ops_index) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(eli_list);
|
||||
MS_EXCEPTION_IF_NULL(index_list);
|
||||
|
@ -46,7 +46,7 @@ void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std
|
|||
GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list);
|
||||
GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list);
|
||||
GenerateRemainingOperatorStrategy(graph, ops, input_tensor_names, index_list, no_stra_op_list);
|
||||
ModifySharingTensorOps(ops, shared_tensors_ops);
|
||||
ModifyParamSharingOpsStrategy(ops, param_users_ops_index);
|
||||
|
||||
for (auto &op : ops) {
|
||||
// Set user-defined strategy
|
||||
|
@ -603,21 +603,35 @@ void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> &graph,
|
|||
}
|
||||
}
|
||||
|
||||
void ModifySharingTensorOps(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const std::vector<std::vector<size_t>> &shared_tensors_ops) {
|
||||
for (auto tensor : shared_tensors_ops) {
|
||||
void ModifyParamSharingOpsStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const std::vector<std::vector<size_t>> ¶m_users_ops_index) {
|
||||
for (auto tensor : param_users_ops_index) {
|
||||
for (auto op_i : tensor) {
|
||||
Dimensions str_gather_a;
|
||||
if (ops[op_i]->type() == GATHERV2) { // It should be the operator to copy | main op > elemwise op
|
||||
str_gather_a = ops[op_i]
|
||||
->selected_strategy()
|
||||
->GetInputDim()[0]; // Instead of 0 we should put the index of input sharing the tensor
|
||||
if (ops[op_i]->type() == GATHERV2) {
|
||||
for (auto op_j : tensor) {
|
||||
if (op_i != op_j && IsStrictElementWise(ops, op_j)) {
|
||||
Strategys stra = GenerateStrategiesFromStrategy(ops, op_j, str_gather_a);
|
||||
StrategyPtr sp = std::make_shared<Strategy>(0, stra);
|
||||
MS_LOG(INFO) << "Changing strategy of " << ops[op_j]->name() << " with " << ops[op_i]->name();
|
||||
ops[op_j]->SetSelectedStrategyAndCost(sp, ops[op_j]->selected_cost());
|
||||
if (op_i != op_j) {
|
||||
Dimensions str_j;
|
||||
if (ops[op_j]->type() == CAST) {
|
||||
str_j = ops[op_j]->selected_strategy()->GetInputDim()[0];
|
||||
} else if (ops[op_j]->type() == MATMUL) {
|
||||
str_j = ops[op_j]->selected_strategy()->GetInputDim()[1];
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
Strategys strategies;
|
||||
Dimensions str1, str2;
|
||||
str1 = str_j;
|
||||
size_t num_device_used = 1;
|
||||
for (size_t i = 0; i < str_j.size(); i++) {
|
||||
num_device_used *= str_j[i];
|
||||
}
|
||||
str2.push_back(g_device_manager->DeviceNum() / num_device_used);
|
||||
str2.push_back(1);
|
||||
strategies.push_back(str1);
|
||||
strategies.push_back(str2);
|
||||
StrategyPtr sp = std::make_shared<Strategy>(0, strategies);
|
||||
MS_LOG(INFO) << "Changing strategy of " << ops[op_i]->name() << " with " << ops[op_j]->name();
|
||||
ops[op_i]->SetSelectedStrategyAndCost(sp, ops[op_i]->selected_cost());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -117,8 +117,8 @@ void GenerateRemainingOperatorStrategy(const std::shared_ptr<Graph> &graph,
|
|||
const std::vector<std::vector<std::string>> &input_tensor_names,
|
||||
const std::shared_ptr<std::vector<size_t>> &index_list,
|
||||
const std::shared_ptr<std::vector<size_t>> &no_stra_op_list);
|
||||
void ModifySharingTensorOps(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const std::vector<std::vector<size_t>> &shared_tensors_ops);
|
||||
void ModifyParamSharingOpsStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const std::vector<std::vector<size_t>> &shared_tensors_ops);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_
|
||||
|
|
|
@ -282,14 +282,5 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> &graph,
|
|||
}
|
||||
return new_graph;
|
||||
}
|
||||
bool IsStrictElementWise(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t iter_ops) {
|
||||
std::string op_type = ops[iter_ops]->type();
|
||||
auto idx = DictOpType.find(op_type);
|
||||
if (idx == DictOpType.end()) {
|
||||
return false;
|
||||
} else {
|
||||
return StrictElementWiseOpType.find(DictOpType.at(op_type)) != ElementWiseOpType.end();
|
||||
}
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,8 +38,6 @@ static const std::set<OperatorType> ElementWiseOpType = {
|
|||
OperatorType::kRecSoftmax, OperatorType::kRecOneHot, OperatorType::kRecExpandDims,
|
||||
OperatorType::kRecStridedSlice, OperatorType::kRecBatchMatMul};
|
||||
|
||||
static const std::set<OperatorType> StrictElementWiseOpType = {OperatorType::kRecElmWiseOp, OperatorType::kRecCast};
|
||||
|
||||
const std::map<std::string, OperatorType> DictOpType{
|
||||
{MATMUL, OperatorType::kRecMatMul},
|
||||
{BATCH_MATMUL, OperatorType::kRecBatchMatMul},
|
||||
|
@ -181,8 +179,6 @@ void Eliminate_Aux(const size_t node_index, const std::shared_ptr<Graph> &graph,
|
|||
std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> &graph,
|
||||
const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
|
||||
const std::shared_ptr<std::vector<size_t>> &index_list);
|
||||
|
||||
bool IsStrictElementWise(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t iter_ops);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_
|
||||
|
|
|
@ -429,17 +429,17 @@ bool IsFindWrong(const OperatorInfoPtr current_op_ptr, const std::string &prim_n
|
|||
return is_find_wrong;
|
||||
}
|
||||
|
||||
void AddSharedTensorWhenMultiUsers(
|
||||
void AddUsersUniqueIdWhenSharingParameter(
|
||||
const std::pair<std::string, std::pair<AnfNodePtr, AnfNodeIndexSet>> ¶meter_users_info) {
|
||||
auto users_set = parameter_users_info.second.second;
|
||||
if (users_set.size() > 1) {
|
||||
MS_LOG(INFO) << "Parameter " << parameter_users_info.first << " has " << users_set.size() << " users.";
|
||||
std::vector<std::string> user_names;
|
||||
std::vector<std::string> param_users_uniqueid;
|
||||
for (auto user : users_set) {
|
||||
MS_LOG(INFO) << "with ID: " << user.first->UniqueId();
|
||||
user_names.push_back(user.first->UniqueId());
|
||||
param_users_uniqueid.push_back(user.first->UniqueId());
|
||||
}
|
||||
entire_costgraph->add_shared_tensor(user_names);
|
||||
entire_costgraph->add_param_users_uniqueid(param_users_uniqueid);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -539,7 +539,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
|
|||
for (auto &node : all_nodes) {
|
||||
if (node->isa<Parameter>()) {
|
||||
ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsParallelCareNode);
|
||||
AddSharedTensorWhenMultiUsers(parameter_users_info);
|
||||
AddUsersUniqueIdWhenSharingParameter(parameter_users_info);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -658,7 +658,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
|
|||
for (auto &node : all_nodes) {
|
||||
if (node->isa<Parameter>()) {
|
||||
ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsParallelCareNode);
|
||||
AddSharedTensorWhenMultiUsers(parameter_users_info);
|
||||
AddUsersUniqueIdWhenSharingParameter(parameter_users_info);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1161,21 +1161,21 @@ size_t FindOperatorIndexById(const std::string &unique_id,
|
|||
return SIZE_MAX;
|
||||
}
|
||||
|
||||
std::vector<std::vector<size_t>> GetSharedTensorsOps(
|
||||
const std::vector<std::vector<std::string>> &shared_tensors_ops_names,
|
||||
std::vector<std::vector<size_t>> GetIndexOfOpsSharingInputTensor(
|
||||
const std::vector<std::vector<std::string>> ¶m_users_uniqueid_list,
|
||||
const std::vector<std::vector<std::string>> &input_tensor_names) {
|
||||
std::vector<std::vector<size_t>> shared_tensors_ops;
|
||||
for (auto user_names : shared_tensors_ops_names) {
|
||||
std::vector<std::vector<size_t>> param_users_ops_index;
|
||||
for (auto users_uniqueid : param_users_uniqueid_list) {
|
||||
std::vector<size_t> users_index;
|
||||
for (size_t i = 0; i < user_names.size(); i++) {
|
||||
size_t user_index = FindOperatorIndexById(user_names[i], input_tensor_names);
|
||||
for (size_t i = 0; i < users_uniqueid.size(); i++) {
|
||||
size_t user_index = FindOperatorIndexById(users_uniqueid[i], input_tensor_names);
|
||||
if (user_index != SIZE_MAX) {
|
||||
users_index.push_back(user_index);
|
||||
}
|
||||
}
|
||||
shared_tensors_ops.push_back(users_index);
|
||||
param_users_ops_index.push_back(users_index);
|
||||
}
|
||||
return shared_tensors_ops;
|
||||
return param_users_ops_index;
|
||||
}
|
||||
|
||||
Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
|
||||
|
@ -1201,14 +1201,14 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const
|
|||
auto ops = entire_costgraph->GetOperators();
|
||||
std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
|
||||
// Needed by rec_parser 2
|
||||
auto shared_tensors_ops_name_list = entire_costgraph->get_shared_tensors_ops_name_list();
|
||||
auto param_users_uniqueid_list = entire_costgraph->get_param_users_uniqueid_list();
|
||||
auto tuple_getitem_list = entire_costgraph->get_tuple_getitem_list();
|
||||
for (auto it = tuple_getitem_list.begin(); it != tuple_getitem_list.end();) {
|
||||
input_tensor_names = RecInputTensorNames(it++, input_tensor_names);
|
||||
}
|
||||
std::shared_ptr<Graph> graph = ParseGraph(ops, input_tensor_names);
|
||||
std::vector<std::vector<size_t>> shared_tensors_ops =
|
||||
GetSharedTensorsOps(shared_tensors_ops_name_list, input_tensor_names);
|
||||
std::vector<std::vector<size_t>> param_users_ops_index =
|
||||
GetIndexOfOpsSharingInputTensor(param_users_uniqueid_list, input_tensor_names);
|
||||
|
||||
std::shared_ptr<std::vector<std::vector<size_t>>> eli_list = std::make_shared<std::vector<std::vector<size_t>>>();
|
||||
std::shared_ptr<std::vector<size_t>> index_list = std::make_shared<std::vector<size_t>>();
|
||||
|
@ -1227,7 +1227,7 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const
|
|||
if (!root->has_flag(kTraining)) {
|
||||
is_training = false;
|
||||
}
|
||||
GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list, is_training, shared_tensors_ops);
|
||||
GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list, is_training, param_users_ops_index);
|
||||
|
||||
if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
|
||||
MS_LOG(INFO) << "Init selected strategy succeeded.";
|
||||
|
|
|
@ -63,11 +63,10 @@ void ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string &name, co
|
|||
size_t FindOperatorIndexById(const std::string &unique_id,
|
||||
const std::vector<std::vector<std::string>> &input_tensor_names);
|
||||
|
||||
void AddSharedTensorWhenMultiUsers(
|
||||
void AddUsersUniqueIdWhenSharingParameter(
|
||||
const std::pair<std::string, std::pair<AnfNodePtr, AnfNodeIndexSet>> ¶meter_users_info);
|
||||
|
||||
std::vector<std::vector<size_t>> GetSharedTensorsOps(
|
||||
const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
std::vector<std::vector<size_t>> GetIndexOfOpsSharingInputTensor(
|
||||
const std::vector<std::vector<std::string>> &shared_tensors_ops_names,
|
||||
const std::vector<std::vector<std::string>> &input_tensor_names);
|
||||
} // namespace parallel
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.train import Model
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.parallel import set_algo_parameters
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from tests.dataset_mock import MindData
|
||||
|
||||
|
||||
class Dataset(MindData):
|
||||
def __init__(self, input_ids, length=3):
|
||||
super(Dataset, self).__init__(size=length)
|
||||
self.input_ids = input_ids
|
||||
self.index = 0
|
||||
self.length = length
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.index >= self.length:
|
||||
raise StopIteration
|
||||
self.index += 1
|
||||
return self.input_ids
|
||||
|
||||
def reset(self):
|
||||
self.index = 0
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self,
|
||||
param_init='normal',
|
||||
height=40000,
|
||||
width=5120,
|
||||
compute_type=mstype.float16):
|
||||
super().__init__()
|
||||
self.param = Parameter(initializer(param_init, [height, width]),
|
||||
name='param', parallel_optimizer=False)
|
||||
self.param_two = Parameter(initializer(param_init, [height, width]),
|
||||
name='param_two', parallel_optimizer=False)
|
||||
self.matmul = P.MatMul(transpose_b=True)
|
||||
self.cast = P.Cast()
|
||||
self.add = P.Add()
|
||||
self.gather = P.GatherV2()
|
||||
self.dtype = compute_type
|
||||
self.width = width
|
||||
|
||||
def construct(self, input_ids):
|
||||
input_ids = self.add(input_ids, input_ids)
|
||||
output_g = self.gather(self.param, input_ids, 0)
|
||||
output_r = P.Reshape()(output_g, (-1, self.width))
|
||||
output_gt = self.gather(self.param_two, input_ids, 0)
|
||||
output_rt = P.Reshape()(output_gt, (-1, self.width))
|
||||
output_m = self.matmul(self.cast(output_r, self.dtype), self.cast(self.param, self.dtype))
|
||||
output_mt = self.matmul(output_rt, self.param_two)
|
||||
output = self.add(output_m, output_mt)
|
||||
return output
|
||||
|
||||
|
||||
def test_rec_shared_param_strmodif():
|
||||
'''
|
||||
Feature: auto_parallel_recursive_programming strategy modification when two operators share the same parameter
|
||||
Description: Modify the strategy of Gather following MatMul/Cast
|
||||
Expectation: Get expected strategies by check key op
|
||||
'''
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_context(save_graphs=True)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8,
|
||||
search_mode="recursive_programming", full_batch=False)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
|
||||
net = Net()
|
||||
learning_rate = 0.1
|
||||
momentum = 0.9
|
||||
epoch_size = 2
|
||||
input_ids = Tensor(np.ones((2, 1024)), mstype.int32)
|
||||
dataset = Dataset(input_ids)
|
||||
opt = nn.Momentum(net.trainable_params(), learning_rate, momentum)
|
||||
model = Model(net, optimizer=opt)
|
||||
model.train(epoch_size, dataset, dataset_sink_mode=False)
|
||||
stras = _cell_graph_executor._get_shard_strategy(model._train_network)
|
||||
for (k, v) in stras.items():
|
||||
if re.search("Gather", k) is not None:
|
||||
assert v == [[4, 1], [2, 1]]
|
||||
context.reset_auto_parallel_context()
|
||||
|
Loading…
Reference in New Issue