modify gen strategy for matmul op

This commit is contained in:
yangzhenzhang 2022-06-01 10:12:19 +08:00
parent 41f1ac9573
commit e5b62b2714
12 changed files with 217 additions and 409 deletions

View File

@ -69,7 +69,7 @@ void SetDevMatrixShape(const Dimensions &mat_a_strategy, const Dimensions &mat_b
Status MatMulBase::GetAttrs() {
if (attrs_.size() < MATMUL_ATTRS_SIZE) {
MS_LOG(ERROR) << name_ << " : The size of attrs small than 2.";
MS_LOG(ERROR) << name_ << ": The size of attrs small than 2, got " << attrs_.size();
return FAILED;
}
@ -79,18 +79,23 @@ Status MatMulBase::GetAttrs() {
if (transpose_a_iter->second->isa<BoolImm>()) {
transpose_a_ = transpose_a_iter->second->cast<BoolImmPtr>()->value();
} else {
MS_LOG(ERROR) << name_ << " : The value of transpose_a is not bool.";
MS_LOG(ERROR) << name_ << ": The value of transpose_a is not bool.";
return FAILED;
}
}
if (transpose_a_) {
MS_LOG(ERROR) << name_ << ": The transpose_a=true is not be supported";
return FAILED;
}
auto transpose_b_iter = attrs_.find(TRANSPOSE_B);
if (transpose_b_iter != attrs_.end()) {
MS_EXCEPTION_IF_NULL(transpose_b_iter->second);
if (transpose_b_iter->second->isa<BoolImm>()) {
transpose_b_ = transpose_b_iter->second->cast<BoolImmPtr>()->value();
} else {
MS_LOG(ERROR) << name_ << " : The value of transpose_b is not bool.";
MS_LOG(ERROR) << name_ << ": The value of transpose_b is not bool.";
return FAILED;
}
}
@ -101,18 +106,22 @@ Status MatMulBase::GetAttrs() {
if (field_size_iter->second->isa<Int64Imm>()) {
field_size_ = field_size_iter->second->cast<Int64ImmPtr>()->value();
} else {
MS_LOG(ERROR) << name_ << " : The value of field_size is not int64_t.";
MS_LOG(ERROR) << name_ << ": The value of field_size is not int64_t.";
return FAILED;
}
}
// infer inputs dimension size
if ((inputs_shape_.size() != MATMUL_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_OUTPUTS_SIZE)) {
MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong.";
MS_LOG(ERROR) << name_ << ": Inputs shape size or outputs shape size is wrong.";
return FAILED;
}
mat_a_dimension_ = inputs_shape_.at(0).size();
mat_b_dimension_ = inputs_shape_.at(1).size();
if (mat_a_dimension_ < 2 || mat_b_dimension_ < 2) {
MS_LOG(ERROR) << name_ << ": The dim of mat_a or mat_b can not smaller than 2, but the dim of mat_a is "
<< mat_a_dimension_ << ", the dim of mat_b is " << mat_b_dimension_;
}
return SUCCESS;
}
@ -150,7 +159,7 @@ Status MatMul::CheckStrategy(const StrategyPtr &strategy) {
size_t mat_a_size = mat_a_strategy.size();
size_t mat_b_size = mat_b_strategy.size();
if ((mat_a_size != mat_a_dimension_) || (mat_b_size != mat_b_dimension_)) {
MS_LOG(ERROR) << name_ << " : The dimensions of mat_a or mat_b's strategy is wrong.";
MS_LOG(ERROR) << name_ << ": The dimensions of mat_a or mat_b's strategy is wrong.";
return FAILED;
}
@ -158,12 +167,12 @@ Status MatMul::CheckStrategy(const StrategyPtr &strategy) {
// dev_matrix_shape:[2,4,8,16,32] (transpose_b is false)
// [16] in the example above
if (!transpose_b_ && (mat_a_strategy.back() != mat_b_strategy.at(SECOND_FROM_END(mat_b_size)))) {
MS_LOG(ERROR) << name_ << " : Can not do this operator in the strategy: " << StrategyToString(stra)
MS_LOG(ERROR) << name_ << ": Can not do this operator in the strategy: " << StrategyToString(stra)
<< ", the transpose_b is false, the shard num of first input's column is " << mat_a_strategy.back()
<< ", but the shard num of second input's row is " << mat_b_strategy.at(SECOND_FROM_END(mat_b_size));
return FAILED;
} else if (transpose_b_ && (mat_a_strategy.back() != mat_b_strategy.back())) {
MS_LOG(ERROR) << name_ << " : Can not do this operator in the strategy: " << StrategyToString(stra)
MS_LOG(ERROR) << name_ << ": Can not do this operator in the strategy: " << StrategyToString(stra)
<< ", the transpose_b is true, the shard num of first input's column is " << mat_a_strategy.back()
<< ", but the shard num of second input's column is " << mat_b_strategy.back();
return FAILED;
@ -171,12 +180,12 @@ Status MatMul::CheckStrategy(const StrategyPtr &strategy) {
if (mat_a_size >= mat_b_size) {
if (CheckRelevantDimension(mat_a_strategy, mat_b_strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal.";
MS_LOG(ERROR) << name_ << ": Strategies of relevant dimensions are not equal.";
return FAILED;
}
} else {
if (CheckRelevantDimension(mat_b_strategy, mat_a_strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal.";
MS_LOG(ERROR) << name_ << ": Strategies of relevant dimensions are not equal.";
return FAILED;
}
}
@ -196,7 +205,7 @@ Status MatMul::CheckOutputStrategy(const StrategyPtr &out_strategy) {
}
if (CheckStrategyValue(out_strategy, outputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Invalid output strategy.";
MS_LOG(ERROR) << name_ << ": Invalid output strategy.";
return FAILED;
}
@ -255,16 +264,16 @@ Status MatMulBase::InferForwardCommunication() {
// Relevant dimension is not split and all reduce is not required,
// need to use origin_dev_matrix_shape_ here, since the dev_matrix_shape_ will be changed if repeated calculation.
if (origin_dev_matrix_shape_.at(relevant_dimension_index) == MIN_SLICE_NUM) {
MS_LOG(INFO) << name_ << " : Forward all reduce is not required.";
MS_LOG(INFO) << name_ << ": Forward all reduce is not required.";
return SUCCESS;
}
std::vector<Group> group_list;
if (CreateGroupByDim(relevant_dimension_index, &group_list) != SUCCESS) {
ReportError(name_ + " : Infer forward communication, create group failed.");
ReportError(name_ + ": Infer forward communication, create group failed.");
return FAILED;
} else if (group_list.empty()) {
MS_LOG(INFO) << name_ << " : Forward all reduce is not required.";
MS_LOG(INFO) << name_ << ": Forward all reduce is not required.";
return SUCCESS;
}
@ -276,7 +285,7 @@ Status MatMulBase::InferForwardCommunication() {
}
forward_op_.push_back(op);
MS_LOG(INFO) << name_ << " : The group name of forward communication is " << group_list[0].name();
MS_LOG(INFO) << name_ << ": The group name of forward communication is " << group_list[0].name();
return SUCCESS;
}
@ -392,7 +401,7 @@ Status MatMulBase::InferTensorInfo() {
Status MatMulBase::SwapLastTwoElements(mindspore::parallel::Shape *const input) {
if (input->size() < 2) {
MS_LOG(ERROR) << name_ << " : The size of inputs small than 2.";
MS_LOG(ERROR) << name_ << ": The size of inputs small than 2.";
return FAILED;
}
auto last_1st_value = input->at(input->size() - 1);
@ -404,270 +413,76 @@ Status MatMulBase::SwapLastTwoElements(mindspore::parallel::Shape *const input)
return SUCCESS;
}
Status MatMulBase::GenerateStrategiesBase(int64_t stage_id, size_t dev_num, const Shape &input0_shape,
Shape input1_shape, std::vector<StrategyPtr> *const sp_vector) {
// The shape of input0 (input1)
// E.g., input0 = [100, 200, 300], input1 = [300, 400]
// Combining the input0_shape and input1_shape
// E.g., combined_shape = [100, 200, 300, 400]
size_t input1_shape_size = input1_shape.size(), input0_shape_size = input0_shape.size();
Dimensions combined_partitions;
Shape combined_shape;
// In SwapLastTwoElements(), it is guaranteed that input0_shape.size() and input1_shape.size() are both larger than 2
if (input0_shape.size() >= input1_shape.size()) {
combined_shape = input0_shape;
combined_shape.push_back(input1_shape[input1_shape.size() - 1]);
} else {
combined_shape = input1_shape;
combined_shape.push_back(input0_shape[input0_shape.size() - 2]);
std::vector<StrategyPtr> MatMulBase::GenerateOpStrategies(int64_t stage_id) {
Shape mat_a_shape = inputs_shape_[0];
Shape mat_b_shape = inputs_shape_[1];
// it is not support transpose_a
if (transpose_a_) {
MS_LOG(EXCEPTION) << name_ << ": It's not yet supported transpose_a";
}
std::function<void(uint64_t, size_t)> recursive = [&stage_id, &dev_num, &sp_vector, &combined_partitions,
&combined_shape, &input1_shape_size, &recursive,
&input0_shape_size, this](uint64_t current_index, size_t n) {
// Finishing the recursive steps, if the strategy is valid, then calculate the cost
// for this operator under the strategy.
if (current_index == combined_shape.size()) {
StrategyPtr sp;
if (this->PrepareStrategy(stage_id, dev_num, combined_partitions, input0_shape_size, input1_shape_size, &sp) ==
FAILED) {
return;
}
sp_vector->push_back(sp);
} else {
MS_LOG(DEBUG) << name_ << " : The value input0_shape_size: " << input0_shape_size
<< ", input1_shape_size: " << input1_shape_size;
for (uint64_t i = 1; i <= n; i *= 2) {
if (n % i == 0 && LongToSize(combined_shape[current_index]) % i == 0) {
combined_partitions.push_back(i);
recursive(current_index + 1, n / i);
combined_partitions.pop_back();
}
}
// it is not support [B, C, D] * [A, B, D, E]
if (mat_b_shape.size() > mat_a_shape.size()) {
MS_LOG(EXCEPTION) << name_
<< ": It's not yet supported that the dim of mat_b larger than the dim of mat_a, but the dim of"
" mat_a is "
<< mat_a_shape.size() << ", the dim of mat_b is " << mat_b_shape.size();
}
// it is not support that broadcasts containing 1, such as [A, B, C, D] * [A, 1, D, E]
size_t diff_len = mat_a_shape.size() - mat_b_shape.size();
for (size_t i = 0; i < mat_b_shape.size() - 2; ++i) {
if (mat_b_shape[i] != mat_a_shape[i + diff_len]) {
MS_LOG(EXCEPTION) << name_ << ": It's not yet supported that broadcasts containing 1, but the shape of mat a is "
<< mat_a_shape << ", the shape of mat_b is " << mat_b_shape;
}
};
recursive(0, dev_num);
if (sp_vector->empty()) {
MS_LOG(ERROR) << name_ << " : No available strategy.";
return FAILED;
}
return Status::SUCCESS;
}
Status MatMulBase::GenerateStrategiesNotPower2(int64_t stage_id, size_t dev_num_not_2_power,
const std::vector<StrategyPtr> &sp_vector_2_power_part) {
// e.g. mat_a: [A, B, C, D], mat_b: [B, D, E], then to generate the strategy for [A, B, C, D, E]
std::vector<StrategyPtr> sp_vector;
size_t related_dim_left = transpose_a_ ? inputs_shape_[0].size() - 2 : inputs_shape_[0].size() - 1;
size_t related_dim_right = transpose_b_ ? inputs_shape_[1].size() - 1 : inputs_shape_[1].size() - 2;
// Handle the not power of 2 part.
for (auto &stra : sp_vector_2_power_part) {
auto stra_arrays = stra->GetInputDim();
if (stra_arrays.size() != 2) {
MS_LOG(ERROR) << "The generated strategy of matmul dose not match two input, the strategy is: " << stra_arrays;
}
for (size_t i = 0; i < 2; ++i) {
size_t stra_size = stra_arrays[i].size();
for (size_t j = 0; j < stra_size; ++j) {
if (i == 1 && j == related_dim_right) {
continue;
}
auto new_stra_arrays{stra_arrays};
new_stra_arrays[i][j] = new_stra_arrays[i][j] * SizeToLong(dev_num_not_2_power);
if (i == 0 && j == related_dim_left) {
new_stra_arrays[1][related_dim_right] =
new_stra_arrays[1][related_dim_right] * SizeToLong(dev_num_not_2_power);
}
StrategyPtr new_stra = std::make_shared<Strategy>(stage_id, new_stra_arrays);
sp_vector.push_back(new_stra);
}
}
Shape splittable_flag(mat_a_shape.size() + 1, 1);
Shapes splittable_input = {splittable_flag};
Shape tmp_shape = inputs_shape_[0];
size_t index = 0;
if (transpose_b_) {
index = inputs_shape_[1].size() - 2;
tmp_shape.push_back(inputs_shape_[1][index]); // mat_a: [A, B, C, D], mat_b: [B, E, D], tmp_shape: [A, B, C, D, E]
} else {
index = inputs_shape_[1].size() - 1;
tmp_shape.push_back(inputs_shape_[1][index]); // mat_a: [A, B, C, D], mat_b: [B, D, E], tmp_shape: [A, B, C, D, E]
}
strategy_cost_.clear();
// add the repeated strategy
auto repeated_stra_arrays{inputs_shape_};
for (auto &stra_array : repeated_stra_arrays) {
std::fill(stra_array.begin(), stra_array.end(), 1);
}
StrategyPtr repeated_stra = std::make_shared<Strategy>(stage_id, repeated_stra_arrays);
sp_vector.push_back(repeated_stra);
Shapes tmp_inputs_shape = {tmp_shape};
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Generate strategies failed";
}
// set the inputs' strategies
for (auto &sp : sp_vector) {
if (SetCostUnderStrategy(sp) == FAILED) {
MS_LOG(WARNING) << name_ << " : Calculating cost for strategy failed.";
continue;
if ((sp == nullptr) || sp->GetInputDim().empty()) {
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
}
}
if (strategy_cost_.empty()) {
MS_LOG(EXCEPTION) << name_ << " : No available strategy.";
}
return SUCCESS;
}
Strategys replace_strategy;
Dimensions tmp_strategy = sp->GetInputDim()[0];
Dimensions mat_a_strategy = tmp_strategy;
mat_a_strategy.pop_back();
Status MatMulBase::GenerateStrategies(int64_t stage_id) {
if (GetAttrs() != SUCCESS) {
MS_LOG(ERROR) << name_ << " : GetAttrs failed.";
return FAILED;
}
CheckGlobalDeviceManager();
RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id);
size_t dev_num = dev_list.size();
Shape input0_shape = inputs_shape_[0], input1_shape = inputs_shape_[1];
if (transpose_a_) {
if (SwapLastTwoElements(&input0_shape) == FAILED) {
MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
// mat_b_shape: [B, D, E], tmp_strategy: [A, B, C, D, E]
// mat_b_strategy: init [A, B, C, D, E]
Dimensions mat_b_strategy = tmp_strategy;
// mat_b_strategy: delete C, [A, B, D, E]
(void)mat_b_strategy.erase(mat_b_strategy.end() - 3);
// mat_b_strategy: delete A, [B, D, E]
(void)mat_b_strategy.erase(mat_b_strategy.begin(), mat_b_strategy.begin() + static_cast<different_type>(diff_len));
// handle transpose_b
if (transpose_b_) {
(void)SwapLastTwoElements(&mat_b_strategy);
}
replace_strategy.push_back(mat_a_strategy);
replace_strategy.push_back(mat_b_strategy);
sp->ResetInputs(replace_strategy);
}
if (transpose_b_) {
if (SwapLastTwoElements(&input1_shape) == FAILED) {
MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
}
}
auto dev_num_2_power = (dev_num & (dev_num - 1));
std::vector<StrategyPtr> sp_vector_2_power_part;
if (dev_num_2_power == 0) {
if (GenerateStrategiesBase(stage_id, dev_num, input0_shape, input1_shape, &sp_vector_2_power_part) != SUCCESS) {
MS_LOG(ERROR) << "No available strategy.";
return FAILED;
}
strategy_cost_.clear();
for (auto &sp : sp_vector_2_power_part) {
if (SetCostUnderStrategy(sp) == FAILED) {
MS_LOG(WARNING) << name_ << " : Calculating cost for strategy failed.";
continue;
}
}
if (strategy_cost_.empty()) {
MS_LOG(EXCEPTION) << name_ << " : No available strategy.";
}
return SUCCESS;
}
auto dev_num_not_2_power = dev_num / (dev_num - dev_num_2_power);
if (GenerateStrategiesBase(stage_id, dev_num - dev_num_2_power, input0_shape, input1_shape,
&sp_vector_2_power_part) != SUCCESS) {
MS_LOG(ERROR) << "Generating strategy in power of 2 devices failed.";
return FAILED;
}
return GenerateStrategiesNotPower2(stage_id, dev_num_not_2_power, sp_vector_2_power_part);
}
std::vector<StrategyPtr> MatMulBase::GenerateOpStrategies(int64_t) {
std::vector<StrategyPtr> sp_vector;
return sp_vector;
}
Status MatMulBase::PrepareStrategy(int64_t stage_id, size_t dev_num,
mindspore::parallel::Dimensions combined_partitions, size_t input0_shape_size,
size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) {
int64_t product =
std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies<int64_t>());
const auto fully_use_device = CostModelContext::GetInstance()->fully_use_device();
if (!fully_use_device) {
if (LongToSize(product) > dev_num) {
return FAILED;
}
} else {
if (LongToSize(product) != dev_num) {
return FAILED;
}
}
Dimensions input0_partitions, input1_partitions;
if (input0_shape_size >= input1_shape_size) {
for (size_t i = 0; i < input0_shape_size; ++i) {
input0_partitions.push_back(combined_partitions[i]);
}
if (input1_shape_size == 2) {
input1_partitions.push_back(combined_partitions[combined_partitions.size() - 2]);
input1_partitions.push_back(combined_partitions[combined_partitions.size() - 1]);
} else {
// input1_shape.size() > 2
for (size_t j = combined_partitions.size() - input1_shape_size - 1; j < combined_partitions.size(); ++j) {
if (j == combined_partitions.size() - 3) {
continue;
}
input1_partitions.push_back(combined_partitions[j]);
}
}
} else {
for (size_t i = 0; i < input1_shape_size; ++i) {
input1_partitions.push_back(combined_partitions[i]);
}
for (size_t j = combined_partitions.size() - input0_shape_size - 1; j < combined_partitions.size() - 3; ++j) {
input0_partitions.push_back(combined_partitions[j]);
}
input0_partitions.push_back(combined_partitions[combined_partitions.size() - 1]);
input0_partitions.push_back(combined_partitions[combined_partitions.size() - 3]);
}
if (transpose_a_) {
if (SwapLastTwoElements(&input0_partitions) == FAILED) {
MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
}
}
if (transpose_b_) {
if (SwapLastTwoElements(&input1_partitions) == FAILED) {
MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
}
}
Strategys stras;
stras.push_back(input0_partitions);
stras.push_back(input1_partitions);
(*sp) = std::make_shared<Strategy>(stage_id, stras);
return SUCCESS;
}
void MatMulBase::InitTensorInfoForCost(std::vector<TensorInfo> *relica_inputs_tensor_vector) {
TensorLayout tly;
if (transpose_a_) {
Shape replica_input0_shape(inputs_tensor_info_[0].shape());
Shape replica_input0_slice_shape(inputs_tensor_info_[0].slice_shape());
if (SwapLastTwoElements(&replica_input0_shape) == FAILED) {
MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
}
if (SwapLastTwoElements(&replica_input0_slice_shape) == FAILED) {
MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
}
TensorInfo replica_input0_info(tly, replica_input0_shape, replica_input0_slice_shape);
relica_inputs_tensor_vector->push_back(replica_input0_info);
} else {
relica_inputs_tensor_vector->push_back(inputs_tensor_info_[0]);
}
if (transpose_b_) {
Shape replica_input1_shape(inputs_tensor_info_[1].shape());
Shape replica_input1_slice_shape(inputs_tensor_info_[1].slice_shape());
if (SwapLastTwoElements(&replica_input1_shape) == FAILED) {
MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
}
if (SwapLastTwoElements(&replica_input1_slice_shape) == FAILED) {
MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
}
TensorInfo replica_input1_info(tly, replica_input1_shape, replica_input1_slice_shape);
relica_inputs_tensor_vector->push_back(replica_input1_info);
} else {
relica_inputs_tensor_vector->push_back(inputs_tensor_info_[1]);
}
}
Status MatMulBase::CheckForTensorSliceValid() const {
const auto align_enable = CostModelContext::GetInstance()->tensor_slice_alignment_enable();
const auto align_size = CostModelContext::GetInstance()->tensor_slice_alignment_size();
if (!align_enable) {
return SUCCESS;
}
if (inputs_tensor_info_.empty()) {
return FAILED;
}
for (auto &one_input_tensor : inputs_tensor_info_) {
auto slice_shape = one_input_tensor.slice_shape();
if ((LongToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % align_size != 0) ||
(LongToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % align_size != 0)) {
return FAILED;
}
}
return SUCCESS;
}
std::shared_ptr<Strategys> BatchMatMulInfo::GenerateBatchStrategies() {
Dimensions batch_strategy(inputs_shape_[1].size() - 1, 1);
(void)batch_strategy.insert(batch_strategy.begin(), stage_device_size_);
@ -675,50 +490,6 @@ std::shared_ptr<Strategys> BatchMatMulInfo::GenerateBatchStrategies() {
return std::make_shared<Strategys>(strategy_v);
}
Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) {
if (InitForCostModel(strategy, nullptr) == FAILED) {
MS_LOG(INFO) << name_ << " : Initialization under the strategy failed.";
return FAILED;
}
PrintStrategy(strategy);
// Check whether the tensor slice of input_tensor_info is valid or not
if (CheckForTensorSliceValid() != SUCCESS) {
MS_LOG(INFO) << name_ << " : The tensor slice is not valid under this strategy.";
return FAILED;
}
// Here, a replicated inputs_ is constructed for the transposed TensorInfo.
std::vector<TensorInfo> relica_inputs_tensor_vector;
InitTensorInfoForCost(&relica_inputs_tensor_vector);
int64_t stage_id = strategy->GetInputStage();
// Here, we use the origin outputs_, because we only use the slice size of the output tensor.
// It does not matter whether the output tensor is transposed or not.
double computation_cost =
operator_cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
double communication_cost = operator_cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
result->communication_without_parameter_ =
operator_cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
result->communication_with_partial_para_ =
result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_);
// Breaking ties for preferring data parallelization
BreakingTiesForPerferringDataParallel(strategy, result);
MS_LOG(DEBUG) << name_ << " : computation_cost: " << result->computation_cost_
<< ", communication_cost: " << result->communication_cost_
<< ", communication_without_parameter_: " << result->communication_without_parameter_
<< ", communication_with_partial_para_: " << result->communication_with_partial_para_;
// refine communication cost calculation for practice
RefineForPracticalCost(result, false);
result->communication_forward_ = result->communication_without_parameter_;
std::shared_ptr<StrategyWithCost> swc =
std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);
swc->cost_list.push_back(result);
(void)strategy_cost_.emplace_back(swc);
return SUCCESS;
}
Status MatMulBase::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
} // namespace parallel
} // namespace mindspore

View File

@ -38,12 +38,8 @@ class MatMulBase : public OperatorInfo {
~MatMulBase() override = default;
// Generate all strategies and the corresponding cost for this MatMul operator
Status GenerateStrategies(int64_t stage_id) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
Status PrepareStrategy(int64_t stage_id, size_t dev_num, Dimensions combined_partitions, size_t input0_shape_size,
size_t input1_shape_size, StrategyPtr *sp);
Status SwapLastTwoElements(Shape *shape);
protected:
@ -52,12 +48,6 @@ class MatMulBase : public OperatorInfo {
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout);
void InitTensorInfoForCost(std::vector<TensorInfo> *);
Status GenerateStrategiesBase(int64_t stage_id, size_t dev_num, const Shape &input0_shape, Shape input1_shape,
std::vector<StrategyPtr> *const sp_vector);
Status GenerateStrategiesNotPower2(int64_t stage_id, size_t dev_num_not_2_power,
const std::vector<StrategyPtr> &sp_vector_2_power_part);
Status CheckForTensorSliceValid() const;
Status GetAttrs() override;
bool transpose_a_ = false;

View File

@ -90,7 +90,7 @@ class OperatorInfo {
// Given the stage_id (which indicates the number of devices),
// generate all strategies for this operator
virtual Status GenerateStrategies(int64_t stage_id);
Status GenerateStrategies(int64_t stage_id);
virtual std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) = 0;
const OperatorCostPtr &operator_cost() const { return operator_cost_; }
void set_cost(const OperatorCostPtr &cost) { operator_cost_ = cost; }

View File

@ -420,15 +420,9 @@ void ReshapeInfo::SetCostForReshape(const mindspore::parallel::StrategyPtr &stra
strategy_cost_.emplace_back(swc);
}
Status ReshapeInfo::GenerateStrategies(int64_t stage_id) {
if (GetAttrs() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": GetAttrs failed.";
return FAILED;
}
if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) {
MS_LOG(ERROR) << name_ << ": Inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", "
<< outputs_shape_.size();
return FAILED;
std::vector<StrategyPtr> ReshapeInfo::GenerateOpStrategies(int64_t stage_id) {
if (inputs_shape_.empty()) {
MS_LOG(EXCEPTION) << name_ << ": Inputs shape size or is empty";
}
Shape input0_split;
(void)input0_split.insert(input0_split.end(), inputs_shape_[0].size(), 1);
@ -436,15 +430,10 @@ Status ReshapeInfo::GenerateStrategies(int64_t stage_id) {
// strategy used only in the input node is parameter,
// in other case, use the input node's output_layout as input_layout.
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed.";
return FAILED;
MS_LOG(EXCEPTION) << name_ << ": GenerateStrategiesForIndependentInputs failed.";
}
return SUCCESS;
}
std::vector<StrategyPtr> ReshapeInfo::GenerateOpStrategies(int64_t) {
std::vector<StrategyPtr> sp_vector;
return sp_vector;
return sp_vector_;
}
Status ReshapeInfo::GenerateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,

View File

@ -61,8 +61,7 @@ class ReshapeInfo : public OperatorInfo {
Status GenerateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs, int64_t out_index,
int64_t in_index, bool is_prev_param, bool is_next_reshape);
Status GenerateStrategies(int64_t stage_id) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
std::string pre_operator_name() const { return pre_operator_name_; }
std::string next_operator_name() const { return next_operator_name_; }

View File

@ -159,7 +159,7 @@ Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
return SetCostUnderStrategyBase(strategy);
}
Status VirtualDatasetInfo::GenerateStrategies(int64_t stage_id) {
std::vector<StrategyPtr> VirtualDatasetInfo::GenerateOpStrategies(int64_t stage_id) {
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
StrategyPtr sp;
Strategys strategy;
@ -183,18 +183,8 @@ Status VirtualDatasetInfo::GenerateStrategies(int64_t stage_id) {
}
}
sp = std::make_shared<Strategy>(stage_id, strategy);
if (SetCostUnderStrategy(sp) == SUCCESS) {
MS_LOG(INFO) << name_ << ": Successfully dataset strategy.";
PrintStrategy(sp);
} else {
MS_LOG(ERROR) << name_ << ": Generating dataset strategy failed.";
return FAILED;
}
return SUCCESS;
}
std::vector<StrategyPtr> VirtualDatasetInfo::GenerateOpStrategies(int64_t) {
std::vector<StrategyPtr> sp_vector;
sp_vector.push_back(sp);
return sp_vector;
}

View File

@ -37,8 +37,7 @@ class VirtualDatasetInfo : public OperatorInfo {
Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) override;
Status InitForCostModel(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) override;
Status GenerateStrategies(int64_t stage_id) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
void ReComputeBatchSplitFlagList() override;

View File

@ -51,7 +51,7 @@ Status VirtualOutputInfo::CheckStrategy(const StrategyPtr &strategy) {
return SUCCESS;
}
Status VirtualOutputInfo::GenerateStrategies(int64_t stage_id) {
std::vector<StrategyPtr> VirtualOutputInfo::GenerateOpStrategies(int64_t stage_id) {
StrategyPtr sp;
Strategys strategy;
bool full_batch = ParallelContext::GetInstance()->full_batch();
@ -63,8 +63,7 @@ Status VirtualOutputInfo::GenerateStrategies(int64_t stage_id) {
}
if (total_dev_num == 0) {
MS_LOG(ERROR) << name_ << ": The total devices num is 0";
return FAILED;
MS_LOG(EXCEPTION) << name_ << ": The total devices num is 0";
}
for (auto &shape : inputs_shape_) {
@ -80,14 +79,9 @@ Status VirtualOutputInfo::GenerateStrategies(int64_t stage_id) {
strategy.push_back(temp);
}
sp = std::make_shared<Strategy>(stage_id, strategy);
if (SetCostUnderStrategy(sp) == SUCCESS) {
MS_LOG(INFO) << name_ << ": Successfully dataset strategy.";
PrintStrategy(sp);
} else {
MS_LOG(ERROR) << name_ << ": Generating dataset strategy failed.";
return FAILED;
}
return SUCCESS;
std::vector<StrategyPtr> sp_vector;
sp_vector.push_back(sp);
return sp_vector;
}
} // namespace parallel
} // namespace mindspore

View File

@ -35,7 +35,7 @@ class VirtualOutputInfo : public VirtualDatasetInfo {
const PrimitiveAttrs &attrs)
: VirtualDatasetInfo(name, inputs_shape, outputs_shape, attrs) {}
~VirtualOutputInfo() override = default;
Status GenerateStrategies(int64_t stage_id) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;

View File

@ -93,6 +93,9 @@ void TestMatmulInfo::SetUp() {
matmul4 = std::make_shared<MatMulInfo>("matmul_info", inputs_shape_3, outputs_shape_3, attr_4);
}
/// Feature: test matmul info
/// Description: infer dev matrix
/// Expectation: the dev matrix is right
TEST_F(TestMatmulInfo, InferDevMatrixShape1) {
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs);
@ -104,6 +107,9 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape1) {
ASSERT_EQ(dev_matrix_shape, expect);
}
/// Feature: test matmul info
/// Description: infer dev matrix
/// Expectation: the dev matrix is right
TEST_F(TestMatmulInfo, InferDevMatrixShape2) {
Strategys inputs = {{2, 4, 8, 8}, {2, 4, 8, 2}};
StrategyPtr strategy = NewStrategy(0, inputs);
@ -115,7 +121,9 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape2) {
ASSERT_EQ(dev_matrix_shape, expect);
}
// matmul2
/// Feature: test matmul info
/// Description: infer dev matrix
/// Expectation: the dev matrix is right
TEST_F(TestMatmulInfo, InferDevMatrixShape3) {
Strategys inputs = {{2, 4, 8, 16}, {1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs);
@ -127,7 +135,9 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape3) {
ASSERT_EQ(dev_matrix_shape, expect);
}
// matmul2
/// Feature: test matmul info
/// Description: infer dev matrix
/// Expectation: the dev matrix is right
TEST_F(TestMatmulInfo, InferDevMatrixShape4) {
Strategys inputs = {{2, 4, 8, 8}, {2, 8}};
StrategyPtr strategy = NewStrategy(0, inputs);
@ -139,7 +149,9 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape4) {
ASSERT_EQ(dev_matrix_shape, expect);
}
// matmul3
/// Feature: test matmul info
/// Description: infer dev matrix
/// Expectation: the dev matrix is right
TEST_F(TestMatmulInfo, InferDevMatrixShape5) {
Strategys inputs = {{8, 16}, {2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs);
@ -151,7 +163,9 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape5) {
ASSERT_EQ(dev_matrix_shape, expect);
}
// matmul3
/// Feature: test matmul info
/// Description: infer dev matrix
/// Expectation: the dev matrix is right
TEST_F(TestMatmulInfo, InferDevMatrixShape6) {
Strategys inputs = {{8, 8}, {2, 4, 2, 8}};
StrategyPtr strategy = NewStrategy(0, inputs);
@ -163,6 +177,9 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape6) {
ASSERT_EQ(dev_matrix_shape, expect);
}
/// Feature: test matmul info
/// Description: infer tensor map
/// Expectation: the tensor map is right
TEST_F(TestMatmulInfo, InferTensorMap1) {
Strategys str = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, str);
@ -188,7 +205,9 @@ TEST_F(TestMatmulInfo, InferTensorMap1) {
ASSERT_EQ(output_tensor_map.array(), output_expect);
}
// matmul2
/// Feature: test matmul info
/// Description: infer tensor map
/// Expectation: the tensor map is right
TEST_F(TestMatmulInfo, InferTensorMap2) {
Strategys str = {{2, 4, 8, 16}, {1, 16}};
StrategyPtr strategy = NewStrategy(0, str);
@ -214,7 +233,9 @@ TEST_F(TestMatmulInfo, InferTensorMap2) {
ASSERT_EQ(output_tensor_map.array(), output_expect);
}
// matmul3
/// Feature: test matmul info
/// Description: infer tensor map
/// Expectation: the tensor map is right
TEST_F(TestMatmulInfo, InferTensorMap3) {
Strategys str = {{8, 16}, {2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str);
@ -240,6 +261,9 @@ TEST_F(TestMatmulInfo, InferTensorMap3) {
ASSERT_EQ(output_tensor_map.array(), output_expect);
}
/// Feature: test matmul info
/// Description: infer slice shape
/// Expectation: the slice shape is right
TEST_F(TestMatmulInfo, InferSliceShape1) {
Strategys str = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, str);
@ -265,7 +289,9 @@ TEST_F(TestMatmulInfo, InferSliceShape1) {
ASSERT_EQ(output_slice_shape, output_slice_shape_expect);
}
// matmul2
/// Feature: test matmul info
/// Description: infer slice shape
/// Expectation: the slice shape is right
TEST_F(TestMatmulInfo, InferSliceShape2) {
Strategys str = {{2, 4, 8, 16}, {1, 16}};
StrategyPtr strategy = NewStrategy(0, str);
@ -291,7 +317,9 @@ TEST_F(TestMatmulInfo, InferSliceShape2) {
ASSERT_EQ(output_slice_shape, output_slice_shape_expect);
}
// matmul3
/// Feature: test matmul info
/// Description: infer slice shape
/// Expectation: the slice shape is right
TEST_F(TestMatmulInfo, InferSliceShape3) {
Strategys str = {{8, 16}, {2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str);
@ -317,7 +345,9 @@ TEST_F(TestMatmulInfo, InferSliceShape3) {
ASSERT_EQ(output_slice_shape, output_slice_shape_expect);
}
// matmul3
/// Feature: test matmul info
/// Description: get tensor layout
/// Expectation: the tensor layout is right
TEST_F(TestMatmulInfo, GetTensorLayout3) {
Strategys str = {{8, 16}, {2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str);
@ -343,6 +373,9 @@ TEST_F(TestMatmulInfo, GetTensorLayout3) {
ASSERT_EQ(output_tensor_map.array(), output_expect);
}
/// Feature: test matmul info
/// Description: infer forward op
/// Expectation: the forward op is right
TEST_F(TestMatmulInfo, GetForwardOp1) {
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs);
@ -369,6 +402,9 @@ TEST_F(TestMatmulInfo, GetForwardOp1) {
ASSERT_EQ(arg1_value_is_string, true);
}
/// Feature: test matmul info
/// Description: infer forward op
/// Expectation: the forward op is right
TEST_F(TestMatmulInfo, GetForwardOp2) {
Strategys inputs = {{2, 4, 8, 1}, {2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs);
@ -379,6 +415,9 @@ TEST_F(TestMatmulInfo, GetForwardOp2) {
ASSERT_EQ(forward_op.size(), 0);
}
/// Feature: test matmul info
/// Description: infer virtual_div op
/// Expectation: the virtual_div op is right
TEST_F(TestMatmulInfo, GetVirtualDivOp1) {
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs);
@ -398,6 +437,9 @@ TEST_F(TestMatmulInfo, GetVirtualDivOp1) {
ASSERT_EQ(divisor, 16);
}
/// Feature: test matmul info
/// Description: infer mirror op
/// Expectation: the mirror op is right
TEST_F(TestMatmulInfo, GetMirrorOPs1) {
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs);
@ -417,7 +459,9 @@ TEST_F(TestMatmulInfo, GetMirrorOPs1) {
ASSERT_EQ(arg0_name, "group");
}
// matmul2
/// Feature: test matmul info
/// Description: infer mirror op
/// Expectation: the mirror op is right
TEST_F(TestMatmulInfo, GetMirrorOPs2) {
Strategys inputs = {{2, 4, 1, 16}, {8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs);
@ -437,7 +481,9 @@ TEST_F(TestMatmulInfo, GetMirrorOPs2) {
ASSERT_EQ(arg0_name, "group");
}
// matmul3
/// Feature: test matmul info
/// Description: infer mirror op
/// Expectation: the mirror op is right
TEST_F(TestMatmulInfo, GetMirrorOPs3) {
Strategys inputs = {{8, 16}, {2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs);
@ -456,6 +502,9 @@ TEST_F(TestMatmulInfo, GetMirrorOPs3) {
ASSERT_EQ(arg0_name, "group");
}
/// Feature: test matmul info
/// Description: infer mirror op
/// Expectation: the mirror op is right
TEST_F(TestMatmulInfo, GetMirrorOPs4) {
Strategys inputs = {{2, 4, 1, 16}, {2, 4, 16, 8}};
StrategyPtr strategy = NewStrategy(0, inputs);
@ -466,6 +515,9 @@ TEST_F(TestMatmulInfo, GetMirrorOPs4) {
ASSERT_EQ(mirror_ops.size(), 2);
}
/// Feature: test matmul info
/// Description: init twice
/// Expectation: the mirror op is right
TEST_F(TestMatmulInfo, InitTwice) {
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs);
@ -487,6 +539,9 @@ TEST_F(TestMatmulInfo, InitTwice) {
ASSERT_EQ(arg0_name, "group");
}
/// Feature: test matmul info
/// Description: check strategy, the strategy is invalid
/// Expectation: return FAILED
TEST_F(TestMatmulInfo, CheckStrategy1) {
// Success: {{2,4,8,16}, {2,4,16,1}}
Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
@ -496,6 +551,9 @@ TEST_F(TestMatmulInfo, CheckStrategy1) {
ASSERT_EQ(ret, FAILED);
}
/// Feature: test matmul info
/// Description: check strategy, the strategy is invalid
/// Expectation: return FAILED
TEST_F(TestMatmulInfo, CheckStrategy2) {
// Success: {{2,4,8,16}, {2,4,16,1}}
Strategys inputs = {{2, 4, 8, 16}, {4, 16, 1}};
@ -505,6 +563,9 @@ TEST_F(TestMatmulInfo, CheckStrategy2) {
ASSERT_EQ(ret, FAILED);
}
/// Feature: test matmul info
/// Description: check strategy, the strategy is invalid
/// Expectation: return FAILED
TEST_F(TestMatmulInfo, CheckStrategy3) {
// Success: {{2,4,8,16}, {2,4,16,1}}
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 8, 1}};
@ -514,6 +575,9 @@ TEST_F(TestMatmulInfo, CheckStrategy3) {
ASSERT_EQ(ret, FAILED);
}
/// Feature: test matmul info
/// Description: check strategy, the strategy is invalid
/// Expectation: return FAILED
TEST_F(TestMatmulInfo, CheckStrategy4) {
// Success: {{2,4,8,16}, {2,4,16,1}}
Strategys inputs = {{2, 4, 8, 16}, {2, 3, 16, 1}};
@ -523,6 +587,9 @@ TEST_F(TestMatmulInfo, CheckStrategy4) {
ASSERT_EQ(ret, FAILED);
}
/// Feature: test matmul info
/// Description: check strategy, the strategy is invalid
/// Expectation: return FAILED
TEST_F(TestMatmulInfo, CheckStrategy5) {
// Success: {{2,4,8,16}, {2,4,16,1}}
Strategys inputs = {{0, 4, 8, 16}, {2, 4, 16, 1}};
@ -532,6 +599,9 @@ TEST_F(TestMatmulInfo, CheckStrategy5) {
ASSERT_EQ(ret, FAILED);
}
/// Feature: test matmul info
/// Description: check strategy, the strategy is invalid
/// Expectation: return FAILED
TEST_F(TestMatmulInfo, CheckStrategy6) {
// Success: {{2,4,8,16}, {2,4,16,1}}
Strategys inputs = {{-1, 4, 8, 16}, {2, 4, 16, 1}};
@ -541,6 +611,9 @@ TEST_F(TestMatmulInfo, CheckStrategy6) {
ASSERT_EQ(ret, FAILED);
}
/// Feature: test matmul info
/// Description: check strategy, the strategy is invalid
/// Expectation: return FAILED
TEST_F(TestMatmulInfo, CheckStrategy7) {
// Success: {{2,4,8,16}, {2,4,16,1}}
Strategys inputs = {{4, 4, 8, 16}, {2, 4, 16, 1}};
@ -550,6 +623,9 @@ TEST_F(TestMatmulInfo, CheckStrategy7) {
ASSERT_EQ(ret, FAILED);
}
/// Feature: test matmul info
/// Description: init, invalid strategy
/// Expectation: return FAILED
TEST_F(TestMatmulInfo, InitFailed) {
// matmul4 attr is wrong
Strategys inputs = {{4, 4, 8, 16}, {2, 4, 16, 1}};
@ -559,6 +635,9 @@ TEST_F(TestMatmulInfo, InitFailed) {
ASSERT_EQ(ret, FAILED);
}
/// Feature: test matmul info
/// Description: generate strategy
/// Expectation: the computation cost is right
TEST_F(TestMatmulInfo, test_GenerateStrategies1) {
// the parameter '0' indicates that the stageId = 0, there are 1024 devices in the stage 0
ASSERT_EQ(matmul1->GenerateStrategies(0), Status::SUCCESS);
@ -574,35 +653,5 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies1) {
break;
}
}
TEST_F(TestMatmulInfo, test_GenerateStrategies2) {
// the parameter '0' indicates that the stageId = 0, there are 1024 devices in the stage 0
ASSERT_EQ(matmul3->GenerateStrategies(0), Status::SUCCESS);
std::vector<std::shared_ptr<StrategyWithCost>> sc = matmul3->GetStrategyCost();
for (const auto& swc : sc) {
StrategyPtr sp = swc->strategy_ptr;
Cost cost = *(swc->cost_list[0]);
matmul3->InitForCostModel(sp, nullptr);
std::vector<TensorInfo> inputs_info = matmul3->inputs_tensor_info();
std::vector<TensorInfo> outputs_info = matmul3->outputs_tensor_info();
std::vector<TensorInfo> replica_inputs_info;
replica_inputs_info.push_back(inputs_info[0]);
// transpose the tensor B
TensorInfo input1_info = inputs_info[1];
Shape input1_shape = input1_info.shape();
Shape input1_slice_shape = input1_info.slice_shape();
TensorLayout tly;
matmul3->SwapLastTwoElements(&input1_shape);
matmul3->SwapLastTwoElements(&input1_slice_shape);
TensorInfo replica_input1_info(tly, input1_shape, input1_slice_shape);
replica_inputs_info.push_back(replica_input1_info);
ASSERT_DOUBLE_EQ(matmul3->operator_cost()->GetComputationCost(replica_inputs_info, outputs_info, sp->GetInputStage()),
cost.computation_cost_);
break;
}
}
} // namespace parallel
} // namespace mindspore

View File

@ -58,6 +58,11 @@ def compile_net(net, x, y, b, phase):
def test_auto_parallel_arithmetic():
"""
Features: test auto parallel
Description: search strategies
Expectation: Generated strategies matching expectations
"""
class Net(nn.Cell):
def __init__(self):
super().__init__()
@ -89,6 +94,11 @@ def test_auto_parallel_arithmetic():
def test_auto_parallel_arithmetic_broadcast_both():
"""
Features: test auto parallel
Description: search strategies for broadcast
Expectation: Generated strategies matching expectations
"""
class Net(nn.Cell):
def __init__(self):
super().__init__()
@ -113,12 +123,17 @@ def test_auto_parallel_arithmetic_broadcast_both():
strategies = _cell_graph_executor._get_shard_strategy(net)
for (k, v) in strategies.items():
if re.search('FloorDiv-op', k) is not None:
assert v == [[8, 1], [1, 1]]
assert v == [[1, 1], [1, 1]]
elif re.search('MatMul-op', k) is not None:
assert v == [[8, 1], [1, 1]]
assert v == [[1, 1], [1, 1]]
def test_auto_parallel_arithmetic_broadcast_right():
"""
Features: test auto parallel
Description: search strategies for right broadcast
Expectation: Generated strategies matching expectations
"""
class Net(nn.Cell):
def __init__(self):
super().__init__()
@ -150,6 +165,11 @@ def test_auto_parallel_arithmetic_broadcast_right():
def test_auto_parallel_arithmetic_broadcast_left():
"""
Features: test auto parallel
Description: search strategies for left broadcast
Expectation: Generated strategies matching expectations
"""
class Net(nn.Cell):
def __init__(self):
super().__init__()

View File

@ -39,6 +39,11 @@ class NetWithLoss(nn.Cell):
def test_common_parameter():
"""
Features: test auto parallel
Description: search strategies for cast parameter
Expectation: Generated strategies matching expectations
"""
class Net(nn.Cell):
def __init__(self):
super().__init__()
@ -72,7 +77,9 @@ def test_common_parameter():
_cell_graph_executor.compile(net, x, y, phase='train')
strategies = _cell_graph_executor._get_shard_strategy(net)
for (k, v) in strategies.items():
if re.search('MatMul-op', k) is not None:
if re.search('MatMul-op0', k) is not None:
assert v == [[4, 1], [1, 2]]
elif re.search('MatMul-op', k) is not None:
assert v == [[8, 1], [1, 1]]
elif re.search('Cast-op', k) is not None:
assert v == [[1, 1]]