!10683 [Auto paralel] Change memory cost calculation in auto-parallel

From: @xiaoda_zh
Reviewed-by: @kisnwang,@stsuteng
Signed-off-by: @stsuteng
This commit is contained in:
mindspore-ci-bot 2020-12-29 10:40:02 +08:00 committed by Gitee
commit fae4246811
36 changed files with 1255 additions and 312 deletions

View File

@ -27,6 +27,7 @@ void OperatorCost::set_is_parameter(const std::vector<bool> &is_parameter) { is_
void OperatorCost::set_is_parameter_involve(const std::vector<bool> &is_parameter_inv) {
is_parameter_involve_ = is_parameter_inv;
is_inputs_should_in_memory_ = std::vector<bool>(is_parameter_involve_.size(), false);
}
void OperatorCost::set_output_parameter_involve(int64_t output_para) { output_parameter_involve_ = output_para; }
@ -41,27 +42,28 @@ void OperatorCost::set_output_critical(int64_t critical) { is_outputs_critical_
double OperatorCost::GetMemoryCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs) const {
return GetInputMemoryCost(inputs, outputs) + GetOutputMemoryCost(inputs, outputs);
}
double OperatorCost::GetInputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &) const {
double result = 0.0;
if (output_parameter_involve_ == 1) {
for (size_t i = 0; i < inputs.size(); ++i) {
if (is_inputs_should_in_memory_[i]) {
result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
}
}
return result;
}
double OperatorCost::GetOutputMemoryCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs) const {
double result = 0.0;
if (is_output_should_in_memory_) {
// When this operator has multiple outputs, they all contributes to the memory.
for (size_t i = 0; i < outputs.size(); ++i) {
result += ListProduct(outputs[i].slice_shape()) * static_cast<double>(outputs_type_lengths_[i]);
}
bool is_any_para_inv =
std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; });
if (is_any_para_inv) {
for (size_t i = 0; i < inputs.size(); ++i) {
if (is_parameter_[i]) {
result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
} else if (inputs_related_ && (!is_parameter_involve_[i])) {
// When the inputs of this operator are related, and they are not parameter-involved, then they are included
// in the memory cost.
result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
}
}
}
}
return result;
}
@ -166,16 +168,43 @@ double MatMulCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inp
return result;
}
// Not taking account of output
void MatMulCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
// Taking account of input
void MatMulCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
if (is_parameter_[0]) {
is_inputs_should_in_memory_[0] = true;
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
} else if (is_parameter_involve_[0]) {
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
}
if (is_parameter_[1]) {
is_inputs_should_in_memory_[1] = true;
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
is_inputs_should_in_memory_[0] = true;
}
} else if (is_parameter_involve_[1]) {
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
is_inputs_should_in_memory_[0] = true;
}
}
}
// Return the per device communication cost in the forward phase.
double ActivationCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
int64_t) const {
double CastCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
// ReLU is the element-wise operator, thus it does not need communication in the forward phase
return 0.0;
}
// Return the per device communication cost in the backward phase.
double ActivationCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
int64_t stage_id) const {
double CastCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
int64_t stage_id) const {
double result = 0.0;
if (is_parameter_[0]) {
TensorInfo input1 = inputs[0];
@ -196,8 +225,8 @@ double ActivationCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double ActivationCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
int64_t) const {
double CastCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
int64_t) const {
TensorInfo input0 = inputs[0];
Shape input0_slice_shape = input0.slice_shape();
return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
@ -205,11 +234,33 @@ double ActivationCost::GetForwardComputationCost(const std::vector<TensorInfo> &
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double ActivationCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
int64_t) const {
double CastCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
int64_t) const {
return 0.0;
}
// Not taking account of output
void CastCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
// Not taking account of input
void CastCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
is_inputs_should_in_memory_[0] = is_parameter_[0];
}
// Taking account of output
void SqrtCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
// Taking account of input
void GeLUCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
if (is_parameter_[0]) {
is_inputs_should_in_memory_[0] = true;
} else if (is_parameter_involve_[0]) {
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
is_inputs_should_in_memory_[0] = true;
}
}
}
// Return the per device communication cost in the forward phase.
double SoftmaxCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
int64_t) const {
@ -259,6 +310,81 @@ double SoftmaxCost::GetBackwardComputationCost(const std::vector<mindspore::para
return 0.0;
}
// Taking account of output
void SoftmaxCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
// Not taking account of input
void SoftmaxCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
is_inputs_should_in_memory_[0] = is_parameter_[0];
}
// Not taking account of output
void PackCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
// Not taking account of input
void PackCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
is_inputs_should_in_memory_[0] = is_parameter_[0];
}
// Not taking account of output
void TileCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
// Taking account of input
void TileCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calculating 'dx', taking account of 'y'
if (is_parameter_[0]) {
is_inputs_should_in_memory_[0] = true;
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
} else if (is_parameter_involve_[0]) {
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
}
if (!is_inputs_should_in_memory_[1]) {
is_inputs_should_in_memory_[1] = is_parameter_[1];
}
}
// Not taking account of output
void BroadcastToCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
void BroadcastToCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
is_inputs_should_in_memory_[0] = is_parameter_[0];
}
// Taking account of input
void ReLU6Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
if (is_parameter_[0]) {
is_inputs_should_in_memory_[0] = true;
} else if (is_parameter_involve_[0]) {
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
is_inputs_should_in_memory_[0] = true;
}
}
}
// Taking account of input
void TransposeCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calulating 'dx', taking account of 'y'
if (is_parameter_[0]) {
is_inputs_should_in_memory_[0] = true;
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
} else if (is_parameter_involve_[0]) {
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
}
if (!is_inputs_should_in_memory_[1]) {
is_inputs_should_in_memory_[1] = is_parameter_[1];
}
}
// return the per device communication cost in the forward phase.
double TmpIdentityCost::GetForwardCommCost(const std::vector<mindspore::parallel::TensorInfo> &,
const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
@ -288,9 +414,12 @@ double TmpIdentityCost::GetBackwardComputationCost(const std::vector<mindspore::
return 0.0;
}
// Return the per device PEAK memory cost contributed by this operator in a training iteration.
double TmpIdentityCost::GetMemoryCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &) const {
return 0.0;
// Not taking account of output
void TmpIdentityCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
// Not taking account of input
void TmpIdentityCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
is_inputs_should_in_memory_[0] = is_parameter_[0];
}
double BatchParallelCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &inputs,
@ -334,6 +463,42 @@ double BatchParallelCost::GetBackwardCommCost(const std::vector<TensorInfo> &inp
return result;
}
void BatchParallelCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
void BatchParallelCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
if (is_parameter_[0]) {
is_inputs_should_in_memory_[0] = true;
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
} else if (is_parameter_involve_[0]) {
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
}
if (is_parameter_[1]) {
is_inputs_should_in_memory_[1] = true;
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
is_inputs_should_in_memory_[0] = true;
}
} else if (is_parameter_involve_[1]) {
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
is_inputs_should_in_memory_[0] = true;
}
}
}
void SparseSoftmaxCrossEntropyWithLogitsCost::CalculateOutputInMemory() {
is_output_should_in_memory_ = is_parameter_involve_[0];
}
void SparseSoftmaxCrossEntropyWithLogitsCost::CalculateInputsInMemory(
const std::map<size_t, bool> &prev_output_in_mem) {
is_inputs_should_in_memory_[0] = is_parameter_[0];
is_inputs_should_in_memory_[1] = is_parameter_[1];
}
// return the per device communication cost in the forward phase.
double PReLUCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
// prelu does not need communication in the forward phase
@ -401,6 +566,21 @@ double PReLUCost::GetBackwardComputationCost(const std::vector<mindspore::parall
return result;
}
void PReLUCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
void PReLUCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calculating 'dx', taking account of both 'x' and 'y';
// when calculating 'dy', taking account of both 'x' and 'y'
if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
is_inputs_should_in_memory_[0] = true;
}
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
}
}
// return the per device communication cost in the forward phase.
double OneHotCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
// onehot does not need communication in the forward phase
@ -430,6 +610,17 @@ double OneHotCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, c
return 0.0;
}
// Not taking account of output
void OneHotCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
// Not taking account of input
void OneHotCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
is_inputs_should_in_memory_[0] = is_parameter_[0];
is_inputs_should_in_memory_[1] = is_parameter_[1];
is_inputs_should_in_memory_[2] = is_parameter_[2];
is_inputs_should_in_memory_[3] = is_parameter_[3];
}
// return the per device communication cost in the forward phase.
double SoftmaxCrossEntropyWithLogitsCost::GetForwardCommCost(const std::vector<TensorInfo> &,
const std::vector<TensorInfo> &, int64_t) const {
@ -463,6 +654,16 @@ double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::
return 0.0;
}
// Taking account of output
void SoftmaxCrossEntropyWithLogitsCost::CalculateOutputInMemory() {
is_output_should_in_memory_ = is_parameter_involve_[0];
}
void SoftmaxCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
is_inputs_should_in_memory_[0] = is_parameter_[0];
is_inputs_should_in_memory_[1] = is_parameter_[1];
}
// return the per device communication cost in the forward phase.
double ReshapeCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const {
@ -524,50 +725,22 @@ double ReshapeCost::GetBackwardComputationCost(const std::vector<mindspore::para
return 0.0;
}
double ArithmeticCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
int64_t) const {
void ReshapeCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
void ReshapeCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
is_inputs_should_in_memory_[0] = is_parameter_[0];
is_inputs_should_in_memory_[1] = is_parameter_[1];
}
double SubCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
int64_t) const {
double result;
result = ListProduct(inputs[0].slice_shape()) * static_cast<double>(inputs_type_lengths_[0]) +
ListProduct(inputs[1].slice_shape()) * static_cast<double>(inputs_type_lengths_[1]);
return result;
}
double ArithmeticCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &, int64_t stage_id) const {
double result = 0.0;
CheckGlobalDeviceManager();
MS_EXCEPTION_IF_NULL(g_device_manager);
auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
if (is_parameter_[0]) {
TensorInfo input_a_tensor_info = inputs[0];
Shape input_a_shape = input_a_tensor_info.shape();
Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
int64_t used_device_num = 1;
for (size_t i = 0; i < input_a_shape.size(); ++i) {
used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
}
if (total_device_num != LongToSize(used_device_num))
result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
}
if (is_parameter_[1]) {
TensorInfo input_b_tensor_info = inputs[1];
Shape input_b_shape = input_b_tensor_info.shape();
Shape input_b_slice_shape = input_b_tensor_info.slice_shape();
int64_t used_device_num = 1;
for (size_t i = 0; i < input_b_shape.size(); ++i) {
used_device_num *= input_b_shape[i] / input_b_slice_shape[i];
}
if (total_device_num != LongToSize(used_device_num))
result += ListProduct(input_b_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
}
return result;
}
double ArithmeticCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
double SubCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
int64_t stage_id) const {
double result = 0.0;
CheckGlobalDeviceManager();
@ -587,6 +760,41 @@ double ArithmeticCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs
result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
}
if (is_parameter_[1]) {
TensorInfo input_b_tensor_info = inputs[1];
Shape input_b_shape = input_b_tensor_info.shape();
Shape input_b_slice_shape = input_b_tensor_info.slice_shape();
int64_t used_device_num = 1;
for (size_t i = 0; i < input_b_shape.size(); ++i) {
used_device_num *= input_b_shape[i] / input_b_slice_shape[i];
}
if (total_device_num != LongToSize(used_device_num))
result += ListProduct(input_b_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
}
return result;
}
double SubCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
int64_t stage_id) const {
double result = 0.0;
CheckGlobalDeviceManager();
MS_EXCEPTION_IF_NULL(g_device_manager);
auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
if (is_parameter_[0]) {
TensorInfo input_a_tensor_info = inputs[0];
Shape input_a_shape = input_a_tensor_info.shape();
Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
int64_t used_device_num = 1;
for (size_t i = 0; i < input_a_shape.size(); ++i) {
used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
}
if (total_device_num != LongToSize(used_device_num))
result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
}
if (is_parameter_[1]) {
TensorInfo input_b_tensor_info = inputs[1];
Shape input_b_shape = input_b_tensor_info.shape();
@ -603,6 +811,273 @@ double ArithmeticCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs
return result;
}
// Not taking account of output
void SubCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
// Not taking account of input
void SubCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
is_inputs_should_in_memory_[0] = is_parameter_[0];
is_inputs_should_in_memory_[1] = is_parameter_[1];
}
// Taking account of input
void MulCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
if (is_parameter_[0]) {
// 'x' is parameter, so it should be in memory.
is_inputs_should_in_memory_[0] = true;
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
// In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
is_inputs_should_in_memory_[1] = true;
}
} else if (is_parameter_involve_[0]) {
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
}
if (is_parameter_[1]) {
is_inputs_should_in_memory_[1] = true;
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
is_inputs_should_in_memory_[0] = true;
}
} else if (is_parameter_involve_[1]) {
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
is_inputs_should_in_memory_[0] = true;
}
}
}
// Taking account of output
void DivCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; }
// Taking account of input
void DivCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calculating 'dx', taking account of 'y'
if (is_parameter_[0]) {
// 'x' is parameter, so it should be in memory.
is_inputs_should_in_memory_[0] = true;
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
// In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
is_inputs_should_in_memory_[1] = true;
}
} else if (is_parameter_involve_[0]) {
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
}
// When calculating 'dy', taking account of 'y'
if (is_parameter_involve_[1]) {
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
}
}
// Taking account of input
void ModCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calculating 'dx', not taking account of 'x' and 'y'
is_inputs_should_in_memory_[0] = is_parameter_[0];
// When calculating 'dy', taking account of 'x' and 'y'
if (is_parameter_involve_[1]) {
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
is_inputs_should_in_memory_[0] = true;
}
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
}
}
void PowCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; }
void PowCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calculating 'dx', taking account of both 'x' and 'power'
if (is_parameter_involve_[0]) {
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
is_inputs_should_in_memory_[0] = true;
}
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
}
// When calculating 'dpower', taking account of 'x'
if (is_parameter_[1]) {
is_inputs_should_in_memory_[1] = true;
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
is_inputs_should_in_memory_[0] = true;
}
} else if (is_parameter_involve_[1]) {
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
is_inputs_should_in_memory_[0] = true;
}
}
}
void AssignCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calculating 'dx', taking account of 'x'
if (is_parameter_involve_[0]) {
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
is_inputs_should_in_memory_[0] = true;
}
}
// When calculating 'dy', not taking account of 'x' and 'y'
is_inputs_should_in_memory_[1] = is_parameter_[1];
}
void SigmoidCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calculating 'dx', taking account of both 'x' and 'y'
if (is_parameter_involve_[0]) {
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
is_inputs_should_in_memory_[0] = true;
}
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
}
// When calculating 'dy', not taking account of 'x' and 'y'
if (!is_inputs_should_in_memory_[1]) {
is_inputs_should_in_memory_[1] = is_parameter_[1];
}
}
void Atan2Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calculating 'dx', taking account of both 'x' and 'y'; when calculating 'dy', taking account of both 'x' and
// 'y'
if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
is_inputs_should_in_memory_[0] = true;
}
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
}
}
void DivNoNanCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; }
void DivNoNanCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calculating 'dx', taking account of 'y'
if (is_parameter_[0]) {
// 'x' is parameter, so it should be in memory.
is_inputs_should_in_memory_[0] = true;
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
// In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
is_inputs_should_in_memory_[1] = true;
}
} else if (is_parameter_involve_[0]) {
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
}
// When calculating 'dy', taking account of 'y'
if (is_parameter_involve_[1]) {
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
}
}
void MaximumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calculating 'dx', taking account of both 'x' and 'y';
// when calculating 'dy', taking account of both 'x' and 'y'
if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
is_inputs_should_in_memory_[0] = true;
}
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
}
}
void SliceCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calculating 'dx', taking account of 'y' and 'z'
if (is_parameter_[0]) {
is_inputs_should_in_memory_[0] = true;
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) {
is_inputs_should_in_memory_[2] = true;
}
} else if (is_parameter_involve_[0]) {
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) {
is_inputs_should_in_memory_[2] = true;
}
}
if (!is_inputs_should_in_memory_[1]) {
is_inputs_should_in_memory_[1] = is_parameter_[1];
}
if (!is_inputs_should_in_memory_[2]) {
is_inputs_should_in_memory_[2] = is_parameter_[2];
}
}
void StridedSliceCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calculating 'dx', taking account of 'y', 'z' and 'w'
if (is_parameter_[0]) {
is_inputs_should_in_memory_[0] = true;
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) {
is_inputs_should_in_memory_[2] = true;
}
if ((prev_output_in_mem.find(3) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(3))) {
is_inputs_should_in_memory_[3] = true;
}
} else if (is_parameter_involve_[0]) {
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) {
is_inputs_should_in_memory_[2] = true;
}
if ((prev_output_in_mem.find(3) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(3))) {
is_inputs_should_in_memory_[3] = true;
}
}
if (!is_inputs_should_in_memory_[1]) {
is_inputs_should_in_memory_[1] = is_parameter_[1];
}
if (!is_inputs_should_in_memory_[2]) {
is_inputs_should_in_memory_[2] = is_parameter_[2];
}
if (!is_inputs_should_in_memory_[3]) {
is_inputs_should_in_memory_[3] = is_parameter_[3];
}
}
void DropOutDoMaskCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
void DropOutDoMaskCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calculating 'dx', taking account of 'y'
if (is_parameter_[0]) {
// 'x' is parameter, so it should be in memory.
is_inputs_should_in_memory_[0] = true;
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
// In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
is_inputs_should_in_memory_[1] = true;
}
} else if (is_parameter_involve_[0]) {
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
}
if (!is_inputs_should_in_memory_[1]) {
is_inputs_should_in_memory_[1] = is_parameter_[1];
}
is_inputs_should_in_memory_[2] = is_parameter_[2];
}
bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int64_t stage_id) {
CheckGlobalDeviceManager();
MS_EXCEPTION_IF_NULL(g_device_manager);
@ -612,8 +1087,8 @@ bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int64_t stage_
return (total_device_num == LongToSize(strategy0));
}
double ReduceMethodCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
double ReduceSumCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const {
double result = 0.0;
TensorInfo input0 = inputs[0];
TensorInfo output0 = outputs[0];
@ -634,8 +1109,8 @@ double ReduceMethodCost::GetForwardCommCost(const std::vector<TensorInfo> &input
return result;
}
double ReduceMethodCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
int64_t stage_id) const {
double ReduceSumCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
int64_t stage_id) const {
double result = 0.0;
if (is_parameter_[0]) {
TensorInfo input_tensor_info = inputs[0];
@ -657,8 +1132,8 @@ double ReduceMethodCost::GetBackwardCommCost(const std::vector<TensorInfo> &inpu
return result;
}
double ReduceMethodCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
double ReduceSumCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
double result = 0.0;
TensorInfo input0 = inputs[0];
TensorInfo output0 = outputs[0];
@ -679,6 +1154,30 @@ double ReduceMethodCost::GetForwardComputationCost(const std::vector<TensorInfo>
return result;
}
// Not taking account of output
void ReduceSumCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
void ReduceSumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calculating 'dx', taking account of 'y'
if (is_parameter_[0]) {
// 'x' is parameter, so it should be in memory.
is_inputs_should_in_memory_[0] = true;
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
// In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
is_inputs_should_in_memory_[1] = true;
}
} else if (is_parameter_involve_[0]) {
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
}
// Not taking account of 'y'
if (!is_inputs_should_in_memory_[1]) {
is_inputs_should_in_memory_[1] = is_parameter_[1];
}
}
double ReduceMeanCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
double result = 0.0;
@ -701,6 +1200,42 @@ double ReduceMeanCost::GetForwardComputationCost(const std::vector<TensorInfo> &
return result;
}
void ReduceMinCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
void ReduceMinCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calculating 'dx', taking account of 'y'
if (is_parameter_[0]) {
// 'x' is parameter, so it should be in memory.
is_inputs_should_in_memory_[0] = true;
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
// In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
is_inputs_should_in_memory_[1] = true;
}
} else if (is_parameter_involve_[0]) {
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
}
// Not taking account of 'y'
if (!is_inputs_should_in_memory_[1]) {
is_inputs_should_in_memory_[1] = is_parameter_[1];
}
}
void ArgMaxWithValueCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
void ArgMaxWithValueCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calculating 'dx', taking account of 'x'
if (is_parameter_[0]) {
is_inputs_should_in_memory_[0] = true;
} else if (is_parameter_involve_[0]) {
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
is_inputs_should_in_memory_[0] = true;
}
}
}
double DropOutCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
int64_t) const {
if (inputs.empty()) {
@ -760,6 +1295,52 @@ double GatherV2Cost::GetBackwardComputationCost(const std::vector<TensorInfo> &,
return 0.0;
}
// Not taking account of output
void GatherV2Cost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
void GatherV2Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calculating 'dx', taking account of 'y' and 'z'
if (is_parameter_[0]) {
// 'x' is parameter, so it should be in memory.
is_inputs_should_in_memory_[0] = true;
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) {
is_inputs_should_in_memory_[2] = true;
}
} else if (is_parameter_involve_[0]) {
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) {
is_inputs_should_in_memory_[2] = true;
}
}
if (!is_inputs_should_in_memory_[1]) {
is_inputs_should_in_memory_[1] = is_parameter_[1];
}
if (!is_inputs_should_in_memory_[2]) {
is_inputs_should_in_memory_[2] = is_parameter_[2];
}
}
void GetNextCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
void GetNextCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
if (is_inputs_should_in_memory_.size() == 0) {
return;
}
is_inputs_should_in_memory_[0] = is_parameter_[0];
}
void UniqueCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
void UniqueCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
is_inputs_should_in_memory_[0] = is_parameter_[0];
}
double LayerNormCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
int64_t stage_id) const {
double result = 0.0;
@ -808,6 +1389,24 @@ double LayerNormCost::GetForwardComputationCost(const std::vector<TensorInfo> &i
return result;
}
void LayerNormCost::CalculateOutputInMemory() {
is_output_should_in_memory_ = is_parameter_involve_[0] || is_parameter_involve_[1] || is_parameter_involve_[2];
}
void LayerNormCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calculating 'dx', taking account of both 'x' and 'y'
// When calculating 'dy', taking account of both 'x' and 'y'
if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
is_inputs_should_in_memory_[0] = true;
}
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
}
is_inputs_should_in_memory_[2] = is_parameter_[2];
}
double UniqueCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const {
return 0.0;
@ -924,6 +1523,12 @@ double UniformCandidateSamplerCost::GetForwardComputationCost(const std::vector<
return result;
}
void UniformCandidateSamplerCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
void UniformCandidateSamplerCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
is_inputs_should_in_memory_[0] = is_parameter_[0];
}
double GatherV2PCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
double result = 0.0;
@ -1019,6 +1624,29 @@ double UnsortedSegmentSumCost::GetForwardComputationCost(const std::vector<Tenso
return result;
}
// Not taking account of output
void UnsortedSegmentSumCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
// Taking account of input
void UnsortedSegmentSumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calculating 'dx', taking account of 'y'
if (is_parameter_[0]) {
is_inputs_should_in_memory_[0] = true;
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
} else if (is_parameter_involve_[0]) {
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
}
if (!is_inputs_should_in_memory_[1]) {
is_inputs_should_in_memory_[1] = is_parameter_[1];
}
is_inputs_should_in_memory_[2] = is_parameter_[2];
}
double UnsortedSegmentMinCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
TensorInfo input0 = inputs[0];
@ -1078,5 +1706,40 @@ double UnsortedSegmentMinCost::GetForwardComputationCost(const std::vector<Tenso
ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]); // ReduceMin
return result;
}
// Taking account of output
void UnsortedSegmentMinCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
// Taking account of input
void UnsortedSegmentMinCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calculating 'dx', taking account of 'x', 'y' and 'z'
if (is_parameter_involve_[0]) {
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
is_inputs_should_in_memory_[0] = true;
}
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
is_inputs_should_in_memory_[1] = true;
}
if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) {
is_inputs_should_in_memory_[2] = true;
}
}
if (!is_inputs_should_in_memory_[1]) {
is_inputs_should_in_memory_[1] = is_parameter_[1];
}
if (!is_inputs_should_in_memory_[2]) {
is_inputs_should_in_memory_[2] = is_parameter_[2];
}
}
// Not taking account of output
void VirtualDatasetCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
// Not taking account of input
void VirtualDatasetCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
for (size_t i = 0; i < is_inputs_should_in_memory_.size(); ++i) {
is_inputs_should_in_memory_[i] = is_parameter_[i];
}
}
} // namespace parallel
} // namespace mindspore

