forked from mindspore-Ecosystem/mindspore
remove restriction for opt shard in inference
This commit is contained in:
parent
09a119cd7c
commit
1c9166e0a6
|
@ -222,9 +222,6 @@ def _parallel_predict_check():
|
|||
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
if not context.get_auto_parallel_context("full_batch"):
|
||||
raise RuntimeError('Model prediction only supports full batch dataset. Please set "full_batch" with True.')
|
||||
if context.get_auto_parallel_context("enable_parallel_optimizer"):
|
||||
raise RuntimeError('Model prediction does not support parallel optimizer. Please set'
|
||||
'"enable_parallel_optimizer" with False.')
|
||||
|
||||
|
||||
def _check_similar_layout(tensor_layout1, tensor_layout2):
|
||||
|
|
|
@ -49,7 +49,8 @@ class Net(nn.Cell):
|
|||
|
||||
def test_distribute_predict():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True,
|
||||
enable_parallel_optimizer=True)
|
||||
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
||||
net = Net()
|
||||
model = Model(net)
|
||||
|
@ -69,9 +70,6 @@ def test_edge_case():
|
|||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
with pytest.raises(RuntimeError):
|
||||
model.infer_predict_layout(inputs)
|
||||
context.set_auto_parallel_context(full_batch=True, enable_parallel_optimizer=True)
|
||||
with pytest.raises(RuntimeError):
|
||||
model.predict(inputs)
|
||||
|
||||
|
||||
# standalone predict
|
||||
|
|
Loading…
Reference in New Issue