From b413638f2328288ce3b693ed161dd99f210f58d9 Mon Sep 17 00:00:00 2001 From: c00425699 Date: Thu, 9 Apr 2020 14:40:43 +0800 Subject: [PATCH] refactor OperatorCostPtr in OperatorInfo --- .../auto_parallel/operator_costmodel.cc | 54 ------------------- .../auto_parallel/operator_costmodel.h | 30 ++--------- .../ccsrc/parallel/ops_info/activation_info.h | 18 ++----- .../ccsrc/parallel/ops_info/arithmetic_info.h | 6 +-- .../parallel/ops_info/batch_parallel_info.h | 6 +-- .../ccsrc/parallel/ops_info/bias_add_info.h | 6 +-- .../parallel/ops_info/dropout_do_mask_info.h | 8 +-- .../ccsrc/parallel/ops_info/generator_info.h | 6 +-- .../ccsrc/parallel/ops_info/get_next_info.h | 6 +-- .../parallel/ops_info/l2_normalize_info.h | 6 +-- mindspore/ccsrc/parallel/ops_info/loss_info.h | 6 +-- .../ccsrc/parallel/ops_info/matmul_info.cc | 6 +-- .../ccsrc/parallel/ops_info/matmul_info.h | 7 +-- .../ccsrc/parallel/ops_info/onehot_info.h | 6 +-- .../ccsrc/parallel/ops_info/operator_info.cc | 13 +++-- .../ccsrc/parallel/ops_info/operator_info.h | 13 +++-- .../ccsrc/parallel/ops_info/prelu_info.h | 6 +-- .../parallel/ops_info/reduce_method_info.cc | 8 ++- .../parallel/ops_info/reduce_method_info.h | 8 +-- .../ccsrc/parallel/ops_info/reshape_info.h | 8 +-- .../parallel/ops_info/tmp_identity_info.h | 8 +-- .../ccsrc/parallel/ops_info/transpose_info.h | 6 +-- .../parallel/ops_info/virtual_dataset_info.h | 8 +-- .../cpp/parallel/ops_info/activation_test.cc | 8 +-- .../cpp/parallel/ops_info/matmul_info_test.cc | 4 +- .../parallel/ops_info/tensor_add_info_test.cc | 8 +-- .../cpp/parallel/ops_info/tmpidentity_test.cc | 4 +- 27 files changed, 62 insertions(+), 211 deletions(-) diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc index 7c17b499b11..93d7dc56c5e 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc @@ -514,60 +514,6 @@ double ArithmeticCost::GetBackwardCommCost(const std::vector& inputs return result; } -double L2NormalizeCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, - const int32_t& stage_id) const { - double result = 0.0; - if (is_parameter_[0]) { - TensorInfo input_tensor_info = inputs[0]; - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - Shape input_shape = input_tensor_info.shape(); - Shape input_slice_shape = input_tensor_info.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input_shape.size(); ++i) { - used_device_num *= input_shape[i] / input_slice_shape[i]; - } - - if (total_device_num != IntToSize(used_device_num)) - result += ListProduct(input_slice_shape) * static_cast(inputs_type_lengths_[0]); - } - - return result; -} - -double L2NormalizeCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, - const int32_t&) const { - TensorInfo input0_info = inputs[0]; - Shape input0_slice_shape = input0_info.slice_shape(); - return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); -} - -double L2NormalizeCost::GetBackwardComputationCost(const std::vector& inputs, - const std::vector&, const int32_t& stage_id) const { - double result = 0.0; - - if (is_parameter_[0]) { - TensorInfo input_tensor_info = inputs[0]; - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - Shape input_shape = input_tensor_info.shape(); - Shape input_slice_shape = input_tensor_info.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input_shape.size(); ++i) { - used_device_num *= input_shape[i] / input_slice_shape[i]; - } - - if (total_device_num != IntToSize(used_device_num)) - result += ListProduct(input_slice_shape) * static_cast(inputs_type_lengths_[0]); - } - - return result; -} - bool IsDataParallel(const Shape& shape, const Shape& slice_shape, const int32_t& stage_id) { CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h index 8f0099bba3a..73f3ff139f2 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h @@ -132,6 +132,8 @@ class ActivationCost : public OperatorCost { }; using ActivationCostPtr = std::shared_ptr; +using TransposeCost = ActivationCost; +using TransposeCostPtr = std::shared_ptr; class SoftmaxCost : public OperatorCost { public: @@ -415,32 +417,8 @@ class ArithmeticCost : public OperatorCost { const int32_t& stage_id) const override; }; using ArithmeticCostPtr = std::shared_ptr; - -class L2NormalizeCost : public OperatorCost { - public: - L2NormalizeCost() = default; - ~L2NormalizeCost() override = default; - - double GetCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector&, const std::vector&, - const int32_t&) const override { - return 0.0; - } - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; -}; -using L2NormalizeCostPtr = std::shared_ptr; +using BiasAddCost = ArithmeticCost; +using BiasAddCostPtr = std::shared_ptr; class ReduceMethodCost : public OperatorCost { public: diff --git a/mindspore/ccsrc/parallel/ops_info/activation_info.h b/mindspore/ccsrc/parallel/ops_info/activation_info.h index 183b593e238..21774c43ee2 100644 --- a/mindspore/ccsrc/parallel/ops_info/activation_info.h +++ b/mindspore/ccsrc/parallel/ops_info/activation_info.h @@ -32,8 +32,8 @@ namespace parallel { class ActivationBase : public OperatorInfo { public: ActivationBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs) {} + const PrimitiveAttrs& attrs, OperatorCostPtr cost) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} ~ActivationBase() override = default; Status Init(const StrategyPtr& strategy) override; @@ -51,19 +51,13 @@ class Activation : public ActivationBase { public: Activation(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : ActivationBase(name, inputs_shape, outputs_shape, attrs) { - ac_cost_ptr_ = std::make_shared(); - } + : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~Activation() override = default; Status GenerateStrategies(int32_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr& strategy) override; - OperatorCostPtr GetOperatorCost() const override { return ac_cost_ptr_; } protected: Status CheckStrategy(const StrategyPtr& strategy) override; - - private: - ActivationCostPtr ac_cost_ptr_; }; class ActivationInfo : public Activation { @@ -108,13 +102,10 @@ class Softmax : public ActivationBase { public: explicit Softmax(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : ActivationBase(name, inputs_shape, outputs_shape, attrs) { - sm_cost_ptr_ = std::make_shared(); - } + : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~Softmax() override = default; Status GenerateStrategies(int32_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr& strategy) override; - OperatorCostPtr GetOperatorCost() const override { return sm_cost_ptr_; } protected: Status CheckStrategy(const StrategyPtr& strategy) override; @@ -122,7 +113,6 @@ class Softmax : public ActivationBase { private: std::vector axis_; - SoftmaxCostPtr sm_cost_ptr_; }; class SoftmaxInfo : public Softmax { diff --git a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h b/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h index 7cd0d66b1bc..daa2ad595c5 100644 --- a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h +++ b/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h @@ -33,15 +33,12 @@ class ArithmeticBase : public OperatorInfo { public: ArithmeticBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs) { - arithmeticcost_ptr_ = std::make_shared(); - } + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~ArithmeticBase() override = default; Status Init(const StrategyPtr& strategy) override; Status InitForCostModel(const StrategyPtr& strategy) override; Status GenerateStrategies(int32_t) override; Status SetCostUnderStrategy(const StrategyPtr&) override; - OperatorCostPtr GetOperatorCost() const override { return arithmeticcost_ptr_; } void ReComputeBatchSplitFlagList() override; protected: @@ -54,7 +51,6 @@ class ArithmeticBase : public OperatorInfo { Status InferTensorMap() override; Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout, const Shape& dev_matrix_array); Shapes InferExpendShape(); - ArithmeticCostPtr arithmeticcost_ptr_; }; class SubInfo : public ArithmeticBase { diff --git a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h index 57711b52983..fae96dcab5b 100644 --- a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h +++ b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h @@ -31,16 +31,13 @@ class BatchParallelInfo : public OperatorInfo { public: BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs), dev_num_(1) { - bp_cost_ptr_ = std::make_shared(); - } + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), dev_num_(1) {} ~BatchParallelInfo() override = default; Status Init(const StrategyPtr& strategy) override; Status InitForCostModel(const StrategyPtr& strategy) override; Status GenerateStrategies(int32_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr& strategy) override; - OperatorCostPtr GetOperatorCost() const override { return bp_cost_ptr_; } protected: Status CheckStrategy(const StrategyPtr& strategy) override; @@ -55,7 +52,6 @@ class BatchParallelInfo : public OperatorInfo { private: int32_t dev_num_; - BatchParallelCostPtr bp_cost_ptr_; }; class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo { diff --git a/mindspore/ccsrc/parallel/ops_info/bias_add_info.h b/mindspore/ccsrc/parallel/ops_info/bias_add_info.h index 07f0bc00ffe..dea5c90c88a 100644 --- a/mindspore/ccsrc/parallel/ops_info/bias_add_info.h +++ b/mindspore/ccsrc/parallel/ops_info/bias_add_info.h @@ -34,16 +34,13 @@ class BiasAddInfo : public OperatorInfo { public: BiasAddInfo(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs) { - biasaddcost_ptr_ = std::make_shared(); - } + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~BiasAddInfo() override = default; Status Init(const StrategyPtr& strategy) override; Status InitForCostModel(const StrategyPtr& strategy) override; Status GenerateStrategies(int32_t) override; Status SetCostUnderStrategy(const StrategyPtr&) override; - OperatorCostPtr GetOperatorCost() const override { return biasaddcost_ptr_; } void ReComputeBatchSplitFlagList() override; protected: @@ -55,7 +52,6 @@ class BiasAddInfo : public OperatorInfo { Status InferDevMatrixShape() override; Status InferTensorMap() override; Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout, const Shape& dev_matrix_array); - ArithmeticCostPtr biasaddcost_ptr_; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h index e43601355a6..859b3e06a41 100644 --- a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h +++ b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h @@ -33,15 +33,12 @@ class DropoutDoMaskInfo : public OperatorInfo { public: DropoutDoMaskInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs) { - bpcost_ptr_ = std::make_shared(); - } + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~DropoutDoMaskInfo() override = default; Status Init(const StrategyPtr& strategy) override; Status GenerateStrategies(int32_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr& strategy) override; - OperatorCostPtr GetOperatorCost() const override { return bpcost_ptr_; } Status InitForCostModel(const StrategyPtr& strategy) override; std::shared_ptr>> GenerateBatchStrategies() override; @@ -53,9 +50,6 @@ class DropoutDoMaskInfo : public OperatorInfo { Status GetAttrs() override { return SUCCESS; } Status InferTensorInfo() override; Status InferDevMatrixShape() override; - - private: - BatchParallelCostPtr bpcost_ptr_; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/generator_info.h b/mindspore/ccsrc/parallel/ops_info/generator_info.h index a280fac28e1..68024593f3c 100644 --- a/mindspore/ccsrc/parallel/ops_info/generator_info.h +++ b/mindspore/ccsrc/parallel/ops_info/generator_info.h @@ -32,15 +32,12 @@ class GeneratorBase : public OperatorInfo { public: GeneratorBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs) { - generatorbasecost_ptr_ = std::make_shared(); - } + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~GeneratorBase() override = default; Status Init(const StrategyPtr &strategy) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - OperatorCostPtr GetOperatorCost() const override { return generatorbasecost_ptr_; } Status InitForCostModel(const StrategyPtr &strategy) override; protected: @@ -52,7 +49,6 @@ class GeneratorBase : public OperatorInfo { Status InferMirrorOps() override { return SUCCESS; } Status InferForwardCommunication() override { return SUCCESS; } virtual Status InferReplaceOps(const StrategyPtr &strategy) = 0; - GeneratorBaseCostPtr generatorbasecost_ptr_; }; class DropoutGenMaskInfo : public GeneratorBase { diff --git a/mindspore/ccsrc/parallel/ops_info/get_next_info.h b/mindspore/ccsrc/parallel/ops_info/get_next_info.h index 32adce11652..9a65eff035c 100644 --- a/mindspore/ccsrc/parallel/ops_info/get_next_info.h +++ b/mindspore/ccsrc/parallel/ops_info/get_next_info.h @@ -32,14 +32,11 @@ class GetNextInfo : public OperatorInfo { public: GetNextInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs) { - getnextcost_ptr_ = std::make_shared(); - } + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~GetNextInfo() override = default; Status Init(const StrategyPtr &strategy) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - OperatorCostPtr GetOperatorCost() const override { return getnextcost_ptr_; } Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; @@ -65,7 +62,6 @@ class GetNextInfo : public OperatorInfo { Shapes shapes_; int32_t output_num_ = 0; std::string shared_name_; - GetNextCostPtr getnextcost_ptr_; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h b/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h index c0af9dbcb9c..22ed5a965b3 100644 --- a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h +++ b/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h @@ -33,12 +33,9 @@ class L2NormalizeInfo : public Activation { public: L2NormalizeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : Activation(name, inputs_shape, outputs_shape, attrs) { - l2normalizecost_ptr_ = std::make_shared(); - } + : Activation(name, inputs_shape, outputs_shape, attrs) {} ~L2NormalizeInfo() override = default; Status GenerateStrategies(int32_t stage_id) override; - OperatorCostPtr GetOperatorCost() const override { return l2normalizecost_ptr_; } protected: Status GetAttrs() override; @@ -47,7 +44,6 @@ class L2NormalizeInfo : public Activation { private: int32_t axis_ = 0; // Default value = 0 - L2NormalizeCostPtr l2normalizecost_ptr_; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/loss_info.h b/mindspore/ccsrc/parallel/ops_info/loss_info.h index 6a9697a4471..f1c2537a39d 100644 --- a/mindspore/ccsrc/parallel/ops_info/loss_info.h +++ b/mindspore/ccsrc/parallel/ops_info/loss_info.h @@ -36,16 +36,13 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo { public: SoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs) { - softmax_loss_cost_ptr_ = std::make_shared(); - } + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~SoftmaxCrossEntropyWithLogitsInfo() override = default; Status Init(const StrategyPtr& strategy) override; Status InitForCostModel(const StrategyPtr& strategy) override; Status GenerateStrategies(int32_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr& strategy) override; - OperatorCostPtr GetOperatorCost() const override { return softmax_loss_cost_ptr_; } void ReComputeBatchSplitFlagList() override; protected: @@ -59,7 +56,6 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo { // There are two outputs for SoftmaxCrossEntropyWithLogits, and outputs[1] is used for grad and overload // the InferAsLossDivisor. Status InferAsLossDivisor() override; - SoftmaxCrossEntropyWithLogitsCostPtr softmax_loss_cost_ptr_; private: int32_t axis_ = -1; // default -1 diff --git a/mindspore/ccsrc/parallel/ops_info/matmul_info.cc b/mindspore/ccsrc/parallel/ops_info/matmul_info.cc index 2b02dc100d4..848116d68a6 100644 --- a/mindspore/ccsrc/parallel/ops_info/matmul_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/matmul_info.cc @@ -593,11 +593,11 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr& // Here, we use the origin outputs_, because we only use the slice size of the output tensor. // It does not matter whether the output tensor is transposed or not. double computation_cost = - matmulcost_ptr->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); - double communication_cost = matmulcost_ptr->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); + cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); + double communication_cost = cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); std::shared_ptr result = std::make_shared(computation_cost, communication_cost); result->communication_without_parameter_ = - matmulcost_ptr->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); + cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); result->communication_with_partial_para_ = result->communication_without_parameter_ + COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); diff --git a/mindspore/ccsrc/parallel/ops_info/matmul_info.h b/mindspore/ccsrc/parallel/ops_info/matmul_info.h index 7ced12b14af..2d3312774df 100644 --- a/mindspore/ccsrc/parallel/ops_info/matmul_info.h +++ b/mindspore/ccsrc/parallel/ops_info/matmul_info.h @@ -34,9 +34,7 @@ class MatMulBase : public OperatorInfo { public: MatMulBase(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs) { - matmulcost_ptr = std::make_shared(); - } + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~MatMulBase() override = default; Status Init(const StrategyPtr& strategy) override; @@ -48,7 +46,6 @@ class MatMulBase : public OperatorInfo { Status PrepareStrategy(int32_t stage_id, size_t dev_num, Dimensions combined_partitions, size_t input0_shape_size, size_t input1_shape_size, StrategyPtr* sp); - OperatorCostPtr GetOperatorCost() const override { return matmulcost_ptr; } Status SwapLastTwoElements(Shape* shape); protected: @@ -66,8 +63,6 @@ class MatMulBase : public OperatorInfo { bool transpose_b_ = false; size_t mat_a_dimension_ = 0; size_t mat_b_dimension_ = 0; - - MatMulCostPtr matmulcost_ptr; }; class MatMul : public MatMulBase { diff --git a/mindspore/ccsrc/parallel/ops_info/onehot_info.h b/mindspore/ccsrc/parallel/ops_info/onehot_info.h index 4697e201a43..a54d8479b33 100644 --- a/mindspore/ccsrc/parallel/ops_info/onehot_info.h +++ b/mindspore/ccsrc/parallel/ops_info/onehot_info.h @@ -33,16 +33,13 @@ class OneHotInfo : public OperatorInfo { public: OneHotInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs) { - onehot_cost_ptr_ = std::make_shared(); - } + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~OneHotInfo() override = default; Status Init(const StrategyPtr& strategy) override; Status InitForCostModel(const StrategyPtr& strategy) override; Status GenerateStrategies(int32_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr& strategy) override; - OperatorCostPtr GetOperatorCost() const override { return onehot_cost_ptr_; } ReplaceGraphPtr replace_graph(const CNodePtr& cnode) override; std::shared_ptr>> GenerateBatchStrategies() override; @@ -60,7 +57,6 @@ class OneHotInfo : public OperatorInfo { Status ComputeReplaceGraph(const CNodePtr& cnode); int axis_ = -1; - OneHotCostPtr onehot_cost_ptr_; int32_t rank_ = 0; int32_t total_class_number_ = 1; int32_t classes_each_device_ = 1; diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/parallel/ops_info/operator_info.cc index 11c518d844b..a24f3e616b8 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.cc @@ -1034,12 +1034,11 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr& strategy) { return FAILED; } int32_t stage_id = strategy->GetInputStage(); - double computation_cost = - GetOperatorCost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); - double communication_cost = GetOperatorCost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + double computation_cost = cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + double communication_cost = cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); std::shared_ptr result = std::make_shared(computation_cost, communication_cost); result->communication_without_parameter_ = - GetOperatorCost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); result->communication_with_partial_para_ = result->communication_without_parameter_ + COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); @@ -1096,7 +1095,7 @@ Status OperatorInfo::set_is_parameter(const std::vector& is_parameter) { return FAILED; } is_parameter_ = is_parameter; - GetOperatorCost()->set_is_parameter(is_parameter); + cost()->set_is_parameter(is_parameter); return SUCCESS; } @@ -1193,7 +1192,7 @@ Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector& inpu } inputs_type_lengths_ = input_lengths; outputs_type_lengths_ = output_lengths; - GetOperatorCost()->SetInputAndOutputTypeLength(input_lengths, output_lengths); + cost()->SetInputAndOutputTypeLength(input_lengths, output_lengths); return SUCCESS; } @@ -1211,7 +1210,7 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr& stra } double OperatorInfo::GetForwardMemoryCostFromCNode() { - return GetOperatorCost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0); + return cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0); } } // namespace parallel diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.h b/mindspore/ccsrc/parallel/ops_info/operator_info.h index e7b8af0a7ed..8fcae8ad330 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.h @@ -53,12 +53,13 @@ class Edge; class OperatorInfo { public: - OperatorInfo(std::string name, Shapes inputs_shape, Shapes outputs_shape, PrimitiveAttrs attrs) + OperatorInfo(std::string name, Shapes inputs_shape, Shapes outputs_shape, PrimitiveAttrs attrs, OperatorCostPtr cost) : name_(std::move(name)), inputs_shape_(std::move(inputs_shape)), outputs_shape_(std::move(outputs_shape)), attrs_(std::move(attrs)), - is_alive_(true) { + is_alive_(true), + cost_(cost) { std::vector not_parameteter(inputs_shape_.size(), false); is_parameter_ = not_parameteter; refkey_parameter_name_ = ""; @@ -75,7 +76,8 @@ class OperatorInfo { // Given the stage_id (which indicates the number of devices), // generate all strategies for this operator virtual Status GenerateStrategies(int32_t stage_id) = 0; - virtual OperatorCostPtr GetOperatorCost() const = 0; + const OperatorCostPtr& cost() const { return cost_; } + void set_cost(const OperatorCostPtr& cost) { cost_ = cost; } virtual Status SetCostUnderStrategy(const StrategyPtr& strategy) = 0; virtual std::shared_ptr>> GenerateBatchStrategies(); @@ -115,7 +117,7 @@ class OperatorInfo { void ReplaceSuccEdge(const std::shared_ptr& op, const std::shared_ptr& new_edge); void ReplacePreEdges(const std::shared_ptr& op, const std::shared_ptr& new_edge); void ReplaceSuccEdges(const std::shared_ptr& op, const std::shared_ptr& new_edge); - std::vector GetOutputTypeLengths() const { return GetOperatorCost()->outputs_type_lengths(); } + std::vector GetOutputTypeLengths() const { return cost()->outputs_type_lengths(); } void SetSelectedStrategyAndCost(const StrategyPtr& s_strategy, const CostPtr& cost) { selected_strategy_ = s_strategy; selected_cost_ = cost; @@ -221,6 +223,9 @@ class OperatorInfo { std::string refkey_parameter_name_; CNodePtr cnode_; int32_t used_devices_ = -1; + + private: + OperatorCostPtr cost_; }; Shape GetSliceShape(const Shape& tensor_shape, const Dimensions& strategy); diff --git a/mindspore/ccsrc/parallel/ops_info/prelu_info.h b/mindspore/ccsrc/parallel/ops_info/prelu_info.h index d491ecb3319..bdfb11550b8 100644 --- a/mindspore/ccsrc/parallel/ops_info/prelu_info.h +++ b/mindspore/ccsrc/parallel/ops_info/prelu_info.h @@ -35,15 +35,12 @@ class PReLUInfo : public OperatorInfo { public: PReLUInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs) { - prelucost_ptr = std::make_shared(); - } + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~PReLUInfo() override = default; Status Init(const StrategyPtr& strategy) override; Status InitForCostModel(const StrategyPtr& strategy) override; Status GenerateStrategies(int32_t stage_id) override; - OperatorCostPtr GetOperatorCost() const override { return prelucost_ptr; } Status SetCostUnderStrategy(const StrategyPtr& strategy) override; protected: @@ -59,7 +56,6 @@ class PReLUInfo : public OperatorInfo { private: Dimensions input_strategy_; - PReLUCostPtr prelucost_ptr; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/reduce_method_info.cc b/mindspore/ccsrc/parallel/ops_info/reduce_method_info.cc index 5b07f8d0a96..aa64e72d05e 100644 --- a/mindspore/ccsrc/parallel/ops_info/reduce_method_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/reduce_method_info.cc @@ -109,8 +109,12 @@ Status ReduceMethod::GetAttrs() { } cross_batch_ = cross_batch_iter->second->cast()->value(); } - reducemethodcost_ptr_->set_cross_batch(cross_batch_); - + auto reducemethodcost = std::dynamic_pointer_cast(cost()); + if (reducemethodcost == nullptr) { + MS_LOG(ERROR) << "Cost cast to ReduceMethodCostPtr failed!"; + return FAILED; + } + reducemethodcost->set_cross_batch(cross_batch_); return SUCCESS; } diff --git a/mindspore/ccsrc/parallel/ops_info/reduce_method_info.h b/mindspore/ccsrc/parallel/ops_info/reduce_method_info.h index 8e2e17af99e..c2ddbc87ce3 100644 --- a/mindspore/ccsrc/parallel/ops_info/reduce_method_info.h +++ b/mindspore/ccsrc/parallel/ops_info/reduce_method_info.h @@ -34,9 +34,7 @@ class ReduceMethod : public OperatorInfo { public: ReduceMethod(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs) { - reducemethodcost_ptr_ = std::make_shared(); - } + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~ReduceMethod() override = default; Status Init(const StrategyPtr &strategy) override; @@ -44,13 +42,11 @@ class ReduceMethod : public OperatorInfo { Status GenerateStrategies(int32_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - OperatorCostPtr GetOperatorCost() const override { return reducemethodcost_ptr_; } protected: std::string reduce_method_; bool keepdims_ = false; bool cross_batch_ = false; - ReduceMethodCostPtr reducemethodcost_ptr_; Status CheckStrategy(const StrategyPtr &strategy) override; Status GetAttrs() override; Dimensions InferOutputStrategy(); @@ -110,7 +106,7 @@ class ReduceMeanInfo : public ReduceMethod { ReduceMeanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { - reducemethodcost_ptr_ = std::make_shared(); + set_cost(std::make_shared()); } ~ReduceMeanInfo() override = default; diff --git a/mindspore/ccsrc/parallel/ops_info/reshape_info.h b/mindspore/ccsrc/parallel/ops_info/reshape_info.h index 1d6a14b1f6d..38192a5d017 100644 --- a/mindspore/ccsrc/parallel/ops_info/reshape_info.h +++ b/mindspore/ccsrc/parallel/ops_info/reshape_info.h @@ -36,12 +36,10 @@ class ReshapeInfo : public OperatorInfo { public: ReshapeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs), + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), dev_num_(0), input_layout_set_flag_(false), - output_layout_set_flag_(false) { - reshape_cost_ptr_ = std::make_shared(); - } + output_layout_set_flag_(false) {} ~ReshapeInfo() override = default; Status Init(const StrategyPtr& strategy) override; void SetInputLayout(const TensorLayout& input_layout) { @@ -55,7 +53,6 @@ class ReshapeInfo : public OperatorInfo { Status InitForCostModel(const StrategyPtr& strategy) override; Status GenerateStrategies(int32_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr& strategy) override; - OperatorCostPtr GetOperatorCost() const override { return reshape_cost_ptr_; } protected: Status CheckStrategy(const StrategyPtr& strategy) override; @@ -67,7 +64,6 @@ class ReshapeInfo : public OperatorInfo { Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout); Status GetAttrs() override; Strategys GetOutputsStrategy(); - ReshapeCostPtr reshape_cost_ptr_; private: Status GetParameterInput(); diff --git a/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h b/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h index 6df5856e0ca..cf850683a61 100644 --- a/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h +++ b/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h @@ -34,9 +34,7 @@ class TmpIdentityInfo : public OperatorInfo { public: TmpIdentityInfo(const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs, const std::string& name = IDENTITY_INFO) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs) { - id_cost_ptr_ = std::make_shared(); - } + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~TmpIdentityInfo() override = default; Status Init(const StrategyPtr& strategy) override; @@ -44,7 +42,6 @@ class TmpIdentityInfo : public OperatorInfo { Status GenerateStrategies(int32_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr& strategy) override; - OperatorCostPtr GetOperatorCost() const override { return id_cost_ptr_; } protected: Status CheckStrategy(const StrategyPtr& strategy) override; @@ -54,9 +51,6 @@ class TmpIdentityInfo : public OperatorInfo { Status InferTensorInfo() override; Status InferDevMatrixShape() override; Status InferTensorMap() override; - - private: - TmpIdentityCostPtr id_cost_ptr_; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/transpose_info.h b/mindspore/ccsrc/parallel/ops_info/transpose_info.h index 4f6f6bb6951..2714b352b6f 100644 --- a/mindspore/ccsrc/parallel/ops_info/transpose_info.h +++ b/mindspore/ccsrc/parallel/ops_info/transpose_info.h @@ -35,15 +35,12 @@ class TransposeInfo : public OperatorInfo { public: TransposeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs) { - transpose_cost_ptr_ = std::make_shared(); - } + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~TransposeInfo() override = default; Status Init(const StrategyPtr& strategy) override; Status InitForCostModel(const StrategyPtr& strategy) override; Status GenerateStrategies(int32_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr& strategy) override; - OperatorCostPtr GetOperatorCost() const override { return transpose_cost_ptr_; } protected: Status CheckStrategy(const StrategyPtr& strategy) override; @@ -60,7 +57,6 @@ class TransposeInfo : public OperatorInfo { Status ComputeAxis(); std::vector axis_v_; Dimensions input_strategy_; - ActivationCostPtr transpose_cost_ptr_; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h index d0278f27d9d..b958adeabee 100644 --- a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h +++ b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h @@ -32,16 +32,13 @@ class VirtualDatasetInfo : public OperatorInfo { public: VirtualDatasetInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs) { - vd_cost_ptr_ = std::make_shared(); - } + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~VirtualDatasetInfo() override = default; Status Init(const StrategyPtr& strategy) override; Status InitForCostModel(const StrategyPtr& strategy) override; Status GenerateStrategies(int32_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr& strategy) override; - OperatorCostPtr GetOperatorCost() const override { return vd_cost_ptr_; } void ReComputeBatchSplitFlagList() override; protected: @@ -53,9 +50,6 @@ class VirtualDatasetInfo : public OperatorInfo { Status InferTensorMap() override; Status GetAttrs() override; Status InferAsLossDivisor() override; - - private: - VirtualDatasetCostPtr vd_cost_ptr_; }; } // namespace parallel diff --git a/tests/ut/cpp/parallel/ops_info/activation_test.cc b/tests/ut/cpp/parallel/ops_info/activation_test.cc index 5d18c5372f7..a8f8425ae98 100644 --- a/tests/ut/cpp/parallel/ops_info/activation_test.cc +++ b/tests/ut/cpp/parallel/ops_info/activation_test.cc @@ -84,9 +84,9 @@ TEST_F(TestActivation, test_activation_strategies) { act_ptr_->InitForCostModel(sp); std::vector inputs_info = act_ptr_->inputs_tensor_info(); std::vector outputs_info = act_ptr_->outputs_tensor_info(); - ASSERT_DOUBLE_EQ(act_ptr_->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), + ASSERT_DOUBLE_EQ(act_ptr_->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), cost.computation_cost_); - ASSERT_DOUBLE_EQ(act_ptr_->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), + ASSERT_DOUBLE_EQ(act_ptr_->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), cost.communication_cost_); } } @@ -109,9 +109,9 @@ TEST_F(TestActivation, test_softmax_strategies) { soft_ptr_->InitForCostModel(sp); std::vector inputs_info = soft_ptr_->inputs_tensor_info(); std::vector outputs_info = soft_ptr_->outputs_tensor_info(); - ASSERT_DOUBLE_EQ(soft_ptr_->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), + ASSERT_DOUBLE_EQ(soft_ptr_->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), cost.computation_cost_); - ASSERT_DOUBLE_EQ(soft_ptr_->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), + ASSERT_DOUBLE_EQ(soft_ptr_->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), cost.communication_cost_); } } diff --git a/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc b/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc index 99ca9f8e0ed..2fece098e81 100644 --- a/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc @@ -569,7 +569,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies1) { matmul1->InitForCostModel(sp); std::vector inputs_info = matmul1->inputs_tensor_info(); std::vector outputs_info = matmul1->outputs_tensor_info(); - ASSERT_DOUBLE_EQ(matmul1->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), + ASSERT_DOUBLE_EQ(matmul1->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), cost.computation_cost_); break; } @@ -599,7 +599,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies2) { TensorInfo replica_input1_info(tly, input1_shape, input1_slice_shape); replica_inputs_info.push_back(replica_input1_info); - ASSERT_DOUBLE_EQ(matmul3->GetOperatorCost()->GetComputationCost(replica_inputs_info, outputs_info, sp->GetInputStage()), + ASSERT_DOUBLE_EQ(matmul3->cost()->GetComputationCost(replica_inputs_info, outputs_info, sp->GetInputStage()), cost.computation_cost_); break; } diff --git a/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc b/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc index 6cb9739b1cd..8c956328a77 100644 --- a/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc @@ -188,11 +188,11 @@ TEST_F(TestTensorAddInfo, GenerateStrategies) { tensor_add->InitForCostModel(sp); std::vector inputs_info = tensor_add->inputs_tensor_info(); std::vector outputs_info = tensor_add->outputs_tensor_info(); - double memory_cost0 = tensor_add->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()); + double memory_cost0 = tensor_add->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()); double memory_cost1 = cost.computation_cost_; bool memory = memory_cost0 - memory_cost1 <= 1.0; - double comm_cost0 = tensor_add->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); + double comm_cost0 = tensor_add->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); double comm_cost1 = cost.communication_cost_; bool comm = comm_cost0 - comm_cost1 <= 1.0; @@ -210,11 +210,11 @@ TEST_F(TestTensorAddInfo, GenerateStrategies1) { tensor_add1->InitForCostModel(sp); std::vector inputs_info = tensor_add1->inputs_tensor_info(); std::vector outputs_info = tensor_add1->outputs_tensor_info(); - double memory_cost0 = tensor_add1->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()); + double memory_cost0 = tensor_add1->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()); double memory_cost1 = cost.computation_cost_; bool memory = memory_cost0 - memory_cost1 <= 1.0; - double comm_cost0 = tensor_add1->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); + double comm_cost0 = tensor_add1->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); double comm_cost1 = cost.communication_cost_; bool comm = comm_cost0 - comm_cost1 <= 1.0; diff --git a/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc b/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc index 043746498f1..3971a2b4713 100644 --- a/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc +++ b/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc @@ -145,9 +145,9 @@ TEST_F(TestTmpIdentityInfo, test_generate_strategies) { identity_ptr->Init(sp); std::vector inputs_info = identity_ptr->inputs_tensor_info(); std::vector outputs_info = identity_ptr->outputs_tensor_info(); - ASSERT_DOUBLE_EQ(identity_ptr->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), + ASSERT_DOUBLE_EQ(identity_ptr->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), cost.computation_cost_); - ASSERT_DOUBLE_EQ(identity_ptr->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), + ASSERT_DOUBLE_EQ(identity_ptr->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), cost.communication_cost_); } }