View File

@ -19,6 +19,7 @@
#include <memory>
#include <vector>
#include <map>
#include "frontend/parallel/device_manager.h"
#include "frontend/parallel/tensor_layout/tensor_info.h"
@ -47,16 +48,7 @@ double ListProduct(std::vector<T> vec) {
// entries timing the length of each entry's data type
class OperatorCost {
public:
explicit OperatorCost(bool is_inputs_related) : inputs_related_(is_inputs_related) {
// this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked
for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) {
is_parameter_.push_back(false);
is_parameter_involve_.push_back(false);
inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
}
}
OperatorCost() : inputs_related_(false) {
OperatorCost() {
// this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked
for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) {
is_parameter_.push_back(false);
@ -89,10 +81,17 @@ class OperatorCost {
const std::vector<TensorInfo> &outputs, int64_t stage_id) const = 0;
virtual double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs, int64_t stage_id) const = 0;
virtual void CalculateOutputInMemory() = 0;
virtual void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) = 0;
bool is_output_in_memory() const { return is_output_should_in_memory_; }
// per device PEAK memory cost in a training iteration
// Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled),
// Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-involved),
// plus necessary inputs.
virtual double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const;
// Contributing the input part for 'GetMemoryCost'
double GetInputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const;
// Contributing the output part for 'GetMemoryCost'
double GetOutputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const;
// per device memory cost in a inference phase
double GetMemoryCostForInference(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &) const;
@ -101,25 +100,25 @@ class OperatorCost {
// pre-operator that has parameters as input.
std::vector<bool> is_parameter_involve_;
int64_t output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved
// Whether the inputs are related or not? For example, TensorAdd's two inputs are independent (not related), while
// Mul's two inputs are dependent (related).
bool inputs_related_;
// for each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter
// For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter
std::vector<bool> is_parameter_;
// for each input and output, the followings record the number of bytes of each element
// Whether the input should keep in memory in training phase. It depends on the operator and the operator's
// previous operators.
std::vector<bool> is_inputs_should_in_memory_;
// Whether the output should keep in memory in training phase. It depends on 'is_parameter_involve_' and the operator.
bool is_output_should_in_memory_ = false;
// For each input and output, the followings record the number of bytes of each element
std::vector<size_t> inputs_type_lengths_;
std::vector<size_t> outputs_type_lengths_;
// Whether the output is critical, which means that this output is included in calculating peak memory cost
// in the inference phase.
int64_t is_outputs_critical_ = -1;
};
using OperatorCostPtr = std::shared_ptr<OperatorCost>;
class MatMulCost : public OperatorCost {
public:
explicit MatMulCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
MatMulCost() : OperatorCost(true) {}
MatMulCost() : OperatorCost() {}
~MatMulCost() override = default;
// per device communication cost
@ -141,14 +140,15 @@ class MatMulCost : public OperatorCost {
int64_t stage_id) const override;
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override;
void CalculateOutputInMemory() override;
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using MatMulCostPtr = std::shared_ptr<MatMulCost>;
using TensorDotCost = MatMulCost;
class ActivationCost : public OperatorCost {
class CastCost : public OperatorCost {
public:
explicit ActivationCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
ActivationCost() : OperatorCost(false) {}
~ActivationCost() override = default;
CastCost() : OperatorCost() {}
~CastCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override {
@ -166,21 +166,95 @@ class ActivationCost : public OperatorCost {
int64_t stage_id) const override;
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override;
// Not taking account of output
void CalculateOutputInMemory() override;
// Not Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using ActivationCostPtr = std::shared_ptr<ActivationCost>;
using TransposeCost = ActivationCost;
using TransposeCostPtr = std::shared_ptr<TransposeCost>;
using StridedSliceCost = ActivationCost;
using StridedSliceCostPtr = std::shared_ptr<StridedSliceCost>;
using SliceCost = ActivationCost;
using SliceCostPtr = std::shared_ptr<SliceCost>;
using SplitCost = ActivationCost;
using SplitCostPtr = std::shared_ptr<SplitCost>;
using RepeatElementsCost = CastCost;
using NegCost = CastCost;
using ExpandDimsCost = CastCost;
using SqueezeCost = CastCost;
using ConcatCost = CastCost;
using LogicalNotCost = CastCost;
using SignCost = CastCost;
using FloorCost = CastCost;
using RoundCost = CastCost;
using CeilCost = CastCost;
using ZerosLikeCost = CastCost;
using OnesLikeCost = CastCost;
using RangeCost = CastCost;
using SplitCost = CastCost;
class SqrtCost : public CastCost {
public:
SqrtCost() : CastCost() {}
~SqrtCost() override = default;
// Taking account of output, not taking accounting of input
void CalculateOutputInMemory() override;
};
using TanhCost = SqrtCost;
using EluCost = SqrtCost;
using ReLUCost = SqrtCost;
using SigmoidCost = SqrtCost;
using ReciprocalCost =
SqrtCost; // The derivative of 'Reciprocal' is different on 'Ascend' and 'GPU'. Here, 'Ascend' is chosen
using InvCost = SqrtCost;
using RsqrtCost = SqrtCost;
using AsinhCost = SqrtCost;
using AcoshCost = SqrtCost;
using ReLUV2Cost = SqrtCost;
class ReLU6Cost : public CastCost {
public:
ReLU6Cost() : CastCost() {}
~ReLU6Cost() override = default;
// Taking account of input, not taking account of output
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using SoftsignCost = ReLU6Cost;
using SoftplusCost = ReLU6Cost;
using SquareCost = ReLU6Cost;
using ExpCost = ReLU6Cost;
using LogCost = ReLU6Cost;
using CosCost = ReLU6Cost;
using ACosCost = ReLU6Cost;
using AbsCost = ReLU6Cost;
using TanCost = ReLU6Cost;
using SinCost = ReLU6Cost;
using SinhCost = ReLU6Cost;
using Log1pCost = ReLU6Cost;
using Expm1Cost = ReLU6Cost;
using CoshCost = ReLU6Cost;
using AtanhCost = ReLU6Cost;
using AtanCost = ReLU6Cost;
using AsinCost = ReLU6Cost;
using ErfCost = ReLU6Cost;
using ErfcCost = ReLU6Cost;
using ActivationInfoCost = ReLU6Cost;
class TransposeCost : public CastCost {
public:
TransposeCost() : CastCost() {}
~TransposeCost() override = default;
// Taking account of input, not taking account of output
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
class GeLUCost : public SqrtCost {
public:
GeLUCost() : SqrtCost() {}
~GeLUCost() override = default;
// Taking account of input and output
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using BesselI0eCost = GeLUCost;
using BesselI1eCost = GeLUCost;
using L2NormalizeCost = GeLUCost;
class SoftmaxCost : public OperatorCost {
public:
explicit SoftmaxCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
SoftmaxCost() : OperatorCost(false) {}
SoftmaxCost() : OperatorCost() {}
~SoftmaxCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
@ -199,21 +273,45 @@ class SoftmaxCost : public OperatorCost {
int64_t stage_id) const override;
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t) const override;
// Taking account of output
void CalculateOutputInMemory() override;
// Not Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
class TileCost : public SoftmaxCost {
public:
TileCost() : SoftmaxCost() {}
~TileCost() override = default;
// Not taking account of output
void CalculateOutputInMemory() override;
// Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
class PackCost : public SoftmaxCost {
public:
PackCost() : SoftmaxCost() {}
~PackCost() override = default;
// Not taking account of output
void CalculateOutputInMemory() override;
// Not taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
class BroadcastToCost : public SoftmaxCost {
public:
BroadcastToCost() : SoftmaxCost() {}
~BroadcastToCost() override = default;
// Not taking account of output
void CalculateOutputInMemory() override;
// Not Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using SoftmaxCostPtr = std::shared_ptr<SoftmaxCost>;
using TileCost = SoftmaxCost;
using TileCostPtr = std::shared_ptr<TileCost>;
using PackCost = TileCost;
using PackCostPtr = std::shared_ptr<PackCost>;
using ConcatCost = TileCost;
using ConcatCostPtr = std::shared_ptr<ConcatCost>;
using BroadcastToCost = SoftmaxCost;
using BroadcastToCostPtr = std::shared_ptr<BroadcastToCost>;
class TmpIdentityCost : public OperatorCost {
public:
explicit TmpIdentityCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
TmpIdentityCost() : OperatorCost(false) {}
TmpIdentityCost() : OperatorCost() {}
~TmpIdentityCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
@ -232,15 +330,16 @@ class TmpIdentityCost : public OperatorCost {
int64_t stage_id) const override;
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override;
// per device PEAK memory cost in a training iteration
double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const override;
// Not taking account of output
void CalculateOutputInMemory() override;
// Not taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using TmpIdentityCostPtr = std::shared_ptr<TmpIdentityCost>;
class BatchParallelCost : public OperatorCost {
public:
explicit BatchParallelCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
BatchParallelCost() : OperatorCost(false) {}
BatchParallelCost() : OperatorCost() {}
~BatchParallelCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
@ -259,13 +358,25 @@ class BatchParallelCost : public OperatorCost {
int64_t stage_id) const override;
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override;
// Not taking account of output
void CalculateOutputInMemory() override;
// Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
class SparseSoftmaxCrossEntropyWithLogitsCost : public BatchParallelCost {
public:
SparseSoftmaxCrossEntropyWithLogitsCost() : BatchParallelCost() {}
~SparseSoftmaxCrossEntropyWithLogitsCost() override = default;
// Taking account of output
void CalculateOutputInMemory() override;
// Not taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using BatchParallelCostPtr = std::shared_ptr<BatchParallelCost>;
class VirtualDatasetCost : public OperatorCost {
public:
explicit VirtualDatasetCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
VirtualDatasetCost() : OperatorCost(false) {}
VirtualDatasetCost() : OperatorCost() {}
~VirtualDatasetCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
@ -290,17 +401,15 @@ class VirtualDatasetCost : public OperatorCost {
int64_t) const override {
return 0.0;
}
// per device PEAK memory cost in a training iteration
double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const override {
return 0.0;
}
// Not taking account of output
void CalculateOutputInMemory() override;
// Not taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using VirtualDatasetCostPtr = std::shared_ptr<VirtualDatasetCost>;
class GeneratorBaseCost : public OperatorCost {
public:
explicit GeneratorBaseCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
GeneratorBaseCost() : OperatorCost(false) {}
GeneratorBaseCost() : OperatorCost() {}
~GeneratorBaseCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
@ -332,8 +441,7 @@ using GeneratorBaseCostPtr = std::shared_ptr<GeneratorBaseCost>;
class PReLUCost : public OperatorCost {
public:
explicit PReLUCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
PReLUCost() : OperatorCost(true) {}
PReLUCost() : OperatorCost() {}
~PReLUCost() override = default;
// per device communication cost
@ -355,13 +463,16 @@ class PReLUCost : public OperatorCost {
int64_t stage_id) const override;
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override;
// Not taking account of output
void CalculateOutputInMemory() override;
// Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using PReLUCostPtr = std::shared_ptr<PReLUCost>;
class OneHotCost : public OperatorCost {
public:
explicit OneHotCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
OneHotCost() : OperatorCost(true) {}
OneHotCost() : OperatorCost() {}
~OneHotCost() override = default;
// per device communication cost
@ -383,13 +494,16 @@ class OneHotCost : public OperatorCost {
int64_t stage_id) const override;
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override;
// Not taking account of output
void CalculateOutputInMemory() override;
// Not taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using OneHotCostPtr = std::shared_ptr<OneHotCost>;
class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost {
public:
explicit SoftmaxCrossEntropyWithLogitsCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
SoftmaxCrossEntropyWithLogitsCost() : OperatorCost(false) {}
SoftmaxCrossEntropyWithLogitsCost() : OperatorCost() {}
~SoftmaxCrossEntropyWithLogitsCost() override = default;
// per device communication cost
@ -411,13 +525,15 @@ class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost {
int64_t stage_id) const override;
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override;
// Taking account of output
void CalculateOutputInMemory() override;
// Not taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr<SoftmaxCrossEntropyWithLogitsCost>;
class ReshapeCost : public OperatorCost {
public:
explicit ReshapeCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
ReshapeCost() : OperatorCost(true) {}
ReshapeCost() : OperatorCost() {}
~ReshapeCost() override = default;
@ -444,14 +560,17 @@ class ReshapeCost : public OperatorCost {
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override;
// Not taking account of output
void CalculateOutputInMemory() override;
// Not taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using ReshapeCostPtr = std::shared_ptr<ReshapeCost>;
class ArithmeticCost : public OperatorCost {
class SubCost : public OperatorCost {
public:
explicit ArithmeticCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
ArithmeticCost() : OperatorCost(false) {}
~ArithmeticCost() override = default;
SubCost() : OperatorCost() {}
~SubCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override {
@ -470,16 +589,127 @@ class ArithmeticCost : public OperatorCost {
int64_t stage_id) const override;
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override;
// Not taking account of output
void CalculateOutputInMemory() override;
// Not taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using ArithmeticCostPtr = std::shared_ptr<ArithmeticCost>;
using BiasAddCost = ArithmeticCost;
using BiasAddCostPtr = std::shared_ptr<BiasAddCost>;
using TensorAddCost = SubCost;
using FloorDivCost = SubCost;
using AssignSubCost = SubCost;
using AssignAddCost = SubCost;
using LogicalAndCost = SubCost;
using LogicalOrCost = SubCost;
using BiasAddCost = SubCost;
using EqualCost = SubCost;
using ApproximateEqualCost = SubCost;
using NotEqualCost = SubCost;
using GreaterCost = SubCost;
using GreaterEqualCost = SubCost;
using LessCost = SubCost;
using LessEqualCost = SubCost;
class ReduceMethodCost : public OperatorCost {
class MulCost : public SubCost {
public:
explicit ReduceMethodCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
ReduceMethodCost() : OperatorCost(true) {}
~ReduceMethodCost() override = default;
MulCost() : SubCost() {}
~MulCost() override = default;
// Taking account of input, not taking account of output
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
class DivCost : public SubCost {
public:
DivCost() : SubCost() {}
~DivCost() override = default;
// Taking account of output
void CalculateOutputInMemory() override;
// Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using ReadDivCost = DivCost;
class ModCost : public SubCost {
public:
ModCost() : SubCost() {}
~ModCost() override = default;
// Taking account of input, not taking account of output
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using FloorModCost = ModCost;
class PowCost : public SubCost {
public:
PowCost() : SubCost() {}
~PowCost() override = default;
// Taking account of output
void CalculateOutputInMemory() override;
// Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
class AssignCost : public SubCost {
public:
AssignCost() : SubCost() {}
~AssignCost() override = default;
// Taking account of input, not taking account of output
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
class SigmoidCrossEntropyWithLogitsCost : public SubCost {
public:
SigmoidCrossEntropyWithLogitsCost() : SubCost() {}
~SigmoidCrossEntropyWithLogitsCost() override = default;
// Taking account of input, not taking account of output
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
class Atan2Cost : public SubCost {
public:
Atan2Cost() : SubCost() {}
~Atan2Cost() override = default;
// Taking account of input, not taking account of output
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
class DivNoNanCost : public SubCost {
public:
DivNoNanCost() : SubCost() {}
~DivNoNanCost() override = default;
// Taking account of output
void CalculateOutputInMemory() override;
// Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
class MaximumCost : public SubCost {
public:
MaximumCost() : SubCost() {}
~MaximumCost() override = default;
// Taking account of input, not taking account of output
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using MinimumCost = MaximumCost;
class SliceCost : public CastCost {
public:
SliceCost() : CastCost() {}
~SliceCost() override = default;
// Not taking account of output, taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
class StridedSliceCost : public CastCost {
public:
StridedSliceCost() : CastCost() {}
~StridedSliceCost() override = default;
// Not taking account of output, taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
class ReduceSumCost : public OperatorCost {
public:
ReduceSumCost() : OperatorCost() {}
~ReduceSumCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override {
@ -500,27 +730,50 @@ class ReduceMethodCost : public OperatorCost {
return 0.0;
}
void set_cross_batch(bool cb) { cross_batch_ = cb; }
// Not taking account of output
void CalculateOutputInMemory() override;
// Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
protected:
bool cross_batch_ = false;
};
using ReduceMethodCostPtr = std::shared_ptr<ReduceMethodCost>;
using ReduceMethodCost = ReduceSumCost;
class ReduceMeanCost : public ReduceMethodCost {
class ReduceMeanCost : public ReduceSumCost {
public:
explicit ReduceMeanCost(bool is_inputs_related) : ReduceMethodCost(is_inputs_related) {}
ReduceMeanCost() : ReduceMethodCost(true) {}
ReduceMeanCost() : ReduceSumCost() {}
~ReduceMeanCost() override = default;
double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override;
};
using ReduceMeanCostPtr = std::shared_ptr<ReduceMeanCost>;
class ReduceMinCost : public ReduceSumCost {
public:
ReduceMinCost() : ReduceSumCost() {}
~ReduceMinCost() override = default;
// Taking account of output
void CalculateOutputInMemory() override;
// Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using ReduceMaxCost = ReduceMinCost;
class ArgMaxWithValueCost : public ReduceSumCost {
public:
ArgMaxWithValueCost() : ReduceSumCost() {}
~ArgMaxWithValueCost() override = default;
// Taking account of output
void CalculateOutputInMemory() override;
// Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using ArgMinWithValueCost = ArgMaxWithValueCost;
class GetNextCost : public OperatorCost {
public:
explicit GetNextCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
GetNextCost() : OperatorCost(false) {}
GetNextCost() : OperatorCost() {}
~GetNextCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
@ -547,13 +800,17 @@ class GetNextCost : public OperatorCost {
int64_t) const override {
return 0.0;
}
// Not taking account of output
void CalculateOutputInMemory() override;
// Not Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using GetNextCostPtr = std::shared_ptr<GetNextCost>;
class DropOutCost : public OperatorCost {
// For memory cost, taking account of output, not taking account of input
class DropOutCost : public SqrtCost {
public:
explicit DropOutCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
DropOutCost() : OperatorCost(true) {}
DropOutCost() : SqrtCost() {}
~DropOutCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
@ -578,12 +835,19 @@ class DropOutCost : public OperatorCost {
}
};
using DropOutCostPtr = std::shared_ptr<DropOutCost>;
class DropOutDoMaskCost : public DropOutCost {
public:
DropOutDoMaskCost() : DropOutCost() {}
~DropOutDoMaskCost() override = default;
// Not taking account of output
void CalculateOutputInMemory() override;
// Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
class UnsortedSegmentSumCost : public OperatorCost {
public:
explicit UnsortedSegmentSumCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
UnsortedSegmentSumCost() : OperatorCost(true) {}
UnsortedSegmentSumCost() : OperatorCost() {}
~UnsortedSegmentSumCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
@ -602,14 +866,15 @@ class UnsortedSegmentSumCost : public OperatorCost {
int64_t) const override {
return 0.0;
}
// Not taking account of output
void CalculateOutputInMemory() override;
// Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using UnsortedSegmentSumCostPtr = std::shared_ptr<UnsortedSegmentSumCost>;
class UnsortedSegmentMinCost : public OperatorCost {
public:
explicit UnsortedSegmentMinCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
UnsortedSegmentMinCost() : OperatorCost(true) {}
UnsortedSegmentMinCost() : OperatorCost() {}
~UnsortedSegmentMinCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
@ -628,14 +893,16 @@ class UnsortedSegmentMinCost : public OperatorCost {
int64_t) const override {
return 0.0;
}
// Taking account of output
void CalculateOutputInMemory() override;
// Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using UnsortedSegmentMinCostPtr = std::shared_ptr<UnsortedSegmentMinCost>;
using UnsortedSegmentMaxCost = UnsortedSegmentMinCost;
class LayerNormCost : public OperatorCost {
public:
explicit LayerNormCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
LayerNormCost() : OperatorCost(true) {}
LayerNormCost() : OperatorCost() {}
~LayerNormCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
@ -656,14 +923,15 @@ class LayerNormCost : public OperatorCost {
int64_t) const override {
return 0.0;
}
// Taking account of output
void CalculateOutputInMemory() override;
// Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using DropOutCostPtr = std::shared_ptr<DropOutCost>;
class UniqueCost : public OperatorCost {
public:
explicit UniqueCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
UniqueCost() : OperatorCost(true) {}
UniqueCost() : OperatorCost() {}
~UniqueCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
@ -682,14 +950,15 @@ class UniqueCost : public OperatorCost {
int64_t stage_id) const override;
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t) const override;
// Taking account of output
void CalculateOutputInMemory() override;
// Not Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using UniqueCostPtr = std::shared_ptr<UniqueCost>;
class UniformCandidateSamplerCost : public OperatorCost {
public:
explicit UniformCandidateSamplerCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
UniformCandidateSamplerCost() : OperatorCost(false) {}
UniformCandidateSamplerCost() : OperatorCost() {}
~UniformCandidateSamplerCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
@ -714,14 +983,15 @@ class UniformCandidateSamplerCost : public OperatorCost {
int64_t) const override {
return 0.0;
}
// Not taking account of output
void CalculateOutputInMemory() override;
// Not Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using UniformCandidateSamplerCostPtr = std::shared_ptr<UniformCandidateSamplerCost>;
class GatherV2Cost : public OperatorCost {
public:
explicit GatherV2Cost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
GatherV2Cost() : OperatorCost(true) {}
GatherV2Cost() : OperatorCost() {}
~GatherV2Cost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
@ -740,14 +1010,15 @@ class GatherV2Cost : public OperatorCost {
int64_t stage_id) const override;
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t) const override;
// Not taking account of output
void CalculateOutputInMemory() override;
// Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using GatherV2CostPtr = std::shared_ptr<GatherV2Cost>;
class GatherV2PCost : public OperatorCost {
class GatherV2PCost : public GatherV2Cost {
public:
explicit GatherV2PCost(bool is_inputs_related) : OperatorCost(is_inputs_related), axis_(0) {}
GatherV2PCost() : OperatorCost(true), axis_(0) {}
GatherV2PCost() : GatherV2Cost(), axis_(0) {}
~GatherV2PCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
@ -773,8 +1044,6 @@ class GatherV2PCost : public OperatorCost {
int64_t axis_;
Shape strategy_;
};
using GatherV2PCostPtr = std::shared_ptr<GatherV2PCost>;
} // namespace parallel
} // namespace mindspore
#endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_

View File

@ -50,8 +50,8 @@ class ActivationBase : public OperatorInfo {
class Activation : public ActivationBase {
public:
Activation(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>(false)) {}
const PrimitiveAttrs &attrs, OperatorCostPtr cost)
: ActivationBase(name, inputs_shape, outputs_shape, attrs, cost) {}
~Activation() override = default;
Status GenerateStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
@ -64,7 +64,7 @@ class ActivationInfo : public Activation {
public:
ActivationInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: Activation(name, inputs_shape, outputs_shape, attrs) {}
: Activation(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationInfoCost>()) {}
~ActivationInfo() override = default;
protected:
@ -74,8 +74,8 @@ class ActivationInfo : public Activation {
class ActivationOther : public Activation {
public:
ActivationOther(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: Activation(name, inputs_shape, outputs_shape, attrs) {}
const PrimitiveAttrs &attrs, OperatorCostPtr cost)
: Activation(name, inputs_shape, outputs_shape, attrs, cost) {}
~ActivationOther() override = default;
protected:
@ -86,7 +86,7 @@ class GeluInfo : public ActivationOther {
public:
GeluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<GeLUCost>()) {}
~GeluInfo() override = default;
};
@ -94,7 +94,7 @@ class TanhInfo : public ActivationOther {
public:
TanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<TanhCost>()) {}
~TanhInfo() override = default;
};
@ -102,7 +102,7 @@ class Softmax : public ActivationBase {
public:
explicit Softmax(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>(false)) {}
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>()) {}
~Softmax() override = default;
Status GenerateStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
@ -134,7 +134,7 @@ class LogSoftmaxInfo : public Softmax {
class EluInfo : public ActivationOther {
public:
EluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<EluCost>()) {}
~EluInfo() override = default;
};
@ -142,7 +142,7 @@ class ReLUInfo : public ActivationOther {
public:
ReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLUCost>()) {}
~ReLUInfo() override = default;
};
@ -150,7 +150,7 @@ class RepeatElementsInfo : public ActivationOther {
public:
RepeatElementsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<RepeatElementsCost>()) {}
~RepeatElementsInfo() override = default;
};
@ -158,7 +158,7 @@ class ReLU6Info : public ActivationOther {
public:
ReLU6Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLU6Cost>()) {}
~ReLU6Info() override = default;
};
@ -166,7 +166,7 @@ class SoftsignInfo : public ActivationOther {
public:
SoftsignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftsignCost>()) {}
~SoftsignInfo() override = default;
};
@ -174,7 +174,7 @@ class SoftplusInfo : public ActivationOther {
public:
SoftplusInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftplusCost>()) {}
~SoftplusInfo() override = default;
};
@ -182,7 +182,7 @@ class CastInfo : public ActivationOther {
public:
CastInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<CastCost>()) {}
~CastInfo() override = default;
protected:
@ -193,14 +193,14 @@ class SqrtInfo : public ActivationOther {
public:
SqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SqrtCost>()) {}
~SqrtInfo() override = default;
};
class NegInfo : public ActivationOther {
public:
NegInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<NegCost>()) {}
~NegInfo() override = default;
};
@ -208,7 +208,7 @@ class ExpandDimsInfo : public ActivationOther {
public:
ExpandDimsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ExpandDimsCost>()) {}
~ExpandDimsInfo() override = default;
protected:
@ -228,7 +228,7 @@ class SqueezeInfo : public ActivationOther {
public:
SqueezeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SqueezeCost>()) {}
~SqueezeInfo() override = default;
protected:
@ -247,7 +247,7 @@ class SquareInfo : public ActivationOther {
public:
SquareInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SquareCost>()) {}
~SquareInfo() override = default;
};
@ -255,7 +255,7 @@ class SigmoidInfo : public ActivationOther {
public:
SigmoidInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SigmoidCost>()) {}
~SigmoidInfo() override = default;
};
@ -263,7 +263,7 @@ class DropoutInfo : public ActivationOther {
public:
DropoutInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>()) {}
~DropoutInfo() override = default;
Status GenerateStrategies(int64_t stage_id) override;

View File

@ -56,7 +56,7 @@ class ArithmeticBase : public OperatorInfo {
class SubInfo : public ArithmeticBase {
public:
SubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SubCost>()) {}
~SubInfo() override = default;
};
@ -64,28 +64,28 @@ class TensorAddInfo : public ArithmeticBase {
public:
TensorAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorAddCost>()) {}
~TensorAddInfo() override = default;
};
class MulInfo : public ArithmeticBase {
public:
MulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<MulCost>()) {}
~MulInfo() override = default;
};
class DivInfo : public ArithmeticBase {
public:
DivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<DivCost>()) {}
~DivInfo() override = default;
};
class ModInfo : public ArithmeticBase {
public:
ModInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ModCost>()) {}
~ModInfo() override = default;
};
@ -93,7 +93,7 @@ class RealDivInfo : public ArithmeticBase {
public:
RealDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReadDivCost>()) {}
~RealDivInfo() override = default;
};
@ -101,7 +101,7 @@ class FloorDivInfo : public ArithmeticBase {
public:
FloorDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<FloorDivCost>()) {}
~FloorDivInfo() override = default;
};
@ -109,14 +109,14 @@ class FloorModInfo : public ArithmeticBase {
public:
FloorModInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<FloorModCost>()) {}
~FloorModInfo() override = default;
};
class PowInfo : public ArithmeticBase {
public:
PowInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<PowCost>()) {}
~PowInfo() override = default;
};
@ -124,7 +124,7 @@ class AssignSubInfo : public ArithmeticBase {
public:
AssignSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<AssignSubCost>()) {}
~AssignSubInfo() override = default;
};
@ -132,7 +132,7 @@ class AssignInfo : public ArithmeticBase {
public:
AssignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<AssignCost>()) {}
~AssignInfo() override = default;
};
@ -140,7 +140,7 @@ class AssignAddInfo : public ArithmeticBase {
public:
AssignAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<AssignAddCost>()) {}
~AssignAddInfo() override = default;
};
@ -149,7 +149,8 @@ class SigmoidCrossEntropyWithLogitsInfo : public ArithmeticBase {
public:
SigmoidCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs,
std::make_shared<SigmoidCrossEntropyWithLogitsCost>()) {}
~SigmoidCrossEntropyWithLogitsInfo() override = default;
};
@ -157,7 +158,7 @@ class Atan2Info : public ArithmeticBase {
public:
Atan2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<Atan2Cost>()) {}
~Atan2Info() override = default;
};
@ -165,7 +166,7 @@ class DivNoNanInfo : public ArithmeticBase {
public:
DivNoNanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<DivNoNanCost>()) {}
~DivNoNanInfo() override = default;
};
@ -173,7 +174,7 @@ class LogicalAndInfo : public ArithmeticBase {
public:
LogicalAndInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LogicalAndCost>()) {}
~LogicalAndInfo() override = default;
};
@ -181,7 +182,7 @@ class LogicalOrInfo : public ArithmeticBase {
public:
LogicalOrInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LogicalOrCost>()) {}
~LogicalOrInfo() override = default;
};
} // namespace parallel

View File

@ -34,8 +34,7 @@ class BatchParallelInfo : public OperatorInfo {
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost), dev_num_(1) {}
BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>(false)),
dev_num_(1) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()), dev_num_(1) {}
~BatchParallelInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
@ -62,7 +61,8 @@ class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo {
public:
SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape,
const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>(true)) {}
: BatchParallelInfo(name, inputs_shape, outputs_shape, attrs,
std::make_shared<SparseSoftmaxCrossEntropyWithLogitsCost>()) {}
~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default;
void ReComputeBatchSplitFlagList() override;
};

View File

@ -34,7 +34,7 @@ class BiasAddInfo : public OperatorInfo {
public:
BiasAddInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BiasAddCost>(false)) {}
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BiasAddCost>()) {}
~BiasAddInfo() override = default;
Status Init(const StrategyPtr &strategy) override;

View File

@ -36,7 +36,7 @@ class BroadcastToInfo : public OperatorInfo {
public:
BroadcastToInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BroadcastToCost>(false)) {}
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BroadcastToCost>()) {}
~BroadcastToInfo() override = default;
Status Init(const StrategyPtr &strategy) override;

View File

@ -32,7 +32,7 @@ class EqualInfo : public ArithmeticBase {
public:
EqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<EqualCost>()) {}
~EqualInfo() override = default;
};
@ -40,7 +40,7 @@ class ApproximateEqualInfo : public ArithmeticBase {
public:
ApproximateEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ApproximateEqualCost>()) {}
~ApproximateEqualInfo() override = default;
};
@ -48,7 +48,7 @@ class NotEqualInfo : public ArithmeticBase {
public:
NotEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<NotEqualCost>()) {}
~NotEqualInfo() override = default;
};
@ -56,7 +56,7 @@ class MaximumInfo : public ArithmeticBase {
public:
MaximumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<MaximumCost>()) {}
~MaximumInfo() override = default;
};
@ -64,7 +64,7 @@ class MinimumInfo : public ArithmeticBase {
public:
MinimumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<MinimumCost>()) {}
~MinimumInfo() override = default;
};
@ -72,7 +72,7 @@ class GreaterInfo : public ArithmeticBase {
public:
GreaterInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<GreaterCost>()) {}
~GreaterInfo() override = default;
};
@ -80,7 +80,7 @@ class GreaterEqualInfo : public ArithmeticBase {
public:
GreaterEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<GreaterEqualCost>()) {}
~GreaterEqualInfo() override = default;
};
@ -88,7 +88,7 @@ class LessInfo : public ArithmeticBase {
public:
LessInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LessCost>()) {}
~LessInfo() override = default;
};
@ -96,7 +96,7 @@ class LessEqualInfo : public ArithmeticBase {
public:
LessEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LessEqualCost>()) {}
~LessEqualInfo() override = default;
};
} // namespace parallel

View File

@ -33,7 +33,7 @@ class ConcatInfo : public OperatorInfo {
public:
ConcatInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ConcatCost>(false)) {}
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ConcatCost>()) {}
~ConcatInfo() override = default;
Status Init(const StrategyPtr &strategy) override;

View File

@ -33,7 +33,7 @@ class DropoutDoMaskInfo : public OperatorInfo {
public:
DropoutDoMaskInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>(true)) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutDoMaskCost>()) {}
~DropoutDoMaskInfo() override = default;
Status Init(const StrategyPtr &strategy) override;

View File

@ -20,6 +20,7 @@
#include <string>
#include <unordered_map>
#include <vector>
#include <memory>
#include "ir/value.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/ops_info/activation_info.h"
@ -30,21 +31,21 @@ namespace parallel {
class ExpInfo : public ActivationOther {
public:
ExpInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ExpCost>()) {}
~ExpInfo() override = default;
};
class LogInfo : public ActivationOther {
public:
LogInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<LogCost>()) {}
~LogInfo() override = default;
};
class CosInfo : public ActivationOther {
public:
CosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<CosCost>()) {}
~CosInfo() override = default;
};
@ -52,7 +53,7 @@ class ACosInfo : public ActivationOther {
public:
ACosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ACosCost>()) {}
~ACosInfo() override = default;
};
@ -60,14 +61,14 @@ class LogicalNotInfo : public ActivationOther {
public:
LogicalNotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<LogicalNotCost>()) {}
~LogicalNotInfo() override = default;
};
class AbsInfo : public ActivationOther {
public:
AbsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AbsCost>()) {}
~AbsInfo() override = default;
};
@ -75,7 +76,7 @@ class SignInfo : public ActivationOther {
public:
SignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SignCost>()) {}
~SignInfo() override = default;
};
@ -83,7 +84,7 @@ class FloorInfo : public ActivationOther {
public:
FloorInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<FloorCost>()) {}
~FloorInfo() override = default;
};
@ -91,7 +92,7 @@ class RoundInfo : public ActivationOther {
public:
RoundInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<RoundCost>()) {}
~RoundInfo() override = default;
};
@ -99,14 +100,14 @@ class ReciprocalInfo : public ActivationOther {
public:
ReciprocalInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReciprocalCost>()) {}
~ReciprocalInfo() override = default;
};
class InvInfo : public ActivationOther {
public:
InvInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<InvCost>()) {}
~InvInfo() override = default;
};
@ -114,21 +115,21 @@ class RsqrtInfo : public ActivationOther {
public:
RsqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<RsqrtCost>()) {}
~RsqrtInfo() override = default;
};
class TanInfo : public ActivationOther {
public:
TanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<TanCost>()) {}
~TanInfo() override = default;
};
class SinInfo : public ActivationOther {
public:
SinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SinCost>()) {}
~SinInfo() override = default;
};
@ -136,7 +137,7 @@ class SinhInfo : public ActivationOther {
public:
SinhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SinhCost>()) {}
~SinhInfo() override = default;
};
@ -144,7 +145,7 @@ class Log1pInfo : public ActivationOther {
public:
Log1pInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<Log1pCost>()) {}
~Log1pInfo() override = default;
};
@ -152,7 +153,7 @@ class Expm1Info : public ActivationOther {
public:
Expm1Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<Expm1Cost>()) {}
~Expm1Info() override = default;
};
@ -160,7 +161,7 @@ class CoshInfo : public ActivationOther {
public:
CoshInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<CoshCost>()) {}
~CoshInfo() override = default;
};
@ -168,7 +169,7 @@ class CeilInfo : public ActivationOther {
public:
CeilInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<CeilCost>()) {}
~CeilInfo() override = default;
};
@ -176,7 +177,7 @@ class AtanhInfo : public ActivationOther {
public:
AtanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AtanhCost>()) {}
~AtanhInfo() override = default;
};
@ -184,7 +185,7 @@ class AtanInfo : public ActivationOther {
public:
AtanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AtanCost>()) {}
~AtanInfo() override = default;
};
@ -192,7 +193,7 @@ class AsinInfo : public ActivationOther {
public:
AsinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AsinCost>()) {}
~AsinInfo() override = default;
};
@ -200,7 +201,7 @@ class AsinhInfo : public ActivationOther {
public:
AsinhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AsinhCost>()) {}
~AsinhInfo() override = default;
};
@ -208,14 +209,14 @@ class AcoshInfo : public ActivationOther {
public:
AcoshInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AcoshCost>()) {}
~AcoshInfo() override = default;
};
class ErfInfo : public ActivationOther {
public:
ErfInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ErfCost>()) {}
~ErfInfo() override = default;
};
@ -223,7 +224,7 @@ class ErfcInfo : public ActivationOther {
public:
ErfcInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ErfcCost>()) {}
~ErfcInfo() override = default;
};
@ -231,7 +232,7 @@ class ZerosLikeInfo : public ActivationOther {
public:
ZerosLikeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ZerosLikeCost>()) {}
~ZerosLikeInfo() override = default;
};
@ -239,7 +240,7 @@ class OnesLikeInfo : public ActivationOther {
public:
OnesLikeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<OnesLikeCost>()) {}
~OnesLikeInfo() override = default;
};
@ -247,7 +248,7 @@ class BesselI0eInfo : public ActivationOther {
public:
BesselI0eInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<BesselI0eCost>()) {}
~BesselI0eInfo() override = default;
};
@ -255,7 +256,7 @@ class BesselI1eInfo : public ActivationOther {
public:
BesselI1eInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<BesselI1eCost>()) {}
~BesselI1eInfo() override = default;
};
} // namespace parallel

