diff --git a/docs/api/api_python/nn/mindspore.nn.Cell.rst b/docs/api/api_python/nn/mindspore.nn.Cell.rst index 39205a5ff8f..0b4a8999cfc 100644 --- a/docs/api/api_python/nn/mindspore.nn.Cell.rst +++ b/docs/api/api_python/nn/mindspore.nn.Cell.rst @@ -95,6 +95,19 @@ 检查Cell中的网络参数名称是否重复。 + + .. py:method:: set_inputs(*inputs) + + 设置编译计算图所需的输入,输入需与实例中定义的输入一致。 + + **参数:** + + - **inputs** (tuple) – Cell的输入。 + + .. note:: + + 这是一个实验接口,可能会被更改或者删除。 + .. py:method:: compile(*inputs) 编译Cell为计算图,输入需与construct中定义的输入一致。 @@ -173,6 +186,18 @@ **返回:** String类型,网络的作用域。 + + .. py:method:: get_inputs() + + 返回编译计算图所设置的输入。 + + **返回:** + + Tuple类型,编译计算图所设置的输入。 + + .. note:: + + 这是一个实验接口,可能会被更改或者删除。 .. py:method:: infer_param_pipeline_stage() diff --git a/mindspore/python/mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py b/mindspore/python/mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py index 67d09351cf7..0c4f389231b 100644 --- a/mindspore/python/mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +++ b/mindspore/python/mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py @@ -146,14 +146,13 @@ def get_single_io_arg(info): check_arg_info(info) del info['valid'] del info['name'] + if 'range' in info: + for i in range(len(info['range'])): + if info['range'][i][1] == -1: + info['range'][i][1] = None res = info else: res = None - if 'range' in info: - for i in range(len(info['range'])): - if info['range'][i][1] == -1: - info['range'][i][1] = None - res = info return res