forked from mindspore-Ecosystem/mindspore
!2413 [Auto parallel] Check 'CAST' from optimizers
Merge pull request !2413 from Xiaoda/7-auto-parallel-check-optimizer-ops
This commit is contained in:
commit
e9e4442dcb
|
@ -130,6 +130,7 @@ constexpr char FORWARD_OP[] = "forward_op";
|
||||||
constexpr char REDISTRIBUTION_OP[] = "redistribution_op";
|
constexpr char REDISTRIBUTION_OP[] = "redistribution_op";
|
||||||
constexpr char DARA_PARALLEL[] = "data_parallel";
|
constexpr char DARA_PARALLEL[] = "data_parallel";
|
||||||
constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter";
|
constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter";
|
||||||
|
constexpr char OPTIMIZER_SUB_STRING[] = "optimizer";
|
||||||
|
|
||||||
// Operator
|
// Operator
|
||||||
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";
|
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";
|
||||||
|
|
|
@ -283,6 +283,10 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) {
|
||||||
if (bool_result) {
|
if (bool_result) {
|
||||||
MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name();
|
MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name();
|
||||||
} else if (prim->name() == CAST) {
|
} else if (prim->name() == CAST) {
|
||||||
|
if (cnode->fullname_with_scope().find(OPTIMIZER_SUB_STRING) != std::string::npos) {
|
||||||
|
// Do not care CASTs from optimizer
|
||||||
|
return false;
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return IsParallelCareNode(cnode) && IsSplittableOperator(prim->name());
|
return IsParallelCareNode(cnode) && IsSplittableOperator(prim->name());
|
||||||
|
|
|
@ -80,9 +80,9 @@ def test_double_star_graph():
|
||||||
|
|
||||||
_executor.compile(net, x, y, z, w, phase='train')
|
_executor.compile(net, x, y, z, w, phase='train')
|
||||||
strategies = _executor._get_strategy(net)
|
strategies = _executor._get_strategy(net)
|
||||||
expected_strategies = {'Default/network-Net/Cast-op1': [[8, 1]],
|
expected_strategies = {'Default/network-Net/Cast-op0': [[8, 1]],
|
||||||
'Default/network-Net/Cast-op3': [[1, 8]],
|
'Default/network-Net/Cast-op1': [[1, 8]],
|
||||||
'Default/network-Net/MatMul-op2': [[8, 1], [1, 1]],
|
'Default/network-Net/MatMul-op3': [[8, 1], [1, 1]],
|
||||||
'Default/network-Net/MatMul-op4': [[1, 1], [1, 8]],
|
'Default/network-Net/MatMul-op4': [[1, 1], [1, 8]],
|
||||||
'Default/network-Net/MatMul-op0': [[1, 8], [8, 1]]}
|
'Default/network-Net/MatMul-op2': [[1, 8], [8, 1]]}
|
||||||
assert strategies == expected_strategies
|
assert strategies == expected_strategies
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import re
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
|
@ -69,9 +70,8 @@ def test_common_parameter():
|
||||||
|
|
||||||
_executor.compile(net, x, y, phase='train')
|
_executor.compile(net, x, y, phase='train')
|
||||||
strategies = _executor._get_strategy(net)
|
strategies = _executor._get_strategy(net)
|
||||||
expected_strategies = {'Default/network-Net/MatMul-op1': [[8, 1], [1, 1]],
|
for (k, v) in strategies.items():
|
||||||
'Default/network-Net/MatMul-op3': [[8, 1], [1, 1]],
|
if re.search('MatMul-op', k) is not None:
|
||||||
'Default/network-Net/Cast-op2': [[1, 1]],
|
assert v == [[8, 1], [1, 1]]
|
||||||
'Default/network-Net/MatMul-op0': [[8, 1], [1, 1]],
|
elif re.search('Cast-op', k) is not None:
|
||||||
'Default/network-Net/Cast-op4': [[1, 1]]}
|
assert v == [[1, 1]]
|
||||||
assert strategies == expected_strategies
|
|
||||||
|
|
Loading…
Reference in New Issue