implement parallel Pack

This commit is contained in:
Yi Huaijie 2020-09-17 15:37:27 +08:00
parent 9475f9a19a
commit 6066b16838
12 changed files with 560 additions and 29 deletions

View File

@ -199,6 +199,8 @@ class SoftmaxCost : public OperatorCost {
using SoftmaxCostPtr = std::shared_ptr<SoftmaxCost>; using SoftmaxCostPtr = std::shared_ptr<SoftmaxCost>;
using TileCost = SoftmaxCost; using TileCost = SoftmaxCost;
using TileCostPtr = std::shared_ptr<TileCost>; using TileCostPtr = std::shared_ptr<TileCost>;
using PackCost = TileCost;
using PackCostPtr = std::shared_ptr<PackCost>;
using ConcatCost = TileCost; using ConcatCost = TileCost;
using ConcatCostPtr = std::shared_ptr<ConcatCost>; using ConcatCostPtr = std::shared_ptr<ConcatCost>;
using SplitCost = TileCost; using SplitCost = TileCost;

View File

@ -178,6 +178,7 @@ REGISTER(EmbeddingLookupInfo);
REGISTER(TileInfo); REGISTER(TileInfo);
REGISTER(StridedSliceInfo); REGISTER(StridedSliceInfo);
REGISTER(DropoutInfo); REGISTER(DropoutInfo);
REGISTER(PackInfo);
REGISTER(ConcatInfo); REGISTER(ConcatInfo);
REGISTER(SplitInfo); REGISTER(SplitInfo);
} // namespace parallel } // namespace parallel

View File

