From a2b2727ba8006eda7ac6d87eb6e0d1e715d19dcc Mon Sep 17 00:00:00 2001 From: lichenever Date: Tue, 16 Mar 2021 09:52:15 +0800 Subject: [PATCH] change_pipeline_key_word --- mindspore/ccsrc/pipeline/jit/parse/parse_base.h | 2 +- mindspore/ops/operations/_inner_ops.py | 3 +++ tests/ut/python/parallel/test_pipeline_split.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h index e666a0b3d71..73d6086707f 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h @@ -136,7 +136,7 @@ const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags"; // define the parse constant const int64_t MAX_COMPARISON_OPS_SUPPORTED = 1; const char CUSTOM_BPROP_NAME[] = "bprop"; -const char STAGE_NAME[] = "stage"; +const char STAGE_NAME[] = "pipeline_stage"; // define the Namespace name const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 68e18d7eeeb..f458e4cd7e3 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -471,6 +471,9 @@ class Receive(PrimitiveWithInfer): self.shape = shape self.dtype = dtype self.group = group + valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] + args = {"dtype": dtype} + validator.check_scalar_or_tensor_types_same(args, valid_type, self.name) def infer_shape(self, x_shape=None): return self.shape diff --git a/tests/ut/python/parallel/test_pipeline_split.py b/tests/ut/python/parallel/test_pipeline_split.py index abc09fb44e7..957586ddcf8 100644 --- a/tests/ut/python/parallel/test_pipeline_split.py +++ b/tests/ut/python/parallel/test_pipeline_split.py @@ -77,7 +77,7 @@ class Net(nn.Cell): self.block = nn.CellList() for i in range(2): cell = MatMulCell(strategy1, strategy2, param) - cell.stage = i + cell.pipeline_stage = i self.block.append(cell) def construct(self, x):