!8868 Fix Code Format

From: @huangxinjing
Reviewed-by: @lichen666,@yangzhenzhang
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-27 09:55:48 +08:00 committed by Gitee
commit 41be7288a4
8 changed files with 8 additions and 20 deletions

View File

@ -913,14 +913,13 @@ double GatherV2PCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs,
double UniformCandidateSamplerCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, double UniformCandidateSamplerCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const { int64_t stage_id) const {
double result = 0.0;
Shape input0_slice_shape = inputs[0].slice_shape(); Shape input0_slice_shape = inputs[0].slice_shape();
if (inputs_type_lengths_.size() != inputs.size()) { if (inputs_type_lengths_.size() != inputs.size()) {
MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size()
<< " for UniformCandidateSampler cost"; << " for UniformCandidateSampler cost";
} }
result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
return result; return result;
} }

View File

@ -279,6 +279,5 @@ Status SliceInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
return SUCCESS; return SUCCESS;
} }
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

View File

@ -30,7 +30,6 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
Status UniformCandidateSamplerInfo::GetUniformSamplerAttrInt64(const std::string &args, int64_t *value) { Status UniformCandidateSamplerInfo::GetUniformSamplerAttrInt64(const std::string &args, int64_t *value) {
auto iter = attrs_.find(args); auto iter = attrs_.find(args);
if (iter == attrs_.end()) { if (iter == attrs_.end()) {
@ -276,7 +275,6 @@ Status UniformCandidateSamplerInfo::InitForCostModel(const StrategyPtr &strategy
ReplaceGraphPtr UniformCandidateSamplerInfo::replace_graph(const CNodePtr &cnode) { ReplaceGraphPtr UniformCandidateSamplerInfo::replace_graph(const CNodePtr &cnode) {
auto input_strategy = strategy_->GetInputDim().at(0); auto input_strategy = strategy_->GetInputDim().at(0);
// Only when the axis-1 is sharded, we need to modify the attribute // Only when the axis-1 is sharded, we need to modify the attribute
if (input_strategy.size() == 2 && input_strategy[1] > 1) { if (input_strategy.size() == 2 && input_strategy[1] > 1) {
if (ComputeReplaceGraph(cnode) != SUCCESS) { if (ComputeReplaceGraph(cnode) != SUCCESS) {
@ -311,6 +309,5 @@ Status UniformCandidateSamplerInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
return SUCCESS; return SUCCESS;
} }
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

View File

@ -331,7 +331,6 @@ Status UnsortedSegmentMinInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
return SUCCESS; return SUCCESS;
} }
// The UnsortedSegmentMaxInfo is almost same with UnsortedSegmentMinInfo // The UnsortedSegmentMaxInfo is almost same with UnsortedSegmentMinInfo
// Except the reduceMin op in the ComputeReplaceGraph is replaced with reduceMax op // Except the reduceMin op in the ComputeReplaceGraph is replaced with reduceMax op
ReplaceGraphPtr UnsortedSegmentMaxInfo::replace_graph(const CNodePtr &cnode) { ReplaceGraphPtr UnsortedSegmentMaxInfo::replace_graph(const CNodePtr &cnode) {
@ -351,9 +350,8 @@ Status UnsortedSegmentMaxInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
MS_LOG(ERROR) << "GenerateGraph Init failed"; MS_LOG(ERROR) << "GenerateGraph Init failed";
return FAILED; return FAILED;
} }
// Get the attributes of the UnsortedSegmentMin // Get the attributes of the UnsortedSegmentMax
auto num_segments = GetValue<int64_t>(input_value_.at(2)); auto num_segments = GetValue<int64_t>(input_value_.at(2));
// Step1: Output branch
auto segment_max = gen_g.PushBack({gen_g.NewOpInst(UNSORTED_SEGMENT_MAX), gen_g.virtual_input_node(), auto segment_max = gen_g.PushBack({gen_g.NewOpInst(UNSORTED_SEGMENT_MAX), gen_g.virtual_input_node(),
gen_g.virtual_input_node(), CreatInt64Imm(num_segments)}); gen_g.virtual_input_node(), CreatInt64Imm(num_segments)});
auto expandim_output = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), segment_max, CreatInt64Imm(0)}); auto expandim_output = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), segment_max, CreatInt64Imm(0)});

View File

@ -78,7 +78,6 @@ class UnsortedSegmentMinInfo : public UnsortedSegmentOpInfo {
protected: protected:
Status ComputeReplaceGraph(const CNodePtr &cnode); Status ComputeReplaceGraph(const CNodePtr &cnode);
}; };
class UnsortedSegmentMaxInfo : public UnsortedSegmentOpInfo { class UnsortedSegmentMaxInfo : public UnsortedSegmentOpInfo {
public: public:
UnsortedSegmentMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, UnsortedSegmentMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,

View File

@ -22,7 +22,6 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include "frontend/parallel/pipeline_transformer/pipeline_transformer.h" #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
#include "frontend/parallel/graph_util/generate_graph.h"
#include "frontend/parallel/auto_parallel/graph_costmodel.h" #include "frontend/parallel/auto_parallel/graph_costmodel.h"
#include "frontend/parallel/ops_info/ops_utils.h" #include "frontend/parallel/ops_info/ops_utils.h"
#include "frontend/parallel/group_manager.h" #include "frontend/parallel/group_manager.h"
@ -33,8 +32,6 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
static std::unordered_map<AnfNodePtr, std::set<int>> parameter_color_map; static std::unordered_map<AnfNodePtr, std::set<int>> parameter_color_map;
static std::pair<bool, int> IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users);
static bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
static int send_tag = 0; static int send_tag = 0;
static int recv_tag = 0; static int recv_tag = 0;
@ -239,7 +236,7 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode
manager_->SetEdge(use_node, index, recv); manager_->SetEdge(use_node, index, recv);
} }
static std::pair<bool, int> IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users) { std::pair<bool, int> PipelineTransformer::IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users) {
std::set<int> tag_set; std::set<int> tag_set;
auto node_stage = node->stage(); auto node_stage = node->stage();
int min_tag = node_stage; int min_tag = node_stage;
@ -371,7 +368,7 @@ void PipelineTransformer::ElimGraphStage() {
} }
} }
static bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { bool PipelineTransformer::IsSomePrimitive(const CNodePtr &cnode, const std::string &name) {
ValueNodePtr anf_node = cnode->input(0)->cast<ValueNodePtr>(); ValueNodePtr anf_node = cnode->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>(); PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>();

View File

@ -18,9 +18,11 @@
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_TRANSFORMER_H_ #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_TRANSFORMER_H_
#include <utility> #include <utility>
#include <string>
#include "ir/value.h" #include "ir/value.h"
#include "ir/graph_utils.h" #include "ir/graph_utils.h"
#include "base/base.h" #include "base/base.h"
#include "frontend/parallel/graph_util/generate_graph.h"
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
@ -49,6 +51,8 @@ class PipelineTransformer {
void ElimParameter(); void ElimParameter();
private: private:
std::pair<bool, int> IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users);
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
void DoBroadCast(const FuncGraphPtr &func); void DoBroadCast(const FuncGraphPtr &func);
SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr &parameter, const int &user_node_stage, SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr &parameter, const int &user_node_stage,
const int &node_stage); const int &node_stage);

View File

@ -54,11 +54,6 @@ static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS};
// g_RefMap, for CNode B input i is a RefKey[Parameter C], // g_RefMap, for CNode B input i is a RefKey[Parameter C],
// it will be one item in map with key: C, and value: (B, i) // it will be one item in map with key: C, and value: (B, i)
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap; static std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap;
static void HandleNoUsedParameter(const FuncGraphPtr &root);
static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph,
const std::string &instance_name);
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &parameter,
const std::string &opt_shard_group);
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) { void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
if (new_node_input.empty()) { if (new_node_input.empty()) {