Add api mindspore.jit and mindspore.jit_class.
1、ms_function will be removed in a future version and replaced with jit. 2、ms_class will be removed in a future version and replaced with jit_class.
This commit is contained in:
parent
8f81c29530
commit
b98fc0021a
|
@ -23,6 +23,7 @@ https://arxiv.org/abs/1409.3215-
|
|||
https://arxiv.org/abs/1706.02515alpha
|
||||
https://discuss.tvm.ai/t/pool2d-gives-bad-output-for-integer-inputs/3377low
|
||||
http://www.apache.org/licenses/LICENSE-2.0////
|
||||
http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.21.1417
|
||||
https://raw.githubusercontent.com/tvmai/tvmai.github.io/master/images/docs/inferbound/gatherbound.png:align
|
||||
https://raw.githubusercontent.com/tvmai/tvmai.github.io/master/images/docs/inferbound/inferbound_traversal.png:align
|
||||
https://raw.githubusercontent.com/tvmai/tvmai.github.io/master/images/relay/let_scope.png:align
|
||||
|
|
|
@ -546,14 +546,14 @@ MindSpore Numpy与MindSpore特性结合
|
|||
|
||||
mindspore.numpy能够充分利用MindSpore的强大功能,实现算子的自动微分,并使用图模式加速运算,帮助用户快速构建高效的模型。同时,MindSpore还支持多种后端设备,包括Ascend、GPU和CPU等,用户可以根据自己的需求灵活设置。以下提供了几种常用方法:
|
||||
|
||||
- `ms_function`: 将代码包裹进图模式,用于提高代码运行效率。
|
||||
- `jit` 装饰器: 将代码包裹进图模式,用于提高代码运行效率。
|
||||
- `GradOperation`: 用于自动求导。
|
||||
- `mindspore.set_context`: 用于设置运行模式和后端设备等。
|
||||
- `mindspore.nn.Cell`: 用于建立深度学习模型。
|
||||
|
||||
使用示例如下:
|
||||
|
||||
- ms_function使用示例
|
||||
- `jit` 装饰器使用示例
|
||||
|
||||
首先,以神经网络里经常使用到的矩阵乘与矩阵加算子为例:
|
||||
|
||||
|
@ -585,13 +585,13 @@ mindspore.numpy能够充分利用MindSpore的强大功能,实现算子的自
|
|||
[2816. 2816. 2816. 2816.]]
|
||||
|
||||
|
||||
对上述示例,我们可以借助 `ms_function` 将所有算子编译到一张静态图里以加快运行效率,示例如下:
|
||||
对上述示例,我们可以借助 `jit` 装饰器将所有算子编译到一张静态图里以加快运行效率,示例如下:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from mindspore import ms_function
|
||||
from mindspore import jit
|
||||
|
||||
forward_compiled = ms_function(forward)
|
||||
forward_compiled = jit(forward)
|
||||
print(forward(x, w1, b1, w2, b2, w3, b3))
|
||||
|
||||
运行结果如下:
|
||||
|
@ -602,11 +602,11 @@ mindspore.numpy能够充分利用MindSpore的强大功能,实现算子的自
|
|||
[2816. 2816. 2816. 2816.]]
|
||||
|
||||
.. note::
|
||||
目前静态图不支持在Python交互式模式下运行,并且有部分语法限制。`ms_function` 的更多信息可参考 `API ms_function <https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/mindspore.ms_function.html>`_ 。
|
||||
目前静态图不支持在Python交互式模式下运行,并且有部分语法限制。
|
||||
|
||||
- GradOperation使用示例
|
||||
|
||||
`GradOperation` 可以实现自动求导。以下示例可以实现对上述没有用 `ms_function` 修饰的 `forward` 函数定义的计算求导。
|
||||
`GradOperation` 可以实现自动求导。以下示例可以实现对上述没有用 `jit` 修饰的 `forward` 函数定义的计算求导。
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -630,15 +630,15 @@ mindspore.numpy能够充分利用MindSpore的强大功能,实现算子的自
|
|||
...
|
||||
Tensor(shape=[4], dtype=Float32, value= [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]))
|
||||
|
||||
如果要对 `ms_function` 修饰的 `forward` 计算求导,需要提前使用 `set_context` 设置运算模式为图模式,示例如下:
|
||||
如果要对 `jit` 修饰的 `forward` 计算求导,需要提前使用 `set_context` 设置运算模式为图模式,示例如下:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from mindspore import ms_function, set_context, GRAPH_MODE
|
||||
from mindspore import jit, set_context, GRAPH_MODE
|
||||
|
||||
set_context(mode=GRAPH_MODE)
|
||||
grad_all = ops.composite.GradOperation(get_all=True)
|
||||
print(grad_all(ms_function(forward))(x, w1, b1, w2, b2, w3, b3))
|
||||
print(grad_all(jit(forward))(x, w1, b1, w2, b2, w3, b3))
|
||||
|
||||
运行结果如下:
|
||||
|
||||
|
|
|
@ -136,8 +136,10 @@ mindspore
|
|||
:toctree: mindspore
|
||||
|
||||
mindspore.JitConfig
|
||||
mindspore.ms_function
|
||||
mindspore.jit
|
||||
mindspore.jit_class
|
||||
mindspore.ms_class
|
||||
mindspore.ms_function
|
||||
mindspore.mutable
|
||||
|
||||
日志
|
||||
|
|
|
@ -8,11 +8,11 @@ mindspore.export
|
|||
.. note::
|
||||
- 当导出文件格式为AIR、ONNX时,单个Tensor的大小不能超过2GB。
|
||||
- 当 `file_name` 没有后缀时,系统会根据 `file_format` 自动添加后缀。
|
||||
- 现已支持将Mindspore function (ms_function) 导出成MINDIR格式文件。
|
||||
- 当导出ms_function时,函数内不能包含有类属性参与的计算。
|
||||
- 现已支持将 `jit` 修饰的函数导出成MINDIR格式文件。
|
||||
- 当导出 `jit` 修饰的函数时,函数内不能包含有类属性参与的计算。
|
||||
|
||||
参数:
|
||||
- **net** (Union[Cell, ms_function]) - MindSpore网络结构。
|
||||
- **net** (Union[Cell, function]) - MindSpore网络结构。
|
||||
- **inputs** (Union[Tensor, Dataset, List, Tuple, Number, Bool]) - 网络的输入,如果网络有多个输入,需要一同传入。当传入的类型为 `Dataset` 时,将会把数据预处理行为同步保存起来。需要手动调整batch的大小,当前仅支持获取 `Dataset` 的 `image` 列。
|
||||
- **file_name** (str) - 导出模型的文件名称。
|
||||
- **file_format** (str) - MindSpore目前支持导出"AIR","ONNX"和"MINDIR"格式的模型。
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
mindspore.jit
|
||||
=============
|
||||
|
||||
.. py:function:: mindspore.jit(fn=None, input_signature=None, hash_args=None, jit_config=None)
|
||||
|
||||
将Python函数编译为一张可调用的MindSpore图。
|
||||
|
||||
MindSpore可以在运行时对图进行优化。
|
||||
|
||||
参数:
|
||||
- **fn** (Function) - 要编译成图的Python函数。默认值:None。
|
||||
- **input_signature** (Tensor) - 用于表示输入参数的Tensor。Tensor的shape和dtype将作为函数的输入shape和dtype。默认值:None。
|
||||
- **hash_args** (Union[Object, List or Tuple of Objects]) - `fn` 里面用到的自由变量,比如外部函数或类对象,再次调用时若 `hash_args` 出现变化会触发重新编译。默认值:None。
|
||||
- **jit_config** (JitConfig) - 编译时所使用的JitConfig配置项,详细可参考 :class:`mindspore.JitConfig`。默认值:None。
|
||||
|
||||
.. note::
|
||||
- 如果指定了 `input_signature` ,则 `fn` 的每个输入都必须是Tensor。并且 `fn` 的输入参数将不会接受 `**kwargs` 参数。
|
||||
|
||||
返回:
|
||||
函数,如果 `fn` 不是None,则返回一个已经将输入 `fn` 编译成图的可执行函数;如果 `fn` 为None,则返回一个装饰器。当这个装饰器使用单个 `fn` 参数进行调用时,等价于 `fn` 不是None的场景。
|
|
@ -0,0 +1,18 @@
|
|||
mindspore.jit_class
|
||||
===================
|
||||
|
||||
.. py:function:: mindspore.jit_class(cls)
|
||||
|
||||
用户自定义类的类装饰器。
|
||||
|
||||
MindSpore可以通过jit_class识别用户定义的类,从而获取这些类的属性和方法。
|
||||
|
||||
参数:
|
||||
- **cls** (Class) - 用户自定义的类。
|
||||
|
||||
返回:
|
||||
类。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 jit_class 用于非 class 类型或者 nn.Cell。
|
||||
- **AttributeError** - 如果调用了 jit_class 装饰的类的私有属性或魔术方法。
|
|
@ -7,11 +7,14 @@ mindspore.ms_class
|
|||
|
||||
MindSpore可以通过ms_class识别用户定义的类,从而获取这些类的属性和方法。
|
||||
|
||||
.. note::
|
||||
`ms_class` 将在未来版本中弃用和移除,请改用 :func:`mindspore.jit_class`。
|
||||
|
||||
参数:
|
||||
- **cls** (Class) - 用户自定义的类。
|
||||
|
||||
返回:
|
||||
带有 __ms_class__ 属性的类。
|
||||
类。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 ms_class 用于非 class 类型或者 nn.Cell。
|
||||
|
|
|
@ -7,6 +7,9 @@ mindspore.ms_function
|
|||
|
||||
MindSpore可以在运行时对图进行优化。
|
||||
|
||||
.. note::
|
||||
`ms_function` 将在未来版本中弃用和移除,请改用 :func:`mindspore.jit`。
|
||||
|
||||
参数:
|
||||
- **fn** (Function) - 要编译成图的Python函数。默认值:None。
|
||||
- **input_signature** (Tensor) - 用于表示输入参数的Tensor。Tensor的shape和dtype将作为函数的输入shape和dtype。默认值:None。
|
||||
|
@ -18,4 +21,3 @@ mindspore.ms_function
|
|||
|
||||
返回:
|
||||
函数,如果 `fn` 不是None,则返回一个已经将输入 `fn` 编译成图的可执行函数;如果 `fn` 为None,则返回一个装饰器。当这个装饰器使用单个 `fn` 参数进行调用时,等价于 `fn` 不是None的场景。
|
||||
|
||||
|
|
|
@ -339,7 +339,7 @@
|
|||
设置Cell对象的反向hook函数。
|
||||
|
||||
.. note::
|
||||
- `register_backward_hook(hook_fn)` 在图模式下,或者在PyNative模式下使用 `ms_function` 功能时不起作用。
|
||||
- `register_backward_hook(hook_fn)` 在图模式下,或者在PyNative模式下使用 `jit` 装饰器功能时不起作用。
|
||||
- hook_fn必须有如下代码定义。 `cell_id` 是已注册Cell对象的信息,包括名称和ID。 `grad_input` 是反向传递给Cell对象的梯度。 `grad_output` 是Cell对象的反向输出梯度。用户可以在hook_fn中打印梯度数据或者返回新的输出梯度。
|
||||
- hook_fn返回新的输出梯度或者None:hook_fn(cell_id, grad_input, grad_output) -> New grad_output or None。
|
||||
- 为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_backward_hook(hook_fn)` 。
|
||||
|
@ -359,7 +359,7 @@
|
|||
设置Cell对象的正向hook函数。
|
||||
|
||||
.. note::
|
||||
- `register_forward_hook(hook_fn)` 在图模式下,或者在PyNative模式下使用 `ms_function` 功能时不起作用。
|
||||
- `register_forward_hook(hook_fn)` 在图模式下,或者在PyNative模式下使用 `jit` 装饰器功能时不起作用。
|
||||
- hook_fn必须有如下代码定义。 `cell_id` 是已注册Cell对象的信息,包括名称和ID。 `inputs` 是网络正向传播时Cell对象的输入数据。 `outputs` 是网络正向传播时Cell对象的输出数据。用户可以在hook_fn中打印数据或者返回新的输出数据。
|
||||
- hook_fn返回新的输出数据或者None:hook_fn(cell_id, inputs, outputs) -> New outputs or None。
|
||||
- 为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_forward_hook(hook_fn)` 。
|
||||
|
@ -379,7 +379,7 @@
|
|||
设置Cell对象的正向pre_hook函数。
|
||||
|
||||
.. note::
|
||||
- `register_forward_pre_hook(hook_fn)` 在图模式下,或者在PyNative模式下使用 `ms_function` 功能时不起作用。
|
||||
- `register_forward_pre_hook(hook_fn)` 在图模式下,或者在PyNative模式下使用 `jit` 装饰器功能时不起作用。
|
||||
- hook_fn必须有如下代码定义。 `cell_id` 是已注册Cell对象的信息,包括名称和ID。 `inputs` 是网络正向传播时Cell对象的输入数据。用户可以在hook_fn中打印输入数据或者返回新的输入数据。
|
||||
- hook_fn返回新的输入数据或者None:hook_fn(cell_id, inputs) -> New inputs or None。
|
||||
- 为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_forward_pre_hook(hook_fn)` 。
|
||||
|
|
|
@ -551,14 +551,14 @@ Interact With MindSpore Functions
|
|||
|
||||
Since `mindspore.numpy` directly wraps MindSpore tensors and operators, it has all the advantages and properties of MindSpore. In this section, we will briefly introduce how to employ MindSpore execution management and automatic differentiation in `mindspore.numpy` coding scenarios. These include:
|
||||
|
||||
- `ms_function`: for running codes in static graph mode for better efficiency.
|
||||
- `jit` decorator: for running codes in static graph mode for better efficiency.
|
||||
- `GradOperation`: for automatic gradient computation.
|
||||
- `mindspore.set_context`: for `mindspore.numpy` execution management.
|
||||
- `mindspore.nn.Cell`: for using `mindspore.numpy` interfaces in MindSpore Deep Learning Models.
|
||||
|
||||
The following are examples:
|
||||
|
||||
- Use ms_function to run code in static graph mode
|
||||
- Use `jit` decorator to run code in static graph mode
|
||||
|
||||
Let's first see an example consisted of matrix multiplication and bias add, which is a typical process in Neural Networks:
|
||||
|
||||
|
@ -590,13 +590,13 @@ The following are examples:
|
|||
[2816. 2816. 2816. 2816.]]
|
||||
|
||||
|
||||
In this function, MindSpore dispatches each computing kernel to device separately. However, with the help of `ms_function`, we can compile all operations into a single static computing graph.
|
||||
In this function, MindSpore dispatches each computing kernel to device separately. However, with the help of `jit` decorator, we can compile all operations into a single static computing graph.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from mindspore import ms_function
|
||||
from mindspore import jit
|
||||
|
||||
forward_compiled = ms_function(forward)
|
||||
forward_compiled = jit(forward)
|
||||
print(forward(x, w1, b1, w2, b2, w3, b3))
|
||||
|
||||
The result is as follows:
|
||||
|
@ -607,11 +607,11 @@ The following are examples:
|
|||
[2816. 2816. 2816. 2816.]]
|
||||
|
||||
.. note::
|
||||
Currently, static graph cannot run in Python interactive mode and not all python types can be passed into functions decorated with `ms_function`. For details about how to use `ms_function`, see `API ms_function <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.ms_function.html>`_ .
|
||||
Currently, static graph cannot run in Python interactive mode and not all python types can be passed into functions decorated with `jit`.
|
||||
|
||||
- Use GradOperation to compute deratives
|
||||
|
||||
`GradOperation` can be used to take deratives from normal functions and functions decorated with `ms_function`. Take the previous example:
|
||||
`GradOperation` can be used to take deratives from normal functions and functions decorated with `jit`. Take the previous example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -635,16 +635,16 @@ The following are examples:
|
|||
...
|
||||
Tensor(shape=[4], dtype=Float32, value= [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]))
|
||||
|
||||
To take the gradient of `ms_function` compiled functions, first we need to set the execution mode to static graph mode.
|
||||
To take the gradient of `jit` compiled functions, first we need to set the execution mode to static graph mode.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from mindspore import ms_function, set_context, GRAPH_MODE
|
||||
from mindspore import jit, set_context, GRAPH_MODE
|
||||
|
||||
set_context(mode=GRAPH_MODE)
|
||||
grad_all = ops.composite.GradOperation(get_all=True)
|
||||
print(grad_all(ms_function(forward))(x, w1, b1, w2, b2, w3, b3))
|
||||
print(grad_all(jit(forward))(x, w1, b1, w2, b2, w3, b3))
|
||||
|
||||
The result is as follows:
|
||||
|
||||
|
|
|
@ -246,8 +246,10 @@ JIT
|
|||
:template: classtemplate.rst
|
||||
|
||||
mindspore.JitConfig
|
||||
mindspore.ms_function
|
||||
mindspore.jit
|
||||
mindspore.jit_class
|
||||
mindspore.ms_class
|
||||
mindspore.ms_function
|
||||
mindspore.mutable
|
||||
|
||||
Log
|
||||
|
|
|
@ -189,8 +189,8 @@ static bool HasSideEffectBackProp(const CNodePtr &cnode) {
|
|||
static AnfNodePtr SkipHookNodeInBackProp(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (IsPrimitiveCNode(node, prim::kPrimHookBackward) || IsPrimitiveCNode(node, prim::kPrimCellBackwardHook)) {
|
||||
MS_LOG(WARNING)
|
||||
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
|
||||
MS_LOG(WARNING) << "Hook operation does not work in graph mode or functions decorated with 'jit', it will be "
|
||||
"eliminated during compilation.";
|
||||
auto output_cnode = node->cast_ptr<CNode>();
|
||||
if (output_cnode->size() - 1 == 1) {
|
||||
return output_cnode->input(1);
|
||||
|
@ -220,8 +220,8 @@ static AnfNodePtr SkipHookNodeInBackProp(const AnfNodePtr &node) {
|
|||
auto tuple_get_item = node->cast_ptr<CNode>();
|
||||
auto inp = tuple_get_item->input(1);
|
||||
if (IsPrimitiveCNode(inp, prim::kPrimHookBackward) || IsPrimitiveCNode(inp, prim::kPrimCellBackwardHook)) {
|
||||
MS_LOG(WARNING)
|
||||
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
|
||||
MS_LOG(WARNING) << "Hook operation does not work in graph mode or functions decorated with 'jit', it will be "
|
||||
"eliminated during compilation.";
|
||||
constexpr size_t idx = 2;
|
||||
auto v_node = dyn_cast_ptr<ValueNode>(tuple_get_item->input(idx));
|
||||
MS_EXCEPTION_IF_NULL(v_node);
|
||||
|
@ -362,7 +362,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
|
|||
TraceGuard guard(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info()));
|
||||
k_app = k_graph_->NewCNode(inputs);
|
||||
}
|
||||
// Run in pynative mode, when @ms_function is used.
|
||||
// Run in pynative mode, when @jit is used.
|
||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
auto pynative_exec = pynative::PyNativeExecutor::GetInstance();
|
||||
auto grad_exec = pynative_exec->grad_executor();
|
||||
|
|
|
@ -211,7 +211,7 @@ std::vector<AnfNodePtr> PynativeDFunctor::RunInputReplace(const FuncGraphPtr &bp
|
|||
}
|
||||
|
||||
void PynativeDFunctor::ReplaceEquivdout(const CNodePtr &k_app, const CNodePtr &cnode_morph) {
|
||||
// The process of replacing forward node only works in pynative mode, when @ms_function is used.
|
||||
// The process of replacing forward node only works in pynative mode, when @jit is used.
|
||||
MS_EXCEPTION_IF_NULL(cnode_morph);
|
||||
MS_LOG(DEBUG) << "Run replace for cnode morph: " << cnode_morph->DebugString(2);
|
||||
// Get forward node and its fprop graph, bprop graph.
|
||||
|
|
|
@ -68,8 +68,8 @@ class SpecialOpEliminater : public OptimizerCaller {
|
|||
new_node = (*eliminater)(optimizer, node);
|
||||
if (new_node != nullptr) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimHookBackward) || IsPrimitiveCNode(node, prim::kPrimCellBackwardHook)) {
|
||||
MS_LOG(WARNING)
|
||||
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
|
||||
MS_LOG(WARNING) << "Hook operation does not work in graph mode or functions decorated with 'jit', it will be "
|
||||
"eliminated during compilation.";
|
||||
}
|
||||
return new_node;
|
||||
}
|
||||
|
|
|
@ -697,7 +697,7 @@ AnfNodePtr ExpandVmap(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr
|
|||
return NewValueNode(tf_fg);
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Currently, the first argument in F.vmap only supports Cell, Python defined "
|
||||
"function or @ms_function decorated function.";
|
||||
"function or @jit decorated function.";
|
||||
}
|
||||
|
||||
std::string GetShapeString(const ShapeVector &tensor_shape) {
|
||||
|
|
|
@ -921,7 +921,7 @@ bool CheckGraphOutputConstOrParameter(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
|
||||
bool EliminateForwardCNode(const ResourcePtr &resource) {
|
||||
// This function only works in Pynative mode. The func_graph is decorated by ms_function.
|
||||
// This function only works in Pynative mode. The func_graph is decorated with 'jit'.
|
||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ namespace mindspore {
|
|||
namespace pipeline {
|
||||
struct ExecutorInfo {
|
||||
FuncGraphPtr func_graph;
|
||||
// The grad graph of func_graph, it will create in PyNative mode when @ms_function is used.
|
||||
// The grad graph of func_graph, it will create in PyNative mode when @jit is used.
|
||||
FuncGraphPtr grad_graph;
|
||||
ResourcePtr resource;
|
||||
// The num of input data.
|
||||
|
|
|
@ -248,7 +248,7 @@ ValuePtr ConvertModuleNameSpace(const py::object &obj) {
|
|||
|
||||
ValuePtr ConvertMsClass(const py::object &obj) {
|
||||
MS_LOG(DEBUG) << "Converting ms class";
|
||||
// Convert class instance decorated with ms_class.
|
||||
// Convert class instance decorated with jit_class.
|
||||
if (py::hasattr(obj, PYTHON_PARSE_METHOD)) {
|
||||
MS_LOG(DEBUG) << "Convert obj to func graph.";
|
||||
FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
|
||||
|
|
|
@ -799,7 +799,7 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
|
|||
}
|
||||
|
||||
void GraphExecutorPy::InitCompileCacheInfo(const ResourcePtr &resource, const std::string &phase) {
|
||||
// The compilation cache only support for training cell or ms_function currently.
|
||||
// The compilation cache only support for training cell or functions decorated with 'jit' currently.
|
||||
// If enable compilation cache, it will get a non-empty dependent files list from python.
|
||||
if (compile_cache_dep_files_.empty()) {
|
||||
return;
|
||||
|
|
|
@ -129,7 +129,7 @@ class Resource : public ResourceBase {
|
|||
// We keep all arguments inputs here for subsequent procedure.
|
||||
std::vector<ValuePtr> arguments_;
|
||||
abstract::AbstractBasePtrList args_abs_;
|
||||
// The source obj to compile, usually a `Cell` or `ms_function` decorated function.
|
||||
// The source obj to compile, usually a `Cell` or `jit` decorated function.
|
||||
py::object source_input_;
|
||||
bool is_cleaned_;
|
||||
// The func_graph_ is loaded from mindir
|
||||
|
|
|
@ -1721,7 +1721,7 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
|
|||
}
|
||||
}
|
||||
|
||||
// Get attribute or method of class object decorated by ms_class.
|
||||
// Get attribute or method of class object decorated with 'jit_class'.
|
||||
auto class_value = GetMsClassObject(data_args);
|
||||
if (class_value != nullptr) {
|
||||
return GetEvaluatedValueForMsClassAttrOrMethod(args_spec_list, class_value, out_conf);
|
||||
|
@ -2012,7 +2012,7 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
if (py::isinstance<py::none>(obj)) {
|
||||
MS_LOG(EXCEPTION) << "Create python object `" << py::str(class_type)
|
||||
<< "` failed, only support to create \'Cell\', \'Primitive\' or "
|
||||
<< "user-defined Class decorated with \'ms_class\'.";
|
||||
<< "user-defined Class decorated with \'jit_class\'.";
|
||||
}
|
||||
|
||||
// Process the object.
|
||||
|
|
|
@ -365,7 +365,7 @@ EvalResultPtr ConvertClassToFunc(const CNodePtr &cnode, const AbstractBasePtr &a
|
|||
MS_LOG(ERROR) << "Can not cast to a AbstractFunction from " << abs->ToString() << ".";
|
||||
MS_LOG(ERROR) << "It's called at: " << cnode->DebugString();
|
||||
MS_EXCEPTION(ValueError) << "Can not call " << class_name << " to create python object in graph mode. "
|
||||
<< "Try using ms_class to decorate the class?";
|
||||
<< "Try using 'jit_class' to decorate the class?";
|
||||
}
|
||||
auto list_func_fg = parse::ParsePythonCode(py_fn);
|
||||
auto fg = cnode->func_graph();
|
||||
|
|
|
@ -503,7 +503,7 @@ void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::args &
|
|||
|
||||
py::object co_name = py::getattr(code_obj, "co_name");
|
||||
if (std::string(py::str(co_name)) == "staging_specialize") {
|
||||
MS_LOG(EXCEPTION) << "Decorating bprop with '@ms_function' is not supported.";
|
||||
MS_LOG(EXCEPTION) << "Decorating bprop with '@jit' is not supported.";
|
||||
}
|
||||
// Three parameters self, out and dout need to be excluded
|
||||
const size_t inputs_num = static_cast<size_t>(py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3);
|
||||
|
|
|
@ -373,8 +373,7 @@ BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const {
|
|||
py::object co_name = py::getattr(code_obj, "co_name");
|
||||
if (std::string(py::str(co_name)) == "staging_specialize") {
|
||||
py::object name_obj = py::getattr(elem.second, "__name__");
|
||||
MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj)
|
||||
<< " with '@ms_function' is not supported.";
|
||||
MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj) << " with '@jit' is not supported.";
|
||||
}
|
||||
SyncData(grad_output);
|
||||
py::tuple hook_fn_args = ConstructCellHookFnArgs(cell_id, iter->second, grad_output);
|
||||
|
@ -404,7 +403,7 @@ BaseRef PrimitivePy::RunVariableHookFunction(const py::tuple &py_args) const {
|
|||
py::object co_name = py::getattr(code_obj, "co_name");
|
||||
if (std::string(py::str(co_name)) == "staging_specialize") {
|
||||
py::object name_obj = py::getattr(elem.second, "__name__");
|
||||
MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj) << " with '@ms_function' is not supported.";
|
||||
MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj) << " with '@jit' is not supported.";
|
||||
}
|
||||
SyncData(grad_output);
|
||||
py::object ret = elem.second(py::make_tuple(grad_output));
|
||||
|
|
|
@ -461,7 +461,7 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
// If the graph is a bprop graph, it should has a hash of the bprop directory.
|
||||
std::string bprop_hash_;
|
||||
|
||||
// If the graph is decorated by @ms_function and runs grad process in pynative mode,
|
||||
// If the graph is decorated with @jit and runs grad process in pynative mode,
|
||||
// forward nodes used in grad graph will be added to output for holding output values.
|
||||
bool modify_output_ = false;
|
||||
mindspore::HashSet<AnfNodePtr> used_forward_nodes_;
|
||||
|
|
|
@ -119,7 +119,7 @@ class ClassMemberNamespace(Namespace):
|
|||
except ValueError:
|
||||
raise UnboundLocalError(name)
|
||||
except KeyError:
|
||||
# Check if cls is user-defined class decorated with ms_class. If true, an exception will be thrown.
|
||||
# Check if cls is user-defined class decorated with jit_class. If true, an exception will be thrown.
|
||||
cls = d.__class__
|
||||
if hasattr(cls, '__ms_class__'):
|
||||
raise NotImplementedError(f"'{cls.__name__ }' object has no attribute or method: '{name}'.")
|
||||
|
|
|
@ -238,7 +238,7 @@ def resolve_symbol(namespace, symbol):
|
|||
if namespace.name == "numpy" and \
|
||||
isinstance(resolve_, (types.FunctionType, types.MethodType, types.ModuleType)):
|
||||
raise NotImplementedError("Mindspore does not support to use the numpy methods " \
|
||||
"within the construct() or @ms_function decorated function in graph mode.")
|
||||
"within the construct() or @jit decorated function in graph mode.")
|
||||
|
||||
# If need trope the obj
|
||||
if resolve_ in convert_object_map:
|
||||
|
@ -539,7 +539,7 @@ def is_class_type(cls):
|
|||
|
||||
|
||||
def get_ms_class_name(cls):
|
||||
"""Get the name of the class instance decorated by ms_class."""
|
||||
"""Get the name of the class instance decorated with jit_class."""
|
||||
if isinstance(cls, type):
|
||||
return cls.__name__
|
||||
return cls.__class__.__name__
|
||||
|
|
|
@ -21,7 +21,7 @@ from ._checkparam import Validator as validator
|
|||
from .common import dtype as mstype
|
||||
from . import context
|
||||
from . import ops
|
||||
from .common.api import ms_class
|
||||
from .common.api import jit_class
|
||||
from .common.parameter import Parameter
|
||||
from .common.tensor import Tensor
|
||||
from .train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager, FixedLossScaleManager
|
||||
|
@ -96,7 +96,7 @@ def all_finite(inputs):
|
|||
return ops.stack(outputs).all()
|
||||
|
||||
|
||||
@ms_class
|
||||
@jit_class
|
||||
class LossScaler(ABC):
|
||||
r"""
|
||||
Loss scaler abstract class when using mixed precision.
|
||||
|
|
|
@ -15,7 +15,8 @@
|
|||
"""Top-level reference to dtype of common module."""
|
||||
from __future__ import absolute_import
|
||||
from mindspore.common import dtype
|
||||
from mindspore.common.api import no_recursive, ms_function, ms_memory_recycle, ms_class, _convert_python_data
|
||||
from mindspore.common.api import no_recursive, ms_function, ms_memory_recycle, ms_class, _convert_python_data, \
|
||||
jit, jit_class
|
||||
from mindspore.common.dtype import Type, int8, byte, int16, short, int32, intc, int64, intp, \
|
||||
uint8, ubyte, uint16, ushort, uint32, uintc, uint64, uintp, float16, half, \
|
||||
float32, single, float64, double, bool_, float_, list_, tuple_, int_, \
|
||||
|
@ -58,7 +59,7 @@ __all__ = [
|
|||
|
||||
__all__.extend([
|
||||
"Tensor", "RowTensor", "SparseTensor", "COOTensor", "CSRTensor", # tensor
|
||||
"no_recursive", "ms_function", "ms_class", # api
|
||||
"no_recursive", "ms_function", "ms_class", 'jit', 'jit_class', # api
|
||||
"Parameter", "ParameterTuple", # parameter
|
||||
"dtype", "_convert_python_data",
|
||||
"set_seed", "get_seed", # random seed
|
||||
|
|
|
@ -254,7 +254,7 @@ class _MindsporeFunctionExecutor:
|
|||
Args:
|
||||
fn (Function): The root function to compile.
|
||||
input_signature (Function): User defines signature to verify input.
|
||||
ms_create_time(TimeStamp): The time ms_function created
|
||||
ms_create_time(TimeStamp): Time the function was created
|
||||
obj (Object): If function is a method, obj is the owner of function,
|
||||
else, obj is none.
|
||||
|
||||
|
@ -300,9 +300,9 @@ class _MindsporeFunctionExecutor:
|
|||
# Check whether hook function registered on Cell object.
|
||||
if self.obj and hasattr(self.obj, "_hook_fn_registered"):
|
||||
if self.obj._hook_fn_registered():
|
||||
logger.warning(f"For 'Cell', it's not support hook function when using ms_function. If you want to "
|
||||
f"use hook function, please use context.set_context to set pynative mode and remove "
|
||||
f"`ms_function`.")
|
||||
logger.warning(f"For 'Cell', it's not support hook function when using 'jit' decorator. "
|
||||
f"If you want to use hook function, please use context.set_context to set "
|
||||
f"pynative mode and remove 'jit' decorator.")
|
||||
# Chose dynamic shape tensors or actual input tensors as compile args.
|
||||
compile_args = self._generate_compile_args(args_list)
|
||||
# Restore the mutable attr for every arg.
|
||||
|
@ -409,7 +409,7 @@ class _MindsporeFunctionExecutor:
|
|||
self.input_signature = list(self.input_signature)
|
||||
dyn_shape = False
|
||||
for sig_args in self.input_signature:
|
||||
Validator.check_isinstance("args in `input_signature` of `ms_function`", sig_args, MetaTensor)
|
||||
Validator.check_isinstance("args in `input_signature` of `jit` decorator", sig_args, MetaTensor)
|
||||
if is_shape_unknown(sig_args.shape):
|
||||
dyn_shape = True
|
||||
if not dyn_shape:
|
||||
|
@ -465,6 +465,114 @@ def _get_ms_function_hash(hash_input):
|
|||
return _get_obj_id(hash_input)
|
||||
|
||||
|
||||
def jit(fn=None, input_signature=None, hash_args=None, jit_config=None):
|
||||
"""
|
||||
Create a callable MindSpore graph from a Python function.
|
||||
|
||||
This allows the MindSpore runtime to apply optimizations based on graph.
|
||||
|
||||
Note:
|
||||
If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
|
||||
will not accept `**kwargs`.
|
||||
|
||||
Args:
|
||||
fn (Function): The Python function that will be run as a graph. Default: None.
|
||||
input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor
|
||||
will be supplied to this function. If input_signature is specified, each input to `fn` must be a `Tensor`.
|
||||
And the input parameters of `fn` cannot accept `**kwargs`. The shape and dtype of actual inputs should
|
||||
keep the same as input_signature. Otherwise, TypeError will be raised. Default: None.
|
||||
hash_args (Union[Object, List or Tuple of Objects]): The local free variables used inside `fn`,
|
||||
like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args`
|
||||
will trigger recompilation.
|
||||
jit_config (JitConfig): Jit config for compile. Default: None.
|
||||
|
||||
Returns:
|
||||
Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
|
||||
None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is
|
||||
equal to the case when `fn` is not None.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore import ops
|
||||
>>> from mindspore import jit
|
||||
...
|
||||
>>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
||||
>>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
||||
...
|
||||
>>> # create a callable MindSpore graph by calling decorator @jit
|
||||
>>> def tensor_add(x, y):
|
||||
... z = x + y
|
||||
... return z
|
||||
...
|
||||
>>> tensor_add_graph = jit(fn=tensor_add)
|
||||
>>> out = tensor_add_graph(x, y)
|
||||
...
|
||||
>>> # create a callable MindSpore graph through decorator @jit
|
||||
>>> @jit
|
||||
... def tensor_add_with_dec(x, y):
|
||||
... z = x + y
|
||||
... return z
|
||||
...
|
||||
>>> out = tensor_add_with_dec(x, y)
|
||||
...
|
||||
>>> # create a callable MindSpore graph through decorator @jit with input_signature parameter
|
||||
>>> @jit(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
|
||||
... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
|
||||
... def tensor_add_with_sig(x, y):
|
||||
... z = x + y
|
||||
... return z
|
||||
...
|
||||
>>> out = tensor_add_with_sig(x, y)
|
||||
...
|
||||
... # Set hash_args as fn, otherwise cache of compiled `closure_fn` will not be reused.
|
||||
... # While fn differs during calling again, recompilation will be triggered.
|
||||
>>> def func(x):
|
||||
... return ops.exp(x)
|
||||
...
|
||||
>>> def closure_fn(x, fn):
|
||||
... @jit(hash_args=fn)
|
||||
... def inner_fn(a):
|
||||
... return fn(a)
|
||||
... return inner_fn(x)
|
||||
...
|
||||
>>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
|
||||
>>> for i in range(10):
|
||||
... closure_fn(inputs, func)
|
||||
"""
|
||||
|
||||
def wrap_mindspore(func):
|
||||
if hash_args:
|
||||
hash_obj = _get_ms_function_hash(hash_args)
|
||||
else:
|
||||
hash_obj = int(time.time() * 1e9)
|
||||
|
||||
@wraps(func)
|
||||
def staging_specialize(*args, **kwargs):
|
||||
if os.getenv("MS_JIT") == '0':
|
||||
return func(*args, **kwargs)
|
||||
|
||||
args = _handle_func_args(func, *args, **kwargs)
|
||||
process_obj = None
|
||||
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
|
||||
process_obj = args[0]
|
||||
# only the function or cell instance wrapped by shard will fall into this branch
|
||||
if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARRALLEL_FUNC_NAME:
|
||||
process_obj = args[0]
|
||||
args = args[1:]
|
||||
out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj, jit_config)(*args)
|
||||
return out
|
||||
|
||||
return staging_specialize
|
||||
|
||||
if fn is not None:
|
||||
return wrap_mindspore(fn)
|
||||
return wrap_mindspore
|
||||
|
||||
|
||||
def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
|
||||
"""
|
||||
Create a callable MindSpore graph from a Python function.
|
||||
|
@ -472,6 +580,7 @@ def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
|
|||
This allows the MindSpore runtime to apply optimizations based on graph.
|
||||
|
||||
Note:
|
||||
`ms_function` will be deprecated and removed in a future version. Please use `jit` instead.
|
||||
If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
|
||||
will not accept `**kwargs`.
|
||||
|
||||
|
@ -544,33 +653,9 @@ def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
|
|||
... closure_fn(inputs, func)
|
||||
"""
|
||||
|
||||
def wrap_mindspore(func):
|
||||
if hash_args:
|
||||
hash_obj = _get_ms_function_hash(hash_args)
|
||||
else:
|
||||
hash_obj = int(time.time() * 1e9)
|
||||
|
||||
@wraps(func)
|
||||
def staging_specialize(*args, **kwargs):
|
||||
if os.getenv("MS_JIT") == '0':
|
||||
return func(*args, **kwargs)
|
||||
|
||||
args = _handle_func_args(func, *args, **kwargs)
|
||||
process_obj = None
|
||||
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
|
||||
process_obj = args[0]
|
||||
# only the function or cell instance wrapped by shard will fall into this branch
|
||||
if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARRALLEL_FUNC_NAME:
|
||||
process_obj = args[0]
|
||||
args = args[1:]
|
||||
out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj, jit_config)(*args)
|
||||
return out
|
||||
|
||||
return staging_specialize
|
||||
|
||||
if fn is not None:
|
||||
return wrap_mindspore(fn)
|
||||
return wrap_mindspore
|
||||
logger.warning("'mindspore.ms_function' will be deprecated and removed in a future version. " \
|
||||
"Please use 'mindspore.jit' instead.")
|
||||
return jit(fn=fn, input_signature=input_signature, hash_args=hash_args, jit_config=jit_config)
|
||||
|
||||
|
||||
def _core(fn=None, **flags):
|
||||
|
@ -671,15 +756,18 @@ def ms_class(cls):
|
|||
|
||||
This allows MindSpore to identify user-defined classes and thus obtain their attributes and methods.
|
||||
|
||||
Note:
|
||||
`ms_class` will be deprecated and removed in a future version. Please use `jit_class` instead.
|
||||
|
||||
Args:
|
||||
cls (Class): User-defined class.
|
||||
|
||||
Returns:
|
||||
Class with __ms_class__ attribute.
|
||||
Class.
|
||||
|
||||
Raises:
|
||||
TypeError: If ms_class is used for non-class types or nn.Cell.
|
||||
AttributeError: If the private attributes or magic methods of the class decorated by ms_class is called.
|
||||
AttributeError: If the private attributes or magic methods of the class decorated with ms_class is called.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
@ -711,6 +799,9 @@ def ms_class(cls):
|
|||
20
|
||||
"""
|
||||
|
||||
logger.warning("'mindspore.ms_class' will be deprecated and removed in a future version. " \
|
||||
"Please use 'mindspore.jit_class' instead.")
|
||||
|
||||
# Check if cls is of type class.
|
||||
if not inspect.isclass(cls):
|
||||
raise TypeError(f'Decorator ms_class can only be used for class type, but got {cls}.')
|
||||
|
@ -722,6 +813,62 @@ def ms_class(cls):
|
|||
return cls
|
||||
|
||||
|
||||
def jit_class(cls):
|
||||
"""
|
||||
Class decorator for user-defined classes.
|
||||
|
||||
This allows MindSpore to identify user-defined classes and thus obtain their attributes and methods.
|
||||
|
||||
Args:
|
||||
cls (Class): User-defined class.
|
||||
|
||||
Returns:
|
||||
Class.
|
||||
|
||||
Raises:
|
||||
TypeError: If jit_class is used for non-class types or nn.Cell.
|
||||
AttributeError: If the private attributes or magic methods of the class decorated with jit_class is called.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore import jit_class
|
||||
...
|
||||
>>> @jit_class
|
||||
... class UserDefinedNet:
|
||||
... def __init__(self):
|
||||
... self.value = 10
|
||||
...
|
||||
... def func(self, x):
|
||||
... return 2 * x
|
||||
...
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super(Net, self).__init__()
|
||||
... self.net = UserDefinedNet()
|
||||
...
|
||||
... def construct(self, x):
|
||||
... out = self.net.value + self.net.func(x)
|
||||
... return out
|
||||
...
|
||||
>>> net = Net()
|
||||
>>> out = net(5)
|
||||
>>> print(out)
|
||||
20
|
||||
"""
|
||||
|
||||
# Check if cls is of type class.
|
||||
if not inspect.isclass(cls):
|
||||
raise TypeError(f'Decorator jit_class can only be used for class type, but got {cls}.')
|
||||
# Check if cls is nn.Cell.
|
||||
if issubclass(cls, ms.nn.Cell):
|
||||
raise TypeError(f"Decorator jit_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
|
||||
setattr(cls, '__ms_class__', True)
|
||||
return cls
|
||||
|
||||
|
||||
class _MsFunctionCompileContext:
|
||||
"""
|
||||
ms_function compile status manager
|
||||
|
@ -1401,4 +1548,4 @@ def ms_memory_recycle():
|
|||
_cell_graph_executor = _CellGraphExecutor()
|
||||
_pynative_executor = _PyNativeExecutor()
|
||||
|
||||
__all__ = ['ms_function', 'ms_memory_recycle', 'ms_class']
|
||||
__all__ = ['ms_function', 'ms_memory_recycle', 'ms_class', 'jit', 'jit_class']
|
||||
|
|
|
@ -1689,7 +1689,7 @@ class Cell(Cell_):
|
|||
Register forward pre hook function for Cell object.
|
||||
|
||||
Note:
|
||||
- The `register_forward_pre_hook(hook_fn)` does not work in graph mode or ms_function.
|
||||
- The `register_forward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
|
||||
- 'hook_fn' must be defined as the following code.
|
||||
`cell_id` is the information of registered Cell object, including name and ID. `inputs` is the forward
|
||||
input objects passed to the Cell. The 'hook_fn' can modify the forward input objects by returning new
|
||||
|
@ -1752,7 +1752,7 @@ class Cell(Cell_):
|
|||
raise TypeError(f"When using 'register_forward_pre_hook(hook_fn)', the type of 'hook_fn' must be python "
|
||||
f"function, but got {type(hook_fn)}.")
|
||||
if hook_fn.__code__.co_name == "staging_specialize":
|
||||
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@ms_function' is not supported.")
|
||||
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
|
||||
|
||||
self._enable_forward_pre_hook = True
|
||||
_pynative_executor.set_hook_changed(self)
|
||||
|
@ -1791,7 +1791,7 @@ class Cell(Cell_):
|
|||
Set the Cell forward hook function.
|
||||
|
||||
Note:
|
||||
- The `register_forward_hook(hook_fn)` does not work in graph mode or ms_function.
|
||||
- The `register_forward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
|
||||
- 'hook_fn' must be defined as the following code.
|
||||
`cell_id` is the information of registered Cell object, including name and ID. `inputs` is the forward
|
||||
input objects passed to the Cell. `output` is the forward output object of the Cell. The 'hook_fn' can
|
||||
|
@ -1856,7 +1856,7 @@ class Cell(Cell_):
|
|||
raise TypeError(f"When using 'register_forward_hook(hook_fn)', the type of 'hook_fn' must be python "
|
||||
f"function, but got {type(hook_fn)}.")
|
||||
if hook_fn.__code__.co_name == "staging_specialize":
|
||||
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@ms_function' is not supported.")
|
||||
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
|
||||
|
||||
self._enable_forward_hook = True
|
||||
_pynative_executor.set_hook_changed(self)
|
||||
|
@ -1893,7 +1893,7 @@ class Cell(Cell_):
|
|||
Register the backward hook function.
|
||||
|
||||
Note:
|
||||
- The `register_backward_hook(hook_fn)` does not work in graph mode or ms_function.
|
||||
- The `register_backward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
|
||||
- The 'hook_fn' must be defined as the following code.
|
||||
`cell_id` is the information of registered Cell object, including name and ID. `grad_input` is the
|
||||
gradient passed to the Cell. `grad_output` is the gradient computed and passed to the next Cell or
|
||||
|
|
|
@ -20,7 +20,7 @@ from mindspore.ops import composite as C
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.primitive import Primitive
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import jit
|
||||
from mindspore.common._decorator import deprecated
|
||||
|
||||
|
||||
|
@ -85,7 +85,7 @@ class Jvp(Cell):
|
|||
self.make_tuple = Primitive('MakeTuple')
|
||||
self.tuple_len = Primitive("tuple_len")
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def construct(self, *args):
|
||||
"""construct for jvp."""
|
||||
jvp_input = args[0:-1]
|
||||
|
@ -186,7 +186,7 @@ class Vjp(Cell):
|
|||
self.typeof = Primitive('typeof')
|
||||
self.tuple_len = Primitive("tuple_len")
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def construct(self, *args):
|
||||
front_input = args[0:-1]
|
||||
output = self.fn(*front_input)
|
||||
|
|
|
@ -17,7 +17,7 @@ from __future__ import absolute_import
|
|||
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import jit
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
from mindspore.nn.optim.optimizer import opt_init_args_register
|
||||
|
||||
|
@ -194,7 +194,7 @@ class Adagrad(Optimizer):
|
|||
self.accum = self._parameters.clone(prefix="accum", init=accum)
|
||||
self.opt = P.ApplyAdagrad(update_slots=update_slots)
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def construct(self, grads):
|
||||
params = self._parameters
|
||||
accum = self.accum
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore import context
|
|||
from mindspore.common import dtype as mstype
|
||||
from mindspore.log import logging
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import jit
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
|
@ -404,7 +404,7 @@ class AdaFactor(Optimizer):
|
|||
"""
|
||||
return False
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def construct(self, gradients):
|
||||
gradients = self.flatten_gradients(gradients)
|
||||
lr = self.get_lr()
|
||||
|
|
|
@ -20,7 +20,7 @@ import numpy as np
|
|||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import jit
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
|
@ -712,7 +712,7 @@ class Adam(Optimizer):
|
|||
|
||||
self._init_distributed_opts(use_locking, use_nesterov)
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def construct(self, gradients):
|
||||
params = self._parameters
|
||||
moment1 = self.moment1
|
||||
|
@ -970,7 +970,7 @@ class AdamWeightDecay(Optimizer):
|
|||
else:
|
||||
self.use_fused_opt = False
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def construct(self, gradients):
|
||||
gradients = self.flatten_gradients(gradients)
|
||||
weight_decay = self.get_weight_decay()
|
||||
|
@ -1185,7 +1185,7 @@ class AdamOffload(Optimizer):
|
|||
self.opt = P.AdamNoUpdateParam(use_locking, use_nesterov)
|
||||
self.opt.set_device("CPU")
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def construct(self, gradients):
|
||||
params = self._parameters
|
||||
moment1 = self.moment1
|
||||
|
|
|
@ -17,7 +17,7 @@ from __future__ import absolute_import
|
|||
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import jit
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
|
@ -196,7 +196,7 @@ class AdaMax(Optimizer):
|
|||
|
||||
self.opt = P.ApplyAdaMax()
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def construct(self, gradients):
|
||||
gradients = self.flatten_gradients(gradients)
|
||||
gradients = self.decay_weight(gradients)
|
||||
|
|
|
@ -17,7 +17,7 @@ from __future__ import absolute_import
|
|||
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import jit
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore
|
||||
|
@ -176,7 +176,7 @@ class ASGD(Optimizer):
|
|||
self.cast = P.Cast()
|
||||
self.squeeze = P.Squeeze()
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def construct(self, gradients):
|
||||
gradients = self.flatten_gradients(gradients)
|
||||
gradients = self.decay_weight(gradients)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import jit
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
|
@ -276,7 +276,7 @@ class FTRL(Optimizer):
|
|||
|
||||
self._init_distributed_opts(use_locking, learning_rate, l1, l2, lr_power)
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def construct(self, grads):
|
||||
params = self._parameters
|
||||
moments = self.moments
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore.ops import composite as C
|
|||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import jit
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
|
@ -259,7 +259,7 @@ class Lamb(Optimizer):
|
|||
self.moments2 = self.params.clone(prefix="lamb_v", init='zeros')
|
||||
self.device_ascend = context.get_context("device_target") == "Ascend"
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def construct(self, gradients):
|
||||
weight_decay = self.get_weight_decay()
|
||||
lr = self.get_lr()
|
||||
|
|
|
@ -20,7 +20,7 @@ from mindspore.ops import composite as C
|
|||
from mindspore.ops import functional as F
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.common import Tensor, Parameter, dtype as mstype
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import jit
|
||||
from mindspore.nn.optim.optimizer import _grad_scale, Optimizer
|
||||
from mindspore.nn.optim.optimizer import opt_init_args_register
|
||||
|
||||
|
@ -171,7 +171,7 @@ class LARS(Optimizer):
|
|||
|
||||
return lr
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def construct(self, gradients):
|
||||
params = self.parameters
|
||||
gradients = self.flatten_gradients(gradients)
|
||||
|
|
|
@ -17,7 +17,7 @@ from __future__ import absolute_import
|
|||
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import jit
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
|
@ -361,7 +361,7 @@ class LazyAdam(Optimizer):
|
|||
|
||||
self._init_distributed_opts(use_locking, use_nesterov)
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def construct(self, gradients):
|
||||
gradients = self.flatten_gradients(gradients)
|
||||
gradients = self.decay_weight(gradients)
|
||||
|
|
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
|||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import jit
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
|
@ -208,7 +208,7 @@ class Momentum(Optimizer):
|
|||
self._get_distributed_optimizer_list("momentum", use_nesterov=self.use_nesterov)
|
||||
self.use_dist_optimizer = self._use_distibuted_optimizer()
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def construct(self, gradients):
|
||||
params = self.params
|
||||
moments = self.moments
|
||||
|
|
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
|||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore.common import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import jit
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
from mindspore.nn.optim.optimizer import opt_init_args_register
|
||||
|
@ -197,7 +197,7 @@ class ProximalAdagrad(Optimizer):
|
|||
self.opt = P.ApplyProximalAdagrad(use_locking=use_locking)
|
||||
self.sparse_opt = P.SparseApplyProximalAdagrad(use_locking=use_locking)
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def construct(self, grads):
|
||||
params = self._parameters
|
||||
accum = self.accum
|
||||
|
|
|
@ -17,7 +17,7 @@ from __future__ import absolute_import
|
|||
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import jit
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
from mindspore.nn.optim.optimizer import opt_init_args_register
|
||||
|
||||
|
@ -225,7 +225,7 @@ class RMSProp(Optimizer):
|
|||
self.epsilon = epsilon
|
||||
self.decay = decay
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def construct(self, gradients):
|
||||
params = self._parameters
|
||||
gradients = self.flatten_gradients(gradients)
|
||||
|
|
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
|||
from mindspore import ops
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import jit
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
|
@ -191,7 +191,7 @@ class Rprop(Optimizer):
|
|||
self.select = P.Select()
|
||||
self.ones_like = P.OnesLike()
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def construct(self, gradients):
|
||||
gradients = self.flatten_gradients(gradients)
|
||||
gradients = self.decay_weight(gradients)
|
||||
|
|
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
|||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import jit
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
|
@ -193,7 +193,7 @@ class SGD(Optimizer):
|
|||
self.accum = self._parameters.clone(prefix="accum", init='zeros')
|
||||
self.stat = self._parameters.clone(prefix="stat", init='ones')
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def construct(self, gradients):
|
||||
params = self._parameters
|
||||
accum = self.accum
|
||||
|
|
|
@ -25,7 +25,7 @@ from mindspore.ops.operations.comm_ops import AllReduce, AllGather
|
|||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import jit
|
||||
|
||||
|
||||
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
|
||||
|
@ -411,7 +411,7 @@ class DistributedGradReducer(Cell):
|
|||
self.mode = context.get_context("mode")
|
||||
self.enable_tuple_broaden = True
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def construct(self, grads):
|
||||
"""
|
||||
Under certain circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the
|
||||
|
|
|
@ -26,7 +26,7 @@ from mindspore.ops import functional as F
|
|||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import jit
|
||||
|
||||
_grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
@ -399,7 +399,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|||
compute_input = F.depend(compute_input, clear_status)
|
||||
return status, compute_input
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def get_overflow_status(self, status, compute_output):
|
||||
"""
|
||||
Get floating-point overflow status.
|
||||
|
|
|
@ -30,7 +30,7 @@ from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFu
|
|||
ListClear_, ListReverse_, ListExtend_, ListCount_, DictClear_, DictHasKey_, DictUpdate_, \
|
||||
DictFromKeys_
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.api import ms_function, _pynative_executor, _wrap_func
|
||||
from mindspore.common.api import jit, _pynative_executor, _wrap_func
|
||||
from mindspore.common.api import _add_flags, _core
|
||||
from mindspore.ops.primitive import Primitive
|
||||
from mindspore.ops import signature as sig
|
||||
|
@ -340,21 +340,22 @@ class GradOperation(GradOperation_):
|
|||
if self.grad_fn is not None and self.fn == fn and self.weights_id == weights_id:
|
||||
return self.grad_fn
|
||||
grad_ = GradOperation(self.get_all, self.get_by_list, self.sens_param)
|
||||
# If calling Grad in GRAPH_MODE or calling Grad in ms_function, do grad in GRAPH_MODE
|
||||
# If calling Grad in GRAPH_MODE or calling Grad in functions decorated with 'jit', do grad in GRAPH_MODE
|
||||
# If calling Grad in pure PYNATIVE_MODE do grad in PYNATIVE_MODE
|
||||
# In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation.
|
||||
# In PYNATIVE_MODE calling Grad from ms_function, use the out layer after_grad do grad in GRAPH_MODE.
|
||||
# In PYNATIVE_MODE calling Grad from functions decorated with 'jit', use the out layer after_grad do
|
||||
# grad in GRAPH_MODE.
|
||||
if context.get_context("mode") == context.GRAPH_MODE:
|
||||
dynamic_shape_inputs = None
|
||||
if isinstance(fn, ms.nn.Cell):
|
||||
dynamic_shape_inputs = fn.get_inputs()
|
||||
fn.grad_ops_label = True
|
||||
if self.get_by_list:
|
||||
@ms_function(input_signature=dynamic_shape_inputs)
|
||||
@jit(input_signature=dynamic_shape_inputs)
|
||||
def after_grad(*args):
|
||||
return grad_(fn, weights)(*args)
|
||||
else:
|
||||
@ms_function(input_signature=dynamic_shape_inputs)
|
||||
@jit(input_signature=dynamic_shape_inputs)
|
||||
def after_grad(*args):
|
||||
return grad_(fn)(*args)
|
||||
elif self.pynative_:
|
||||
|
@ -368,7 +369,7 @@ class GradOperation(GradOperation_):
|
|||
return out
|
||||
else:
|
||||
grad_.pynative_ = True
|
||||
# after_grad of this branch can't use @ms_function, just directly call grad_
|
||||
# after_grad of this branch can't use @jit, just directly call grad_
|
||||
if self.get_by_list:
|
||||
def after_grad(*args, **kwargs):
|
||||
return grad_(fn, weights)(*args, **kwargs)
|
||||
|
@ -420,9 +421,9 @@ class _TaylorOperation(TaylorOperation_):
|
|||
return self.grad_fn
|
||||
taylor_grad_ = _TaylorOperation()
|
||||
|
||||
# If calling Grad in GRAPH_MODE or calling Grad in ms_function, do grad in GRAPH_MODE
|
||||
# If calling Grad in GRAPH_MODE or calling Grad in functions decorated with 'jit', do grad in GRAPH_MODE
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def after_taylor_grad(*args):
|
||||
return taylor_grad_(fn)(*args)
|
||||
|
||||
|
@ -483,25 +484,26 @@ class _Grad(GradOperation_):
|
|||
return res
|
||||
|
||||
grad_ = _Grad(self.get_by_list, self.sens_param, self.get_by_position, self.has_aux, self.get_value)
|
||||
# If calling Grad in GRAPH_MODE or calling Grad in ms_function, do grad in GRAPH_MODE
|
||||
# If calling Grad in GRAPH_MODE or calling Grad in functions decorated with 'jit', do grad in GRAPH_MODE
|
||||
# If calling Grad in pure PYNATIVE_MODE do grad in PYNATIVE_MODE
|
||||
# In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation.
|
||||
# In PYNATIVE_MODE calling Grad from ms_function, use the out layer after_grad do grad in GRAPH_MODE.
|
||||
# In PYNATIVE_MODE calling Grad from functions decorated with 'jit', use the out layer after_grad do
|
||||
# grad in GRAPH_MODE.
|
||||
if context.get_context("mode") == context.GRAPH_MODE:
|
||||
dynamic_shape_inputs = None
|
||||
if isinstance(fn, ms.nn.Cell):
|
||||
dynamic_shape_inputs = fn.get_inputs()
|
||||
if self.get_by_position:
|
||||
@ms_function(input_signature=dynamic_shape_inputs)
|
||||
@jit(input_signature=dynamic_shape_inputs)
|
||||
def after_grad(*args):
|
||||
return grad_(fn, weights, grad_position)(*args)
|
||||
else:
|
||||
if self.get_by_list:
|
||||
@ms_function(input_signature=dynamic_shape_inputs)
|
||||
@jit(input_signature=dynamic_shape_inputs)
|
||||
def after_grad(*args):
|
||||
return grad_(fn, weights)(*args)
|
||||
else:
|
||||
@ms_function(input_signature=dynamic_shape_inputs)
|
||||
@jit(input_signature=dynamic_shape_inputs)
|
||||
def after_grad(*args):
|
||||
return grad_(fn)(*args)
|
||||
elif self.pynative_:
|
||||
|
@ -522,7 +524,7 @@ class _Grad(GradOperation_):
|
|||
fn_ = fn
|
||||
if self.has_aux:
|
||||
fn_ = aux_fn
|
||||
# after_grad of this branch can't use @ms_function, just directly call grad_
|
||||
# after_grad of this branch can't use @jit, just directly call grad_
|
||||
if self.get_by_position:
|
||||
def after_grad(*args, **kwargs):
|
||||
return grad_(fn_, weights, grad_position)(*args, **kwargs)
|
||||
|
@ -589,7 +591,7 @@ class _Vmap(VmapOperation_):
|
|||
|
||||
vmap_ = self
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def after_vmap(*args):
|
||||
return vmap_(fn, in_axes, out_axes)(*args)
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
from __future__ import absolute_import
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
from mindspore.common import ms_function
|
||||
from mindspore.common import jit
|
||||
from mindspore.common import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.cell import Cell
|
||||
|
@ -647,7 +647,7 @@ def jvp(fn, inputs, v, has_aux=False):
|
|||
def grad_all(u, first_grad):
|
||||
return _grad_all(fn_)(*first_grad, u)
|
||||
|
||||
@ms_function(hash_args=fn_)
|
||||
@jit(hash_args=fn_)
|
||||
def _wrap_container(*arg):
|
||||
jvp_inputs = arg[1:]
|
||||
vectors = arg[0]
|
||||
|
@ -730,7 +730,7 @@ def linearize(fn, inputs):
|
|||
"""
|
||||
linearize_inner = _LinearizeInner()
|
||||
|
||||
@ms_function(hash_args=fn)
|
||||
@jit(hash_args=fn)
|
||||
def _wrap_container(*arg):
|
||||
args = arg[1:-1]
|
||||
vectors = arg[-1]
|
||||
|
@ -1006,7 +1006,7 @@ def jacfwd(fn, grad_position=0, has_aux=False):
|
|||
return _grad_all(aux_fn)(*first_grad, u)
|
||||
return _grad_all(fn)(*first_grad, u)
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def wrapped(*args):
|
||||
checked_grad_position = _check_grad_position(grad_position, len(args))
|
||||
primals, v, inputs_shape = _jacfwd_construct_v(args, checked_grad_position)
|
||||
|
@ -1188,7 +1188,7 @@ def jacrev(fn, grad_position=0, has_aux=False):
|
|||
res = outputs[0]
|
||||
return res
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def wrapped(*args):
|
||||
checked_grad_position = _check_grad_position(grad_position, len(args))
|
||||
outputs = fn(*args)
|
||||
|
|
|
@ -284,7 +284,7 @@ class InsertGradientOf(PrimitiveWithInfer):
|
|||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor, ops, ms_function
|
||||
>>> from mindspore import Tensor, ops, jit
|
||||
>>> a = Tensor(np.array([1.0]).astype(np.float32))
|
||||
>>> b = Tensor(np.array([0.2]).astype(np.float32))
|
||||
>>> def clip_gradient(dx):
|
||||
|
@ -306,7 +306,7 @@ class InsertGradientOf(PrimitiveWithInfer):
|
|||
... c = x * y
|
||||
... return c
|
||||
...
|
||||
... @ms_function
|
||||
... @jit
|
||||
... def f(x, y):
|
||||
... return clip_test(x, y)
|
||||
...
|
||||
|
|
|
@ -19,7 +19,7 @@ from __future__ import division
|
|||
from mindspore.nn.cell import Cell
|
||||
from mindspore.ops.operations.comm_ops import AllGather
|
||||
from mindspore.communication import GlobalComm
|
||||
from mindspore.common import ms_function
|
||||
from mindspore.common import jit
|
||||
|
||||
_ALLGATHER_CELL = None
|
||||
|
||||
|
@ -35,7 +35,7 @@ class AllGatherCell(Cell):
|
|||
self.allgather = AllGather(group)
|
||||
self.add_flags(skip_auto_parallel_compile=True)
|
||||
|
||||
@ms_function()
|
||||
@jit()
|
||||
def construct(self, x):
|
||||
x = self.allgather(x)
|
||||
|
||||
|
|
|
@ -86,7 +86,7 @@ class Shard(Shard_):
|
|||
def shard_fn(*args):
|
||||
args = (fn,) + args
|
||||
|
||||
@ms.common.ms_function(hash_args=fn)
|
||||
@ms.common.jit(hash_args=fn)
|
||||
def after_shard(*args):
|
||||
return shard_(fn, in_strategy, out_strategy, parameter_plan, device, level)(*args)
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ from functools import wraps
|
|||
import mindspore.ops as ops
|
||||
from mindspore import context
|
||||
from mindspore.common.dtype import pytype_to_dtype
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import jit
|
||||
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
|
||||
from mindspore.train.dataset_helper import _has_dynamic_shape, _check_inputs
|
||||
import mindspore.dataset as ds
|
||||
|
@ -104,7 +104,7 @@ def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
|
|||
if jit_config is None:
|
||||
dst_sink_fun = sink_fun
|
||||
else:
|
||||
dst_sink_fun = ms_function(sink_fun, jit_config=jit_config)
|
||||
dst_sink_fun = jit(sink_fun, jit_config=jit_config)
|
||||
dataset.__sink_fun__ = dst_sink_fun
|
||||
|
||||
return dataset.__sink_fun__
|
||||
|
@ -116,7 +116,7 @@ def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
|
|||
if jit_config is None:
|
||||
dst_sink_fun = sink_fun
|
||||
else:
|
||||
dst_sink_fun = ms_function(sink_fun, jit_config=jit_config)
|
||||
dst_sink_fun = jit(sink_fun, jit_config=jit_config)
|
||||
dataset.__sink_aux__.sink_funcs[key] = dst_sink_fun
|
||||
|
||||
return dst_sink_fun
|
||||
|
|
|
@ -1010,11 +1010,12 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|||
Note:
|
||||
1. When exporting AIR, ONNX format, the size of a single tensor can not exceed 2GB.
|
||||
2. When file_name does not have a suffix, the system will automatically add one according to the file_format.
|
||||
3. Mindspore functions (ms_function) export as mindir format is enabled.
|
||||
4. When export ms_function, the function should not involve class properties in calculations.
|
||||
3. Exporting functions decorated with 'jit' to mindir format is supported.
|
||||
4. When exporting a function decorated with 'jit', the function should not involve class properties in
|
||||
calculations.
|
||||
|
||||
Args:
|
||||
net (Union[Cell, ms_function]): MindSpore network.
|
||||
net (Union[Cell, function]): MindSpore network.
|
||||
inputs (Union[Tensor, Dataset, List, Tuple, Number, Bool]): It represents the inputs
|
||||
of the `net`, if the network has multiple inputs, set them together. While its type is Dataset,
|
||||
it represents the preprocess behavior of the `net`, data preprocess operations will be serialized.
|
||||
|
|
|
@ -21,7 +21,7 @@ import numpy as np
|
|||
|
||||
from mindspore import ParameterTuple
|
||||
from mindspore import nn, context
|
||||
from mindspore.common.api import _cell_graph_executor, ms_function
|
||||
from mindspore.common.api import _cell_graph_executor, jit
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
|
@ -53,7 +53,7 @@ def run_block(net, *inputs, rand_func=None, training=True):
|
|||
set_block_param_with_rand(net, rand_func)
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
def func_pynative(*inputs):
|
||||
@ms_function
|
||||
@jit
|
||||
def _func_pynative(*inputs):
|
||||
return net(*inputs)
|
||||
|
||||
|
|
|
@ -17,14 +17,14 @@
|
|||
|
||||
from mindspore import context
|
||||
from mindspore.common import ParameterTuple
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import jit
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops.composite.base import GradOperation
|
||||
|
||||
|
||||
class Bprop(Cell):
|
||||
"""
|
||||
The gradient wraper.
|
||||
The gradient wrapper.
|
||||
"""
|
||||
|
||||
def __init__(self, func, wrt_params, params, grad_op, sens):
|
||||
|
@ -90,7 +90,7 @@ def bprop(func, *inputs, grads_wrt_outputs=None, wrt: list = None, params: list
|
|||
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
def func_pynative(*inputs):
|
||||
@ms_function
|
||||
@jit
|
||||
def _func_pynative(*inputs):
|
||||
return grad(*inputs)
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ==============================================================================
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, Tensor, Parameter, ms_function
|
||||
from mindspore import context, Tensor, Parameter, jit
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import functional as F
|
||||
|
@ -44,7 +44,7 @@ def test_monad_vmap():
|
|||
|
||||
vampfunc = F.vmap(AssignNet())
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def test_monad(a):
|
||||
c = Tensor([[1, 2], [3, 4], [5, 6]], mstype.int32)
|
||||
out = vampfunc(a)
|
||||
|
|
|
@ -76,7 +76,7 @@ dict_arg = {"x": tensor_x, "y": tensor_y, "u": tensor_u}
|
|||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_non_tensor_inputs(mode):
|
||||
"""
|
||||
Feature: Construct()/ms_function input type with back propagate.
|
||||
Feature: Input type with back propagate.
|
||||
Description: Normal input type without tensor.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
@ -116,7 +116,7 @@ class GradNet1(nn.Cell):
|
|||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
||||
def test_grad_first_input_net(mode):
|
||||
"""
|
||||
Feature: Construct()/ms_function input type with back propagate.
|
||||
Feature: Input type with back propagate.
|
||||
Description: Normal input type.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
@ -183,7 +183,7 @@ class GradCell(nn.Cell):
|
|||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_grad_parameter_input(mode):
|
||||
"""
|
||||
Feature: Construct()/ms_function input type with back propagate.
|
||||
Feature: Input type with back propagate.
|
||||
Description: Grad with Parameter as input type.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
@ -209,7 +209,7 @@ def test_grad_parameter_input(mode):
|
|||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
||||
def test_grad_parameter_as_input_and_fv(mode):
|
||||
"""
|
||||
Feature: Construct()/ms_function input type with back propagate.
|
||||
Feature: Input type with back propagate.
|
||||
Description: Grad with Parameters as input type and fv.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
@ -236,7 +236,7 @@ def test_grad_parameter_as_input_and_fv(mode):
|
|||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
||||
def test_grad_same_parameter_both_input_and_fv(mode):
|
||||
"""
|
||||
Feature: Construct()/ms_function input type with back propagate.
|
||||
Feature: Input type with back propagate.
|
||||
Description: Grad with the same Parameter used as input type and fv at the same time.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
@ -304,7 +304,7 @@ class GradCellWithParameterTuple(nn.Cell):
|
|||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_grad_parameter_as_input_and_fv2(mode):
|
||||
"""
|
||||
Feature: Construct()/ms_function input type with back propagate.
|
||||
Feature: Input type with back propagate.
|
||||
Description: Grad with Parameters as input type and fv. ParameterTuple as fv.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
import pytest
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import nn
|
||||
from mindspore import Tensor, ms_function
|
||||
from mindspore import Tensor, jit
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore import context
|
||||
|
||||
|
@ -83,7 +83,7 @@ def test_single_while():
|
|||
Description: The else branches of while loops aren't supported.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while(x, y):
|
||||
while x > y:
|
||||
y += x
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mindspore import context, ms_function
|
||||
from mindspore import context, jit
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import composite as C
|
||||
|
@ -267,7 +267,7 @@ def test_single_for():
|
|||
Description: The else branches of for loops aren't supported.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for(x, y):
|
||||
for _ in range(3):
|
||||
y += x
|
||||
|
@ -294,7 +294,7 @@ def test_single_for_with_not_iterable_object():
|
|||
Description: The else branches of for loops aren't supported.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for_with_not_iterable_object():
|
||||
ret = 0
|
||||
a = 1
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import ms_function
|
||||
from mindspore import jit
|
||||
from mindspore import nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import composite as C
|
||||
|
@ -95,7 +95,7 @@ def test_if_in_for_dict():
|
|||
Description: Execute 'for x in xs' when xs is dictionary.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for(xs):
|
||||
result = 0
|
||||
ys = {'b': 0, 'g': 0}
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
import mindspore as ms
|
||||
from mindspore import context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import jit
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
@ -67,7 +67,7 @@ class TestClass(nn.Cell):
|
|||
out = self.test(state, init_global_obs)
|
||||
return out
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def test(self, state, init_global_obs):
|
||||
num_agent = self.zero
|
||||
while self.less(num_agent, 3):
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import mindspore
|
||||
from mindspore import context, nn, ops, Tensor, CSRTensor, Parameter, ms_function, mutable
|
||||
from mindspore import context, nn, ops, Tensor, CSRTensor, Parameter, jit, mutable
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
|
||||
|
@ -64,7 +64,7 @@ def test_repeat_control_arrow_for_stack_actor():
|
|||
assert out == result
|
||||
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def switch_op(x, y):
|
||||
z1 = y + 1
|
||||
z2 = Tensor(5, mindspore.int32)
|
||||
|
@ -87,7 +87,7 @@ def test_switch_op():
|
|||
assert out == 5
|
||||
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def switch_single_op(x, y, z):
|
||||
return F.switch(x, y, z)
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ import numpy as np
|
|||
from mindspore.nn import Cell
|
||||
from mindspore.common import Tensor, dtype, Parameter
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import ms_function
|
||||
from mindspore import jit
|
||||
import mindspore.ops.functional as F
|
||||
|
||||
|
||||
|
@ -78,7 +78,7 @@ def test_poly_delay_specialize():
|
|||
"""
|
||||
pow_ops = P.Pow()
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def poly_node_network(x, y):
|
||||
def function_h():
|
||||
pow_res = pow_ops(x, x)
|
||||
|
|
|
@ -14,11 +14,11 @@
|
|||
# ============================================================================
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor, ms_function
|
||||
from mindspore import Tensor, jit
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def hof(x):
|
||||
def f(x):
|
||||
return x + 3
|
||||
|
|
|
@ -17,7 +17,7 @@ import pytest
|
|||
|
||||
import mindspore.context as context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common import ms_function
|
||||
from mindspore.common import jit
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
|
||||
|
@ -32,7 +32,7 @@ c4 = Tensor([0], mstype.int32)
|
|||
c5 = Tensor([14], mstype.int32)
|
||||
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def simple_if(x, y):
|
||||
if x < y:
|
||||
x = x + 1
|
||||
|
@ -42,7 +42,7 @@ def simple_if(x, y):
|
|||
return x
|
||||
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def if_by_if(x, y):
|
||||
if x < y:
|
||||
x = x + 1
|
||||
|
@ -52,7 +52,7 @@ def if_by_if(x, y):
|
|||
return x
|
||||
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def if_in_if(x, y, z):
|
||||
out = c4
|
||||
if x < y:
|
||||
|
@ -65,7 +65,7 @@ def if_in_if(x, y, z):
|
|||
return out
|
||||
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def simple_while(x, y):
|
||||
y = y + 4
|
||||
while x < y:
|
||||
|
@ -74,7 +74,7 @@ def simple_while(x, y):
|
|||
return x
|
||||
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def while_by_while(x, y, z):
|
||||
while x < y:
|
||||
x = x + 1
|
||||
|
@ -85,7 +85,7 @@ def while_by_while(x, y, z):
|
|||
return x
|
||||
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def while_in_while(x, y, z):
|
||||
out = c4
|
||||
while x < y:
|
||||
|
@ -98,7 +98,7 @@ def while_in_while(x, y, z):
|
|||
return out
|
||||
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def while_by_while_in_while(x, y, z):
|
||||
out = c4
|
||||
while x < c2:
|
||||
|
@ -115,7 +115,7 @@ def while_by_while_in_while(x, y, z):
|
|||
return out
|
||||
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def while_in_while_in_while(x, y, z):
|
||||
out = c4
|
||||
while x < c2:
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor, ms_function
|
||||
from mindspore import Tensor, jit
|
||||
from mindspore.common import dtype as mstype, Parameter
|
||||
from mindspore.nn import Cell
|
||||
import pytest
|
||||
|
@ -31,7 +31,7 @@ def test_while_return_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -57,7 +57,7 @@ def test_if_return_in_while_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -86,7 +86,7 @@ def test_if_return_else_break_in_while_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -114,7 +114,7 @@ def test_if_return_else_return_in_while_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -142,7 +142,7 @@ def test_if_break_else_return_in_while_in_else_take_break():
|
|||
Expectation: take the break branch, success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -170,7 +170,7 @@ def test_if_break_else_return_in_while_in_else_take_return():
|
|||
Expectation: take the return branch, success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -198,7 +198,7 @@ def test_while_return_in_while_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -225,7 +225,7 @@ def test_if_return_in_while_in_while_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -255,7 +255,7 @@ def test_if_return_else_return_in_while_in_while_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -284,7 +284,7 @@ def test_while_return_after_if_else_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -315,7 +315,7 @@ def test_if_else_after_while_return_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -346,7 +346,7 @@ def test_if_return_after_if_else_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -377,7 +377,7 @@ def test_if_else_after_if_return_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -407,7 +407,7 @@ def test_while_return_in_else_after_if_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if x > y:
|
||||
x = x + y
|
||||
|
@ -438,7 +438,7 @@ def test_if_else_after_by_while_return_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -468,7 +468,7 @@ def test_if_return_in_else_after_if_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if x > y:
|
||||
x = x + y
|
||||
|
@ -499,7 +499,7 @@ def test_if_else_after_by_if_return_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -530,7 +530,7 @@ def test_if_else_in_if_while_return_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -560,7 +560,7 @@ def test_if_else_in_if_if_return_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -590,7 +590,7 @@ def test_for_return_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -616,7 +616,7 @@ def test_if_return_in_for_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -645,7 +645,7 @@ def test_if_return_else_break_in_for_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -673,7 +673,7 @@ def test_if_return_else_return_in_for_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -701,7 +701,7 @@ def test_for_return_in_for_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -728,7 +728,7 @@ def test_if_return_in_for_in_for_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -758,7 +758,7 @@ def test_if_return_else_return_in_for_in_for_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -787,7 +787,7 @@ def test_for_return_after_if_else_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -818,7 +818,7 @@ def test_if_else_after_for_return_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -849,7 +849,7 @@ def test_for_return_in_else_after_if_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if x > y:
|
||||
x = x + y
|
||||
|
@ -880,7 +880,7 @@ def test_if_else_after_by_for_return_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -910,7 +910,7 @@ def test_if_else_in_if_for_return_in_else():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, bias):
|
||||
if bias > y:
|
||||
y = x + y
|
||||
|
@ -940,7 +940,7 @@ def test_if_by_if_break_in_if_in_while():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, z):
|
||||
out = z
|
||||
while x < y:
|
||||
|
@ -970,7 +970,7 @@ def test_if_raise_raise():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, z):
|
||||
out = z
|
||||
if x >= y:
|
||||
|
@ -997,7 +997,7 @@ def test_if_raise_not_raise():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, z):
|
||||
out = z
|
||||
if x >= y:
|
||||
|
@ -1023,7 +1023,7 @@ def test_if_assert_success():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, z):
|
||||
out = z
|
||||
out = z
|
||||
|
@ -1051,7 +1051,7 @@ def test_if_assert_failure():
|
|||
Expectation: success
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def foo(x, y, z):
|
||||
out = z
|
||||
if x >= y:
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor, ms_function
|
||||
from mindspore import Tensor, jit
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import ops
|
||||
import mindspore.nn as nn
|
||||
|
@ -23,7 +23,7 @@ ZERO = Tensor([0], mstype.int32)
|
|||
ONE = Tensor([1], mstype.int32)
|
||||
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def f(x):
|
||||
y = ZERO
|
||||
if x < 0:
|
||||
|
@ -38,7 +38,7 @@ def f(x):
|
|||
return z
|
||||
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def fr(x):
|
||||
y = ZERO
|
||||
if x < 0:
|
||||
|
@ -53,7 +53,7 @@ def fr(x):
|
|||
return z
|
||||
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def f_pythonerr(x):
|
||||
if x > 0:
|
||||
return f_pythonerr(x - 1)
|
||||
|
@ -69,7 +69,7 @@ def test_python_error():
|
|||
assert 'not defined' in str(e)
|
||||
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def f_recrusive_endless(x):
|
||||
if x > 0:
|
||||
return f_recrusive_endless(x - 1)
|
||||
|
@ -94,7 +94,7 @@ def test_endless():
|
|||
assert 'loop' in str(e)
|
||||
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def f_ok(x):
|
||||
if x > 0:
|
||||
return f_ok(x - 1) + 1
|
||||
|
|
|
@ -13,11 +13,11 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor, ms_function
|
||||
from mindspore import Tensor, jit
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def t1_while(x, y):
|
||||
y = y + 4
|
||||
while x < y:
|
||||
|
|
|
@ -16,7 +16,7 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, ms_function
|
||||
from mindspore import context, jit
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.serialization import export, load
|
||||
|
||||
|
@ -59,7 +59,12 @@ def test_single_while():
|
|||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ms_function_while():
|
||||
def test_jit_function_while():
|
||||
"""
|
||||
Features: Control flow.
|
||||
Description: Test while in @jit decorated function.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
network = SingleWhileNet()
|
||||
|
||||
|
@ -76,7 +81,7 @@ def test_ms_function_while():
|
|||
loaded_net = nn.GraphCell(graph)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def run_graph(x, y):
|
||||
outputs = loaded_net(x, y)
|
||||
return outputs
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
from mindspore import dtype as mstype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
@ -32,7 +32,7 @@ def test_single_if_4():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if():
|
||||
x = Tensor(7).astype("int32")
|
||||
y = Tensor(0).astype("int32")
|
||||
|
@ -55,7 +55,7 @@ def test_single_if_two_cond():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if():
|
||||
x = Tensor(1)
|
||||
y = np.array(2)
|
||||
|
@ -76,7 +76,7 @@ def test_single_if_builtin_function_abs():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if():
|
||||
x = Tensor(-11, mstype.float32)
|
||||
if abs(x) > Tensor(np.array(2)):
|
||||
|
@ -96,7 +96,7 @@ def test_single_if_builtin_function_abs_min():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if():
|
||||
x = Tensor(-11, mstype.float32)
|
||||
y = Tensor(12, mstype.float32)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import pytest
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor, ms_function, context, Parameter
|
||||
from mindspore import Tensor, jit, context, Parameter
|
||||
from mindspore.nn import Cell
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
@ -33,7 +33,7 @@ def test_single_while_1():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while():
|
||||
x = Tensor(1)
|
||||
while x < Tensor(7):
|
||||
|
@ -54,7 +54,7 @@ def test_single_while_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while():
|
||||
x = Tensor(7).astype("int32")
|
||||
y = Tensor(0).astype("int32")
|
||||
|
@ -76,7 +76,7 @@ def test_single_while_3():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while():
|
||||
x = Tensor(7).astype("int32")
|
||||
y = Tensor(0).astype("int32")
|
||||
|
@ -99,7 +99,7 @@ def test_single_while_two_cond_1():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while():
|
||||
x = Tensor(1)
|
||||
y = Tensor(8)
|
||||
|
@ -123,7 +123,7 @@ def test_single_while_two_cond_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while():
|
||||
x = Tensor(7).astype("int32")
|
||||
y = Tensor(0).astype("int32")
|
||||
|
@ -172,7 +172,7 @@ def test_single_while_numpy():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while():
|
||||
x = np.array([1, 2, 3, 4, 5])
|
||||
y = np.array([0, 2, 4, 6, 8])
|
||||
|
@ -193,7 +193,7 @@ def test_single_while_two_cond_3():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while():
|
||||
x = np.array([1, 2, 3, 4, 5])
|
||||
y = Tensor(1)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
from mindspore import dtype as mstype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
@ -32,7 +32,7 @@ def test_single_for_1():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for():
|
||||
x = Tensor(7).astype("int32")
|
||||
y = Tensor(0).astype("int32")
|
||||
|
@ -54,7 +54,7 @@ def test_single_for_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for():
|
||||
x = Tensor(7).astype("int32")
|
||||
y = Tensor(0).astype("int32")
|
||||
|
@ -78,7 +78,7 @@ def test_single_for_zip():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for():
|
||||
tuple_x = (Tensor(1).astype("int32"), Tensor(3).astype("int32"), Tensor(5).astype("int32"))
|
||||
sum_x = Tensor(0).astype("int32")
|
||||
|
@ -101,7 +101,7 @@ def test_single_for_builtin_function_int():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for():
|
||||
x = np.array(1.1)
|
||||
for _ in range(3):
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow if in if scenario"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -31,7 +31,7 @@ def test_if_in_if_5():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_in_if():
|
||||
x = list([1, 2, 3, 4])
|
||||
if max(x) >= 4:
|
||||
|
@ -55,7 +55,7 @@ def test_if_else_in_if_else_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_in_if():
|
||||
x = Tensor(10)
|
||||
y = Tensor(7)
|
||||
|
@ -90,7 +90,7 @@ def test_if_in_if_multi_conds_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_in_if():
|
||||
x = Tensor(10)
|
||||
y = Tensor(2)
|
||||
|
@ -119,7 +119,7 @@ def test_if_in_if_4():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_in_if():
|
||||
x = np.array([1, 2, 3, 4, 5])
|
||||
y = x % 2
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow if in while scenario"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -31,7 +31,7 @@ def test_if_in_while_1():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_in_while():
|
||||
x = Tensor(1)
|
||||
y = Tensor(0)
|
||||
|
@ -55,7 +55,7 @@ def test_if_in_while_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_in_while():
|
||||
x = Tensor(1)
|
||||
while x < Tensor(5):
|
||||
|
@ -78,7 +78,7 @@ def test_if_in_while_3():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_in_while():
|
||||
x = Tensor(1)
|
||||
y = Tensor(0)
|
||||
|
@ -104,7 +104,7 @@ def test_if_in_while_4():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_in_while():
|
||||
x = Tensor(1)
|
||||
y = Tensor(0)
|
||||
|
@ -135,7 +135,7 @@ def test_if_in_while_numpy():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_in_while():
|
||||
x = np.array([1, 2])
|
||||
y = np.array([3, 2])
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -31,7 +31,7 @@ def test_if_in_for_tensor():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for():
|
||||
x = Tensor(7)
|
||||
y = Tensor(0)
|
||||
|
@ -54,7 +54,7 @@ def test_if_in_for_tensor_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for():
|
||||
x = Tensor(7)
|
||||
y = Tensor(0)
|
||||
|
@ -79,7 +79,7 @@ def test_if_in_for_tensor_3():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for():
|
||||
x = Tensor(7)
|
||||
y = Tensor(0)
|
||||
|
@ -103,7 +103,7 @@ def test_if_in_for_numpy_5():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for():
|
||||
x = np.array([1, 2, 3, 4])
|
||||
y = (Tensor(1), Tensor(3), Tensor(5))
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor, ms_function, context, nn, Parameter
|
||||
from mindspore import Tensor, jit, context, nn, Parameter
|
||||
import numpy as np
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
@ -33,7 +33,7 @@ def test_while_in_if_1():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if():
|
||||
x = Tensor([1])
|
||||
if x > Tensor([0]):
|
||||
|
@ -58,7 +58,7 @@ def test_while_in_if_2():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while():
|
||||
x = Tensor([6]).astype("int32")
|
||||
y = Tensor([0]).astype("int32")
|
||||
|
@ -84,7 +84,7 @@ def test_while_in_if_3():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while():
|
||||
x = Tensor([7]).astype("int32")
|
||||
y = Tensor([0]).astype("int32")
|
||||
|
@ -111,7 +111,7 @@ def test_while_two_cond_in_if_1():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while():
|
||||
x = Tensor([1])
|
||||
y = Tensor([8])
|
||||
|
@ -138,7 +138,7 @@ def test_while_two_cond_in_if_2():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while():
|
||||
x = Tensor([7]).astype("int32")
|
||||
y = Tensor([0]).astype("int32")
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor, ms_function, context, nn, Parameter
|
||||
from mindspore import Tensor, jit, context, nn, Parameter
|
||||
import numpy as np
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
@ -33,7 +33,7 @@ def test_while_in_while_1():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if():
|
||||
x = Tensor([1])
|
||||
y = Tensor([3])
|
||||
|
@ -60,7 +60,7 @@ def test_while_in_while_2():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while():
|
||||
|
||||
x = Tensor([3]).astype("int32")
|
||||
|
@ -86,7 +86,7 @@ def test_while_in_while_3():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while():
|
||||
x = Tensor([7]).astype("int32")
|
||||
y = Tensor([0]).astype("int32")
|
||||
|
@ -114,7 +114,7 @@ def test_while_in_while_with_two_cond_1():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while():
|
||||
x = Tensor([1])
|
||||
y = Tensor([8])
|
||||
|
@ -141,7 +141,7 @@ def test_while_in_while_with_two_cond_2():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while():
|
||||
x = Tensor([7]).astype("int32")
|
||||
y = Tensor([0]).astype("int32")
|
||||
|
@ -167,7 +167,7 @@ def test_while_in_while_with_two_cond_3():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while():
|
||||
x = Tensor([7]).astype("int32")
|
||||
y = Tensor([0]).astype("int32")
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor, ms_function, context, nn, Parameter
|
||||
from mindspore import Tensor, jit, context, nn, Parameter
|
||||
import numpy as np
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
@ -33,7 +33,7 @@ def test_while_in_for_1():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for():
|
||||
x = Tensor([7]).astype("int32")
|
||||
y = Tensor([0]).astype("int32")
|
||||
|
@ -61,7 +61,7 @@ def test_while_in_for_zip():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for():
|
||||
tuple_x = (Tensor(1).astype("int32"), Tensor(3).astype("int32"), Tensor(5).astype("int32"))
|
||||
sum_x = Tensor([0]).astype("int32")
|
||||
|
@ -90,7 +90,7 @@ def test_while_in_for_numpy():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for():
|
||||
x = np.array([1, 3, 5])
|
||||
y = np.array([0, 2, 4])
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context, nn
|
||||
from mindspore import Tensor, jit, context, nn
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
@ -33,7 +33,7 @@ def test_for_in_if_tensor():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for_in_if():
|
||||
x = Tensor(1)
|
||||
y = Tensor(0)
|
||||
|
@ -57,7 +57,7 @@ def test_for_in_if_tensor_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for_in_if():
|
||||
x = Tensor(1)
|
||||
y = Tensor(0)
|
||||
|
@ -115,7 +115,7 @@ def test_for_in_if_numpy():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for_in_if():
|
||||
x = np.array([1, 1, 1])
|
||||
y = list((4, 6, -2))
|
||||
|
@ -138,7 +138,7 @@ def test_for_in_if_isinstance_raise():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for_in_if(x):
|
||||
if isinstance(x, Tensor):
|
||||
print("before add:", x)
|
||||
|
@ -160,7 +160,7 @@ def test_for_in_if_dict_isinstance():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for_in_if():
|
||||
dict_x = {'a': 1, 'b': 2}
|
||||
res = 0
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context, nn
|
||||
from mindspore import Tensor, jit, context, nn
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import operations as P
|
||||
|
@ -34,7 +34,7 @@ def test_for_in_while_tensor():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for_in_while():
|
||||
x = Tensor(1)
|
||||
y = Tensor(0)
|
||||
|
@ -59,7 +59,7 @@ def test_for_in_while_numpy_append():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for_in_while():
|
||||
x = np.array([[1, 2, 3], [3, 4, 5], [4, 5, 6]])
|
||||
y = Tensor(0)
|
||||
|
@ -119,7 +119,7 @@ def test_for_in_while_print():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for_in_while():
|
||||
x = Tensor(1)
|
||||
y = Tensor(0)
|
||||
|
@ -145,7 +145,7 @@ def test_for_in_while_round():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_for_in_while():
|
||||
x = 3.14159
|
||||
y = 3
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
""" test graph fallback control flow if after if scenario"""
|
||||
import pytest
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -30,7 +30,7 @@ def test_if_after_if_tensor():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_if():
|
||||
x = Tensor(1)
|
||||
y = Tensor(0)
|
||||
|
@ -54,7 +54,7 @@ def test_if_after_if_tensor_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_if():
|
||||
x = Tensor(1)
|
||||
y = Tensor(0)
|
||||
|
@ -84,7 +84,7 @@ def test_if_after_if_tensor_3():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_if(a):
|
||||
if a > 15:
|
||||
y = Tensor(1)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
""" test graph fallback control flow if after while scenario"""
|
||||
import pytest
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -30,7 +30,7 @@ def test_if_after_while_tensor():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_while():
|
||||
x = Tensor(1)
|
||||
y = Tensor(0)
|
||||
|
@ -56,7 +56,7 @@ def test_if_after_while_tensor_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_while():
|
||||
x = Tensor(1)
|
||||
y = Tensor(0)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -30,7 +30,7 @@ def test_if_after_if_in_if_tensor():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_if_in_if():
|
||||
x = Tensor(1)
|
||||
y = Tensor(10)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -31,7 +31,7 @@ def test_if_after_if_in_while_tensor():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_if_in_while():
|
||||
x = Tensor(1)
|
||||
y = Tensor(10)
|
||||
|
@ -60,7 +60,7 @@ def test_if_after_if_in_while_numpy():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_if_in_while():
|
||||
x = np.array([1, 2])
|
||||
y = np.array([3, 2])
|
||||
|
@ -88,7 +88,7 @@ def test_if_after_if_in_while_tensor_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_if_in_while():
|
||||
x = Tensor(5)
|
||||
y = Tensor(2)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -30,7 +30,7 @@ def test_if_after_if_in_for_tensor_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_if_in_for():
|
||||
x = Tensor(5)
|
||||
y = Tensor(2)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -30,7 +30,7 @@ def test_if_after_while_in_if_tensor():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_while_in_if():
|
||||
x = Tensor(1)
|
||||
y = Tensor(2)
|
||||
|
@ -57,7 +57,7 @@ def test_if_after_while_in_if_tensor_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_while_in_if():
|
||||
x = Tensor(1)
|
||||
y = Tensor(2)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -31,7 +31,7 @@ def test_if_after_while_in_while_tensor():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_while_in_while():
|
||||
x = Tensor(1)
|
||||
y = Tensor(2)
|
||||
|
@ -58,7 +58,7 @@ def test_if_after_while_in_while_tensor_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_while_in_while():
|
||||
x = Tensor(1)
|
||||
y = Tensor(2)
|
||||
|
@ -86,7 +86,7 @@ def test_if_after_while_in_while_numpy():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_while_in_while():
|
||||
x = np.array([1, 2, 3, 4])
|
||||
y = Tensor(5)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -31,7 +31,7 @@ def test_if_after_while_in_for_tensor():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_while_in_for():
|
||||
x = Tensor(1)
|
||||
y = Tensor(2)
|
||||
|
@ -58,7 +58,7 @@ def test_if_after_while_in_for_tensor_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_while_in_for():
|
||||
x = Tensor(1)
|
||||
y = Tensor(2)
|
||||
|
@ -86,7 +86,7 @@ def test_if_after_while_in_for_numpy():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_while_in_for():
|
||||
x = np.array([1, 2, 3, 4])
|
||||
y = Tensor(5)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -31,7 +31,7 @@ def test_if_after_for_in_if_tensor():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_for_in_if():
|
||||
x = Tensor([1])
|
||||
y = Tensor([2])
|
||||
|
@ -53,7 +53,7 @@ def test_if_after_for_in_if_numpy():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_for_in_if():
|
||||
x = np.array([1, 2])
|
||||
y = np.array([3, 4])
|
||||
|
@ -80,7 +80,7 @@ def test_if_after_for_in_if_tensor_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_for_in_if():
|
||||
x = Tensor([1])
|
||||
y = Tensor([2])
|
||||
|
@ -109,7 +109,7 @@ def test_if_after_for_in_if_numpy_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_for_in_for():
|
||||
x = np.array([3, 2])
|
||||
y = Tensor(np.array([0, 2, 4, 6, 8]))
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -31,7 +31,7 @@ def test_if_after_for_in_while_tensor():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_for_in_while():
|
||||
x = Tensor([1])
|
||||
y = Tensor([2])
|
||||
|
@ -58,7 +58,7 @@ def test_if_after_for_in_while_tensor_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_for_in_while():
|
||||
x = Tensor([1])
|
||||
y = Tensor([2])
|
||||
|
@ -85,7 +85,7 @@ def test_if_after_for_in_while_numpy_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_for_in_while():
|
||||
x = np.array([5, 4, 3, 2, 1])
|
||||
y = (Tensor(1), Tensor(3), Tensor(5))
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -30,7 +30,7 @@ def test_if_after_for_in_for_numpy():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_if_after_for_in_for():
|
||||
x = np.array([3, 2])
|
||||
y = Tensor(np.array([0, 2, 4, 6, 8]))
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -31,7 +31,7 @@ def test_while_after_if_tensor():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_if():
|
||||
x = Tensor([1])
|
||||
y = Tensor([2])
|
||||
|
@ -57,7 +57,7 @@ def test_while_after_if_tensor_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_if():
|
||||
x = Tensor([1])
|
||||
y = Tensor([2])
|
||||
|
@ -87,7 +87,7 @@ def test_while_after_if_numpy():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_if():
|
||||
x = np.array([3, 2])
|
||||
y = Tensor(np.array([3, 2]))
|
||||
|
@ -113,7 +113,7 @@ def test_while_after_if_numpy_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_if():
|
||||
x = np.array([3, 2])
|
||||
y = [1, 2, 3, 4]
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -31,7 +31,7 @@ def test_while_after_while_tensor():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_while():
|
||||
x = Tensor([1])
|
||||
y = Tensor([2])
|
||||
|
@ -57,7 +57,7 @@ def test_while_after_while_tensor_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_while():
|
||||
x = Tensor([1])
|
||||
y = Tensor([2])
|
||||
|
@ -78,7 +78,7 @@ def test_while_after_while_numpy_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_while():
|
||||
x = np.array([3, 2])
|
||||
y = [1, 2, 3, 4]
|
||||
|
@ -103,7 +103,7 @@ def test_while_after_while_numpy():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_while():
|
||||
x = [1, 2, 3, 4]
|
||||
y = Tensor([8])
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -31,7 +31,7 @@ def test_while_after_for_tensor_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_for():
|
||||
x = Tensor([1])
|
||||
y = Tensor([2])
|
||||
|
@ -53,7 +53,7 @@ def test_while_after_for_numpy_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_for():
|
||||
x = np.array([3, 2])
|
||||
y = [1, 2, 3, 4]
|
||||
|
@ -81,7 +81,7 @@ def test_while_after_for_tensor():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_for():
|
||||
x = Tensor([1])
|
||||
y = Tensor([2])
|
||||
|
@ -107,7 +107,7 @@ def test_while_after_for_numpy():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_for():
|
||||
x = [1, 2, 3, 4, 5]
|
||||
y = Tensor([3])
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -31,7 +31,7 @@ def test_while_after_if_in_if_tensor():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_if_in_if():
|
||||
x = Tensor([1])
|
||||
y = Tensor([2])
|
||||
|
@ -63,7 +63,7 @@ def test_while_after_if_in_if_tensor_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_if_in_if():
|
||||
x = Tensor([1])
|
||||
y = Tensor([2])
|
||||
|
@ -89,7 +89,7 @@ def test_while_after_if_in_if_numpy():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_if_in_if():
|
||||
x = np.array([1])
|
||||
y = np.array([10])
|
||||
|
@ -113,7 +113,7 @@ def test_while_after_if_in_if_numpy_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_if_in_if():
|
||||
x = np.array([1])
|
||||
y = np.array([10])
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -31,7 +31,7 @@ def test_while_after_if_in_while_tensor():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_if_in_while():
|
||||
x = Tensor([1])
|
||||
y = Tensor([2])
|
||||
|
@ -64,7 +64,7 @@ def test_while_after_if_in_while_numpy_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_if_in_while():
|
||||
x = np.array([1])
|
||||
y = np.array([10])
|
||||
|
@ -95,7 +95,7 @@ def test_while_after_if_in_while_tensor_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_if_in_while():
|
||||
x = Tensor([1])
|
||||
y = Tensor([2])
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -31,7 +31,7 @@ def test_while_after_if_in_for_tensor():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_if_in_for():
|
||||
x = Tensor([1])
|
||||
y = Tensor([2])
|
||||
|
@ -63,7 +63,7 @@ def test_while_after_if_in_for_tensor_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_if_in_for():
|
||||
x = Tensor([1])
|
||||
y = Tensor([2])
|
||||
|
@ -94,7 +94,7 @@ def test_while_after_if_in_for_numpy_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_if_in_for():
|
||||
x = np.array([1])
|
||||
y = np.array([10])
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -31,7 +31,7 @@ def test_while_after_while_in_if_tensor():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_while_in_if():
|
||||
x = Tensor([1])
|
||||
y = Tensor([2])
|
||||
|
@ -60,7 +60,7 @@ def test_while_after_while_in_if_tensor_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_while_in_if():
|
||||
x = Tensor([3])
|
||||
y = Tensor([5])
|
||||
|
@ -95,7 +95,7 @@ def test_while_after_while_in_if_numpy_2():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_while_in_if():
|
||||
x = Tensor([1])
|
||||
y = Tensor([2])
|
||||
|
@ -124,7 +124,7 @@ def test_while_after_while_in_if_numpy():
|
|||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
@jit
|
||||
def control_flow_while_after_while_in_if():
|
||||
x = Tensor([1])
|
||||
y = Tensor([2])
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue