fix standalone prediction

This commit is contained in:
Ziyan 2020-12-26 11:15:38 +08:00
parent 8f2b70261a
commit 660f578988
2 changed files with 8 additions and 6 deletions

View File

@ -166,7 +166,8 @@ class Optimizer(Cell):
if context.get_auto_parallel_context("enable_parallel_optimizer"):
if _get_parallel_mode() == ParallelMode.DATA_PARALLEL and context.get_context("device_target") == "Ascend":
self.use_parallel = True
elif context.get_context("device_target") != "Ascend":
elif _get_parallel_mode() == ParallelMode.DATA_PARALLEL \
and context.get_context("device_target") != "Ascend":
raise RuntimeError("Parallel optimizer only supports Ascend in data parallel mode.")
elif _get_parallel_mode() in (ParallelMode.STAND_ALONE, ParallelMode.HYBRID_PARALLEL):
raise RuntimeError("Parallel optimizer is not supported in {}.".format(_get_parallel_mode()))

View File

@ -241,13 +241,8 @@ def _infer_rank_list(train_map, predict_map=None):
ret = {}
for param_name in train_map:
train_layout = train_map[param_name]
predict_layout = predict_map[param_name]
train_dev_mat = train_layout[0]
dev_num = np.array(train_dev_mat).prod()
if _check_same_layout(train_layout, predict_layout):
dev_rank = _get_global_rank()
ret[param_name] = ([dev_rank], True)
continue
new_train_layout = _remove_repeated_slices(train_layout)
array = np.arange(dev_num).reshape(train_dev_mat)
index = ()
@ -263,7 +258,13 @@ def _infer_rank_list(train_map, predict_map=None):
if param_name not in predict_map:
logger.warning("predict_map does not contain %s", param_name)
continue
predict_layout = predict_map[param_name]
dev_num = np.array(predict_layout[0].prod())
# optimization pass
if _check_same_layout(train_layout, predict_layout):
dev_rank = _get_global_rank()
ret[param_name] = ([dev_rank], True)
continue
if _check_similar_layout(train_layout, predict_layout):
if len(rank_list) == 1:
ret[param_name] = (rank_list, True)