remove restriction for opt shard in inference

This commit is contained in:
Ziyan 2021-07-16 15:16:43 +08:00
parent 09a119cd7c
commit 1c9166e0a6
2 changed files with 2 additions and 7 deletions

View File

@ -222,9 +222,6 @@ def _parallel_predict_check():
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
if not context.get_auto_parallel_context("full_batch"): if not context.get_auto_parallel_context("full_batch"):
raise RuntimeError('Model prediction only supports full batch dataset. Please set "full_batch" with True.') raise RuntimeError('Model prediction only supports full batch dataset. Please set "full_batch" with True.')
if context.get_auto_parallel_context("enable_parallel_optimizer"):
raise RuntimeError('Model prediction does not support parallel optimizer. Please set'
'"enable_parallel_optimizer" with False.')
def _check_similar_layout(tensor_layout1, tensor_layout2): def _check_similar_layout(tensor_layout1, tensor_layout2):

View File

@ -49,7 +49,8 @@ class Net(nn.Cell):
def test_distribute_predict(): def test_distribute_predict():
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True,
enable_parallel_optimizer=True)
inputs = Tensor(np.ones([32, 128]).astype(np.float32)) inputs = Tensor(np.ones([32, 128]).astype(np.float32))
net = Net() net = Net()
model = Model(net) model = Model(net)
@ -69,9 +70,6 @@ def test_edge_case():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
model.infer_predict_layout(inputs) model.infer_predict_layout(inputs)
context.set_auto_parallel_context(full_batch=True, enable_parallel_optimizer=True)
with pytest.raises(RuntimeError):
model.predict(inputs)
# standalone predict # standalone predict