forked from OSSInnovation/mindspore
fix range start and deta error
This commit is contained in:
parent
a28d01f087
commit
1ff78eb167
|
@ -134,59 +134,6 @@ Status RangeInfo::InitForCostModel(const StrategyPtr &strategy) {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RangeInfo::InferNewAttr() {
|
|
||||||
CheckGlobalDeviceManager();
|
|
||||||
int64_t rank = g_device_manager->rank_index_in_stage();
|
|
||||||
|
|
||||||
// If repeated calculation and repeated num as the last dimension of dev-matrix,
|
|
||||||
// the dev-matrix is [split_num_, repeated_calc_num_], so from rank 0 to rank repeated_calc_num_
|
|
||||||
// are repeated calculation, and these rank have the same 'new_start_'.
|
|
||||||
// If repeated calculation and repeated num as the first dimension of dev-matrix,
|
|
||||||
// the dev-matrix is [repeated_calc_num_, split_num_], so rank 0 and rank split_num_ and so on
|
|
||||||
// are repeated calculation, and these rank have the same 'new_start_'.
|
|
||||||
float start_bias = inputs_shape_[0][0] / split_num_ * delta_;
|
|
||||||
if (repeated_num_in_dev_matrix_right_) {
|
|
||||||
new_start_ = start_ + start_bias * (rank / repeated_calc_num_);
|
|
||||||
} else {
|
|
||||||
new_start_ = start_ + start_bias * (rank % split_num_);
|
|
||||||
}
|
|
||||||
|
|
||||||
new_limit_ = new_start_ + start_bias;
|
|
||||||
MS_LOG(INFO) << name_ << ": The new start is " << new_start_ << ", the new limit is " << new_limit_;
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status RangeInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|
||||||
GenerateGraph gen_g = GenerateGraph();
|
|
||||||
if (gen_g.Init(cnode) != SUCCESS) {
|
|
||||||
MS_LOG(ERROR) << name_ << ": GenerateGraph Init failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (InferNewAttr() != SUCCESS) {
|
|
||||||
MS_LOG(ERROR) << name_ << ": Infer new attr failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
|
|
||||||
Attr attr_start = std::make_pair(START, MakeValue(new_start_));
|
|
||||||
Attr attr_limit = std::make_pair(LIMIT, MakeValue(new_limit_));
|
|
||||||
Attr attr_delta = std::make_pair(DELTA, MakeValue(delta_));
|
|
||||||
OperatorAttrs attrs = {attr_start, attr_limit, attr_delta};
|
|
||||||
auto new_range_op = gen_g.PushBack({gen_g.NewOpInst(RANGE, attrs), gen_g.virtual_input_node()});
|
|
||||||
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(new_range_op, 1)};
|
|
||||||
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
|
|
||||||
std::make_pair(input_nodes, new_range_op));
|
|
||||||
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
ReplaceGraphPtr RangeInfo::replace_graph(const CNodePtr &cnode) {
|
|
||||||
if (ComputeReplaceGraph(cnode) != SUCCESS) {
|
|
||||||
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
|
|
||||||
}
|
|
||||||
return replace_graph_;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status RangeInfo::GenerateStrategies(int64_t stage_id) {
|
Status RangeInfo::GenerateStrategies(int64_t stage_id) {
|
||||||
Shape input0_split(inputs_shape_[0].size(), 1);
|
Shape input0_split(inputs_shape_[0].size(), 1);
|
||||||
Shapes splittable_inputs = {input0_split};
|
Shapes splittable_inputs = {input0_split};
|
||||||
|
|
|
@ -47,7 +47,6 @@ class RangeInfo : public OperatorInfo {
|
||||||
|
|
||||||
Status GenerateStrategies(int64_t stage_id) override;
|
Status GenerateStrategies(int64_t stage_id) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||||
|
@ -57,9 +56,7 @@ class RangeInfo : public OperatorInfo {
|
||||||
Status InferDevMatrixShape() override;
|
Status InferDevMatrixShape() override;
|
||||||
Status InferTensorMap() override;
|
Status InferTensorMap() override;
|
||||||
Status GetAttrs() override;
|
Status GetAttrs() override;
|
||||||
Status InferNewAttr();
|
|
||||||
float GetRangeAttr(const std::string &arg);
|
float GetRangeAttr(const std::string &arg);
|
||||||
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
|
||||||
|
|
||||||
float start_ = 0.0;
|
float start_ = 0.0;
|
||||||
float limit_ = 0.0;
|
float limit_ = 0.0;
|
||||||
|
|
Loading…
Reference in New Issue