forked from mindspore-Ecosystem/mindspore
modify gen strategy for matmul op
This commit is contained in:
parent
41f1ac9573
commit
e5b62b2714
|
@ -69,7 +69,7 @@ void SetDevMatrixShape(const Dimensions &mat_a_strategy, const Dimensions &mat_b
|
|||
|
||||
Status MatMulBase::GetAttrs() {
|
||||
if (attrs_.size() < MATMUL_ATTRS_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << " : The size of attrs small than 2.";
|
||||
MS_LOG(ERROR) << name_ << ": The size of attrs small than 2, got " << attrs_.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
|
@ -79,18 +79,23 @@ Status MatMulBase::GetAttrs() {
|
|||
if (transpose_a_iter->second->isa<BoolImm>()) {
|
||||
transpose_a_ = transpose_a_iter->second->cast<BoolImmPtr>()->value();
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << " : The value of transpose_a is not bool.";
|
||||
MS_LOG(ERROR) << name_ << ": The value of transpose_a is not bool.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
if (transpose_a_) {
|
||||
MS_LOG(ERROR) << name_ << ": The transpose_a=true is not be supported";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
auto transpose_b_iter = attrs_.find(TRANSPOSE_B);
|
||||
if (transpose_b_iter != attrs_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(transpose_b_iter->second);
|
||||
if (transpose_b_iter->second->isa<BoolImm>()) {
|
||||
transpose_b_ = transpose_b_iter->second->cast<BoolImmPtr>()->value();
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << " : The value of transpose_b is not bool.";
|
||||
MS_LOG(ERROR) << name_ << ": The value of transpose_b is not bool.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
@ -101,18 +106,22 @@ Status MatMulBase::GetAttrs() {
|
|||
if (field_size_iter->second->isa<Int64Imm>()) {
|
||||
field_size_ = field_size_iter->second->cast<Int64ImmPtr>()->value();
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << " : The value of field_size is not int64_t.";
|
||||
MS_LOG(ERROR) << name_ << ": The value of field_size is not int64_t.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
// infer inputs dimension size
|
||||
if ((inputs_shape_.size() != MATMUL_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_OUTPUTS_SIZE)) {
|
||||
MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong.";
|
||||
MS_LOG(ERROR) << name_ << ": Inputs shape size or outputs shape size is wrong.";
|
||||
return FAILED;
|
||||
}
|
||||
mat_a_dimension_ = inputs_shape_.at(0).size();
|
||||
mat_b_dimension_ = inputs_shape_.at(1).size();
|
||||
if (mat_a_dimension_ < 2 || mat_b_dimension_ < 2) {
|
||||
MS_LOG(ERROR) << name_ << ": The dim of mat_a or mat_b can not smaller than 2, but the dim of mat_a is "
|
||||
<< mat_a_dimension_ << ", the dim of mat_b is " << mat_b_dimension_;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
@ -150,7 +159,7 @@ Status MatMul::CheckStrategy(const StrategyPtr &strategy) {
|
|||
size_t mat_a_size = mat_a_strategy.size();
|
||||
size_t mat_b_size = mat_b_strategy.size();
|
||||
if ((mat_a_size != mat_a_dimension_) || (mat_b_size != mat_b_dimension_)) {
|
||||
MS_LOG(ERROR) << name_ << " : The dimensions of mat_a or mat_b's strategy is wrong.";
|
||||
MS_LOG(ERROR) << name_ << ": The dimensions of mat_a or mat_b's strategy is wrong.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
|
@ -158,12 +167,12 @@ Status MatMul::CheckStrategy(const StrategyPtr &strategy) {
|
|||
// dev_matrix_shape:[2,4,8,16,32] (transpose_b is false)
|
||||
// [16] in the example above
|
||||
if (!transpose_b_ && (mat_a_strategy.back() != mat_b_strategy.at(SECOND_FROM_END(mat_b_size)))) {
|
||||
MS_LOG(ERROR) << name_ << " : Can not do this operator in the strategy: " << StrategyToString(stra)
|
||||
MS_LOG(ERROR) << name_ << ": Can not do this operator in the strategy: " << StrategyToString(stra)
|
||||
<< ", the transpose_b is false, the shard num of first input's column is " << mat_a_strategy.back()
|
||||
<< ", but the shard num of second input's row is " << mat_b_strategy.at(SECOND_FROM_END(mat_b_size));
|
||||
return FAILED;
|
||||
} else if (transpose_b_ && (mat_a_strategy.back() != mat_b_strategy.back())) {
|
||||
MS_LOG(ERROR) << name_ << " : Can not do this operator in the strategy: " << StrategyToString(stra)
|
||||
MS_LOG(ERROR) << name_ << ": Can not do this operator in the strategy: " << StrategyToString(stra)
|
||||
<< ", the transpose_b is true, the shard num of first input's column is " << mat_a_strategy.back()
|
||||
<< ", but the shard num of second input's column is " << mat_b_strategy.back();
|
||||
return FAILED;
|
||||
|
@ -171,12 +180,12 @@ Status MatMul::CheckStrategy(const StrategyPtr &strategy) {
|
|||
|
||||
if (mat_a_size >= mat_b_size) {
|
||||
if (CheckRelevantDimension(mat_a_strategy, mat_b_strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal.";
|
||||
MS_LOG(ERROR) << name_ << ": Strategies of relevant dimensions are not equal.";
|
||||
return FAILED;
|
||||
}
|
||||
} else {
|
||||
if (CheckRelevantDimension(mat_b_strategy, mat_a_strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal.";
|
||||
MS_LOG(ERROR) << name_ << ": Strategies of relevant dimensions are not equal.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
@ -196,7 +205,7 @@ Status MatMul::CheckOutputStrategy(const StrategyPtr &out_strategy) {
|
|||
}
|
||||
|
||||
if (CheckStrategyValue(out_strategy, outputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Invalid output strategy.";
|
||||
MS_LOG(ERROR) << name_ << ": Invalid output strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
|
@ -255,16 +264,16 @@ Status MatMulBase::InferForwardCommunication() {
|
|||
// Relevant dimension is not split and all reduce is not required,
|
||||
// need to use origin_dev_matrix_shape_ here, since the dev_matrix_shape_ will be changed if repeated calculation.
|
||||
if (origin_dev_matrix_shape_.at(relevant_dimension_index) == MIN_SLICE_NUM) {
|
||||
MS_LOG(INFO) << name_ << " : Forward all reduce is not required.";
|
||||
MS_LOG(INFO) << name_ << ": Forward all reduce is not required.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<Group> group_list;
|
||||
if (CreateGroupByDim(relevant_dimension_index, &group_list) != SUCCESS) {
|
||||
ReportError(name_ + " : Infer forward communication, create group failed.");
|
||||
ReportError(name_ + ": Infer forward communication, create group failed.");
|
||||
return FAILED;
|
||||
} else if (group_list.empty()) {
|
||||
MS_LOG(INFO) << name_ << " : Forward all reduce is not required.";
|
||||
MS_LOG(INFO) << name_ << ": Forward all reduce is not required.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -276,7 +285,7 @@ Status MatMulBase::InferForwardCommunication() {
|
|||
}
|
||||
|
||||
forward_op_.push_back(op);
|
||||
MS_LOG(INFO) << name_ << " : The group name of forward communication is " << group_list[0].name();
|
||||
MS_LOG(INFO) << name_ << ": The group name of forward communication is " << group_list[0].name();
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -392,7 +401,7 @@ Status MatMulBase::InferTensorInfo() {
|
|||
|
||||
Status MatMulBase::SwapLastTwoElements(mindspore::parallel::Shape *const input) {
|
||||
if (input->size() < 2) {
|
||||
MS_LOG(ERROR) << name_ << " : The size of inputs small than 2.";
|
||||
MS_LOG(ERROR) << name_ << ": The size of inputs small than 2.";
|
||||
return FAILED;
|
||||
}
|
||||
auto last_1st_value = input->at(input->size() - 1);
|
||||
|
@ -404,270 +413,76 @@ Status MatMulBase::SwapLastTwoElements(mindspore::parallel::Shape *const input)
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MatMulBase::GenerateStrategiesBase(int64_t stage_id, size_t dev_num, const Shape &input0_shape,
|
||||
Shape input1_shape, std::vector<StrategyPtr> *const sp_vector) {
|
||||
// The shape of input0 (input1)
|
||||
// E.g., input0 = [100, 200, 300], input1 = [300, 400]
|
||||
|
||||
// Combining the input0_shape and input1_shape
|
||||
// E.g., combined_shape = [100, 200, 300, 400]
|
||||
size_t input1_shape_size = input1_shape.size(), input0_shape_size = input0_shape.size();
|
||||
Dimensions combined_partitions;
|
||||
Shape combined_shape;
|
||||
// In SwapLastTwoElements(), it is guaranteed that input0_shape.size() and input1_shape.size() are both larger than 2
|
||||
if (input0_shape.size() >= input1_shape.size()) {
|
||||
combined_shape = input0_shape;
|
||||
combined_shape.push_back(input1_shape[input1_shape.size() - 1]);
|
||||
} else {
|
||||
combined_shape = input1_shape;
|
||||
combined_shape.push_back(input0_shape[input0_shape.size() - 2]);
|
||||
std::vector<StrategyPtr> MatMulBase::GenerateOpStrategies(int64_t stage_id) {
|
||||
Shape mat_a_shape = inputs_shape_[0];
|
||||
Shape mat_b_shape = inputs_shape_[1];
|
||||
// it is not support transpose_a
|
||||
if (transpose_a_) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": It's not yet supported transpose_a";
|
||||
}
|
||||
std::function<void(uint64_t, size_t)> recursive = [&stage_id, &dev_num, &sp_vector, &combined_partitions,
|
||||
&combined_shape, &input1_shape_size, &recursive,
|
||||
&input0_shape_size, this](uint64_t current_index, size_t n) {
|
||||
// Finishing the recursive steps, if the strategy is valid, then calculate the cost
|
||||
// for this operator under the strategy.
|
||||
if (current_index == combined_shape.size()) {
|
||||
StrategyPtr sp;
|
||||
if (this->PrepareStrategy(stage_id, dev_num, combined_partitions, input0_shape_size, input1_shape_size, &sp) ==
|
||||
FAILED) {
|
||||
return;
|
||||
}
|
||||
sp_vector->push_back(sp);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << name_ << " : The value input0_shape_size: " << input0_shape_size
|
||||
<< ", input1_shape_size: " << input1_shape_size;
|
||||
for (uint64_t i = 1; i <= n; i *= 2) {
|
||||
if (n % i == 0 && LongToSize(combined_shape[current_index]) % i == 0) {
|
||||
combined_partitions.push_back(i);
|
||||
recursive(current_index + 1, n / i);
|
||||
combined_partitions.pop_back();
|
||||
}
|
||||
}
|
||||
// it is not support [B, C, D] * [A, B, D, E]
|
||||
if (mat_b_shape.size() > mat_a_shape.size()) {
|
||||
MS_LOG(EXCEPTION) << name_
|
||||
<< ": It's not yet supported that the dim of mat_b larger than the dim of mat_a, but the dim of"
|
||||
" mat_a is "
|
||||
<< mat_a_shape.size() << ", the dim of mat_b is " << mat_b_shape.size();
|
||||
}
|
||||
// it is not support that broadcasts containing 1, such as [A, B, C, D] * [A, 1, D, E]
|
||||
size_t diff_len = mat_a_shape.size() - mat_b_shape.size();
|
||||
for (size_t i = 0; i < mat_b_shape.size() - 2; ++i) {
|
||||
if (mat_b_shape[i] != mat_a_shape[i + diff_len]) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": It's not yet supported that broadcasts containing 1, but the shape of mat a is "
|
||||
<< mat_a_shape << ", the shape of mat_b is " << mat_b_shape;
|
||||
}
|
||||
};
|
||||
recursive(0, dev_num);
|
||||
if (sp_vector->empty()) {
|
||||
MS_LOG(ERROR) << name_ << " : No available strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
return Status::SUCCESS;
|
||||
}
|
||||
|
||||
Status MatMulBase::GenerateStrategiesNotPower2(int64_t stage_id, size_t dev_num_not_2_power,
|
||||
const std::vector<StrategyPtr> &sp_vector_2_power_part) {
|
||||
// e.g. mat_a: [A, B, C, D], mat_b: [B, D, E], then to generate the strategy for [A, B, C, D, E]
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
size_t related_dim_left = transpose_a_ ? inputs_shape_[0].size() - 2 : inputs_shape_[0].size() - 1;
|
||||
size_t related_dim_right = transpose_b_ ? inputs_shape_[1].size() - 1 : inputs_shape_[1].size() - 2;
|
||||
// Handle the not power of 2 part.
|
||||
for (auto &stra : sp_vector_2_power_part) {
|
||||
auto stra_arrays = stra->GetInputDim();
|
||||
if (stra_arrays.size() != 2) {
|
||||
MS_LOG(ERROR) << "The generated strategy of matmul dose not match two input, the strategy is: " << stra_arrays;
|
||||
}
|
||||
for (size_t i = 0; i < 2; ++i) {
|
||||
size_t stra_size = stra_arrays[i].size();
|
||||
for (size_t j = 0; j < stra_size; ++j) {
|
||||
if (i == 1 && j == related_dim_right) {
|
||||
continue;
|
||||
}
|
||||
auto new_stra_arrays{stra_arrays};
|
||||
new_stra_arrays[i][j] = new_stra_arrays[i][j] * SizeToLong(dev_num_not_2_power);
|
||||
if (i == 0 && j == related_dim_left) {
|
||||
new_stra_arrays[1][related_dim_right] =
|
||||
new_stra_arrays[1][related_dim_right] * SizeToLong(dev_num_not_2_power);
|
||||
}
|
||||
StrategyPtr new_stra = std::make_shared<Strategy>(stage_id, new_stra_arrays);
|
||||
sp_vector.push_back(new_stra);
|
||||
}
|
||||
}
|
||||
Shape splittable_flag(mat_a_shape.size() + 1, 1);
|
||||
Shapes splittable_input = {splittable_flag};
|
||||
Shape tmp_shape = inputs_shape_[0];
|
||||
size_t index = 0;
|
||||
if (transpose_b_) {
|
||||
index = inputs_shape_[1].size() - 2;
|
||||
tmp_shape.push_back(inputs_shape_[1][index]); // mat_a: [A, B, C, D], mat_b: [B, E, D], tmp_shape: [A, B, C, D, E]
|
||||
} else {
|
||||
index = inputs_shape_[1].size() - 1;
|
||||
tmp_shape.push_back(inputs_shape_[1][index]); // mat_a: [A, B, C, D], mat_b: [B, D, E], tmp_shape: [A, B, C, D, E]
|
||||
}
|
||||
strategy_cost_.clear();
|
||||
// add the repeated strategy
|
||||
auto repeated_stra_arrays{inputs_shape_};
|
||||
for (auto &stra_array : repeated_stra_arrays) {
|
||||
std::fill(stra_array.begin(), stra_array.end(), 1);
|
||||
}
|
||||
StrategyPtr repeated_stra = std::make_shared<Strategy>(stage_id, repeated_stra_arrays);
|
||||
sp_vector.push_back(repeated_stra);
|
||||
Shapes tmp_inputs_shape = {tmp_shape};
|
||||
|
||||
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Generate strategies failed";
|
||||
}
|
||||
|
||||
// set the inputs' strategies
|
||||
for (auto &sp : sp_vector) {
|
||||
if (SetCostUnderStrategy(sp) == FAILED) {
|
||||
MS_LOG(WARNING) << name_ << " : Calculating cost for strategy failed.";
|
||||
continue;
|
||||
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
||||
}
|
||||
}
|
||||
if (strategy_cost_.empty()) {
|
||||
MS_LOG(EXCEPTION) << name_ << " : No available strategy.";
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
Strategys replace_strategy;
|
||||
Dimensions tmp_strategy = sp->GetInputDim()[0];
|
||||
Dimensions mat_a_strategy = tmp_strategy;
|
||||
mat_a_strategy.pop_back();
|
||||
|
||||
Status MatMulBase::GenerateStrategies(int64_t stage_id) {
|
||||
if (GetAttrs() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : GetAttrs failed.";
|
||||
return FAILED;
|
||||
}
|
||||
CheckGlobalDeviceManager();
|
||||
RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id);
|
||||
size_t dev_num = dev_list.size();
|
||||
Shape input0_shape = inputs_shape_[0], input1_shape = inputs_shape_[1];
|
||||
if (transpose_a_) {
|
||||
if (SwapLastTwoElements(&input0_shape) == FAILED) {
|
||||
MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
|
||||
// mat_b_shape: [B, D, E], tmp_strategy: [A, B, C, D, E]
|
||||
// mat_b_strategy: init [A, B, C, D, E]
|
||||
Dimensions mat_b_strategy = tmp_strategy;
|
||||
// mat_b_strategy: delete C, [A, B, D, E]
|
||||
(void)mat_b_strategy.erase(mat_b_strategy.end() - 3);
|
||||
// mat_b_strategy: delete A, [B, D, E]
|
||||
(void)mat_b_strategy.erase(mat_b_strategy.begin(), mat_b_strategy.begin() + static_cast<different_type>(diff_len));
|
||||
// handle transpose_b
|
||||
if (transpose_b_) {
|
||||
(void)SwapLastTwoElements(&mat_b_strategy);
|
||||
}
|
||||
replace_strategy.push_back(mat_a_strategy);
|
||||
replace_strategy.push_back(mat_b_strategy);
|
||||
sp->ResetInputs(replace_strategy);
|
||||
}
|
||||
if (transpose_b_) {
|
||||
if (SwapLastTwoElements(&input1_shape) == FAILED) {
|
||||
MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
|
||||
}
|
||||
}
|
||||
auto dev_num_2_power = (dev_num & (dev_num - 1));
|
||||
std::vector<StrategyPtr> sp_vector_2_power_part;
|
||||
if (dev_num_2_power == 0) {
|
||||
if (GenerateStrategiesBase(stage_id, dev_num, input0_shape, input1_shape, &sp_vector_2_power_part) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "No available strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
strategy_cost_.clear();
|
||||
for (auto &sp : sp_vector_2_power_part) {
|
||||
if (SetCostUnderStrategy(sp) == FAILED) {
|
||||
MS_LOG(WARNING) << name_ << " : Calculating cost for strategy failed.";
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (strategy_cost_.empty()) {
|
||||
MS_LOG(EXCEPTION) << name_ << " : No available strategy.";
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
auto dev_num_not_2_power = dev_num / (dev_num - dev_num_2_power);
|
||||
if (GenerateStrategiesBase(stage_id, dev_num - dev_num_2_power, input0_shape, input1_shape,
|
||||
&sp_vector_2_power_part) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Generating strategy in power of 2 devices failed.";
|
||||
return FAILED;
|
||||
}
|
||||
return GenerateStrategiesNotPower2(stage_id, dev_num_not_2_power, sp_vector_2_power_part);
|
||||
}
|
||||
|
||||
std::vector<StrategyPtr> MatMulBase::GenerateOpStrategies(int64_t) {
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
return sp_vector;
|
||||
}
|
||||
|
||||
Status MatMulBase::PrepareStrategy(int64_t stage_id, size_t dev_num,
|
||||
mindspore::parallel::Dimensions combined_partitions, size_t input0_shape_size,
|
||||
size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) {
|
||||
int64_t product =
|
||||
std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies<int64_t>());
|
||||
const auto fully_use_device = CostModelContext::GetInstance()->fully_use_device();
|
||||
if (!fully_use_device) {
|
||||
if (LongToSize(product) > dev_num) {
|
||||
return FAILED;
|
||||
}
|
||||
} else {
|
||||
if (LongToSize(product) != dev_num) {
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
Dimensions input0_partitions, input1_partitions;
|
||||
if (input0_shape_size >= input1_shape_size) {
|
||||
for (size_t i = 0; i < input0_shape_size; ++i) {
|
||||
input0_partitions.push_back(combined_partitions[i]);
|
||||
}
|
||||
if (input1_shape_size == 2) {
|
||||
input1_partitions.push_back(combined_partitions[combined_partitions.size() - 2]);
|
||||
input1_partitions.push_back(combined_partitions[combined_partitions.size() - 1]);
|
||||
} else {
|
||||
// input1_shape.size() > 2
|
||||
for (size_t j = combined_partitions.size() - input1_shape_size - 1; j < combined_partitions.size(); ++j) {
|
||||
if (j == combined_partitions.size() - 3) {
|
||||
continue;
|
||||
}
|
||||
input1_partitions.push_back(combined_partitions[j]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < input1_shape_size; ++i) {
|
||||
input1_partitions.push_back(combined_partitions[i]);
|
||||
}
|
||||
for (size_t j = combined_partitions.size() - input0_shape_size - 1; j < combined_partitions.size() - 3; ++j) {
|
||||
input0_partitions.push_back(combined_partitions[j]);
|
||||
}
|
||||
input0_partitions.push_back(combined_partitions[combined_partitions.size() - 1]);
|
||||
input0_partitions.push_back(combined_partitions[combined_partitions.size() - 3]);
|
||||
}
|
||||
if (transpose_a_) {
|
||||
if (SwapLastTwoElements(&input0_partitions) == FAILED) {
|
||||
MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
|
||||
}
|
||||
}
|
||||
if (transpose_b_) {
|
||||
if (SwapLastTwoElements(&input1_partitions) == FAILED) {
|
||||
MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
|
||||
}
|
||||
}
|
||||
Strategys stras;
|
||||
stras.push_back(input0_partitions);
|
||||
stras.push_back(input1_partitions);
|
||||
(*sp) = std::make_shared<Strategy>(stage_id, stras);
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void MatMulBase::InitTensorInfoForCost(std::vector<TensorInfo> *relica_inputs_tensor_vector) {
|
||||
TensorLayout tly;
|
||||
if (transpose_a_) {
|
||||
Shape replica_input0_shape(inputs_tensor_info_[0].shape());
|
||||
Shape replica_input0_slice_shape(inputs_tensor_info_[0].slice_shape());
|
||||
if (SwapLastTwoElements(&replica_input0_shape) == FAILED) {
|
||||
MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
|
||||
}
|
||||
if (SwapLastTwoElements(&replica_input0_slice_shape) == FAILED) {
|
||||
MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
|
||||
}
|
||||
|
||||
TensorInfo replica_input0_info(tly, replica_input0_shape, replica_input0_slice_shape);
|
||||
relica_inputs_tensor_vector->push_back(replica_input0_info);
|
||||
} else {
|
||||
relica_inputs_tensor_vector->push_back(inputs_tensor_info_[0]);
|
||||
}
|
||||
if (transpose_b_) {
|
||||
Shape replica_input1_shape(inputs_tensor_info_[1].shape());
|
||||
Shape replica_input1_slice_shape(inputs_tensor_info_[1].slice_shape());
|
||||
if (SwapLastTwoElements(&replica_input1_shape) == FAILED) {
|
||||
MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
|
||||
}
|
||||
if (SwapLastTwoElements(&replica_input1_slice_shape) == FAILED) {
|
||||
MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
|
||||
}
|
||||
|
||||
TensorInfo replica_input1_info(tly, replica_input1_shape, replica_input1_slice_shape);
|
||||
relica_inputs_tensor_vector->push_back(replica_input1_info);
|
||||
} else {
|
||||
relica_inputs_tensor_vector->push_back(inputs_tensor_info_[1]);
|
||||
}
|
||||
}
|
||||
|
||||
Status MatMulBase::CheckForTensorSliceValid() const {
|
||||
const auto align_enable = CostModelContext::GetInstance()->tensor_slice_alignment_enable();
|
||||
const auto align_size = CostModelContext::GetInstance()->tensor_slice_alignment_size();
|
||||
if (!align_enable) {
|
||||
return SUCCESS;
|
||||
}
|
||||
if (inputs_tensor_info_.empty()) {
|
||||
return FAILED;
|
||||
}
|
||||
for (auto &one_input_tensor : inputs_tensor_info_) {
|
||||
auto slice_shape = one_input_tensor.slice_shape();
|
||||
if ((LongToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % align_size != 0) ||
|
||||
(LongToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % align_size != 0)) {
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::shared_ptr<Strategys> BatchMatMulInfo::GenerateBatchStrategies() {
|
||||
Dimensions batch_strategy(inputs_shape_[1].size() - 1, 1);
|
||||
(void)batch_strategy.insert(batch_strategy.begin(), stage_device_size_);
|
||||
|
@ -675,50 +490,6 @@ std::shared_ptr<Strategys> BatchMatMulInfo::GenerateBatchStrategies() {
|
|||
return std::make_shared<Strategys>(strategy_v);
|
||||
}
|
||||
|
||||
Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) {
|
||||
if (InitForCostModel(strategy, nullptr) == FAILED) {
|
||||
MS_LOG(INFO) << name_ << " : Initialization under the strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
PrintStrategy(strategy);
|
||||
// Check whether the tensor slice of input_tensor_info is valid or not
|
||||
if (CheckForTensorSliceValid() != SUCCESS) {
|
||||
MS_LOG(INFO) << name_ << " : The tensor slice is not valid under this strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
// Here, a replicated inputs_ is constructed for the transposed TensorInfo.
|
||||
std::vector<TensorInfo> relica_inputs_tensor_vector;
|
||||
InitTensorInfoForCost(&relica_inputs_tensor_vector);
|
||||
|
||||
int64_t stage_id = strategy->GetInputStage();
|
||||
// Here, we use the origin outputs_, because we only use the slice size of the output tensor.
|
||||
// It does not matter whether the output tensor is transposed or not.
|
||||
double computation_cost =
|
||||
operator_cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
|
||||
double communication_cost = operator_cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
|
||||
const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
|
||||
std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
|
||||
result->communication_without_parameter_ =
|
||||
operator_cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
|
||||
result->communication_with_partial_para_ =
|
||||
result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_);
|
||||
|
||||
// Breaking ties for preferring data parallelization
|
||||
BreakingTiesForPerferringDataParallel(strategy, result);
|
||||
MS_LOG(DEBUG) << name_ << " : computation_cost: " << result->computation_cost_
|
||||
<< ", communication_cost: " << result->communication_cost_
|
||||
<< ", communication_without_parameter_: " << result->communication_without_parameter_
|
||||
<< ", communication_with_partial_para_: " << result->communication_with_partial_para_;
|
||||
// refine communication cost calculation for practice
|
||||
RefineForPracticalCost(result, false);
|
||||
result->communication_forward_ = result->communication_without_parameter_;
|
||||
|
||||
std::shared_ptr<StrategyWithCost> swc =
|
||||
std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);
|
||||
swc->cost_list.push_back(result);
|
||||
(void)strategy_cost_.emplace_back(swc);
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
Status MatMulBase::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,12 +38,8 @@ class MatMulBase : public OperatorInfo {
|
|||
~MatMulBase() override = default;
|
||||
|
||||
// Generate all strategies and the corresponding cost for this MatMul operator
|
||||
Status GenerateStrategies(int64_t stage_id) override;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||
Status PrepareStrategy(int64_t stage_id, size_t dev_num, Dimensions combined_partitions, size_t input0_shape_size,
|
||||
size_t input1_shape_size, StrategyPtr *sp);
|
||||
|
||||
Status SwapLastTwoElements(Shape *shape);
|
||||
|
||||
protected:
|
||||
|
@ -52,12 +48,6 @@ class MatMulBase : public OperatorInfo {
|
|||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout);
|
||||
void InitTensorInfoForCost(std::vector<TensorInfo> *);
|
||||
Status GenerateStrategiesBase(int64_t stage_id, size_t dev_num, const Shape &input0_shape, Shape input1_shape,
|
||||
std::vector<StrategyPtr> *const sp_vector);
|
||||
Status GenerateStrategiesNotPower2(int64_t stage_id, size_t dev_num_not_2_power,
|
||||
const std::vector<StrategyPtr> &sp_vector_2_power_part);
|
||||
Status CheckForTensorSliceValid() const;
|
||||
Status GetAttrs() override;
|
||||
|
||||
bool transpose_a_ = false;
|
||||
|
|
|
@ -90,7 +90,7 @@ class OperatorInfo {
|
|||
|
||||
// Given the stage_id (which indicates the number of devices),
|
||||
// generate all strategies for this operator
|
||||
virtual Status GenerateStrategies(int64_t stage_id);
|
||||
Status GenerateStrategies(int64_t stage_id);
|
||||
virtual std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) = 0;
|
||||
const OperatorCostPtr &operator_cost() const { return operator_cost_; }
|
||||
void set_cost(const OperatorCostPtr &cost) { operator_cost_ = cost; }
|
||||
|
|
|
@ -420,15 +420,9 @@ void ReshapeInfo::SetCostForReshape(const mindspore::parallel::StrategyPtr &stra
|
|||
strategy_cost_.emplace_back(swc);
|
||||
}
|
||||
|
||||
Status ReshapeInfo::GenerateStrategies(int64_t stage_id) {
|
||||
if (GetAttrs() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": GetAttrs failed.";
|
||||
return FAILED;
|
||||
}
|
||||
if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) {
|
||||
MS_LOG(ERROR) << name_ << ": Inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", "
|
||||
<< outputs_shape_.size();
|
||||
return FAILED;
|
||||
std::vector<StrategyPtr> ReshapeInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
if (inputs_shape_.empty()) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Inputs shape size or is empty";
|
||||
}
|
||||
Shape input0_split;
|
||||
(void)input0_split.insert(input0_split.end(), inputs_shape_[0].size(), 1);
|
||||
|
@ -436,15 +430,10 @@ Status ReshapeInfo::GenerateStrategies(int64_t stage_id) {
|
|||
// strategy used only in the input node is parameter,
|
||||
// in other case, use the input node's output_layout as input_layout.
|
||||
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed.";
|
||||
return FAILED;
|
||||
MS_LOG(EXCEPTION) << name_ << ": GenerateStrategiesForIndependentInputs failed.";
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<StrategyPtr> ReshapeInfo::GenerateOpStrategies(int64_t) {
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
return sp_vector;
|
||||
return sp_vector_;
|
||||
}
|
||||
|
||||
Status ReshapeInfo::GenerateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
|
||||
|
|
|
@ -61,8 +61,7 @@ class ReshapeInfo : public OperatorInfo {
|
|||
Status GenerateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
|
||||
const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs, int64_t out_index,
|
||||
int64_t in_index, bool is_prev_param, bool is_next_reshape);
|
||||
Status GenerateStrategies(int64_t stage_id) override;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||
std::string pre_operator_name() const { return pre_operator_name_; }
|
||||
std::string next_operator_name() const { return next_operator_name_; }
|
||||
|
|
|
@ -159,7 +159,7 @@ Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
|||
return SetCostUnderStrategyBase(strategy);
|
||||
}
|
||||
|
||||
Status VirtualDatasetInfo::GenerateStrategies(int64_t stage_id) {
|
||||
std::vector<StrategyPtr> VirtualDatasetInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
StrategyPtr sp;
|
||||
Strategys strategy;
|
||||
|
@ -183,18 +183,8 @@ Status VirtualDatasetInfo::GenerateStrategies(int64_t stage_id) {
|
|||
}
|
||||
}
|
||||
sp = std::make_shared<Strategy>(stage_id, strategy);
|
||||
if (SetCostUnderStrategy(sp) == SUCCESS) {
|
||||
MS_LOG(INFO) << name_ << ": Successfully dataset strategy.";
|
||||
PrintStrategy(sp);
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << ": Generating dataset strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<StrategyPtr> VirtualDatasetInfo::GenerateOpStrategies(int64_t) {
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
sp_vector.push_back(sp);
|
||||
return sp_vector;
|
||||
}
|
||||
|
||||
|
|
|
@ -37,8 +37,7 @@ class VirtualDatasetInfo : public OperatorInfo {
|
|||
Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) override;
|
||||
|
||||
Status GenerateStrategies(int64_t stage_id) override;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||
void ReComputeBatchSplitFlagList() override;
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ Status VirtualOutputInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status VirtualOutputInfo::GenerateStrategies(int64_t stage_id) {
|
||||
std::vector<StrategyPtr> VirtualOutputInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
StrategyPtr sp;
|
||||
Strategys strategy;
|
||||
bool full_batch = ParallelContext::GetInstance()->full_batch();
|
||||
|
@ -63,8 +63,7 @@ Status VirtualOutputInfo::GenerateStrategies(int64_t stage_id) {
|
|||
}
|
||||
|
||||
if (total_dev_num == 0) {
|
||||
MS_LOG(ERROR) << name_ << ": The total devices num is 0";
|
||||
return FAILED;
|
||||
MS_LOG(EXCEPTION) << name_ << ": The total devices num is 0";
|
||||
}
|
||||
|
||||
for (auto &shape : inputs_shape_) {
|
||||
|
@ -80,14 +79,9 @@ Status VirtualOutputInfo::GenerateStrategies(int64_t stage_id) {
|
|||
strategy.push_back(temp);
|
||||
}
|
||||
sp = std::make_shared<Strategy>(stage_id, strategy);
|
||||
if (SetCostUnderStrategy(sp) == SUCCESS) {
|
||||
MS_LOG(INFO) << name_ << ": Successfully dataset strategy.";
|
||||
PrintStrategy(sp);
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << ": Generating dataset strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
sp_vector.push_back(sp);
|
||||
return sp_vector;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,7 +35,7 @@ class VirtualOutputInfo : public VirtualDatasetInfo {
|
|||
const PrimitiveAttrs &attrs)
|
||||
: VirtualDatasetInfo(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~VirtualOutputInfo() override = default;
|
||||
Status GenerateStrategies(int64_t stage_id) override;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
||||
|
||||
protected:
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -93,6 +93,9 @@ void TestMatmulInfo::SetUp() {
|
|||
matmul4 = std::make_shared<MatMulInfo>("matmul_info", inputs_shape_3, outputs_shape_3, attr_4);
|
||||
}
|
||||
|
||||
/// Feature: test matmul info
|
||||
/// Description: infer dev matrix
|
||||
/// Expectation: the dev matrix is right
|
||||
TEST_F(TestMatmulInfo, InferDevMatrixShape1) {
|
||||
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||
|
@ -104,6 +107,9 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape1) {
|
|||
ASSERT_EQ(dev_matrix_shape, expect);
|
||||
}
|
||||
|
||||
/// Feature: test matmul info
|
||||
/// Description: infer dev matrix
|
||||
/// Expectation: the dev matrix is right
|
||||
TEST_F(TestMatmulInfo, InferDevMatrixShape2) {
|
||||
Strategys inputs = {{2, 4, 8, 8}, {2, 4, 8, 2}};
|
||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||
|
@ -115,7 +121,9 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape2) {
|
|||
ASSERT_EQ(dev_matrix_shape, expect);
|
||||
}
|
||||
|
||||
// matmul2
|
||||
/// Feature: test matmul info
|
||||
/// Description: infer dev matrix
|
||||
/// Expectation: the dev matrix is right
|
||||
TEST_F(TestMatmulInfo, InferDevMatrixShape3) {
|
||||
Strategys inputs = {{2, 4, 8, 16}, {1, 16}};
|
||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||
|
@ -127,7 +135,9 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape3) {
|
|||
ASSERT_EQ(dev_matrix_shape, expect);
|
||||
}
|
||||
|
||||
// matmul2
|
||||
/// Feature: test matmul info
|
||||
/// Description: infer dev matrix
|
||||
/// Expectation: the dev matrix is right
|
||||
TEST_F(TestMatmulInfo, InferDevMatrixShape4) {
|
||||
Strategys inputs = {{2, 4, 8, 8}, {2, 8}};
|
||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||
|
@ -139,7 +149,9 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape4) {
|
|||
ASSERT_EQ(dev_matrix_shape, expect);
|
||||
}
|
||||
|
||||
// matmul3
|
||||
/// Feature: test matmul info
|
||||
/// Description: infer dev matrix
|
||||
/// Expectation: the dev matrix is right
|
||||
TEST_F(TestMatmulInfo, InferDevMatrixShape5) {
|
||||
Strategys inputs = {{8, 16}, {2, 4, 1, 16}};
|
||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||
|
@ -151,7 +163,9 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape5) {
|
|||
ASSERT_EQ(dev_matrix_shape, expect);
|
||||
}
|
||||
|
||||
// matmul3
|
||||
/// Feature: test matmul info
|
||||
/// Description: infer dev matrix
|
||||
/// Expectation: the dev matrix is right
|
||||
TEST_F(TestMatmulInfo, InferDevMatrixShape6) {
|
||||
Strategys inputs = {{8, 8}, {2, 4, 2, 8}};
|
||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||
|
@ -163,6 +177,9 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape6) {
|
|||
ASSERT_EQ(dev_matrix_shape, expect);
|
||||
}
|
||||
|
||||
/// Feature: test matmul info
|
||||
/// Description: infer tensor map
|
||||
/// Expectation: the tensor map is right
|
||||
TEST_F(TestMatmulInfo, InferTensorMap1) {
|
||||
Strategys str = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
||||
StrategyPtr strategy = NewStrategy(0, str);
|
||||
|
@ -188,7 +205,9 @@ TEST_F(TestMatmulInfo, InferTensorMap1) {
|
|||
ASSERT_EQ(output_tensor_map.array(), output_expect);
|
||||
}
|
||||
|
||||
// matmul2
|
||||
/// Feature: test matmul info
|
||||
/// Description: infer tensor map
|
||||
/// Expectation: the tensor map is right
|
||||
TEST_F(TestMatmulInfo, InferTensorMap2) {
|
||||
Strategys str = {{2, 4, 8, 16}, {1, 16}};
|
||||
StrategyPtr strategy = NewStrategy(0, str);
|
||||
|
@ -214,7 +233,9 @@ TEST_F(TestMatmulInfo, InferTensorMap2) {
|
|||
ASSERT_EQ(output_tensor_map.array(), output_expect);
|
||||
}
|
||||
|
||||
// matmul3
|
||||
/// Feature: test matmul info
|
||||
/// Description: infer tensor map
|
||||
/// Expectation: the tensor map is right
|
||||
TEST_F(TestMatmulInfo, InferTensorMap3) {
|
||||
Strategys str = {{8, 16}, {2, 4, 1, 16}};
|
||||
StrategyPtr strategy = NewStrategy(0, str);
|
||||
|
@ -240,6 +261,9 @@ TEST_F(TestMatmulInfo, InferTensorMap3) {
|
|||
ASSERT_EQ(output_tensor_map.array(), output_expect);
|
||||
}
|
||||
|
||||
/// Feature: test matmul info
|
||||
/// Description: infer slice shape
|
||||
/// Expectation: the slice shape is right
|
||||
TEST_F(TestMatmulInfo, InferSliceShape1) {
|
||||
Strategys str = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
||||
StrategyPtr strategy = NewStrategy(0, str);
|
||||
|
@ -265,7 +289,9 @@ TEST_F(TestMatmulInfo, InferSliceShape1) {
|
|||
ASSERT_EQ(output_slice_shape, output_slice_shape_expect);
|
||||
}
|
||||
|
||||
// matmul2
|
||||
/// Feature: test matmul info
|
||||
/// Description: infer slice shape
|
||||
/// Expectation: the slice shape is right
|
||||
TEST_F(TestMatmulInfo, InferSliceShape2) {
|
||||
Strategys str = {{2, 4, 8, 16}, {1, 16}};
|
||||
StrategyPtr strategy = NewStrategy(0, str);
|
||||
|
@ -291,7 +317,9 @@ TEST_F(TestMatmulInfo, InferSliceShape2) {
|
|||
ASSERT_EQ(output_slice_shape, output_slice_shape_expect);
|
||||
}
|
||||
|
||||
// matmul3
|
||||
/// Feature: test matmul info
|
||||
/// Description: infer slice shape
|
||||
/// Expectation: the slice shape is right
|
||||
TEST_F(TestMatmulInfo, InferSliceShape3) {
|
||||
Strategys str = {{8, 16}, {2, 4, 1, 16}};
|
||||
StrategyPtr strategy = NewStrategy(0, str);
|
||||
|
@ -317,7 +345,9 @@ TEST_F(TestMatmulInfo, InferSliceShape3) {
|
|||
ASSERT_EQ(output_slice_shape, output_slice_shape_expect);
|
||||
}
|
||||
|
||||
// matmul3
|
||||
/// Feature: test matmul info
|
||||
/// Description: get tensor layout
|
||||
/// Expectation: the tensor layout is right
|
||||
TEST_F(TestMatmulInfo, GetTensorLayout3) {
|
||||
Strategys str = {{8, 16}, {2, 4, 1, 16}};
|
||||
StrategyPtr strategy = NewStrategy(0, str);
|
||||
|
@ -343,6 +373,9 @@ TEST_F(TestMatmulInfo, GetTensorLayout3) {
|
|||
ASSERT_EQ(output_tensor_map.array(), output_expect);
|
||||
}
|
||||
|
||||
/// Feature: test matmul info
|
||||
/// Description: infer forward op
|
||||
/// Expectation: the forward op is right
|
||||
TEST_F(TestMatmulInfo, GetForwardOp1) {
|
||||
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||
|
@ -369,6 +402,9 @@ TEST_F(TestMatmulInfo, GetForwardOp1) {
|
|||
ASSERT_EQ(arg1_value_is_string, true);
|
||||
}
|
||||
|
||||
/// Feature: test matmul info
|
||||
/// Description: infer forward op
|
||||
/// Expectation: the forward op is right
|
||||
TEST_F(TestMatmulInfo, GetForwardOp2) {
|
||||
Strategys inputs = {{2, 4, 8, 1}, {2, 4, 1, 16}};
|
||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||
|
@ -379,6 +415,9 @@ TEST_F(TestMatmulInfo, GetForwardOp2) {
|
|||
ASSERT_EQ(forward_op.size(), 0);
|
||||
}
|
||||
|
||||
/// Feature: test matmul info
|
||||
/// Description: infer virtual_div op
|
||||
/// Expectation: the virtual_div op is right
|
||||
TEST_F(TestMatmulInfo, GetVirtualDivOp1) {
|
||||
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||
|
@ -398,6 +437,9 @@ TEST_F(TestMatmulInfo, GetVirtualDivOp1) {
|
|||
ASSERT_EQ(divisor, 16);
|
||||
}
|
||||
|
||||
/// Feature: test matmul info
|
||||
/// Description: infer mirror op
|
||||
/// Expectation: the mirror op is right
|
||||
TEST_F(TestMatmulInfo, GetMirrorOPs1) {
|
||||
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||
|
@ -417,7 +459,9 @@ TEST_F(TestMatmulInfo, GetMirrorOPs1) {
|
|||
ASSERT_EQ(arg0_name, "group");
|
||||
}
|
||||
|
||||
// matmul2
|
||||
/// Feature: test matmul info
|
||||
/// Description: infer mirror op
|
||||
/// Expectation: the mirror op is right
|
||||
TEST_F(TestMatmulInfo, GetMirrorOPs2) {
|
||||
Strategys inputs = {{2, 4, 1, 16}, {8, 16}};
|
||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||
|
@ -437,7 +481,9 @@ TEST_F(TestMatmulInfo, GetMirrorOPs2) {
|
|||
ASSERT_EQ(arg0_name, "group");
|
||||
}
|
||||
|
||||
// matmul3
|
||||
/// Feature: test matmul info
|
||||
/// Description: infer mirror op
|
||||
/// Expectation: the mirror op is right
|
||||
TEST_F(TestMatmulInfo, GetMirrorOPs3) {
|
||||
Strategys inputs = {{8, 16}, {2, 4, 1, 16}};
|
||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||
|
@ -456,6 +502,9 @@ TEST_F(TestMatmulInfo, GetMirrorOPs3) {
|
|||
ASSERT_EQ(arg0_name, "group");
|
||||
}
|
||||
|
||||
/// Feature: test matmul info
|
||||
/// Description: infer mirror op
|
||||
/// Expectation: the mirror op is right
|
||||
TEST_F(TestMatmulInfo, GetMirrorOPs4) {
|
||||
Strategys inputs = {{2, 4, 1, 16}, {2, 4, 16, 8}};
|
||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||
|
@ -466,6 +515,9 @@ TEST_F(TestMatmulInfo, GetMirrorOPs4) {
|
|||
ASSERT_EQ(mirror_ops.size(), 2);
|
||||
}
|
||||
|
||||
/// Feature: test matmul info
|
||||
/// Description: init twice
|
||||
/// Expectation: the mirror op is right
|
||||
TEST_F(TestMatmulInfo, InitTwice) {
|
||||
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||
|
@ -487,6 +539,9 @@ TEST_F(TestMatmulInfo, InitTwice) {
|
|||
ASSERT_EQ(arg0_name, "group");
|
||||
}
|
||||
|
||||
/// Feature: test matmul info
|
||||
/// Description: check strategy, the strategy is invalid
|
||||
/// Expectation: return FAILED
|
||||
TEST_F(TestMatmulInfo, CheckStrategy1) {
|
||||
// Success: {{2,4,8,16}, {2,4,16,1}}
|
||||
Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
|
||||
|
@ -496,6 +551,9 @@ TEST_F(TestMatmulInfo, CheckStrategy1) {
|
|||
ASSERT_EQ(ret, FAILED);
|
||||
}
|
||||
|
||||
/// Feature: test matmul info
|
||||
/// Description: check strategy, the strategy is invalid
|
||||
/// Expectation: return FAILED
|
||||
TEST_F(TestMatmulInfo, CheckStrategy2) {
|
||||
// Success: {{2,4,8,16}, {2,4,16,1}}
|
||||
Strategys inputs = {{2, 4, 8, 16}, {4, 16, 1}};
|
||||
|
@ -505,6 +563,9 @@ TEST_F(TestMatmulInfo, CheckStrategy2) {
|
|||
ASSERT_EQ(ret, FAILED);
|
||||
}
|
||||
|
||||
/// Feature: test matmul info
|
||||
/// Description: check strategy, the strategy is invalid
|
||||
/// Expectation: return FAILED
|
||||
TEST_F(TestMatmulInfo, CheckStrategy3) {
|
||||
// Success: {{2,4,8,16}, {2,4,16,1}}
|
||||
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 8, 1}};
|
||||
|
@ -514,6 +575,9 @@ TEST_F(TestMatmulInfo, CheckStrategy3) {
|
|||
ASSERT_EQ(ret, FAILED);
|
||||
}
|
||||
|
||||
/// Feature: test matmul info
|
||||
/// Description: check strategy, the strategy is invalid
|
||||
/// Expectation: return FAILED
|
||||
TEST_F(TestMatmulInfo, CheckStrategy4) {
|
||||
// Success: {{2,4,8,16}, {2,4,16,1}}
|
||||
Strategys inputs = {{2, 4, 8, 16}, {2, 3, 16, 1}};
|
||||
|
@ -523,6 +587,9 @@ TEST_F(TestMatmulInfo, CheckStrategy4) {
|
|||
ASSERT_EQ(ret, FAILED);
|
||||
}
|
||||
|
||||
/// Feature: test matmul info
|
||||
/// Description: check strategy, the strategy is invalid
|
||||
/// Expectation: return FAILED
|
||||
TEST_F(TestMatmulInfo, CheckStrategy5) {
|
||||
// Success: {{2,4,8,16}, {2,4,16,1}}
|
||||
Strategys inputs = {{0, 4, 8, 16}, {2, 4, 16, 1}};
|
||||
|
@ -532,6 +599,9 @@ TEST_F(TestMatmulInfo, CheckStrategy5) {
|
|||
ASSERT_EQ(ret, FAILED);
|
||||
}
|
||||
|
||||
/// Feature: test matmul info
|
||||
/// Description: check strategy, the strategy is invalid
|
||||
/// Expectation: return FAILED
|
||||
TEST_F(TestMatmulInfo, CheckStrategy6) {
|
||||
// Success: {{2,4,8,16}, {2,4,16,1}}
|
||||
Strategys inputs = {{-1, 4, 8, 16}, {2, 4, 16, 1}};
|
||||
|
@ -541,6 +611,9 @@ TEST_F(TestMatmulInfo, CheckStrategy6) {
|
|||
ASSERT_EQ(ret, FAILED);
|
||||
}
|
||||
|
||||
/// Feature: test matmul info
|
||||
/// Description: check strategy, the strategy is invalid
|
||||
/// Expectation: return FAILED
|
||||
TEST_F(TestMatmulInfo, CheckStrategy7) {
|
||||
// Success: {{2,4,8,16}, {2,4,16,1}}
|
||||
Strategys inputs = {{4, 4, 8, 16}, {2, 4, 16, 1}};
|
||||
|
@ -550,6 +623,9 @@ TEST_F(TestMatmulInfo, CheckStrategy7) {
|
|||
ASSERT_EQ(ret, FAILED);
|
||||
}
|
||||
|
||||
/// Feature: test matmul info
|
||||
/// Description: init, invalid strategy
|
||||
/// Expectation: return FAILED
|
||||
TEST_F(TestMatmulInfo, InitFailed) {
|
||||
// matmul4 attr is wrong
|
||||
Strategys inputs = {{4, 4, 8, 16}, {2, 4, 16, 1}};
|
||||
|
@ -559,6 +635,9 @@ TEST_F(TestMatmulInfo, InitFailed) {
|
|||
ASSERT_EQ(ret, FAILED);
|
||||
}
|
||||
|
||||
/// Feature: test matmul info
|
||||
/// Description: generate strategy
|
||||
/// Expectation: the computation cost is right
|
||||
TEST_F(TestMatmulInfo, test_GenerateStrategies1) {
|
||||
// the parameter '0' indicates that the stageId = 0, there are 1024 devices in the stage 0
|
||||
ASSERT_EQ(matmul1->GenerateStrategies(0), Status::SUCCESS);
|
||||
|
@ -574,35 +653,5 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies1) {
|
|||
break;
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TestMatmulInfo, test_GenerateStrategies2) {
|
||||
// the parameter '0' indicates that the stageId = 0, there are 1024 devices in the stage 0
|
||||
ASSERT_EQ(matmul3->GenerateStrategies(0), Status::SUCCESS);
|
||||
std::vector<std::shared_ptr<StrategyWithCost>> sc = matmul3->GetStrategyCost();
|
||||
for (const auto& swc : sc) {
|
||||
StrategyPtr sp = swc->strategy_ptr;
|
||||
Cost cost = *(swc->cost_list[0]);
|
||||
matmul3->InitForCostModel(sp, nullptr);
|
||||
|
||||
std::vector<TensorInfo> inputs_info = matmul3->inputs_tensor_info();
|
||||
std::vector<TensorInfo> outputs_info = matmul3->outputs_tensor_info();
|
||||
std::vector<TensorInfo> replica_inputs_info;
|
||||
replica_inputs_info.push_back(inputs_info[0]);
|
||||
|
||||
// transpose the tensor B
|
||||
TensorInfo input1_info = inputs_info[1];
|
||||
Shape input1_shape = input1_info.shape();
|
||||
Shape input1_slice_shape = input1_info.slice_shape();
|
||||
TensorLayout tly;
|
||||
matmul3->SwapLastTwoElements(&input1_shape);
|
||||
matmul3->SwapLastTwoElements(&input1_slice_shape);
|
||||
TensorInfo replica_input1_info(tly, input1_shape, input1_slice_shape);
|
||||
replica_inputs_info.push_back(replica_input1_info);
|
||||
|
||||
ASSERT_DOUBLE_EQ(matmul3->operator_cost()->GetComputationCost(replica_inputs_info, outputs_info, sp->GetInputStage()),
|
||||
cost.computation_cost_);
|
||||
break;
|
||||
}
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -58,6 +58,11 @@ def compile_net(net, x, y, b, phase):
|
|||
|
||||
|
||||
def test_auto_parallel_arithmetic():
|
||||
"""
|
||||
Features: test auto parallel
|
||||
Description: search strategies
|
||||
Expectation: Generated strategies matching expectations
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -89,6 +94,11 @@ def test_auto_parallel_arithmetic():
|
|||
|
||||
|
||||
def test_auto_parallel_arithmetic_broadcast_both():
|
||||
"""
|
||||
Features: test auto parallel
|
||||
Description: search strategies for broadcast
|
||||
Expectation: Generated strategies matching expectations
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -113,12 +123,17 @@ def test_auto_parallel_arithmetic_broadcast_both():
|
|||
strategies = _cell_graph_executor._get_shard_strategy(net)
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('FloorDiv-op', k) is not None:
|
||||
assert v == [[8, 1], [1, 1]]
|
||||
assert v == [[1, 1], [1, 1]]
|
||||
elif re.search('MatMul-op', k) is not None:
|
||||
assert v == [[8, 1], [1, 1]]
|
||||
assert v == [[1, 1], [1, 1]]
|
||||
|
||||
|
||||
def test_auto_parallel_arithmetic_broadcast_right():
|
||||
"""
|
||||
Features: test auto parallel
|
||||
Description: search strategies for right broadcast
|
||||
Expectation: Generated strategies matching expectations
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -150,6 +165,11 @@ def test_auto_parallel_arithmetic_broadcast_right():
|
|||
|
||||
|
||||
def test_auto_parallel_arithmetic_broadcast_left():
|
||||
"""
|
||||
Features: test auto parallel
|
||||
Description: search strategies for left broadcast
|
||||
Expectation: Generated strategies matching expectations
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -39,6 +39,11 @@ class NetWithLoss(nn.Cell):
|
|||
|
||||
|
||||
def test_common_parameter():
|
||||
"""
|
||||
Features: test auto parallel
|
||||
Description: search strategies for cast parameter
|
||||
Expectation: Generated strategies matching expectations
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -72,7 +77,9 @@ def test_common_parameter():
|
|||
_cell_graph_executor.compile(net, x, y, phase='train')
|
||||
strategies = _cell_graph_executor._get_shard_strategy(net)
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('MatMul-op', k) is not None:
|
||||
if re.search('MatMul-op0', k) is not None:
|
||||
assert v == [[4, 1], [1, 2]]
|
||||
elif re.search('MatMul-op', k) is not None:
|
||||
assert v == [[8, 1], [1, 1]]
|
||||
elif re.search('Cast-op', k) is not None:
|
||||
assert v == [[1, 1]]
|
||||
|
|
Loading…
Reference in New Issue