View File

@ -32,7 +32,7 @@ class GetNextInfo : public OperatorInfo {
public:
GetNextInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>(false)) {}
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>()) {}
~GetNextInfo() override = default;
Status Init(const StrategyPtr &strategy) override;

View File

@ -33,7 +33,7 @@ class L2NormalizeInfo : public Activation {
public:
L2NormalizeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: Activation(name, inputs_shape, outputs_shape, attrs) {}
: Activation(name, inputs_shape, outputs_shape, attrs, std::make_shared<L2NormalizeCost>()) {}
~L2NormalizeInfo() override = default;
Status GenerateStrategies(int64_t stage_id) override;

View File

@ -40,7 +40,7 @@ class LayerNormInfo : public OperatorInfo {
public:
LayerNormInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<LayerNormCost>(true)),
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<LayerNormCost>()),
begin_norm_axis_(0) {}
~LayerNormInfo() override = default;

View File

@ -36,8 +36,7 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo {
public:
SoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs,
std::make_shared<SoftmaxCrossEntropyWithLogitsCost>(false)) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCrossEntropyWithLogitsCost>()) {}
~SoftmaxCrossEntropyWithLogitsInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;

View File

@ -34,7 +34,7 @@ class MatMulBase : public OperatorInfo {
public:
MatMulBase(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>(true)) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>()) {}
~MatMulBase() override = default;
Status Init(const StrategyPtr &strategy) override;