@ -39,7 +39,6 @@ const std::set<std::string> BLACK_LIST = {TUPLE_GETITEM,
TILE_SHAPE, TILE_SHAPE,
TUPLE_DIV, TUPLE_DIV,
TUPLE_TO_ARRAY, TUPLE_TO_ARRAY,
MAKE_LIST,
MAKE_DICT, MAKE_DICT,
MAKE_SLICE, MAKE_SLICE,
MAKE_RECORD, MAKE_RECORD,

View File

@ -41,5 +41,6 @@
#include "frontend/parallel/ops_info/strided_slice_info.h" #include "frontend/parallel/ops_info/strided_slice_info.h"
#include "frontend/parallel/ops_info/concat_info.h" #include "frontend/parallel/ops_info/concat_info.h"
#include "frontend/parallel/ops_info/split_info.h" #include "frontend/parallel/ops_info/split_info.h"
#include "frontend/parallel/ops_info/pack_info.h"
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_ #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_

View File

@ -0,0 +1,253 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/ops_info/pack_info.h"
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#include "pipeline/jit/resource.h"
namespace mindspore {
namespace parallel {
Status PackInfo::GetAttrs() {
int axis = 0;
auto axis_iter = attrs_.find(AXIS);
if (axis_iter != attrs_.end()) {
MS_EXCEPTION_IF_NULL(axis_iter->second);
if (axis_iter->second->isa<Int32Imm>()) {
axis = axis_iter->second->cast<Int32ImmPtr>()->value();
} else {
MS_LOG(ERROR) << name_ << ": The value of axis is not int";
return FAILED;
}
} else {
MS_LOG(ERROR) << name_ << ": Can not find the axis attr";
return FAILED;
}
if (inputs_shape_.empty()) {
MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
return FAILED;
}
int dim = SizeToInt(inputs_shape_[0].size());
if (axis < 0) {
axis = axis + dim;
}
axis_ = SizeToInt(axis);
return SUCCESS;
}
Status PackInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy);
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy";
return FAILED;
}
std::vector<Dimensions> stra = strategy->GetInputDim();
for (size_t i = 0; i < stra.size(); ++i) {
auto strategy_ele = stra[i];
if (axis_ > strategy_ele.size()) {
MS_LOG(ERROR) << name_ << ": The axis is out of range, the axis is " << axis_;
return FAILED;
}
for (size_t j = 0; j < strategy_ele.size(); ++j) {
if (strategy_ele[j] != stra[0][j]) {
MS_LOG(ERROR) << name_ << ": The strategy of each input tensor must be equal";
return FAILED;
}
}
}
return SUCCESS;
}
Status PackInfo::InferDevMatrixShape() {
MS_EXCEPTION_IF_NULL(strategy_);
std::vector<Dimensions> stra = strategy_->GetInputDim();
if (stra.empty()) {
MS_LOG(ERROR) << name_ << "The strategy is empty";
return FAILED;
}
dev_matrix_shape_ = stra[0];
return SUCCESS;
}
Status PackInfo::InferTensorMap() {
TensorMap in_tensor_map;
TensorMap out_tensor_map;
if (inputs_shape_.empty()) {
MS_LOG(ERROR) << name_ << "The inputs shape is empty";
return FAILED;
}
int32_t size = SizeToInt(inputs_shape_[0].size());
for (int i = 0; i < size; ++i) {
in_tensor_map.push_back(size - i - 1);
out_tensor_map.push_back(size - i - 1);
}
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
inputs_tensor_map_.push_back(in_tensor_map);
}
out_tensor_map.insert(out_tensor_map.begin() + axis_, MAP_NONE);
outputs_tensor_map_.push_back(out_tensor_map);
return SUCCESS;
}
Status PackInfo::InferMirrorOps() {
mirror_ops_.clear();
if (inputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty";
return FAILED;
}
Shape input_tensor_map = inputs_tensor_map_[0];
std::vector<Group> group;
if (CreateGroupByTensorMap(input_tensor_map, &group) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Create group for input failed.";
return FAILED;
}
if (group.empty()) {
MS_LOG(INFO) << name_ << ": The mirror group is empty.";
return SUCCESS;
}
OperatorVector input_op;
input_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
mirror_ops_.push_back(input_op);
}
return SUCCESS;
}
Status PackInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": Invalid args";
return FAILED;
}
TensorLayout input_layout, output_layout;
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
// infer tensor layout
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
return FAILED;
}
TensorInfo input_tensor_info(input_layout);
inputs_tensor_info_.push_back(input_tensor_info);
}
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
TensorInfo output_tensor_info(output_layout);
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}
void PackInfo::ReComputeBatchSplitFlagList() {
for (size_t i = 0; i < inputs_shape_.size(); i++) {
split_flag_list_[i] = true;
}
}
Status PackInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
Status PackInfo::GenerateStrategies(int32_t stage_id) {
if (InferAttrs() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer attrs failed";
return FAILED;
}
if (inputs_shape_.empty()) {
MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
return FAILED;
}
Shape input_split;
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
input_split.push_back(1);
}
// to generate the first input's strategy
Shapes splittable_input = {input_split};
Shapes tmp_inputs_shape = {inputs_shape_[0]};
std::vector<StrategyPtr> sp_vector;
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Generate strategies failed";
return FAILED;
}
// the others strategies are equal to the first input's strategy
for (auto &sp : sp_vector) {
if ((sp == nullptr) || sp->GetInputDim().empty()) {
MS_LOG(ERROR) << name_ << ": The strategy is null or empty";
return FAILED;
}
Strategys tmp_strategy;
Dimensions first_input_strategy = sp->GetInputDim()[0];
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
tmp_strategy.push_back(first_input_strategy);
}
sp->ResetInputs(tmp_strategy);
}
size_t success = 0;
for (auto &sp : sp_vector) {
PrintStrategy(sp);
if (SetCostUnderStrategy(sp) == SUCCESS) {
success++;
MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy.";
PrintStrategy(sp);
}
}
return SUCCESS;
}
Status PackInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}
Status PackInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,62 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_PACK_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_PACK_INFO_H_
#include <string>
#include <memory>
#include <unordered_map>
#include <vector>
#include "ir/value.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
namespace mindspore {
namespace parallel {
class PackInfo : public OperatorInfo {
public:
PackInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<PackCost>(false)) {}
~PackInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
Status GenerateStrategies(int32_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
void ReComputeBatchSplitFlagList() override;
protected:
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
private:
size_t axis_ = 0;
};
using PackInfoPtr = std::shared_ptr<PackInfo>;
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_PACK_INFO_H_

View File

@ -116,7 +116,8 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) { std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
std::vector<bool> is_parameter; std::vector<bool> is_parameter;
std::vector<AnfNodePtr> node_inputs{node->inputs()}; std::vector<AnfNodePtr> node_inputs{node->inputs()};
if ((node_inputs.size() == 2) && AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE)) { if ((node_inputs.size() == 2) &&
(AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) {
node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs(); node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
} }
for (size_t i = 1; i < node_inputs.size(); ++i) { for (size_t i = 1; i < node_inputs.size(); ++i) {
@ -193,7 +194,8 @@ std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
std::vector<size_t> inputs_type_len; std::vector<size_t> inputs_type_len;
std::vector<AnfNodePtr> node_inputs{node->inputs()}; std::vector<AnfNodePtr> node_inputs{node->inputs()};
if ((node_inputs.size() == 2) && AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE)) { if ((node_inputs.size() == 2) &&
(AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) {
node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs(); node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
} }
@ -259,7 +261,7 @@ bool IsSplittableOperator(const std::string &op_name) {
{MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU, {MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU,
FLOORDIV, L2_NORMALIZE, TENSOR_ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK, FLOORDIV, L2_NORMALIZE, TENSOR_ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING, REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, PACK,
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT, LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT,
STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT, STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT,
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS,
@ -281,7 +283,7 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) {
return false; return false;
} }
bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name()); bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name());
if (bool_result && (prim->name() != MAKE_TUPLE)) { if (bool_result && (prim->name() != MAKE_TUPLE) && (prim->name() != MAKE_LIST)) {
MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name(); MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name();
} else if (prim->name() == CAST) { } else if (prim->name() == CAST) {
if (cnode->fullname_with_scope().find(OPTIMIZER_SUB_STRING) != std::string::npos) { if (cnode->fullname_with_scope().find(OPTIMIZER_SUB_STRING) != std::string::npos) {

View File

@ -450,7 +450,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
AnfNodeIndexSet node_set = manager->node_users()[node]; AnfNodeIndexSet node_set = manager->node_users()[node];
CNodePtr insert_node_new; CNodePtr insert_node_new;
if (AnfNodeIsPrimitive(node, MAKE_TUPLE)) { if (AnfNodeIsPrimitive(node, MAKE_TUPLE) || AnfNodeIsPrimitive(node, MAKE_LIST)) {
MS_LOG(INFO) << "No need to insert redistribution op betweend make_tuple node and the next node"; MS_LOG(INFO) << "No need to insert redistribution op betweend make_tuple node and the next node";
return; return;
} }
@ -851,7 +851,8 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
FuncGraphManagerPtr manager = func_graph->manager(); FuncGraphManagerPtr manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
if ((node->inputs().size() == 2) && AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE)) { if ((node->inputs().size() == 2) &&
(AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE) || AnfNodeIsPrimitive(node->input(1), MAKE_LIST))) {
MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node"; MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node";
return; return;
} }
@ -1055,7 +1056,7 @@ Shapes GetNodeShape(const AnfNodePtr &node) {
MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " shape_ptr is nullptr, full name is " MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " shape_ptr is nullptr, full name is "
<< node->fullname_with_scope(); << node->fullname_with_scope();
} }
auto tuple_shape_ptr = dyn_cast<abstract::TupleShape>(base_shape_ptr); auto tuple_shape_ptr = dyn_cast<abstract::SequeueShape>(base_shape_ptr);
if (tuple_shape_ptr != nullptr) { if (tuple_shape_ptr != nullptr) {
auto tuple_shape = tuple_shape_ptr->shape(); auto tuple_shape = tuple_shape_ptr->shape();
for (auto &shape : tuple_shape) { for (auto &shape : tuple_shape) {
@ -1436,7 +1437,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
SetVirtualDatasetStrategy(cnode); SetVirtualDatasetStrategy(cnode);
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
if (prim->name() == MAKE_TUPLE) { if (prim->name() == MAKE_TUPLE || prim->name() == MAKE_LIST) {
continue; continue;
} }
auto attrs = prim->attrs(); auto attrs = prim->attrs();
@ -2459,9 +2460,9 @@ Status ParallelInit() {
return SUCCESS; return SUCCESS;
} }
void HandleForwardMakeTuple(const std::vector<AnfNodePtr> &all_nodes) { void HandleForwardMakeTupleAndMakeList(const std::vector<AnfNodePtr> &all_nodes) {
for (auto &node : all_nodes) { for (auto &node : all_nodes) {
if (!AnfNodeIsPrimitive(node, MAKE_TUPLE)) { if (!AnfNodeIsPrimitive(node, MAKE_TUPLE) && !AnfNodeIsPrimitive(node, MAKE_LIST)) {
continue; continue;
} }
@ -2473,25 +2474,28 @@ void HandleForwardMakeTuple(const std::vector<AnfNodePtr> &all_nodes) {
FuncGraphManagerPtr manager = cnode->func_graph()->manager(); FuncGraphManagerPtr manager = cnode->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
auto make_tuple_user = manager->node_users()[cnode]; std::string op_type = AnfNodeIsPrimitive(node, MAKE_TUPLE) ? MAKE_TUPLE : MAKE_LIST;
if (make_tuple_user.size() != 1) {
MS_LOG(EXCEPTION) << "Now the make_tuple's user must be 1, but got " << make_tuple_user.size();
}
CNodePtr make_tuple_next_cnode = make_tuple_user.pop().first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple_next_cnode);
std::string make_tuple_user_prim_name = GetPrimName(make_tuple_next_cnode); auto make_tuple_list_user = manager->node_users()[cnode];
if (!IsParallelCareNode(make_tuple_next_cnode)) { if (make_tuple_list_user.size() != 1) {
MS_LOG(INFO) << "The make_tuple's user is " << make_tuple_user_prim_name << ", no need to set operator info"; MS_LOG(EXCEPTION) << "Now the " << op_type << "'s user must be 1, but got " << make_tuple_list_user.size();
}
CNodePtr make_tuple_list_next_cnode = make_tuple_list_user.pop().first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple_list_next_cnode);
std::string make_tuple__list_user_prim_name = GetPrimName(make_tuple_list_next_cnode);
if (!IsParallelCareNode(make_tuple_list_next_cnode)) {
MS_LOG(INFO) << "The " << op_type << "'s user is " << make_tuple__list_user_prim_name
<< ", no need to set operator info";
continue; continue;
} }
if (make_tuple_next_cnode->inputs().size() != 2) { if (make_tuple_list_next_cnode->inputs().size() != 2) {
MS_LOG(EXCEPTION) << "Now the make_tuple's user only support 1 input, but got " MS_LOG(EXCEPTION) << "Now the " << op_type << "'s user only support 1 input, but got "
<< make_tuple_next_cnode->inputs().size() - 1; << make_tuple_list_next_cnode->inputs().size() - 1;
} }
MS_LOG(INFO) << "Set the make_tuple's operator info, and the op name is " << make_tuple_user_prim_name; MS_LOG(INFO) << "Set the " << op_type << "'s operator info, and the op name is " << make_tuple__list_user_prim_name;
OperatorInfoPtr op_info = GetDistributeOperator(make_tuple_next_cnode); OperatorInfoPtr op_info = GetDistributeOperator(make_tuple_list_next_cnode);
MS_EXCEPTION_IF_NULL(op_info); MS_EXCEPTION_IF_NULL(op_info);
cnode->set_user_data<OperatorInfo>(op_info); cnode->set_user_data<OperatorInfo>(op_info);
} }
@ -2695,7 +2699,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
ReshapeInit(all_nodes); ReshapeInit(all_nodes);
} }
HandleForwardMakeTuple(all_nodes); HandleForwardMakeTupleAndMakeList(all_nodes);
// if the input or parameter has multiple users, check whether its split strategies are consistent. // if the input or parameter has multiple users, check whether its split strategies are consistent.
CheckParameterSplit(all_nodes); CheckParameterSplit(all_nodes);

View File

@ -348,6 +348,16 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py
} }
auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list); auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
return tuple; return tuple;
} else if (py::isinstance<py::list>(shape_obj) && py::isinstance<py::list>(type_obj)) {
py::list shape_list = shape_obj.cast<py::list>();
py::list typeid_list = type_obj.cast<py::list>();
AbstractBasePtrList ptr_list;
for (size_t it = 0; it < shape_list.size(); ++it) {
auto tensor_it = PyListDtype2AbstractTensor(shape_list[it], typeid_list[it]);
ptr_list.push_back(tensor_it);
}
auto list = std::make_shared<abstract::AbstractList>(ptr_list);
return list;
} else if (shape_obj.is_none() && type_obj.is_none()) { } else if (shape_obj.is_none() && type_obj.is_none()) {
// AbstractNone indicates there is no output for this CNode node. // AbstractNone indicates there is no output for this CNode node.
auto abstract_none = std::make_shared<abstract::AbstractNone>(); auto abstract_none = std::make_shared<abstract::AbstractNone>();

View File

@ -228,11 +228,19 @@ def get_bprop_virtual_div_operator(self):
dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout))) dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout)))
return (dx,) return (dx,)
dx = () if F.issubclass_(F.typeof(dout), mstype.tuple_):
input_nums = F.tuple_len(dout) dx = ()
input_nums = F.tuple_len(dout)
for i in range(input_nums):
ele_grad = op(dout[i], cast(F.scalar_to_array(divisor), dtype(dout[i])))
dx = dx + (ele_grad,)
return (dx,)
dx = []
input_nums = F.list_len(dout)
for i in range(input_nums): for i in range(input_nums):
ele_grad = op(dout[i], cast(F.scalar_to_array(divisor), dtype(dout[i]))) ele_grad = op(dout[i], cast(F.scalar_to_array(divisor), dtype(dout[i])))
dx = dx + (ele_grad,) dx.append(ele_grad)
return (dx,) return (dx,)
return bprop return bprop

