diff --git a/mindspore/ccsrc/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/parallel/ops_info/ops_utils.h index 2d91338b114..4b8f61bb2e4 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_utils.h @@ -130,6 +130,7 @@ constexpr char FORWARD_OP[] = "forward_op"; constexpr char REDISTRIBUTION_OP[] = "redistribution_op"; constexpr char DARA_PARALLEL[] = "data_parallel"; constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter"; +constexpr char OPTIMIZER_SUB_STRING[] = "optimizer"; // Operator constexpr char VIRTUAL_DIV[] = "_VirtualDiv"; diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index 8b4f7e2dec2..894177df8d2 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -283,6 +283,10 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) { if (bool_result) { MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name(); } else if (prim->name() == CAST) { + if (cnode->fullname_with_scope().find(OPTIMIZER_SUB_STRING) != std::string::npos) { + // Do not care CASTs from optimizer + return false; + } return true; } return IsParallelCareNode(cnode) && IsSplittableOperator(prim->name()); diff --git a/tests/ut/python/parallel/test_auto_parallel_cast.py b/tests/ut/python/parallel/test_auto_parallel_cast.py index cac452de960..4a77fd0cd23 100644 --- a/tests/ut/python/parallel/test_auto_parallel_cast.py +++ b/tests/ut/python/parallel/test_auto_parallel_cast.py @@ -80,9 +80,9 @@ def test_double_star_graph(): _executor.compile(net, x, y, z, w, phase='train') strategies = _executor._get_strategy(net) - expected_strategies = {'Default/network-Net/Cast-op1': [[8, 1]], - 'Default/network-Net/Cast-op3': [[1, 8]], - 'Default/network-Net/MatMul-op2': [[8, 1], [1, 1]], + expected_strategies = {'Default/network-Net/Cast-op0': [[8, 1]], + 'Default/network-Net/Cast-op1': [[1, 8]], + 'Default/network-Net/MatMul-op3': [[8, 1], [1, 1]], 'Default/network-Net/MatMul-op4': [[1, 1], [1, 8]], - 'Default/network-Net/MatMul-op0': [[1, 8], [8, 1]]} + 'Default/network-Net/MatMul-op2': [[1, 8], [8, 1]]} assert strategies == expected_strategies diff --git a/tests/ut/python/parallel/test_auto_parallel_parameter_cast.py b/tests/ut/python/parallel/test_auto_parallel_parameter_cast.py index 691bfcdf32a..6d4452407c0 100644 --- a/tests/ut/python/parallel/test_auto_parallel_parameter_cast.py +++ b/tests/ut/python/parallel/test_auto_parallel_parameter_cast.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re import numpy as np import mindspore as ms @@ -69,9 +70,8 @@ def test_common_parameter(): _executor.compile(net, x, y, phase='train') strategies = _executor._get_strategy(net) - expected_strategies = {'Default/network-Net/MatMul-op1': [[8, 1], [1, 1]], - 'Default/network-Net/MatMul-op3': [[8, 1], [1, 1]], - 'Default/network-Net/Cast-op2': [[1, 1]], - 'Default/network-Net/MatMul-op0': [[8, 1], [1, 1]], - 'Default/network-Net/Cast-op4': [[1, 1]]} - assert strategies == expected_strategies + for (k, v) in strategies.items(): + if re.search('MatMul-op', k) is not None: + assert v == [[8, 1], [1, 1]] + elif re.search('Cast-op', k) is not None: + assert v == [[1, 1]]