forked from mindspore-Ecosystem/mindspore
!13382 [PipelineSplit]change pipeline key word
From: @lichen666 Reviewed-by: @kisnwang,@zhunaipan Signed-off-by: @zhunaipan
This commit is contained in:
commit
7454ac8ecd
|
@ -136,7 +136,7 @@ const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags";
|
|||
// define the parse constant
|
||||
const int64_t MAX_COMPARISON_OPS_SUPPORTED = 1;
|
||||
const char CUSTOM_BPROP_NAME[] = "bprop";
|
||||
const char STAGE_NAME[] = "stage";
|
||||
const char STAGE_NAME[] = "pipeline_stage";
|
||||
|
||||
// define the Namespace name
|
||||
const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace
|
||||
|
|
|
@ -478,6 +478,9 @@ class Receive(PrimitiveWithInfer):
|
|||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
self.group = group
|
||||
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
|
||||
args = {"dtype": dtype}
|
||||
validator.check_scalar_or_tensor_types_same(args, valid_type, self.name)
|
||||
|
||||
def infer_shape(self, x_shape=None):
|
||||
return self.shape
|
||||
|
|
|
@ -77,7 +77,7 @@ class Net(nn.Cell):
|
|||
self.block = nn.CellList()
|
||||
for i in range(2):
|
||||
cell = MatMulCell(strategy1, strategy2, param)
|
||||
cell.stage = i
|
||||
cell.pipeline_stage = i
|
||||
self.block.append(cell)
|
||||
|
||||
def construct(self, x):
|
||||
|
|
Loading…
Reference in New Issue