forked from mindspore-Ecosystem/mindspore
refactor OperatorCostPtr in OperatorInfo
This commit is contained in:
parent
cc53ddaeca
commit
b413638f23
|
@ -514,60 +514,6 @@ double ArithmeticCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
double L2NormalizeCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
|
|
||||||
const int32_t& stage_id) const {
|
|
||||||
double result = 0.0;
|
|
||||||
if (is_parameter_[0]) {
|
|
||||||
TensorInfo input_tensor_info = inputs[0];
|
|
||||||
CheckGlobalDeviceManager();
|
|
||||||
MS_EXCEPTION_IF_NULL(g_device_manager);
|
|
||||||
auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
|
|
||||||
|
|
||||||
Shape input_shape = input_tensor_info.shape();
|
|
||||||
Shape input_slice_shape = input_tensor_info.slice_shape();
|
|
||||||
int32_t used_device_num = 1;
|
|
||||||
for (size_t i = 0; i < input_shape.size(); ++i) {
|
|
||||||
used_device_num *= input_shape[i] / input_slice_shape[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (total_device_num != IntToSize(used_device_num))
|
|
||||||
result += ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
double L2NormalizeCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
|
|
||||||
const int32_t&) const {
|
|
||||||
TensorInfo input0_info = inputs[0];
|
|
||||||
Shape input0_slice_shape = input0_info.slice_shape();
|
|
||||||
return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
double L2NormalizeCost::GetBackwardComputationCost(const std::vector<TensorInfo>& inputs,
|
|
||||||
const std::vector<TensorInfo>&, const int32_t& stage_id) const {
|
|
||||||
double result = 0.0;
|
|
||||||
|
|
||||||
if (is_parameter_[0]) {
|
|
||||||
TensorInfo input_tensor_info = inputs[0];
|
|
||||||
CheckGlobalDeviceManager();
|
|
||||||
MS_EXCEPTION_IF_NULL(g_device_manager);
|
|
||||||
auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
|
|
||||||
|
|
||||||
Shape input_shape = input_tensor_info.shape();
|
|
||||||
Shape input_slice_shape = input_tensor_info.slice_shape();
|
|
||||||
int32_t used_device_num = 1;
|
|
||||||
for (size_t i = 0; i < input_shape.size(); ++i) {
|
|
||||||
used_device_num *= input_shape[i] / input_slice_shape[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (total_device_num != IntToSize(used_device_num))
|
|
||||||
result += ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsDataParallel(const Shape& shape, const Shape& slice_shape, const int32_t& stage_id) {
|
bool IsDataParallel(const Shape& shape, const Shape& slice_shape, const int32_t& stage_id) {
|
||||||
CheckGlobalDeviceManager();
|
CheckGlobalDeviceManager();
|
||||||
MS_EXCEPTION_IF_NULL(g_device_manager);
|
MS_EXCEPTION_IF_NULL(g_device_manager);
|
||||||
|
|
|
@ -132,6 +132,8 @@ class ActivationCost : public OperatorCost {
|
||||||
};
|
};
|
||||||
|
|
||||||
using ActivationCostPtr = std::shared_ptr<ActivationCost>;
|
using ActivationCostPtr = std::shared_ptr<ActivationCost>;
|
||||||
|
using TransposeCost = ActivationCost;
|
||||||
|
using TransposeCostPtr = std::shared_ptr<TransposeCost>;
|
||||||
|
|
||||||
class SoftmaxCost : public OperatorCost {
|
class SoftmaxCost : public OperatorCost {
|
||||||
public:
|
public:
|
||||||
|
@ -415,32 +417,8 @@ class ArithmeticCost : public OperatorCost {
|
||||||
const int32_t& stage_id) const override;
|
const int32_t& stage_id) const override;
|
||||||
};
|
};
|
||||||
using ArithmeticCostPtr = std::shared_ptr<ArithmeticCost>;
|
using ArithmeticCostPtr = std::shared_ptr<ArithmeticCost>;
|
||||||
|
using BiasAddCost = ArithmeticCost;
|
||||||
class L2NormalizeCost : public OperatorCost {
|
using BiasAddCostPtr = std::shared_ptr<BiasAddCost>;
|
||||||
public:
|
|
||||||
L2NormalizeCost() = default;
|
|
||||||
~L2NormalizeCost() override = default;
|
|
||||||
|
|
||||||
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
|
||||||
const int32_t& stage_id) const override {
|
|
||||||
return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
|
|
||||||
}
|
|
||||||
double GetForwardCommCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&,
|
|
||||||
const int32_t&) const override {
|
|
||||||
return 0.0;
|
|
||||||
}
|
|
||||||
double GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
|
||||||
const int32_t& stage_id) const override;
|
|
||||||
double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
|
||||||
const int32_t& stage_id) const override {
|
|
||||||
return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
|
|
||||||
}
|
|
||||||
double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
|
||||||
const int32_t& stage_id) const override;
|
|
||||||
double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
|
||||||
const int32_t& stage_id) const override;
|
|
||||||
};
|
|
||||||
using L2NormalizeCostPtr = std::shared_ptr<L2NormalizeCost>;
|
|
||||||
|
|
||||||
class ReduceMethodCost : public OperatorCost {
|
class ReduceMethodCost : public OperatorCost {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -32,8 +32,8 @@ namespace parallel {
|
||||||
class ActivationBase : public OperatorInfo {
|
class ActivationBase : public OperatorInfo {
|
||||||
public:
|
public:
|
||||||
ActivationBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
ActivationBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||||
const PrimitiveAttrs& attrs)
|
const PrimitiveAttrs& attrs, OperatorCostPtr cost)
|
||||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs) {}
|
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {}
|
||||||
~ActivationBase() override = default;
|
~ActivationBase() override = default;
|
||||||
|
|
||||||
Status Init(const StrategyPtr& strategy) override;
|
Status Init(const StrategyPtr& strategy) override;
|
||||||
|
@ -51,19 +51,13 @@ class Activation : public ActivationBase {
|
||||||
public:
|
public:
|
||||||
Activation(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
Activation(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||||
const PrimitiveAttrs& attrs)
|
const PrimitiveAttrs& attrs)
|
||||||
: ActivationBase(name, inputs_shape, outputs_shape, attrs) {
|
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>()) {}
|
||||||
ac_cost_ptr_ = std::make_shared<ActivationCost>();
|
|
||||||
}
|
|
||||||
~Activation() override = default;
|
~Activation() override = default;
|
||||||
Status GenerateStrategies(int32_t stage_id) override;
|
Status GenerateStrategies(int32_t stage_id) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
||||||
OperatorCostPtr GetOperatorCost() const override { return ac_cost_ptr_; }
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status CheckStrategy(const StrategyPtr& strategy) override;
|
Status CheckStrategy(const StrategyPtr& strategy) override;
|
||||||
|
|
||||||
private:
|
|
||||||
ActivationCostPtr ac_cost_ptr_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class ActivationInfo : public Activation {
|
class ActivationInfo : public Activation {
|
||||||
|
@ -108,13 +102,10 @@ class Softmax : public ActivationBase {
|
||||||
public:
|
public:
|
||||||
explicit Softmax(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
explicit Softmax(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||||
const PrimitiveAttrs& attrs)
|
const PrimitiveAttrs& attrs)
|
||||||
: ActivationBase(name, inputs_shape, outputs_shape, attrs) {
|
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>()) {}
|
||||||
sm_cost_ptr_ = std::make_shared<SoftmaxCost>();
|
|
||||||
}
|
|
||||||
~Softmax() override = default;
|
~Softmax() override = default;
|
||||||
Status GenerateStrategies(int32_t stage_id) override;
|
Status GenerateStrategies(int32_t stage_id) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
||||||
OperatorCostPtr GetOperatorCost() const override { return sm_cost_ptr_; }
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status CheckStrategy(const StrategyPtr& strategy) override;
|
Status CheckStrategy(const StrategyPtr& strategy) override;
|
||||||
|
@ -122,7 +113,6 @@ class Softmax : public ActivationBase {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int32_t> axis_;
|
std::vector<int32_t> axis_;
|
||||||
SoftmaxCostPtr sm_cost_ptr_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class SoftmaxInfo : public Softmax {
|
class SoftmaxInfo : public Softmax {
|
||||||
|
|
|
@ -33,15 +33,12 @@ class ArithmeticBase : public OperatorInfo {
|
||||||
public:
|
public:
|
||||||
ArithmeticBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
ArithmeticBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||||
const PrimitiveAttrs& attrs)
|
const PrimitiveAttrs& attrs)
|
||||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs) {
|
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>()) {}
|
||||||
arithmeticcost_ptr_ = std::make_shared<ArithmeticCost>();
|
|
||||||
}
|
|
||||||
~ArithmeticBase() override = default;
|
~ArithmeticBase() override = default;
|
||||||
Status Init(const StrategyPtr& strategy) override;
|
Status Init(const StrategyPtr& strategy) override;
|
||||||
Status InitForCostModel(const StrategyPtr& strategy) override;
|
Status InitForCostModel(const StrategyPtr& strategy) override;
|
||||||
Status GenerateStrategies(int32_t) override;
|
Status GenerateStrategies(int32_t) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr&) override;
|
Status SetCostUnderStrategy(const StrategyPtr&) override;
|
||||||
OperatorCostPtr GetOperatorCost() const override { return arithmeticcost_ptr_; }
|
|
||||||
void ReComputeBatchSplitFlagList() override;
|
void ReComputeBatchSplitFlagList() override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -54,7 +51,6 @@ class ArithmeticBase : public OperatorInfo {
|
||||||
Status InferTensorMap() override;
|
Status InferTensorMap() override;
|
||||||
Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout, const Shape& dev_matrix_array);
|
Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout, const Shape& dev_matrix_array);
|
||||||
Shapes InferExpendShape();
|
Shapes InferExpendShape();
|
||||||
ArithmeticCostPtr arithmeticcost_ptr_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class SubInfo : public ArithmeticBase {
|
class SubInfo : public ArithmeticBase {
|
||||||
|
|
|
@ -31,16 +31,13 @@ class BatchParallelInfo : public OperatorInfo {
|
||||||
public:
|
public:
|
||||||
BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||||
const PrimitiveAttrs& attrs)
|
const PrimitiveAttrs& attrs)
|
||||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs), dev_num_(1) {
|
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()), dev_num_(1) {}
|
||||||
bp_cost_ptr_ = std::make_shared<BatchParallelCost>();
|
|
||||||
}
|
|
||||||
|
|
||||||
~BatchParallelInfo() override = default;
|
~BatchParallelInfo() override = default;
|
||||||
Status Init(const StrategyPtr& strategy) override;
|
Status Init(const StrategyPtr& strategy) override;
|
||||||
Status InitForCostModel(const StrategyPtr& strategy) override;
|
Status InitForCostModel(const StrategyPtr& strategy) override;
|
||||||
Status GenerateStrategies(int32_t stage_id) override;
|
Status GenerateStrategies(int32_t stage_id) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
||||||
OperatorCostPtr GetOperatorCost() const override { return bp_cost_ptr_; }
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status CheckStrategy(const StrategyPtr& strategy) override;
|
Status CheckStrategy(const StrategyPtr& strategy) override;
|
||||||
|
@ -55,7 +52,6 @@ class BatchParallelInfo : public OperatorInfo {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32_t dev_num_;
|
int32_t dev_num_;
|
||||||
BatchParallelCostPtr bp_cost_ptr_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo {
|
class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo {
|
||||||
|
|
|
@ -34,16 +34,13 @@ class BiasAddInfo : public OperatorInfo {
|
||||||
public:
|
public:
|
||||||
BiasAddInfo(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
BiasAddInfo(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||||
const PrimitiveAttrs& attrs)
|
const PrimitiveAttrs& attrs)
|
||||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs) {
|
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BiasAddCost>()) {}
|
||||||
biasaddcost_ptr_ = std::make_shared<ArithmeticCost>();
|
|
||||||
}
|
|
||||||
~BiasAddInfo() override = default;
|
~BiasAddInfo() override = default;
|
||||||
|
|
||||||
Status Init(const StrategyPtr& strategy) override;
|
Status Init(const StrategyPtr& strategy) override;
|
||||||
Status InitForCostModel(const StrategyPtr& strategy) override;
|
Status InitForCostModel(const StrategyPtr& strategy) override;
|
||||||
Status GenerateStrategies(int32_t) override;
|
Status GenerateStrategies(int32_t) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr&) override;
|
Status SetCostUnderStrategy(const StrategyPtr&) override;
|
||||||
OperatorCostPtr GetOperatorCost() const override { return biasaddcost_ptr_; }
|
|
||||||
void ReComputeBatchSplitFlagList() override;
|
void ReComputeBatchSplitFlagList() override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -55,7 +52,6 @@ class BiasAddInfo : public OperatorInfo {
|
||||||
Status InferDevMatrixShape() override;
|
Status InferDevMatrixShape() override;
|
||||||
Status InferTensorMap() override;
|
Status InferTensorMap() override;
|
||||||
Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout, const Shape& dev_matrix_array);
|
Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout, const Shape& dev_matrix_array);
|
||||||
ArithmeticCostPtr biasaddcost_ptr_;
|
|
||||||
};
|
};
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -33,15 +33,12 @@ class DropoutDoMaskInfo : public OperatorInfo {
|
||||||
public:
|
public:
|
||||||
DropoutDoMaskInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
DropoutDoMaskInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||||
const PrimitiveAttrs& attrs)
|
const PrimitiveAttrs& attrs)
|
||||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs) {
|
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()) {}
|
||||||
bpcost_ptr_ = std::make_shared<BatchParallelCost>();
|
|
||||||
}
|
|
||||||
~DropoutDoMaskInfo() override = default;
|
~DropoutDoMaskInfo() override = default;
|
||||||
|
|
||||||
Status Init(const StrategyPtr& strategy) override;
|
Status Init(const StrategyPtr& strategy) override;
|
||||||
Status GenerateStrategies(int32_t stage_id) override;
|
Status GenerateStrategies(int32_t stage_id) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
||||||
OperatorCostPtr GetOperatorCost() const override { return bpcost_ptr_; }
|
|
||||||
Status InitForCostModel(const StrategyPtr& strategy) override;
|
Status InitForCostModel(const StrategyPtr& strategy) override;
|
||||||
std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override;
|
std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override;
|
||||||
|
|
||||||
|
@ -53,9 +50,6 @@ class DropoutDoMaskInfo : public OperatorInfo {
|
||||||
Status GetAttrs() override { return SUCCESS; }
|
Status GetAttrs() override { return SUCCESS; }
|
||||||
Status InferTensorInfo() override;
|
Status InferTensorInfo() override;
|
||||||
Status InferDevMatrixShape() override;
|
Status InferDevMatrixShape() override;
|
||||||
|
|
||||||
private:
|
|
||||||
BatchParallelCostPtr bpcost_ptr_;
|
|
||||||
};
|
};
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -32,15 +32,12 @@ class GeneratorBase : public OperatorInfo {
|
||||||
public:
|
public:
|
||||||
GeneratorBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
GeneratorBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
const PrimitiveAttrs &attrs)
|
const PrimitiveAttrs &attrs)
|
||||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs) {
|
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GeneratorBaseCost>()) {}
|
||||||
generatorbasecost_ptr_ = std::make_shared<GeneratorBaseCost>();
|
|
||||||
}
|
|
||||||
|
|
||||||
~GeneratorBase() override = default;
|
~GeneratorBase() override = default;
|
||||||
|
|
||||||
Status Init(const StrategyPtr &strategy) override;
|
Status Init(const StrategyPtr &strategy) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||||
OperatorCostPtr GetOperatorCost() const override { return generatorbasecost_ptr_; }
|
|
||||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -52,7 +49,6 @@ class GeneratorBase : public OperatorInfo {
|
||||||
Status InferMirrorOps() override { return SUCCESS; }
|
Status InferMirrorOps() override { return SUCCESS; }
|
||||||
Status InferForwardCommunication() override { return SUCCESS; }
|
Status InferForwardCommunication() override { return SUCCESS; }
|
||||||
virtual Status InferReplaceOps(const StrategyPtr &strategy) = 0;
|
virtual Status InferReplaceOps(const StrategyPtr &strategy) = 0;
|
||||||
GeneratorBaseCostPtr generatorbasecost_ptr_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class DropoutGenMaskInfo : public GeneratorBase {
|
class DropoutGenMaskInfo : public GeneratorBase {
|
||||||
|
|
|
@ -32,14 +32,11 @@ class GetNextInfo : public OperatorInfo {
|
||||||
public:
|
public:
|
||||||
GetNextInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
GetNextInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
const PrimitiveAttrs &attrs)
|
const PrimitiveAttrs &attrs)
|
||||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs) {
|
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>()) {}
|
||||||
getnextcost_ptr_ = std::make_shared<GetNextCost>();
|
|
||||||
}
|
|
||||||
~GetNextInfo() override = default;
|
~GetNextInfo() override = default;
|
||||||
|
|
||||||
Status Init(const StrategyPtr &strategy) override;
|
Status Init(const StrategyPtr &strategy) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||||
OperatorCostPtr GetOperatorCost() const override { return getnextcost_ptr_; }
|
|
||||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||||
Status GenerateStrategies(int32_t stage_id) override;
|
Status GenerateStrategies(int32_t stage_id) override;
|
||||||
|
|
||||||
|
@ -65,7 +62,6 @@ class GetNextInfo : public OperatorInfo {
|
||||||
Shapes shapes_;
|
Shapes shapes_;
|
||||||
int32_t output_num_ = 0;
|
int32_t output_num_ = 0;
|
||||||
std::string shared_name_;
|
std::string shared_name_;
|
||||||
GetNextCostPtr getnextcost_ptr_;
|
|
||||||
};
|
};
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -33,12 +33,9 @@ class L2NormalizeInfo : public Activation {
|
||||||
public:
|
public:
|
||||||
L2NormalizeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
L2NormalizeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||||
const PrimitiveAttrs& attrs)
|
const PrimitiveAttrs& attrs)
|
||||||
: Activation(name, inputs_shape, outputs_shape, attrs) {
|
: Activation(name, inputs_shape, outputs_shape, attrs) {}
|
||||||
l2normalizecost_ptr_ = std::make_shared<L2NormalizeCost>();
|
|
||||||
}
|
|
||||||
~L2NormalizeInfo() override = default;
|
~L2NormalizeInfo() override = default;
|
||||||
Status GenerateStrategies(int32_t stage_id) override;
|
Status GenerateStrategies(int32_t stage_id) override;
|
||||||
OperatorCostPtr GetOperatorCost() const override { return l2normalizecost_ptr_; }
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status GetAttrs() override;
|
Status GetAttrs() override;
|
||||||
|
@ -47,7 +44,6 @@ class L2NormalizeInfo : public Activation {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32_t axis_ = 0; // Default value = 0
|
int32_t axis_ = 0; // Default value = 0
|
||||||
L2NormalizeCostPtr l2normalizecost_ptr_;
|
|
||||||
};
|
};
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -36,16 +36,13 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo {
|
||||||
public:
|
public:
|
||||||
SoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
SoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||||
const PrimitiveAttrs& attrs)
|
const PrimitiveAttrs& attrs)
|
||||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs) {
|
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCrossEntropyWithLogitsCost>()) {}
|
||||||
softmax_loss_cost_ptr_ = std::make_shared<SoftmaxCrossEntropyWithLogitsCost>();
|
|
||||||
}
|
|
||||||
~SoftmaxCrossEntropyWithLogitsInfo() override = default;
|
~SoftmaxCrossEntropyWithLogitsInfo() override = default;
|
||||||
Status Init(const StrategyPtr& strategy) override;
|
Status Init(const StrategyPtr& strategy) override;
|
||||||
Status InitForCostModel(const StrategyPtr& strategy) override;
|
Status InitForCostModel(const StrategyPtr& strategy) override;
|
||||||
|
|
||||||
Status GenerateStrategies(int32_t stage_id) override;
|
Status GenerateStrategies(int32_t stage_id) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
||||||
OperatorCostPtr GetOperatorCost() const override { return softmax_loss_cost_ptr_; }
|
|
||||||
void ReComputeBatchSplitFlagList() override;
|
void ReComputeBatchSplitFlagList() override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -59,7 +56,6 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo {
|
||||||
// There are two outputs for SoftmaxCrossEntropyWithLogits, and outputs[1] is used for grad and overload
|
// There are two outputs for SoftmaxCrossEntropyWithLogits, and outputs[1] is used for grad and overload
|
||||||
// the InferAsLossDivisor.
|
// the InferAsLossDivisor.
|
||||||
Status InferAsLossDivisor() override;
|
Status InferAsLossDivisor() override;
|
||||||
SoftmaxCrossEntropyWithLogitsCostPtr softmax_loss_cost_ptr_;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32_t axis_ = -1; // default -1
|
int32_t axis_ = -1; // default -1
|
||||||
|
|
|
@ -593,11 +593,11 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr&
|
||||||
// Here, we use the origin outputs_, because we only use the slice size of the output tensor.
|
// Here, we use the origin outputs_, because we only use the slice size of the output tensor.
|
||||||
// It does not matter whether the output tensor is transposed or not.
|
// It does not matter whether the output tensor is transposed or not.
|
||||||
double computation_cost =
|
double computation_cost =
|
||||||
matmulcost_ptr->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
|
cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
|
||||||
double communication_cost = matmulcost_ptr->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
|
double communication_cost = cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
|
||||||
std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
|
std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
|
||||||
result->communication_without_parameter_ =
|
result->communication_without_parameter_ =
|
||||||
matmulcost_ptr->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
|
cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
|
||||||
result->communication_with_partial_para_ =
|
result->communication_with_partial_para_ =
|
||||||
result->communication_without_parameter_ +
|
result->communication_without_parameter_ +
|
||||||
COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_);
|
COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_);
|
||||||
|
|
|
@ -34,9 +34,7 @@ class MatMulBase : public OperatorInfo {
|
||||||
public:
|
public:
|
||||||
MatMulBase(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
MatMulBase(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||||
const PrimitiveAttrs& attrs)
|
const PrimitiveAttrs& attrs)
|
||||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs) {
|
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>()) {}
|
||||||
matmulcost_ptr = std::make_shared<MatMulCost>();
|
|
||||||
}
|
|
||||||
~MatMulBase() override = default;
|
~MatMulBase() override = default;
|
||||||
|
|
||||||
Status Init(const StrategyPtr& strategy) override;
|
Status Init(const StrategyPtr& strategy) override;
|
||||||
|
@ -48,7 +46,6 @@ class MatMulBase : public OperatorInfo {
|
||||||
Status PrepareStrategy(int32_t stage_id, size_t dev_num, Dimensions combined_partitions, size_t input0_shape_size,
|
Status PrepareStrategy(int32_t stage_id, size_t dev_num, Dimensions combined_partitions, size_t input0_shape_size,
|
||||||
size_t input1_shape_size, StrategyPtr* sp);
|
size_t input1_shape_size, StrategyPtr* sp);
|
||||||
|
|
||||||
OperatorCostPtr GetOperatorCost() const override { return matmulcost_ptr; }
|
|
||||||
Status SwapLastTwoElements(Shape* shape);
|
Status SwapLastTwoElements(Shape* shape);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -66,8 +63,6 @@ class MatMulBase : public OperatorInfo {
|
||||||
bool transpose_b_ = false;
|
bool transpose_b_ = false;
|
||||||
size_t mat_a_dimension_ = 0;
|
size_t mat_a_dimension_ = 0;
|
||||||
size_t mat_b_dimension_ = 0;
|
size_t mat_b_dimension_ = 0;
|
||||||
|
|
||||||
MatMulCostPtr matmulcost_ptr;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class MatMul : public MatMulBase {
|
class MatMul : public MatMulBase {
|
||||||
|
|
|
@ -33,16 +33,13 @@ class OneHotInfo : public OperatorInfo {
|
||||||
public:
|
public:
|
||||||
OneHotInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
OneHotInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||||
const PrimitiveAttrs& attrs)
|
const PrimitiveAttrs& attrs)
|
||||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs) {
|
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<OneHotCost>()) {}
|
||||||
onehot_cost_ptr_ = std::make_shared<OneHotCost>();
|
|
||||||
}
|
|
||||||
~OneHotInfo() override = default;
|
~OneHotInfo() override = default;
|
||||||
Status Init(const StrategyPtr& strategy) override;
|
Status Init(const StrategyPtr& strategy) override;
|
||||||
Status InitForCostModel(const StrategyPtr& strategy) override;
|
Status InitForCostModel(const StrategyPtr& strategy) override;
|
||||||
|
|
||||||
Status GenerateStrategies(int32_t stage_id) override;
|
Status GenerateStrategies(int32_t stage_id) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
||||||
OperatorCostPtr GetOperatorCost() const override { return onehot_cost_ptr_; }
|
|
||||||
ReplaceGraphPtr replace_graph(const CNodePtr& cnode) override;
|
ReplaceGraphPtr replace_graph(const CNodePtr& cnode) override;
|
||||||
std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override;
|
std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override;
|
||||||
|
|
||||||
|
@ -60,7 +57,6 @@ class OneHotInfo : public OperatorInfo {
|
||||||
Status ComputeReplaceGraph(const CNodePtr& cnode);
|
Status ComputeReplaceGraph(const CNodePtr& cnode);
|
||||||
|
|
||||||
int axis_ = -1;
|
int axis_ = -1;
|
||||||
OneHotCostPtr onehot_cost_ptr_;
|
|
||||||
int32_t rank_ = 0;
|
int32_t rank_ = 0;
|
||||||
int32_t total_class_number_ = 1;
|
int32_t total_class_number_ = 1;
|
||||||
int32_t classes_each_device_ = 1;
|
int32_t classes_each_device_ = 1;
|
||||||
|
|
|
@ -1034,12 +1034,11 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr& strategy) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
int32_t stage_id = strategy->GetInputStage();
|
int32_t stage_id = strategy->GetInputStage();
|
||||||
double computation_cost =
|
double computation_cost = cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
|
||||||
GetOperatorCost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
|
double communication_cost = cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
|
||||||
double communication_cost = GetOperatorCost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
|
|
||||||
std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
|
std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
|
||||||
result->communication_without_parameter_ =
|
result->communication_without_parameter_ =
|
||||||
GetOperatorCost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
|
cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
|
||||||
result->communication_with_partial_para_ =
|
result->communication_with_partial_para_ =
|
||||||
result->communication_without_parameter_ +
|
result->communication_without_parameter_ +
|
||||||
COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_);
|
COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_);
|
||||||
|
@ -1096,7 +1095,7 @@ Status OperatorInfo::set_is_parameter(const std::vector<bool>& is_parameter) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
is_parameter_ = is_parameter;
|
is_parameter_ = is_parameter;
|
||||||
GetOperatorCost()->set_is_parameter(is_parameter);
|
cost()->set_is_parameter(is_parameter);
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1193,7 +1192,7 @@ Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector<size_t>& inpu
|
||||||
}
|
}
|
||||||
inputs_type_lengths_ = input_lengths;
|
inputs_type_lengths_ = input_lengths;
|
||||||
outputs_type_lengths_ = output_lengths;
|
outputs_type_lengths_ = output_lengths;
|
||||||
GetOperatorCost()->SetInputAndOutputTypeLength(input_lengths, output_lengths);
|
cost()->SetInputAndOutputTypeLength(input_lengths, output_lengths);
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1211,7 +1210,7 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr& stra
|
||||||
}
|
}
|
||||||
|
|
||||||
double OperatorInfo::GetForwardMemoryCostFromCNode() {
|
double OperatorInfo::GetForwardMemoryCostFromCNode() {
|
||||||
return GetOperatorCost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0);
|
return cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
|
|
|
@ -53,12 +53,13 @@ class Edge;
|
||||||
|
|
||||||
class OperatorInfo {
|
class OperatorInfo {
|
||||||
public:
|
public:
|
||||||
OperatorInfo(std::string name, Shapes inputs_shape, Shapes outputs_shape, PrimitiveAttrs attrs)
|
OperatorInfo(std::string name, Shapes inputs_shape, Shapes outputs_shape, PrimitiveAttrs attrs, OperatorCostPtr cost)
|
||||||
: name_(std::move(name)),
|
: name_(std::move(name)),
|
||||||
inputs_shape_(std::move(inputs_shape)),
|
inputs_shape_(std::move(inputs_shape)),
|
||||||
outputs_shape_(std::move(outputs_shape)),
|
outputs_shape_(std::move(outputs_shape)),
|
||||||
attrs_(std::move(attrs)),
|
attrs_(std::move(attrs)),
|
||||||
is_alive_(true) {
|
is_alive_(true),
|
||||||
|
cost_(cost) {
|
||||||
std::vector<bool> not_parameteter(inputs_shape_.size(), false);
|
std::vector<bool> not_parameteter(inputs_shape_.size(), false);
|
||||||
is_parameter_ = not_parameteter;
|
is_parameter_ = not_parameteter;
|
||||||
refkey_parameter_name_ = "";
|
refkey_parameter_name_ = "";
|
||||||
|
@ -75,7 +76,8 @@ class OperatorInfo {
|
||||||
// Given the stage_id (which indicates the number of devices),
|
// Given the stage_id (which indicates the number of devices),
|
||||||
// generate all strategies for this operator
|
// generate all strategies for this operator
|
||||||
virtual Status GenerateStrategies(int32_t stage_id) = 0;
|
virtual Status GenerateStrategies(int32_t stage_id) = 0;
|
||||||
virtual OperatorCostPtr GetOperatorCost() const = 0;
|
const OperatorCostPtr& cost() const { return cost_; }
|
||||||
|
void set_cost(const OperatorCostPtr& cost) { cost_ = cost; }
|
||||||
virtual Status SetCostUnderStrategy(const StrategyPtr& strategy) = 0;
|
virtual Status SetCostUnderStrategy(const StrategyPtr& strategy) = 0;
|
||||||
|
|
||||||
virtual std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies();
|
virtual std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies();
|
||||||
|
@ -115,7 +117,7 @@ class OperatorInfo {
|
||||||
void ReplaceSuccEdge(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge);
|
void ReplaceSuccEdge(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge);
|
||||||
void ReplacePreEdges(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge);
|
void ReplacePreEdges(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge);
|
||||||
void ReplaceSuccEdges(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge);
|
void ReplaceSuccEdges(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge);
|
||||||
std::vector<size_t> GetOutputTypeLengths() const { return GetOperatorCost()->outputs_type_lengths(); }
|
std::vector<size_t> GetOutputTypeLengths() const { return cost()->outputs_type_lengths(); }
|
||||||
void SetSelectedStrategyAndCost(const StrategyPtr& s_strategy, const CostPtr& cost) {
|
void SetSelectedStrategyAndCost(const StrategyPtr& s_strategy, const CostPtr& cost) {
|
||||||
selected_strategy_ = s_strategy;
|
selected_strategy_ = s_strategy;
|
||||||
selected_cost_ = cost;
|
selected_cost_ = cost;
|
||||||
|
@ -221,6 +223,9 @@ class OperatorInfo {
|
||||||
std::string refkey_parameter_name_;
|
std::string refkey_parameter_name_;
|
||||||
CNodePtr cnode_;
|
CNodePtr cnode_;
|
||||||
int32_t used_devices_ = -1;
|
int32_t used_devices_ = -1;
|
||||||
|
|
||||||
|
private:
|
||||||
|
OperatorCostPtr cost_;
|
||||||
};
|
};
|
||||||
|
|
||||||
Shape GetSliceShape(const Shape& tensor_shape, const Dimensions& strategy);
|
Shape GetSliceShape(const Shape& tensor_shape, const Dimensions& strategy);
|
||||||
|
|
|
@ -35,15 +35,12 @@ class PReLUInfo : public OperatorInfo {
|
||||||
public:
|
public:
|
||||||
PReLUInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
PReLUInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||||
const PrimitiveAttrs& attrs)
|
const PrimitiveAttrs& attrs)
|
||||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs) {
|
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<PReLUCost>()) {}
|
||||||
prelucost_ptr = std::make_shared<PReLUCost>();
|
|
||||||
}
|
|
||||||
~PReLUInfo() override = default;
|
~PReLUInfo() override = default;
|
||||||
Status Init(const StrategyPtr& strategy) override;
|
Status Init(const StrategyPtr& strategy) override;
|
||||||
Status InitForCostModel(const StrategyPtr& strategy) override;
|
Status InitForCostModel(const StrategyPtr& strategy) override;
|
||||||
|
|
||||||
Status GenerateStrategies(int32_t stage_id) override;
|
Status GenerateStrategies(int32_t stage_id) override;
|
||||||
OperatorCostPtr GetOperatorCost() const override { return prelucost_ptr; }
|
|
||||||
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -59,7 +56,6 @@ class PReLUInfo : public OperatorInfo {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Dimensions input_strategy_;
|
Dimensions input_strategy_;
|
||||||
PReLUCostPtr prelucost_ptr;
|
|
||||||
};
|
};
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -109,8 +109,12 @@ Status ReduceMethod::GetAttrs() {
|
||||||
}
|
}
|
||||||
cross_batch_ = cross_batch_iter->second->cast<BoolImmPtr>()->value();
|
cross_batch_ = cross_batch_iter->second->cast<BoolImmPtr>()->value();
|
||||||
}
|
}
|
||||||
reducemethodcost_ptr_->set_cross_batch(cross_batch_);
|
auto reducemethodcost = std::dynamic_pointer_cast<ReduceMethodCost>(cost());
|
||||||
|
if (reducemethodcost == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Cost cast to ReduceMethodCostPtr failed!";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
reducemethodcost->set_cross_batch(cross_batch_);
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -34,9 +34,7 @@ class ReduceMethod : public OperatorInfo {
|
||||||
public:
|
public:
|
||||||
ReduceMethod(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
ReduceMethod(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
const PrimitiveAttrs &attrs)
|
const PrimitiveAttrs &attrs)
|
||||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs) {
|
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMethodCost>()) {}
|
||||||
reducemethodcost_ptr_ = std::make_shared<ReduceMethodCost>();
|
|
||||||
}
|
|
||||||
~ReduceMethod() override = default;
|
~ReduceMethod() override = default;
|
||||||
|
|
||||||
Status Init(const StrategyPtr &strategy) override;
|
Status Init(const StrategyPtr &strategy) override;
|
||||||
|
@ -44,13 +42,11 @@ class ReduceMethod : public OperatorInfo {
|
||||||
|
|
||||||
Status GenerateStrategies(int32_t stage_id) override;
|
Status GenerateStrategies(int32_t stage_id) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||||
OperatorCostPtr GetOperatorCost() const override { return reducemethodcost_ptr_; }
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::string reduce_method_;
|
std::string reduce_method_;
|
||||||
bool keepdims_ = false;
|
bool keepdims_ = false;
|
||||||
bool cross_batch_ = false;
|
bool cross_batch_ = false;
|
||||||
ReduceMethodCostPtr reducemethodcost_ptr_;
|
|
||||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||||
Status GetAttrs() override;
|
Status GetAttrs() override;
|
||||||
Dimensions InferOutputStrategy();
|
Dimensions InferOutputStrategy();
|
||||||
|
@ -110,7 +106,7 @@ class ReduceMeanInfo : public ReduceMethod {
|
||||||
ReduceMeanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
ReduceMeanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
const PrimitiveAttrs &attrs)
|
const PrimitiveAttrs &attrs)
|
||||||
: ReduceMethod(name, inputs_shape, outputs_shape, attrs) {
|
: ReduceMethod(name, inputs_shape, outputs_shape, attrs) {
|
||||||
reducemethodcost_ptr_ = std::make_shared<ReduceMeanCost>();
|
set_cost(std::make_shared<ReduceMeanCost>());
|
||||||
}
|
}
|
||||||
|
|
||||||
~ReduceMeanInfo() override = default;
|
~ReduceMeanInfo() override = default;
|
||||||
|
|
|
@ -36,12 +36,10 @@ class ReshapeInfo : public OperatorInfo {
|
||||||
public:
|
public:
|
||||||
ReshapeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
ReshapeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||||
const PrimitiveAttrs& attrs)
|
const PrimitiveAttrs& attrs)
|
||||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs),
|
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>()),
|
||||||
dev_num_(0),
|
dev_num_(0),
|
||||||
input_layout_set_flag_(false),
|
input_layout_set_flag_(false),
|
||||||
output_layout_set_flag_(false) {
|
output_layout_set_flag_(false) {}
|
||||||
reshape_cost_ptr_ = std::make_shared<ReshapeCost>();
|
|
||||||
}
|
|
||||||
~ReshapeInfo() override = default;
|
~ReshapeInfo() override = default;
|
||||||
Status Init(const StrategyPtr& strategy) override;
|
Status Init(const StrategyPtr& strategy) override;
|
||||||
void SetInputLayout(const TensorLayout& input_layout) {
|
void SetInputLayout(const TensorLayout& input_layout) {
|
||||||
|
@ -55,7 +53,6 @@ class ReshapeInfo : public OperatorInfo {
|
||||||
Status InitForCostModel(const StrategyPtr& strategy) override;
|
Status InitForCostModel(const StrategyPtr& strategy) override;
|
||||||
Status GenerateStrategies(int32_t stage_id) override;
|
Status GenerateStrategies(int32_t stage_id) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
||||||
OperatorCostPtr GetOperatorCost() const override { return reshape_cost_ptr_; }
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status CheckStrategy(const StrategyPtr& strategy) override;
|
Status CheckStrategy(const StrategyPtr& strategy) override;
|
||||||
|
@ -67,7 +64,6 @@ class ReshapeInfo : public OperatorInfo {
|
||||||
Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout);
|
Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout);
|
||||||
Status GetAttrs() override;
|
Status GetAttrs() override;
|
||||||
Strategys GetOutputsStrategy();
|
Strategys GetOutputsStrategy();
|
||||||
ReshapeCostPtr reshape_cost_ptr_;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status GetParameterInput();
|
Status GetParameterInput();
|
||||||
|
|
|
@ -34,9 +34,7 @@ class TmpIdentityInfo : public OperatorInfo {
|
||||||
public:
|
public:
|
||||||
TmpIdentityInfo(const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs,
|
TmpIdentityInfo(const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs,
|
||||||
const std::string& name = IDENTITY_INFO)
|
const std::string& name = IDENTITY_INFO)
|
||||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs) {
|
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TmpIdentityCost>()) {}
|
||||||
id_cost_ptr_ = std::make_shared<TmpIdentityCost>();
|
|
||||||
}
|
|
||||||
~TmpIdentityInfo() override = default;
|
~TmpIdentityInfo() override = default;
|
||||||
|
|
||||||
Status Init(const StrategyPtr& strategy) override;
|
Status Init(const StrategyPtr& strategy) override;
|
||||||
|
@ -44,7 +42,6 @@ class TmpIdentityInfo : public OperatorInfo {
|
||||||
|
|
||||||
Status GenerateStrategies(int32_t stage_id) override;
|
Status GenerateStrategies(int32_t stage_id) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
||||||
OperatorCostPtr GetOperatorCost() const override { return id_cost_ptr_; }
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status CheckStrategy(const StrategyPtr& strategy) override;
|
Status CheckStrategy(const StrategyPtr& strategy) override;
|
||||||
|
@ -54,9 +51,6 @@ class TmpIdentityInfo : public OperatorInfo {
|
||||||
Status InferTensorInfo() override;
|
Status InferTensorInfo() override;
|
||||||
Status InferDevMatrixShape() override;
|
Status InferDevMatrixShape() override;
|
||||||
Status InferTensorMap() override;
|
Status InferTensorMap() override;
|
||||||
|
|
||||||
private:
|
|
||||||
TmpIdentityCostPtr id_cost_ptr_;
|
|
||||||
};
|
};
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -35,15 +35,12 @@ class TransposeInfo : public OperatorInfo {
|
||||||
public:
|
public:
|
||||||
TransposeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
TransposeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||||
const PrimitiveAttrs& attrs)
|
const PrimitiveAttrs& attrs)
|
||||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs) {
|
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TransposeCost>()) {}
|
||||||
transpose_cost_ptr_ = std::make_shared<ActivationCost>();
|
|
||||||
}
|
|
||||||
~TransposeInfo() override = default;
|
~TransposeInfo() override = default;
|
||||||
Status Init(const StrategyPtr& strategy) override;
|
Status Init(const StrategyPtr& strategy) override;
|
||||||
Status InitForCostModel(const StrategyPtr& strategy) override;
|
Status InitForCostModel(const StrategyPtr& strategy) override;
|
||||||
Status GenerateStrategies(int32_t stage_id) override;
|
Status GenerateStrategies(int32_t stage_id) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
||||||
OperatorCostPtr GetOperatorCost() const override { return transpose_cost_ptr_; }
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status CheckStrategy(const StrategyPtr& strategy) override;
|
Status CheckStrategy(const StrategyPtr& strategy) override;
|
||||||
|
@ -60,7 +57,6 @@ class TransposeInfo : public OperatorInfo {
|
||||||
Status ComputeAxis();
|
Status ComputeAxis();
|
||||||
std::vector<int32_t> axis_v_;
|
std::vector<int32_t> axis_v_;
|
||||||
Dimensions input_strategy_;
|
Dimensions input_strategy_;
|
||||||
ActivationCostPtr transpose_cost_ptr_;
|
|
||||||
};
|
};
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -32,16 +32,13 @@ class VirtualDatasetInfo : public OperatorInfo {
|
||||||
public:
|
public:
|
||||||
VirtualDatasetInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
VirtualDatasetInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||||
const PrimitiveAttrs& attrs)
|
const PrimitiveAttrs& attrs)
|
||||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs) {
|
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<VirtualDatasetCost>()) {}
|
||||||
vd_cost_ptr_ = std::make_shared<VirtualDatasetCost>();
|
|
||||||
}
|
|
||||||
~VirtualDatasetInfo() override = default;
|
~VirtualDatasetInfo() override = default;
|
||||||
Status Init(const StrategyPtr& strategy) override;
|
Status Init(const StrategyPtr& strategy) override;
|
||||||
Status InitForCostModel(const StrategyPtr& strategy) override;
|
Status InitForCostModel(const StrategyPtr& strategy) override;
|
||||||
|
|
||||||
Status GenerateStrategies(int32_t stage_id) override;
|
Status GenerateStrategies(int32_t stage_id) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
||||||
OperatorCostPtr GetOperatorCost() const override { return vd_cost_ptr_; }
|
|
||||||
void ReComputeBatchSplitFlagList() override;
|
void ReComputeBatchSplitFlagList() override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -53,9 +50,6 @@ class VirtualDatasetInfo : public OperatorInfo {
|
||||||
Status InferTensorMap() override;
|
Status InferTensorMap() override;
|
||||||
Status GetAttrs() override;
|
Status GetAttrs() override;
|
||||||
Status InferAsLossDivisor() override;
|
Status InferAsLossDivisor() override;
|
||||||
|
|
||||||
private:
|
|
||||||
VirtualDatasetCostPtr vd_cost_ptr_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
|
|
|
@ -84,9 +84,9 @@ TEST_F(TestActivation, test_activation_strategies) {
|
||||||
act_ptr_->InitForCostModel(sp);
|
act_ptr_->InitForCostModel(sp);
|
||||||
std::vector<TensorInfo> inputs_info = act_ptr_->inputs_tensor_info();
|
std::vector<TensorInfo> inputs_info = act_ptr_->inputs_tensor_info();
|
||||||
std::vector<TensorInfo> outputs_info = act_ptr_->outputs_tensor_info();
|
std::vector<TensorInfo> outputs_info = act_ptr_->outputs_tensor_info();
|
||||||
ASSERT_DOUBLE_EQ(act_ptr_->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
|
ASSERT_DOUBLE_EQ(act_ptr_->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
|
||||||
cost.computation_cost_);
|
cost.computation_cost_);
|
||||||
ASSERT_DOUBLE_EQ(act_ptr_->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
|
ASSERT_DOUBLE_EQ(act_ptr_->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
|
||||||
cost.communication_cost_);
|
cost.communication_cost_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -109,9 +109,9 @@ TEST_F(TestActivation, test_softmax_strategies) {
|
||||||
soft_ptr_->InitForCostModel(sp);
|
soft_ptr_->InitForCostModel(sp);
|
||||||
std::vector<TensorInfo> inputs_info = soft_ptr_->inputs_tensor_info();
|
std::vector<TensorInfo> inputs_info = soft_ptr_->inputs_tensor_info();
|
||||||
std::vector<TensorInfo> outputs_info = soft_ptr_->outputs_tensor_info();
|
std::vector<TensorInfo> outputs_info = soft_ptr_->outputs_tensor_info();
|
||||||
ASSERT_DOUBLE_EQ(soft_ptr_->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
|
ASSERT_DOUBLE_EQ(soft_ptr_->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
|
||||||
cost.computation_cost_);
|
cost.computation_cost_);
|
||||||
ASSERT_DOUBLE_EQ(soft_ptr_->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
|
ASSERT_DOUBLE_EQ(soft_ptr_->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
|
||||||
cost.communication_cost_);
|
cost.communication_cost_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -569,7 +569,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies1) {
|
||||||
matmul1->InitForCostModel(sp);
|
matmul1->InitForCostModel(sp);
|
||||||
std::vector<TensorInfo> inputs_info = matmul1->inputs_tensor_info();
|
std::vector<TensorInfo> inputs_info = matmul1->inputs_tensor_info();
|
||||||
std::vector<TensorInfo> outputs_info = matmul1->outputs_tensor_info();
|
std::vector<TensorInfo> outputs_info = matmul1->outputs_tensor_info();
|
||||||
ASSERT_DOUBLE_EQ(matmul1->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
|
ASSERT_DOUBLE_EQ(matmul1->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
|
||||||
cost.computation_cost_);
|
cost.computation_cost_);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -599,7 +599,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies2) {
|
||||||
TensorInfo replica_input1_info(tly, input1_shape, input1_slice_shape);
|
TensorInfo replica_input1_info(tly, input1_shape, input1_slice_shape);
|
||||||
replica_inputs_info.push_back(replica_input1_info);
|
replica_inputs_info.push_back(replica_input1_info);
|
||||||
|
|
||||||
ASSERT_DOUBLE_EQ(matmul3->GetOperatorCost()->GetComputationCost(replica_inputs_info, outputs_info, sp->GetInputStage()),
|
ASSERT_DOUBLE_EQ(matmul3->cost()->GetComputationCost(replica_inputs_info, outputs_info, sp->GetInputStage()),
|
||||||
cost.computation_cost_);
|
cost.computation_cost_);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
@ -188,11 +188,11 @@ TEST_F(TestTensorAddInfo, GenerateStrategies) {
|
||||||
tensor_add->InitForCostModel(sp);
|
tensor_add->InitForCostModel(sp);
|
||||||
std::vector<TensorInfo> inputs_info = tensor_add->inputs_tensor_info();
|
std::vector<TensorInfo> inputs_info = tensor_add->inputs_tensor_info();
|
||||||
std::vector<TensorInfo> outputs_info = tensor_add->outputs_tensor_info();
|
std::vector<TensorInfo> outputs_info = tensor_add->outputs_tensor_info();
|
||||||
double memory_cost0 = tensor_add->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage());
|
double memory_cost0 = tensor_add->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage());
|
||||||
double memory_cost1 = cost.computation_cost_;
|
double memory_cost1 = cost.computation_cost_;
|
||||||
bool memory = memory_cost0 - memory_cost1 <= 1.0;
|
bool memory = memory_cost0 - memory_cost1 <= 1.0;
|
||||||
|
|
||||||
double comm_cost0 = tensor_add->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage());
|
double comm_cost0 = tensor_add->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage());
|
||||||
double comm_cost1 = cost.communication_cost_;
|
double comm_cost1 = cost.communication_cost_;
|
||||||
bool comm = comm_cost0 - comm_cost1 <= 1.0;
|
bool comm = comm_cost0 - comm_cost1 <= 1.0;
|
||||||
|
|
||||||
|
@ -210,11 +210,11 @@ TEST_F(TestTensorAddInfo, GenerateStrategies1) {
|
||||||
tensor_add1->InitForCostModel(sp);
|
tensor_add1->InitForCostModel(sp);
|
||||||
std::vector<TensorInfo> inputs_info = tensor_add1->inputs_tensor_info();
|
std::vector<TensorInfo> inputs_info = tensor_add1->inputs_tensor_info();
|
||||||
std::vector<TensorInfo> outputs_info = tensor_add1->outputs_tensor_info();
|
std::vector<TensorInfo> outputs_info = tensor_add1->outputs_tensor_info();
|
||||||
double memory_cost0 = tensor_add1->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage());
|
double memory_cost0 = tensor_add1->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage());
|
||||||
double memory_cost1 = cost.computation_cost_;
|
double memory_cost1 = cost.computation_cost_;
|
||||||
bool memory = memory_cost0 - memory_cost1 <= 1.0;
|
bool memory = memory_cost0 - memory_cost1 <= 1.0;
|
||||||
|
|
||||||
double comm_cost0 = tensor_add1->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage());
|
double comm_cost0 = tensor_add1->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage());
|
||||||
double comm_cost1 = cost.communication_cost_;
|
double comm_cost1 = cost.communication_cost_;
|
||||||
bool comm = comm_cost0 - comm_cost1 <= 1.0;
|
bool comm = comm_cost0 - comm_cost1 <= 1.0;
|
||||||
|
|
||||||
|
|
|
@ -145,9 +145,9 @@ TEST_F(TestTmpIdentityInfo, test_generate_strategies) {
|
||||||
identity_ptr->Init(sp);
|
identity_ptr->Init(sp);
|
||||||
std::vector<TensorInfo> inputs_info = identity_ptr->inputs_tensor_info();
|
std::vector<TensorInfo> inputs_info = identity_ptr->inputs_tensor_info();
|
||||||
std::vector<TensorInfo> outputs_info = identity_ptr->outputs_tensor_info();
|
std::vector<TensorInfo> outputs_info = identity_ptr->outputs_tensor_info();
|
||||||
ASSERT_DOUBLE_EQ(identity_ptr->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
|
ASSERT_DOUBLE_EQ(identity_ptr->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
|
||||||
cost.computation_cost_);
|
cost.computation_cost_);
|
||||||
ASSERT_DOUBLE_EQ(identity_ptr->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
|
ASSERT_DOUBLE_EQ(identity_ptr->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
|
||||||
cost.communication_cost_);
|
cost.communication_cost_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue