forked from mindspore-Ecosystem/mindspore
!8868 Fix Code Format
From: @huangxinjing Reviewed-by: @lichen666,@yangzhenzhang Signed-off-by:
This commit is contained in:
commit
41be7288a4
|
@ -913,14 +913,13 @@ double GatherV2PCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs,
|
|||
double UniformCandidateSamplerCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
|
||||
const std::vector<TensorInfo> &outputs,
|
||||
int64_t stage_id) const {
|
||||
double result = 0.0;
|
||||
Shape input0_slice_shape = inputs[0].slice_shape();
|
||||
if (inputs_type_lengths_.size() != inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size()
|
||||
<< " for UniformCandidateSampler cost";
|
||||
}
|
||||
|
||||
result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
|
||||
double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -279,6 +279,5 @@ Status SliceInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,7 +30,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
||||
Status UniformCandidateSamplerInfo::GetUniformSamplerAttrInt64(const std::string &args, int64_t *value) {
|
||||
auto iter = attrs_.find(args);
|
||||
if (iter == attrs_.end()) {
|
||||
|
@ -276,7 +275,6 @@ Status UniformCandidateSamplerInfo::InitForCostModel(const StrategyPtr &strategy
|
|||
|
||||
ReplaceGraphPtr UniformCandidateSamplerInfo::replace_graph(const CNodePtr &cnode) {
|
||||
auto input_strategy = strategy_->GetInputDim().at(0);
|
||||
|
||||
// Only when the axis-1 is sharded, we need to modify the attribute
|
||||
if (input_strategy.size() == 2 && input_strategy[1] > 1) {
|
||||
if (ComputeReplaceGraph(cnode) != SUCCESS) {
|
||||
|
@ -311,6 +309,5 @@ Status UniformCandidateSamplerInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -331,7 +331,6 @@ Status UnsortedSegmentMinInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
// The UnsortedSegmentMaxInfo is almost same with UnsortedSegmentMinInfo
|
||||
// Except the reduceMin op in the ComputeReplaceGraph is replaced with reduceMax op
|
||||
ReplaceGraphPtr UnsortedSegmentMaxInfo::replace_graph(const CNodePtr &cnode) {
|
||||
|
@ -351,9 +350,8 @@ Status UnsortedSegmentMaxInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|||
MS_LOG(ERROR) << "GenerateGraph Init failed";
|
||||
return FAILED;
|
||||
}
|
||||
// Get the attributes of the UnsortedSegmentMin
|
||||
// Get the attributes of the UnsortedSegmentMax
|
||||
auto num_segments = GetValue<int64_t>(input_value_.at(2));
|
||||
// Step1: Output branch
|
||||
auto segment_max = gen_g.PushBack({gen_g.NewOpInst(UNSORTED_SEGMENT_MAX), gen_g.virtual_input_node(),
|
||||
gen_g.virtual_input_node(), CreatInt64Imm(num_segments)});
|
||||
auto expandim_output = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), segment_max, CreatInt64Imm(0)});
|
||||
|
|
|
@ -78,7 +78,6 @@ class UnsortedSegmentMinInfo : public UnsortedSegmentOpInfo {
|
|||
protected:
|
||||
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||
};
|
||||
|
||||
class UnsortedSegmentMaxInfo : public UnsortedSegmentOpInfo {
|
||||
public:
|
||||
UnsortedSegmentMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
#include "frontend/parallel/auto_parallel/graph_costmodel.h"
|
||||
#include "frontend/parallel/ops_info/ops_utils.h"
|
||||
#include "frontend/parallel/group_manager.h"
|
||||
|
@ -33,8 +32,6 @@
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
static std::unordered_map<AnfNodePtr, std::set<int>> parameter_color_map;
|
||||
static std::pair<bool, int> IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users);
|
||||
static bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
|
||||
static int send_tag = 0;
|
||||
static int recv_tag = 0;
|
||||
|
||||
|
@ -239,7 +236,7 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode
|
|||
manager_->SetEdge(use_node, index, recv);
|
||||
}
|
||||
|
||||
static std::pair<bool, int> IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users) {
|
||||
std::pair<bool, int> PipelineTransformer::IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users) {
|
||||
std::set<int> tag_set;
|
||||
auto node_stage = node->stage();
|
||||
int min_tag = node_stage;
|
||||
|
@ -371,7 +368,7 @@ void PipelineTransformer::ElimGraphStage() {
|
|||
}
|
||||
}
|
||||
|
||||
static bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) {
|
||||
bool PipelineTransformer::IsSomePrimitive(const CNodePtr &cnode, const std::string &name) {
|
||||
ValueNodePtr anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>();
|
||||
|
|
|
@ -18,9 +18,11 @@
|
|||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_TRANSFORMER_H_
|
||||
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include "ir/value.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "base/base.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
@ -49,6 +51,8 @@ class PipelineTransformer {
|
|||
void ElimParameter();
|
||||
|
||||
private:
|
||||
std::pair<bool, int> IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users);
|
||||
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
|
||||
void DoBroadCast(const FuncGraphPtr &func);
|
||||
SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, const int &user_node_stage,
|
||||
const int &node_stage);
|
||||
|
|
|
@ -54,11 +54,6 @@ static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS};
|
|||
// g_RefMap, for CNode B input i is a RefKey[Parameter C],
|
||||
// it will be one item in map with key: C, and value: (B, i)
|
||||
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap;
|
||||
static void HandleNoUsedParameter(const FuncGraphPtr &root);
|
||||
static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph,
|
||||
const std::string &instance_name);
|
||||
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter,
|
||||
const std::string &opt_shard_group);
|
||||
|
||||
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
|
||||
if (new_node_input.empty()) {
|
||||
|
|
Loading…
Reference in New Issue