forked from mindspore-Ecosystem/mindspore
remove restriction for opt shard in inference
This commit is contained in:
parent
09a119cd7c
commit
1c9166e0a6
|
@ -222,9 +222,6 @@ def _parallel_predict_check():
|
||||||
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||||
if not context.get_auto_parallel_context("full_batch"):
|
if not context.get_auto_parallel_context("full_batch"):
|
||||||
raise RuntimeError('Model prediction only supports full batch dataset. Please set "full_batch" with True.')
|
raise RuntimeError('Model prediction only supports full batch dataset. Please set "full_batch" with True.')
|
||||||
if context.get_auto_parallel_context("enable_parallel_optimizer"):
|
|
||||||
raise RuntimeError('Model prediction does not support parallel optimizer. Please set'
|
|
||||||
'"enable_parallel_optimizer" with False.')
|
|
||||||
|
|
||||||
|
|
||||||
def _check_similar_layout(tensor_layout1, tensor_layout2):
|
def _check_similar_layout(tensor_layout1, tensor_layout2):
|
||||||
|
|
|
@ -49,7 +49,8 @@ class Net(nn.Cell):
|
||||||
|
|
||||||
def test_distribute_predict():
|
def test_distribute_predict():
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True,
|
||||||
|
enable_parallel_optimizer=True)
|
||||||
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
||||||
net = Net()
|
net = Net()
|
||||||
model = Model(net)
|
model = Model(net)
|
||||||
|
@ -69,9 +70,6 @@ def test_edge_case():
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
model.infer_predict_layout(inputs)
|
model.infer_predict_layout(inputs)
|
||||||
context.set_auto_parallel_context(full_batch=True, enable_parallel_optimizer=True)
|
|
||||||
with pytest.raises(RuntimeError):
|
|
||||||
model.predict(inputs)
|
|
||||||
|
|
||||||
|
|
||||||
# standalone predict
|
# standalone predict
|
||||||
|
|
Loading…
Reference in New Issue