forked from mindspore-Ecosystem/mindspore
!16158 modify check strategy for scatter update
From: @yangzhenzhang Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsuteng
This commit is contained in:
commit
9b4ff99305
|
@ -71,10 +71,22 @@ Status ScatterUpdateInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
}
|
||||
|
||||
if (std::accumulate(stra[2].begin(), stra[2].begin() + stra[1].size(), 1, std::multiplies<int64_t>()) != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The indices can not be split";
|
||||
MS_LOG(ERROR) << name_ << ": The first " << stra[1].size() << " dimensions of updates can not be split";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (stra[0].size() - 1 != stra[2].size() - stra[1].size()) {
|
||||
MS_LOG(ERROR) << name_ << ": updates.strategy must be equal to indices.strategy + input.strategy[1:]";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < stra[0].size(); ++i) {
|
||||
if (stra[0][i] != stra[2][stra[1].size() + i - 1]) {
|
||||
MS_LOG(ERROR) << name_ << ": updates.strategy must be equal to indices.strategy + input.strategy[1:]";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
""" test scatter update """
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Model, Parameter
|
||||
from mindspore.ops import operations as P
|
||||
|
@ -52,6 +53,19 @@ def test_distribute_predict():
|
|||
return predict_map, output
|
||||
|
||||
|
||||
def test_scatter_update_wrong_strategy():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
|
||||
inputs = Tensor(np.ones([32, 64, 128]).astype(np.float32))
|
||||
strategy1 = ((1, 2, 4), (1, 1), (1, 1, 4, 2))
|
||||
strategy2 = ((1, 2, 4), (1, 2, 4))
|
||||
net = Net(strategy1, strategy2)
|
||||
model = Model(net)
|
||||
with pytest.raises(RuntimeError):
|
||||
model.predict(inputs)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_distribute_predict_auto_parallel():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, full_batch=True)
|
||||
|
|
Loading…
Reference in New Issue