Add api mindspore.jit and mindspore.jit_class.

1、ms_function will be removed in a future version and replaced with jit.
2、ms_class will be removed in a future version and replaced with
jit_class.
This commit is contained in:
huangbingjian 2022-10-18 09:01:45 +08:00
parent 8f81c29530
commit b98fc0021a
459 changed files with 2304 additions and 2090 deletions

View File

@ -23,6 +23,7 @@ https://arxiv.org/abs/1409.3215-
https://arxiv.org/abs/1706.02515alpha
https://discuss.tvm.ai/t/pool2d-gives-bad-output-for-integer-inputs/3377low
http://www.apache.org/licenses/LICENSE-2.0////
http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.21.1417
https://raw.githubusercontent.com/tvmai/tvmai.github.io/master/images/docs/inferbound/gatherbound.png:align
https://raw.githubusercontent.com/tvmai/tvmai.github.io/master/images/docs/inferbound/inferbound_traversal.png:align
https://raw.githubusercontent.com/tvmai/tvmai.github.io/master/images/relay/let_scope.png:align

View File

@ -546,14 +546,14 @@ MindSpore Numpy与MindSpore特性结合
mindspore.numpy能够充分利用MindSpore的强大功能实现算子的自动微分并使用图模式加速运算帮助用户快速构建高效的模型。同时MindSpore还支持多种后端设备包括Ascend、GPU和CPU等用户可以根据自己的需求灵活设置。以下提供了几种常用方法
- `ms_function`: 将代码包裹进图模式,用于提高代码运行效率。
- `jit` 装饰器: 将代码包裹进图模式,用于提高代码运行效率。
- `GradOperation`: 用于自动求导。
- `mindspore.set_context`: 用于设置运行模式和后端设备等。
- `mindspore.nn.Cell`: 用于建立深度学习模型。
使用示例如下:
- ms_function使用示例
- `jit` 装饰器使用示例
首先,以神经网络里经常使用到的矩阵乘与矩阵加算子为例:
@ -585,13 +585,13 @@ mindspore.numpy能够充分利用MindSpore的强大功能实现算子的自
[2816. 2816. 2816. 2816.]]
对上述示例,我们可以借助 `ms_function` 将所有算子编译到一张静态图里以加快运行效率,示例如下:
对上述示例,我们可以借助 `jit` 装饰器将所有算子编译到一张静态图里以加快运行效率,示例如下:
.. code-block:: python
from mindspore import ms_function
from mindspore import jit
forward_compiled = ms_function(forward)
forward_compiled = jit(forward)
print(forward(x, w1, b1, w2, b2, w3, b3))
运行结果如下:
@ -602,11 +602,11 @@ mindspore.numpy能够充分利用MindSpore的强大功能实现算子的自
[2816. 2816. 2816. 2816.]]
.. note::
目前静态图不支持在Python交互式模式下运行并且有部分语法限制。`ms_function` 的更多信息可参考 `API ms_function <https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/mindspore.ms_function.html>`_
目前静态图不支持在Python交互式模式下运行并且有部分语法限制。
- GradOperation使用示例
`GradOperation` 可以实现自动求导。以下示例可以实现对上述没有用 `ms_function` 修饰的 `forward` 函数定义的计算求导。
`GradOperation` 可以实现自动求导。以下示例可以实现对上述没有用 `jit` 修饰的 `forward` 函数定义的计算求导。
.. code-block:: python
@ -630,15 +630,15 @@ mindspore.numpy能够充分利用MindSpore的强大功能实现算子的自
...
Tensor(shape=[4], dtype=Float32, value= [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]))
如果要对 `ms_function` 修饰的 `forward` 计算求导,需要提前使用 `set_context` 设置运算模式为图模式,示例如下:
如果要对 `jit` 修饰的 `forward` 计算求导,需要提前使用 `set_context` 设置运算模式为图模式,示例如下:
.. code-block:: python
from mindspore import ms_function, set_context, GRAPH_MODE
from mindspore import jit, set_context, GRAPH_MODE
set_context(mode=GRAPH_MODE)
grad_all = ops.composite.GradOperation(get_all=True)
print(grad_all(ms_function(forward))(x, w1, b1, w2, b2, w3, b3))
print(grad_all(jit(forward))(x, w1, b1, w2, b2, w3, b3))
运行结果如下:

View File

@ -136,8 +136,10 @@ mindspore
:toctree: mindspore
mindspore.JitConfig
mindspore.ms_function
mindspore.jit
mindspore.jit_class
mindspore.ms_class
mindspore.ms_function
mindspore.mutable
日志

View File

@ -8,11 +8,11 @@ mindspore.export
.. note::
- 当导出文件格式为AIR、ONNX时单个Tensor的大小不能超过2GB。
- 当 `file_name` 没有后缀时,系统会根据 `file_format` 自动添加后缀。
- 现已支持将Mindspore function (ms_function) 导出成MINDIR格式文件。
- 当导出ms_function时,函数内不能包含有类属性参与的计算。
- 现已支持将 `jit` 修饰的函数导出成MINDIR格式文件。
- 当导出 `jit` 修饰的函数时,函数内不能包含有类属性参与的计算。
参数:
- **net** (Union[Cell, ms_function]) - MindSpore网络结构。
- **net** (Union[Cell, function]) - MindSpore网络结构。
- **inputs** (Union[Tensor, Dataset, List, Tuple, Number, Bool]) - 网络的输入,如果网络有多个输入,需要一同传入。当传入的类型为 `Dataset`将会把数据预处理行为同步保存起来。需要手动调整batch的大小当前仅支持获取 `Dataset``image` 列。
- **file_name** (str) - 导出模型的文件名称。
- **file_format** (str) - MindSpore目前支持导出"AIR""ONNX"和"MINDIR"格式的模型。

View File

@ -0,0 +1,20 @@
mindspore.jit
=============
.. py:function:: mindspore.jit(fn=None, input_signature=None, hash_args=None, jit_config=None)
将Python函数编译为一张可调用的MindSpore图。
MindSpore可以在运行时对图进行优化。
参数:
- **fn** (Function) - 要编译成图的Python函数。默认值None。
- **input_signature** (Tensor) - 用于表示输入参数的Tensor。Tensor的shape和dtype将作为函数的输入shape和dtype。默认值None。
- **hash_args** (Union[Object, List or Tuple of Objects]) - `fn` 里面用到的自由变量,比如外部函数或类对象,再次调用时若 `hash_args` 出现变化会触发重新编译。默认值None。
- **jit_config** (JitConfig) - 编译时所使用的JitConfig配置项详细可参考 :class:`mindspore.JitConfig`。默认值None。
.. note::
- 如果指定了 `input_signature` ,则 `fn` 的每个输入都必须是Tensor。并且 `fn` 的输入参数将不会接受 `**kwargs` 参数。
返回:
函数,如果 `fn` 不是None则返回一个已经将输入 `fn` 编译成图的可执行函数;如果 `fn` 为None则返回一个装饰器。当这个装饰器使用单个 `fn` 参数进行调用时,等价于 `fn` 不是None的场景。

View File

@ -0,0 +1,18 @@
mindspore.jit_class
===================
.. py:function:: mindspore.jit_class(cls)
用户自定义类的类装饰器。
MindSpore可以通过jit_class识别用户定义的类从而获取这些类的属性和方法。
参数:
- **cls** (Class) - 用户自定义的类。
返回:
类。
异常:
- **TypeError** - 如果 jit_class 用于非 class 类型或者 nn.Cell。
- **AttributeError** - 如果调用了 jit_class 装饰的类的私有属性或魔术方法。

View File

@ -7,11 +7,14 @@ mindspore.ms_class
MindSpore可以通过ms_class识别用户定义的类从而获取这些类的属性和方法。
.. note::
`ms_class` 将在未来版本中弃用和移除,请改用 :func:`mindspore.jit_class`
参数:
- **cls** (Class) - 用户自定义的类。
返回:
带有 __ms_class__ 属性的类。
类。
异常:
- **TypeError** - 如果 ms_class 用于非 class 类型或者 nn.Cell。

View File

@ -7,6 +7,9 @@ mindspore.ms_function
MindSpore可以在运行时对图进行优化。
.. note::
`ms_function` 将在未来版本中弃用和移除,请改用 :func:`mindspore.jit`
参数:
- **fn** (Function) - 要编译成图的Python函数。默认值None。
- **input_signature** (Tensor) - 用于表示输入参数的Tensor。Tensor的shape和dtype将作为函数的输入shape和dtype。默认值None。
@ -18,4 +21,3 @@ mindspore.ms_function
返回:
函数,如果 `fn` 不是None则返回一个已经将输入 `fn` 编译成图的可执行函数;如果 `fn` 为None则返回一个装饰器。当这个装饰器使用单个 `fn` 参数进行调用时,等价于 `fn` 不是None的场景。

View File

@ -339,7 +339,7 @@
设置Cell对象的反向hook函数。
.. note::
- `register_backward_hook(hook_fn)` 在图模式下或者在PyNative模式下使用 `ms_function` 功能时不起作用。
- `register_backward_hook(hook_fn)` 在图模式下或者在PyNative模式下使用 `jit` 装饰器功能时不起作用。
- hook_fn必须有如下代码定义。 `cell_id` 是已注册Cell对象的信息包括名称和ID。 `grad_input` 是反向传递给Cell对象的梯度。 `grad_output` 是Cell对象的反向输出梯度。用户可以在hook_fn中打印梯度数据或者返回新的输出梯度。
- hook_fn返回新的输出梯度或者Nonehook_fn(cell_id, grad_input, grad_output) -> New grad_output or None。
- 为了避免脚本在切换到图模式时运行失败不建议在Cell对象的 `construct` 函数中调用 `register_backward_hook(hook_fn)`
@ -359,7 +359,7 @@
设置Cell对象的正向hook函数。
.. note::
- `register_forward_hook(hook_fn)` 在图模式下或者在PyNative模式下使用 `ms_function` 功能时不起作用。
- `register_forward_hook(hook_fn)` 在图模式下或者在PyNative模式下使用 `jit` 装饰器功能时不起作用。
- hook_fn必须有如下代码定义。 `cell_id` 是已注册Cell对象的信息包括名称和ID。 `inputs` 是网络正向传播时Cell对象的输入数据。 `outputs` 是网络正向传播时Cell对象的输出数据。用户可以在hook_fn中打印数据或者返回新的输出数据。
- hook_fn返回新的输出数据或者Nonehook_fn(cell_id, inputs, outputs) -> New outputs or None。
- 为了避免脚本在切换到图模式时运行失败不建议在Cell对象的 `construct` 函数中调用 `register_forward_hook(hook_fn)`
@ -379,7 +379,7 @@
设置Cell对象的正向pre_hook函数。
.. note::
- `register_forward_pre_hook(hook_fn)` 在图模式下或者在PyNative模式下使用 `ms_function` 功能时不起作用。
- `register_forward_pre_hook(hook_fn)` 在图模式下或者在PyNative模式下使用 `jit` 装饰器功能时不起作用。
- hook_fn必须有如下代码定义。 `cell_id` 是已注册Cell对象的信息包括名称和ID。 `inputs` 是网络正向传播时Cell对象的输入数据。用户可以在hook_fn中打印输入数据或者返回新的输入数据。
- hook_fn返回新的输入数据或者Nonehook_fn(cell_id, inputs) -> New inputs or None。
- 为了避免脚本在切换到图模式时运行失败不建议在Cell对象的 `construct` 函数中调用 `register_forward_pre_hook(hook_fn)`

