From 3c27c08b4621b24d720c3eba854b9a5686e38957 Mon Sep 17 00:00:00 2001 From: Xiaoda Zhang Date: Mon, 28 Dec 2020 10:06:18 +0800 Subject: [PATCH] change memory cost calculation in auto-parallel --- .../auto_parallel/operator_costmodel.cc | 803 ++++++++++++++++-- .../auto_parallel/operator_costmodel.h | 509 ++++++++--- .../parallel/ops_info/activation_info.h | 44 +- .../parallel/ops_info/arithmetic_info.h | 35 +- .../parallel/ops_info/batch_parallel_info.h | 6 +- .../parallel/ops_info/bias_add_info.h | 2 +- .../parallel/ops_info/broadcast_to_info.h | 2 +- .../ops_info/comparison_function_info.h | 18 +- .../frontend/parallel/ops_info/concat_info.h | 2 +- .../parallel/ops_info/dropout_do_mask_info.h | 2 +- .../ops_info/elementary_function_info.h | 61 +- .../parallel/ops_info/get_next_info.h | 2 +- .../parallel/ops_info/l2_normalize_info.h | 2 +- .../parallel/ops_info/layer_norm_info.h | 2 +- .../frontend/parallel/ops_info/loss_info.h | 3 +- .../frontend/parallel/ops_info/matmul_info.h | 2 +- .../frontend/parallel/ops_info/onehot_info.h | 2 +- .../parallel/ops_info/operator_info.cc | 18 +- .../frontend/parallel/ops_info/pack_info.h | 2 +- .../frontend/parallel/ops_info/prelu_info.h | 2 +- .../frontend/parallel/ops_info/range_info.h | 2 +- .../parallel/ops_info/reduce_method_info.h | 16 +- .../frontend/parallel/ops_info/reluv2_info.h | 2 +- .../frontend/parallel/ops_info/reshape_info.h | 2 +- .../frontend/parallel/ops_info/slice_info.h | 2 +- .../frontend/parallel/ops_info/split_info.h | 2 +- .../parallel/ops_info/strided_slice_info.cc | 2 + .../parallel/ops_info/strided_slice_info.h | 2 +- .../parallel/ops_info/tensordot_info.h | 2 +- .../frontend/parallel/ops_info/tile_info.h | 2 +- .../parallel/ops_info/tmp_identity_info.h | 2 +- .../parallel/ops_info/transpose_info.h | 2 +- .../frontend/parallel/ops_info/unique_info.h | 2 +- .../ops_info/unsorted_segment_op_info.h | 2 +- .../parallel/ops_info/virtual_dataset_info.h | 2 +- .../auto_parallel/operator_costmodel_test.cc | 4 +- 36 files changed, 1255 insertions(+), 312 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc index 7d0d7d81261..4f235ec3171 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc @@ -27,6 +27,7 @@ void OperatorCost::set_is_parameter(const std::vector &is_parameter) { is_ void OperatorCost::set_is_parameter_involve(const std::vector &is_parameter_inv) { is_parameter_involve_ = is_parameter_inv; + is_inputs_should_in_memory_ = std::vector(is_parameter_involve_.size(), false); } void OperatorCost::set_output_parameter_involve(int64_t output_para) { output_parameter_involve_ = output_para; } @@ -41,27 +42,28 @@ void OperatorCost::set_output_critical(int64_t critical) { is_outputs_critical_ double OperatorCost::GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const { + return GetInputMemoryCost(inputs, outputs) + GetOutputMemoryCost(inputs, outputs); +} + +double OperatorCost::GetInputMemoryCost(const std::vector &inputs, const std::vector &) const { double result = 0.0; - if (output_parameter_involve_ == 1) { + for (size_t i = 0; i < inputs.size(); ++i) { + if (is_inputs_should_in_memory_[i]) { + result += ListProduct(inputs[i].slice_shape()) * static_cast(inputs_type_lengths_[i]); + } + } + return result; +} + +double OperatorCost::GetOutputMemoryCost(const std::vector &inputs, + const std::vector &outputs) const { + double result = 0.0; + if (is_output_should_in_memory_) { // When this operator has multiple outputs, they all contributes to the memory. for (size_t i = 0; i < outputs.size(); ++i) { result += ListProduct(outputs[i].slice_shape()) * static_cast(outputs_type_lengths_[i]); } - bool is_any_para_inv = - std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; }); - if (is_any_para_inv) { - for (size_t i = 0; i < inputs.size(); ++i) { - if (is_parameter_[i]) { - result += ListProduct(inputs[i].slice_shape()) * static_cast(inputs_type_lengths_[i]); - } else if (inputs_related_ && (!is_parameter_involve_[i])) { - // When the inputs of this operator are related, and they are not parameter-involved, then they are included - // in the memory cost. - result += ListProduct(inputs[i].slice_shape()) * static_cast(inputs_type_lengths_[i]); - } - } - } } - return result; } @@ -166,16 +168,43 @@ double MatMulCost::GetBackwardComputationCost(const std::vector &inp return result; } +// Not taking account of output +void MatMulCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } + +// Taking account of input +void MatMulCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + if (is_parameter_[0]) { + is_inputs_should_in_memory_[0] = true; + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } else if (is_parameter_involve_[0]) { + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } + + if (is_parameter_[1]) { + is_inputs_should_in_memory_[1] = true; + if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { + is_inputs_should_in_memory_[0] = true; + } + } else if (is_parameter_involve_[1]) { + if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { + is_inputs_should_in_memory_[0] = true; + } + } +} + // Return the per device communication cost in the forward phase. -double ActivationCost::GetForwardCommCost(const std::vector &, const std::vector &, - int64_t) const { +double CastCost::GetForwardCommCost(const std::vector &, const std::vector &, int64_t) const { // ReLU is the element-wise operator, thus it does not need communication in the forward phase return 0.0; } // Return the per device communication cost in the backward phase. -double ActivationCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int64_t stage_id) const { +double CastCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int64_t stage_id) const { double result = 0.0; if (is_parameter_[0]) { TensorInfo input1 = inputs[0]; @@ -196,8 +225,8 @@ double ActivationCost::GetBackwardCommCost(const std::vector &inputs // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double ActivationCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, - int64_t) const { +double CastCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, + int64_t) const { TensorInfo input0 = inputs[0]; Shape input0_slice_shape = input0.slice_shape(); return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); @@ -205,11 +234,33 @@ double ActivationCost::GetForwardComputationCost(const std::vector & // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double ActivationCost::GetBackwardComputationCost(const std::vector &, const std::vector &, - int64_t) const { +double CastCost::GetBackwardComputationCost(const std::vector &, const std::vector &, + int64_t) const { return 0.0; } +// Not taking account of output +void CastCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } + +// Not taking account of input +void CastCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + is_inputs_should_in_memory_[0] = is_parameter_[0]; +} + +// Taking account of output +void SqrtCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; } + +// Taking account of input +void GeLUCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + if (is_parameter_[0]) { + is_inputs_should_in_memory_[0] = true; + } else if (is_parameter_involve_[0]) { + if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { + is_inputs_should_in_memory_[0] = true; + } + } +} + // Return the per device communication cost in the forward phase. double SoftmaxCost::GetForwardCommCost(const std::vector &, const std::vector &, int64_t) const { @@ -259,6 +310,81 @@ double SoftmaxCost::GetBackwardComputationCost(const std::vector &prev_output_in_mem) { + is_inputs_should_in_memory_[0] = is_parameter_[0]; +} + +// Not taking account of output +void PackCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } + +// Not taking account of input +void PackCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + is_inputs_should_in_memory_[0] = is_parameter_[0]; +} + +// Not taking account of output +void TileCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } + +// Taking account of input +void TileCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + // When calculating 'dx', taking account of 'y' + if (is_parameter_[0]) { + is_inputs_should_in_memory_[0] = true; + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } else if (is_parameter_involve_[0]) { + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } + + if (!is_inputs_should_in_memory_[1]) { + is_inputs_should_in_memory_[1] = is_parameter_[1]; + } +} + +// Not taking account of output +void BroadcastToCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } + +void BroadcastToCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + is_inputs_should_in_memory_[0] = is_parameter_[0]; +} + +// Taking account of input +void ReLU6Cost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + if (is_parameter_[0]) { + is_inputs_should_in_memory_[0] = true; + } else if (is_parameter_involve_[0]) { + if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { + is_inputs_should_in_memory_[0] = true; + } + } +} + +// Taking account of input +void TransposeCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + // When calulating 'dx', taking account of 'y' + if (is_parameter_[0]) { + is_inputs_should_in_memory_[0] = true; + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } else if (is_parameter_involve_[0]) { + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } + + if (!is_inputs_should_in_memory_[1]) { + is_inputs_should_in_memory_[1] = is_parameter_[1]; + } +} + // return the per device communication cost in the forward phase. double TmpIdentityCost::GetForwardCommCost(const std::vector &, const std::vector &, int64_t) const { @@ -288,9 +414,12 @@ double TmpIdentityCost::GetBackwardComputationCost(const std::vector &, const std::vector &) const { - return 0.0; +// Not taking account of output +void TmpIdentityCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } + +// Not taking account of input +void TmpIdentityCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + is_inputs_should_in_memory_[0] = is_parameter_[0]; } double BatchParallelCost::GetForwardComputationCost(const std::vector &inputs, @@ -334,6 +463,42 @@ double BatchParallelCost::GetBackwardCommCost(const std::vector &inp return result; } + +void BatchParallelCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } + +void BatchParallelCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + if (is_parameter_[0]) { + is_inputs_should_in_memory_[0] = true; + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } else if (is_parameter_involve_[0]) { + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } + + if (is_parameter_[1]) { + is_inputs_should_in_memory_[1] = true; + if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { + is_inputs_should_in_memory_[0] = true; + } + } else if (is_parameter_involve_[1]) { + if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { + is_inputs_should_in_memory_[0] = true; + } + } +} + +void SparseSoftmaxCrossEntropyWithLogitsCost::CalculateOutputInMemory() { + is_output_should_in_memory_ = is_parameter_involve_[0]; +} + +void SparseSoftmaxCrossEntropyWithLogitsCost::CalculateInputsInMemory( + const std::map &prev_output_in_mem) { + is_inputs_should_in_memory_[0] = is_parameter_[0]; + is_inputs_should_in_memory_[1] = is_parameter_[1]; +} // return the per device communication cost in the forward phase. double PReLUCost::GetForwardCommCost(const std::vector &, const std::vector &, int64_t) const { // prelu does not need communication in the forward phase @@ -401,6 +566,21 @@ double PReLUCost::GetBackwardComputationCost(const std::vector &prev_output_in_mem) { + // When calculating 'dx', taking account of both 'x' and 'y'; + // when calculating 'dy', taking account of both 'x' and 'y' + if (is_parameter_involve_[0] || is_parameter_involve_[1]) { + if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { + is_inputs_should_in_memory_[0] = true; + } + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } +} + // return the per device communication cost in the forward phase. double OneHotCost::GetForwardCommCost(const std::vector &, const std::vector &, int64_t) const { // onehot does not need communication in the forward phase @@ -430,6 +610,17 @@ double OneHotCost::GetBackwardComputationCost(const std::vector &, c return 0.0; } +// Not taking account of output +void OneHotCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } + +// Not taking account of input +void OneHotCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + is_inputs_should_in_memory_[0] = is_parameter_[0]; + is_inputs_should_in_memory_[1] = is_parameter_[1]; + is_inputs_should_in_memory_[2] = is_parameter_[2]; + is_inputs_should_in_memory_[3] = is_parameter_[3]; +} + // return the per device communication cost in the forward phase. double SoftmaxCrossEntropyWithLogitsCost::GetForwardCommCost(const std::vector &, const std::vector &, int64_t) const { @@ -463,6 +654,16 @@ double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std:: return 0.0; } +// Taking account of output +void SoftmaxCrossEntropyWithLogitsCost::CalculateOutputInMemory() { + is_output_should_in_memory_ = is_parameter_involve_[0]; +} + +void SoftmaxCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + is_inputs_should_in_memory_[0] = is_parameter_[0]; + is_inputs_should_in_memory_[1] = is_parameter_[1]; +} + // return the per device communication cost in the forward phase. double ReshapeCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int64_t stage_id) const { @@ -524,50 +725,22 @@ double ReshapeCost::GetBackwardComputationCost(const std::vector &inputs, const std::vector &, - int64_t) const { +void ReshapeCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } + +void ReshapeCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + is_inputs_should_in_memory_[0] = is_parameter_[0]; + is_inputs_should_in_memory_[1] = is_parameter_[1]; +} + +double SubCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, + int64_t) const { double result; result = ListProduct(inputs[0].slice_shape()) * static_cast(inputs_type_lengths_[0]) + ListProduct(inputs[1].slice_shape()) * static_cast(inputs_type_lengths_[1]); return result; } -double ArithmeticCost::GetBackwardComputationCost(const std::vector &inputs, - const std::vector &, int64_t stage_id) const { - double result = 0.0; - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - if (is_parameter_[0]) { - TensorInfo input_a_tensor_info = inputs[0]; - Shape input_a_shape = input_a_tensor_info.shape(); - Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); - int64_t used_device_num = 1; - for (size_t i = 0; i < input_a_shape.size(); ++i) { - used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; - } - - if (total_device_num != LongToSize(used_device_num)) - result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); - } - - if (is_parameter_[1]) { - TensorInfo input_b_tensor_info = inputs[1]; - Shape input_b_shape = input_b_tensor_info.shape(); - Shape input_b_slice_shape = input_b_tensor_info.slice_shape(); - int64_t used_device_num = 1; - for (size_t i = 0; i < input_b_shape.size(); ++i) { - used_device_num *= input_b_shape[i] / input_b_slice_shape[i]; - } - - if (total_device_num != LongToSize(used_device_num)) - result += ListProduct(input_b_slice_shape) * static_cast(inputs_type_lengths_[1]); - } - return result; -} - -double ArithmeticCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, +double SubCost::GetBackwardComputationCost(const std::vector &inputs, const std::vector &, int64_t stage_id) const { double result = 0.0; CheckGlobalDeviceManager(); @@ -587,6 +760,41 @@ double ArithmeticCost::GetBackwardCommCost(const std::vector &inputs result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); } + if (is_parameter_[1]) { + TensorInfo input_b_tensor_info = inputs[1]; + Shape input_b_shape = input_b_tensor_info.shape(); + Shape input_b_slice_shape = input_b_tensor_info.slice_shape(); + int64_t used_device_num = 1; + for (size_t i = 0; i < input_b_shape.size(); ++i) { + used_device_num *= input_b_shape[i] / input_b_slice_shape[i]; + } + + if (total_device_num != LongToSize(used_device_num)) + result += ListProduct(input_b_slice_shape) * static_cast(inputs_type_lengths_[1]); + } + return result; +} + +double SubCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int64_t stage_id) const { + double result = 0.0; + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + if (is_parameter_[0]) { + TensorInfo input_a_tensor_info = inputs[0]; + Shape input_a_shape = input_a_tensor_info.shape(); + Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); + int64_t used_device_num = 1; + for (size_t i = 0; i < input_a_shape.size(); ++i) { + used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; + } + + if (total_device_num != LongToSize(used_device_num)) + result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); + } + if (is_parameter_[1]) { TensorInfo input_b_tensor_info = inputs[1]; Shape input_b_shape = input_b_tensor_info.shape(); @@ -603,6 +811,273 @@ double ArithmeticCost::GetBackwardCommCost(const std::vector &inputs return result; } +// Not taking account of output +void SubCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } + +// Not taking account of input +void SubCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + is_inputs_should_in_memory_[0] = is_parameter_[0]; + is_inputs_should_in_memory_[1] = is_parameter_[1]; +} + +// Taking account of input +void MulCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + if (is_parameter_[0]) { + // 'x' is parameter, so it should be in memory. + is_inputs_should_in_memory_[0] = true; + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here. + is_inputs_should_in_memory_[1] = true; + } + } else if (is_parameter_involve_[0]) { + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } + + if (is_parameter_[1]) { + is_inputs_should_in_memory_[1] = true; + if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { + is_inputs_should_in_memory_[0] = true; + } + } else if (is_parameter_involve_[1]) { + if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { + is_inputs_should_in_memory_[0] = true; + } + } +} + +// Taking account of output +void DivCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; } + +// Taking account of input +void DivCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + // When calculating 'dx', taking account of 'y' + if (is_parameter_[0]) { + // 'x' is parameter, so it should be in memory. + is_inputs_should_in_memory_[0] = true; + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here. + is_inputs_should_in_memory_[1] = true; + } + } else if (is_parameter_involve_[0]) { + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } + + // When calculating 'dy', taking account of 'y' + if (is_parameter_involve_[1]) { + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } +} + +// Taking account of input +void ModCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + // When calculating 'dx', not taking account of 'x' and 'y' + is_inputs_should_in_memory_[0] = is_parameter_[0]; + // When calculating 'dy', taking account of 'x' and 'y' + if (is_parameter_involve_[1]) { + if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { + is_inputs_should_in_memory_[0] = true; + } + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } +} + +void PowCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; } + +void PowCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + // When calculating 'dx', taking account of both 'x' and 'power' + if (is_parameter_involve_[0]) { + if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { + is_inputs_should_in_memory_[0] = true; + } + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } + // When calculating 'dpower', taking account of 'x' + if (is_parameter_[1]) { + is_inputs_should_in_memory_[1] = true; + if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { + is_inputs_should_in_memory_[0] = true; + } + } else if (is_parameter_involve_[1]) { + if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { + is_inputs_should_in_memory_[0] = true; + } + } +} + +void AssignCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + // When calculating 'dx', taking account of 'x' + if (is_parameter_involve_[0]) { + if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { + is_inputs_should_in_memory_[0] = true; + } + } + // When calculating 'dy', not taking account of 'x' and 'y' + is_inputs_should_in_memory_[1] = is_parameter_[1]; +} + +void SigmoidCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + // When calculating 'dx', taking account of both 'x' and 'y' + if (is_parameter_involve_[0]) { + if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { + is_inputs_should_in_memory_[0] = true; + } + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } + // When calculating 'dy', not taking account of 'x' and 'y' + if (!is_inputs_should_in_memory_[1]) { + is_inputs_should_in_memory_[1] = is_parameter_[1]; + } +} + +void Atan2Cost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + // When calculating 'dx', taking account of both 'x' and 'y'; when calculating 'dy', taking account of both 'x' and + // 'y' + if (is_parameter_involve_[0] || is_parameter_involve_[1]) { + if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { + is_inputs_should_in_memory_[0] = true; + } + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } +} + +void DivNoNanCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; } + +void DivNoNanCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + // When calculating 'dx', taking account of 'y' + if (is_parameter_[0]) { + // 'x' is parameter, so it should be in memory. + is_inputs_should_in_memory_[0] = true; + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here. + is_inputs_should_in_memory_[1] = true; + } + } else if (is_parameter_involve_[0]) { + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } + + // When calculating 'dy', taking account of 'y' + if (is_parameter_involve_[1]) { + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } +} + +void MaximumCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + // When calculating 'dx', taking account of both 'x' and 'y'; + // when calculating 'dy', taking account of both 'x' and 'y' + if (is_parameter_involve_[0] || is_parameter_involve_[1]) { + if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { + is_inputs_should_in_memory_[0] = true; + } + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } +} + +void SliceCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + // When calculating 'dx', taking account of 'y' and 'z' + if (is_parameter_[0]) { + is_inputs_should_in_memory_[0] = true; + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) { + is_inputs_should_in_memory_[2] = true; + } + } else if (is_parameter_involve_[0]) { + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) { + is_inputs_should_in_memory_[2] = true; + } + } + + if (!is_inputs_should_in_memory_[1]) { + is_inputs_should_in_memory_[1] = is_parameter_[1]; + } + if (!is_inputs_should_in_memory_[2]) { + is_inputs_should_in_memory_[2] = is_parameter_[2]; + } +} + +void StridedSliceCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + // When calculating 'dx', taking account of 'y', 'z' and 'w' + if (is_parameter_[0]) { + is_inputs_should_in_memory_[0] = true; + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) { + is_inputs_should_in_memory_[2] = true; + } + if ((prev_output_in_mem.find(3) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(3))) { + is_inputs_should_in_memory_[3] = true; + } + } else if (is_parameter_involve_[0]) { + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) { + is_inputs_should_in_memory_[2] = true; + } + if ((prev_output_in_mem.find(3) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(3))) { + is_inputs_should_in_memory_[3] = true; + } + } + + if (!is_inputs_should_in_memory_[1]) { + is_inputs_should_in_memory_[1] = is_parameter_[1]; + } + if (!is_inputs_should_in_memory_[2]) { + is_inputs_should_in_memory_[2] = is_parameter_[2]; + } + if (!is_inputs_should_in_memory_[3]) { + is_inputs_should_in_memory_[3] = is_parameter_[3]; + } +} + +void DropOutDoMaskCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } + +void DropOutDoMaskCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + // When calculating 'dx', taking account of 'y' + if (is_parameter_[0]) { + // 'x' is parameter, so it should be in memory. + is_inputs_should_in_memory_[0] = true; + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here. + is_inputs_should_in_memory_[1] = true; + } + } else if (is_parameter_involve_[0]) { + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } + + if (!is_inputs_should_in_memory_[1]) { + is_inputs_should_in_memory_[1] = is_parameter_[1]; + } + is_inputs_should_in_memory_[2] = is_parameter_[2]; +} + bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int64_t stage_id) { CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); @@ -612,8 +1087,8 @@ bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int64_t stage_ return (total_device_num == LongToSize(strategy0)); } -double ReduceMethodCost::GetForwardCommCost(const std::vector &inputs, - const std::vector &outputs, int64_t stage_id) const { +double ReduceSumCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int64_t stage_id) const { double result = 0.0; TensorInfo input0 = inputs[0]; TensorInfo output0 = outputs[0]; @@ -634,8 +1109,8 @@ double ReduceMethodCost::GetForwardCommCost(const std::vector &input return result; } -double ReduceMethodCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int64_t stage_id) const { +double ReduceSumCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int64_t stage_id) const { double result = 0.0; if (is_parameter_[0]) { TensorInfo input_tensor_info = inputs[0]; @@ -657,8 +1132,8 @@ double ReduceMethodCost::GetBackwardCommCost(const std::vector &inpu return result; } -double ReduceMethodCost::GetForwardComputationCost(const std::vector &inputs, - const std::vector &outputs, int64_t stage_id) const { +double ReduceSumCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int64_t stage_id) const { double result = 0.0; TensorInfo input0 = inputs[0]; TensorInfo output0 = outputs[0]; @@ -679,6 +1154,30 @@ double ReduceMethodCost::GetForwardComputationCost(const std::vector return result; } +// Not taking account of output +void ReduceSumCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } + +void ReduceSumCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + // When calculating 'dx', taking account of 'y' + if (is_parameter_[0]) { + // 'x' is parameter, so it should be in memory. + is_inputs_should_in_memory_[0] = true; + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here. + is_inputs_should_in_memory_[1] = true; + } + } else if (is_parameter_involve_[0]) { + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } + + // Not taking account of 'y' + if (!is_inputs_should_in_memory_[1]) { + is_inputs_should_in_memory_[1] = is_parameter_[1]; + } +} + double ReduceMeanCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int64_t stage_id) const { double result = 0.0; @@ -701,6 +1200,42 @@ double ReduceMeanCost::GetForwardComputationCost(const std::vector & return result; } +void ReduceMinCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; } + +void ReduceMinCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + // When calculating 'dx', taking account of 'y' + if (is_parameter_[0]) { + // 'x' is parameter, so it should be in memory. + is_inputs_should_in_memory_[0] = true; + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here. + is_inputs_should_in_memory_[1] = true; + } + } else if (is_parameter_involve_[0]) { + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } + + // Not taking account of 'y' + if (!is_inputs_should_in_memory_[1]) { + is_inputs_should_in_memory_[1] = is_parameter_[1]; + } +} + +void ArgMaxWithValueCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; } + +void ArgMaxWithValueCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + // When calculating 'dx', taking account of 'x' + if (is_parameter_[0]) { + is_inputs_should_in_memory_[0] = true; + } else if (is_parameter_involve_[0]) { + if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { + is_inputs_should_in_memory_[0] = true; + } + } +} + double DropOutCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int64_t) const { if (inputs.empty()) { @@ -760,6 +1295,52 @@ double GatherV2Cost::GetBackwardComputationCost(const std::vector &, return 0.0; } +// Not taking account of output +void GatherV2Cost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } + +void GatherV2Cost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + // When calculating 'dx', taking account of 'y' and 'z' + if (is_parameter_[0]) { + // 'x' is parameter, so it should be in memory. + is_inputs_should_in_memory_[0] = true; + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) { + is_inputs_should_in_memory_[2] = true; + } + } else if (is_parameter_involve_[0]) { + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) { + is_inputs_should_in_memory_[2] = true; + } + } + + if (!is_inputs_should_in_memory_[1]) { + is_inputs_should_in_memory_[1] = is_parameter_[1]; + } + if (!is_inputs_should_in_memory_[2]) { + is_inputs_should_in_memory_[2] = is_parameter_[2]; + } +} + +void GetNextCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } + +void GetNextCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + if (is_inputs_should_in_memory_.size() == 0) { + return; + } + is_inputs_should_in_memory_[0] = is_parameter_[0]; +} + +void UniqueCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; } + +void UniqueCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + is_inputs_should_in_memory_[0] = is_parameter_[0]; +} + double LayerNormCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int64_t stage_id) const { double result = 0.0; @@ -808,6 +1389,24 @@ double LayerNormCost::GetForwardComputationCost(const std::vector &i return result; } +void LayerNormCost::CalculateOutputInMemory() { + is_output_should_in_memory_ = is_parameter_involve_[0] || is_parameter_involve_[1] || is_parameter_involve_[2]; +} + +void LayerNormCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + // When calculating 'dx', taking account of both 'x' and 'y' + // When calculating 'dy', taking account of both 'x' and 'y' + if (is_parameter_involve_[0] || is_parameter_involve_[1]) { + if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { + is_inputs_should_in_memory_[0] = true; + } + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } + is_inputs_should_in_memory_[2] = is_parameter_[2]; +} + double UniqueCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int64_t stage_id) const { return 0.0; @@ -924,6 +1523,12 @@ double UniformCandidateSamplerCost::GetForwardComputationCost(const std::vector< return result; } +void UniformCandidateSamplerCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } + +void UniformCandidateSamplerCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + is_inputs_should_in_memory_[0] = is_parameter_[0]; +} + double GatherV2PCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int64_t stage_id) const { double result = 0.0; @@ -1019,6 +1624,29 @@ double UnsortedSegmentSumCost::GetForwardComputationCost(const std::vector &prev_output_in_mem) { + // When calculating 'dx', taking account of 'y' + if (is_parameter_[0]) { + is_inputs_should_in_memory_[0] = true; + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } else if (is_parameter_involve_[0]) { + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + } + + if (!is_inputs_should_in_memory_[1]) { + is_inputs_should_in_memory_[1] = is_parameter_[1]; + } + is_inputs_should_in_memory_[2] = is_parameter_[2]; +} + double UnsortedSegmentMinCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int64_t stage_id) const { TensorInfo input0 = inputs[0]; @@ -1078,5 +1706,40 @@ double UnsortedSegmentMinCost::GetForwardComputationCost(const std::vector(outputs_type_lengths_[0]); // ReduceMin return result; } + +// Taking account of output +void UnsortedSegmentMinCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; } + +// Taking account of input +void UnsortedSegmentMinCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + // When calculating 'dx', taking account of 'x', 'y' and 'z' + if (is_parameter_involve_[0]) { + if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { + is_inputs_should_in_memory_[0] = true; + } + if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { + is_inputs_should_in_memory_[1] = true; + } + if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) { + is_inputs_should_in_memory_[2] = true; + } + } + if (!is_inputs_should_in_memory_[1]) { + is_inputs_should_in_memory_[1] = is_parameter_[1]; + } + if (!is_inputs_should_in_memory_[2]) { + is_inputs_should_in_memory_[2] = is_parameter_[2]; + } +} + +// Not taking account of output +void VirtualDatasetCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } + +// Not taking account of input +void VirtualDatasetCost::CalculateInputsInMemory(const std::map &prev_output_in_mem) { + for (size_t i = 0; i < is_inputs_should_in_memory_.size(); ++i) { + is_inputs_should_in_memory_[i] = is_parameter_[i]; + } +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h index 352145d46b6..bf68b8dc5bc 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h @@ -19,6 +19,7 @@ #include #include +#include #include "frontend/parallel/device_manager.h" #include "frontend/parallel/tensor_layout/tensor_info.h" @@ -47,16 +48,7 @@ double ListProduct(std::vector vec) { // entries timing the length of each entry's data type class OperatorCost { public: - explicit OperatorCost(bool is_inputs_related) : inputs_related_(is_inputs_related) { - // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked - for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) { - is_parameter_.push_back(false); - is_parameter_involve_.push_back(false); - inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); - outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); - } - } - OperatorCost() : inputs_related_(false) { + OperatorCost() { // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) { is_parameter_.push_back(false); @@ -89,10 +81,17 @@ class OperatorCost { const std::vector &outputs, int64_t stage_id) const = 0; virtual double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int64_t stage_id) const = 0; + virtual void CalculateOutputInMemory() = 0; + virtual void CalculateInputsInMemory(const std::map &prev_output_in_mem) = 0; + bool is_output_in_memory() const { return is_output_should_in_memory_; } // per device PEAK memory cost in a training iteration - // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled), + // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-involved), // plus necessary inputs. virtual double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const; + // Contributing the input part for 'GetMemoryCost' + double GetInputMemoryCost(const std::vector &inputs, const std::vector &outputs) const; + // Contributing the output part for 'GetMemoryCost' + double GetOutputMemoryCost(const std::vector &inputs, const std::vector &outputs) const; // per device memory cost in a inference phase double GetMemoryCostForInference(const std::vector &, const std::vector &) const; @@ -101,25 +100,25 @@ class OperatorCost { // pre-operator that has parameters as input. std::vector is_parameter_involve_; int64_t output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved - // Whether the inputs are related or not? For example, TensorAdd's two inputs are independent (not related), while - // Mul's two inputs are dependent (related). - bool inputs_related_; - // for each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter + // For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter std::vector is_parameter_; - // for each input and output, the followings record the number of bytes of each element + // Whether the input should keep in memory in training phase. It depends on the operator and the operator's + // previous operators. + std::vector is_inputs_should_in_memory_; + // Whether the output should keep in memory in training phase. It depends on 'is_parameter_involve_' and the operator. + bool is_output_should_in_memory_ = false; + // For each input and output, the followings record the number of bytes of each element std::vector inputs_type_lengths_; std::vector outputs_type_lengths_; // Whether the output is critical, which means that this output is included in calculating peak memory cost // in the inference phase. int64_t is_outputs_critical_ = -1; }; - using OperatorCostPtr = std::shared_ptr; class MatMulCost : public OperatorCost { public: - explicit MatMulCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - MatMulCost() : OperatorCost(true) {} + MatMulCost() : OperatorCost() {} ~MatMulCost() override = default; // per device communication cost @@ -141,14 +140,15 @@ class MatMulCost : public OperatorCost { int64_t stage_id) const override; double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int64_t stage_id) const override; + void CalculateOutputInMemory() override; + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; }; -using MatMulCostPtr = std::shared_ptr; +using TensorDotCost = MatMulCost; -class ActivationCost : public OperatorCost { +class CastCost : public OperatorCost { public: - explicit ActivationCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - ActivationCost() : OperatorCost(false) {} - ~ActivationCost() override = default; + CastCost() : OperatorCost() {} + ~CastCost() override = default; double GetCommCost(const std::vector &inputs, const std::vector &outputs, int64_t stage_id) const override { @@ -166,21 +166,95 @@ class ActivationCost : public OperatorCost { int64_t stage_id) const override; double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int64_t stage_id) const override; + // Not taking account of output + void CalculateOutputInMemory() override; + // Not Taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; }; -using ActivationCostPtr = std::shared_ptr; -using TransposeCost = ActivationCost; -using TransposeCostPtr = std::shared_ptr; -using StridedSliceCost = ActivationCost; -using StridedSliceCostPtr = std::shared_ptr; -using SliceCost = ActivationCost; -using SliceCostPtr = std::shared_ptr; -using SplitCost = ActivationCost; -using SplitCostPtr = std::shared_ptr; +using RepeatElementsCost = CastCost; +using NegCost = CastCost; +using ExpandDimsCost = CastCost; +using SqueezeCost = CastCost; +using ConcatCost = CastCost; +using LogicalNotCost = CastCost; +using SignCost = CastCost; +using FloorCost = CastCost; +using RoundCost = CastCost; +using CeilCost = CastCost; +using ZerosLikeCost = CastCost; +using OnesLikeCost = CastCost; +using RangeCost = CastCost; +using SplitCost = CastCost; + +class SqrtCost : public CastCost { + public: + SqrtCost() : CastCost() {} + ~SqrtCost() override = default; + // Taking account of output, not taking accounting of input + void CalculateOutputInMemory() override; +}; +using TanhCost = SqrtCost; +using EluCost = SqrtCost; +using ReLUCost = SqrtCost; +using SigmoidCost = SqrtCost; +using ReciprocalCost = + SqrtCost; // The derivative of 'Reciprocal' is different on 'Ascend' and 'GPU'. Here, 'Ascend' is chosen +using InvCost = SqrtCost; +using RsqrtCost = SqrtCost; +using AsinhCost = SqrtCost; +using AcoshCost = SqrtCost; +using ReLUV2Cost = SqrtCost; + +class ReLU6Cost : public CastCost { + public: + ReLU6Cost() : CastCost() {} + ~ReLU6Cost() override = default; + // Taking account of input, not taking account of output + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; +}; +using SoftsignCost = ReLU6Cost; +using SoftplusCost = ReLU6Cost; +using SquareCost = ReLU6Cost; +using ExpCost = ReLU6Cost; +using LogCost = ReLU6Cost; +using CosCost = ReLU6Cost; +using ACosCost = ReLU6Cost; +using AbsCost = ReLU6Cost; +using TanCost = ReLU6Cost; +using SinCost = ReLU6Cost; +using SinhCost = ReLU6Cost; +using Log1pCost = ReLU6Cost; +using Expm1Cost = ReLU6Cost; +using CoshCost = ReLU6Cost; +using AtanhCost = ReLU6Cost; +using AtanCost = ReLU6Cost; +using AsinCost = ReLU6Cost; +using ErfCost = ReLU6Cost; +using ErfcCost = ReLU6Cost; +using ActivationInfoCost = ReLU6Cost; + +class TransposeCost : public CastCost { + public: + TransposeCost() : CastCost() {} + ~TransposeCost() override = default; + // Taking account of input, not taking account of output + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; +}; + +class GeLUCost : public SqrtCost { + public: + GeLUCost() : SqrtCost() {} + ~GeLUCost() override = default; + // Taking account of input and output + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; +}; +using BesselI0eCost = GeLUCost; +using BesselI1eCost = GeLUCost; +using L2NormalizeCost = GeLUCost; class SoftmaxCost : public OperatorCost { public: - explicit SoftmaxCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - SoftmaxCost() : OperatorCost(false) {} + SoftmaxCost() : OperatorCost() {} ~SoftmaxCost() override = default; double GetCommCost(const std::vector &inputs, const std::vector &outputs, @@ -199,21 +273,45 @@ class SoftmaxCost : public OperatorCost { int64_t stage_id) const override; double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int64_t) const override; + // Taking account of output + void CalculateOutputInMemory() override; + // Not Taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; +}; + +class TileCost : public SoftmaxCost { + public: + TileCost() : SoftmaxCost() {} + ~TileCost() override = default; + // Not taking account of output + void CalculateOutputInMemory() override; + // Taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; +}; + +class PackCost : public SoftmaxCost { + public: + PackCost() : SoftmaxCost() {} + ~PackCost() override = default; + // Not taking account of output + void CalculateOutputInMemory() override; + // Not taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; +}; + +class BroadcastToCost : public SoftmaxCost { + public: + BroadcastToCost() : SoftmaxCost() {} + ~BroadcastToCost() override = default; + // Not taking account of output + void CalculateOutputInMemory() override; + // Not Taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; }; -using SoftmaxCostPtr = std::shared_ptr; -using TileCost = SoftmaxCost; -using TileCostPtr = std::shared_ptr; -using PackCost = TileCost; -using PackCostPtr = std::shared_ptr; -using ConcatCost = TileCost; -using ConcatCostPtr = std::shared_ptr; -using BroadcastToCost = SoftmaxCost; -using BroadcastToCostPtr = std::shared_ptr; class TmpIdentityCost : public OperatorCost { public: - explicit TmpIdentityCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - TmpIdentityCost() : OperatorCost(false) {} + TmpIdentityCost() : OperatorCost() {} ~TmpIdentityCost() override = default; double GetCommCost(const std::vector &inputs, const std::vector &outputs, @@ -232,15 +330,16 @@ class TmpIdentityCost : public OperatorCost { int64_t stage_id) const override; double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int64_t stage_id) const override; - // per device PEAK memory cost in a training iteration - double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const override; + // Not taking account of output + void CalculateOutputInMemory() override; + // Not taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; }; using TmpIdentityCostPtr = std::shared_ptr; class BatchParallelCost : public OperatorCost { public: - explicit BatchParallelCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - BatchParallelCost() : OperatorCost(false) {} + BatchParallelCost() : OperatorCost() {} ~BatchParallelCost() override = default; double GetCommCost(const std::vector &inputs, const std::vector &outputs, @@ -259,13 +358,25 @@ class BatchParallelCost : public OperatorCost { int64_t stage_id) const override; double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int64_t stage_id) const override; + // Not taking account of output + void CalculateOutputInMemory() override; + // Taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; +}; + +class SparseSoftmaxCrossEntropyWithLogitsCost : public BatchParallelCost { + public: + SparseSoftmaxCrossEntropyWithLogitsCost() : BatchParallelCost() {} + ~SparseSoftmaxCrossEntropyWithLogitsCost() override = default; + // Taking account of output + void CalculateOutputInMemory() override; + // Not taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; }; -using BatchParallelCostPtr = std::shared_ptr; class VirtualDatasetCost : public OperatorCost { public: - explicit VirtualDatasetCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - VirtualDatasetCost() : OperatorCost(false) {} + VirtualDatasetCost() : OperatorCost() {} ~VirtualDatasetCost() override = default; double GetCommCost(const std::vector &inputs, const std::vector &outputs, @@ -290,17 +401,15 @@ class VirtualDatasetCost : public OperatorCost { int64_t) const override { return 0.0; } - // per device PEAK memory cost in a training iteration - double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const override { - return 0.0; - } + // Not taking account of output + void CalculateOutputInMemory() override; + // Not taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; }; -using VirtualDatasetCostPtr = std::shared_ptr; class GeneratorBaseCost : public OperatorCost { public: - explicit GeneratorBaseCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - GeneratorBaseCost() : OperatorCost(false) {} + GeneratorBaseCost() : OperatorCost() {} ~GeneratorBaseCost() override = default; double GetCommCost(const std::vector &inputs, const std::vector &outputs, @@ -332,8 +441,7 @@ using GeneratorBaseCostPtr = std::shared_ptr; class PReLUCost : public OperatorCost { public: - explicit PReLUCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - PReLUCost() : OperatorCost(true) {} + PReLUCost() : OperatorCost() {} ~PReLUCost() override = default; // per device communication cost @@ -355,13 +463,16 @@ class PReLUCost : public OperatorCost { int64_t stage_id) const override; double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int64_t stage_id) const override; + // Not taking account of output + void CalculateOutputInMemory() override; + // Taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; }; using PReLUCostPtr = std::shared_ptr; class OneHotCost : public OperatorCost { public: - explicit OneHotCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - OneHotCost() : OperatorCost(true) {} + OneHotCost() : OperatorCost() {} ~OneHotCost() override = default; // per device communication cost @@ -383,13 +494,16 @@ class OneHotCost : public OperatorCost { int64_t stage_id) const override; double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int64_t stage_id) const override; + // Not taking account of output + void CalculateOutputInMemory() override; + // Not taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; }; using OneHotCostPtr = std::shared_ptr; class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost { public: - explicit SoftmaxCrossEntropyWithLogitsCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - SoftmaxCrossEntropyWithLogitsCost() : OperatorCost(false) {} + SoftmaxCrossEntropyWithLogitsCost() : OperatorCost() {} ~SoftmaxCrossEntropyWithLogitsCost() override = default; // per device communication cost @@ -411,13 +525,15 @@ class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost { int64_t stage_id) const override; double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int64_t stage_id) const override; + // Taking account of output + void CalculateOutputInMemory() override; + // Not taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; }; -using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr; class ReshapeCost : public OperatorCost { public: - explicit ReshapeCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - ReshapeCost() : OperatorCost(true) {} + ReshapeCost() : OperatorCost() {} ~ReshapeCost() override = default; @@ -444,14 +560,17 @@ class ReshapeCost : public OperatorCost { double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int64_t stage_id) const override; + // Not taking account of output + void CalculateOutputInMemory() override; + // Not taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; }; using ReshapeCostPtr = std::shared_ptr; -class ArithmeticCost : public OperatorCost { +class SubCost : public OperatorCost { public: - explicit ArithmeticCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - ArithmeticCost() : OperatorCost(false) {} - ~ArithmeticCost() override = default; + SubCost() : OperatorCost() {} + ~SubCost() override = default; double GetCommCost(const std::vector &inputs, const std::vector &outputs, int64_t stage_id) const override { @@ -470,16 +589,127 @@ class ArithmeticCost : public OperatorCost { int64_t stage_id) const override; double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int64_t stage_id) const override; + // Not taking account of output + void CalculateOutputInMemory() override; + // Not taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; }; -using ArithmeticCostPtr = std::shared_ptr; -using BiasAddCost = ArithmeticCost; -using BiasAddCostPtr = std::shared_ptr; +using TensorAddCost = SubCost; +using FloorDivCost = SubCost; +using AssignSubCost = SubCost; +using AssignAddCost = SubCost; +using LogicalAndCost = SubCost; +using LogicalOrCost = SubCost; +using BiasAddCost = SubCost; +using EqualCost = SubCost; +using ApproximateEqualCost = SubCost; +using NotEqualCost = SubCost; +using GreaterCost = SubCost; +using GreaterEqualCost = SubCost; +using LessCost = SubCost; +using LessEqualCost = SubCost; -class ReduceMethodCost : public OperatorCost { +class MulCost : public SubCost { public: - explicit ReduceMethodCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - ReduceMethodCost() : OperatorCost(true) {} - ~ReduceMethodCost() override = default; + MulCost() : SubCost() {} + ~MulCost() override = default; + // Taking account of input, not taking account of output + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; +}; + +class DivCost : public SubCost { + public: + DivCost() : SubCost() {} + ~DivCost() override = default; + // Taking account of output + void CalculateOutputInMemory() override; + // Taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; +}; +using ReadDivCost = DivCost; + +class ModCost : public SubCost { + public: + ModCost() : SubCost() {} + ~ModCost() override = default; + // Taking account of input, not taking account of output + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; +}; +using FloorModCost = ModCost; + +class PowCost : public SubCost { + public: + PowCost() : SubCost() {} + ~PowCost() override = default; + // Taking account of output + void CalculateOutputInMemory() override; + // Taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; +}; + +class AssignCost : public SubCost { + public: + AssignCost() : SubCost() {} + ~AssignCost() override = default; + // Taking account of input, not taking account of output + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; +}; + +class SigmoidCrossEntropyWithLogitsCost : public SubCost { + public: + SigmoidCrossEntropyWithLogitsCost() : SubCost() {} + ~SigmoidCrossEntropyWithLogitsCost() override = default; + // Taking account of input, not taking account of output + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; +}; + +class Atan2Cost : public SubCost { + public: + Atan2Cost() : SubCost() {} + ~Atan2Cost() override = default; + // Taking account of input, not taking account of output + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; +}; + +class DivNoNanCost : public SubCost { + public: + DivNoNanCost() : SubCost() {} + ~DivNoNanCost() override = default; + // Taking account of output + void CalculateOutputInMemory() override; + // Taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; +}; + +class MaximumCost : public SubCost { + public: + MaximumCost() : SubCost() {} + ~MaximumCost() override = default; + // Taking account of input, not taking account of output + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; +}; +using MinimumCost = MaximumCost; + +class SliceCost : public CastCost { + public: + SliceCost() : CastCost() {} + ~SliceCost() override = default; + // Not taking account of output, taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; +}; + +class StridedSliceCost : public CastCost { + public: + StridedSliceCost() : CastCost() {} + ~StridedSliceCost() override = default; + // Not taking account of output, taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; +}; + +class ReduceSumCost : public OperatorCost { + public: + ReduceSumCost() : OperatorCost() {} + ~ReduceSumCost() override = default; double GetCommCost(const std::vector &inputs, const std::vector &outputs, int64_t stage_id) const override { @@ -500,27 +730,50 @@ class ReduceMethodCost : public OperatorCost { return 0.0; } void set_cross_batch(bool cb) { cross_batch_ = cb; } + // Not taking account of output + void CalculateOutputInMemory() override; + // Taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; protected: bool cross_batch_ = false; }; -using ReduceMethodCostPtr = std::shared_ptr; +using ReduceMethodCost = ReduceSumCost; -class ReduceMeanCost : public ReduceMethodCost { +class ReduceMeanCost : public ReduceSumCost { public: - explicit ReduceMeanCost(bool is_inputs_related) : ReduceMethodCost(is_inputs_related) {} - ReduceMeanCost() : ReduceMethodCost(true) {} + ReduceMeanCost() : ReduceSumCost() {} ~ReduceMeanCost() override = default; double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int64_t stage_id) const override; }; -using ReduceMeanCostPtr = std::shared_ptr; + +class ReduceMinCost : public ReduceSumCost { + public: + ReduceMinCost() : ReduceSumCost() {} + ~ReduceMinCost() override = default; + // Taking account of output + void CalculateOutputInMemory() override; + // Taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; +}; +using ReduceMaxCost = ReduceMinCost; + +class ArgMaxWithValueCost : public ReduceSumCost { + public: + ArgMaxWithValueCost() : ReduceSumCost() {} + ~ArgMaxWithValueCost() override = default; + // Taking account of output + void CalculateOutputInMemory() override; + // Taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; +}; +using ArgMinWithValueCost = ArgMaxWithValueCost; class GetNextCost : public OperatorCost { public: - explicit GetNextCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - GetNextCost() : OperatorCost(false) {} + GetNextCost() : OperatorCost() {} ~GetNextCost() override = default; double GetCommCost(const std::vector &inputs, const std::vector &outputs, @@ -547,13 +800,17 @@ class GetNextCost : public OperatorCost { int64_t) const override { return 0.0; } + // Not taking account of output + void CalculateOutputInMemory() override; + // Not Taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; }; using GetNextCostPtr = std::shared_ptr; -class DropOutCost : public OperatorCost { +// For memory cost, taking account of output, not taking account of input +class DropOutCost : public SqrtCost { public: - explicit DropOutCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - DropOutCost() : OperatorCost(true) {} + DropOutCost() : SqrtCost() {} ~DropOutCost() override = default; double GetCommCost(const std::vector &inputs, const std::vector &outputs, @@ -578,12 +835,19 @@ class DropOutCost : public OperatorCost { } }; -using DropOutCostPtr = std::shared_ptr; +class DropOutDoMaskCost : public DropOutCost { + public: + DropOutDoMaskCost() : DropOutCost() {} + ~DropOutDoMaskCost() override = default; + // Not taking account of output + void CalculateOutputInMemory() override; + // Taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; +}; class UnsortedSegmentSumCost : public OperatorCost { public: - explicit UnsortedSegmentSumCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - UnsortedSegmentSumCost() : OperatorCost(true) {} + UnsortedSegmentSumCost() : OperatorCost() {} ~UnsortedSegmentSumCost() override = default; double GetCommCost(const std::vector &inputs, const std::vector &outputs, @@ -602,14 +866,15 @@ class UnsortedSegmentSumCost : public OperatorCost { int64_t) const override { return 0.0; } + // Not taking account of output + void CalculateOutputInMemory() override; + // Taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; }; -using UnsortedSegmentSumCostPtr = std::shared_ptr; - class UnsortedSegmentMinCost : public OperatorCost { public: - explicit UnsortedSegmentMinCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - UnsortedSegmentMinCost() : OperatorCost(true) {} + UnsortedSegmentMinCost() : OperatorCost() {} ~UnsortedSegmentMinCost() override = default; double GetCommCost(const std::vector &inputs, const std::vector &outputs, @@ -628,14 +893,16 @@ class UnsortedSegmentMinCost : public OperatorCost { int64_t) const override { return 0.0; } + // Taking account of output + void CalculateOutputInMemory() override; + // Taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; }; - -using UnsortedSegmentMinCostPtr = std::shared_ptr; +using UnsortedSegmentMaxCost = UnsortedSegmentMinCost; class LayerNormCost : public OperatorCost { public: - explicit LayerNormCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - LayerNormCost() : OperatorCost(true) {} + LayerNormCost() : OperatorCost() {} ~LayerNormCost() override = default; double GetCommCost(const std::vector &inputs, const std::vector &outputs, @@ -656,14 +923,15 @@ class LayerNormCost : public OperatorCost { int64_t) const override { return 0.0; } + // Taking account of output + void CalculateOutputInMemory() override; + // Taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; }; -using DropOutCostPtr = std::shared_ptr; - class UniqueCost : public OperatorCost { public: - explicit UniqueCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - UniqueCost() : OperatorCost(true) {} + UniqueCost() : OperatorCost() {} ~UniqueCost() override = default; double GetCommCost(const std::vector &inputs, const std::vector &outputs, @@ -682,14 +950,15 @@ class UniqueCost : public OperatorCost { int64_t stage_id) const override; double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int64_t) const override; + // Taking account of output + void CalculateOutputInMemory() override; + // Not Taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; }; -using UniqueCostPtr = std::shared_ptr; - class UniformCandidateSamplerCost : public OperatorCost { public: - explicit UniformCandidateSamplerCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - UniformCandidateSamplerCost() : OperatorCost(false) {} + UniformCandidateSamplerCost() : OperatorCost() {} ~UniformCandidateSamplerCost() override = default; double GetCommCost(const std::vector &inputs, const std::vector &outputs, @@ -714,14 +983,15 @@ class UniformCandidateSamplerCost : public OperatorCost { int64_t) const override { return 0.0; } + // Not taking account of output + void CalculateOutputInMemory() override; + // Not Taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; }; -using UniformCandidateSamplerCostPtr = std::shared_ptr; - class GatherV2Cost : public OperatorCost { public: - explicit GatherV2Cost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - GatherV2Cost() : OperatorCost(true) {} + GatherV2Cost() : OperatorCost() {} ~GatherV2Cost() override = default; double GetCommCost(const std::vector &inputs, const std::vector &outputs, @@ -740,14 +1010,15 @@ class GatherV2Cost : public OperatorCost { int64_t stage_id) const override; double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int64_t) const override; + // Not taking account of output + void CalculateOutputInMemory() override; + // Taking account of input + void CalculateInputsInMemory(const std::map &prev_output_in_mem) override; }; -using GatherV2CostPtr = std::shared_ptr; - -class GatherV2PCost : public OperatorCost { +class GatherV2PCost : public GatherV2Cost { public: - explicit GatherV2PCost(bool is_inputs_related) : OperatorCost(is_inputs_related), axis_(0) {} - GatherV2PCost() : OperatorCost(true), axis_(0) {} + GatherV2PCost() : GatherV2Cost(), axis_(0) {} ~GatherV2PCost() override = default; double GetCommCost(const std::vector &inputs, const std::vector &outputs, @@ -773,8 +1044,6 @@ class GatherV2PCost : public OperatorCost { int64_t axis_; Shape strategy_; }; - -using GatherV2PCostPtr = std::shared_ptr; } // namespace parallel } // namespace mindspore #endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h index 14229928ed7..9e9f536d788 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h @@ -50,8 +50,8 @@ class ActivationBase : public OperatorInfo { class Activation : public ActivationBase { public: Activation(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + const PrimitiveAttrs &attrs, OperatorCostPtr cost) + : ActivationBase(name, inputs_shape, outputs_shape, attrs, cost) {} ~Activation() override = default; Status GenerateStrategies(int64_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override; @@ -64,7 +64,7 @@ class ActivationInfo : public Activation { public: ActivationInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : Activation(name, inputs_shape, outputs_shape, attrs) {} + : Activation(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~ActivationInfo() override = default; protected: @@ -74,8 +74,8 @@ class ActivationInfo : public Activation { class ActivationOther : public Activation { public: ActivationOther(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : Activation(name, inputs_shape, outputs_shape, attrs) {} + const PrimitiveAttrs &attrs, OperatorCostPtr cost) + : Activation(name, inputs_shape, outputs_shape, attrs, cost) {} ~ActivationOther() override = default; protected: @@ -86,7 +86,7 @@ class GeluInfo : public ActivationOther { public: GeluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~GeluInfo() override = default; }; @@ -94,7 +94,7 @@ class TanhInfo : public ActivationOther { public: TanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~TanhInfo() override = default; }; @@ -102,7 +102,7 @@ class Softmax : public ActivationBase { public: explicit Softmax(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~Softmax() override = default; Status GenerateStrategies(int64_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override; @@ -134,7 +134,7 @@ class LogSoftmaxInfo : public Softmax { class EluInfo : public ActivationOther { public: EluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~EluInfo() override = default; }; @@ -142,7 +142,7 @@ class ReLUInfo : public ActivationOther { public: ReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~ReLUInfo() override = default; }; @@ -150,7 +150,7 @@ class RepeatElementsInfo : public ActivationOther { public: RepeatElementsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~RepeatElementsInfo() override = default; }; @@ -158,7 +158,7 @@ class ReLU6Info : public ActivationOther { public: ReLU6Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~ReLU6Info() override = default; }; @@ -166,7 +166,7 @@ class SoftsignInfo : public ActivationOther { public: SoftsignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~SoftsignInfo() override = default; }; @@ -174,7 +174,7 @@ class SoftplusInfo : public ActivationOther { public: SoftplusInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~SoftplusInfo() override = default; }; @@ -182,7 +182,7 @@ class CastInfo : public ActivationOther { public: CastInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~CastInfo() override = default; protected: @@ -193,14 +193,14 @@ class SqrtInfo : public ActivationOther { public: SqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~SqrtInfo() override = default; }; class NegInfo : public ActivationOther { public: NegInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~NegInfo() override = default; }; @@ -208,7 +208,7 @@ class ExpandDimsInfo : public ActivationOther { public: ExpandDimsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~ExpandDimsInfo() override = default; protected: @@ -228,7 +228,7 @@ class SqueezeInfo : public ActivationOther { public: SqueezeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~SqueezeInfo() override = default; protected: @@ -247,7 +247,7 @@ class SquareInfo : public ActivationOther { public: SquareInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~SquareInfo() override = default; }; @@ -255,7 +255,7 @@ class SigmoidInfo : public ActivationOther { public: SigmoidInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~SigmoidInfo() override = default; }; @@ -263,7 +263,7 @@ class DropoutInfo : public ActivationOther { public: DropoutInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~DropoutInfo() override = default; Status GenerateStrategies(int64_t stage_id) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h index 9d9d08af6b2..38ce29f5431 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h @@ -56,7 +56,7 @@ class ArithmeticBase : public OperatorInfo { class SubInfo : public ArithmeticBase { public: SubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~SubInfo() override = default; }; @@ -64,28 +64,28 @@ class TensorAddInfo : public ArithmeticBase { public: TensorAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~TensorAddInfo() override = default; }; class MulInfo : public ArithmeticBase { public: MulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~MulInfo() override = default; }; class DivInfo : public ArithmeticBase { public: DivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~DivInfo() override = default; }; class ModInfo : public ArithmeticBase { public: ModInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~ModInfo() override = default; }; @@ -93,7 +93,7 @@ class RealDivInfo : public ArithmeticBase { public: RealDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~RealDivInfo() override = default; }; @@ -101,7 +101,7 @@ class FloorDivInfo : public ArithmeticBase { public: FloorDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~FloorDivInfo() override = default; }; @@ -109,14 +109,14 @@ class FloorModInfo : public ArithmeticBase { public: FloorModInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~FloorModInfo() override = default; }; class PowInfo : public ArithmeticBase { public: PowInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~PowInfo() override = default; }; @@ -124,7 +124,7 @@ class AssignSubInfo : public ArithmeticBase { public: AssignSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~AssignSubInfo() override = default; }; @@ -132,7 +132,7 @@ class AssignInfo : public ArithmeticBase { public: AssignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~AssignInfo() override = default; }; @@ -140,7 +140,7 @@ class AssignAddInfo : public ArithmeticBase { public: AssignAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~AssignAddInfo() override = default; }; @@ -149,7 +149,8 @@ class SigmoidCrossEntropyWithLogitsInfo : public ArithmeticBase { public: SigmoidCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, + std::make_shared()) {} ~SigmoidCrossEntropyWithLogitsInfo() override = default; }; @@ -157,7 +158,7 @@ class Atan2Info : public ArithmeticBase { public: Atan2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~Atan2Info() override = default; }; @@ -165,7 +166,7 @@ class DivNoNanInfo : public ArithmeticBase { public: DivNoNanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~DivNoNanInfo() override = default; }; @@ -173,7 +174,7 @@ class LogicalAndInfo : public ArithmeticBase { public: LogicalAndInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~LogicalAndInfo() override = default; }; @@ -181,7 +182,7 @@ class LogicalOrInfo : public ArithmeticBase { public: LogicalOrInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~LogicalOrInfo() override = default; }; } // namespace parallel diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.h index 1c73298c412..a96ba18c991 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.h @@ -34,8 +34,7 @@ class BatchParallelInfo : public OperatorInfo { : OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost), dev_num_(1) {} BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)), - dev_num_(1) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), dev_num_(1) {} ~BatchParallelInfo() override = default; Status Init(const StrategyPtr &strategy) override; @@ -62,7 +61,8 @@ class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo { public: SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + : BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, + std::make_shared()) {} ~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default; void ReComputeBatchSplitFlagList() override; }; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.h index 7e2b68a1a0d..bdbf88a29bd 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.h @@ -34,7 +34,7 @@ class BiasAddInfo : public OperatorInfo { public: BiasAddInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~BiasAddInfo() override = default; Status Init(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/broadcast_to_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/broadcast_to_info.h index 212f3844df4..2c163c563e9 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/broadcast_to_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/broadcast_to_info.h @@ -36,7 +36,7 @@ class BroadcastToInfo : public OperatorInfo { public: BroadcastToInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~BroadcastToInfo() override = default; Status Init(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/comparison_function_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/comparison_function_info.h index 74e231966ab..e457e5aa5f1 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/comparison_function_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/comparison_function_info.h @@ -32,7 +32,7 @@ class EqualInfo : public ArithmeticBase { public: EqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~EqualInfo() override = default; }; @@ -40,7 +40,7 @@ class ApproximateEqualInfo : public ArithmeticBase { public: ApproximateEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~ApproximateEqualInfo() override = default; }; @@ -48,7 +48,7 @@ class NotEqualInfo : public ArithmeticBase { public: NotEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~NotEqualInfo() override = default; }; @@ -56,7 +56,7 @@ class MaximumInfo : public ArithmeticBase { public: MaximumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~MaximumInfo() override = default; }; @@ -64,7 +64,7 @@ class MinimumInfo : public ArithmeticBase { public: MinimumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~MinimumInfo() override = default; }; @@ -72,7 +72,7 @@ class GreaterInfo : public ArithmeticBase { public: GreaterInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~GreaterInfo() override = default; }; @@ -80,7 +80,7 @@ class GreaterEqualInfo : public ArithmeticBase { public: GreaterEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~GreaterEqualInfo() override = default; }; @@ -88,7 +88,7 @@ class LessInfo : public ArithmeticBase { public: LessInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~LessInfo() override = default; }; @@ -96,7 +96,7 @@ class LessEqualInfo : public ArithmeticBase { public: LessEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~LessEqualInfo() override = default; }; } // namespace parallel diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/concat_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/concat_info.h index 151ef7eb456..558cf139d22 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/concat_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/concat_info.h @@ -33,7 +33,7 @@ class ConcatInfo : public OperatorInfo { public: ConcatInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~ConcatInfo() override = default; Status Init(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.h index 4f514c2354e..f0c4573afba 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.h @@ -33,7 +33,7 @@ class DropoutDoMaskInfo : public OperatorInfo { public: DropoutDoMaskInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~DropoutDoMaskInfo() override = default; Status Init(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/elementary_function_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/elementary_function_info.h index 194383a5612..79f1fe4beeb 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/elementary_function_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/elementary_function_info.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "ir/value.h" #include "frontend/parallel/auto_parallel/operator_costmodel.h" #include "frontend/parallel/ops_info/activation_info.h" @@ -30,21 +31,21 @@ namespace parallel { class ExpInfo : public ActivationOther { public: ExpInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~ExpInfo() override = default; }; class LogInfo : public ActivationOther { public: LogInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~LogInfo() override = default; }; class CosInfo : public ActivationOther { public: CosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~CosInfo() override = default; }; @@ -52,7 +53,7 @@ class ACosInfo : public ActivationOther { public: ACosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~ACosInfo() override = default; }; @@ -60,14 +61,14 @@ class LogicalNotInfo : public ActivationOther { public: LogicalNotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~LogicalNotInfo() override = default; }; class AbsInfo : public ActivationOther { public: AbsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~AbsInfo() override = default; }; @@ -75,7 +76,7 @@ class SignInfo : public ActivationOther { public: SignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~SignInfo() override = default; }; @@ -83,7 +84,7 @@ class FloorInfo : public ActivationOther { public: FloorInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~FloorInfo() override = default; }; @@ -91,7 +92,7 @@ class RoundInfo : public ActivationOther { public: RoundInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~RoundInfo() override = default; }; @@ -99,14 +100,14 @@ class ReciprocalInfo : public ActivationOther { public: ReciprocalInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~ReciprocalInfo() override = default; }; class InvInfo : public ActivationOther { public: InvInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~InvInfo() override = default; }; @@ -114,21 +115,21 @@ class RsqrtInfo : public ActivationOther { public: RsqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~RsqrtInfo() override = default; }; class TanInfo : public ActivationOther { public: TanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~TanInfo() override = default; }; class SinInfo : public ActivationOther { public: SinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~SinInfo() override = default; }; @@ -136,7 +137,7 @@ class SinhInfo : public ActivationOther { public: SinhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~SinhInfo() override = default; }; @@ -144,7 +145,7 @@ class Log1pInfo : public ActivationOther { public: Log1pInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~Log1pInfo() override = default; }; @@ -152,7 +153,7 @@ class Expm1Info : public ActivationOther { public: Expm1Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~Expm1Info() override = default; }; @@ -160,7 +161,7 @@ class CoshInfo : public ActivationOther { public: CoshInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~CoshInfo() override = default; }; @@ -168,7 +169,7 @@ class CeilInfo : public ActivationOther { public: CeilInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~CeilInfo() override = default; }; @@ -176,7 +177,7 @@ class AtanhInfo : public ActivationOther { public: AtanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~AtanhInfo() override = default; }; @@ -184,7 +185,7 @@ class AtanInfo : public ActivationOther { public: AtanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~AtanInfo() override = default; }; @@ -192,7 +193,7 @@ class AsinInfo : public ActivationOther { public: AsinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~AsinInfo() override = default; }; @@ -200,7 +201,7 @@ class AsinhInfo : public ActivationOther { public: AsinhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~AsinhInfo() override = default; }; @@ -208,14 +209,14 @@ class AcoshInfo : public ActivationOther { public: AcoshInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~AcoshInfo() override = default; }; class ErfInfo : public ActivationOther { public: ErfInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~ErfInfo() override = default; }; @@ -223,7 +224,7 @@ class ErfcInfo : public ActivationOther { public: ErfcInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~ErfcInfo() override = default; }; @@ -231,7 +232,7 @@ class ZerosLikeInfo : public ActivationOther { public: ZerosLikeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~ZerosLikeInfo() override = default; }; @@ -239,7 +240,7 @@ class OnesLikeInfo : public ActivationOther { public: OnesLikeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~OnesLikeInfo() override = default; }; @@ -247,7 +248,7 @@ class BesselI0eInfo : public ActivationOther { public: BesselI0eInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~BesselI0eInfo() override = default; }; @@ -255,7 +256,7 @@ class BesselI1eInfo : public ActivationOther { public: BesselI1eInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~BesselI1eInfo() override = default; }; } // namespace parallel diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.h index 111fd9ff43f..b8c82bd0039 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.h @@ -32,7 +32,7 @@ class GetNextInfo : public OperatorInfo { public: GetNextInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~GetNextInfo() override = default; Status Init(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.h index 06c61f8b157..084e0f29ce0 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.h @@ -33,7 +33,7 @@ class L2NormalizeInfo : public Activation { public: L2NormalizeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : Activation(name, inputs_shape, outputs_shape, attrs) {} + : Activation(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~L2NormalizeInfo() override = default; Status GenerateStrategies(int64_t stage_id) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.h index 5637d9829d9..25b1f6596d1 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.h @@ -40,7 +40,7 @@ class LayerNormInfo : public OperatorInfo { public: LayerNormInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(true)), + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()), begin_norm_axis_(0) {} ~LayerNormInfo() override = default; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.h index e30864e6633..3f045929ac1 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.h @@ -36,8 +36,7 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo { public: SoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, - std::make_shared(false)) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~SoftmaxCrossEntropyWithLogitsInfo() override = default; Status Init(const StrategyPtr &strategy) override; Status InitForCostModel(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h index 26afaaae69c..6ea93c7e7cd 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h @@ -34,7 +34,7 @@ class MatMulBase : public OperatorInfo { public: MatMulBase(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~MatMulBase() override = default; Status Init(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.h index 1b3624fb1ce..a1a44703f0f 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.h @@ -33,7 +33,7 @@ class OneHotInfo : public OperatorInfo { public: OneHotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~OneHotInfo() override = default; Status Init(const StrategyPtr &strategy) override; Status InitForCostModel(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc index 89003c3aac9..f910be859de 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -1204,6 +1204,20 @@ int64_t OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() { } else { is_output_parameter_involve_ = 0; } + // Set 'is_parameter_involve_' and 'is_output_parameter_involve_' into operatorCost, which are used in + // calculating 'inputs_in_memory' and 'output_in_memory', respectively. + operator_cost()->set_is_parameter_involve(is_parameter_involve_); + operator_cost()->set_output_parameter_involve(is_output_parameter_involve_); + // Calculating 'output_in_memory' + operator_cost()->CalculateOutputInMemory(); + // Calculating 'inputs_in_memory' + std::map input_in_memory; + for (auto &p_edge : prev_edges) { + auto input_index = p_edge->next_op_input_index(); + auto is_in_mem = p_edge->prev_operator()->operator_cost()->is_output_in_memory(); + input_in_memory.emplace(std::make_pair(input_index, is_in_mem)); + } + operator_cost()->CalculateInputsInMemory(input_in_memory); return is_output_parameter_involve_; } @@ -1220,14 +1234,10 @@ Status OperatorInfo::set_is_parameter(const std::vector &is_parameter) { } Status OperatorInfo::CalculateMemoryCost() { - // First, set the 'is_parameter_involve_' and 'is_output_parameter_involve_' into OperatorCost, which are necessary to - // calculate memory cost. if (is_parameter_involve_.size() != is_parameter_.size()) { MS_LOG(ERROR) << "'is_parameter_' does not have the same number of input size of 'is_parameter_involve_'."; return FAILED; } - operator_cost()->set_is_parameter_involve(is_parameter_involve_); - operator_cost()->set_output_parameter_involve(is_output_parameter_involve_); // Set the memory cost in the 'strategy_cost_' for (auto &swc : strategy_cost_) { auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/pack_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/pack_info.h index 87de6811662..cc704ac6440 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/pack_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/pack_info.h @@ -33,7 +33,7 @@ class PackInfo : public OperatorInfo { public: PackInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~PackInfo() override = default; Status Init(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.h index 2d5b44d0018..401afb3c963 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.h @@ -35,7 +35,7 @@ class PReLUInfo : public OperatorInfo { public: PReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~PReLUInfo() override = default; Status Init(const StrategyPtr &strategy) override; Status InitForCostModel(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/range_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/range_info.h index e3aa63dc173..af77eb5b41e 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/range_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/range_info.h @@ -39,7 +39,7 @@ class RangeInfo : public OperatorInfo { public: RangeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~RangeInfo() override = default; Status Init(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.h index 9f2f3eeae1d..daa2cf222d8 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.h @@ -33,8 +33,8 @@ namespace parallel { class ReduceMethod : public OperatorInfo { public: ReduceMethod(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + const PrimitiveAttrs &attrs, OperatorCostPtr cost) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost) {} ~ReduceMethod() override = default; Status Init(const StrategyPtr &strategy) override; @@ -62,7 +62,7 @@ class ReduceMaxInfo : public ReduceMethod { public: ReduceMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { + : ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared()) { reduce_method_ = REDUCE_OP_MAX; } @@ -73,7 +73,7 @@ class ArgMaxWithValueInfo : public ReduceMethod { public: ArgMaxWithValueInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { + : ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared()) { reduce_method_ = REDUCE_OP_MAX; } @@ -105,9 +105,7 @@ class ReduceMeanInfo : public ReduceMethod { public: ReduceMeanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { - set_cost(std::make_shared()); - } + : ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~ReduceMeanInfo() override = default; @@ -119,7 +117,7 @@ class ReduceSumInfo : public ReduceMethod { public: ReduceSumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { + : ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared()) { reduce_method_ = REDUCE_OP_SUM; } @@ -130,7 +128,7 @@ class ReduceMinInfo : public ReduceMethod { public: ReduceMinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { + : ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared()) { reduce_method_ = REDUCE_OP_MIN; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.h index e30ecf267f5..c1482d83bb8 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.h @@ -37,7 +37,7 @@ class ReLUV2Info : public OperatorInfo { public: ReLUV2Info(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~ReLUV2Info() override = default; Status Init(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h index 21dfdce485c..e8dd69bc314 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h @@ -36,7 +36,7 @@ class ReshapeInfo : public OperatorInfo { public: ReshapeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)), + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), dev_num_(0), pre_operator_index_(0), next_operator_index_(0), diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/slice_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/slice_info.h index cdc948d345c..0edf6552294 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/slice_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/slice_info.h @@ -34,7 +34,7 @@ class SliceInfo : public OperatorInfo { public: SliceInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)), + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()), slice_axis_(-1) {} ~SliceInfo() override = default; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/split_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/split_info.h index d1cb9f13519..61419f2e305 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/split_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/split_info.h @@ -31,7 +31,7 @@ class SplitInfo : public OperatorInfo { public: SplitInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~SplitInfo() override = default; Status Init(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc index a9c8b4ec2e6..3ef53631811 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc @@ -273,6 +273,8 @@ Status StridedSliceInfo::GenerateStrategies(int64_t stage_id) { PrintStrategy(sp); } } + + MS_LOG(INFO) << name() << ", finishing GenerateStrategies()."; return SUCCESS; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.h index 0d55557d40f..d68651cc10e 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.h @@ -34,7 +34,7 @@ class StridedSliceInfo : public OperatorInfo { public: StridedSliceInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~StridedSliceInfo() override = default; Status Init(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/tensordot_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/tensordot_info.h index 4ca1374a0c0..6709d335885 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/tensordot_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/tensordot_info.h @@ -41,7 +41,7 @@ class TensorDotInfo : public OperatorInfo { public: TensorDotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~TensorDotInfo() override = default; Status Init(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/tile_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/tile_info.h index 952b4918907..e3a6d824e67 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/tile_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/tile_info.h @@ -34,7 +34,7 @@ class TileInfo : public OperatorInfo { public: TileInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~TileInfo() override = default; Status Init(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.h index 690823d30f7..24f54eebf7c 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.h @@ -34,7 +34,7 @@ class TmpIdentityInfo : public OperatorInfo { public: TmpIdentityInfo(const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs, const std::string &name = IDENTITY_INFO) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~TmpIdentityInfo() override = default; Status Init(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.h index babd5cb9bbc..f528ac79b57 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.h @@ -35,7 +35,7 @@ class TransposeInfo : public OperatorInfo { public: TransposeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~TransposeInfo() override = default; Status Init(const StrategyPtr &strategy) override; Status InitForCostModel(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.h index 0ecf6f6f43c..94323bc523a 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.h @@ -32,7 +32,7 @@ class UniqueInfo : public OperatorInfo { public: UniqueInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~UniqueInfo() override = default; Status Init(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.h index 33866ab60f9..539a2494af1 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.h @@ -82,7 +82,7 @@ class UnsortedSegmentMaxInfo : public UnsortedSegmentOpInfo { public: UnsortedSegmentMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + : UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~UnsortedSegmentMaxInfo() override = default; ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.h index 7088a5267ee..a5674851c07 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.h @@ -32,7 +32,7 @@ class VirtualDatasetInfo : public OperatorInfo { public: VirtualDatasetInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~VirtualDatasetInfo() override = default; Status Init(const StrategyPtr &strategy) override; Status InitForCostModel(const StrategyPtr &strategy) override; diff --git a/tests/ut/cpp/parallel/auto_parallel/operator_costmodel_test.cc b/tests/ut/cpp/parallel/auto_parallel/operator_costmodel_test.cc index 13c278bebd9..7db3c4ae673 100644 --- a/tests/ut/cpp/parallel/auto_parallel/operator_costmodel_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/operator_costmodel_test.cc @@ -85,11 +85,11 @@ class TestActivationCost : public UT::Common { TestActivationCost() {} void SetUp(); void TearDown(); - ActivationCost ac_cost_; + ActivationInfoCost ac_cost_; }; void TestActivationCost::SetUp() { - ac_cost_ = ActivationCost(); + ac_cost_ = ActivationInfoCost(); RankList dev_list; for (int32_t i = 0; i < 1050; i++) {