!23331 batch parallel operator replace input shape from construct

Merge pull request !23331 from yangzhenzhang/modify-batch-parallel-info
This commit is contained in:
i-robot 2021-09-14 06:32:45 +00:00 committed by Gitee
commit 2b4cdea48a
16 changed files with 279 additions and 196 deletions

View File

@ -54,20 +54,39 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) {
Status BatchParallelInfo::InferDevMatrixShape() {
dev_matrix_shape_.push_back(stage_device_size_);
if (need_replace_input_ && !inputs_shape_.empty()) {
replace_shape_ = inputs_shape_[0];
if (!replace_shape_.empty()) {
replace_shape_[0] /= stage_device_size_;
}
}
return SUCCESS;
}
Status BatchParallelInfo::InferForwardCommunication() { return SUCCESS; }
Status BatchParallelInfo::InferTensorMap() {
if (strategy_->GetInputDim()[0][0] != stage_device_size_) {
MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy.";
auto strategy = strategy_->GetInputDim();
if (strategy.empty()) {
MS_LOG(INFO) << name_ << ": the strategy is empty";
return SUCCESS;
}
if (strategy[0].empty()) {
MS_LOG(INFO) << name_ << ": the first element of strategy is empty";
return FAILED;
}
if (strategy[0][0] != stage_device_size_) {
MS_LOG(ERROR) << name_ << ": It is not a valid data parallel strategy.";
return FAILED;
}
for (size_t i = 0; i < inputs_shape_.size(); i++) {
Shape tensor_map_index;
for (size_t j = 0; j < inputs_shape_[i].size(); ++j) {
if (strategy_->GetInputDim()[i][j] == stage_device_size_ && j == 0) {
if (strategy[i][j] == stage_device_size_ && j == 0) {
tensor_map_index.push_back(0);
} else {
tensor_map_index.push_back(MAP_NONE);
@ -89,25 +108,23 @@ Status BatchParallelInfo::InferTensorMap() {
return SUCCESS;
}
Strategys BatchParallelInfo::GetOutputsStrategy() {
Strategys outputs_strategy;
for (size_t i = 0; i < outputs_shape_.size(); ++i) {
Dimensions strategy;
for (size_t j = 0; j < outputs_shape_[i].size(); ++j) {
if (i == 0 && j == 0) {
strategy.push_back(stage_device_size_);
} else {
strategy.push_back(1);
}
}
outputs_strategy.push_back(strategy);
Status BatchParallelInfo::GetAttrs() {
// if the operator's input is a shape(is not a tensor), need to assign the shape value to inputs_shape_
if (!inputs_shape_.empty()) {
return SUCCESS;
}
return outputs_strategy;
}
if (input_value_.empty()) {
return SUCCESS;
}
Status BatchParallelInfo::GetAttrs() { return SUCCESS; }
auto shape_ptr = input_value_[0]->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(shape_ptr);
inputs_shape_.push_back(GetValue<Shape>(shape_ptr));
need_replace_input_ = true;
return SUCCESS;
}
Status BatchParallelInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
@ -158,5 +175,30 @@ Status BatchParallelInfo::InferAsLossDivisor() {
as_loss_divisor_ = 1;
return SUCCESS;
}
void BatchParallelInfo::ReplaceNodeInputOrAttrs() {
if (!need_replace_input_) {
return;
}
auto cnode = cnode_;
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() != 2) {
MS_LOG(EXCEPTION) << name_ << ": The size of tile cnode's inputs must be 2";
}
if (!IsValueNode<ValueTuple>(cnode->input(1))) {
MS_LOG(EXCEPTION) << name_ << ": The input[1] of tile cnode is not ValueTuple.";
}
auto func_graph = cnode->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
ValuePtr replace_shape = MakeValue(replace_shape_);
AnfNodePtr val = NewValueNode(replace_shape);
(void)manager->Replace(cnode->input(1), val);
}
} // namespace parallel
} // namespace mindspore

View File

@ -41,6 +41,7 @@ class BatchParallelInfo : public OperatorInfo {
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
void ReplaceNodeInputOrAttrs() override;
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
@ -48,11 +49,12 @@ class BatchParallelInfo : public OperatorInfo {
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status GetAttrs() override;
Strategys GetOutputsStrategy();
Status InferAsLossDivisor() override;
private:
int64_t dev_num_;
int64_t dev_num_ = 1;
bool need_replace_input_ = false;
Shape replace_shape_;
};
class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo {

View File

@ -17,6 +17,7 @@
#include "frontend/parallel/ops_info/conv2d_info.h"
#include <algorithm>
#include <functional>
#include <cmath>
#include <memory>
#include <utility>
@ -1041,7 +1042,8 @@ Status Conv2DBackpropInputInfo::InferMirrorOps() {
return SUCCESS;
}
void Conv2DBackpropInputInfo::UpdateOutShape(const CNodePtr &cnode) {
void Conv2DBackpropInputInfo::UpdateOutShape() {
auto cnode = cnode_;
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() != 4) {
MS_LOG(EXCEPTION) << name_ << ": The size of cnode's inputs must be 4, but got " << cnode->size();
@ -1191,5 +1193,7 @@ void Conv2DBackpropInputInfo::InferNewPadList() {
MS_LOG(INFO) << name_ << ": the new pad list is " << new_pad_list_ << ", the required size of current rank is "
<< current_rank_required_size << ", new pad all is " << pad_all;
}
void Conv2DBackpropInputInfo::ReplaceNodeInputOrAttrs() { UpdateOutShape(); }
} // namespace parallel
} // namespace mindspore

View File

@ -127,7 +127,8 @@ class Conv2DBackpropInputInfo : public Conv2DInfo {
const PrimitiveAttrs &attrs)
: Conv2DInfo(name, inputs_shape, outputs_shape, attrs) {}
~Conv2DBackpropInputInfo() override = default;
void UpdateOutShape(const CNodePtr &cnode);
void UpdateOutShape();
void ReplaceNodeInputOrAttrs() override;
protected:
Status GetAttrs() override;

View File

@ -24,6 +24,8 @@
#include "ir/value.h"
#include "pipeline/jit/resource.h"
#include "frontend/parallel/auto_parallel/costmodel.h"
#include "frontend/parallel/graph_util/node_info.h"
#include "frontend/parallel/step_parallel_utils.h"
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/strategy.h"
@ -213,7 +215,8 @@ void SetGenMaskShape(const CNodePtr &cnode, const Shape &input_slice_shape) {
// split. Find the DropoutGenMask node in the anf graph according to DropoutDoMask node, and modify the input shape
// of DropoutGenMask according to the strategy of DropoutDoMask. When the DropoutDoMask performs repeated calculation
// and both seeds of DropoutGenMask are 0, two new seeds are automatically generated for DropoutGenMask.
std::vector<Operator> DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) {
std::vector<Operator> DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp() {
auto cnode = cnode_;
std::vector<Operator> replace_ops;
MS_EXCEPTION_IF_NULL(cnode);
PrimitivePtr prim = GetDropoutGenMaskPrim(cnode);
@ -262,5 +265,46 @@ std::vector<Operator> DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodeP
replace_ops.push_back(replace_op);
return replace_ops;
}
static void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) {
FuncGraphPtr func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
FuncGraphManagerPtr manager = func_graph->manager();
if (manager == nullptr) {
MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
}
std::string instance_name = CreateInstanceName(node, 0);
std::vector<AnfNodePtr> replace_input;
replace_input = ReplaceOpInput(replace_op, instance_name, node);
if (node->inputs().size() == DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
replace_input.push_back(node->input(3));
}
CNodePtr replace_node = func_graph->NewCNode(replace_input);
MS_EXCEPTION_IF_NULL(replace_node);
ScopePtr scope = node->scope();
MS_EXCEPTION_IF_NULL(scope);
replace_node->set_scope(scope);
replace_node->set_in_forward_flag(true);
replace_input[0]->set_scope(scope);
PrimitivePtr prim = GetValueNode<PrimitivePtr>(replace_node->input(0));
PrimitivePtr origin_prim = GetValueNode<PrimitivePtr>(node->input(0));
SetUserAttrs(origin_prim->attrs(), prim);
(void)manager->Replace(node, replace_node);
}
void DropoutDoMaskInfo::ReplaceNodeInputOrAttrs() {
auto cnode = cnode_;
MS_EXCEPTION_IF_NULL(cnode);
std::vector<Operator> replace_op = GetDropoutGenMaskReplaceOp();
if (replace_op.empty()) {
MS_LOG(DEBUG) << name_ << ": No need to replace dropout_gen_mask";
return;
}
if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
MS_LOG(EXCEPTION) << name_ << ": The size of drop out do mask cnode's input is not "
<< DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
}
ReplaceOneOp(replace_op[0], cnode->input(DROPOUT_GEN_MASK_INDEX)->cast<CNodePtr>());
}
} // namespace parallel
} // namespace mindspore

View File

@ -41,7 +41,8 @@ class DropoutDoMaskInfo : public OperatorInfo {
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
std::vector<Operator> GetDropoutGenMaskReplaceOp(const CNodePtr &cnode);
std::vector<Operator> GetDropoutGenMaskReplaceOp();
void ReplaceNodeInputOrAttrs() override;
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;

View File

@ -1029,6 +1029,9 @@ Status PrepareStrategyBase(int64_t stage_id, size_t dev_num, const Shapes &input
}
std::shared_ptr<Strategys> OperatorInfo::GenerateBatchStrategies() {
if (inputs_shape_.empty() && InferAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
}
ComputeBatchSplitFlagList();
return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_);
}

View File

@ -182,6 +182,7 @@ class OperatorInfo {
int32_t stage_id() const { return stage_id_; }
Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group);
Status CreateGroupForOptShard(TensorLayout *const tensor_layout, std::vector<Group> *group);
virtual void ReplaceNodeInputOrAttrs() {}
// Key for user data.
constexpr static char key[] = "OpInfo";

View File

@ -325,6 +325,7 @@ constexpr char GATHERV2[] = "Gather";
constexpr char SPARSE_GATHERV2[] = "SparseGatherV2";
constexpr char STRIDEDSLICE[] = "StridedSlice";
constexpr char SLICE[] = "Slice";
constexpr char UNIFORM_REAL[] = "UniformReal";
constexpr char BROADCAST[] = "Broadcast";
constexpr char BROADCAST_TO[] = "BroadcastTo";
constexpr char SQRT[] = "Sqrt";

View File

@ -155,7 +155,8 @@ Status TileInfo::InferMirrorOps() {
return SUCCESS;
}
void TileInfo::UpdateMultiples(const CNodePtr &cnode) {
void TileInfo::UpdateMultiples() {
auto cnode = cnode_;
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() != 3) {
MS_LOG(EXCEPTION) << name_ << ": The size of tile cnode's inputs must be 3";
@ -175,6 +176,8 @@ void TileInfo::UpdateMultiples(const CNodePtr &cnode) {
(void)manager->Replace(cnode->input(2), val);
}
void TileInfo::ReplaceNodeInputOrAttrs() { UpdateMultiples(); }
std::shared_ptr<Strategys> TileInfo::GenerateBatchStrategies() {
if (InferAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";

View File

@ -42,7 +42,8 @@ class TileInfo : public OperatorInfo {
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
void UpdateMultiples(const CNodePtr &cnode);
void UpdateMultiples();
void ReplaceNodeInputOrAttrs() override;
protected:
Status GetAttrs() override;

View File

@ -59,33 +59,11 @@ namespace mindspore {
namespace parallel {
static const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER};
static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS, LOAD, UPDATESTATE};
static const std::set<std::string> NO_INPUT_TENSOR_OPS = {UNIFORM_REAL};
// g_RefMap, for CNode B input i is a RefKey[Parameter C],
// it will be one item in map with key: C, and value: (B, i)
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap;
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
if (new_node_input.empty()) {
return;
}
auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
auto iter = attrs.find(GROUP);
if (iter != attrs.end()) {
auto value = iter->second;
MS_EXCEPTION_IF_NULL(value);
if (value->isa<StringImm>()) {
std::string hash_name = value->cast<StringImmPtr>()->value();
MS_EXCEPTION_IF_NULL(g_device_manager);
std::string rank_list_name = g_device_manager->FindRankListNameByHashName(hash_name);
(void)prim->AddAttr(GROUP_RANKS, MakeValue(rank_list_name));
}
}
}
void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool accu_flag) {
if (new_node_input.empty()) {
return;
@ -282,17 +260,6 @@ static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, cons
return new_node;
}
std::string CreateInstanceName(const CNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node);
if (!IsValueNode<Primitive>(node->input(0))) {
MS_LOG(EXCEPTION) << "CreateInstanceName: " << node->ToString() << " doesn't have primitive";
}
std::string name_base = node->fullname_with_scope();
std::string name = name_base + "_" + std::to_string(index);
std::string instance_name = HashInstanceName(name);
return instance_name;
}
void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
// step1:get graph manager distribute_operator
@ -731,7 +698,8 @@ void StepSplitTensor(const AnfNodePtr &node, const FuncGraphManagerPtr &manager)
MS_EXCEPTION_IF_NULL(prim_anf_node);
PrimitivePtr use_cnode_prim = prim_anf_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(use_cnode_prim);
if (use_cnode_prim->name() == DEPEND && node_pair.second != 1) {
if ((use_cnode_prim->name() == DEPEND && node_pair.second != 1) ||
NO_INPUT_TENSOR_OPS.find(use_cnode_prim->name()) != NO_INPUT_TENSOR_OPS.end()) {
continue;
}
if (IsParallelCareNode(use_cnode)) {
@ -744,76 +712,6 @@ void StepSplitTensor(const AnfNodePtr &node, const FuncGraphManagerPtr &manager)
}
}
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
const CNodePtr &node) {
OperatorArgs arg_replace_op = replace_op.second;
ValuePtr pyop_instance = CreatOpInstance(arg_replace_op.first, replace_op.first, instance_name);
if (pyop_instance == nullptr) {
MS_LOG(EXCEPTION) << "Failure: " << replace_op.first << " CreatOpInstance failed";
}
OperatorParams params = arg_replace_op.second;
if (node->inputs().size() < 2) {
// GetNext operator dose not has input
if (node->inputs().size() == 1) {
return {NewValueNode(pyop_instance)};
}
MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2";
}
std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
if (replace_op.first == EMBEDDING_LOOKUP) {
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)};
}
if (!params.empty()) {
Param param_first = *(params.begin());
int64_t first_position = param_first.second;
if (first_position == 1) {
replace_input.pop_back();
}
for (auto &param : params) {
AnfNodePtr val = NewValueNode(param.first.second);
if (val == nullptr) {
MS_LOG(EXCEPTION) << "Failure:val is nullptr";
}
int64_t position = param.second;
(void)replace_input.insert(replace_input.begin() + position, val);
}
} else if (replace_op.first == SYNC_BATCH_NORM) {
for (size_t i = 2; i < node->inputs().size(); ++i) {
replace_input.push_back(node->input(i));
}
}
SetCommunicationOpGroupLabel(replace_input);
return replace_input;
}
void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) {
FuncGraphPtr func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
FuncGraphManagerPtr manager = func_graph->manager();
if (manager == nullptr) {
MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
}
std::string instance_name = CreateInstanceName(node, 0);
std::vector<AnfNodePtr> replace_input;
replace_input = ReplaceOpInput(replace_op, instance_name, node);
if (node->inputs().size() == DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
replace_input.push_back(node->input(3));
}
CNodePtr replace_node = func_graph->NewCNode(replace_input);
MS_EXCEPTION_IF_NULL(replace_node);
ScopePtr scope = node->scope();
MS_EXCEPTION_IF_NULL(scope);
replace_node->set_scope(scope);
replace_node->set_in_forward_flag(true);
replace_input[0]->set_scope(scope);
PrimitivePtr prim = GetValueNode<PrimitivePtr>(replace_node->input(0));
PrimitivePtr origin_prim = GetValueNode<PrimitivePtr>(node->input(0));
SetUserAttrs(origin_prim->attrs(), prim);
(void)manager->Replace(node, replace_node);
}
void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
// step1:get graph manager distribute_operator
OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>();
@ -2527,64 +2425,6 @@ void StepReplace(const OperatorInfoPtr &distribute_operator, const CNodePtr &cno
}
}
void HandleDropoutNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(distribute_operator);
MS_EXCEPTION_IF_NULL(cnode);
std::string op_name = distribute_operator->name();
if (op_name.find(DROPOUT_DO_MASK) == std::string::npos) {
return;
}
DropoutDoMaskInfoPtr dropout_do_mask = std::dynamic_pointer_cast<DropoutDoMaskInfo>(distribute_operator);
MS_EXCEPTION_IF_NULL(dropout_do_mask);
std::vector<Operator> replace_op = dropout_do_mask->GetDropoutGenMaskReplaceOp(cnode);
if (replace_op.empty()) {
MS_LOG(DEBUG) << "No need to replace dropout_gen_mask";
return;
}
if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
MS_LOG(EXCEPTION) << "The size of drop out do mask cnode's input is not " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
}
ReplaceOneOp(replace_op[0], cnode->input(DROPOUT_GEN_MASK_INDEX)->cast<CNodePtr>());
}
void HandleTileNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() < 3 || !IsValueNode<Primitive>(cnode->input(0))) {
return;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim->name() != TILE) {
return;
}
TileInfoPtr tile = std::dynamic_pointer_cast<TileInfo>(distribute_operator);
MS_EXCEPTION_IF_NULL(tile);
tile->UpdateMultiples(cnode);
}
void HandleConv2dTransposeNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() != 4 || !IsValueNode<Primitive>(cnode->input(0))) {
return;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim->name() != CONV2D_BACK_PROP_INPUT && prim->name() != CONV2D_TRANSPOSE) {
return;
}
Conv2DBackpropInputInfoPtr op_ptr = std::dynamic_pointer_cast<Conv2DBackpropInputInfo>(distribute_operator);
MS_EXCEPTION_IF_NULL(op_ptr);
op_ptr->UpdateOutShape(cnode);
}
void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
HandleDropoutNode(distribute_operator, cnode);
HandleTileNode(distribute_operator, cnode);
HandleConv2dTransposeNode(distribute_operator, cnode);
}
std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) {
// J->CNode->Graph
std::set<FuncGraphPtr> graph_set;
@ -2724,7 +2564,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
BackwardCommunication(root, distribute_operator, cnode, sens_loss_pairs);
}
HandleSpecialNode(distribute_operator, cnode);
distribute_operator->ReplaceNodeInputOrAttrs();
} else if (IsValueNode<Tensor>(node) || IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node)) {
StepSplitTensor(node, manager);
}

