forked from mindspore-Ecosystem/mindspore
!23331 batch parallel operator replace input shape from construct
Merge pull request !23331 from yangzhenzhang/modify-batch-parallel-info
This commit is contained in:
commit
2b4cdea48a
|
@ -54,20 +54,39 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
|
||||
Status BatchParallelInfo::InferDevMatrixShape() {
|
||||
dev_matrix_shape_.push_back(stage_device_size_);
|
||||
|
||||
if (need_replace_input_ && !inputs_shape_.empty()) {
|
||||
replace_shape_ = inputs_shape_[0];
|
||||
if (!replace_shape_.empty()) {
|
||||
replace_shape_[0] /= stage_device_size_;
|
||||
}
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status BatchParallelInfo::InferForwardCommunication() { return SUCCESS; }
|
||||
|
||||
Status BatchParallelInfo::InferTensorMap() {
|
||||
if (strategy_->GetInputDim()[0][0] != stage_device_size_) {
|
||||
MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy.";
|
||||
auto strategy = strategy_->GetInputDim();
|
||||
if (strategy.empty()) {
|
||||
MS_LOG(INFO) << name_ << ": the strategy is empty";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
if (strategy[0].empty()) {
|
||||
MS_LOG(INFO) << name_ << ": the first element of strategy is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (strategy[0][0] != stage_device_size_) {
|
||||
MS_LOG(ERROR) << name_ << ": It is not a valid data parallel strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
for (size_t i = 0; i < inputs_shape_.size(); i++) {
|
||||
Shape tensor_map_index;
|
||||
for (size_t j = 0; j < inputs_shape_[i].size(); ++j) {
|
||||
if (strategy_->GetInputDim()[i][j] == stage_device_size_ && j == 0) {
|
||||
if (strategy[i][j] == stage_device_size_ && j == 0) {
|
||||
tensor_map_index.push_back(0);
|
||||
} else {
|
||||
tensor_map_index.push_back(MAP_NONE);
|
||||
|
@ -89,25 +108,23 @@ Status BatchParallelInfo::InferTensorMap() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Strategys BatchParallelInfo::GetOutputsStrategy() {
|
||||
Strategys outputs_strategy;
|
||||
|
||||
for (size_t i = 0; i < outputs_shape_.size(); ++i) {
|
||||
Dimensions strategy;
|
||||
for (size_t j = 0; j < outputs_shape_[i].size(); ++j) {
|
||||
if (i == 0 && j == 0) {
|
||||
strategy.push_back(stage_device_size_);
|
||||
} else {
|
||||
strategy.push_back(1);
|
||||
}
|
||||
}
|
||||
outputs_strategy.push_back(strategy);
|
||||
Status BatchParallelInfo::GetAttrs() {
|
||||
// if the operator's input is a shape(is not a tensor), need to assign the shape value to inputs_shape_
|
||||
if (!inputs_shape_.empty()) {
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
return outputs_strategy;
|
||||
}
|
||||
if (input_value_.empty()) {
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status BatchParallelInfo::GetAttrs() { return SUCCESS; }
|
||||
auto shape_ptr = input_value_[0]->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
|
||||
inputs_shape_.push_back(GetValue<Shape>(shape_ptr));
|
||||
need_replace_input_ = true;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status BatchParallelInfo::Init(const StrategyPtr &strategy) {
|
||||
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
|
@ -158,5 +175,30 @@ Status BatchParallelInfo::InferAsLossDivisor() {
|
|||
as_loss_divisor_ = 1;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void BatchParallelInfo::ReplaceNodeInputOrAttrs() {
|
||||
if (!need_replace_input_) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto cnode = cnode_;
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() != 2) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of tile cnode's inputs must be 2";
|
||||
}
|
||||
|
||||
if (!IsValueNode<ValueTuple>(cnode->input(1))) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The input[1] of tile cnode is not ValueTuple.";
|
||||
}
|
||||
|
||||
auto func_graph = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
||||
ValuePtr replace_shape = MakeValue(replace_shape_);
|
||||
AnfNodePtr val = NewValueNode(replace_shape);
|
||||
(void)manager->Replace(cnode->input(1), val);
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -41,6 +41,7 @@ class BatchParallelInfo : public OperatorInfo {
|
|||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||
void ReplaceNodeInputOrAttrs() override;
|
||||
|
||||
protected:
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
|
@ -48,11 +49,12 @@ class BatchParallelInfo : public OperatorInfo {
|
|||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status GetAttrs() override;
|
||||
Strategys GetOutputsStrategy();
|
||||
Status InferAsLossDivisor() override;
|
||||
|
||||
private:
|
||||
int64_t dev_num_;
|
||||
int64_t dev_num_ = 1;
|
||||
bool need_replace_input_ = false;
|
||||
Shape replace_shape_;
|
||||
};
|
||||
|
||||
class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "frontend/parallel/ops_info/conv2d_info.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
@ -1041,7 +1042,8 @@ Status Conv2DBackpropInputInfo::InferMirrorOps() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
void Conv2DBackpropInputInfo::UpdateOutShape(const CNodePtr &cnode) {
|
||||
void Conv2DBackpropInputInfo::UpdateOutShape() {
|
||||
auto cnode = cnode_;
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() != 4) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of cnode's inputs must be 4, but got " << cnode->size();
|
||||
|
@ -1191,5 +1193,7 @@ void Conv2DBackpropInputInfo::InferNewPadList() {
|
|||
MS_LOG(INFO) << name_ << ": the new pad list is " << new_pad_list_ << ", the required size of current rank is "
|
||||
<< current_rank_required_size << ", new pad all is " << pad_all;
|
||||
}
|
||||
|
||||
void Conv2DBackpropInputInfo::ReplaceNodeInputOrAttrs() { UpdateOutShape(); }
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -127,7 +127,8 @@ class Conv2DBackpropInputInfo : public Conv2DInfo {
|
|||
const PrimitiveAttrs &attrs)
|
||||
: Conv2DInfo(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~Conv2DBackpropInputInfo() override = default;
|
||||
void UpdateOutShape(const CNodePtr &cnode);
|
||||
void UpdateOutShape();
|
||||
void ReplaceNodeInputOrAttrs() override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override;
|
||||
|
|
|
@ -24,6 +24,8 @@
|
|||
#include "ir/value.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
#include "frontend/parallel/auto_parallel/costmodel.h"
|
||||
#include "frontend/parallel/graph_util/node_info.h"
|
||||
#include "frontend/parallel/step_parallel_utils.h"
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
|
||||
|
@ -213,7 +215,8 @@ void SetGenMaskShape(const CNodePtr &cnode, const Shape &input_slice_shape) {
|
|||
// split. Find the DropoutGenMask node in the anf graph according to DropoutDoMask node, and modify the input shape
|
||||
// of DropoutGenMask according to the strategy of DropoutDoMask. When the DropoutDoMask performs repeated calculation
|
||||
// and both seeds of DropoutGenMask are 0, two new seeds are automatically generated for DropoutGenMask.
|
||||
std::vector<Operator> DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) {
|
||||
std::vector<Operator> DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp() {
|
||||
auto cnode = cnode_;
|
||||
std::vector<Operator> replace_ops;
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
PrimitivePtr prim = GetDropoutGenMaskPrim(cnode);
|
||||
|
@ -262,5 +265,46 @@ std::vector<Operator> DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodeP
|
|||
replace_ops.push_back(replace_op);
|
||||
return replace_ops;
|
||||
}
|
||||
|
||||
static void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) {
|
||||
FuncGraphPtr func_graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
|
||||
}
|
||||
std::string instance_name = CreateInstanceName(node, 0);
|
||||
std::vector<AnfNodePtr> replace_input;
|
||||
replace_input = ReplaceOpInput(replace_op, instance_name, node);
|
||||
if (node->inputs().size() == DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
|
||||
replace_input.push_back(node->input(3));
|
||||
}
|
||||
CNodePtr replace_node = func_graph->NewCNode(replace_input);
|
||||
MS_EXCEPTION_IF_NULL(replace_node);
|
||||
ScopePtr scope = node->scope();
|
||||
MS_EXCEPTION_IF_NULL(scope);
|
||||
replace_node->set_scope(scope);
|
||||
replace_node->set_in_forward_flag(true);
|
||||
replace_input[0]->set_scope(scope);
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(replace_node->input(0));
|
||||
PrimitivePtr origin_prim = GetValueNode<PrimitivePtr>(node->input(0));
|
||||
SetUserAttrs(origin_prim->attrs(), prim);
|
||||
(void)manager->Replace(node, replace_node);
|
||||
}
|
||||
|
||||
void DropoutDoMaskInfo::ReplaceNodeInputOrAttrs() {
|
||||
auto cnode = cnode_;
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<Operator> replace_op = GetDropoutGenMaskReplaceOp();
|
||||
if (replace_op.empty()) {
|
||||
MS_LOG(DEBUG) << name_ << ": No need to replace dropout_gen_mask";
|
||||
return;
|
||||
}
|
||||
if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of drop out do mask cnode's input is not "
|
||||
<< DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
|
||||
}
|
||||
ReplaceOneOp(replace_op[0], cnode->input(DROPOUT_GEN_MASK_INDEX)->cast<CNodePtr>());
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -41,7 +41,8 @@ class DropoutDoMaskInfo : public OperatorInfo {
|
|||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
|
||||
std::vector<Operator> GetDropoutGenMaskReplaceOp(const CNodePtr &cnode);
|
||||
std::vector<Operator> GetDropoutGenMaskReplaceOp();
|
||||
void ReplaceNodeInputOrAttrs() override;
|
||||
|
||||
protected:
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -1029,6 +1029,9 @@ Status PrepareStrategyBase(int64_t stage_id, size_t dev_num, const Shapes &input
|
|||
}
|
||||
|
||||
std::shared_ptr<Strategys> OperatorInfo::GenerateBatchStrategies() {
|
||||
if (inputs_shape_.empty() && InferAttrs() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
|
||||
}
|
||||
ComputeBatchSplitFlagList();
|
||||
return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_);
|
||||
}
|
||||
|
|
|
@ -182,6 +182,7 @@ class OperatorInfo {
|
|||
int32_t stage_id() const { return stage_id_; }
|
||||
Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group);
|
||||
Status CreateGroupForOptShard(TensorLayout *const tensor_layout, std::vector<Group> *group);
|
||||
virtual void ReplaceNodeInputOrAttrs() {}
|
||||
|
||||
// Key for user data.
|
||||
constexpr static char key[] = "OpInfo";
|
||||
|
|
|
@ -325,6 +325,7 @@ constexpr char GATHERV2[] = "Gather";
|
|||
constexpr char SPARSE_GATHERV2[] = "SparseGatherV2";
|
||||
constexpr char STRIDEDSLICE[] = "StridedSlice";
|
||||
constexpr char SLICE[] = "Slice";
|
||||
constexpr char UNIFORM_REAL[] = "UniformReal";
|
||||
constexpr char BROADCAST[] = "Broadcast";
|
||||
constexpr char BROADCAST_TO[] = "BroadcastTo";
|
||||
constexpr char SQRT[] = "Sqrt";
|
||||
|
|
|
@ -155,7 +155,8 @@ Status TileInfo::InferMirrorOps() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
void TileInfo::UpdateMultiples(const CNodePtr &cnode) {
|
||||
void TileInfo::UpdateMultiples() {
|
||||
auto cnode = cnode_;
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() != 3) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of tile cnode's inputs must be 3";
|
||||
|
@ -175,6 +176,8 @@ void TileInfo::UpdateMultiples(const CNodePtr &cnode) {
|
|||
(void)manager->Replace(cnode->input(2), val);
|
||||
}
|
||||
|
||||
void TileInfo::ReplaceNodeInputOrAttrs() { UpdateMultiples(); }
|
||||
|
||||
std::shared_ptr<Strategys> TileInfo::GenerateBatchStrategies() {
|
||||
if (InferAttrs() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
|
||||
|
|
|
@ -42,7 +42,8 @@ class TileInfo : public OperatorInfo {
|
|||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
||||
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
|
||||
void UpdateMultiples(const CNodePtr &cnode);
|
||||
void UpdateMultiples();
|
||||
void ReplaceNodeInputOrAttrs() override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override;
|
||||
|
|
|
@ -59,33 +59,11 @@ namespace mindspore {
|
|||
namespace parallel {
|
||||
static const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER};
|
||||
static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS, LOAD, UPDATESTATE};
|
||||
static const std::set<std::string> NO_INPUT_TENSOR_OPS = {UNIFORM_REAL};
|
||||
// g_RefMap, for CNode B input i is a RefKey[Parameter C],
|
||||
// it will be one item in map with key: C, and value: (B, i)
|
||||
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap;
|
||||
|
||||
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
|
||||
if (new_node_input.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
|
||||
auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
||||
auto attrs = prim->attrs();
|
||||
auto iter = attrs.find(GROUP);
|
||||
if (iter != attrs.end()) {
|
||||
auto value = iter->second;
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<StringImm>()) {
|
||||
std::string hash_name = value->cast<StringImmPtr>()->value();
|
||||
MS_EXCEPTION_IF_NULL(g_device_manager);
|
||||
std::string rank_list_name = g_device_manager->FindRankListNameByHashName(hash_name);
|
||||
(void)prim->AddAttr(GROUP_RANKS, MakeValue(rank_list_name));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool accu_flag) {
|
||||
if (new_node_input.empty()) {
|
||||
return;
|
||||
|
@ -282,17 +260,6 @@ static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, cons
|
|||
return new_node;
|
||||
}
|
||||
|
||||
std::string CreateInstanceName(const CNodePtr &node, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!IsValueNode<Primitive>(node->input(0))) {
|
||||
MS_LOG(EXCEPTION) << "CreateInstanceName: " << node->ToString() << " doesn't have primitive";
|
||||
}
|
||||
std::string name_base = node->fullname_with_scope();
|
||||
std::string name = name_base + "_" + std::to_string(index);
|
||||
std::string instance_name = HashInstanceName(name);
|
||||
return instance_name;
|
||||
}
|
||||
|
||||
void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// step1:get graph manager distribute_operator
|
||||
|
@ -731,7 +698,8 @@ void StepSplitTensor(const AnfNodePtr &node, const FuncGraphManagerPtr &manager)
|
|||
MS_EXCEPTION_IF_NULL(prim_anf_node);
|
||||
PrimitivePtr use_cnode_prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(use_cnode_prim);
|
||||
if (use_cnode_prim->name() == DEPEND && node_pair.second != 1) {
|
||||
if ((use_cnode_prim->name() == DEPEND && node_pair.second != 1) ||
|
||||
NO_INPUT_TENSOR_OPS.find(use_cnode_prim->name()) != NO_INPUT_TENSOR_OPS.end()) {
|
||||
continue;
|
||||
}
|
||||
if (IsParallelCareNode(use_cnode)) {
|
||||
|
@ -744,76 +712,6 @@ void StepSplitTensor(const AnfNodePtr &node, const FuncGraphManagerPtr &manager)
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
|
||||
const CNodePtr &node) {
|
||||
OperatorArgs arg_replace_op = replace_op.second;
|
||||
ValuePtr pyop_instance = CreatOpInstance(arg_replace_op.first, replace_op.first, instance_name);
|
||||
if (pyop_instance == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure: " << replace_op.first << " CreatOpInstance failed";
|
||||
}
|
||||
OperatorParams params = arg_replace_op.second;
|
||||
if (node->inputs().size() < 2) {
|
||||
// GetNext operator dose not has input
|
||||
if (node->inputs().size() == 1) {
|
||||
return {NewValueNode(pyop_instance)};
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2";
|
||||
}
|
||||
std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
|
||||
|
||||
if (replace_op.first == EMBEDDING_LOOKUP) {
|
||||
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)};
|
||||
}
|
||||
|
||||
if (!params.empty()) {
|
||||
Param param_first = *(params.begin());
|
||||
int64_t first_position = param_first.second;
|
||||
if (first_position == 1) {
|
||||
replace_input.pop_back();
|
||||
}
|
||||
for (auto ¶m : params) {
|
||||
AnfNodePtr val = NewValueNode(param.first.second);
|
||||
if (val == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure:val is nullptr";
|
||||
}
|
||||
int64_t position = param.second;
|
||||
(void)replace_input.insert(replace_input.begin() + position, val);
|
||||
}
|
||||
} else if (replace_op.first == SYNC_BATCH_NORM) {
|
||||
for (size_t i = 2; i < node->inputs().size(); ++i) {
|
||||
replace_input.push_back(node->input(i));
|
||||
}
|
||||
}
|
||||
SetCommunicationOpGroupLabel(replace_input);
|
||||
return replace_input;
|
||||
}
|
||||
|
||||
void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) {
|
||||
FuncGraphPtr func_graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
|
||||
}
|
||||
std::string instance_name = CreateInstanceName(node, 0);
|
||||
std::vector<AnfNodePtr> replace_input;
|
||||
replace_input = ReplaceOpInput(replace_op, instance_name, node);
|
||||
if (node->inputs().size() == DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
|
||||
replace_input.push_back(node->input(3));
|
||||
}
|
||||
CNodePtr replace_node = func_graph->NewCNode(replace_input);
|
||||
MS_EXCEPTION_IF_NULL(replace_node);
|
||||
ScopePtr scope = node->scope();
|
||||
MS_EXCEPTION_IF_NULL(scope);
|
||||
replace_node->set_scope(scope);
|
||||
replace_node->set_in_forward_flag(true);
|
||||
replace_input[0]->set_scope(scope);
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(replace_node->input(0));
|
||||
PrimitivePtr origin_prim = GetValueNode<PrimitivePtr>(node->input(0));
|
||||
SetUserAttrs(origin_prim->attrs(), prim);
|
||||
(void)manager->Replace(node, replace_node);
|
||||
}
|
||||
|
||||
void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
|
||||
// step1:get graph manager distribute_operator
|
||||
OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>();
|
||||
|
@ -2527,64 +2425,6 @@ void StepReplace(const OperatorInfoPtr &distribute_operator, const CNodePtr &cno
|
|||
}
|
||||
}
|
||||
|
||||
void HandleDropoutNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(distribute_operator);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
std::string op_name = distribute_operator->name();
|
||||
if (op_name.find(DROPOUT_DO_MASK) == std::string::npos) {
|
||||
return;
|
||||
}
|
||||
|
||||
DropoutDoMaskInfoPtr dropout_do_mask = std::dynamic_pointer_cast<DropoutDoMaskInfo>(distribute_operator);
|
||||
MS_EXCEPTION_IF_NULL(dropout_do_mask);
|
||||
std::vector<Operator> replace_op = dropout_do_mask->GetDropoutGenMaskReplaceOp(cnode);
|
||||
if (replace_op.empty()) {
|
||||
MS_LOG(DEBUG) << "No need to replace dropout_gen_mask";
|
||||
return;
|
||||
}
|
||||
if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
|
||||
MS_LOG(EXCEPTION) << "The size of drop out do mask cnode's input is not " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
|
||||
}
|
||||
ReplaceOneOp(replace_op[0], cnode->input(DROPOUT_GEN_MASK_INDEX)->cast<CNodePtr>());
|
||||
}
|
||||
|
||||
void HandleTileNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() < 3 || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
return;
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim->name() != TILE) {
|
||||
return;
|
||||
}
|
||||
|
||||
TileInfoPtr tile = std::dynamic_pointer_cast<TileInfo>(distribute_operator);
|
||||
MS_EXCEPTION_IF_NULL(tile);
|
||||
tile->UpdateMultiples(cnode);
|
||||
}
|
||||
|
||||
void HandleConv2dTransposeNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() != 4 || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
return;
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim->name() != CONV2D_BACK_PROP_INPUT && prim->name() != CONV2D_TRANSPOSE) {
|
||||
return;
|
||||
}
|
||||
|
||||
Conv2DBackpropInputInfoPtr op_ptr = std::dynamic_pointer_cast<Conv2DBackpropInputInfo>(distribute_operator);
|
||||
MS_EXCEPTION_IF_NULL(op_ptr);
|
||||
op_ptr->UpdateOutShape(cnode);
|
||||
}
|
||||
|
||||
void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
|
||||
HandleDropoutNode(distribute_operator, cnode);
|
||||
HandleTileNode(distribute_operator, cnode);
|
||||
HandleConv2dTransposeNode(distribute_operator, cnode);
|
||||
}
|
||||
|
||||
std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) {
|
||||
// J->CNode->Graph
|
||||
std::set<FuncGraphPtr> graph_set;
|
||||
|
@ -2724,7 +2564,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|||
BackwardCommunication(root, distribute_operator, cnode, sens_loss_pairs);
|
||||
}
|
||||
|
||||
HandleSpecialNode(distribute_operator, cnode);
|
||||
distribute_operator->ReplaceNodeInputOrAttrs();
|
||||
} else if (IsValueNode<Tensor>(node) || IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node)) {
|
||||
StepSplitTensor(node, manager);
|
||||
}
|
||||
|
|
|
@ -55,7 +55,6 @@ struct CommInfo {
|
|||
};
|
||||
|
||||
std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name);
|
||||
std::string CreateInstanceName(const CNodePtr &node, size_t index);
|
||||
void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node);
|
||||
|
||||
void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node,
|
||||
|
@ -79,9 +78,6 @@ bool FindCommunicationOp(const std::vector<AnfNodePtr> &all_nodes);
|
|||
void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_operator, const CNodePtr &insert_node,
|
||||
const TensorRedistribution &tensor_redistribution, const CNodePtr &pre_node);
|
||||
|
||||
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
|
||||
const CNodePtr &node);
|
||||
|
||||
void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node);
|
||||
|
||||
void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node);
|
||||
|
|
|
@ -92,7 +92,10 @@ Shapes GetValueListShape(const AnfNodePtr &node) {
|
|||
}
|
||||
for (auto &ele : inputs_seq) {
|
||||
auto tensor = ele->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(WARNING) << "The value node is not a tensor";
|
||||
break;
|
||||
}
|
||||
auto one_shape = tensor->shape();
|
||||
shapes.push_back(one_shape);
|
||||
}
|
||||
|
@ -145,5 +148,83 @@ Shapes GetNodeShape(const AnfNodePtr &node) {
|
|||
}
|
||||
return shapes;
|
||||
}
|
||||
|
||||
std::string CreateInstanceName(const CNodePtr &node, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!IsValueNode<Primitive>(node->input(0))) {
|
||||
MS_LOG(EXCEPTION) << "CreateInstanceName: " << node->ToString() << " doesn't have primitive";
|
||||
}
|
||||
std::string name_base = node->fullname_with_scope();
|
||||
std::string name = name_base + "_" + std::to_string(index);
|
||||
std::string instance_name = HashInstanceName(name);
|
||||
return instance_name;
|
||||
}
|
||||
|
||||
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
|
||||
if (new_node_input.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
|
||||
auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
||||
auto attrs = prim->attrs();
|
||||
auto iter = attrs.find(GROUP);
|
||||
if (iter != attrs.end()) {
|
||||
auto value = iter->second;
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<StringImm>()) {
|
||||
std::string hash_name = value->cast<StringImmPtr>()->value();
|
||||
MS_EXCEPTION_IF_NULL(g_device_manager);
|
||||
std::string rank_list_name = g_device_manager->FindRankListNameByHashName(hash_name);
|
||||
(void)prim->AddAttr(GROUP_RANKS, MakeValue(rank_list_name));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
|
||||
const CNodePtr &node) {
|
||||
OperatorArgs arg_replace_op = replace_op.second;
|
||||
ValuePtr pyop_instance = CreatOpInstance(arg_replace_op.first, replace_op.first, instance_name);
|
||||
if (pyop_instance == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure: " << replace_op.first << " CreatOpInstance failed";
|
||||
}
|
||||
OperatorParams params = arg_replace_op.second;
|
||||
if (node->inputs().size() < 2) {
|
||||
// GetNext operator dose not has input
|
||||
if (node->inputs().size() == 1) {
|
||||
return {NewValueNode(pyop_instance)};
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2";
|
||||
}
|
||||
std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
|
||||
|
||||
if (replace_op.first == EMBEDDING_LOOKUP) {
|
||||
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)};
|
||||
}
|
||||
|
||||
if (!params.empty()) {
|
||||
Param param_first = *(params.begin());
|
||||
int64_t first_position = param_first.second;
|
||||
if (first_position == 1) {
|
||||
replace_input.pop_back();
|
||||
}
|
||||
for (auto ¶m : params) {
|
||||
AnfNodePtr val = NewValueNode(param.first.second);
|
||||
if (val == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure:val is nullptr";
|
||||
}
|
||||
int64_t position = param.second;
|
||||
(void)replace_input.insert(replace_input.begin() + position, val);
|
||||
}
|
||||
} else if (replace_op.first == SYNC_BATCH_NORM) {
|
||||
for (size_t i = 2; i < node->inputs().size(); ++i) {
|
||||
replace_input.push_back(node->input(i));
|
||||
}
|
||||
}
|
||||
SetCommunicationOpGroupLabel(replace_input);
|
||||
return replace_input;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,6 +28,10 @@ namespace parallel {
|
|||
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
|
||||
bool IsParallelCareNode(const CNodePtr &cnode);
|
||||
Shapes GetNodeShape(const AnfNodePtr &node);
|
||||
std::string CreateInstanceName(const CNodePtr &node, size_t index);
|
||||
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input);
|
||||
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
|
||||
const CNodePtr &node);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, mul_weight, strategy1=None, strategy2=None):
|
||||
super().__init__()
|
||||
self.mul = P.Mul().shard(strategy1)
|
||||
self.neg = P.Neg().shard(strategy2)
|
||||
self.mul_weight = Parameter(mul_weight, "w1")
|
||||
self.uniform_real = P.UniformReal()
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.mul(x, self.mul_weight)
|
||||
out = self.neg(out)
|
||||
z = self.uniform_real((128, 64, 32))
|
||||
out = out + z
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net):
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
train_net.set_train()
|
||||
_cell_graph_executor.compile(train_net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_batch_parallel_replace_shape():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((16, 1, 1), (16, 1, 1))
|
||||
strategy2 = ((16, 1, 1),)
|
||||
net = Net(_w1, strategy1, strategy2)
|
||||
compile_net(net)
|
Loading…
Reference in New Issue