View File

@ -551,14 +551,14 @@ Interact With MindSpore Functions
Since `mindspore.numpy` directly wraps MindSpore tensors and operators, it has all the advantages and properties of MindSpore. In this section, we will briefly introduce how to employ MindSpore execution management and automatic differentiation in `mindspore.numpy` coding scenarios. These include:
- `ms_function`: for running codes in static graph mode for better efficiency.
- `jit` decorator: for running codes in static graph mode for better efficiency.
- `GradOperation`: for automatic gradient computation.
- `mindspore.set_context`: for `mindspore.numpy` execution management.
- `mindspore.nn.Cell`: for using `mindspore.numpy` interfaces in MindSpore Deep Learning Models.
The following are examples:
- Use ms_function to run code in static graph mode
- Use `jit` decorator to run code in static graph mode
Let's first see an example consisted of matrix multiplication and bias add, which is a typical process in Neural Networks:
@ -590,13 +590,13 @@ The following are examples:
[2816. 2816. 2816. 2816.]]
In this function, MindSpore dispatches each computing kernel to device separately. However, with the help of `ms_function`, we can compile all operations into a single static computing graph.
In this function, MindSpore dispatches each computing kernel to device separately. However, with the help of `jit` decorator, we can compile all operations into a single static computing graph.
.. code-block:: python
from mindspore import ms_function
from mindspore import jit
forward_compiled = ms_function(forward)
forward_compiled = jit(forward)
print(forward(x, w1, b1, w2, b2, w3, b3))
The result is as follows:
@ -607,11 +607,11 @@ The following are examples:
[2816. 2816. 2816. 2816.]]
.. note::
Currently, static graph cannot run in Python interactive mode and not all python types can be passed into functions decorated with `ms_function`. For details about how to use `ms_function`, see `API ms_function <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.ms_function.html>`_ .
Currently, static graph cannot run in Python interactive mode and not all python types can be passed into functions decorated with `jit`.
- Use GradOperation to compute deratives
`GradOperation` can be used to take deratives from normal functions and functions decorated with `ms_function`. Take the previous example:
`GradOperation` can be used to take deratives from normal functions and functions decorated with `jit`. Take the previous example:
.. code-block:: python
@ -635,16 +635,16 @@ The following are examples:
...
Tensor(shape=[4], dtype=Float32, value= [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]))
To take the gradient of `ms_function` compiled functions, first we need to set the execution mode to static graph mode.
To take the gradient of `jit` compiled functions, first we need to set the execution mode to static graph mode.
.. code-block:: python
from mindspore import ms_function, set_context, GRAPH_MODE
from mindspore import jit, set_context, GRAPH_MODE
set_context(mode=GRAPH_MODE)
grad_all = ops.composite.GradOperation(get_all=True)
print(grad_all(ms_function(forward))(x, w1, b1, w2, b2, w3, b3))
print(grad_all(jit(forward))(x, w1, b1, w2, b2, w3, b3))
The result is as follows:

View File

@ -246,8 +246,10 @@ JIT
:template: classtemplate.rst
mindspore.JitConfig
mindspore.ms_function
mindspore.jit
mindspore.jit_class
mindspore.ms_class
mindspore.ms_function
mindspore.mutable
Log

View File