View File

@ -55,7 +55,6 @@ struct CommInfo {
};
std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name);
std::string CreateInstanceName(const CNodePtr &node, size_t index);
void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node);
void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node,
@ -79,9 +78,6 @@ bool FindCommunicationOp(const std::vector<AnfNodePtr> &all_nodes);
void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_operator, const CNodePtr &insert_node,
const TensorRedistribution &tensor_redistribution, const CNodePtr &pre_node);
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
const CNodePtr &node);
void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node);
void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node);

View File

@ -92,7 +92,10 @@ Shapes GetValueListShape(const AnfNodePtr &node) {
}
for (auto &ele : inputs_seq) {
auto tensor = ele->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
if (tensor == nullptr) {
MS_LOG(WARNING) << "The value node is not a tensor";
break;
}
auto one_shape = tensor->shape();
shapes.push_back(one_shape);
}
@ -145,5 +148,83 @@ Shapes GetNodeShape(const AnfNodePtr &node) {
}
return shapes;
}
std::string CreateInstanceName(const CNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node);
if (!IsValueNode<Primitive>(node->input(0))) {
MS_LOG(EXCEPTION) << "CreateInstanceName: " << node->ToString() << " doesn't have primitive";
}
std::string name_base = node->fullname_with_scope();
std::string name = name_base + "_" + std::to_string(index);
std::string instance_name = HashInstanceName(name);
return instance_name;
}
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
if (new_node_input.empty()) {
return;
}
auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
auto iter = attrs.find(GROUP);
if (iter != attrs.end()) {
auto value = iter->second;
MS_EXCEPTION_IF_NULL(value);
if (value->isa<StringImm>()) {
std::string hash_name = value->cast<StringImmPtr>()->value();
MS_EXCEPTION_IF_NULL(g_device_manager);
std::string rank_list_name = g_device_manager->FindRankListNameByHashName(hash_name);
(void)prim->AddAttr(GROUP_RANKS, MakeValue(rank_list_name));
}
}
}
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
const CNodePtr &node) {
OperatorArgs arg_replace_op = replace_op.second;
ValuePtr pyop_instance = CreatOpInstance(arg_replace_op.first, replace_op.first, instance_name);
if (pyop_instance == nullptr) {
MS_LOG(EXCEPTION) << "Failure: " << replace_op.first << " CreatOpInstance failed";
}
OperatorParams params = arg_replace_op.second;
if (node->inputs().size() < 2) {
// GetNext operator dose not has input
if (node->inputs().size() == 1) {
return {NewValueNode(pyop_instance)};
}
MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2";
}
std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
if (replace_op.first == EMBEDDING_LOOKUP) {
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)};
}
if (!params.empty()) {
Param param_first = *(params.begin());
int64_t first_position = param_first.second;
if (first_position == 1) {
replace_input.pop_back();
}
for (auto &param : params) {
AnfNodePtr val = NewValueNode(param.first.second);
if (val == nullptr) {
MS_LOG(EXCEPTION) << "Failure:val is nullptr";
}
int64_t position = param.second;
(void)replace_input.insert(replace_input.begin() + position, val);
}
} else if (replace_op.first == SYNC_BATCH_NORM) {
for (size_t i = 2; i < node->inputs().size(); ++i) {
replace_input.push_back(node->input(i));
}
}
SetCommunicationOpGroupLabel(replace_input);
return replace_input;
}
} // namespace parallel
} // namespace mindspore

View File

@ -28,6 +28,10 @@ namespace parallel {
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
bool IsParallelCareNode(const CNodePtr &cnode);
Shapes GetNodeShape(const AnfNodePtr &node);
std::string CreateInstanceName(const CNodePtr &node, size_t index);
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input);
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
const CNodePtr &node);
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,59 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _cell_graph_executor
from mindspore.nn import Cell, TrainOneStepCell, Momentum
from mindspore.ops import operations as P
class Net(Cell):
def __init__(self, mul_weight, strategy1=None, strategy2=None):
super().__init__()
self.mul = P.Mul().shard(strategy1)
self.neg = P.Neg().shard(strategy2)
self.mul_weight = Parameter(mul_weight, "w1")
self.uniform_real = P.UniformReal()
def construct(self, x, b):
out = self.mul(x, self.mul_weight)
out = self.neg(out)
z = self.uniform_real((128, 64, 32))
out = out + z
return out
_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
_w1 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
def compile_net(net):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_cell_graph_executor.compile(train_net, _x, _b)
context.reset_auto_parallel_context()
def test_batch_parallel_replace_shape():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((16, 1, 1), (16, 1, 1))
strategy2 = ((16, 1, 1),)
net = Net(_w1, strategy1, strategy2)
compile_net(net)