!20411 enable optimizer parallel for inference
Merge pull request !20411 from gziyan/enable_opt_shard_predict
This commit is contained in:
commit
c9d3c1d346
|
@ -224,9 +224,6 @@ def _parallel_predict_check():
|
||||||
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||||
if not context.get_auto_parallel_context("full_batch"):
|
if not context.get_auto_parallel_context("full_batch"):
|
||||||
raise RuntimeError('Model prediction only supports full batch dataset. Please set "full_batch" with True.')
|
raise RuntimeError('Model prediction only supports full batch dataset. Please set "full_batch" with True.')
|
||||||
if context.get_auto_parallel_context("enable_parallel_optimizer"):
|
|
||||||
raise RuntimeError('Model prediction does not support parallel optimizer. Please set'
|
|
||||||
'"enable_parallel_optimizer" with False.')
|
|
||||||
|
|
||||||
|
|
||||||
def _check_similar_layout(tensor_layout1, tensor_layout2):
|
def _check_similar_layout(tensor_layout1, tensor_layout2):
|
||||||
|
|
|
@ -49,7 +49,8 @@ class Net(nn.Cell):
|
||||||
|
|
||||||
def test_distribute_predict():
|
def test_distribute_predict():
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True,
|
||||||
|
enable_parallel_optimizer=True)
|
||||||
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
||||||
net = Net()
|
net = Net()
|
||||||
model = Model(net)
|
model = Model(net)
|
||||||
|
@ -69,9 +70,6 @@ def test_edge_case():
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
model.infer_predict_layout(inputs)
|
model.infer_predict_layout(inputs)
|
||||||
context.set_auto_parallel_context(full_batch=True, enable_parallel_optimizer=True)
|
|
||||||
with pytest.raises(RuntimeError):
|
|
||||||
model.predict(inputs)
|
|
||||||
|
|
||||||
|
|
||||||
# standalone predict
|
# standalone predict
|
||||||
|
|
Loading…
Reference in New Issue