@ -189,8 +189,8 @@ static bool HasSideEffectBackProp(const CNodePtr &cnode) {
static AnfNodePtr SkipHookNodeInBackProp(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (IsPrimitiveCNode(node, prim::kPrimHookBackward) || IsPrimitiveCNode(node, prim::kPrimCellBackwardHook)) {
MS_LOG(WARNING)
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
MS_LOG(WARNING) << "Hook operation does not work in graph mode or functions decorated with 'jit', it will be "
"eliminated during compilation.";
auto output_cnode = node->cast_ptr<CNode>();
if (output_cnode->size() - 1 == 1) {
return output_cnode->input(1);
@ -220,8 +220,8 @@ static AnfNodePtr SkipHookNodeInBackProp(const AnfNodePtr &node) {
auto tuple_get_item = node->cast_ptr<CNode>();
auto inp = tuple_get_item->input(1);
if (IsPrimitiveCNode(inp, prim::kPrimHookBackward) || IsPrimitiveCNode(inp, prim::kPrimCellBackwardHook)) {
MS_LOG(WARNING)
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
MS_LOG(WARNING) << "Hook operation does not work in graph mode or functions decorated with 'jit', it will be "
"eliminated during compilation.";
constexpr size_t idx = 2;
auto v_node = dyn_cast_ptr<ValueNode>(tuple_get_item->input(idx));
MS_EXCEPTION_IF_NULL(v_node);
@ -362,7 +362,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
TraceGuard guard(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info()));
k_app = k_graph_->NewCNode(inputs);
}
// Run in pynative mode, when @ms_function is used.
// Run in pynative mode, when @jit is used.
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
auto pynative_exec = pynative::PyNativeExecutor::GetInstance();
auto grad_exec = pynative_exec->grad_executor();

View File

@ -211,7 +211,7 @@ std::vector<AnfNodePtr> PynativeDFunctor::RunInputReplace(const FuncGraphPtr &bp
}
void PynativeDFunctor::ReplaceEquivdout(const CNodePtr &k_app, const CNodePtr &cnode_morph) {
// The process of replacing forward node only works in pynative mode, when @ms_function is used.
// The process of replacing forward node only works in pynative mode, when @jit is used.
MS_EXCEPTION_IF_NULL(cnode_morph);
MS_LOG(DEBUG) << "Run replace for cnode morph: " << cnode_morph->DebugString(2);
// Get forward node and its fprop graph, bprop graph.

View File

@ -68,8 +68,8 @@ class SpecialOpEliminater : public OptimizerCaller {
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
if (IsPrimitiveCNode(node, prim::kPrimHookBackward) || IsPrimitiveCNode(node, prim::kPrimCellBackwardHook)) {
MS_LOG(WARNING)
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
MS_LOG(WARNING) << "Hook operation does not work in graph mode or functions decorated with 'jit', it will be "
"eliminated during compilation.";
}
return new_node;
}

View File

@ -697,7 +697,7 @@ AnfNodePtr ExpandVmap(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr
return NewValueNode(tf_fg);
}
MS_LOG(EXCEPTION) << "Currently, the first argument in F.vmap only supports Cell, Python defined "
"function or @ms_function decorated function.";
"function or @jit decorated function.";
}
std::string GetShapeString(const ShapeVector &tensor_shape) {

View File

@ -921,7 +921,7 @@ bool CheckGraphOutputConstOrParameter(const FuncGraphPtr &func_graph) {
}
bool EliminateForwardCNode(const ResourcePtr &resource) {
// This function only works in Pynative mode. The func_graph is decorated by ms_function.
// This function only works in Pynative mode. The func_graph is decorated with 'jit'.
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
return true;
}

View File

@ -29,7 +29,7 @@ namespace mindspore {
namespace pipeline {
struct ExecutorInfo {
FuncGraphPtr func_graph;
// The grad graph of func_graph, it will create in PyNative mode when @ms_function is used.
// The grad graph of func_graph, it will create in PyNative mode when @jit is used.
FuncGraphPtr grad_graph;
ResourcePtr resource;
// The num of input data.

View File

@ -248,7 +248,7 @@ ValuePtr ConvertModuleNameSpace(const py::object &obj) {
ValuePtr ConvertMsClass(const py::object &obj) {
MS_LOG(DEBUG) << "Converting ms class";
// Convert class instance decorated with ms_class.
// Convert class instance decorated with jit_class.
if (py::hasattr(obj, PYTHON_PARSE_METHOD)) {
MS_LOG(DEBUG) << "Convert obj to func graph.";
FuncGraphPtr func_graph = ConvertToFuncGraph(obj);

View File

@ -799,7 +799,7 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
}
void GraphExecutorPy::InitCompileCacheInfo(const ResourcePtr &resource, const std::string &phase) {
// The compilation cache only support for training cell or ms_function currently.
// The compilation cache only support for training cell or functions decorated with 'jit' currently.
// If enable compilation cache, it will get a non-empty dependent files list from python.
if (compile_cache_dep_files_.empty()) {
return;

View File

@ -129,7 +129,7 @@ class Resource : public ResourceBase {
// We keep all arguments inputs here for subsequent procedure.
std::vector<ValuePtr> arguments_;
abstract::AbstractBasePtrList args_abs_;
// The source obj to compile, usually a `Cell` or `ms_function` decorated function.
// The source obj to compile, usually a `Cell` or `jit` decorated function.
py::object source_input_;
bool is_cleaned_;
// The func_graph_ is loaded from mindir

View File

@ -1721,7 +1721,7 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
}
}
// Get attribute or method of class object decorated by ms_class.
// Get attribute or method of class object decorated with 'jit_class'.
auto class_value = GetMsClassObject(data_args);
if (class_value != nullptr) {
return GetEvaluatedValueForMsClassAttrOrMethod(args_spec_list, class_value, out_conf);
@ -2012,7 +2012,7 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
if (py::isinstance<py::none>(obj)) {
MS_LOG(EXCEPTION) << "Create python object `" << py::str(class_type)
<< "` failed, only support to create \'Cell\', \'Primitive\' or "
<< "user-defined Class decorated with \'ms_class\'.";
<< "user-defined Class decorated with \'jit_class\'.";
}
// Process the object.

View File

@ -365,7 +365,7 @@ EvalResultPtr ConvertClassToFunc(const CNodePtr &cnode, const AbstractBasePtr &a
MS_LOG(ERROR) << "Can not cast to a AbstractFunction from " << abs->ToString() << ".";
MS_LOG(ERROR) << "It's called at: " << cnode->DebugString();
MS_EXCEPTION(ValueError) << "Can not call " << class_name << " to create python object in graph mode. "
<< "Try using ms_class to decorate the class?";
<< "Try using 'jit_class' to decorate the class?";
}
auto list_func_fg = parse::ParsePythonCode(py_fn);
auto fg = cnode->func_graph();

View File

@ -503,7 +503,7 @@ void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::args &
py::object co_name = py::getattr(code_obj, "co_name");
if (std::string(py::str(co_name)) == "staging_specialize") {
MS_LOG(EXCEPTION) << "Decorating bprop with '@ms_function' is not supported.";
MS_LOG(EXCEPTION) << "Decorating bprop with '@jit' is not supported.";
}
// Three parameters self, out and dout need to be excluded
const size_t inputs_num = static_cast<size_t>(py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3);

View File

@ -373,8 +373,7 @@ BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const {
py::object co_name = py::getattr(code_obj, "co_name");
if (std::string(py::str(co_name)) == "staging_specialize") {
py::object name_obj = py::getattr(elem.second, "__name__");
MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj)
<< " with '@ms_function' is not supported.";
MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj) << " with '@jit' is not supported.";
}
SyncData(grad_output);
py::tuple hook_fn_args = ConstructCellHookFnArgs(cell_id, iter->second, grad_output);
@ -404,7 +403,7 @@ BaseRef PrimitivePy::RunVariableHookFunction(const py::tuple &py_args) const {
py::object co_name = py::getattr(code_obj, "co_name");
if (std::string(py::str(co_name)) == "staging_specialize") {
py::object name_obj = py::getattr(elem.second, "__name__");
MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj) << " with '@ms_function' is not supported.";
MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj) << " with '@jit' is not supported.";
}
SyncData(grad_output);
py::object ret = elem.second(py::make_tuple(grad_output));

View File

@ -461,7 +461,7 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
// If the graph is a bprop graph, it should has a hash of the bprop directory.
std::string bprop_hash_;
// If the graph is decorated by @ms_function and runs grad process in pynative mode,
// If the graph is decorated with @jit and runs grad process in pynative mode,
// forward nodes used in grad graph will be added to output for holding output values.
bool modify_output_ = false;
mindspore::HashSet<AnfNodePtr> used_forward_nodes_;

View File

@ -119,7 +119,7 @@ class ClassMemberNamespace(Namespace):
except ValueError:
raise UnboundLocalError(name)
except KeyError:
# Check if cls is user-defined class decorated with ms_class. If true, an exception will be thrown.
# Check if cls is user-defined class decorated with jit_class. If true, an exception will be thrown.
cls = d.__class__
if hasattr(cls, '__ms_class__'):
raise NotImplementedError(f"'{cls.__name__ }' object has no attribute or method: '{name}'.")

View File

@ -238,7 +238,7 @@ def resolve_symbol(namespace, symbol):
if namespace.name == "numpy" and \
isinstance(resolve_, (types.FunctionType, types.MethodType, types.ModuleType)):
raise NotImplementedError("Mindspore does not support to use the numpy methods " \
"within the construct() or @ms_function decorated function in graph mode.")
"within the construct() or @jit decorated function in graph mode.")
# If need trope the obj
if resolve_ in convert_object_map:
@ -539,7 +539,7 @@ def is_class_type(cls):
def get_ms_class_name(cls):
"""Get the name of the class instance decorated by ms_class."""
"""Get the name of the class instance decorated with jit_class."""
if isinstance(cls, type):
return cls.__name__
return cls.__class__.__name__

View File

@ -21,7 +21,7 @@ from ._checkparam import Validator as validator
from .common import dtype as mstype
from . import context
from . import ops
from .common.api import ms_class
from .common.api import jit_class
from .common.parameter import Parameter
from .common.tensor import Tensor
from .train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager, FixedLossScaleManager
@ -96,7 +96,7 @@ def all_finite(inputs):
return ops.stack(outputs).all()
@ms_class
@jit_class
class LossScaler(ABC):
r"""
Loss scaler abstract class when using mixed precision.

View File

@ -15,7 +15,8 @@
"""Top-level reference to dtype of common module."""
from __future__ import absolute_import
from mindspore.common import dtype
from mindspore.common.api import no_recursive, ms_function, ms_memory_recycle, ms_class, _convert_python_data
from mindspore.common.api import no_recursive, ms_function, ms_memory_recycle, ms_class, _convert_python_data, \
jit, jit_class
from mindspore.common.dtype import Type, int8, byte, int16, short, int32, intc, int64, intp, \
uint8, ubyte, uint16, ushort, uint32, uintc, uint64, uintp, float16, half, \
float32, single, float64, double, bool_, float_, list_, tuple_, int_, \
@ -58,7 +59,7 @@ __all__ = [
__all__.extend([
"Tensor", "RowTensor", "SparseTensor", "COOTensor", "CSRTensor", # tensor
"no_recursive", "ms_function", "ms_class", # api
"no_recursive", "ms_function", "ms_class", 'jit', 'jit_class', # api
"Parameter", "ParameterTuple", # parameter
"dtype", "_convert_python_data",
"set_seed", "get_seed", # random seed

View File

@ -254,7 +254,7 @@ class _MindsporeFunctionExecutor:
Args:
fn (Function): The root function to compile.
input_signature (Function): User defines signature to verify input.
ms_create_time(TimeStamp): The time ms_function created
ms_create_time(TimeStamp): Time the function was created
obj (Object): If function is a method, obj is the owner of function,
else, obj is none.
@ -300,9 +300,9 @@ class _MindsporeFunctionExecutor:
# Check whether hook function registered on Cell object.
if self.obj and hasattr(self.obj, "_hook_fn_registered"):
if self.obj._hook_fn_registered():
logger.warning(f"For 'Cell', it's not support hook function when using ms_function. If you want to "
f"use hook function, please use context.set_context to set pynative mode and remove "
f"`ms_function`.")
logger.warning(f"For 'Cell', it's not support hook function when using 'jit' decorator. "
f"If you want to use hook function, please use context.set_context to set "
f"pynative mode and remove 'jit' decorator.")
# Chose dynamic shape tensors or actual input tensors as compile args.
compile_args = self._generate_compile_args(args_list)
# Restore the mutable attr for every arg.
@ -409,7 +409,7 @@ class _MindsporeFunctionExecutor:
self.input_signature = list(self.input_signature)
dyn_shape = False
for sig_args in self.input_signature:
Validator.check_isinstance("args in `input_signature` of `ms_function`", sig_args, MetaTensor)
Validator.check_isinstance("args in `input_signature` of `jit` decorator", sig_args, MetaTensor)
if is_shape_unknown(sig_args.shape):
dyn_shape = True
if not dyn_shape:
@ -465,6 +465,114 @@ def _get_ms_function_hash(hash_input):
return _get_obj_id(hash_input)
def jit(fn=None, input_signature=None, hash_args=None, jit_config=None):
"""
Create a callable MindSpore graph from a Python function.
This allows the MindSpore runtime to apply optimizations based on graph.
Note:
If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
will not accept `**kwargs`.
Args:
fn (Function): The Python function that will be run as a graph. Default: None.
input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor
will be supplied to this function. If input_signature is specified, each input to `fn` must be a `Tensor`.
And the input parameters of `fn` cannot accept `**kwargs`. The shape and dtype of actual inputs should
keep the same as input_signature. Otherwise, TypeError will be raised. Default: None.
hash_args (Union[Object, List or Tuple of Objects]): The local free variables used inside `fn`,
like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args`
will trigger recompilation.
jit_config (JitConfig): Jit config for compile. Default: None.
Returns:
Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is
equal to the case when `fn` is not None.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindspore import ops
>>> from mindspore import jit
...
>>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
>>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
...
>>> # create a callable MindSpore graph by calling decorator @jit
>>> def tensor_add(x, y):
... z = x + y
... return z
...
>>> tensor_add_graph = jit(fn=tensor_add)
>>> out = tensor_add_graph(x, y)
...
>>> # create a callable MindSpore graph through decorator @jit
>>> @jit
... def tensor_add_with_dec(x, y):
... z = x + y
... return z
...
>>> out = tensor_add_with_dec(x, y)
...
>>> # create a callable MindSpore graph through decorator @jit with input_signature parameter
>>> @jit(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
... def tensor_add_with_sig(x, y):
... z = x + y
... return z
...
>>> out = tensor_add_with_sig(x, y)
...
... # Set hash_args as fn, otherwise cache of compiled `closure_fn` will not be reused.
... # While fn differs during calling again, recompilation will be triggered.
>>> def func(x):
... return ops.exp(x)
...
>>> def closure_fn(x, fn):
... @jit(hash_args=fn)
... def inner_fn(a):
... return fn(a)
... return inner_fn(x)
...
>>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
>>> for i in range(10):
... closure_fn(inputs, func)
"""
def wrap_mindspore(func):
if hash_args:
hash_obj = _get_ms_function_hash(hash_args)
else:
hash_obj = int(time.time() * 1e9)
@wraps(func)
def staging_specialize(*args, **kwargs):
if os.getenv("MS_JIT") == '0':
return func(*args, **kwargs)
args = _handle_func_args(func, *args, **kwargs)
process_obj = None
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
process_obj = args[0]
# only the function or cell instance wrapped by shard will fall into this branch
if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARRALLEL_FUNC_NAME:
process_obj = args[0]
args = args[1:]
out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj, jit_config)(*args)
return out
return staging_specialize
if fn is not None:
return wrap_mindspore(fn)
return wrap_mindspore
def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
"""
Create a callable MindSpore graph from a Python function.
@ -472,6 +580,7 @@ def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
This allows the MindSpore runtime to apply optimizations based on graph.
Note:
`ms_function` will be deprecated and removed in a future version. Please use `jit` instead.
If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
will not accept `**kwargs`.
@ -544,33 +653,9 @@ def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
... closure_fn(inputs, func)
"""
def wrap_mindspore(func):
if hash_args:
hash_obj = _get_ms_function_hash(hash_args)
else:
hash_obj = int(time.time() * 1e9)
@wraps(func)
def staging_specialize(*args, **kwargs):
if os.getenv("MS_JIT") == '0':
return func(*args, **kwargs)
args = _handle_func_args(func, *args, **kwargs)
process_obj = None
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
process_obj = args[0]
# only the function or cell instance wrapped by shard will fall into this branch
if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARRALLEL_FUNC_NAME:
process_obj = args[0]
args = args[1:]
out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj, jit_config)(*args)
return out
return staging_specialize
if fn is not None:
return wrap_mindspore(fn)
return wrap_mindspore
logger.warning("'mindspore.ms_function' will be deprecated and removed in a future version. " \
"Please use 'mindspore.jit' instead.")
return jit(fn=fn, input_signature=input_signature, hash_args=hash_args, jit_config=jit_config)
def _core(fn=None, **flags):
@ -671,15 +756,18 @@ def ms_class(cls):
This allows MindSpore to identify user-defined classes and thus obtain their attributes and methods.
Note:
`ms_class` will be deprecated and removed in a future version. Please use `jit_class` instead.
Args:
cls (Class): User-defined class.
Returns:
Class with __ms_class__ attribute.
Class.
Raises:
TypeError: If ms_class is used for non-class types or nn.Cell.
AttributeError: If the private attributes or magic methods of the class decorated by ms_class is called.
AttributeError: If the private attributes or magic methods of the class decorated with ms_class is called.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -711,6 +799,9 @@ def ms_class(cls):
20
"""
logger.warning("'mindspore.ms_class' will be deprecated and removed in a future version. " \
"Please use 'mindspore.jit_class' instead.")
# Check if cls is of type class.
if not inspect.isclass(cls):
raise TypeError(f'Decorator ms_class can only be used for class type, but got {cls}.')
@ -722,6 +813,62 @@ def ms_class(cls):
return cls
def jit_class(cls):
"""
Class decorator for user-defined classes.
This allows MindSpore to identify user-defined classes and thus obtain their attributes and methods.
Args:
cls (Class): User-defined class.
Returns:
Class.
Raises:
TypeError: If jit_class is used for non-class types or nn.Cell.
AttributeError: If the private attributes or magic methods of the class decorated with jit_class is called.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore.nn as nn
>>> from mindspore import jit_class
...
>>> @jit_class
... class UserDefinedNet:
... def __init__(self):
... self.value = 10
...
... def func(self, x):
... return 2 * x
...
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.net = UserDefinedNet()
...
... def construct(self, x):
... out = self.net.value + self.net.func(x)
... return out
...
>>> net = Net()
>>> out = net(5)
>>> print(out)
20
"""
# Check if cls is of type class.
if not inspect.isclass(cls):
raise TypeError(f'Decorator jit_class can only be used for class type, but got {cls}.')
# Check if cls is nn.Cell.
if issubclass(cls, ms.nn.Cell):
raise TypeError(f"Decorator jit_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
setattr(cls, '__ms_class__', True)
return cls
class _MsFunctionCompileContext:
"""
ms_function compile status manager
@ -1401,4 +1548,4 @@ def ms_memory_recycle():
_cell_graph_executor = _CellGraphExecutor()
_pynative_executor = _PyNativeExecutor()
__all__ = ['ms_function', 'ms_memory_recycle', 'ms_class']
__all__ = ['ms_function', 'ms_memory_recycle', 'ms_class', 'jit', 'jit_class']

View File

@ -1689,7 +1689,7 @@ class Cell(Cell_):
Register forward pre hook function for Cell object.
Note:
- The `register_forward_pre_hook(hook_fn)` does not work in graph mode or ms_function.
- The `register_forward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
- 'hook_fn' must be defined as the following code.
`cell_id` is the information of registered Cell object, including name and ID. `inputs` is the forward
input objects passed to the Cell. The 'hook_fn' can modify the forward input objects by returning new
@ -1752,7 +1752,7 @@ class Cell(Cell_):
raise TypeError(f"When using 'register_forward_pre_hook(hook_fn)', the type of 'hook_fn' must be python "
f"function, but got {type(hook_fn)}.")
if hook_fn.__code__.co_name == "staging_specialize":
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@ms_function' is not supported.")
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
self._enable_forward_pre_hook = True
_pynative_executor.set_hook_changed(self)
@ -1791,7 +1791,7 @@ class Cell(Cell_):
Set the Cell forward hook function.
Note:
- The `register_forward_hook(hook_fn)` does not work in graph mode or ms_function.
- The `register_forward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
- 'hook_fn' must be defined as the following code.
`cell_id` is the information of registered Cell object, including name and ID. `inputs` is the forward
input objects passed to the Cell. `output` is the forward output object of the Cell. The 'hook_fn' can
@ -1856,7 +1856,7 @@ class Cell(Cell_):
raise TypeError(f"When using 'register_forward_hook(hook_fn)', the type of 'hook_fn' must be python "
f"function, but got {type(hook_fn)}.")
if hook_fn.__code__.co_name == "staging_specialize":
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@ms_function' is not supported.")
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
self._enable_forward_hook = True
_pynative_executor.set_hook_changed(self)
@ -1893,7 +1893,7 @@ class Cell(Cell_):
Register the backward hook function.
Note:
- The `register_backward_hook(hook_fn)` does not work in graph mode or ms_function.
- The `register_backward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
- The 'hook_fn' must be defined as the following code.
`cell_id` is the information of registered Cell object, including name and ID. `grad_input` is the
gradient passed to the Cell. `grad_output` is the gradient computed and passed to the next Cell or

View File

@ -20,7 +20,7 @@ from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops.primitive import Primitive
from mindspore.common import dtype as mstype
from mindspore.common.api import ms_function
from mindspore.common.api import jit
from mindspore.common._decorator import deprecated
@ -85,7 +85,7 @@ class Jvp(Cell):
self.make_tuple = Primitive('MakeTuple')
self.tuple_len = Primitive("tuple_len")
@ms_function
@jit
def construct(self, *args):
"""construct for jvp."""
jvp_input = args[0:-1]
@ -186,7 +186,7 @@ class Vjp(Cell):
self.typeof = Primitive('typeof')
self.tuple_len = Primitive("tuple_len")
@ms_function
@jit
def construct(self, *args):
front_input = args[0:-1]
output = self.fn(*front_input)

View File

@ -17,7 +17,7 @@ from __future__ import absolute_import
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore._checkparam import Validator as validator
from mindspore.common.api import ms_function
from mindspore.common.api import jit
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.optim.optimizer import opt_init_args_register
@ -194,7 +194,7 @@ class Adagrad(Optimizer):
self.accum = self._parameters.clone(prefix="accum", init=accum)
self.opt = P.ApplyAdagrad(update_slots=update_slots)
@ms_function
@jit
def construct(self, grads):
params = self._parameters
accum = self.accum

View File

@ -19,7 +19,7 @@ from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.log import logging
from mindspore.common.initializer import initializer
from mindspore.common.api import ms_function
from mindspore.common.api import jit
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
@ -404,7 +404,7 @@ class AdaFactor(Optimizer):
"""
return False
@ms_function
@jit
def construct(self, gradients):
gradients = self.flatten_gradients(gradients)
lr = self.get_lr()

View File

@ -20,7 +20,7 @@ import numpy as np
from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.common.api import ms_function
from mindspore.common.api import jit
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
@ -712,7 +712,7 @@ class Adam(Optimizer):
self._init_distributed_opts(use_locking, use_nesterov)
@ms_function
@jit
def construct(self, gradients):
params = self._parameters
moment1 = self.moment1
@ -970,7 +970,7 @@ class AdamWeightDecay(Optimizer):
else:
self.use_fused_opt = False
@ms_function
@jit
def construct(self, gradients):
gradients = self.flatten_gradients(gradients)
weight_decay = self.get_weight_decay()
@ -1185,7 +1185,7 @@ class AdamOffload(Optimizer):
self.opt = P.AdamNoUpdateParam(use_locking, use_nesterov)
self.opt.set_device("CPU")
@ms_function
@jit
def construct(self, gradients):
params = self._parameters
moment1 = self.moment1

View File

@ -17,7 +17,7 @@ from __future__ import absolute_import
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.common.api import ms_function
from mindspore.common.api import jit
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
@ -196,7 +196,7 @@ class AdaMax(Optimizer):
self.opt = P.ApplyAdaMax()
@ms_function
@jit
def construct(self, gradients):
gradients = self.flatten_gradients(gradients)
gradients = self.decay_weight(gradients)

View File

@ -17,7 +17,7 @@ from __future__ import absolute_import
from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
from mindspore.common.api import ms_function
from mindspore.common.api import jit
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
import mindspore
@ -176,7 +176,7 @@ class ASGD(Optimizer):
self.cast = P.Cast()
self.squeeze = P.Squeeze()
@ms_function
@jit
def construct(self, gradients):
gradients = self.flatten_gradients(gradients)
gradients = self.decay_weight(gradients)

View File

@ -16,7 +16,7 @@
from __future__ import absolute_import
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.api import ms_function
from mindspore.common.api import jit
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.nn.optim.optimizer import Optimizer
@ -276,7 +276,7 @@ class FTRL(Optimizer):
self._init_distributed_opts(use_locking, learning_rate, l1, l2, lr_power)
@ms_function
@jit
def construct(self, grads):
params = self._parameters
moments = self.moments

View File

@ -22,7 +22,7 @@ from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops.operations import _inner_ops as inner
from mindspore.common.tensor import Tensor
from mindspore.common.api import ms_function
from mindspore.common.api import jit
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.nn.optim.optimizer import Optimizer
@ -259,7 +259,7 @@ class Lamb(Optimizer):
self.moments2 = self.params.clone(prefix="lamb_v", init='zeros')
self.device_ascend = context.get_context("device_target") == "Ascend"
@ms_function
@jit
def construct(self, gradients):
weight_decay = self.get_weight_decay()
lr = self.get_lr()

View File

@ -20,7 +20,7 @@ from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore._checkparam import Validator as validator
from mindspore.common import Tensor, Parameter, dtype as mstype
from mindspore.common.api import ms_function
from mindspore.common.api import jit
from mindspore.nn.optim.optimizer import _grad_scale, Optimizer
from mindspore.nn.optim.optimizer import opt_init_args_register
@ -171,7 +171,7 @@ class LARS(Optimizer):
return lr
@ms_function
@jit
def construct(self, gradients):
params = self.parameters
gradients = self.flatten_gradients(gradients)

View File

@ -17,7 +17,7 @@ from __future__ import absolute_import
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.common.api import ms_function
from mindspore.common.api import jit
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
@ -361,7 +361,7 @@ class LazyAdam(Optimizer):
self._init_distributed_opts(use_locking, use_nesterov)
@ms_function
@jit
def construct(self, gradients):
gradients = self.flatten_gradients(gradients)
gradients = self.decay_weight(gradients)

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore.common.api import ms_function
from mindspore.common.api import jit
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator
from mindspore.nn.optim.optimizer import Optimizer
@ -208,7 +208,7 @@ class Momentum(Optimizer):
self._get_distributed_optimizer_list("momentum", use_nesterov=self.use_nesterov)
self.use_dist_optimizer = self._use_distibuted_optimizer()
@ms_function
@jit
def construct(self, gradients):
params = self.params
moments = self.moments

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common import Tensor
import mindspore.common.dtype as mstype
from mindspore.common.api import ms_function
from mindspore.common.api import jit
from mindspore._checkparam import Validator as validator
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.optim.optimizer import opt_init_args_register
@ -197,7 +197,7 @@ class ProximalAdagrad(Optimizer):
self.opt = P.ApplyProximalAdagrad(use_locking=use_locking)
self.sparse_opt = P.SparseApplyProximalAdagrad(use_locking=use_locking)
@ms_function
@jit
def construct(self, grads):
params = self._parameters
accum = self.accum

View File

@ -17,7 +17,7 @@ from __future__ import absolute_import
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore._checkparam import Validator as validator
from mindspore.common.api import ms_function
from mindspore.common.api import jit
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.optim.optimizer import opt_init_args_register
@ -225,7 +225,7 @@ class RMSProp(Optimizer):
self.epsilon = epsilon
self.decay = decay
@ms_function
@jit
def construct(self, gradients):
params = self._parameters
gradients = self.flatten_gradients(gradients)

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from mindspore import ops
from mindspore.ops import operations as P
import mindspore.common.dtype as mstype
from mindspore.common.api import ms_function
from mindspore.common.api import jit
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.nn.optim.optimizer import Optimizer
@ -191,7 +191,7 @@ class Rprop(Optimizer):
self.select = P.Select()
self.ones_like = P.OnesLike()
@ms_function
@jit
def construct(self, gradients):
gradients = self.flatten_gradients(gradients)
gradients = self.decay_weight(gradients)

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore.common.api import ms_function
from mindspore.common.api import jit
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator as validator
from mindspore.nn.optim.optimizer import Optimizer
@ -193,7 +193,7 @@ class SGD(Optimizer):
self.accum = self._parameters.clone(prefix="accum", init='zeros')
self.stat = self._parameters.clone(prefix="stat", init='ones')
@ms_function
@jit
def construct(self, gradients):
params = self._parameters
accum = self.accum

View File

@ -25,7 +25,7 @@ from mindspore.ops.operations.comm_ops import AllReduce, AllGather
from mindspore.parallel._auto_parallel_context import auto_parallel_context
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.common.api import ms_function
from mindspore.common.api import jit
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
@ -411,7 +411,7 @@ class DistributedGradReducer(Cell):
self.mode = context.get_context("mode")
self.enable_tuple_broaden = True
@ms_function
@jit
def construct(self, grads):
"""
Under certain circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the

View File

@ -26,7 +26,7 @@ from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from mindspore.common.api import ms_function
from mindspore.common.api import jit
_grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@ -399,7 +399,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
compute_input = F.depend(compute_input, clear_status)
return status, compute_input
@ms_function
@jit
def get_overflow_status(self, status, compute_output):
"""
Get floating-point overflow status.

View File

@ -30,7 +30,7 @@ from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFu
ListClear_, ListReverse_, ListExtend_, ListCount_, DictClear_, DictHasKey_, DictUpdate_, \
DictFromKeys_
from mindspore.common import dtype as mstype
from mindspore.common.api import ms_function, _pynative_executor, _wrap_func
from mindspore.common.api import jit, _pynative_executor, _wrap_func
from mindspore.common.api import _add_flags, _core
from mindspore.ops.primitive import Primitive
from mindspore.ops import signature as sig
@ -340,21 +340,22 @@ class GradOperation(GradOperation_):
if self.grad_fn is not None and self.fn == fn and self.weights_id == weights_id:
return self.grad_fn
grad_ = GradOperation(self.get_all, self.get_by_list, self.sens_param)
# If calling Grad in GRAPH_MODE or calling Grad in ms_function, do grad in GRAPH_MODE
# If calling Grad in GRAPH_MODE or calling Grad in functions decorated with 'jit', do grad in GRAPH_MODE
# If calling Grad in pure PYNATIVE_MODE do grad in PYNATIVE_MODE
# In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation.
# In PYNATIVE_MODE calling Grad from ms_function, use the out layer after_grad do grad in GRAPH_MODE.
# In PYNATIVE_MODE calling Grad from functions decorated with 'jit', use the out layer after_grad do
# grad in GRAPH_MODE.
if context.get_context("mode") == context.GRAPH_MODE:
dynamic_shape_inputs = None
if isinstance(fn, ms.nn.Cell):
dynamic_shape_inputs = fn.get_inputs()
fn.grad_ops_label = True
if self.get_by_list:
@ms_function(input_signature=dynamic_shape_inputs)
@jit(input_signature=dynamic_shape_inputs)
def after_grad(*args):
return grad_(fn, weights)(*args)
else:
@ms_function(input_signature=dynamic_shape_inputs)
@jit(input_signature=dynamic_shape_inputs)
def after_grad(*args):
return grad_(fn)(*args)
elif self.pynative_:
@ -368,7 +369,7 @@ class GradOperation(GradOperation_):
return out
else:
grad_.pynative_ = True
# after_grad of this branch can't use @ms_function, just directly call grad_
# after_grad of this branch can't use @jit, just directly call grad_
if self.get_by_list:
def after_grad(*args, **kwargs):
return grad_(fn, weights)(*args, **kwargs)
@ -420,9 +421,9 @@ class _TaylorOperation(TaylorOperation_):
return self.grad_fn
taylor_grad_ = _TaylorOperation()
# If calling Grad in GRAPH_MODE or calling Grad in ms_function, do grad in GRAPH_MODE
# If calling Grad in GRAPH_MODE or calling Grad in functions decorated with 'jit', do grad in GRAPH_MODE
@ms_function
@jit
def after_taylor_grad(*args):
return taylor_grad_(fn)(*args)
@ -483,25 +484,26 @@ class _Grad(GradOperation_):
return res
grad_ = _Grad(self.get_by_list, self.sens_param, self.get_by_position, self.has_aux, self.get_value)
# If calling Grad in GRAPH_MODE or calling Grad in ms_function, do grad in GRAPH_MODE
# If calling Grad in GRAPH_MODE or calling Grad in functions decorated with 'jit', do grad in GRAPH_MODE
# If calling Grad in pure PYNATIVE_MODE do grad in PYNATIVE_MODE
# In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation.
# In PYNATIVE_MODE calling Grad from ms_function, use the out layer after_grad do grad in GRAPH_MODE.
# In PYNATIVE_MODE calling Grad from functions decorated with 'jit', use the out layer after_grad do
# grad in GRAPH_MODE.
if context.get_context("mode") == context.GRAPH_MODE:
dynamic_shape_inputs = None
if isinstance(fn, ms.nn.Cell):
dynamic_shape_inputs = fn.get_inputs()
if self.get_by_position:
@ms_function(input_signature=dynamic_shape_inputs)
@jit(input_signature=dynamic_shape_inputs)
def after_grad(*args):
return grad_(fn, weights, grad_position)(*args)
else:
if self.get_by_list:
@ms_function(input_signature=dynamic_shape_inputs)
@jit(input_signature=dynamic_shape_inputs)
def after_grad(*args):
return grad_(fn, weights)(*args)
else:
@ms_function(input_signature=dynamic_shape_inputs)
@jit(input_signature=dynamic_shape_inputs)
def after_grad(*args):
return grad_(fn)(*args)
elif self.pynative_:
@ -522,7 +524,7 @@ class _Grad(GradOperation_):
fn_ = fn
if self.has_aux:
fn_ = aux_fn
# after_grad of this branch can't use @ms_function, just directly call grad_
# after_grad of this branch can't use @jit, just directly call grad_
if self.get_by_position:
def after_grad(*args, **kwargs):
return grad_(fn_, weights, grad_position)(*args, **kwargs)
@ -589,7 +591,7 @@ class _Vmap(VmapOperation_):
vmap_ = self
@ms_function
@jit
def after_vmap(*args):
return vmap_(fn, in_axes, out_axes)(*args)

View File

@ -17,7 +17,7 @@
from __future__ import absolute_import
from functools import partial
import numpy as np
from mindspore.common import ms_function
from mindspore.common import jit
from mindspore.common import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.cell import Cell
@ -647,7 +647,7 @@ def jvp(fn, inputs, v, has_aux=False):
def grad_all(u, first_grad):
return _grad_all(fn_)(*first_grad, u)
@ms_function(hash_args=fn_)
@jit(hash_args=fn_)
def _wrap_container(*arg):
jvp_inputs = arg[1:]
vectors = arg[0]
@ -730,7 +730,7 @@ def linearize(fn, inputs):
"""
linearize_inner = _LinearizeInner()
@ms_function(hash_args=fn)
@jit(hash_args=fn)
def _wrap_container(*arg):
args = arg[1:-1]
vectors = arg[-1]
@ -1006,7 +1006,7 @@ def jacfwd(fn, grad_position=0, has_aux=False):
return _grad_all(aux_fn)(*first_grad, u)
return _grad_all(fn)(*first_grad, u)
@ms_function
@jit
def wrapped(*args):
checked_grad_position = _check_grad_position(grad_position, len(args))
primals, v, inputs_shape = _jacfwd_construct_v(args, checked_grad_position)
@ -1188,7 +1188,7 @@ def jacrev(fn, grad_position=0, has_aux=False):
res = outputs[0]
return res
@ms_function
@jit
def wrapped(*args):
checked_grad_position = _check_grad_position(grad_position, len(args))
outputs = fn(*args)

View File

@ -284,7 +284,7 @@ class InsertGradientOf(PrimitiveWithInfer):
Examples:
>>> import numpy as np
>>> from mindspore import Tensor, ops, ms_function
>>> from mindspore import Tensor, ops, jit
>>> a = Tensor(np.array([1.0]).astype(np.float32))
>>> b = Tensor(np.array([0.2]).astype(np.float32))
>>> def clip_gradient(dx):
@ -306,7 +306,7 @@ class InsertGradientOf(PrimitiveWithInfer):
... c = x * y
... return c
...
... @ms_function
... @jit
... def f(x, y):
... return clip_test(x, y)
...

View File

@ -19,7 +19,7 @@ from __future__ import division
from mindspore.nn.cell import Cell
from mindspore.ops.operations.comm_ops import AllGather
from mindspore.communication import GlobalComm
from mindspore.common import ms_function
from mindspore.common import jit
_ALLGATHER_CELL = None
@ -35,7 +35,7 @@ class AllGatherCell(Cell):
self.allgather = AllGather(group)
self.add_flags(skip_auto_parallel_compile=True)
@ms_function()
@jit()
def construct(self, x):
x = self.allgather(x)

View File

@ -86,7 +86,7 @@ class Shard(Shard_):
def shard_fn(*args):
args = (fn,) + args
@ms.common.ms_function(hash_args=fn)
@ms.common.jit(hash_args=fn)
def after_shard(*args):
return shard_(fn, in_strategy, out_strategy, parameter_plan, device, level)(*args)

View File

@ -18,7 +18,7 @@ from functools import wraps
import mindspore.ops as ops
from mindspore import context
from mindspore.common.dtype import pytype_to_dtype
from mindspore.common.api import ms_function
from mindspore.common.api import jit
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
from mindspore.train.dataset_helper import _has_dynamic_shape, _check_inputs
import mindspore.dataset as ds
@ -104,7 +104,7 @@ def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
if jit_config is None:
dst_sink_fun = sink_fun
else:
dst_sink_fun = ms_function(sink_fun, jit_config=jit_config)
dst_sink_fun = jit(sink_fun, jit_config=jit_config)
dataset.__sink_fun__ = dst_sink_fun
return dataset.__sink_fun__
@ -116,7 +116,7 @@ def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
if jit_config is None:
dst_sink_fun = sink_fun
else:
dst_sink_fun = ms_function(sink_fun, jit_config=jit_config)
dst_sink_fun = jit(sink_fun, jit_config=jit_config)
dataset.__sink_aux__.sink_funcs[key] = dst_sink_fun
return dst_sink_fun

View File

@ -1010,11 +1010,12 @@ def export(net, *inputs, file_name, file_format, **kwargs):
Note:
1. When exporting AIR, ONNX format, the size of a single tensor can not exceed 2GB.
2. When file_name does not have a suffix, the system will automatically add one according to the file_format.
3. Mindspore functions (ms_function) export as mindir format is enabled.
4. When export ms_function, the function should not involve class properties in calculations.
3. Exporting functions decorated with 'jit' to mindir format is supported.
4. When exporting a function decorated with 'jit', the function should not involve class properties in
calculations.
Args:
net (Union[Cell, ms_function]): MindSpore network.
net (Union[Cell, function]): MindSpore network.
inputs (Union[Tensor, Dataset, List, Tuple, Number, Bool]): It represents the inputs
of the `net`, if the network has multiple inputs, set them together. While its type is Dataset,
it represents the preprocess behavior of the `net`, data preprocess operations will be serialized.

View File

@ -21,7 +21,7 @@ import numpy as np
from mindspore import ParameterTuple
from mindspore import nn, context
from mindspore.common.api import _cell_graph_executor, ms_function
from mindspore.common.api import _cell_graph_executor, jit
from mindspore.common.tensor import Tensor
from mindspore.ops import functional as F
from mindspore.ops import operations as P
@ -53,7 +53,7 @@ def run_block(net, *inputs, rand_func=None, training=True):
set_block_param_with_rand(net, rand_func)
if context.get_context("mode") == context.PYNATIVE_MODE:
def func_pynative(*inputs):
@ms_function
@jit
def _func_pynative(*inputs):
return net(*inputs)

View File

@ -17,14 +17,14 @@
from mindspore import context
from mindspore.common import ParameterTuple
from mindspore.common.api import ms_function
from mindspore.common.api import jit
from mindspore.nn import Cell
from mindspore.ops.composite.base import GradOperation
class Bprop(Cell):
"""
The gradient wraper.
The gradient wrapper.
"""
def __init__(self, func, wrt_params, params, grad_op, sens):
@ -90,7 +90,7 @@ def bprop(func, *inputs, grads_wrt_outputs=None, wrt: list = None, params: list
if context.get_context("mode") == context.PYNATIVE_MODE:
def func_pynative(*inputs):
@ms_function
@jit
def _func_pynative(*inputs):
return grad(*inputs)

View File

@ -14,7 +14,7 @@
# ==============================================================================
import pytest
import mindspore.nn as nn
from mindspore import context, Tensor, Parameter, ms_function
from mindspore import context, Tensor, Parameter, jit
import mindspore.ops.operations as P
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
@ -44,7 +44,7 @@ def test_monad_vmap():
vampfunc = F.vmap(AssignNet())
@ms_function
@jit
def test_monad(a):
c = Tensor([[1, 2], [3, 4], [5, 6]], mstype.int32)
out = vampfunc(a)

View File

@ -76,7 +76,7 @@ dict_arg = {"x": tensor_x, "y": tensor_y, "u": tensor_u}
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_non_tensor_inputs(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Feature: Input type with back propagate.
Description: Normal input type without tensor.
Expectation: No exception.
"""
@ -116,7 +116,7 @@ class GradNet1(nn.Cell):
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
def test_grad_first_input_net(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Feature: Input type with back propagate.
Description: Normal input type.
Expectation: No exception.
"""
@ -183,7 +183,7 @@ class GradCell(nn.Cell):
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_grad_parameter_input(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Feature: Input type with back propagate.
Description: Grad with Parameter as input type.
Expectation: No exception.
"""
@ -209,7 +209,7 @@ def test_grad_parameter_input(mode):
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
def test_grad_parameter_as_input_and_fv(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Feature: Input type with back propagate.
Description: Grad with Parameters as input type and fv.
Expectation: No exception.
"""
@ -236,7 +236,7 @@ def test_grad_parameter_as_input_and_fv(mode):
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
def test_grad_same_parameter_both_input_and_fv(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Feature: Input type with back propagate.
Description: Grad with the same Parameter used as input type and fv at the same time.
Expectation: No exception.
"""
@ -304,7 +304,7 @@ class GradCellWithParameterTuple(nn.Cell):
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_grad_parameter_as_input_and_fv2(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Feature: Input type with back propagate.
Description: Grad with Parameters as input type and fv. ParameterTuple as fv.
Expectation: No exception.
"""

View File

@ -15,7 +15,7 @@
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor, ms_function
from mindspore import Tensor, jit
from mindspore.ops import composite as C
from mindspore import context
@ -83,7 +83,7 @@ def test_single_while():
Description: The else branches of while loops aren't supported.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while(x, y):
while x > y:
y += x

View File

@ -14,7 +14,7 @@
# ============================================================================
import numpy as np
import pytest
from mindspore import context, ms_function
from mindspore import context, jit
from mindspore import Tensor, nn
from mindspore.common.parameter import Parameter
from mindspore.ops import composite as C
@ -267,7 +267,7 @@ def test_single_for():
Description: The else branches of for loops aren't supported.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for(x, y):
for _ in range(3):
y += x
@ -294,7 +294,7 @@ def test_single_for_with_not_iterable_object():
Description: The else branches of for loops aren't supported.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for_with_not_iterable_object():
ret = 0
a = 1

View File

@ -15,7 +15,7 @@
import numpy as np
import pytest
from mindspore.common import dtype as mstype
from mindspore import ms_function
from mindspore import jit
from mindspore import nn
from mindspore import Tensor
from mindspore.ops import composite as C
@ -95,7 +95,7 @@ def test_if_in_for_dict():
Description: Execute 'for x in xs' when xs is dictionary.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for(xs):
result = 0
ys = {'b': 0, 'g': 0}

View File

@ -15,7 +15,7 @@
import mindspore as ms
from mindspore import context
from mindspore.ops import operations as P
from mindspore.common.api import ms_function
from mindspore.common.api import jit
from mindspore.common.tensor import Tensor
import mindspore.nn as nn
@ -67,7 +67,7 @@ class TestClass(nn.Cell):
out = self.test(state, init_global_obs)
return out
@ms_function
@jit
def test(self, state, init_global_obs):
num_agent = self.zero
while self.less(num_agent, 3):

View File

@ -15,7 +15,7 @@
import numpy as np
import pytest
import mindspore
from mindspore import context, nn, ops, Tensor, CSRTensor, Parameter, ms_function, mutable
from mindspore import context, nn, ops, Tensor, CSRTensor, Parameter, jit, mutable
from mindspore.ops import functional as F
@ -64,7 +64,7 @@ def test_repeat_control_arrow_for_stack_actor():
assert out == result
@ms_function
@jit
def switch_op(x, y):
z1 = y + 1
z2 = Tensor(5, mindspore.int32)
@ -87,7 +87,7 @@ def test_switch_op():
assert out == 5
@ms_function
@jit
def switch_single_op(x, y, z):
return F.switch(x, y, z)

View File

@ -19,7 +19,7 @@ import numpy as np
from mindspore.nn import Cell
from mindspore.common import Tensor, dtype, Parameter
from mindspore.ops import operations as P
from mindspore import ms_function
from mindspore import jit
import mindspore.ops.functional as F
@ -78,7 +78,7 @@ def test_poly_delay_specialize():
"""
pow_ops = P.Pow()
@ms_function
@jit
def poly_node_network(x, y):
def function_h():
pow_res = pow_ops(x, x)

View File

@ -14,11 +14,11 @@
# ============================================================================
import pytest
import mindspore.context as context
from mindspore import Tensor, ms_function
from mindspore import Tensor, jit
from mindspore.common import dtype as mstype
@ms_function
@jit
def hof(x):
def f(x):
return x + 3

View File

@ -17,7 +17,7 @@ import pytest
import mindspore.context as context
from mindspore.common import dtype as mstype
from mindspore.common import ms_function
from mindspore.common import jit
from mindspore.common.tensor import Tensor
@ -32,7 +32,7 @@ c4 = Tensor([0], mstype.int32)
c5 = Tensor([14], mstype.int32)
@ms_function
@jit
def simple_if(x, y):
if x < y:
x = x + 1
@ -42,7 +42,7 @@ def simple_if(x, y):
return x
@ms_function
@jit
def if_by_if(x, y):
if x < y:
x = x + 1
@ -52,7 +52,7 @@ def if_by_if(x, y):
return x
@ms_function
@jit
def if_in_if(x, y, z):
out = c4
if x < y:
@ -65,7 +65,7 @@ def if_in_if(x, y, z):
return out
@ms_function
@jit
def simple_while(x, y):
y = y + 4
while x < y:
@ -74,7 +74,7 @@ def simple_while(x, y):
return x
@ms_function
@jit
def while_by_while(x, y, z):
while x < y:
x = x + 1
@ -85,7 +85,7 @@ def while_by_while(x, y, z):
return x
@ms_function
@jit
def while_in_while(x, y, z):
out = c4
while x < y:
@ -98,7 +98,7 @@ def while_in_while(x, y, z):
return out
@ms_function
@jit
def while_by_while_in_while(x, y, z):
out = c4
while x < c2:
@ -115,7 +115,7 @@ def while_by_while_in_while(x, y, z):
return out
@ms_function
@jit
def while_in_while_in_while(x, y, z):
out = c4
while x < c2:

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
import mindspore.context as context
from mindspore import Tensor, ms_function
from mindspore import Tensor, jit
from mindspore.common import dtype as mstype, Parameter
from mindspore.nn import Cell
import pytest
@ -31,7 +31,7 @@ def test_while_return_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -57,7 +57,7 @@ def test_if_return_in_while_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -86,7 +86,7 @@ def test_if_return_else_break_in_while_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -114,7 +114,7 @@ def test_if_return_else_return_in_while_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -142,7 +142,7 @@ def test_if_break_else_return_in_while_in_else_take_break():
Expectation: take the break branch, success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -170,7 +170,7 @@ def test_if_break_else_return_in_while_in_else_take_return():
Expectation: take the return branch, success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -198,7 +198,7 @@ def test_while_return_in_while_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -225,7 +225,7 @@ def test_if_return_in_while_in_while_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -255,7 +255,7 @@ def test_if_return_else_return_in_while_in_while_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -284,7 +284,7 @@ def test_while_return_after_if_else_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -315,7 +315,7 @@ def test_if_else_after_while_return_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -346,7 +346,7 @@ def test_if_return_after_if_else_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -377,7 +377,7 @@ def test_if_else_after_if_return_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -407,7 +407,7 @@ def test_while_return_in_else_after_if_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if x > y:
x = x + y
@ -438,7 +438,7 @@ def test_if_else_after_by_while_return_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -468,7 +468,7 @@ def test_if_return_in_else_after_if_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if x > y:
x = x + y
@ -499,7 +499,7 @@ def test_if_else_after_by_if_return_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -530,7 +530,7 @@ def test_if_else_in_if_while_return_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -560,7 +560,7 @@ def test_if_else_in_if_if_return_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -590,7 +590,7 @@ def test_for_return_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -616,7 +616,7 @@ def test_if_return_in_for_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -645,7 +645,7 @@ def test_if_return_else_break_in_for_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -673,7 +673,7 @@ def test_if_return_else_return_in_for_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -701,7 +701,7 @@ def test_for_return_in_for_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -728,7 +728,7 @@ def test_if_return_in_for_in_for_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -758,7 +758,7 @@ def test_if_return_else_return_in_for_in_for_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -787,7 +787,7 @@ def test_for_return_after_if_else_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -818,7 +818,7 @@ def test_if_else_after_for_return_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -849,7 +849,7 @@ def test_for_return_in_else_after_if_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if x > y:
x = x + y
@ -880,7 +880,7 @@ def test_if_else_after_by_for_return_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -910,7 +910,7 @@ def test_if_else_in_if_for_return_in_else():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, bias):
if bias > y:
y = x + y
@ -940,7 +940,7 @@ def test_if_by_if_break_in_if_in_while():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, z):
out = z
while x < y:
@ -970,7 +970,7 @@ def test_if_raise_raise():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, z):
out = z
if x >= y:
@ -997,7 +997,7 @@ def test_if_raise_not_raise():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, z):
out = z
if x >= y:
@ -1023,7 +1023,7 @@ def test_if_assert_success():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, z):
out = z
out = z
@ -1051,7 +1051,7 @@ def test_if_assert_failure():
Expectation: success
"""
@ms_function
@jit
def foo(x, y, z):
out = z
if x >= y:

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
import mindspore.context as context
from mindspore import Tensor, ms_function
from mindspore import Tensor, jit
from mindspore.common import dtype as mstype
from mindspore import ops
import mindspore.nn as nn
@ -23,7 +23,7 @@ ZERO = Tensor([0], mstype.int32)
ONE = Tensor([1], mstype.int32)
@ms_function
@jit
def f(x):
y = ZERO
if x < 0:
@ -38,7 +38,7 @@ def f(x):
return z
@ms_function
@jit
def fr(x):
y = ZERO
if x < 0:
@ -53,7 +53,7 @@ def fr(x):
return z
@ms_function
@jit
def f_pythonerr(x):
if x > 0:
return f_pythonerr(x - 1)
@ -69,7 +69,7 @@ def test_python_error():
assert 'not defined' in str(e)
@ms_function
@jit
def f_recrusive_endless(x):
if x > 0:
return f_recrusive_endless(x - 1)
@ -94,7 +94,7 @@ def test_endless():
assert 'loop' in str(e)
@ms_function
@jit
def f_ok(x):
if x > 0:
return f_ok(x - 1) + 1

View File

@ -13,11 +13,11 @@
# limitations under the License.
# ============================================================================
import mindspore.context as context
from mindspore import Tensor, ms_function
from mindspore import Tensor, jit
from mindspore.common import dtype as mstype
@ms_function
@jit
def t1_while(x, y):
y = y + 4
while x < y:

View File

@ -16,7 +16,7 @@ import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import context, ms_function
from mindspore import context, jit
from mindspore.common.tensor import Tensor
from mindspore.train.serialization import export, load
@ -59,7 +59,12 @@ def test_single_while():
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_ms_function_while():
def test_jit_function_while():
"""
Features: Control flow.
Description: Test while in @jit decorated function.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
network = SingleWhileNet()
@ -76,7 +81,7 @@ def test_ms_function_while():
loaded_net = nn.GraphCell(graph)
context.set_context(mode=context.PYNATIVE_MODE)
@ms_function
@jit
def run_graph(x, y):
outputs = loaded_net(x, y)
return outputs

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
from mindspore import dtype as mstype
context.set_context(mode=context.GRAPH_MODE)
@ -32,7 +32,7 @@ def test_single_if_4():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if():
x = Tensor(7).astype("int32")
y = Tensor(0).astype("int32")
@ -55,7 +55,7 @@ def test_single_if_two_cond():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if():
x = Tensor(1)
y = np.array(2)
@ -76,7 +76,7 @@ def test_single_if_builtin_function_abs():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if():
x = Tensor(-11, mstype.float32)
if abs(x) > Tensor(np.array(2)):
@ -96,7 +96,7 @@ def test_single_if_builtin_function_abs_min():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if():
x = Tensor(-11, mstype.float32)
y = Tensor(12, mstype.float32)

View File

@ -16,7 +16,7 @@
import pytest
import numpy as np
import mindspore as ms
from mindspore import Tensor, ms_function, context, Parameter
from mindspore import Tensor, jit, context, Parameter
from mindspore.nn import Cell
context.set_context(mode=context.GRAPH_MODE)
@ -33,7 +33,7 @@ def test_single_while_1():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while():
x = Tensor(1)
while x < Tensor(7):
@ -54,7 +54,7 @@ def test_single_while_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while():
x = Tensor(7).astype("int32")
y = Tensor(0).astype("int32")
@ -76,7 +76,7 @@ def test_single_while_3():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while():
x = Tensor(7).astype("int32")
y = Tensor(0).astype("int32")
@ -99,7 +99,7 @@ def test_single_while_two_cond_1():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while():
x = Tensor(1)
y = Tensor(8)
@ -123,7 +123,7 @@ def test_single_while_two_cond_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while():
x = Tensor(7).astype("int32")
y = Tensor(0).astype("int32")
@ -172,7 +172,7 @@ def test_single_while_numpy():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while():
x = np.array([1, 2, 3, 4, 5])
y = np.array([0, 2, 4, 6, 8])
@ -193,7 +193,7 @@ def test_single_while_two_cond_3():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while():
x = np.array([1, 2, 3, 4, 5])
y = Tensor(1)

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
from mindspore import dtype as mstype
context.set_context(mode=context.GRAPH_MODE)
@ -32,7 +32,7 @@ def test_single_for_1():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for():
x = Tensor(7).astype("int32")
y = Tensor(0).astype("int32")
@ -54,7 +54,7 @@ def test_single_for_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for():
x = Tensor(7).astype("int32")
y = Tensor(0).astype("int32")
@ -78,7 +78,7 @@ def test_single_for_zip():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for():
tuple_x = (Tensor(1).astype("int32"), Tensor(3).astype("int32"), Tensor(5).astype("int32"))
sum_x = Tensor(0).astype("int32")
@ -101,7 +101,7 @@ def test_single_for_builtin_function_int():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for():
x = np.array(1.1)
for _ in range(3):

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow if in if scenario"""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@ -31,7 +31,7 @@ def test_if_in_if_5():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_in_if():
x = list([1, 2, 3, 4])
if max(x) >= 4:
@ -55,7 +55,7 @@ def test_if_else_in_if_else_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_in_if():
x = Tensor(10)
y = Tensor(7)
@ -90,7 +90,7 @@ def test_if_in_if_multi_conds_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_in_if():
x = Tensor(10)
y = Tensor(2)
@ -119,7 +119,7 @@ def test_if_in_if_4():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_in_if():
x = np.array([1, 2, 3, 4, 5])
y = x % 2

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow if in while scenario"""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@ -31,7 +31,7 @@ def test_if_in_while_1():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_in_while():
x = Tensor(1)
y = Tensor(0)
@ -55,7 +55,7 @@ def test_if_in_while_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_in_while():
x = Tensor(1)
while x < Tensor(5):
@ -78,7 +78,7 @@ def test_if_in_while_3():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_in_while():
x = Tensor(1)
y = Tensor(0)
@ -104,7 +104,7 @@ def test_if_in_while_4():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_in_while():
x = Tensor(1)
y = Tensor(0)
@ -135,7 +135,7 @@ def test_if_in_while_numpy():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_in_while():
x = np.array([1, 2])
y = np.array([3, 2])

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@ -31,7 +31,7 @@ def test_if_in_for_tensor():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for():
x = Tensor(7)
y = Tensor(0)
@ -54,7 +54,7 @@ def test_if_in_for_tensor_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for():
x = Tensor(7)
y = Tensor(0)
@ -79,7 +79,7 @@ def test_if_in_for_tensor_3():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for():
x = Tensor(7)
y = Tensor(0)
@ -103,7 +103,7 @@ def test_if_in_for_numpy_5():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for():
x = np.array([1, 2, 3, 4])
y = (Tensor(1), Tensor(3), Tensor(5))

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow."""
import pytest
import mindspore as ms
from mindspore import Tensor, ms_function, context, nn, Parameter
from mindspore import Tensor, jit, context, nn, Parameter
import numpy as np
context.set_context(mode=context.GRAPH_MODE)
@ -33,7 +33,7 @@ def test_while_in_if_1():
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if():
x = Tensor([1])
if x > Tensor([0]):
@ -58,7 +58,7 @@ def test_while_in_if_2():
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while():
x = Tensor([6]).astype("int32")
y = Tensor([0]).astype("int32")
@ -84,7 +84,7 @@ def test_while_in_if_3():
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while():
x = Tensor([7]).astype("int32")
y = Tensor([0]).astype("int32")
@ -111,7 +111,7 @@ def test_while_two_cond_in_if_1():
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while():
x = Tensor([1])
y = Tensor([8])
@ -138,7 +138,7 @@ def test_while_two_cond_in_if_2():
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while():
x = Tensor([7]).astype("int32")
y = Tensor([0]).astype("int32")

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow."""
import pytest
import mindspore as ms
from mindspore import Tensor, ms_function, context, nn, Parameter
from mindspore import Tensor, jit, context, nn, Parameter
import numpy as np
context.set_context(mode=context.GRAPH_MODE)
@ -33,7 +33,7 @@ def test_while_in_while_1():
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if():
x = Tensor([1])
y = Tensor([3])
@ -60,7 +60,7 @@ def test_while_in_while_2():
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while():
x = Tensor([3]).astype("int32")
@ -86,7 +86,7 @@ def test_while_in_while_3():
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while():
x = Tensor([7]).astype("int32")
y = Tensor([0]).astype("int32")
@ -114,7 +114,7 @@ def test_while_in_while_with_two_cond_1():
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while():
x = Tensor([1])
y = Tensor([8])
@ -141,7 +141,7 @@ def test_while_in_while_with_two_cond_2():
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while():
x = Tensor([7]).astype("int32")
y = Tensor([0]).astype("int32")
@ -167,7 +167,7 @@ def test_while_in_while_with_two_cond_3():
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while():
x = Tensor([7]).astype("int32")
y = Tensor([0]).astype("int32")

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow."""
import pytest
import mindspore as ms
from mindspore import Tensor, ms_function, context, nn, Parameter
from mindspore import Tensor, jit, context, nn, Parameter
import numpy as np
context.set_context(mode=context.GRAPH_MODE)
@ -33,7 +33,7 @@ def test_while_in_for_1():
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for():
x = Tensor([7]).astype("int32")
y = Tensor([0]).astype("int32")
@ -61,7 +61,7 @@ def test_while_in_for_zip():
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for():
tuple_x = (Tensor(1).astype("int32"), Tensor(3).astype("int32"), Tensor(5).astype("int32"))
sum_x = Tensor([0]).astype("int32")
@ -90,7 +90,7 @@ def test_while_in_for_numpy():
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for():
x = np.array([1, 3, 5])
y = np.array([0, 2, 4])

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context, nn
from mindspore import Tensor, jit, context, nn
from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype
@ -33,7 +33,7 @@ def test_for_in_if_tensor():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for_in_if():
x = Tensor(1)
y = Tensor(0)
@ -57,7 +57,7 @@ def test_for_in_if_tensor_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for_in_if():
x = Tensor(1)
y = Tensor(0)
@ -115,7 +115,7 @@ def test_for_in_if_numpy():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for_in_if():
x = np.array([1, 1, 1])
y = list((4, 6, -2))
@ -138,7 +138,7 @@ def test_for_in_if_isinstance_raise():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for_in_if(x):
if isinstance(x, Tensor):
print("before add:", x)
@ -160,7 +160,7 @@ def test_for_in_if_dict_isinstance():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for_in_if():
dict_x = {'a': 1, 'b': 2}
res = 0

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context, nn
from mindspore import Tensor, jit, context, nn
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
@ -34,7 +34,7 @@ def test_for_in_while_tensor():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for_in_while():
x = Tensor(1)
y = Tensor(0)
@ -59,7 +59,7 @@ def test_for_in_while_numpy_append():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for_in_while():
x = np.array([[1, 2, 3], [3, 4, 5], [4, 5, 6]])
y = Tensor(0)
@ -119,7 +119,7 @@ def test_for_in_while_print():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for_in_while():
x = Tensor(1)
y = Tensor(0)
@ -145,7 +145,7 @@ def test_for_in_while_round():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_for_in_while():
x = 3.14159
y = 3

View File

@ -14,7 +14,7 @@
# ============================================================================
""" test graph fallback control flow if after if scenario"""
import pytest
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@ -30,7 +30,7 @@ def test_if_after_if_tensor():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_if():
x = Tensor(1)
y = Tensor(0)
@ -54,7 +54,7 @@ def test_if_after_if_tensor_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_if():
x = Tensor(1)
y = Tensor(0)
@ -84,7 +84,7 @@ def test_if_after_if_tensor_3():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_if(a):
if a > 15:
y = Tensor(1)

View File

@ -14,7 +14,7 @@
# ============================================================================
""" test graph fallback control flow if after while scenario"""
import pytest
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@ -30,7 +30,7 @@ def test_if_after_while_tensor():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_while():
x = Tensor(1)
y = Tensor(0)
@ -56,7 +56,7 @@ def test_if_after_while_tensor_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_while():
x = Tensor(1)
y = Tensor(0)

View File

@ -14,7 +14,7 @@
# ============================================================================
""" test graph fallback control flow."""
import pytest
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@ -30,7 +30,7 @@ def test_if_after_if_in_if_tensor():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_if_in_if():
x = Tensor(1)
y = Tensor(10)

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@ -31,7 +31,7 @@ def test_if_after_if_in_while_tensor():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_if_in_while():
x = Tensor(1)
y = Tensor(10)
@ -60,7 +60,7 @@ def test_if_after_if_in_while_numpy():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_if_in_while():
x = np.array([1, 2])
y = np.array([3, 2])
@ -88,7 +88,7 @@ def test_if_after_if_in_while_tensor_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_if_in_while():
x = Tensor(5)
y = Tensor(2)

View File

@ -14,7 +14,7 @@
# ============================================================================
""" test graph fallback control flow."""
import pytest
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@ -30,7 +30,7 @@ def test_if_after_if_in_for_tensor_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_if_in_for():
x = Tensor(5)
y = Tensor(2)

View File

@ -14,7 +14,7 @@
# ============================================================================
""" test graph fallback control flow."""
import pytest
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@ -30,7 +30,7 @@ def test_if_after_while_in_if_tensor():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_while_in_if():
x = Tensor(1)
y = Tensor(2)
@ -57,7 +57,7 @@ def test_if_after_while_in_if_tensor_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_while_in_if():
x = Tensor(1)
y = Tensor(2)

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@ -31,7 +31,7 @@ def test_if_after_while_in_while_tensor():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_while_in_while():
x = Tensor(1)
y = Tensor(2)
@ -58,7 +58,7 @@ def test_if_after_while_in_while_tensor_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_while_in_while():
x = Tensor(1)
y = Tensor(2)
@ -86,7 +86,7 @@ def test_if_after_while_in_while_numpy():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_while_in_while():
x = np.array([1, 2, 3, 4])
y = Tensor(5)

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@ -31,7 +31,7 @@ def test_if_after_while_in_for_tensor():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_while_in_for():
x = Tensor(1)
y = Tensor(2)
@ -58,7 +58,7 @@ def test_if_after_while_in_for_tensor_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_while_in_for():
x = Tensor(1)
y = Tensor(2)
@ -86,7 +86,7 @@ def test_if_after_while_in_for_numpy():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_while_in_for():
x = np.array([1, 2, 3, 4])
y = Tensor(5)

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@ -31,7 +31,7 @@ def test_if_after_for_in_if_tensor():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_for_in_if():
x = Tensor([1])
y = Tensor([2])
@ -53,7 +53,7 @@ def test_if_after_for_in_if_numpy():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_for_in_if():
x = np.array([1, 2])
y = np.array([3, 4])
@ -80,7 +80,7 @@ def test_if_after_for_in_if_tensor_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_for_in_if():
x = Tensor([1])
y = Tensor([2])
@ -109,7 +109,7 @@ def test_if_after_for_in_if_numpy_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_for_in_for():
x = np.array([3, 2])
y = Tensor(np.array([0, 2, 4, 6, 8]))

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@ -31,7 +31,7 @@ def test_if_after_for_in_while_tensor():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_for_in_while():
x = Tensor([1])
y = Tensor([2])
@ -58,7 +58,7 @@ def test_if_after_for_in_while_tensor_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_for_in_while():
x = Tensor([1])
y = Tensor([2])
@ -85,7 +85,7 @@ def test_if_after_for_in_while_numpy_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_for_in_while():
x = np.array([5, 4, 3, 2, 1])
y = (Tensor(1), Tensor(3), Tensor(5))

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@ -30,7 +30,7 @@ def test_if_after_for_in_for_numpy():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_if_after_for_in_for():
x = np.array([3, 2])
y = Tensor(np.array([0, 2, 4, 6, 8]))

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@ -31,7 +31,7 @@ def test_while_after_if_tensor():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_if():
x = Tensor([1])
y = Tensor([2])
@ -57,7 +57,7 @@ def test_while_after_if_tensor_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_if():
x = Tensor([1])
y = Tensor([2])
@ -87,7 +87,7 @@ def test_while_after_if_numpy():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_if():
x = np.array([3, 2])
y = Tensor(np.array([3, 2]))
@ -113,7 +113,7 @@ def test_while_after_if_numpy_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_if():
x = np.array([3, 2])
y = [1, 2, 3, 4]

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@ -31,7 +31,7 @@ def test_while_after_while_tensor():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_while():
x = Tensor([1])
y = Tensor([2])
@ -57,7 +57,7 @@ def test_while_after_while_tensor_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_while():
x = Tensor([1])
y = Tensor([2])
@ -78,7 +78,7 @@ def test_while_after_while_numpy_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_while():
x = np.array([3, 2])
y = [1, 2, 3, 4]
@ -103,7 +103,7 @@ def test_while_after_while_numpy():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_while():
x = [1, 2, 3, 4]
y = Tensor([8])

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@ -31,7 +31,7 @@ def test_while_after_for_tensor_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_for():
x = Tensor([1])
y = Tensor([2])
@ -53,7 +53,7 @@ def test_while_after_for_numpy_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_for():
x = np.array([3, 2])
y = [1, 2, 3, 4]
@ -81,7 +81,7 @@ def test_while_after_for_tensor():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_for():
x = Tensor([1])
y = Tensor([2])
@ -107,7 +107,7 @@ def test_while_after_for_numpy():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_for():
x = [1, 2, 3, 4, 5]
y = Tensor([3])

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@ -31,7 +31,7 @@ def test_while_after_if_in_if_tensor():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_if_in_if():
x = Tensor([1])
y = Tensor([2])
@ -63,7 +63,7 @@ def test_while_after_if_in_if_tensor_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_if_in_if():
x = Tensor([1])
y = Tensor([2])
@ -89,7 +89,7 @@ def test_while_after_if_in_if_numpy():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_if_in_if():
x = np.array([1])
y = np.array([10])
@ -113,7 +113,7 @@ def test_while_after_if_in_if_numpy_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_if_in_if():
x = np.array([1])
y = np.array([10])

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@ -31,7 +31,7 @@ def test_while_after_if_in_while_tensor():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_if_in_while():
x = Tensor([1])
y = Tensor([2])
@ -64,7 +64,7 @@ def test_while_after_if_in_while_numpy_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_if_in_while():
x = np.array([1])
y = np.array([10])
@ -95,7 +95,7 @@ def test_while_after_if_in_while_tensor_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_if_in_while():
x = Tensor([1])
y = Tensor([2])

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@ -31,7 +31,7 @@ def test_while_after_if_in_for_tensor():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_if_in_for():
x = Tensor([1])
y = Tensor([2])
@ -63,7 +63,7 @@ def test_while_after_if_in_for_tensor_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_if_in_for():
x = Tensor([1])
y = Tensor([2])
@ -94,7 +94,7 @@ def test_while_after_if_in_for_numpy_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_if_in_for():
x = np.array([1])
y = np.array([10])

View File

@ -15,7 +15,7 @@
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@ -31,7 +31,7 @@ def test_while_after_while_in_if_tensor():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_while_in_if():
x = Tensor([1])
y = Tensor([2])
@ -60,7 +60,7 @@ def test_while_after_while_in_if_tensor_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_while_in_if():
x = Tensor([3])
y = Tensor([5])
@ -95,7 +95,7 @@ def test_while_after_while_in_if_numpy_2():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_while_in_if():
x = Tensor([1])
y = Tensor([2])
@ -124,7 +124,7 @@ def test_while_after_while_in_if_numpy():
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
@jit
def control_flow_while_after_while_in_if():
x = Tensor([1])
y = Tensor([2])

Some files were not shown because too many files have changed in this diff Show More