forked from mindspore-Ecosystem/mindspore
fix standalone prediction
This commit is contained in:
parent
8f2b70261a
commit
660f578988
|
@ -166,7 +166,8 @@ class Optimizer(Cell):
|
|||
if context.get_auto_parallel_context("enable_parallel_optimizer"):
|
||||
if _get_parallel_mode() == ParallelMode.DATA_PARALLEL and context.get_context("device_target") == "Ascend":
|
||||
self.use_parallel = True
|
||||
elif context.get_context("device_target") != "Ascend":
|
||||
elif _get_parallel_mode() == ParallelMode.DATA_PARALLEL \
|
||||
and context.get_context("device_target") != "Ascend":
|
||||
raise RuntimeError("Parallel optimizer only supports Ascend in data parallel mode.")
|
||||
elif _get_parallel_mode() in (ParallelMode.STAND_ALONE, ParallelMode.HYBRID_PARALLEL):
|
||||
raise RuntimeError("Parallel optimizer is not supported in {}.".format(_get_parallel_mode()))
|
||||
|
|
|
@ -241,13 +241,8 @@ def _infer_rank_list(train_map, predict_map=None):
|
|||
ret = {}
|
||||
for param_name in train_map:
|
||||
train_layout = train_map[param_name]
|
||||
predict_layout = predict_map[param_name]
|
||||
train_dev_mat = train_layout[0]
|
||||
dev_num = np.array(train_dev_mat).prod()
|
||||
if _check_same_layout(train_layout, predict_layout):
|
||||
dev_rank = _get_global_rank()
|
||||
ret[param_name] = ([dev_rank], True)
|
||||
continue
|
||||
new_train_layout = _remove_repeated_slices(train_layout)
|
||||
array = np.arange(dev_num).reshape(train_dev_mat)
|
||||
index = ()
|
||||
|
@ -263,7 +258,13 @@ def _infer_rank_list(train_map, predict_map=None):
|
|||
if param_name not in predict_map:
|
||||
logger.warning("predict_map does not contain %s", param_name)
|
||||
continue
|
||||
predict_layout = predict_map[param_name]
|
||||
dev_num = np.array(predict_layout[0].prod())
|
||||
# optimization pass
|
||||
if _check_same_layout(train_layout, predict_layout):
|
||||
dev_rank = _get_global_rank()
|
||||
ret[param_name] = ([dev_rank], True)
|
||||
continue
|
||||
if _check_similar_layout(train_layout, predict_layout):
|
||||
if len(rank_list) == 1:
|
||||
ret[param_name] = (rank_list, True)
|
||||
|
|
Loading…
Reference in New Issue