View File

@ -92,6 +92,7 @@ dict_getitem = Primitive('dict_getitem')
dict_setitem = Primitive('dict_setitem') dict_setitem = Primitive('dict_setitem')
tuple_div = Primitive("tuple_div") tuple_div = Primitive("tuple_div")
tuple_len = Primitive("tuple_len") tuple_len = Primitive("tuple_len")
list_len = Primitive("list_len")
tuple_reversed = Primitive("tuple_reversed") tuple_reversed = Primitive("tuple_reversed")
make_range = Primitive("make_range") make_range = Primitive("make_range")
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')

View File

@ -0,0 +1,188 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.context as context
from mindspore import Tensor, Parameter
import mindspore.nn as nn
from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, Momentum
from mindspore.ops import operations as P
class Net(nn.Cell):
def __init__(self, weight1, weight2, axis=0, strategy1=None, strategy2=None, is_parameter=True):
super(Net, self).__init__()
self.pack = P.Pack(axis=axis).shard(strategy1)
self.mul = P.Mul().shard(strategy2)
if is_parameter:
self.weight1 = Parameter(weight1, "w1")
else:
self.weight1 = weight1
self.weight2 = Parameter(weight2, "w2")
def construct(self, x):
out = self.pack([self.weight1, self.weight2])
out = self.mul(x, out)
return out
class Net1(nn.Cell):
def __init__(self, weight1, weight2, axis=0, strategy1=None, strategy2=None):
super(Net1, self).__init__()
self.pack = P.Pack(axis=axis).shard(strategy1)
self.mul = P.Mul().shard(strategy2)
self.weight1 = Parameter(weight1, "w1")
self.weight2 = Parameter(weight2, "w2")
def construct(self, x):
out = self.mul(x, self.weight1)
out = self.pack([out, self.weight2])
return out
class Net2(nn.Cell):
def __init__(self, weight1, weight2, weight3, axis=0, strategy1=None, strategy2=None, is_parameter=True):
super(Net2, self).__init__()
self.pack = P.Pack(axis=axis).shard(strategy1)
self.mul = P.Mul().shard(strategy2)
if is_parameter:
self.weight1 = Parameter(weight1, "w1")
else:
self.weight1 = weight1
self.weight2 = Parameter(weight2, "w2")
self.weight3 = Parameter(weight2, "w3")
def construct(self, x):
out = self.pack([self.weight1, self.weight2, self.weight3])
out = self.mul(x, out)
return out
_w1 = Tensor(np.ones([48, 64]), dtype=ms.float32)
_w2 = Tensor(np.ones([48, 64]), dtype=ms.float32)
_w3 = Tensor(np.ones([48, 64]), dtype=ms.float32)
_x = Tensor(np.ones([2, 48, 64]), dtype=ms.float32)
_x1 = Tensor(np.ones([48, 64]), dtype=ms.float32)
_x2 = Tensor(np.ones([3, 48, 64]), dtype=ms.float32)
def compile_net(net):
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
_executor.compile(train_net, _x)
context.reset_auto_parallel_context()
def compile_net1(net):
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
_executor.compile(train_net, _x1)
context.reset_auto_parallel_context()
def compile_net2(net):
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
_executor.compile(train_net, _x2)
context.reset_auto_parallel_context()
def test_pack_parameter():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((4, 2), (4, 2))
strategy2 = ((1, 4, 2), (1, 4, 2))
net = Net(_w1, _w2, 0, strategy1, strategy2)
compile_net(net)
def test_pack_parameter_no_full_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2), (2, 2))
strategy2 = ((1, 4, 2), (1, 4, 2))
net = Net(_w1, _w2, 0, strategy1, strategy2)
compile_net(net)
def test_pack_tensor_and_parameter():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((4, 2), (4, 2))
strategy2 = ((1, 4, 2), (1, 4, 2))
net = Net(_w1, _w2, 0, strategy1, strategy2, False)
compile_net(net)
def test_pack_output():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((4, 2), (4, 2))
strategy2 = ((4, 2), (4, 2))
net = Net1(_w1, _w2, 0, strategy1, strategy2)
compile_net1(net)
def test_pack_output_axis1():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((4, 2), (4, 2))
strategy2 = ((4, 2), (4, 2))
net = Net1(_w1, _w2, 1, strategy1, strategy2)
compile_net1(net)
def test_pack_output_no_full_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2), (4, 2))
net = Net1(_w1, _w2, 0, strategy1, strategy2)
compile_net1(net)
def test_pack_no_strategy():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = None
strategy2 = ((4, 2), (4, 2))
net = Net1(_w1, _w2, 0, strategy1, strategy2)
compile_net1(net)
def test_pack_no_strategy_axis1():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = None
strategy2 = ((4, 2), (4, 2))
net = Net1(_w1, _w2, 1, strategy1, strategy2)
compile_net1(net)
def test_pack_auto_parallel():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net1(_w1, _w2, 0)
compile_net1(net)
def test_pack_auto_parallel_axis1():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net1(_w1, _w2, 1)
compile_net1(net)
def test_pack_auto_parallel_3_tensor():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net2(_w1, _w2, _w3)
compile_net2(net)