!18768 [AutoParallel]Fix auto parallel gatherv2 bug
Merge pull request !18768 from lichen/fix_auto_parallel_gatherv2
This commit is contained in:
commit
70152adcb3
|
@ -79,7 +79,6 @@ class GatherPInfo : public OperatorInfo {
|
|||
int64_t index_offset_;
|
||||
int64_t slice_size_;
|
||||
std::string replace_op_name_ = GATHERV2;
|
||||
Shape out_dev_matrix_shape_;
|
||||
Group group_;
|
||||
bool manual_split_ = false;
|
||||
bool dynamic_shape_indices_ = false;
|
||||
|
|
|
@ -1581,8 +1581,11 @@ Status OperatorInfo::InferAsLossDivisor() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]);
|
||||
MS_LOG(INFO) << name_ << ": the dev matrix shape is " << ShapeToString(dev_matrix_shape_)
|
||||
if (out_dev_matrix_shape_.empty()) {
|
||||
out_dev_matrix_shape_ = dev_matrix_shape_;
|
||||
}
|
||||
as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(out_dev_matrix_shape_, outputs_tensor_map_[0]);
|
||||
MS_LOG(INFO) << name_ << ": the dev matrix shape is " << ShapeToString(out_dev_matrix_shape_)
|
||||
<< ", the output tensor map is " << ShapeToString(outputs_tensor_map_[0]) << ", loss divisor is "
|
||||
<< as_loss_divisor_;
|
||||
return SUCCESS;
|
||||
|
|
|
@ -232,6 +232,7 @@ class OperatorInfo {
|
|||
std::vector<TensorInfo> inputs_tensor_info_;
|
||||
std::vector<TensorInfo> outputs_tensor_info_;
|
||||
Shape dev_matrix_shape_; // if repeated calculation, it contains the repeated_calc_num_
|
||||
Shape out_dev_matrix_shape_;
|
||||
int64_t repeated_calc_num_ = 1;
|
||||
int64_t as_loss_divisor_ = 1;
|
||||
TensorMaps inputs_tensor_map_;
|
||||
|
|
|
@ -146,10 +146,10 @@ class Primitive(Primitive_):
|
|||
mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if strategy is not None:
|
||||
if not isinstance(strategy, tuple):
|
||||
raise TypeError('strategy must be tuple type.')
|
||||
raise TypeError(f'strategy must be tuple type, but got:{type(strategy)}')
|
||||
for ele in strategy:
|
||||
if not isinstance(ele, tuple):
|
||||
raise TypeError('The element of strategy must be tuple type.')
|
||||
raise TypeError(f'The element of strategy must be tuple type, but got:{type(ele)}')
|
||||
if not _is_in_auto_parallel_mode() and strategy:
|
||||
logger.warning(f"The shard strategy {strategy} of {self.name} is not valid in {mode}. "
|
||||
f"Please use semi auto or auto parallel mode.")
|
||||
|
|
Loading…
Reference in New Issue