forked from mindspore-Ecosystem/mindspore
!8868 Fix Code Format
From: @huangxinjing Reviewed-by: @lichen666,@yangzhenzhang Signed-off-by:
This commit is contained in:
commit
41be7288a4
|
@ -913,14 +913,13 @@ double GatherV2PCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs,
|
||||||
double UniformCandidateSamplerCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
|
double UniformCandidateSamplerCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
|
||||||
const std::vector<TensorInfo> &outputs,
|
const std::vector<TensorInfo> &outputs,
|
||||||
int64_t stage_id) const {
|
int64_t stage_id) const {
|
||||||
double result = 0.0;
|
|
||||||
Shape input0_slice_shape = inputs[0].slice_shape();
|
Shape input0_slice_shape = inputs[0].slice_shape();
|
||||||
if (inputs_type_lengths_.size() != inputs.size()) {
|
if (inputs_type_lengths_.size() != inputs.size()) {
|
||||||
MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size()
|
MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size()
|
||||||
<< " for UniformCandidateSampler cost";
|
<< " for UniformCandidateSampler cost";
|
||||||
}
|
}
|
||||||
|
|
||||||
result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
|
double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -279,6 +279,5 @@ Status SliceInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||||
|
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -30,7 +30,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
|
|
||||||
Status UniformCandidateSamplerInfo::GetUniformSamplerAttrInt64(const std::string &args, int64_t *value) {
|
Status UniformCandidateSamplerInfo::GetUniformSamplerAttrInt64(const std::string &args, int64_t *value) {
|
||||||
auto iter = attrs_.find(args);
|
auto iter = attrs_.find(args);
|
||||||
if (iter == attrs_.end()) {
|
if (iter == attrs_.end()) {
|
||||||
|
@ -276,7 +275,6 @@ Status UniformCandidateSamplerInfo::InitForCostModel(const StrategyPtr &strategy
|
||||||
|
|
||||||
ReplaceGraphPtr UniformCandidateSamplerInfo::replace_graph(const CNodePtr &cnode) {
|
ReplaceGraphPtr UniformCandidateSamplerInfo::replace_graph(const CNodePtr &cnode) {
|
||||||
auto input_strategy = strategy_->GetInputDim().at(0);
|
auto input_strategy = strategy_->GetInputDim().at(0);
|
||||||
|
|
||||||
// Only when the axis-1 is sharded, we need to modify the attribute
|
// Only when the axis-1 is sharded, we need to modify the attribute
|
||||||
if (input_strategy.size() == 2 && input_strategy[1] > 1) {
|
if (input_strategy.size() == 2 && input_strategy[1] > 1) {
|
||||||
if (ComputeReplaceGraph(cnode) != SUCCESS) {
|
if (ComputeReplaceGraph(cnode) != SUCCESS) {
|
||||||
|
@ -311,6 +309,5 @@ Status UniformCandidateSamplerInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||||
|
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -331,7 +331,6 @@ Status UnsortedSegmentMinInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||||
|
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
// The UnsortedSegmentMaxInfo is almost same with UnsortedSegmentMinInfo
|
// The UnsortedSegmentMaxInfo is almost same with UnsortedSegmentMinInfo
|
||||||
// Except the reduceMin op in the ComputeReplaceGraph is replaced with reduceMax op
|
// Except the reduceMin op in the ComputeReplaceGraph is replaced with reduceMax op
|
||||||
ReplaceGraphPtr UnsortedSegmentMaxInfo::replace_graph(const CNodePtr &cnode) {
|
ReplaceGraphPtr UnsortedSegmentMaxInfo::replace_graph(const CNodePtr &cnode) {
|
||||||
|
@ -351,9 +350,8 @@ Status UnsortedSegmentMaxInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||||
MS_LOG(ERROR) << "GenerateGraph Init failed";
|
MS_LOG(ERROR) << "GenerateGraph Init failed";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
// Get the attributes of the UnsortedSegmentMin
|
// Get the attributes of the UnsortedSegmentMax
|
||||||
auto num_segments = GetValue<int64_t>(input_value_.at(2));
|
auto num_segments = GetValue<int64_t>(input_value_.at(2));
|
||||||
// Step1: Output branch
|
|
||||||
auto segment_max = gen_g.PushBack({gen_g.NewOpInst(UNSORTED_SEGMENT_MAX), gen_g.virtual_input_node(),
|
auto segment_max = gen_g.PushBack({gen_g.NewOpInst(UNSORTED_SEGMENT_MAX), gen_g.virtual_input_node(),
|
||||||
gen_g.virtual_input_node(), CreatInt64Imm(num_segments)});
|
gen_g.virtual_input_node(), CreatInt64Imm(num_segments)});
|
||||||
auto expandim_output = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), segment_max, CreatInt64Imm(0)});
|
auto expandim_output = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), segment_max, CreatInt64Imm(0)});
|
||||||
|
|
|
@ -78,7 +78,6 @@ class UnsortedSegmentMinInfo : public UnsortedSegmentOpInfo {
|
||||||
protected:
|
protected:
|
||||||
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||||
};
|
};
|
||||||
|
|
||||||
class UnsortedSegmentMaxInfo : public UnsortedSegmentOpInfo {
|
class UnsortedSegmentMaxInfo : public UnsortedSegmentOpInfo {
|
||||||
public:
|
public:
|
||||||
UnsortedSegmentMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
UnsortedSegmentMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
|
|
|
@ -22,7 +22,6 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
|
#include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
|
||||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
|
||||||
#include "frontend/parallel/auto_parallel/graph_costmodel.h"
|
#include "frontend/parallel/auto_parallel/graph_costmodel.h"
|
||||||
#include "frontend/parallel/ops_info/ops_utils.h"
|
#include "frontend/parallel/ops_info/ops_utils.h"
|
||||||
#include "frontend/parallel/group_manager.h"
|
#include "frontend/parallel/group_manager.h"
|
||||||
|
@ -33,8 +32,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
static std::unordered_map<AnfNodePtr, std::set<int>> parameter_color_map;
|
static std::unordered_map<AnfNodePtr, std::set<int>> parameter_color_map;
|
||||||
static std::pair<bool, int> IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users);
|
|
||||||
static bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
|
|
||||||
static int send_tag = 0;
|
static int send_tag = 0;
|
||||||
static int recv_tag = 0;
|
static int recv_tag = 0;
|
||||||
|
|
||||||
|
@ -239,7 +236,7 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode
|
||||||
manager_->SetEdge(use_node, index, recv);
|
manager_->SetEdge(use_node, index, recv);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::pair<bool, int> IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users) {
|
std::pair<bool, int> PipelineTransformer::IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users) {
|
||||||
std::set<int> tag_set;
|
std::set<int> tag_set;
|
||||||
auto node_stage = node->stage();
|
auto node_stage = node->stage();
|
||||||
int min_tag = node_stage;
|
int min_tag = node_stage;
|
||||||
|
@ -371,7 +368,7 @@ void PipelineTransformer::ElimGraphStage() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) {
|
bool PipelineTransformer::IsSomePrimitive(const CNodePtr &cnode, const std::string &name) {
|
||||||
ValueNodePtr anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
ValueNodePtr anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>();
|
PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>();
|
||||||
|
|
|
@ -18,9 +18,11 @@
|
||||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_TRANSFORMER_H_
|
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_TRANSFORMER_H_
|
||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <string>
|
||||||
#include "ir/value.h"
|
#include "ir/value.h"
|
||||||
#include "ir/graph_utils.h"
|
#include "ir/graph_utils.h"
|
||||||
#include "base/base.h"
|
#include "base/base.h"
|
||||||
|
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
|
@ -49,6 +51,8 @@ class PipelineTransformer {
|
||||||
void ElimParameter();
|
void ElimParameter();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
std::pair<bool, int> IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users);
|
||||||
|
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
|
||||||
void DoBroadCast(const FuncGraphPtr &func);
|
void DoBroadCast(const FuncGraphPtr &func);
|
||||||
SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, const int &user_node_stage,
|
SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, const int &user_node_stage,
|
||||||
const int &node_stage);
|
const int &node_stage);
|
||||||
|
|
|
@ -54,11 +54,6 @@ static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS};
|
||||||
// g_RefMap, for CNode B input i is a RefKey[Parameter C],
|
// g_RefMap, for CNode B input i is a RefKey[Parameter C],
|
||||||
// it will be one item in map with key: C, and value: (B, i)
|
// it will be one item in map with key: C, and value: (B, i)
|
||||||
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap;
|
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap;
|
||||||
static void HandleNoUsedParameter(const FuncGraphPtr &root);
|
|
||||||
static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph,
|
|
||||||
const std::string &instance_name);
|
|
||||||
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter,
|
|
||||||
const std::string &opt_shard_group);
|
|
||||||
|
|
||||||
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
|
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
|
||||||
if (new_node_input.empty()) {
|
if (new_node_input.empty()) {
|
||||||
|
|
Loading…
Reference in New Issue