View File

@ -33,7 +33,7 @@ class OneHotInfo : public OperatorInfo {
public:
OneHotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<OneHotCost>(false)) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<OneHotCost>()) {}
~OneHotInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;

View File

@ -1204,6 +1204,20 @@ int64_t OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() {
} else {
is_output_parameter_involve_ = 0;
}
// Set 'is_parameter_involve_' and 'is_output_parameter_involve_' into operatorCost, which are used in
// calculating 'inputs_in_memory' and 'output_in_memory', respectively.
operator_cost()->set_is_parameter_involve(is_parameter_involve_);
operator_cost()->set_output_parameter_involve(is_output_parameter_involve_);
// Calculating 'output_in_memory'
operator_cost()->CalculateOutputInMemory();
// Calculating 'inputs_in_memory'
std::map<size_t, bool> input_in_memory;
for (auto &p_edge : prev_edges) {
auto input_index = p_edge->next_op_input_index();
auto is_in_mem = p_edge->prev_operator()->operator_cost()->is_output_in_memory();
input_in_memory.emplace(std::make_pair(input_index, is_in_mem));
}
operator_cost()->CalculateInputsInMemory(input_in_memory);
return is_output_parameter_involve_;
}
@ -1220,14 +1234,10 @@ Status OperatorInfo::set_is_parameter(const std::vector<bool> &is_parameter) {
}
Status OperatorInfo::CalculateMemoryCost() {
// First, set the 'is_parameter_involve_' and 'is_output_parameter_involve_' into OperatorCost, which are necessary to
// calculate memory cost.
if (is_parameter_involve_.size() != is_parameter_.size()) {
MS_LOG(ERROR) << "'is_parameter_' does not have the same number of input size of 'is_parameter_involve_'.";
return FAILED;
}
operator_cost()->set_is_parameter_involve(is_parameter_involve_);
operator_cost()->set_output_parameter_involve(is_output_parameter_involve_);
// Set the memory cost in the 'strategy_cost_'
for (auto &swc : strategy_cost_) {
auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr);

View File

@ -33,7 +33,7 @@ class PackInfo : public OperatorInfo {
public:
PackInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<PackCost>(false)) {}
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<PackCost>()) {}
~PackInfo() override = default;
Status Init(const StrategyPtr &strategy) override;

