forked from mindspore-Ecosystem/mindspore
!1152 [AutoParallel] dynamic output shape handling for Reduce series & Squeeze
Merge pull request !1152 from Chong/support_squeeze_and_reduce
This commit is contained in:
commit
b124bf38a1
|
@ -67,6 +67,38 @@ std::vector<int32_t> PrepareMatMul(const std::shared_ptr<Graph> &graph,
|
|||
return s;
|
||||
}
|
||||
|
||||
// std::vector<std::vector<int32_t>> PrepareVirtualDataset(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
// const size_t iter_ops) {
|
||||
// std::vector<std::vector<int32_t>> strategies = MakeDataParallelStrategy(ops, iter_ops);
|
||||
// strategies[1][0] = strategies[0][0];
|
||||
// return strategies;
|
||||
// }
|
||||
|
||||
std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, std::vector<int32_t> s) {
|
||||
std::vector<std::vector<int32_t>> strategies;
|
||||
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
|
||||
if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() == 1) {
|
||||
auto max = s[max_element(s.begin(), s.end()) - s.begin()];
|
||||
std::vector<int32_t> s_single;
|
||||
s_single.push_back(max);
|
||||
strategies.push_back(s_single);
|
||||
continue;
|
||||
}
|
||||
strategies.push_back(s);
|
||||
}
|
||||
return strategies;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int32_t>> PrepareOneHot(std::vector<int32_t> s) {
|
||||
std::vector<std::vector<int32_t>> strategies;
|
||||
std::vector<int32_t> s_empty = {};
|
||||
strategies.push_back(s);
|
||||
strategies.push_back(s_empty);
|
||||
strategies.push_back(s_empty);
|
||||
return strategies;
|
||||
}
|
||||
|
||||
std::vector<int32_t> MakeRecSearchStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const std::shared_ptr<Graph> &graph, const size_t iter_ops,
|
||||
const size_t iter_op_inputs) {
|
||||
|
@ -163,5 +195,237 @@ std::vector<int32_t> PrepareStrategy(const std::shared_ptr<Graph> &graph,
|
|||
}
|
||||
}
|
||||
|
||||
int FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names,
|
||||
const size_t iter_ops) {
|
||||
int incoming_op_index = -1;
|
||||
for (size_t i = 1; i < (size_t)input_tensor_names[iter_ops].size(); i++) {
|
||||
for (size_t j = 0; j < (size_t)input_tensor_names.size(); j++) {
|
||||
if (input_tensor_names[iter_ops][i] == input_tensor_names[j][0]) {
|
||||
incoming_op_index = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (incoming_op_index != -1) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return incoming_op_index;
|
||||
}
|
||||
|
||||
std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, const size_t iter_graph) {
|
||||
std::vector<int32_t> s;
|
||||
for (auto input : ops[iter_ops]->inputs_tensor_info()) {
|
||||
auto input_stra_dim = input.shape().size();
|
||||
if (input_stra_dim == 0) {
|
||||
continue;
|
||||
}
|
||||
if (input_stra_dim == 1) {
|
||||
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w);
|
||||
} else if (input_stra_dim == 2) {
|
||||
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h);
|
||||
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w);
|
||||
} else if (input_stra_dim == 4) {
|
||||
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_n);
|
||||
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_c);
|
||||
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h);
|
||||
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Tensor's shape is unknown.";
|
||||
}
|
||||
break;
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const int incoming_op_index) {
|
||||
std::vector<int32_t> s;
|
||||
auto strategy = ops[incoming_op_index]->selected_strategy();
|
||||
if (strategy->GetInputNumber() == 0) {
|
||||
return s;
|
||||
}
|
||||
for (size_t i = 0; i < (size_t)ops[incoming_op_index]->inputs_tensor_info().size(); i++) {
|
||||
if (ops[incoming_op_index]->inputs_tensor_info()[i].shape().size() == 0) {
|
||||
continue;
|
||||
}
|
||||
for (size_t j = 0; j < ops[incoming_op_index]->inputs_tensor_info()[i].shape().size(); ++j) {
|
||||
s.push_back(strategy->GetInputDim()[i][j]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
std::vector<int32_t> GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const int iter_ops) {
|
||||
std::vector<int32_t> axis_list;
|
||||
auto axis_param = ops[iter_ops]->attrs().find(AXIS)->second;
|
||||
std::vector<ValuePtr> elements;
|
||||
if (axis_param->isa<ValueTuple>()) {
|
||||
elements = axis_param->cast<ValueTuplePtr>()->value();
|
||||
} else if (axis_param->isa<ValueList>()) {
|
||||
elements = axis_param->cast<ValueListPtr>()->value();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Failure: Axis type is invalid, neither tuple nor list." << std::endl;
|
||||
}
|
||||
for (auto &element : elements) {
|
||||
if (!element->isa<Int32Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "Failure: Dimension indexes is not Int32." << std::endl;
|
||||
}
|
||||
auto axis = element->cast<Int32ImmPtr>()->value();
|
||||
axis_list.push_back(axis);
|
||||
}
|
||||
return axis_list;
|
||||
}
|
||||
|
||||
std::vector<int32_t> ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const int incoming_op_index, std::vector<int32_t> s) {
|
||||
std::vector<int32_t> s_Squeeze;
|
||||
std::vector<int32_t> stra_dim_list;
|
||||
for (size_t i = 0; i < s.size(); i++) {
|
||||
stra_dim_list.push_back(i);
|
||||
}
|
||||
auto axis_list = GetAxisList(ops, incoming_op_index);
|
||||
for (auto axis : axis_list) {
|
||||
auto it = find(stra_dim_list.begin(), stra_dim_list.end(), axis);
|
||||
if (it == stra_dim_list.end()) {
|
||||
MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl;
|
||||
}
|
||||
if (ops[incoming_op_index]->inputs_tensor_info()[0].shape()[axis] != 1) {
|
||||
MS_LOG(EXCEPTION) << "Failure: Removed dimension's shape is not 1." << std::endl;
|
||||
}
|
||||
stra_dim_list.erase(it);
|
||||
}
|
||||
for (size_t i = 0; i < (size_t)stra_dim_list.size(); i++) {
|
||||
s_Squeeze.push_back(s[stra_dim_list[i]]);
|
||||
}
|
||||
return s_Squeeze;
|
||||
}
|
||||
|
||||
std::vector<int32_t> GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops) {
|
||||
std::vector<int32_t> dim_list;
|
||||
bool keep_dims;
|
||||
if (!ops[iter_ops]->attrs().find(KEEP_DIMS)->second->isa<BoolImm>()) {
|
||||
MS_LOG(EXCEPTION) << "Failure: Parameter keep_dims is not a boolean value." << std::endl;
|
||||
}
|
||||
keep_dims = ops[iter_ops]->attrs().find(KEEP_DIMS)->second->cast<BoolImmPtr>()->value();
|
||||
if (keep_dims != false) {
|
||||
return dim_list;
|
||||
}
|
||||
auto input_value = ops[iter_ops]->input_value();
|
||||
auto input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size();
|
||||
if (input_value.back()->isa<ValueTuple>()) {
|
||||
auto attr_axis = GetValue<std::vector<int>>(input_value.back());
|
||||
if (attr_axis.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Failure: This output is a 0-D tensor." << std::endl;
|
||||
}
|
||||
for (auto &axis : attr_axis) {
|
||||
axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis);
|
||||
}
|
||||
} else if (input_value.back()->isa<Int32Imm>()) {
|
||||
int axis = GetValue<int>(input_value.back());
|
||||
axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Failure: Axis type is invalid." << std::endl;
|
||||
}
|
||||
return dim_list;
|
||||
}
|
||||
|
||||
std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const int incoming_op_index, std::vector<int32_t> s) {
|
||||
std::vector<int32_t> s_Reduce;
|
||||
std::vector<int32_t> axis_list;
|
||||
for (size_t i = 0; i < s.size(); i++) {
|
||||
axis_list.push_back(i + 1);
|
||||
}
|
||||
auto dim_list = GetDimList(ops, incoming_op_index);
|
||||
for (auto axis : dim_list) {
|
||||
auto it = find(axis_list.begin(), axis_list.end(), axis);
|
||||
if (it == axis_list.end()) {
|
||||
MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl;
|
||||
}
|
||||
axis_list.erase(it);
|
||||
}
|
||||
for (size_t i = 0; i < (size_t)axis_list.size(); i++) {
|
||||
s_Reduce.push_back(s[axis_list[i] - 1]);
|
||||
}
|
||||
return s_Reduce;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, std::vector<int32_t> s) {
|
||||
std::vector<int32_t> s_empty = {};
|
||||
std::vector<std::vector<int32_t>> stra;
|
||||
if (s.size() == 0) {
|
||||
return stra;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(ops[iter_ops]);
|
||||
if (ops[iter_ops]->type() == BIAS_ADD) {
|
||||
return PrepareBiasAdd(ops, iter_ops, s);
|
||||
}
|
||||
if (ops[iter_ops]->type() == ONEHOT) {
|
||||
return PrepareOneHot(s);
|
||||
}
|
||||
for (size_t i = 0; i < (size_t)ops[iter_ops]->inputs_tensor_info().size(); i++) {
|
||||
if (ops[iter_ops]->inputs_tensor_info()[i].shape().size() == 0) {
|
||||
stra.push_back(s_empty);
|
||||
continue;
|
||||
}
|
||||
std::vector<int32_t> s_1 = s;
|
||||
bool modified = false;
|
||||
for (size_t j = 0; j < (size_t)ops[iter_ops]->inputs_tensor_info()[i].shape().size(); j++) {
|
||||
if (ops[iter_ops]->inputs_tensor_info()[i].shape()[j] == 1) {
|
||||
s_1[j] = 1;
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
if (modified) {
|
||||
stra.push_back(s_1);
|
||||
} else {
|
||||
stra.push_back(s);
|
||||
}
|
||||
}
|
||||
return stra;
|
||||
}
|
||||
|
||||
std::vector<int32_t> ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, std::vector<int32_t> s) {
|
||||
std::vector<int32_t> s_Squeeze;
|
||||
auto axis_list = GetAxisList(ops, iter_ops);
|
||||
size_t s_index = 0;
|
||||
size_t axis_list_index = 0;
|
||||
for (size_t i = 0; i < (size_t)(s.size() + axis_list.size()); i++) {
|
||||
if ((i) == (size_t)axis_list[axis_list_index]) {
|
||||
s_Squeeze.push_back(1);
|
||||
axis_list_index++;
|
||||
} else {
|
||||
s_Squeeze.push_back(s[s_index]);
|
||||
s_index++;
|
||||
}
|
||||
}
|
||||
return s_Squeeze;
|
||||
}
|
||||
|
||||
std::vector<int32_t> ModifyStrategyIfReduceOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, std::vector<int32_t> s) {
|
||||
std::vector<int32_t> dim_list = GetDimList(ops, iter_ops);
|
||||
if (dim_list.size() == 0) {
|
||||
return s;
|
||||
}
|
||||
std::vector<int32_t> s_Reduce;
|
||||
size_t s_index = 0;
|
||||
size_t dim_list_index = 0;
|
||||
for (size_t i = 0; i < (size_t)(s.size() + dim_list.size()); i++) {
|
||||
if ((i + 1) == (size_t)dim_list[dim_list_index]) {
|
||||
s_Reduce.push_back(1);
|
||||
dim_list_index++;
|
||||
} else {
|
||||
s_Reduce.push_back(s[s_index]);
|
||||
s_index++;
|
||||
}
|
||||
}
|
||||
return s_Reduce;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,6 +31,11 @@ void GenerateStrategy(std::shared_ptr<Graph> graph, const std::vector<std::share
|
|||
std::vector<int32_t> PrepareMatMul(const std::shared_ptr<Graph> &graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_nodes,
|
||||
const size_t iter_op_inputs);
|
||||
std::vector<std::vector<int32_t>> PrepareVirtualDataset(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops);
|
||||
std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, std::vector<int32_t> s);
|
||||
std::vector<std::vector<int32_t>> PrepareOneHot(std::vector<int32_t> s);
|
||||
std::vector<int32_t> MakeRecSearchStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const std::shared_ptr<Graph> &graph, const size_t iter_ops,
|
||||
const size_t iter_op_inputs);
|
||||
|
@ -39,6 +44,24 @@ std::vector<int32_t> MakeDataParallelStrategy(const std::vector<std::shared_ptr<
|
|||
std::vector<int32_t> PrepareStrategy(const std::shared_ptr<Graph> &graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||
const size_t iter_op_inputs);
|
||||
int FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names, const size_t iter_ops);
|
||||
std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, const size_t iter_graph);
|
||||
std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const int incoming_op_index);
|
||||
std::vector<int32_t> GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const int iter_ops);
|
||||
std::vector<int32_t> ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const int incoming_op_index, std::vector<int32_t> s);
|
||||
std::vector<int32_t> GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops);
|
||||
std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const int incoming_op_index, std::vector<int32_t> s);
|
||||
std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, std::vector<int32_t> s);
|
||||
std::vector<int32_t> ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, std::vector<int32_t> s);
|
||||
std::vector<int32_t> ModifyStrategyIfReduceOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, std::vector<int32_t> s);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_
|
||||
|
|
|
@ -140,6 +140,7 @@ class OperatorInfo {
|
|||
CostPtr selected_cost() const { return selected_cost_; }
|
||||
Status InitSelectedStrategy(const StrategyPtr &s_strategy) { return Init(s_strategy); }
|
||||
void set_input_value(const std::vector<ValuePtr> &input_value) { input_value_ = input_value; }
|
||||
const std::vector<ValuePtr> &input_value() const { return input_value_; }
|
||||
void set_outputs_dtype(const TypePtr &dtype) { outputs_dtype_ = dtype; }
|
||||
void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; }
|
||||
bool is_alive() const { return is_alive_; }
|
||||
|
|
Loading…
Reference in New Issue