!18768 [AutoParallel]Fix auto parallel gatherv2 bug

Merge pull request !18768 from lichen/fix_auto_parallel_gatherv2
This commit is contained in:
i-robot 2021-06-24 06:36:21 +00:00 committed by Gitee
commit 70152adcb3
4 changed files with 8 additions and 5 deletions

View File

@ -79,7 +79,6 @@ class GatherPInfo : public OperatorInfo {
int64_t index_offset_;
int64_t slice_size_;
std::string replace_op_name_ = GATHERV2;
Shape out_dev_matrix_shape_;
Group group_;
bool manual_split_ = false;
bool dynamic_shape_indices_ = false;

View File

@ -1581,8 +1581,11 @@ Status OperatorInfo::InferAsLossDivisor() {
return SUCCESS;
}
as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]);
MS_LOG(INFO) << name_ << ": the dev matrix shape is " << ShapeToString(dev_matrix_shape_)
if (out_dev_matrix_shape_.empty()) {
out_dev_matrix_shape_ = dev_matrix_shape_;
}
as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(out_dev_matrix_shape_, outputs_tensor_map_[0]);
MS_LOG(INFO) << name_ << ": the dev matrix shape is " << ShapeToString(out_dev_matrix_shape_)
<< ", the output tensor map is " << ShapeToString(outputs_tensor_map_[0]) << ", loss divisor is "
<< as_loss_divisor_;
return SUCCESS;

View File

@ -232,6 +232,7 @@ class OperatorInfo {
std::vector<TensorInfo> inputs_tensor_info_;
std::vector<TensorInfo> outputs_tensor_info_;
Shape dev_matrix_shape_; // if repeated calculation, it contains the repeated_calc_num_
Shape out_dev_matrix_shape_;
int64_t repeated_calc_num_ = 1;
int64_t as_loss_divisor_ = 1;
TensorMaps inputs_tensor_map_;

View File

@ -146,10 +146,10 @@ class Primitive(Primitive_):
mode = context.get_auto_parallel_context("parallel_mode")
if strategy is not None:
if not isinstance(strategy, tuple):
raise TypeError('strategy must be tuple type.')
raise TypeError(f'strategy must be tuple type, but got:{type(strategy)}')
for ele in strategy:
if not isinstance(ele, tuple):
raise TypeError('The element of strategy must be tuple type.')
raise TypeError(f'The element of strategy must be tuple type, but got:{type(ele)}')
if not _is_in_auto_parallel_mode() and strategy:
logger.warning(f"The shard strategy {strategy} of {self.name} is not valid in {mode}. "
f"Please use semi auto or auto parallel mode.")