View File

@ -35,7 +35,7 @@ class PReLUInfo : public OperatorInfo {
public:
PReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<PReLUCost>(true)) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<PReLUCost>()) {}
~PReLUInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;

View File

@ -39,7 +39,7 @@ class RangeInfo : public OperatorInfo {
public:
RangeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>(true)) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<RangeCost>()) {}
~RangeInfo() override = default;
Status Init(const StrategyPtr &strategy) override;

View File

@ -33,8 +33,8 @@ namespace parallel {
class ReduceMethod : public OperatorInfo {
public:
ReduceMethod(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMethodCost>(true)) {}
const PrimitiveAttrs &attrs, OperatorCostPtr cost)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost) {}
~ReduceMethod() override = default;
Status Init(const StrategyPtr &strategy) override;
@ -62,7 +62,7 @@ class ReduceMaxInfo : public ReduceMethod {
public:
ReduceMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ReduceMethod(name, inputs_shape, outputs_shape, attrs) {
: ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMaxCost>()) {
reduce_method_ = REDUCE_OP_MAX;
}
@ -73,7 +73,7 @@ class ArgMaxWithValueInfo : public ReduceMethod {
public:
ArgMaxWithValueInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ReduceMethod(name, inputs_shape, outputs_shape, attrs) {
: ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArgMaxWithValueCost>()) {
reduce_method_ = REDUCE_OP_MAX;
}
@ -105,9 +105,7 @@ class ReduceMeanInfo : public ReduceMethod {
public:
ReduceMeanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ReduceMethod(name, inputs_shape, outputs_shape, attrs) {
set_cost(std::make_shared<ReduceMeanCost>());
}
: ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMeanCost>()) {}
~ReduceMeanInfo() override = default;
@ -119,7 +117,7 @@ class ReduceSumInfo : public ReduceMethod {
public:
ReduceSumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ReduceMethod(name, inputs_shape, outputs_shape, attrs) {
: ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceSumCost>()) {
reduce_method_ = REDUCE_OP_SUM;
}
@ -130,7 +128,7 @@ class ReduceMinInfo : public ReduceMethod {
public:
ReduceMinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ReduceMethod(name, inputs_shape, outputs_shape, attrs) {
: ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMinCost>()) {
reduce_method_ = REDUCE_OP_MIN;
}

View File

@ -37,7 +37,7 @@ class ReLUV2Info : public OperatorInfo {
public:
ReLUV2Info(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>(false)) {}
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLUV2Cost>()) {}
~ReLUV2Info() override = default;
Status Init(const StrategyPtr &strategy) override;

View File

@ -36,7 +36,7 @@ class ReshapeInfo : public OperatorInfo {
public:
ReshapeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>(false)),
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>()),
dev_num_(0),
pre_operator_index_(0),
next_operator_index_(0),

View File

@ -34,7 +34,7 @@ class SliceInfo : public OperatorInfo {
public:
SliceInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<SliceCost>(false)),
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<SliceCost>()),
slice_axis_(-1) {}
~SliceInfo() override = default;

View File

@ -31,7 +31,7 @@ class SplitInfo : public OperatorInfo {
public:
SplitInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ConcatCost>(false)) {}
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<SplitCost>()) {}
~SplitInfo() override = default;
Status Init(const StrategyPtr &strategy) override;

View File

@ -273,6 +273,8 @@ Status StridedSliceInfo::GenerateStrategies(int64_t stage_id) {
PrintStrategy(sp);
}
}
MS_LOG(INFO) << name() << ", finishing GenerateStrategies().";
return SUCCESS;
}

