!35938 [Fallback] Support ms.Tensor() in construct or ms_function.
Merge pull request !35938 from Margaret_wangrui/ms_tensor
This commit is contained in:
commit
d79cfe29c5
|
@ -1329,6 +1329,16 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
|
|||
// Process the node attr
|
||||
auto attr_str = python_adapter::GetPyObjAttr(node, "attr").cast<std::string>();
|
||||
MS_LOG(DEBUG) << "Attr = " << attr_str;
|
||||
// The fallback feature is enabled in default.
|
||||
static const auto use_fallback = (support_fallback() != "0");
|
||||
// Process xxx.Tensor(), eg: ms.Tensor()
|
||||
if (use_fallback && attr_str == "Tensor") {
|
||||
std::string script_text = py::cast<std::string>(ast()->GetAstNodeText(node));
|
||||
AnfNodePtr interpret_node = MakeInterpretNode(block, value_node, script_text);
|
||||
interpret_node->set_interpret(true);
|
||||
interpret_node->set_interpret_internal_type(true);
|
||||
return interpret_node;
|
||||
}
|
||||
AnfNodePtr attr_node = nullptr;
|
||||
{
|
||||
TraceGuard guard(GetLocation(python_adapter::GetPyObjAttr(node, "attr")));
|
||||
|
@ -1337,8 +1347,6 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
|
|||
MS_EXCEPTION_IF_NULL(block->func_graph());
|
||||
// Create the apply node
|
||||
auto attr_cnode = block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node});
|
||||
// The fallback feature is enabled in default.
|
||||
static const auto use_fallback = (support_fallback() != "0");
|
||||
if (use_fallback) {
|
||||
// Check whether it is constant, constant does not need interpret.
|
||||
auto value_str = py::cast<std::string>(ast()->GetAstNodeText(value_body));
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
import mindspore as ms
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.initializer import One
|
||||
|
@ -421,3 +422,49 @@ def test_fallback_tensor_slice():
|
|||
out = Tensor(array)[1:5]
|
||||
return out
|
||||
print(foo())
|
||||
|
||||
|
||||
def test_fallback_ms_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test ms.Tensor() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
me_x = ms.Tensor([1])
|
||||
return me_x
|
||||
res = foo()
|
||||
assert (res.asnumpy() == [1]).all()
|
||||
|
||||
|
||||
def test_fallback_ms_tensor_numpy():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test ms.Tensor() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
me_x = ms.Tensor(np.array([1, 2], dtype=np.float32))
|
||||
return me_x
|
||||
res = foo()
|
||||
assert (res.asnumpy() == [1, 2]).all()
|
||||
|
||||
|
||||
def test_fallback_ms_tensor_class():
|
||||
"""
|
||||
Feature: Fallback feature
|
||||
Description: Test ms.Tensor() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def construct(self):
|
||||
np_array = np.array(9)
|
||||
x = ms.Tensor(np_array)
|
||||
res = x + ms.Tensor(np_array)
|
||||
return res
|
||||
|
||||
net = Net()
|
||||
res = net()
|
||||
assert res == 18
|
||||
|
|
Loading…
Reference in New Issue