diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/range_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/range_info.cc index 88b61dd790..5365a0a9bd 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/range_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/range_info.cc @@ -134,59 +134,6 @@ Status RangeInfo::InitForCostModel(const StrategyPtr &strategy) { return SUCCESS; } -Status RangeInfo::InferNewAttr() { - CheckGlobalDeviceManager(); - int64_t rank = g_device_manager->rank_index_in_stage(); - - // If repeated calculation and repeated num as the last dimension of dev-matrix, - // the dev-matrix is [split_num_, repeated_calc_num_], so from rank 0 to rank repeated_calc_num_ - // are repeated calculation, and these rank have the same 'new_start_'. - // If repeated calculation and repeated num as the first dimension of dev-matrix, - // the dev-matrix is [repeated_calc_num_, split_num_], so rank 0 and rank split_num_ and so on - // are repeated calculation, and these rank have the same 'new_start_'. - float start_bias = inputs_shape_[0][0] / split_num_ * delta_; - if (repeated_num_in_dev_matrix_right_) { - new_start_ = start_ + start_bias * (rank / repeated_calc_num_); - } else { - new_start_ = start_ + start_bias * (rank % split_num_); - } - - new_limit_ = new_start_ + start_bias; - MS_LOG(INFO) << name_ << ": The new start is " << new_start_ << ", the new limit is " << new_limit_; - return SUCCESS; -} - -Status RangeInfo::ComputeReplaceGraph(const CNodePtr &cnode) { - GenerateGraph gen_g = GenerateGraph(); - if (gen_g.Init(cnode) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": GenerateGraph Init failed"; - return FAILED; - } - - if (InferNewAttr() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer new attr failed"; - return FAILED; - } - - Attr attr_start = std::make_pair(START, MakeValue(new_start_)); - Attr attr_limit = std::make_pair(LIMIT, MakeValue(new_limit_)); - Attr attr_delta = std::make_pair(DELTA, MakeValue(delta_)); - OperatorAttrs attrs = {attr_start, attr_limit, attr_delta}; - auto new_range_op = gen_g.PushBack({gen_g.NewOpInst(RANGE, attrs), gen_g.virtual_input_node()}); - std::vector> input_nodes = {std::make_pair(new_range_op, 1)}; - replace_graph_ = std::make_shared>, AnfNodePtr>>( - std::make_pair(input_nodes, new_range_op)); - - return SUCCESS; -} - -ReplaceGraphPtr RangeInfo::replace_graph(const CNodePtr &cnode) { - if (ComputeReplaceGraph(cnode) != SUCCESS) { - MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed."; - } - return replace_graph_; -} - Status RangeInfo::GenerateStrategies(int64_t stage_id) { Shape input0_split(inputs_shape_[0].size(), 1); Shapes splittable_inputs = {input0_split}; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/range_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/range_info.h index 38b2bad89a..e3aa63dc17 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/range_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/range_info.h @@ -47,7 +47,6 @@ class RangeInfo : public OperatorInfo { Status GenerateStrategies(int64_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; protected: Status CheckStrategy(const StrategyPtr &strategy) override; @@ -57,9 +56,7 @@ class RangeInfo : public OperatorInfo { Status InferDevMatrixShape() override; Status InferTensorMap() override; Status GetAttrs() override; - Status InferNewAttr(); float GetRangeAttr(const std::string &arg); - Status ComputeReplaceGraph(const CNodePtr &cnode); float start_ = 0.0; float limit_ = 0.0;