!16158 modify check strategy for scatter update

From: @yangzhenzhang
Reviewed-by: @kisnwang,@stsuteng
Signed-off-by: @stsuteng
This commit is contained in:
mindspore-ci-bot 2021-05-11 09:41:07 +08:00 committed by Gitee
commit 9b4ff99305
2 changed files with 27 additions and 1 deletions

View File

@ -71,10 +71,22 @@ Status ScatterUpdateInfo::CheckStrategy(const StrategyPtr &strategy) {
}
if (std::accumulate(stra[2].begin(), stra[2].begin() + stra[1].size(), 1, std::multiplies<int64_t>()) != 1) {
MS_LOG(ERROR) << name_ << ": The indices can not be split";
MS_LOG(ERROR) << name_ << ": The first " << stra[1].size() << " dimensions of updates can not be split";
return FAILED;
}
if (stra[0].size() - 1 != stra[2].size() - stra[1].size()) {
MS_LOG(ERROR) << name_ << ": updates.strategy must be equal to indices.strategy + input.strategy[1:]";
return FAILED;
}
for (size_t i = 1; i < stra[0].size(); ++i) {
if (stra[0][i] != stra[2][stra[1].size() + i - 1]) {
MS_LOG(ERROR) << name_ << ": updates.strategy must be equal to indices.strategy + input.strategy[1:]";
return FAILED;
}
}
return SUCCESS;
}

View File

@ -14,6 +14,7 @@
# ============================================================================
""" test scatter update """
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor, Model, Parameter
from mindspore.ops import operations as P
@ -52,6 +53,19 @@ def test_distribute_predict():
return predict_map, output
def test_scatter_update_wrong_strategy():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
inputs = Tensor(np.ones([32, 64, 128]).astype(np.float32))
strategy1 = ((1, 2, 4), (1, 1), (1, 1, 4, 2))
strategy2 = ((1, 2, 4), (1, 2, 4))
net = Net(strategy1, strategy2)
model = Model(net)
with pytest.raises(RuntimeError):
model.predict(inputs)
context.reset_auto_parallel_context()
def test_distribute_predict_auto_parallel():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, full_batch=True)