forked from mindspore-Ecosystem/mindspore
blacklist
This commit is contained in:
parent
aa8858ab05
commit
e485af77f6
|
@ -427,10 +427,38 @@ ValuePtr ConvertConstantNumpyNumber(const py::object &obj, ResolveTypeDef obj_ty
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void CheckJITForbiddenAPI(const py::object &obj) {
|
||||
auto module = python_adapter::GetPyModule(PYTHON_MOD_MODULE);
|
||||
py::list obj_info = python_adapter::CallPyModFn(module, PYTHON_MOD_GET_MODULE_AND_NAME_INFO, obj);
|
||||
std::ostringstream oss;
|
||||
auto obj_module = py::cast<std::string>(obj_info[0]);
|
||||
auto obj_name = py::cast<std::string>(obj_info[1]);
|
||||
auto obj_type = py::cast<std::string>(obj_info[2]);
|
||||
oss << "Failed to compile in GRAPH_MODE because the " << obj_type << " '" << obj_module << "." << obj_name
|
||||
<< "' is not supported in 'construct' or function with @jit decorator. "
|
||||
<< "Try to use the " << obj_type << " '" << obj_module << "." << obj_name << "' externally "
|
||||
<< "such as initialized in the method '__init__' before assigning"
|
||||
<< ".\nFor more details, please refer to "
|
||||
<< "https://www.mindspore.cn/docs/zh-CN/master/design/dynamic_graph_and_static_graph.html \n";
|
||||
// Check if the API is decoratored by @jit_forbidden_register.
|
||||
bool is_jit_forbidden_register = data_converter::IsJITForbiddenAPI(obj);
|
||||
if (is_jit_forbidden_register) {
|
||||
MS_LOG(EXCEPTION) << oss.str();
|
||||
}
|
||||
// Check if the API's module is in the JIT forbidden module set.
|
||||
bool is_jit_forbidden_module =
|
||||
py::cast<bool>(python_adapter::CallPyModFn(module, PYTHON_MOD_IS_JIT_FORBIDDEN_MODULE, obj_info[0]));
|
||||
if (is_jit_forbidden_module) {
|
||||
MS_LOG(EXCEPTION) << oss.str();
|
||||
}
|
||||
}
|
||||
|
||||
ValuePtr ConvertOtherObj(const py::object &obj, bool forbid_reuse = false) {
|
||||
auto obj_type = data_converter::GetObjType(obj);
|
||||
MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " ";
|
||||
if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
|
||||
// Check JIT forbidden API
|
||||
CheckJITForbiddenAPI(obj);
|
||||
MS_LOG(DEBUG) << "Resolve the class type, need create class instance.";
|
||||
std::string desc = py::str(obj);
|
||||
// desc has format "<class xxxx>", strip the '<' and '>' by offset 1.
|
||||
|
@ -438,6 +466,10 @@ ValuePtr ConvertOtherObj(const py::object &obj, bool forbid_reuse = false) {
|
|||
}
|
||||
if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD ||
|
||||
(obj_type == RESOLVE_TYPE_CLASS_INSTANCE && py::hasattr(obj, PYTHON_PARSE_METHOD))) {
|
||||
if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) {
|
||||
// Check JIT forbidden API
|
||||
CheckJITForbiddenAPI(obj);
|
||||
}
|
||||
MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type;
|
||||
FuncGraphPtr func_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_PARSE_METHOD, forbid_reuse);
|
||||
if (func_graph == nullptr) {
|
||||
|
@ -733,6 +765,9 @@ bool IsNumpyArrayInstance(const py::object &obj) {
|
|||
// Check if the object is MsClass instance.
|
||||
bool IsMsClassInstance(const py::object &obj) { return py::hasattr(obj, PYTHON_MS_CLASS); }
|
||||
|
||||
// Check if the object is jit forbidden api.
|
||||
bool IsJITForbiddenAPI(const py::object &obj) { return py::hasattr(obj, PYTHON_JIT_FORBIDDEN); }
|
||||
|
||||
// Check if the object is class type.
|
||||
bool IsClassType(const py::object &obj) {
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
|
|
|
@ -47,6 +47,7 @@ ClassInstanceTypeDef GetClassInstanceType(const py::object &obj);
|
|||
bool IsCellInstance(const py::object &obj);
|
||||
bool IsNumpyArrayInstance(const py::object &obj);
|
||||
bool IsMsClassInstance(const py::object &obj);
|
||||
bool IsJITForbiddenAPI(const py::object &obj);
|
||||
bool IsClassType(const py::object &obj);
|
||||
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs);
|
||||
py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs);
|
||||
|
|
|
@ -55,6 +55,7 @@ enum ParseTargetTypeDef {
|
|||
};
|
||||
|
||||
// Define python module name.
|
||||
const char PYTHON_MOD_MODULE[] = "mindspore";
|
||||
const char PYTHON_MOD_PARSE_MODULE[] = "mindspore._extends.parse";
|
||||
const char PYTHON_MOD_PARSE_OBJECT_FUNCTION[] = "parse_cb";
|
||||
const char PYTHON_MOD_RESOLVE_FUNCTION[] = "resolve_symbol";
|
||||
|
@ -108,6 +109,9 @@ const char PYTHON_PARSE_GET_CONVERT_OBJECT_FOR_UNSUPPORTED_TYPE[] = "get_convert
|
|||
const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj";
|
||||
const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";
|
||||
|
||||
const char PYTHON_MOD_GET_MODULE_AND_NAME_INFO[] = "get_obj_module_and_name_info";
|
||||
const char PYTHON_MOD_IS_JIT_FORBIDDEN_MODULE[] = "is_jit_forbidden_module";
|
||||
|
||||
// Define the common name.
|
||||
const char NAMED_PRIMITIVE_LEN[] = "len";
|
||||
const char NAMED_PRIMITIVE_BODY[] = "body";
|
||||
|
|
|
@ -19,6 +19,7 @@ namespace mindspore {
|
|||
const char PYTHON_PRIMITIVE_FLAG[] = "__primitive_flag__";
|
||||
const char PYTHON_CELL_AS_LIST[] = "__cell_as_list__";
|
||||
const char PYTHON_MS_CLASS[] = "__ms_class__";
|
||||
const char PYTHON_JIT_FORBIDDEN[] = "__jit_forbidden__";
|
||||
const char PYTHON_ADAPTER_TENSOR[] = "__adapter_tensor__";
|
||||
const char PYTHON_CLASS_MEMBER_NAMESPACE[] = "__class_member_namespace__";
|
||||
const char PYTHON_FUNCTION_FORBID_REUSE[] = "__function_forbid_reuse__";
|
||||
|
|
|
@ -21,6 +21,7 @@ namespace mindspore {
|
|||
extern const char PYTHON_PRIMITIVE_FLAG[];
|
||||
extern const char PYTHON_CELL_AS_LIST[];
|
||||
extern const char PYTHON_MS_CLASS[];
|
||||
extern const char PYTHON_JIT_FORBIDDEN[];
|
||||
extern const char PYTHON_ADAPTER_TENSOR[];
|
||||
extern const char PYTHON_CLASS_MEMBER_NAMESPACE[];
|
||||
extern const char PYTHON_FUNCTION_FORBID_REUSE[];
|
||||
|
|
|
@ -31,6 +31,7 @@ from mindspore.profiler import Profiler, EnvProfiler
|
|||
from mindspore.parallel import set_algo_parameters, get_algo_parameters, reset_algo_parameters, \
|
||||
rank_list_for_transform, transform_checkpoint_by_rank, transform_checkpoints, merge_pipeline_strategys, shard
|
||||
from mindspore.rewrite import SymbolTree, ScopedValue, Node, NodeType, TreeNodeHelper
|
||||
from mindspore._check_jit_forbidden_api import get_obj_module_and_name_info, is_jit_forbidden_module
|
||||
|
||||
|
||||
__all__ = ["run_check"]
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Check jit forbidden api."""
|
||||
|
||||
import types
|
||||
|
||||
from mindspore import log as logger
|
||||
|
||||
# module: such as "mindspore.common.initializer"
|
||||
_jit_forbidden_module = set()
|
||||
|
||||
|
||||
def jit_forbidden_register(fn):
|
||||
setattr(fn, '__jit_forbidden__', True)
|
||||
def jit_forbidden(*args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
return jit_forbidden
|
||||
|
||||
|
||||
def add_jit_forbidden_module(jit_forbidden_module):
|
||||
logger.debug(f'add jit_forbidden_module_set: {_jit_forbidden_module}')
|
||||
return _jit_forbidden_module.add(jit_forbidden_module)
|
||||
|
||||
|
||||
def remove_jit_forbidden_module(jit_forbidden_module):
|
||||
logger.debug(f'remove jit_forbidden_module_set: {_jit_forbidden_module}')
|
||||
return _jit_forbidden_module.remove(jit_forbidden_module)
|
||||
|
||||
|
||||
def get_jit_forbidden_module(jit_forbidden_module):
|
||||
logger.debug(f'get jit_forbidden_module_set: {_jit_forbidden_module}')
|
||||
return _jit_forbidden_module
|
||||
|
||||
|
||||
def get_obj_module_and_name_info(obj):
|
||||
"""Return the description of the object whose type is class, function or method."""
|
||||
if isinstance(obj, (types.FunctionType, types.MethodType)):
|
||||
return obj.__module__, obj.__qualname__, "method or function"
|
||||
return obj.__module__, obj.__name__, "class"
|
||||
|
||||
|
||||
def is_jit_forbidden_module(obj_module):
|
||||
"""Return the matching result of object module in jit forbidden module set."""
|
||||
if obj_module in _jit_forbidden_module:
|
||||
return True
|
||||
return False
|
||||
|
||||
add_jit_forbidden_module("mindspore.common.initializer")
|
|
@ -41,6 +41,7 @@ from mindspore.ops.operations import Cast
|
|||
from mindspore.ops.primitive import Primitive
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.parallel.shard import Shard
|
||||
from mindspore._check_jit_forbidden_api import jit_forbidden_register
|
||||
|
||||
|
||||
class Cell(Cell_):
|
||||
|
@ -1195,6 +1196,7 @@ class Cell(Cell_):
|
|||
param.is_init = False
|
||||
param.name = prefix + name
|
||||
|
||||
@jit_forbidden_register
|
||||
def trainable_params(self, recurse=True):
|
||||
"""
|
||||
Returns all trainable parameters.
|
||||
|
@ -1209,6 +1211,7 @@ class Cell(Cell_):
|
|||
"""
|
||||
return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
|
||||
|
||||
@jit_forbidden_register
|
||||
def untrainable_params(self, recurse=True):
|
||||
"""
|
||||
Returns all untrainable parameters.
|
||||
|
@ -1223,6 +1226,7 @@ class Cell(Cell_):
|
|||
"""
|
||||
return list(filter(lambda x: not x.requires_grad, self.get_parameters(expand=recurse)))
|
||||
|
||||
@jit_forbidden_register
|
||||
def get_parameters(self, expand=True):
|
||||
"""
|
||||
Returns an iterator over cell parameters.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021-2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -107,13 +107,17 @@ def test_fallback_tensor_with_init():
|
|||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test Tensor() with init in graph mode.
|
||||
Expectation: No exception.
|
||||
Expectation: throw RuntimeError.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
me_x = Tensor(shape=(1, 3), dtype=mstype.float32, init=One())
|
||||
return me_x
|
||||
print(foo())
|
||||
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
foo()
|
||||
assert "Failed to compile in GRAPH_MODE" in str(ex.value)
|
||||
assert "the class 'mindspore.common.initializer.One'" in str(ex.value)
|
||||
|
||||
|
||||
def test_fallback_tensor_reshape():
|
||||
|
|
|
@ -0,0 +1,191 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test jit forbidden api in graph mode. """
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context, jit
|
||||
from mindspore.common.initializer import initializer, One
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_jit_forbidden_api_one1():
|
||||
"""
|
||||
Feature: mindspore.common.initializer.One
|
||||
Description: test jit forbidden api 'One' in graph mode.
|
||||
Expectation: throw RuntimeError
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self):
|
||||
t = One()
|
||||
return t
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
net()
|
||||
assert "Failed to compile in GRAPH_MODE" in str(ex.value)
|
||||
assert "the class 'mindspore.common.initializer.One'" in str(ex.value)
|
||||
|
||||
|
||||
def test_jit_forbidden_api_one2():
|
||||
"""
|
||||
Feature: mindspore.common.initializer.One
|
||||
Description: test jit forbidden api 'One' in graph mode.
|
||||
Expectation: throw RuntimeError
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
t = One()
|
||||
return t
|
||||
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
foo()
|
||||
assert "Failed to compile in GRAPH_MODE" in str(ex.value)
|
||||
assert "the class 'mindspore.common.initializer.One'" in str(ex.value)
|
||||
|
||||
|
||||
def test_jit_forbidden_api_initializer1():
|
||||
"""
|
||||
Feature: mindspore.common.initializer.initializer
|
||||
Description: test jit forbidden api 'initializer' in graph mode.
|
||||
Expectation: throw RuntimeError
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self):
|
||||
t = initializer('ones', [1, 2, 3], mstype.float32)
|
||||
return t
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
net()
|
||||
assert "Failed to compile in GRAPH_MODE" in str(ex.value)
|
||||
assert "the method or function 'mindspore.common.initializer.initializer'" in str(ex.value)
|
||||
|
||||
|
||||
def test_jit_forbidden_api_initializer2():
|
||||
"""
|
||||
Feature: mindspore.common.initializer.initializer
|
||||
Description: test jit forbidden api 'initializer' in graph mode.
|
||||
Expectation: throw RuntimeError
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
t = initializer('ones', [1, 2, 3], mstype.float32)
|
||||
return t
|
||||
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
foo()
|
||||
assert "Failed to compile in GRAPH_MODE" in str(ex.value)
|
||||
assert "the method or function 'mindspore.common.initializer.initializer'" in str(ex.value)
|
||||
|
||||
|
||||
def test_jit_forbidden_api_untrainable_params1():
|
||||
"""
|
||||
Feature: mindspore.nn.cell.Cell.untrainable_params
|
||||
Description: test jit forbidden api 'untrainable_params' in graph mode.
|
||||
Expectation: throw RuntimeError
|
||||
"""
|
||||
class InnerNet(nn.Cell):
|
||||
def construct(self):
|
||||
return True
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = InnerNet()
|
||||
|
||||
def construct(self):
|
||||
out = self.inner_net.untrainable_params()
|
||||
return out
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
net()
|
||||
assert "Failed to compile in GRAPH_MODE" in str(ex.value)
|
||||
assert "the method or function 'mindspore.nn.cell.Cell.untrainable_params'" in str(ex.value)
|
||||
|
||||
|
||||
def test_jit_forbidden_api_untrainable_params2():
|
||||
"""
|
||||
Feature: mindspore.nn.cell.Cell.untrainable_params
|
||||
Description: test jit forbidden api 'untrainable_params' in graph mode.
|
||||
Expectation: throw RuntimeError
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def construct(self):
|
||||
return True
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
return Net().untrainable_params()
|
||||
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
foo()
|
||||
assert "Failed to compile in GRAPH_MODE" in str(ex.value)
|
||||
assert "the method or function 'mindspore.nn.cell.Cell.untrainable_params'" in str(ex.value)
|
||||
|
||||
|
||||
def test_jit_forbidden_api_get_parameters1():
|
||||
"""
|
||||
Feature: mindspore.nn.cell.Cell.get_parameters
|
||||
Description: test jit forbidden api 'get_parameters' in graph mode.
|
||||
Expectation: throw RuntimeError
|
||||
"""
|
||||
class InnerNet(nn.Cell):
|
||||
def construct(self):
|
||||
return True
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = InnerNet()
|
||||
|
||||
def construct(self):
|
||||
out = self.inner_net.get_parameters()
|
||||
return out
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
net()
|
||||
assert "Failed to compile in GRAPH_MODE" in str(ex.value)
|
||||
assert "the method or function 'mindspore.nn.cell.Cell.get_parameters'" in str(ex.value)
|
||||
|
||||
|
||||
def test_jit_forbidden_api_get_parameters2():
|
||||
"""
|
||||
Feature: mindspore.nn.cell.Cell.untrainable_params
|
||||
Description: test jit forbidden api 'get_parameters' in graph mode.
|
||||
Expectation: throw RuntimeError
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def construct(self):
|
||||
return True
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
return Net().get_parameters()
|
||||
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
foo()
|
||||
assert "Failed to compile in GRAPH_MODE" in str(ex.value)
|
||||
assert "the method or function 'mindspore.nn.cell.Cell.get_parameters'" in str(ex.value)
|
Loading…
Reference in New Issue