!30496 [Fallback] Suppport scipy module.

Merge pull request !30496 from huangbingjian/support_scipy
This commit is contained in:
i-robot 2022-02-28 08:33:34 +00:00 committed by Gitee
commit a92c54b206
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 34 additions and 8 deletions

View File

@ -591,6 +591,7 @@ class Parser:
# Used to resolve mindspore builtin ops namespace.
self.ms_common_ns = CellNamespace('mindspore.common')
self.ms_nn_ns = CellNamespace('mindspore.nn')
self.ms_scipy_ns = CellNamespace('mindspore.scipy')
self.ms_ops_ns = CellNamespace('mindspore.ops')
self.ms_ops_c_ns = CellNamespace('mindspore.ops.composite')
self.ms_ops_c_multitype_ns = CellNamespace('mindspore.ops.composite.multitype_ops')
@ -709,6 +710,9 @@ class Parser:
if name == 'mindspore.numpy':
logger.debug(f"Found 'mindspore.numpy' namespace.")
return True
if name == 'mindspore.scipy':
logger.debug(f"Found 'mindspore.scipy' namespace.")
return True
if name == 'mindspore.context':
logger.debug(f"Found 'mindspore.context' namespace.")
return True
@ -747,6 +751,9 @@ class Parser:
if rightmost_name in self.ms_nn_ns:
logger.debug(f"Found '{name}'({rightmost_name}) in nn namespace: {str(self.ms_nn_ns)}.")
return True
if rightmost_name in self.ms_scipy_ns:
logger.debug(f"Found '{name}'({rightmost_name}) in scipy namespace: {str(self.ms_scipy_ns)}.")
return True
if rightmost_name in trope_ns:
logger.debug(f"Found '{name}'({rightmost_name}) in trope namespace: {str(trope_ns)}.")
return True
@ -766,14 +773,16 @@ class Parser:
logger.debug(f"value: {type(value)}, '{value_str}', hasattr(__name__): {hasattr(value, '__name__')}.")
# To check if allowed to support.
if self.is_unsupported_internal_type(value):
return self.global_namespace, var, value, SYNTAX_UNSUPPORTED_INTERNAL_TYPE
if self.is_unsupported_python_builtin_type(value):
return self.global_namespace, var, value, SYNTAX_UNSUPPORTED_EXTERNAL_TYPE
if self.is_unsupported_special_type(value):
return self.global_namespace, var, value, SYNTAX_UNSUPPORTED_SPECIAL_TYPE
if self.is_unsupported_namespace(value) or not self.is_supported_namespace_module(value):
return self.global_namespace, var, value, SYNTAX_UNSUPPORTED_NAMESPACE
return self.global_namespace, var, value, SYNTAX_SUPPORTED
support_info = self.global_namespace, var, value, SYNTAX_UNSUPPORTED_INTERNAL_TYPE
elif self.is_unsupported_python_builtin_type(value):
support_info = self.global_namespace, var, value, SYNTAX_UNSUPPORTED_EXTERNAL_TYPE
elif self.is_unsupported_special_type(value):
support_info = self.global_namespace, var, value, SYNTAX_UNSUPPORTED_SPECIAL_TYPE
elif self.is_unsupported_namespace(value) or not self.is_supported_namespace_module(value):
support_info = self.global_namespace, var, value, SYNTAX_UNSUPPORTED_NAMESPACE
else:
support_info = self.global_namespace, var, value, SYNTAX_SUPPORTED
return support_info
error_info = f"The name '{var}' is not defined, or not supported in graph mode."
logger.debug(f"error_info: {error_info}")

View File

@ -23,6 +23,7 @@ from mindspore.ops import functional as F
from mindspore.nn.probability import distribution
import mindspore.common.dtype as mstype
import mindspore.common._monad as monad
import mindspore.scipy.linalg as alg
context.set_context(mode=context.GRAPH_MODE)
@ -226,6 +227,22 @@ def test_context():
print(out)
def test_scipy_module():
"""
Feature: JIT Fallback
Description: Test scipy module in graph.
Expectation: No exception.
"""
class Network(nn.Cell):
def construct(self, x):
return alg.eigh(x)
net = Network()
x = Tensor([[2, 0, 0, 0], [0, 1, 0, 0], [0, 0, 2, 0], [0, 0, 0, 1]])
out = net(x)
print(out)
def test_self_attr():
"""
Feature: JIT Fallback