!35938 [Fallback] Support ms.Tensor() in construct or ms_function.

Merge pull request !35938 from Margaret_wangrui/ms_tensor
This commit is contained in:
i-robot 2022-06-15 02:25:28 +00:00 committed by Gitee
commit d79cfe29c5
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 57 additions and 2 deletions

View File

@ -1329,6 +1329,16 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
// Process the node attr
auto attr_str = python_adapter::GetPyObjAttr(node, "attr").cast<std::string>();
MS_LOG(DEBUG) << "Attr = " << attr_str;
// The fallback feature is enabled in default.
static const auto use_fallback = (support_fallback() != "0");
// Process xxx.Tensor(), eg: ms.Tensor()
if (use_fallback && attr_str == "Tensor") {
std::string script_text = py::cast<std::string>(ast()->GetAstNodeText(node));
AnfNodePtr interpret_node = MakeInterpretNode(block, value_node, script_text);
interpret_node->set_interpret(true);
interpret_node->set_interpret_internal_type(true);
return interpret_node;
}
AnfNodePtr attr_node = nullptr;
{
TraceGuard guard(GetLocation(python_adapter::GetPyObjAttr(node, "attr")));
@ -1337,8 +1347,6 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
MS_EXCEPTION_IF_NULL(block->func_graph());
// Create the apply node
auto attr_cnode = block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node});
// The fallback feature is enabled in default.
static const auto use_fallback = (support_fallback() != "0");
if (use_fallback) {
// Check whether it is constant, constant does not need interpret.
auto value_str = py::cast<std::string>(ast()->GetAstNodeText(value_body));

View File

@ -16,6 +16,7 @@
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
import mindspore as ms
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.common.initializer import One
@ -421,3 +422,49 @@ def test_fallback_tensor_slice():
out = Tensor(array)[1:5]
return out
print(foo())
def test_fallback_ms_tensor():
"""
Feature: JIT Fallback
Description: Test ms.Tensor() in graph mode.
Expectation: No exception.
"""
@ms_function
def foo():
me_x = ms.Tensor([1])
return me_x
res = foo()
assert (res.asnumpy() == [1]).all()
def test_fallback_ms_tensor_numpy():
"""
Feature: JIT Fallback
Description: Test ms.Tensor() in graph mode.
Expectation: No exception.
"""
@ms_function
def foo():
me_x = ms.Tensor(np.array([1, 2], dtype=np.float32))
return me_x
res = foo()
assert (res.asnumpy() == [1, 2]).all()
def test_fallback_ms_tensor_class():
"""
Feature: Fallback feature
Description: Test ms.Tensor() in graph mode.
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self):
np_array = np.array(9)
x = ms.Tensor(np_array)
res = x + ms.Tensor(np_array)
return res
net = Net()
res = net()
assert res == 18