forked from mindspore-Ecosystem/mindspore
!10683 [Auto paralel] Change memory cost calculation in auto-parallel
From: @xiaoda_zh Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsuteng
This commit is contained in:
commit
fae4246811
|
@ -27,6 +27,7 @@ void OperatorCost::set_is_parameter(const std::vector<bool> &is_parameter) { is_
|
|||
|
||||
void OperatorCost::set_is_parameter_involve(const std::vector<bool> &is_parameter_inv) {
|
||||
is_parameter_involve_ = is_parameter_inv;
|
||||
is_inputs_should_in_memory_ = std::vector<bool>(is_parameter_involve_.size(), false);
|
||||
}
|
||||
|
||||
void OperatorCost::set_output_parameter_involve(int64_t output_para) { output_parameter_involve_ = output_para; }
|
||||
|
@ -41,27 +42,28 @@ void OperatorCost::set_output_critical(int64_t critical) { is_outputs_critical_
|
|||
|
||||
double OperatorCost::GetMemoryCost(const std::vector<TensorInfo> &inputs,
|
||||
const std::vector<TensorInfo> &outputs) const {
|
||||
return GetInputMemoryCost(inputs, outputs) + GetOutputMemoryCost(inputs, outputs);
|
||||
}
|
||||
|
||||
double OperatorCost::GetInputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &) const {
|
||||
double result = 0.0;
|
||||
if (output_parameter_involve_ == 1) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_inputs_should_in_memory_[i]) {
|
||||
result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
double OperatorCost::GetOutputMemoryCost(const std::vector<TensorInfo> &inputs,
|
||||
const std::vector<TensorInfo> &outputs) const {
|
||||
double result = 0.0;
|
||||
if (is_output_should_in_memory_) {
|
||||
// When this operator has multiple outputs, they all contributes to the memory.
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
result += ListProduct(outputs[i].slice_shape()) * static_cast<double>(outputs_type_lengths_[i]);
|
||||
}
|
||||
bool is_any_para_inv =
|
||||
std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; });
|
||||
if (is_any_para_inv) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_parameter_[i]) {
|
||||
result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
|
||||
} else if (inputs_related_ && (!is_parameter_involve_[i])) {
|
||||
// When the inputs of this operator are related, and they are not parameter-involved, then they are included
|
||||
// in the memory cost.
|
||||
result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -166,16 +168,43 @@ double MatMulCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inp
|
|||
return result;
|
||||
}
|
||||
|
||||
// Not taking account of output
|
||||
void MatMulCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
|
||||
|
||||
// Taking account of input
|
||||
void MatMulCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
if (is_parameter_[0]) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
} else if (is_parameter_involve_[0]) {
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_parameter_[1]) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
}
|
||||
} else if (is_parameter_involve_[1]) {
|
||||
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return the per device communication cost in the forward phase.
|
||||
double ActivationCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
|
||||
int64_t) const {
|
||||
double CastCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
|
||||
// ReLU is the element-wise operator, thus it does not need communication in the forward phase
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Return the per device communication cost in the backward phase.
|
||||
double ActivationCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
|
||||
int64_t stage_id) const {
|
||||
double CastCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
|
||||
int64_t stage_id) const {
|
||||
double result = 0.0;
|
||||
if (is_parameter_[0]) {
|
||||
TensorInfo input1 = inputs[0];
|
||||
|
@ -196,8 +225,8 @@ double ActivationCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs
|
|||
|
||||
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
|
||||
// this operator uses
|
||||
double ActivationCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
|
||||
int64_t) const {
|
||||
double CastCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
|
||||
int64_t) const {
|
||||
TensorInfo input0 = inputs[0];
|
||||
Shape input0_slice_shape = input0.slice_shape();
|
||||
return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
|
||||
|
@ -205,11 +234,33 @@ double ActivationCost::GetForwardComputationCost(const std::vector<TensorInfo> &
|
|||
|
||||
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
|
||||
// this operator uses
|
||||
double ActivationCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
|
||||
int64_t) const {
|
||||
double CastCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
|
||||
int64_t) const {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Not taking account of output
|
||||
void CastCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
|
||||
|
||||
// Not taking account of input
|
||||
void CastCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
is_inputs_should_in_memory_[0] = is_parameter_[0];
|
||||
}
|
||||
|
||||
// Taking account of output
|
||||
void SqrtCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
|
||||
|
||||
// Taking account of input
|
||||
void GeLUCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
if (is_parameter_[0]) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
} else if (is_parameter_involve_[0]) {
|
||||
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return the per device communication cost in the forward phase.
|
||||
double SoftmaxCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
|
||||
int64_t) const {
|
||||
|
@ -259,6 +310,81 @@ double SoftmaxCost::GetBackwardComputationCost(const std::vector<mindspore::para
|
|||
return 0.0;
|
||||
}
|
||||
|
||||
// Taking account of output
|
||||
void SoftmaxCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
|
||||
|
||||
// Not taking account of input
|
||||
void SoftmaxCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
is_inputs_should_in_memory_[0] = is_parameter_[0];
|
||||
}
|
||||
|
||||
// Not taking account of output
|
||||
void PackCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
|
||||
|
||||
// Not taking account of input
|
||||
void PackCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
is_inputs_should_in_memory_[0] = is_parameter_[0];
|
||||
}
|
||||
|
||||
// Not taking account of output
|
||||
void TileCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
|
||||
|
||||
// Taking account of input
|
||||
void TileCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calculating 'dx', taking account of 'y'
|
||||
if (is_parameter_[0]) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
} else if (is_parameter_involve_[0]) {
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_inputs_should_in_memory_[1]) {
|
||||
is_inputs_should_in_memory_[1] = is_parameter_[1];
|
||||
}
|
||||
}
|
||||
|
||||
// Not taking account of output
|
||||
void BroadcastToCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
|
||||
|
||||
void BroadcastToCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
is_inputs_should_in_memory_[0] = is_parameter_[0];
|
||||
}
|
||||
|
||||
// Taking account of input
|
||||
void ReLU6Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
if (is_parameter_[0]) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
} else if (is_parameter_involve_[0]) {
|
||||
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Taking account of input
|
||||
void TransposeCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calulating 'dx', taking account of 'y'
|
||||
if (is_parameter_[0]) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
} else if (is_parameter_involve_[0]) {
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_inputs_should_in_memory_[1]) {
|
||||
is_inputs_should_in_memory_[1] = is_parameter_[1];
|
||||
}
|
||||
}
|
||||
|
||||
// return the per device communication cost in the forward phase.
|
||||
double TmpIdentityCost::GetForwardCommCost(const std::vector<mindspore::parallel::TensorInfo> &,
|
||||
const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
|
||||
|
@ -288,9 +414,12 @@ double TmpIdentityCost::GetBackwardComputationCost(const std::vector<mindspore::
|
|||
return 0.0;
|
||||
}
|
||||
|
||||
// Return the per device PEAK memory cost contributed by this operator in a training iteration.
|
||||
double TmpIdentityCost::GetMemoryCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &) const {
|
||||
return 0.0;
|
||||
// Not taking account of output
|
||||
void TmpIdentityCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
|
||||
|
||||
// Not taking account of input
|
||||
void TmpIdentityCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
is_inputs_should_in_memory_[0] = is_parameter_[0];
|
||||
}
|
||||
|
||||
double BatchParallelCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &inputs,
|
||||
|
@ -334,6 +463,42 @@ double BatchParallelCost::GetBackwardCommCost(const std::vector<TensorInfo> &inp
|
|||
|
||||
return result;
|
||||
}
|
||||
|
||||
void BatchParallelCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
|
||||
|
||||
void BatchParallelCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
if (is_parameter_[0]) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
} else if (is_parameter_involve_[0]) {
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_parameter_[1]) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
}
|
||||
} else if (is_parameter_involve_[1]) {
|
||||
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SparseSoftmaxCrossEntropyWithLogitsCost::CalculateOutputInMemory() {
|
||||
is_output_should_in_memory_ = is_parameter_involve_[0];
|
||||
}
|
||||
|
||||
void SparseSoftmaxCrossEntropyWithLogitsCost::CalculateInputsInMemory(
|
||||
const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
is_inputs_should_in_memory_[0] = is_parameter_[0];
|
||||
is_inputs_should_in_memory_[1] = is_parameter_[1];
|
||||
}
|
||||
// return the per device communication cost in the forward phase.
|
||||
double PReLUCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
|
||||
// prelu does not need communication in the forward phase
|
||||
|
@ -401,6 +566,21 @@ double PReLUCost::GetBackwardComputationCost(const std::vector<mindspore::parall
|
|||
return result;
|
||||
}
|
||||
|
||||
void PReLUCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
|
||||
|
||||
void PReLUCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calculating 'dx', taking account of both 'x' and 'y';
|
||||
// when calculating 'dy', taking account of both 'x' and 'y'
|
||||
if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
|
||||
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
}
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// return the per device communication cost in the forward phase.
|
||||
double OneHotCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
|
||||
// onehot does not need communication in the forward phase
|
||||
|
@ -430,6 +610,17 @@ double OneHotCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, c
|
|||
return 0.0;
|
||||
}
|
||||
|
||||
// Not taking account of output
|
||||
void OneHotCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
|
||||
|
||||
// Not taking account of input
|
||||
void OneHotCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
is_inputs_should_in_memory_[0] = is_parameter_[0];
|
||||
is_inputs_should_in_memory_[1] = is_parameter_[1];
|
||||
is_inputs_should_in_memory_[2] = is_parameter_[2];
|
||||
is_inputs_should_in_memory_[3] = is_parameter_[3];
|
||||
}
|
||||
|
||||
// return the per device communication cost in the forward phase.
|
||||
double SoftmaxCrossEntropyWithLogitsCost::GetForwardCommCost(const std::vector<TensorInfo> &,
|
||||
const std::vector<TensorInfo> &, int64_t) const {
|
||||
|
@ -463,6 +654,16 @@ double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::
|
|||
return 0.0;
|
||||
}
|
||||
|
||||
// Taking account of output
|
||||
void SoftmaxCrossEntropyWithLogitsCost::CalculateOutputInMemory() {
|
||||
is_output_should_in_memory_ = is_parameter_involve_[0];
|
||||
}
|
||||
|
||||
void SoftmaxCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
is_inputs_should_in_memory_[0] = is_parameter_[0];
|
||||
is_inputs_should_in_memory_[1] = is_parameter_[1];
|
||||
}
|
||||
|
||||
// return the per device communication cost in the forward phase.
|
||||
double ReshapeCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
int64_t stage_id) const {
|
||||
|
@ -524,50 +725,22 @@ double ReshapeCost::GetBackwardComputationCost(const std::vector<mindspore::para
|
|||
return 0.0;
|
||||
}
|
||||
|
||||
double ArithmeticCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
|
||||
int64_t) const {
|
||||
void ReshapeCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
|
||||
|
||||
void ReshapeCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
is_inputs_should_in_memory_[0] = is_parameter_[0];
|
||||
is_inputs_should_in_memory_[1] = is_parameter_[1];
|
||||
}
|
||||
|
||||
double SubCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
|
||||
int64_t) const {
|
||||
double result;
|
||||
result = ListProduct(inputs[0].slice_shape()) * static_cast<double>(inputs_type_lengths_[0]) +
|
||||
ListProduct(inputs[1].slice_shape()) * static_cast<double>(inputs_type_lengths_[1]);
|
||||
return result;
|
||||
}
|
||||
|
||||
double ArithmeticCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs,
|
||||
const std::vector<TensorInfo> &, int64_t stage_id) const {
|
||||
double result = 0.0;
|
||||
CheckGlobalDeviceManager();
|
||||
MS_EXCEPTION_IF_NULL(g_device_manager);
|
||||
auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
|
||||
|
||||
if (is_parameter_[0]) {
|
||||
TensorInfo input_a_tensor_info = inputs[0];
|
||||
Shape input_a_shape = input_a_tensor_info.shape();
|
||||
Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
|
||||
int64_t used_device_num = 1;
|
||||
for (size_t i = 0; i < input_a_shape.size(); ++i) {
|
||||
used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
|
||||
}
|
||||
|
||||
if (total_device_num != LongToSize(used_device_num))
|
||||
result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
|
||||
}
|
||||
|
||||
if (is_parameter_[1]) {
|
||||
TensorInfo input_b_tensor_info = inputs[1];
|
||||
Shape input_b_shape = input_b_tensor_info.shape();
|
||||
Shape input_b_slice_shape = input_b_tensor_info.slice_shape();
|
||||
int64_t used_device_num = 1;
|
||||
for (size_t i = 0; i < input_b_shape.size(); ++i) {
|
||||
used_device_num *= input_b_shape[i] / input_b_slice_shape[i];
|
||||
}
|
||||
|
||||
if (total_device_num != LongToSize(used_device_num))
|
||||
result += ListProduct(input_b_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
double ArithmeticCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
|
||||
double SubCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
|
||||
int64_t stage_id) const {
|
||||
double result = 0.0;
|
||||
CheckGlobalDeviceManager();
|
||||
|
@ -587,6 +760,41 @@ double ArithmeticCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs
|
|||
result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
|
||||
}
|
||||
|
||||
if (is_parameter_[1]) {
|
||||
TensorInfo input_b_tensor_info = inputs[1];
|
||||
Shape input_b_shape = input_b_tensor_info.shape();
|
||||
Shape input_b_slice_shape = input_b_tensor_info.slice_shape();
|
||||
int64_t used_device_num = 1;
|
||||
for (size_t i = 0; i < input_b_shape.size(); ++i) {
|
||||
used_device_num *= input_b_shape[i] / input_b_slice_shape[i];
|
||||
}
|
||||
|
||||
if (total_device_num != LongToSize(used_device_num))
|
||||
result += ListProduct(input_b_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
double SubCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
|
||||
int64_t stage_id) const {
|
||||
double result = 0.0;
|
||||
CheckGlobalDeviceManager();
|
||||
MS_EXCEPTION_IF_NULL(g_device_manager);
|
||||
auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
|
||||
|
||||
if (is_parameter_[0]) {
|
||||
TensorInfo input_a_tensor_info = inputs[0];
|
||||
Shape input_a_shape = input_a_tensor_info.shape();
|
||||
Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
|
||||
int64_t used_device_num = 1;
|
||||
for (size_t i = 0; i < input_a_shape.size(); ++i) {
|
||||
used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
|
||||
}
|
||||
|
||||
if (total_device_num != LongToSize(used_device_num))
|
||||
result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
|
||||
}
|
||||
|
||||
if (is_parameter_[1]) {
|
||||
TensorInfo input_b_tensor_info = inputs[1];
|
||||
Shape input_b_shape = input_b_tensor_info.shape();
|
||||
|
@ -603,6 +811,273 @@ double ArithmeticCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs
|
|||
return result;
|
||||
}
|
||||
|
||||
// Not taking account of output
|
||||
void SubCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
|
||||
|
||||
// Not taking account of input
|
||||
void SubCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
is_inputs_should_in_memory_[0] = is_parameter_[0];
|
||||
is_inputs_should_in_memory_[1] = is_parameter_[1];
|
||||
}
|
||||
|
||||
// Taking account of input
|
||||
void MulCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
if (is_parameter_[0]) {
|
||||
// 'x' is parameter, so it should be in memory.
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
// In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
} else if (is_parameter_involve_[0]) {
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_parameter_[1]) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
}
|
||||
} else if (is_parameter_involve_[1]) {
|
||||
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Taking account of output
|
||||
void DivCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; }
|
||||
|
||||
// Taking account of input
|
||||
void DivCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calculating 'dx', taking account of 'y'
|
||||
if (is_parameter_[0]) {
|
||||
// 'x' is parameter, so it should be in memory.
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
// In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
} else if (is_parameter_involve_[0]) {
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// When calculating 'dy', taking account of 'y'
|
||||
if (is_parameter_involve_[1]) {
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Taking account of input
|
||||
void ModCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calculating 'dx', not taking account of 'x' and 'y'
|
||||
is_inputs_should_in_memory_[0] = is_parameter_[0];
|
||||
// When calculating 'dy', taking account of 'x' and 'y'
|
||||
if (is_parameter_involve_[1]) {
|
||||
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
}
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PowCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; }
|
||||
|
||||
void PowCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calculating 'dx', taking account of both 'x' and 'power'
|
||||
if (is_parameter_involve_[0]) {
|
||||
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
}
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
}
|
||||
// When calculating 'dpower', taking account of 'x'
|
||||
if (is_parameter_[1]) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
}
|
||||
} else if (is_parameter_involve_[1]) {
|
||||
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AssignCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calculating 'dx', taking account of 'x'
|
||||
if (is_parameter_involve_[0]) {
|
||||
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
}
|
||||
}
|
||||
// When calculating 'dy', not taking account of 'x' and 'y'
|
||||
is_inputs_should_in_memory_[1] = is_parameter_[1];
|
||||
}
|
||||
|
||||
void SigmoidCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calculating 'dx', taking account of both 'x' and 'y'
|
||||
if (is_parameter_involve_[0]) {
|
||||
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
}
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
}
|
||||
// When calculating 'dy', not taking account of 'x' and 'y'
|
||||
if (!is_inputs_should_in_memory_[1]) {
|
||||
is_inputs_should_in_memory_[1] = is_parameter_[1];
|
||||
}
|
||||
}
|
||||
|
||||
void Atan2Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calculating 'dx', taking account of both 'x' and 'y'; when calculating 'dy', taking account of both 'x' and
|
||||
// 'y'
|
||||
if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
|
||||
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
}
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DivNoNanCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; }
|
||||
|
||||
void DivNoNanCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calculating 'dx', taking account of 'y'
|
||||
if (is_parameter_[0]) {
|
||||
// 'x' is parameter, so it should be in memory.
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
// In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
} else if (is_parameter_involve_[0]) {
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// When calculating 'dy', taking account of 'y'
|
||||
if (is_parameter_involve_[1]) {
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MaximumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calculating 'dx', taking account of both 'x' and 'y';
|
||||
// when calculating 'dy', taking account of both 'x' and 'y'
|
||||
if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
|
||||
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
}
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SliceCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calculating 'dx', taking account of 'y' and 'z'
|
||||
if (is_parameter_[0]) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) {
|
||||
is_inputs_should_in_memory_[2] = true;
|
||||
}
|
||||
} else if (is_parameter_involve_[0]) {
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) {
|
||||
is_inputs_should_in_memory_[2] = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_inputs_should_in_memory_[1]) {
|
||||
is_inputs_should_in_memory_[1] = is_parameter_[1];
|
||||
}
|
||||
if (!is_inputs_should_in_memory_[2]) {
|
||||
is_inputs_should_in_memory_[2] = is_parameter_[2];
|
||||
}
|
||||
}
|
||||
|
||||
void StridedSliceCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calculating 'dx', taking account of 'y', 'z' and 'w'
|
||||
if (is_parameter_[0]) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) {
|
||||
is_inputs_should_in_memory_[2] = true;
|
||||
}
|
||||
if ((prev_output_in_mem.find(3) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(3))) {
|
||||
is_inputs_should_in_memory_[3] = true;
|
||||
}
|
||||
} else if (is_parameter_involve_[0]) {
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) {
|
||||
is_inputs_should_in_memory_[2] = true;
|
||||
}
|
||||
if ((prev_output_in_mem.find(3) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(3))) {
|
||||
is_inputs_should_in_memory_[3] = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_inputs_should_in_memory_[1]) {
|
||||
is_inputs_should_in_memory_[1] = is_parameter_[1];
|
||||
}
|
||||
if (!is_inputs_should_in_memory_[2]) {
|
||||
is_inputs_should_in_memory_[2] = is_parameter_[2];
|
||||
}
|
||||
if (!is_inputs_should_in_memory_[3]) {
|
||||
is_inputs_should_in_memory_[3] = is_parameter_[3];
|
||||
}
|
||||
}
|
||||
|
||||
void DropOutDoMaskCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
|
||||
|
||||
void DropOutDoMaskCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calculating 'dx', taking account of 'y'
|
||||
if (is_parameter_[0]) {
|
||||
// 'x' is parameter, so it should be in memory.
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
// In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
} else if (is_parameter_involve_[0]) {
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_inputs_should_in_memory_[1]) {
|
||||
is_inputs_should_in_memory_[1] = is_parameter_[1];
|
||||
}
|
||||
is_inputs_should_in_memory_[2] = is_parameter_[2];
|
||||
}
|
||||
|
||||
bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int64_t stage_id) {
|
||||
CheckGlobalDeviceManager();
|
||||
MS_EXCEPTION_IF_NULL(g_device_manager);
|
||||
|
@ -612,8 +1087,8 @@ bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int64_t stage_
|
|||
return (total_device_num == LongToSize(strategy0));
|
||||
}
|
||||
|
||||
double ReduceMethodCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs,
|
||||
const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
|
||||
double ReduceSumCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
int64_t stage_id) const {
|
||||
double result = 0.0;
|
||||
TensorInfo input0 = inputs[0];
|
||||
TensorInfo output0 = outputs[0];
|
||||
|
@ -634,8 +1109,8 @@ double ReduceMethodCost::GetForwardCommCost(const std::vector<TensorInfo> &input
|
|||
return result;
|
||||
}
|
||||
|
||||
double ReduceMethodCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
|
||||
int64_t stage_id) const {
|
||||
double ReduceSumCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
|
||||
int64_t stage_id) const {
|
||||
double result = 0.0;
|
||||
if (is_parameter_[0]) {
|
||||
TensorInfo input_tensor_info = inputs[0];
|
||||
|
@ -657,8 +1132,8 @@ double ReduceMethodCost::GetBackwardCommCost(const std::vector<TensorInfo> &inpu
|
|||
return result;
|
||||
}
|
||||
|
||||
double ReduceMethodCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
|
||||
const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
|
||||
double ReduceSumCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
|
||||
const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
|
||||
double result = 0.0;
|
||||
TensorInfo input0 = inputs[0];
|
||||
TensorInfo output0 = outputs[0];
|
||||
|
@ -679,6 +1154,30 @@ double ReduceMethodCost::GetForwardComputationCost(const std::vector<TensorInfo>
|
|||
return result;
|
||||
}
|
||||
|
||||
// Not taking account of output
|
||||
void ReduceSumCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
|
||||
|
||||
void ReduceSumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calculating 'dx', taking account of 'y'
|
||||
if (is_parameter_[0]) {
|
||||
// 'x' is parameter, so it should be in memory.
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
// In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
} else if (is_parameter_involve_[0]) {
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Not taking account of 'y'
|
||||
if (!is_inputs_should_in_memory_[1]) {
|
||||
is_inputs_should_in_memory_[1] = is_parameter_[1];
|
||||
}
|
||||
}
|
||||
|
||||
double ReduceMeanCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
|
||||
const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
|
||||
double result = 0.0;
|
||||
|
@ -701,6 +1200,42 @@ double ReduceMeanCost::GetForwardComputationCost(const std::vector<TensorInfo> &
|
|||
return result;
|
||||
}
|
||||
|
||||
void ReduceMinCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
|
||||
|
||||
void ReduceMinCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calculating 'dx', taking account of 'y'
|
||||
if (is_parameter_[0]) {
|
||||
// 'x' is parameter, so it should be in memory.
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
// In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
} else if (is_parameter_involve_[0]) {
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Not taking account of 'y'
|
||||
if (!is_inputs_should_in_memory_[1]) {
|
||||
is_inputs_should_in_memory_[1] = is_parameter_[1];
|
||||
}
|
||||
}
|
||||
|
||||
void ArgMaxWithValueCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
|
||||
|
||||
void ArgMaxWithValueCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calculating 'dx', taking account of 'x'
|
||||
if (is_parameter_[0]) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
} else if (is_parameter_involve_[0]) {
|
||||
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
double DropOutCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
|
||||
int64_t) const {
|
||||
if (inputs.empty()) {
|
||||
|
@ -760,6 +1295,52 @@ double GatherV2Cost::GetBackwardComputationCost(const std::vector<TensorInfo> &,
|
|||
return 0.0;
|
||||
}
|
||||
|
||||
// Not taking account of output
|
||||
void GatherV2Cost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
|
||||
|
||||
void GatherV2Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calculating 'dx', taking account of 'y' and 'z'
|
||||
if (is_parameter_[0]) {
|
||||
// 'x' is parameter, so it should be in memory.
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) {
|
||||
is_inputs_should_in_memory_[2] = true;
|
||||
}
|
||||
} else if (is_parameter_involve_[0]) {
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) {
|
||||
is_inputs_should_in_memory_[2] = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_inputs_should_in_memory_[1]) {
|
||||
is_inputs_should_in_memory_[1] = is_parameter_[1];
|
||||
}
|
||||
if (!is_inputs_should_in_memory_[2]) {
|
||||
is_inputs_should_in_memory_[2] = is_parameter_[2];
|
||||
}
|
||||
}
|
||||
|
||||
void GetNextCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
|
||||
|
||||
void GetNextCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
if (is_inputs_should_in_memory_.size() == 0) {
|
||||
return;
|
||||
}
|
||||
is_inputs_should_in_memory_[0] = is_parameter_[0];
|
||||
}
|
||||
|
||||
void UniqueCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
|
||||
|
||||
void UniqueCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
is_inputs_should_in_memory_[0] = is_parameter_[0];
|
||||
}
|
||||
|
||||
double LayerNormCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
|
||||
int64_t stage_id) const {
|
||||
double result = 0.0;
|
||||
|
@ -808,6 +1389,24 @@ double LayerNormCost::GetForwardComputationCost(const std::vector<TensorInfo> &i
|
|||
return result;
|
||||
}
|
||||
|
||||
void LayerNormCost::CalculateOutputInMemory() {
|
||||
is_output_should_in_memory_ = is_parameter_involve_[0] || is_parameter_involve_[1] || is_parameter_involve_[2];
|
||||
}
|
||||
|
||||
void LayerNormCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calculating 'dx', taking account of both 'x' and 'y'
|
||||
// When calculating 'dy', taking account of both 'x' and 'y'
|
||||
if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
|
||||
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
}
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
}
|
||||
is_inputs_should_in_memory_[2] = is_parameter_[2];
|
||||
}
|
||||
|
||||
double UniqueCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
int64_t stage_id) const {
|
||||
return 0.0;
|
||||
|
@ -924,6 +1523,12 @@ double UniformCandidateSamplerCost::GetForwardComputationCost(const std::vector<
|
|||
return result;
|
||||
}
|
||||
|
||||
void UniformCandidateSamplerCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
|
||||
|
||||
void UniformCandidateSamplerCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
is_inputs_should_in_memory_[0] = is_parameter_[0];
|
||||
}
|
||||
|
||||
double GatherV2PCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
|
||||
const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
|
||||
double result = 0.0;
|
||||
|
@ -1019,6 +1624,29 @@ double UnsortedSegmentSumCost::GetForwardComputationCost(const std::vector<Tenso
|
|||
return result;
|
||||
}
|
||||
|
||||
// Not taking account of output
|
||||
void UnsortedSegmentSumCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
|
||||
|
||||
// Taking account of input
|
||||
void UnsortedSegmentSumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calculating 'dx', taking account of 'y'
|
||||
if (is_parameter_[0]) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
} else if (is_parameter_involve_[0]) {
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_inputs_should_in_memory_[1]) {
|
||||
is_inputs_should_in_memory_[1] = is_parameter_[1];
|
||||
}
|
||||
is_inputs_should_in_memory_[2] = is_parameter_[2];
|
||||
}
|
||||
|
||||
double UnsortedSegmentMinCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs,
|
||||
const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
|
||||
TensorInfo input0 = inputs[0];
|
||||
|
@ -1078,5 +1706,40 @@ double UnsortedSegmentMinCost::GetForwardComputationCost(const std::vector<Tenso
|
|||
ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]); // ReduceMin
|
||||
return result;
|
||||
}
|
||||
|
||||
// Taking account of output
|
||||
void UnsortedSegmentMinCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
|
||||
|
||||
// Taking account of input
|
||||
void UnsortedSegmentMinCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calculating 'dx', taking account of 'x', 'y' and 'z'
|
||||
if (is_parameter_involve_[0]) {
|
||||
if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
}
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
is_inputs_should_in_memory_[1] = true;
|
||||
}
|
||||
if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) {
|
||||
is_inputs_should_in_memory_[2] = true;
|
||||
}
|
||||
}
|
||||
if (!is_inputs_should_in_memory_[1]) {
|
||||
is_inputs_should_in_memory_[1] = is_parameter_[1];
|
||||
}
|
||||
if (!is_inputs_should_in_memory_[2]) {
|
||||
is_inputs_should_in_memory_[2] = is_parameter_[2];
|
||||
}
|
||||
}
|
||||
|
||||
// Not taking account of output
|
||||
void VirtualDatasetCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
|
||||
|
||||
// Not taking account of input
|
||||
void VirtualDatasetCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
for (size_t i = 0; i < is_inputs_should_in_memory_.size(); ++i) {
|
||||
is_inputs_should_in_memory_[i] = is_parameter_[i];
|
||||
}
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "frontend/parallel/device_manager.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_info.h"
|
||||
|
||||
|
@ -47,16 +48,7 @@ double ListProduct(std::vector<T> vec) {
|
|||
// entries timing the length of each entry's data type
|
||||
class OperatorCost {
|
||||
public:
|
||||
explicit OperatorCost(bool is_inputs_related) : inputs_related_(is_inputs_related) {
|
||||
// this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked
|
||||
for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) {
|
||||
is_parameter_.push_back(false);
|
||||
is_parameter_involve_.push_back(false);
|
||||
inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
|
||||
outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
|
||||
}
|
||||
}
|
||||
OperatorCost() : inputs_related_(false) {
|
||||
OperatorCost() {
|
||||
// this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked
|
||||
for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) {
|
||||
is_parameter_.push_back(false);
|
||||
|
@ -89,10 +81,17 @@ class OperatorCost {
|
|||
const std::vector<TensorInfo> &outputs, int64_t stage_id) const = 0;
|
||||
virtual double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs,
|
||||
const std::vector<TensorInfo> &outputs, int64_t stage_id) const = 0;
|
||||
virtual void CalculateOutputInMemory() = 0;
|
||||
virtual void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) = 0;
|
||||
bool is_output_in_memory() const { return is_output_should_in_memory_; }
|
||||
// per device PEAK memory cost in a training iteration
|
||||
// Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled),
|
||||
// Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-involved),
|
||||
// plus necessary inputs.
|
||||
virtual double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const;
|
||||
// Contributing the input part for 'GetMemoryCost'
|
||||
double GetInputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const;
|
||||
// Contributing the output part for 'GetMemoryCost'
|
||||
double GetOutputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const;
|
||||
// per device memory cost in a inference phase
|
||||
double GetMemoryCostForInference(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &) const;
|
||||
|
||||
|
@ -101,25 +100,25 @@ class OperatorCost {
|
|||
// pre-operator that has parameters as input.
|
||||
std::vector<bool> is_parameter_involve_;
|
||||
int64_t output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved
|
||||
// Whether the inputs are related or not? For example, TensorAdd's two inputs are independent (not related), while
|
||||
// Mul's two inputs are dependent (related).
|
||||
bool inputs_related_;
|
||||
// for each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter
|
||||
// For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter
|
||||
std::vector<bool> is_parameter_;
|
||||
// for each input and output, the followings record the number of bytes of each element
|
||||
// Whether the input should keep in memory in training phase. It depends on the operator and the operator's
|
||||
// previous operators.
|
||||
std::vector<bool> is_inputs_should_in_memory_;
|
||||
// Whether the output should keep in memory in training phase. It depends on 'is_parameter_involve_' and the operator.
|
||||
bool is_output_should_in_memory_ = false;
|
||||
// For each input and output, the followings record the number of bytes of each element
|
||||
std::vector<size_t> inputs_type_lengths_;
|
||||
std::vector<size_t> outputs_type_lengths_;
|
||||
// Whether the output is critical, which means that this output is included in calculating peak memory cost
|
||||
// in the inference phase.
|
||||
int64_t is_outputs_critical_ = -1;
|
||||
};
|
||||
|
||||
using OperatorCostPtr = std::shared_ptr<OperatorCost>;
|
||||
|
||||
class MatMulCost : public OperatorCost {
|
||||
public:
|
||||
explicit MatMulCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
MatMulCost() : OperatorCost(true) {}
|
||||
MatMulCost() : OperatorCost() {}
|
||||
~MatMulCost() override = default;
|
||||
|
||||
// per device communication cost
|
||||
|
@ -141,14 +140,15 @@ class MatMulCost : public OperatorCost {
|
|||
int64_t stage_id) const override;
|
||||
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
int64_t stage_id) const override;
|
||||
void CalculateOutputInMemory() override;
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using MatMulCostPtr = std::shared_ptr<MatMulCost>;
|
||||
using TensorDotCost = MatMulCost;
|
||||
|
||||
class ActivationCost : public OperatorCost {
|
||||
class CastCost : public OperatorCost {
|
||||
public:
|
||||
explicit ActivationCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
ActivationCost() : OperatorCost(false) {}
|
||||
~ActivationCost() override = default;
|
||||
CastCost() : OperatorCost() {}
|
||||
~CastCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
int64_t stage_id) const override {
|
||||
|
@ -166,21 +166,95 @@ class ActivationCost : public OperatorCost {
|
|||
int64_t stage_id) const override;
|
||||
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
int64_t stage_id) const override;
|
||||
// Not taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Not Taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using ActivationCostPtr = std::shared_ptr<ActivationCost>;
|
||||
using TransposeCost = ActivationCost;
|
||||
using TransposeCostPtr = std::shared_ptr<TransposeCost>;
|
||||
using StridedSliceCost = ActivationCost;
|
||||
using StridedSliceCostPtr = std::shared_ptr<StridedSliceCost>;
|
||||
using SliceCost = ActivationCost;
|
||||
using SliceCostPtr = std::shared_ptr<SliceCost>;
|
||||
using SplitCost = ActivationCost;
|
||||
using SplitCostPtr = std::shared_ptr<SplitCost>;
|
||||
using RepeatElementsCost = CastCost;
|
||||
using NegCost = CastCost;
|
||||
using ExpandDimsCost = CastCost;
|
||||
using SqueezeCost = CastCost;
|
||||
using ConcatCost = CastCost;
|
||||
using LogicalNotCost = CastCost;
|
||||
using SignCost = CastCost;
|
||||
using FloorCost = CastCost;
|
||||
using RoundCost = CastCost;
|
||||
using CeilCost = CastCost;
|
||||
using ZerosLikeCost = CastCost;
|
||||
using OnesLikeCost = CastCost;
|
||||
using RangeCost = CastCost;
|
||||
using SplitCost = CastCost;
|
||||
|
||||
class SqrtCost : public CastCost {
|
||||
public:
|
||||
SqrtCost() : CastCost() {}
|
||||
~SqrtCost() override = default;
|
||||
// Taking account of output, not taking accounting of input
|
||||
void CalculateOutputInMemory() override;
|
||||
};
|
||||
using TanhCost = SqrtCost;
|
||||
using EluCost = SqrtCost;
|
||||
using ReLUCost = SqrtCost;
|
||||
using SigmoidCost = SqrtCost;
|
||||
using ReciprocalCost =
|
||||
SqrtCost; // The derivative of 'Reciprocal' is different on 'Ascend' and 'GPU'. Here, 'Ascend' is chosen
|
||||
using InvCost = SqrtCost;
|
||||
using RsqrtCost = SqrtCost;
|
||||
using AsinhCost = SqrtCost;
|
||||
using AcoshCost = SqrtCost;
|
||||
using ReLUV2Cost = SqrtCost;
|
||||
|
||||
class ReLU6Cost : public CastCost {
|
||||
public:
|
||||
ReLU6Cost() : CastCost() {}
|
||||
~ReLU6Cost() override = default;
|
||||
// Taking account of input, not taking account of output
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using SoftsignCost = ReLU6Cost;
|
||||
using SoftplusCost = ReLU6Cost;
|
||||
using SquareCost = ReLU6Cost;
|
||||
using ExpCost = ReLU6Cost;
|
||||
using LogCost = ReLU6Cost;
|
||||
using CosCost = ReLU6Cost;
|
||||
using ACosCost = ReLU6Cost;
|
||||
using AbsCost = ReLU6Cost;
|
||||
using TanCost = ReLU6Cost;
|
||||
using SinCost = ReLU6Cost;
|
||||
using SinhCost = ReLU6Cost;
|
||||
using Log1pCost = ReLU6Cost;
|
||||
using Expm1Cost = ReLU6Cost;
|
||||
using CoshCost = ReLU6Cost;
|
||||
using AtanhCost = ReLU6Cost;
|
||||
using AtanCost = ReLU6Cost;
|
||||
using AsinCost = ReLU6Cost;
|
||||
using ErfCost = ReLU6Cost;
|
||||
using ErfcCost = ReLU6Cost;
|
||||
using ActivationInfoCost = ReLU6Cost;
|
||||
|
||||
class TransposeCost : public CastCost {
|
||||
public:
|
||||
TransposeCost() : CastCost() {}
|
||||
~TransposeCost() override = default;
|
||||
// Taking account of input, not taking account of output
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
class GeLUCost : public SqrtCost {
|
||||
public:
|
||||
GeLUCost() : SqrtCost() {}
|
||||
~GeLUCost() override = default;
|
||||
// Taking account of input and output
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using BesselI0eCost = GeLUCost;
|
||||
using BesselI1eCost = GeLUCost;
|
||||
using L2NormalizeCost = GeLUCost;
|
||||
|
||||
class SoftmaxCost : public OperatorCost {
|
||||
public:
|
||||
explicit SoftmaxCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
SoftmaxCost() : OperatorCost(false) {}
|
||||
SoftmaxCost() : OperatorCost() {}
|
||||
~SoftmaxCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
|
@ -199,21 +273,45 @@ class SoftmaxCost : public OperatorCost {
|
|||
int64_t stage_id) const override;
|
||||
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
int64_t) const override;
|
||||
// Taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Not Taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
class TileCost : public SoftmaxCost {
|
||||
public:
|
||||
TileCost() : SoftmaxCost() {}
|
||||
~TileCost() override = default;
|
||||
// Not taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
class PackCost : public SoftmaxCost {
|
||||
public:
|
||||
PackCost() : SoftmaxCost() {}
|
||||
~PackCost() override = default;
|
||||
// Not taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Not taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
class BroadcastToCost : public SoftmaxCost {
|
||||
public:
|
||||
BroadcastToCost() : SoftmaxCost() {}
|
||||
~BroadcastToCost() override = default;
|
||||
// Not taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Not Taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using SoftmaxCostPtr = std::shared_ptr<SoftmaxCost>;
|
||||
using TileCost = SoftmaxCost;
|
||||
using TileCostPtr = std::shared_ptr<TileCost>;
|
||||
using PackCost = TileCost;
|
||||
using PackCostPtr = std::shared_ptr<PackCost>;
|
||||
using ConcatCost = TileCost;
|
||||
using ConcatCostPtr = std::shared_ptr<ConcatCost>;
|
||||
using BroadcastToCost = SoftmaxCost;
|
||||
using BroadcastToCostPtr = std::shared_ptr<BroadcastToCost>;
|
||||
|
||||
class TmpIdentityCost : public OperatorCost {
|
||||
public:
|
||||
explicit TmpIdentityCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
TmpIdentityCost() : OperatorCost(false) {}
|
||||
TmpIdentityCost() : OperatorCost() {}
|
||||
~TmpIdentityCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
|
@ -232,15 +330,16 @@ class TmpIdentityCost : public OperatorCost {
|
|||
int64_t stage_id) const override;
|
||||
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
int64_t stage_id) const override;
|
||||
// per device PEAK memory cost in a training iteration
|
||||
double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const override;
|
||||
// Not taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Not taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using TmpIdentityCostPtr = std::shared_ptr<TmpIdentityCost>;
|
||||
|
||||
class BatchParallelCost : public OperatorCost {
|
||||
public:
|
||||
explicit BatchParallelCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
BatchParallelCost() : OperatorCost(false) {}
|
||||
BatchParallelCost() : OperatorCost() {}
|
||||
~BatchParallelCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
|
@ -259,13 +358,25 @@ class BatchParallelCost : public OperatorCost {
|
|||
int64_t stage_id) const override;
|
||||
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
int64_t stage_id) const override;
|
||||
// Not taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
class SparseSoftmaxCrossEntropyWithLogitsCost : public BatchParallelCost {
|
||||
public:
|
||||
SparseSoftmaxCrossEntropyWithLogitsCost() : BatchParallelCost() {}
|
||||
~SparseSoftmaxCrossEntropyWithLogitsCost() override = default;
|
||||
// Taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Not taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using BatchParallelCostPtr = std::shared_ptr<BatchParallelCost>;
|
||||
|
||||
class VirtualDatasetCost : public OperatorCost {
|
||||
public:
|
||||
explicit VirtualDatasetCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
VirtualDatasetCost() : OperatorCost(false) {}
|
||||
VirtualDatasetCost() : OperatorCost() {}
|
||||
~VirtualDatasetCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
|
@ -290,17 +401,15 @@ class VirtualDatasetCost : public OperatorCost {
|
|||
int64_t) const override {
|
||||
return 0.0;
|
||||
}
|
||||
// per device PEAK memory cost in a training iteration
|
||||
double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const override {
|
||||
return 0.0;
|
||||
}
|
||||
// Not taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Not taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using VirtualDatasetCostPtr = std::shared_ptr<VirtualDatasetCost>;
|
||||
|
||||
class GeneratorBaseCost : public OperatorCost {
|
||||
public:
|
||||
explicit GeneratorBaseCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
GeneratorBaseCost() : OperatorCost(false) {}
|
||||
GeneratorBaseCost() : OperatorCost() {}
|
||||
~GeneratorBaseCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
|
@ -332,8 +441,7 @@ using GeneratorBaseCostPtr = std::shared_ptr<GeneratorBaseCost>;
|
|||
|
||||
class PReLUCost : public OperatorCost {
|
||||
public:
|
||||
explicit PReLUCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
PReLUCost() : OperatorCost(true) {}
|
||||
PReLUCost() : OperatorCost() {}
|
||||
~PReLUCost() override = default;
|
||||
|
||||
// per device communication cost
|
||||
|
@ -355,13 +463,16 @@ class PReLUCost : public OperatorCost {
|
|||
int64_t stage_id) const override;
|
||||
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
int64_t stage_id) const override;
|
||||
// Not taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using PReLUCostPtr = std::shared_ptr<PReLUCost>;
|
||||
|
||||
class OneHotCost : public OperatorCost {
|
||||
public:
|
||||
explicit OneHotCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
OneHotCost() : OperatorCost(true) {}
|
||||
OneHotCost() : OperatorCost() {}
|
||||
~OneHotCost() override = default;
|
||||
|
||||
// per device communication cost
|
||||
|
@ -383,13 +494,16 @@ class OneHotCost : public OperatorCost {
|
|||
int64_t stage_id) const override;
|
||||
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
int64_t stage_id) const override;
|
||||
// Not taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Not taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using OneHotCostPtr = std::shared_ptr<OneHotCost>;
|
||||
|
||||
class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost {
|
||||
public:
|
||||
explicit SoftmaxCrossEntropyWithLogitsCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
SoftmaxCrossEntropyWithLogitsCost() : OperatorCost(false) {}
|
||||
SoftmaxCrossEntropyWithLogitsCost() : OperatorCost() {}
|
||||
~SoftmaxCrossEntropyWithLogitsCost() override = default;
|
||||
|
||||
// per device communication cost
|
||||
|
@ -411,13 +525,15 @@ class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost {
|
|||
int64_t stage_id) const override;
|
||||
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
int64_t stage_id) const override;
|
||||
// Taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Not taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr<SoftmaxCrossEntropyWithLogitsCost>;
|
||||
|
||||
class ReshapeCost : public OperatorCost {
|
||||
public:
|
||||
explicit ReshapeCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
ReshapeCost() : OperatorCost(true) {}
|
||||
ReshapeCost() : OperatorCost() {}
|
||||
|
||||
~ReshapeCost() override = default;
|
||||
|
||||
|
@ -444,14 +560,17 @@ class ReshapeCost : public OperatorCost {
|
|||
|
||||
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
int64_t stage_id) const override;
|
||||
// Not taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Not taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using ReshapeCostPtr = std::shared_ptr<ReshapeCost>;
|
||||
|
||||
class ArithmeticCost : public OperatorCost {
|
||||
class SubCost : public OperatorCost {
|
||||
public:
|
||||
explicit ArithmeticCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
ArithmeticCost() : OperatorCost(false) {}
|
||||
~ArithmeticCost() override = default;
|
||||
SubCost() : OperatorCost() {}
|
||||
~SubCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
int64_t stage_id) const override {
|
||||
|
@ -470,16 +589,127 @@ class ArithmeticCost : public OperatorCost {
|
|||
int64_t stage_id) const override;
|
||||
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
int64_t stage_id) const override;
|
||||
// Not taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Not taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using ArithmeticCostPtr = std::shared_ptr<ArithmeticCost>;
|
||||
using BiasAddCost = ArithmeticCost;
|
||||
using BiasAddCostPtr = std::shared_ptr<BiasAddCost>;
|
||||
using TensorAddCost = SubCost;
|
||||
using FloorDivCost = SubCost;
|
||||
using AssignSubCost = SubCost;
|
||||
using AssignAddCost = SubCost;
|
||||
using LogicalAndCost = SubCost;
|
||||
using LogicalOrCost = SubCost;
|
||||
using BiasAddCost = SubCost;
|
||||
using EqualCost = SubCost;
|
||||
using ApproximateEqualCost = SubCost;
|
||||
using NotEqualCost = SubCost;
|
||||
using GreaterCost = SubCost;
|
||||
using GreaterEqualCost = SubCost;
|
||||
using LessCost = SubCost;
|
||||
using LessEqualCost = SubCost;
|
||||
|
||||
class ReduceMethodCost : public OperatorCost {
|
||||
class MulCost : public SubCost {
|
||||
public:
|
||||
explicit ReduceMethodCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
ReduceMethodCost() : OperatorCost(true) {}
|
||||
~ReduceMethodCost() override = default;
|
||||
MulCost() : SubCost() {}
|
||||
~MulCost() override = default;
|
||||
// Taking account of input, not taking account of output
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
class DivCost : public SubCost {
|
||||
public:
|
||||
DivCost() : SubCost() {}
|
||||
~DivCost() override = default;
|
||||
// Taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using ReadDivCost = DivCost;
|
||||
|
||||
class ModCost : public SubCost {
|
||||
public:
|
||||
ModCost() : SubCost() {}
|
||||
~ModCost() override = default;
|
||||
// Taking account of input, not taking account of output
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using FloorModCost = ModCost;
|
||||
|
||||
class PowCost : public SubCost {
|
||||
public:
|
||||
PowCost() : SubCost() {}
|
||||
~PowCost() override = default;
|
||||
// Taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
class AssignCost : public SubCost {
|
||||
public:
|
||||
AssignCost() : SubCost() {}
|
||||
~AssignCost() override = default;
|
||||
// Taking account of input, not taking account of output
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
class SigmoidCrossEntropyWithLogitsCost : public SubCost {
|
||||
public:
|
||||
SigmoidCrossEntropyWithLogitsCost() : SubCost() {}
|
||||
~SigmoidCrossEntropyWithLogitsCost() override = default;
|
||||
// Taking account of input, not taking account of output
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
class Atan2Cost : public SubCost {
|
||||
public:
|
||||
Atan2Cost() : SubCost() {}
|
||||
~Atan2Cost() override = default;
|
||||
// Taking account of input, not taking account of output
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
class DivNoNanCost : public SubCost {
|
||||
public:
|
||||
DivNoNanCost() : SubCost() {}
|
||||
~DivNoNanCost() override = default;
|
||||
// Taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
class MaximumCost : public SubCost {
|
||||
public:
|
||||
MaximumCost() : SubCost() {}
|
||||
~MaximumCost() override = default;
|
||||
// Taking account of input, not taking account of output
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using MinimumCost = MaximumCost;
|
||||
|
||||
class SliceCost : public CastCost {
|
||||
public:
|
||||
SliceCost() : CastCost() {}
|
||||
~SliceCost() override = default;
|
||||
// Not taking account of output, taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
class StridedSliceCost : public CastCost {
|
||||
public:
|
||||
StridedSliceCost() : CastCost() {}
|
||||
~StridedSliceCost() override = default;
|
||||
// Not taking account of output, taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
class ReduceSumCost : public OperatorCost {
|
||||
public:
|
||||
ReduceSumCost() : OperatorCost() {}
|
||||
~ReduceSumCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
int64_t stage_id) const override {
|
||||
|
@ -500,27 +730,50 @@ class ReduceMethodCost : public OperatorCost {
|
|||
return 0.0;
|
||||
}
|
||||
void set_cross_batch(bool cb) { cross_batch_ = cb; }
|
||||
// Not taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
|
||||
protected:
|
||||
bool cross_batch_ = false;
|
||||
};
|
||||
using ReduceMethodCostPtr = std::shared_ptr<ReduceMethodCost>;
|
||||
using ReduceMethodCost = ReduceSumCost;
|
||||
|
||||
class ReduceMeanCost : public ReduceMethodCost {
|
||||
class ReduceMeanCost : public ReduceSumCost {
|
||||
public:
|
||||
explicit ReduceMeanCost(bool is_inputs_related) : ReduceMethodCost(is_inputs_related) {}
|
||||
ReduceMeanCost() : ReduceMethodCost(true) {}
|
||||
ReduceMeanCost() : ReduceSumCost() {}
|
||||
~ReduceMeanCost() override = default;
|
||||
|
||||
double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
int64_t stage_id) const override;
|
||||
};
|
||||
using ReduceMeanCostPtr = std::shared_ptr<ReduceMeanCost>;
|
||||
|
||||
class ReduceMinCost : public ReduceSumCost {
|
||||
public:
|
||||
ReduceMinCost() : ReduceSumCost() {}
|
||||
~ReduceMinCost() override = default;
|
||||
// Taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using ReduceMaxCost = ReduceMinCost;
|
||||
|
||||
class ArgMaxWithValueCost : public ReduceSumCost {
|
||||
public:
|
||||
ArgMaxWithValueCost() : ReduceSumCost() {}
|
||||
~ArgMaxWithValueCost() override = default;
|
||||
// Taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using ArgMinWithValueCost = ArgMaxWithValueCost;
|
||||
|
||||
class GetNextCost : public OperatorCost {
|
||||
public:
|
||||
explicit GetNextCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
GetNextCost() : OperatorCost(false) {}
|
||||
GetNextCost() : OperatorCost() {}
|
||||
~GetNextCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
|
@ -547,13 +800,17 @@ class GetNextCost : public OperatorCost {
|
|||
int64_t) const override {
|
||||
return 0.0;
|
||||
}
|
||||
// Not taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Not Taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using GetNextCostPtr = std::shared_ptr<GetNextCost>;
|
||||
|
||||
class DropOutCost : public OperatorCost {
|
||||
// For memory cost, taking account of output, not taking account of input
|
||||
class DropOutCost : public SqrtCost {
|
||||
public:
|
||||
explicit DropOutCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
DropOutCost() : OperatorCost(true) {}
|
||||
DropOutCost() : SqrtCost() {}
|
||||
~DropOutCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
|
@ -578,12 +835,19 @@ class DropOutCost : public OperatorCost {
|
|||
}
|
||||
};
|
||||
|
||||
using DropOutCostPtr = std::shared_ptr<DropOutCost>;
|
||||
class DropOutDoMaskCost : public DropOutCost {
|
||||
public:
|
||||
DropOutDoMaskCost() : DropOutCost() {}
|
||||
~DropOutDoMaskCost() override = default;
|
||||
// Not taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
class UnsortedSegmentSumCost : public OperatorCost {
|
||||
public:
|
||||
explicit UnsortedSegmentSumCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
UnsortedSegmentSumCost() : OperatorCost(true) {}
|
||||
UnsortedSegmentSumCost() : OperatorCost() {}
|
||||
~UnsortedSegmentSumCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
|
@ -602,14 +866,15 @@ class UnsortedSegmentSumCost : public OperatorCost {
|
|||
int64_t) const override {
|
||||
return 0.0;
|
||||
}
|
||||
// Not taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
using UnsortedSegmentSumCostPtr = std::shared_ptr<UnsortedSegmentSumCost>;
|
||||
|
||||
class UnsortedSegmentMinCost : public OperatorCost {
|
||||
public:
|
||||
explicit UnsortedSegmentMinCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
UnsortedSegmentMinCost() : OperatorCost(true) {}
|
||||
UnsortedSegmentMinCost() : OperatorCost() {}
|
||||
~UnsortedSegmentMinCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
|
@ -628,14 +893,16 @@ class UnsortedSegmentMinCost : public OperatorCost {
|
|||
int64_t) const override {
|
||||
return 0.0;
|
||||
}
|
||||
// Taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
using UnsortedSegmentMinCostPtr = std::shared_ptr<UnsortedSegmentMinCost>;
|
||||
using UnsortedSegmentMaxCost = UnsortedSegmentMinCost;
|
||||
|
||||
class LayerNormCost : public OperatorCost {
|
||||
public:
|
||||
explicit LayerNormCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
LayerNormCost() : OperatorCost(true) {}
|
||||
LayerNormCost() : OperatorCost() {}
|
||||
~LayerNormCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
|
@ -656,14 +923,15 @@ class LayerNormCost : public OperatorCost {
|
|||
int64_t) const override {
|
||||
return 0.0;
|
||||
}
|
||||
// Taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
using DropOutCostPtr = std::shared_ptr<DropOutCost>;
|
||||
|
||||
class UniqueCost : public OperatorCost {
|
||||
public:
|
||||
explicit UniqueCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
UniqueCost() : OperatorCost(true) {}
|
||||
UniqueCost() : OperatorCost() {}
|
||||
~UniqueCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
|
@ -682,14 +950,15 @@ class UniqueCost : public OperatorCost {
|
|||
int64_t stage_id) const override;
|
||||
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
int64_t) const override;
|
||||
// Taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Not Taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
using UniqueCostPtr = std::shared_ptr<UniqueCost>;
|
||||
|
||||
class UniformCandidateSamplerCost : public OperatorCost {
|
||||
public:
|
||||
explicit UniformCandidateSamplerCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
UniformCandidateSamplerCost() : OperatorCost(false) {}
|
||||
UniformCandidateSamplerCost() : OperatorCost() {}
|
||||
~UniformCandidateSamplerCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
|
@ -714,14 +983,15 @@ class UniformCandidateSamplerCost : public OperatorCost {
|
|||
int64_t) const override {
|
||||
return 0.0;
|
||||
}
|
||||
// Not taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Not Taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
using UniformCandidateSamplerCostPtr = std::shared_ptr<UniformCandidateSamplerCost>;
|
||||
|
||||
class GatherV2Cost : public OperatorCost {
|
||||
public:
|
||||
explicit GatherV2Cost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||
GatherV2Cost() : OperatorCost(true) {}
|
||||
GatherV2Cost() : OperatorCost() {}
|
||||
~GatherV2Cost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
|
@ -740,14 +1010,15 @@ class GatherV2Cost : public OperatorCost {
|
|||
int64_t stage_id) const override;
|
||||
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
int64_t) const override;
|
||||
// Not taking account of output
|
||||
void CalculateOutputInMemory() override;
|
||||
// Taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
using GatherV2CostPtr = std::shared_ptr<GatherV2Cost>;
|
||||
|
||||
class GatherV2PCost : public OperatorCost {
|
||||
class GatherV2PCost : public GatherV2Cost {
|
||||
public:
|
||||
explicit GatherV2PCost(bool is_inputs_related) : OperatorCost(is_inputs_related), axis_(0) {}
|
||||
GatherV2PCost() : OperatorCost(true), axis_(0) {}
|
||||
GatherV2PCost() : GatherV2Cost(), axis_(0) {}
|
||||
~GatherV2PCost() override = default;
|
||||
|
||||
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
|
@ -773,8 +1044,6 @@ class GatherV2PCost : public OperatorCost {
|
|||
int64_t axis_;
|
||||
Shape strategy_;
|
||||
};
|
||||
|
||||
using GatherV2PCostPtr = std::shared_ptr<GatherV2PCost>;
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_
|
||||
|
|
|
@ -50,8 +50,8 @@ class ActivationBase : public OperatorInfo {
|
|||
class Activation : public ActivationBase {
|
||||
public:
|
||||
Activation(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>(false)) {}
|
||||
const PrimitiveAttrs &attrs, OperatorCostPtr cost)
|
||||
: ActivationBase(name, inputs_shape, outputs_shape, attrs, cost) {}
|
||||
~Activation() override = default;
|
||||
Status GenerateStrategies(int64_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||
|
@ -64,7 +64,7 @@ class ActivationInfo : public Activation {
|
|||
public:
|
||||
ActivationInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: Activation(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: Activation(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationInfoCost>()) {}
|
||||
~ActivationInfo() override = default;
|
||||
|
||||
protected:
|
||||
|
@ -74,8 +74,8 @@ class ActivationInfo : public Activation {
|
|||
class ActivationOther : public Activation {
|
||||
public:
|
||||
ActivationOther(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: Activation(name, inputs_shape, outputs_shape, attrs) {}
|
||||
const PrimitiveAttrs &attrs, OperatorCostPtr cost)
|
||||
: Activation(name, inputs_shape, outputs_shape, attrs, cost) {}
|
||||
~ActivationOther() override = default;
|
||||
|
||||
protected:
|
||||
|
@ -86,7 +86,7 @@ class GeluInfo : public ActivationOther {
|
|||
public:
|
||||
GeluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<GeLUCost>()) {}
|
||||
~GeluInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -94,7 +94,7 @@ class TanhInfo : public ActivationOther {
|
|||
public:
|
||||
TanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<TanhCost>()) {}
|
||||
~TanhInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -102,7 +102,7 @@ class Softmax : public ActivationBase {
|
|||
public:
|
||||
explicit Softmax(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>(false)) {}
|
||||
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>()) {}
|
||||
~Softmax() override = default;
|
||||
Status GenerateStrategies(int64_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||
|
@ -134,7 +134,7 @@ class LogSoftmaxInfo : public Softmax {
|
|||
class EluInfo : public ActivationOther {
|
||||
public:
|
||||
EluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<EluCost>()) {}
|
||||
~EluInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -142,7 +142,7 @@ class ReLUInfo : public ActivationOther {
|
|||
public:
|
||||
ReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLUCost>()) {}
|
||||
~ReLUInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -150,7 +150,7 @@ class RepeatElementsInfo : public ActivationOther {
|
|||
public:
|
||||
RepeatElementsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<RepeatElementsCost>()) {}
|
||||
~RepeatElementsInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -158,7 +158,7 @@ class ReLU6Info : public ActivationOther {
|
|||
public:
|
||||
ReLU6Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLU6Cost>()) {}
|
||||
~ReLU6Info() override = default;
|
||||
};
|
||||
|
||||
|
@ -166,7 +166,7 @@ class SoftsignInfo : public ActivationOther {
|
|||
public:
|
||||
SoftsignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftsignCost>()) {}
|
||||
~SoftsignInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -174,7 +174,7 @@ class SoftplusInfo : public ActivationOther {
|
|||
public:
|
||||
SoftplusInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftplusCost>()) {}
|
||||
~SoftplusInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -182,7 +182,7 @@ class CastInfo : public ActivationOther {
|
|||
public:
|
||||
CastInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<CastCost>()) {}
|
||||
~CastInfo() override = default;
|
||||
|
||||
protected:
|
||||
|
@ -193,14 +193,14 @@ class SqrtInfo : public ActivationOther {
|
|||
public:
|
||||
SqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SqrtCost>()) {}
|
||||
~SqrtInfo() override = default;
|
||||
};
|
||||
|
||||
class NegInfo : public ActivationOther {
|
||||
public:
|
||||
NegInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<NegCost>()) {}
|
||||
~NegInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -208,7 +208,7 @@ class ExpandDimsInfo : public ActivationOther {
|
|||
public:
|
||||
ExpandDimsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ExpandDimsCost>()) {}
|
||||
~ExpandDimsInfo() override = default;
|
||||
|
||||
protected:
|
||||
|
@ -228,7 +228,7 @@ class SqueezeInfo : public ActivationOther {
|
|||
public:
|
||||
SqueezeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SqueezeCost>()) {}
|
||||
~SqueezeInfo() override = default;
|
||||
|
||||
protected:
|
||||
|
@ -247,7 +247,7 @@ class SquareInfo : public ActivationOther {
|
|||
public:
|
||||
SquareInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SquareCost>()) {}
|
||||
~SquareInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -255,7 +255,7 @@ class SigmoidInfo : public ActivationOther {
|
|||
public:
|
||||
SigmoidInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SigmoidCost>()) {}
|
||||
~SigmoidInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -263,7 +263,7 @@ class DropoutInfo : public ActivationOther {
|
|||
public:
|
||||
DropoutInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>()) {}
|
||||
~DropoutInfo() override = default;
|
||||
Status GenerateStrategies(int64_t stage_id) override;
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ class ArithmeticBase : public OperatorInfo {
|
|||
class SubInfo : public ArithmeticBase {
|
||||
public:
|
||||
SubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SubCost>()) {}
|
||||
~SubInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -64,28 +64,28 @@ class TensorAddInfo : public ArithmeticBase {
|
|||
public:
|
||||
TensorAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorAddCost>()) {}
|
||||
~TensorAddInfo() override = default;
|
||||
};
|
||||
|
||||
class MulInfo : public ArithmeticBase {
|
||||
public:
|
||||
MulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<MulCost>()) {}
|
||||
~MulInfo() override = default;
|
||||
};
|
||||
|
||||
class DivInfo : public ArithmeticBase {
|
||||
public:
|
||||
DivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<DivCost>()) {}
|
||||
~DivInfo() override = default;
|
||||
};
|
||||
|
||||
class ModInfo : public ArithmeticBase {
|
||||
public:
|
||||
ModInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ModCost>()) {}
|
||||
~ModInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -93,7 +93,7 @@ class RealDivInfo : public ArithmeticBase {
|
|||
public:
|
||||
RealDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReadDivCost>()) {}
|
||||
~RealDivInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -101,7 +101,7 @@ class FloorDivInfo : public ArithmeticBase {
|
|||
public:
|
||||
FloorDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<FloorDivCost>()) {}
|
||||
~FloorDivInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -109,14 +109,14 @@ class FloorModInfo : public ArithmeticBase {
|
|||
public:
|
||||
FloorModInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<FloorModCost>()) {}
|
||||
~FloorModInfo() override = default;
|
||||
};
|
||||
|
||||
class PowInfo : public ArithmeticBase {
|
||||
public:
|
||||
PowInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<PowCost>()) {}
|
||||
~PowInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -124,7 +124,7 @@ class AssignSubInfo : public ArithmeticBase {
|
|||
public:
|
||||
AssignSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<AssignSubCost>()) {}
|
||||
~AssignSubInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -132,7 +132,7 @@ class AssignInfo : public ArithmeticBase {
|
|||
public:
|
||||
AssignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<AssignCost>()) {}
|
||||
~AssignInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -140,7 +140,7 @@ class AssignAddInfo : public ArithmeticBase {
|
|||
public:
|
||||
AssignAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<AssignAddCost>()) {}
|
||||
~AssignAddInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -149,7 +149,8 @@ class SigmoidCrossEntropyWithLogitsInfo : public ArithmeticBase {
|
|||
public:
|
||||
SigmoidCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs,
|
||||
std::make_shared<SigmoidCrossEntropyWithLogitsCost>()) {}
|
||||
~SigmoidCrossEntropyWithLogitsInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -157,7 +158,7 @@ class Atan2Info : public ArithmeticBase {
|
|||
public:
|
||||
Atan2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<Atan2Cost>()) {}
|
||||
~Atan2Info() override = default;
|
||||
};
|
||||
|
||||
|
@ -165,7 +166,7 @@ class DivNoNanInfo : public ArithmeticBase {
|
|||
public:
|
||||
DivNoNanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<DivNoNanCost>()) {}
|
||||
~DivNoNanInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -173,7 +174,7 @@ class LogicalAndInfo : public ArithmeticBase {
|
|||
public:
|
||||
LogicalAndInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LogicalAndCost>()) {}
|
||||
~LogicalAndInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -181,7 +182,7 @@ class LogicalOrInfo : public ArithmeticBase {
|
|||
public:
|
||||
LogicalOrInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LogicalOrCost>()) {}
|
||||
~LogicalOrInfo() override = default;
|
||||
};
|
||||
} // namespace parallel
|
||||
|
|
|
@ -34,8 +34,7 @@ class BatchParallelInfo : public OperatorInfo {
|
|||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost), dev_num_(1) {}
|
||||
BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>(false)),
|
||||
dev_num_(1) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()), dev_num_(1) {}
|
||||
|
||||
~BatchParallelInfo() override = default;
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
|
@ -62,7 +61,8 @@ class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo {
|
|||
public:
|
||||
SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape,
|
||||
const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
|
||||
: BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>(true)) {}
|
||||
: BatchParallelInfo(name, inputs_shape, outputs_shape, attrs,
|
||||
std::make_shared<SparseSoftmaxCrossEntropyWithLogitsCost>()) {}
|
||||
~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default;
|
||||
void ReComputeBatchSplitFlagList() override;
|
||||
};
|
||||
|
|
|
@ -34,7 +34,7 @@ class BiasAddInfo : public OperatorInfo {
|
|||
public:
|
||||
BiasAddInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BiasAddCost>(false)) {}
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BiasAddCost>()) {}
|
||||
~BiasAddInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -36,7 +36,7 @@ class BroadcastToInfo : public OperatorInfo {
|
|||
public:
|
||||
BroadcastToInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BroadcastToCost>(false)) {}
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BroadcastToCost>()) {}
|
||||
~BroadcastToInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -32,7 +32,7 @@ class EqualInfo : public ArithmeticBase {
|
|||
public:
|
||||
EqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<EqualCost>()) {}
|
||||
~EqualInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -40,7 +40,7 @@ class ApproximateEqualInfo : public ArithmeticBase {
|
|||
public:
|
||||
ApproximateEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ApproximateEqualCost>()) {}
|
||||
~ApproximateEqualInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -48,7 +48,7 @@ class NotEqualInfo : public ArithmeticBase {
|
|||
public:
|
||||
NotEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<NotEqualCost>()) {}
|
||||
~NotEqualInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -56,7 +56,7 @@ class MaximumInfo : public ArithmeticBase {
|
|||
public:
|
||||
MaximumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<MaximumCost>()) {}
|
||||
~MaximumInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -64,7 +64,7 @@ class MinimumInfo : public ArithmeticBase {
|
|||
public:
|
||||
MinimumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<MinimumCost>()) {}
|
||||
~MinimumInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -72,7 +72,7 @@ class GreaterInfo : public ArithmeticBase {
|
|||
public:
|
||||
GreaterInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<GreaterCost>()) {}
|
||||
~GreaterInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -80,7 +80,7 @@ class GreaterEqualInfo : public ArithmeticBase {
|
|||
public:
|
||||
GreaterEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<GreaterEqualCost>()) {}
|
||||
~GreaterEqualInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -88,7 +88,7 @@ class LessInfo : public ArithmeticBase {
|
|||
public:
|
||||
LessInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LessCost>()) {}
|
||||
~LessInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -96,7 +96,7 @@ class LessEqualInfo : public ArithmeticBase {
|
|||
public:
|
||||
LessEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LessEqualCost>()) {}
|
||||
~LessEqualInfo() override = default;
|
||||
};
|
||||
} // namespace parallel
|
||||
|
|
|
@ -33,7 +33,7 @@ class ConcatInfo : public OperatorInfo {
|
|||
public:
|
||||
ConcatInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ConcatCost>(false)) {}
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ConcatCost>()) {}
|
||||
~ConcatInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -33,7 +33,7 @@ class DropoutDoMaskInfo : public OperatorInfo {
|
|||
public:
|
||||
DropoutDoMaskInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>(true)) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutDoMaskCost>()) {}
|
||||
~DropoutDoMaskInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||
#include "frontend/parallel/ops_info/activation_info.h"
|
||||
|
@ -30,21 +31,21 @@ namespace parallel {
|
|||
class ExpInfo : public ActivationOther {
|
||||
public:
|
||||
ExpInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ExpCost>()) {}
|
||||
~ExpInfo() override = default;
|
||||
};
|
||||
|
||||
class LogInfo : public ActivationOther {
|
||||
public:
|
||||
LogInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<LogCost>()) {}
|
||||
~LogInfo() override = default;
|
||||
};
|
||||
|
||||
class CosInfo : public ActivationOther {
|
||||
public:
|
||||
CosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<CosCost>()) {}
|
||||
~CosInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -52,7 +53,7 @@ class ACosInfo : public ActivationOther {
|
|||
public:
|
||||
ACosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ACosCost>()) {}
|
||||
~ACosInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -60,14 +61,14 @@ class LogicalNotInfo : public ActivationOther {
|
|||
public:
|
||||
LogicalNotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<LogicalNotCost>()) {}
|
||||
~LogicalNotInfo() override = default;
|
||||
};
|
||||
|
||||
class AbsInfo : public ActivationOther {
|
||||
public:
|
||||
AbsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AbsCost>()) {}
|
||||
~AbsInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -75,7 +76,7 @@ class SignInfo : public ActivationOther {
|
|||
public:
|
||||
SignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SignCost>()) {}
|
||||
~SignInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -83,7 +84,7 @@ class FloorInfo : public ActivationOther {
|
|||
public:
|
||||
FloorInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<FloorCost>()) {}
|
||||
~FloorInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -91,7 +92,7 @@ class RoundInfo : public ActivationOther {
|
|||
public:
|
||||
RoundInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<RoundCost>()) {}
|
||||
~RoundInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -99,14 +100,14 @@ class ReciprocalInfo : public ActivationOther {
|
|||
public:
|
||||
ReciprocalInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReciprocalCost>()) {}
|
||||
~ReciprocalInfo() override = default;
|
||||
};
|
||||
|
||||
class InvInfo : public ActivationOther {
|
||||
public:
|
||||
InvInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<InvCost>()) {}
|
||||
~InvInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -114,21 +115,21 @@ class RsqrtInfo : public ActivationOther {
|
|||
public:
|
||||
RsqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<RsqrtCost>()) {}
|
||||
~RsqrtInfo() override = default;
|
||||
};
|
||||
|
||||
class TanInfo : public ActivationOther {
|
||||
public:
|
||||
TanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<TanCost>()) {}
|
||||
~TanInfo() override = default;
|
||||
};
|
||||
|
||||
class SinInfo : public ActivationOther {
|
||||
public:
|
||||
SinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SinCost>()) {}
|
||||
~SinInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -136,7 +137,7 @@ class SinhInfo : public ActivationOther {
|
|||
public:
|
||||
SinhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SinhCost>()) {}
|
||||
~SinhInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -144,7 +145,7 @@ class Log1pInfo : public ActivationOther {
|
|||
public:
|
||||
Log1pInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<Log1pCost>()) {}
|
||||
~Log1pInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -152,7 +153,7 @@ class Expm1Info : public ActivationOther {
|
|||
public:
|
||||
Expm1Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<Expm1Cost>()) {}
|
||||
~Expm1Info() override = default;
|
||||
};
|
||||
|
||||
|
@ -160,7 +161,7 @@ class CoshInfo : public ActivationOther {
|
|||
public:
|
||||
CoshInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<CoshCost>()) {}
|
||||
~CoshInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -168,7 +169,7 @@ class CeilInfo : public ActivationOther {
|
|||
public:
|
||||
CeilInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<CeilCost>()) {}
|
||||
~CeilInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -176,7 +177,7 @@ class AtanhInfo : public ActivationOther {
|
|||
public:
|
||||
AtanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AtanhCost>()) {}
|
||||
~AtanhInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -184,7 +185,7 @@ class AtanInfo : public ActivationOther {
|
|||
public:
|
||||
AtanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AtanCost>()) {}
|
||||
~AtanInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -192,7 +193,7 @@ class AsinInfo : public ActivationOther {
|
|||
public:
|
||||
AsinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AsinCost>()) {}
|
||||
~AsinInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -200,7 +201,7 @@ class AsinhInfo : public ActivationOther {
|
|||
public:
|
||||
AsinhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AsinhCost>()) {}
|
||||
~AsinhInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -208,14 +209,14 @@ class AcoshInfo : public ActivationOther {
|
|||
public:
|
||||
AcoshInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AcoshCost>()) {}
|
||||
~AcoshInfo() override = default;
|
||||
};
|
||||
|
||||
class ErfInfo : public ActivationOther {
|
||||
public:
|
||||
ErfInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ErfCost>()) {}
|
||||
~ErfInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -223,7 +224,7 @@ class ErfcInfo : public ActivationOther {
|
|||
public:
|
||||
ErfcInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ErfcCost>()) {}
|
||||
~ErfcInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -231,7 +232,7 @@ class ZerosLikeInfo : public ActivationOther {
|
|||
public:
|
||||
ZerosLikeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ZerosLikeCost>()) {}
|
||||
~ZerosLikeInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -239,7 +240,7 @@ class OnesLikeInfo : public ActivationOther {
|
|||
public:
|
||||
OnesLikeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<OnesLikeCost>()) {}
|
||||
~OnesLikeInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -247,7 +248,7 @@ class BesselI0eInfo : public ActivationOther {
|
|||
public:
|
||||
BesselI0eInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<BesselI0eCost>()) {}
|
||||
~BesselI0eInfo() override = default;
|
||||
};
|
||||
|
||||
|
@ -255,7 +256,7 @@ class BesselI1eInfo : public ActivationOther {
|
|||
public:
|
||||
BesselI1eInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<BesselI1eCost>()) {}
|
||||
~BesselI1eInfo() override = default;
|
||||
};
|
||||
} // namespace parallel
|
||||
|
|
|
@ -32,7 +32,7 @@ class GetNextInfo : public OperatorInfo {
|
|||
public:
|
||||
GetNextInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>(false)) {}
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>()) {}
|
||||
~GetNextInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -33,7 +33,7 @@ class L2NormalizeInfo : public Activation {
|
|||
public:
|
||||
L2NormalizeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: Activation(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: Activation(name, inputs_shape, outputs_shape, attrs, std::make_shared<L2NormalizeCost>()) {}
|
||||
~L2NormalizeInfo() override = default;
|
||||
Status GenerateStrategies(int64_t stage_id) override;
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ class LayerNormInfo : public OperatorInfo {
|
|||
public:
|
||||
LayerNormInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<LayerNormCost>(true)),
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<LayerNormCost>()),
|
||||
begin_norm_axis_(0) {}
|
||||
~LayerNormInfo() override = default;
|
||||
|
||||
|
|
|
@ -36,8 +36,7 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo {
|
|||
public:
|
||||
SoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs,
|
||||
std::make_shared<SoftmaxCrossEntropyWithLogitsCost>(false)) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCrossEntropyWithLogitsCost>()) {}
|
||||
~SoftmaxCrossEntropyWithLogitsInfo() override = default;
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -34,7 +34,7 @@ class MatMulBase : public OperatorInfo {
|
|||
public:
|
||||
MatMulBase(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>(true)) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>()) {}
|
||||
~MatMulBase() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -33,7 +33,7 @@ class OneHotInfo : public OperatorInfo {
|
|||
public:
|
||||
OneHotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<OneHotCost>(false)) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<OneHotCost>()) {}
|
||||
~OneHotInfo() override = default;
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -1204,6 +1204,20 @@ int64_t OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() {
|
|||
} else {
|
||||
is_output_parameter_involve_ = 0;
|
||||
}
|
||||
// Set 'is_parameter_involve_' and 'is_output_parameter_involve_' into operatorCost, which are used in
|
||||
// calculating 'inputs_in_memory' and 'output_in_memory', respectively.
|
||||
operator_cost()->set_is_parameter_involve(is_parameter_involve_);
|
||||
operator_cost()->set_output_parameter_involve(is_output_parameter_involve_);
|
||||
// Calculating 'output_in_memory'
|
||||
operator_cost()->CalculateOutputInMemory();
|
||||
// Calculating 'inputs_in_memory'
|
||||
std::map<size_t, bool> input_in_memory;
|
||||
for (auto &p_edge : prev_edges) {
|
||||
auto input_index = p_edge->next_op_input_index();
|
||||
auto is_in_mem = p_edge->prev_operator()->operator_cost()->is_output_in_memory();
|
||||
input_in_memory.emplace(std::make_pair(input_index, is_in_mem));
|
||||
}
|
||||
operator_cost()->CalculateInputsInMemory(input_in_memory);
|
||||
|
||||
return is_output_parameter_involve_;
|
||||
}
|
||||
|
@ -1220,14 +1234,10 @@ Status OperatorInfo::set_is_parameter(const std::vector<bool> &is_parameter) {
|
|||
}
|
||||
|
||||
Status OperatorInfo::CalculateMemoryCost() {
|
||||
// First, set the 'is_parameter_involve_' and 'is_output_parameter_involve_' into OperatorCost, which are necessary to
|
||||
// calculate memory cost.
|
||||
if (is_parameter_involve_.size() != is_parameter_.size()) {
|
||||
MS_LOG(ERROR) << "'is_parameter_' does not have the same number of input size of 'is_parameter_involve_'.";
|
||||
return FAILED;
|
||||
}
|
||||
operator_cost()->set_is_parameter_involve(is_parameter_involve_);
|
||||
operator_cost()->set_output_parameter_involve(is_output_parameter_involve_);
|
||||
// Set the memory cost in the 'strategy_cost_'
|
||||
for (auto &swc : strategy_cost_) {
|
||||
auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr);
|
||||
|
|
|
@ -33,7 +33,7 @@ class PackInfo : public OperatorInfo {
|
|||
public:
|
||||
PackInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<PackCost>(false)) {}
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<PackCost>()) {}
|
||||
~PackInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -35,7 +35,7 @@ class PReLUInfo : public OperatorInfo {
|
|||
public:
|
||||
PReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<PReLUCost>(true)) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<PReLUCost>()) {}
|
||||
~PReLUInfo() override = default;
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -39,7 +39,7 @@ class RangeInfo : public OperatorInfo {
|
|||
public:
|
||||
RangeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>(true)) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<RangeCost>()) {}
|
||||
~RangeInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -33,8 +33,8 @@ namespace parallel {
|
|||
class ReduceMethod : public OperatorInfo {
|
||||
public:
|
||||
ReduceMethod(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMethodCost>(true)) {}
|
||||
const PrimitiveAttrs &attrs, OperatorCostPtr cost)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost) {}
|
||||
~ReduceMethod() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
|
@ -62,7 +62,7 @@ class ReduceMaxInfo : public ReduceMethod {
|
|||
public:
|
||||
ReduceMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ReduceMethod(name, inputs_shape, outputs_shape, attrs) {
|
||||
: ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMaxCost>()) {
|
||||
reduce_method_ = REDUCE_OP_MAX;
|
||||
}
|
||||
|
||||
|
@ -73,7 +73,7 @@ class ArgMaxWithValueInfo : public ReduceMethod {
|
|||
public:
|
||||
ArgMaxWithValueInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ReduceMethod(name, inputs_shape, outputs_shape, attrs) {
|
||||
: ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArgMaxWithValueCost>()) {
|
||||
reduce_method_ = REDUCE_OP_MAX;
|
||||
}
|
||||
|
||||
|
@ -105,9 +105,7 @@ class ReduceMeanInfo : public ReduceMethod {
|
|||
public:
|
||||
ReduceMeanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ReduceMethod(name, inputs_shape, outputs_shape, attrs) {
|
||||
set_cost(std::make_shared<ReduceMeanCost>());
|
||||
}
|
||||
: ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMeanCost>()) {}
|
||||
|
||||
~ReduceMeanInfo() override = default;
|
||||
|
||||
|
@ -119,7 +117,7 @@ class ReduceSumInfo : public ReduceMethod {
|
|||
public:
|
||||
ReduceSumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ReduceMethod(name, inputs_shape, outputs_shape, attrs) {
|
||||
: ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceSumCost>()) {
|
||||
reduce_method_ = REDUCE_OP_SUM;
|
||||
}
|
||||
|
||||
|
@ -130,7 +128,7 @@ class ReduceMinInfo : public ReduceMethod {
|
|||
public:
|
||||
ReduceMinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ReduceMethod(name, inputs_shape, outputs_shape, attrs) {
|
||||
: ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMinCost>()) {
|
||||
reduce_method_ = REDUCE_OP_MIN;
|
||||
}
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class ReLUV2Info : public OperatorInfo {
|
|||
public:
|
||||
ReLUV2Info(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>(false)) {}
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLUV2Cost>()) {}
|
||||
~ReLUV2Info() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -36,7 +36,7 @@ class ReshapeInfo : public OperatorInfo {
|
|||
public:
|
||||
ReshapeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>(false)),
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>()),
|
||||
dev_num_(0),
|
||||
pre_operator_index_(0),
|
||||
next_operator_index_(0),
|
||||
|
|
|
@ -34,7 +34,7 @@ class SliceInfo : public OperatorInfo {
|
|||
public:
|
||||
SliceInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<SliceCost>(false)),
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<SliceCost>()),
|
||||
slice_axis_(-1) {}
|
||||
~SliceInfo() override = default;
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ class SplitInfo : public OperatorInfo {
|
|||
public:
|
||||
SplitInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ConcatCost>(false)) {}
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<SplitCost>()) {}
|
||||
~SplitInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -273,6 +273,8 @@ Status StridedSliceInfo::GenerateStrategies(int64_t stage_id) {
|
|||
PrintStrategy(sp);
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name() << ", finishing GenerateStrategies().";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ class StridedSliceInfo : public OperatorInfo {
|
|||
public:
|
||||
StridedSliceInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<StridedSliceCost>(false)) {}
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<StridedSliceCost>()) {}
|
||||
~StridedSliceInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -41,7 +41,7 @@ class TensorDotInfo : public OperatorInfo {
|
|||
public:
|
||||
TensorDotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>(true)) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorDotCost>()) {}
|
||||
~TensorDotInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -34,7 +34,7 @@ class TileInfo : public OperatorInfo {
|
|||
public:
|
||||
TileInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<TileCost>(false)) {}
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<TileCost>()) {}
|
||||
~TileInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -34,7 +34,7 @@ class TmpIdentityInfo : public OperatorInfo {
|
|||
public:
|
||||
TmpIdentityInfo(const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs,
|
||||
const std::string &name = IDENTITY_INFO)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TmpIdentityCost>(false)) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TmpIdentityCost>()) {}
|
||||
~TmpIdentityInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -35,7 +35,7 @@ class TransposeInfo : public OperatorInfo {
|
|||
public:
|
||||
TransposeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TransposeCost>(false)) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TransposeCost>()) {}
|
||||
~TransposeInfo() override = default;
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -32,7 +32,7 @@ class UniqueInfo : public OperatorInfo {
|
|||
public:
|
||||
UniqueInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>(false)) {}
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<UniqueCost>()) {}
|
||||
~UniqueInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -82,7 +82,7 @@ class UnsortedSegmentMaxInfo : public UnsortedSegmentOpInfo {
|
|||
public:
|
||||
UnsortedSegmentMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMinCost>()) {}
|
||||
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMaxCost>()) {}
|
||||
~UnsortedSegmentMaxInfo() override = default;
|
||||
|
||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||
|
|
|
@ -32,7 +32,7 @@ class VirtualDatasetInfo : public OperatorInfo {
|
|||
public:
|
||||
VirtualDatasetInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<VirtualDatasetCost>(false)) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<VirtualDatasetCost>()) {}
|
||||
~VirtualDatasetInfo() override = default;
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -85,11 +85,11 @@ class TestActivationCost : public UT::Common {
|
|||
TestActivationCost() {}
|
||||
void SetUp();
|
||||
void TearDown();
|
||||
ActivationCost ac_cost_;
|
||||
ActivationInfoCost ac_cost_;
|
||||
};
|
||||
|
||||
void TestActivationCost::SetUp() {
|
||||
ac_cost_ = ActivationCost();
|
||||
ac_cost_ = ActivationInfoCost();
|
||||
RankList dev_list;
|
||||
|
||||
for (int32_t i = 0; i < 1050; i++) {
|
||||
|
|
Loading…
Reference in New Issue