View File

@ -34,7 +34,7 @@ class StridedSliceInfo : public OperatorInfo {
public:
StridedSliceInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<StridedSliceCost>(false)) {}
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<StridedSliceCost>()) {}
~StridedSliceInfo() override = default;
Status Init(const StrategyPtr &strategy) override;

View File

@ -41,7 +41,7 @@ class TensorDotInfo : public OperatorInfo {
public:
TensorDotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>(true)) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorDotCost>()) {}
~TensorDotInfo() override = default;
Status Init(const StrategyPtr &strategy) override;

View File

@ -34,7 +34,7 @@ class TileInfo : public OperatorInfo {
public:
TileInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<TileCost>(false)) {}
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<TileCost>()) {}
~TileInfo() override = default;
Status Init(const StrategyPtr &strategy) override;

View File

@ -34,7 +34,7 @@ class TmpIdentityInfo : public OperatorInfo {
public:
TmpIdentityInfo(const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs,
const std::string &name = IDENTITY_INFO)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TmpIdentityCost>(false)) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TmpIdentityCost>()) {}
~TmpIdentityInfo() override = default;
Status Init(const StrategyPtr &strategy) override;

View File

@ -35,7 +35,7 @@ class TransposeInfo : public OperatorInfo {
public:
TransposeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TransposeCost>(false)) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TransposeCost>()) {}
~TransposeInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;

View File

@ -32,7 +32,7 @@ class UniqueInfo : public OperatorInfo {
public:
UniqueInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>(false)) {}
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<UniqueCost>()) {}
~UniqueInfo() override = default;
Status Init(const StrategyPtr &strategy) override;

View File

@ -82,7 +82,7 @@ class UnsortedSegmentMaxInfo : public UnsortedSegmentOpInfo {
public:
UnsortedSegmentMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMinCost>()) {}
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMaxCost>()) {}
~UnsortedSegmentMaxInfo() override = default;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;

View File

@ -32,7 +32,7 @@ class VirtualDatasetInfo : public OperatorInfo {
public:
VirtualDatasetInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<VirtualDatasetCost>(false)) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<VirtualDatasetCost>()) {}
~VirtualDatasetInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;

View File

@ -85,11 +85,11 @@ class TestActivationCost : public UT::Common {
TestActivationCost() {}
void SetUp();
void TearDown();
ActivationCost ac_cost_;
ActivationInfoCost ac_cost_;
};
void TestActivationCost::SetUp() {
ac_cost_ = ActivationCost();
ac_cost_ = ActivationInfoCost();
RankList dev_list;
for (int32_t i = 0; i < 1050; i++) {