blacklist

This commit is contained in:
ligan 2022-12-29 18:41:05 +08:00
parent aa8858ab05
commit e485af77f6
10 changed files with 305 additions and 3 deletions

View File

@ -427,10 +427,38 @@ ValuePtr ConvertConstantNumpyNumber(const py::object &obj, ResolveTypeDef obj_ty
return nullptr;
}
void CheckJITForbiddenAPI(const py::object &obj) {
auto module = python_adapter::GetPyModule(PYTHON_MOD_MODULE);
py::list obj_info = python_adapter::CallPyModFn(module, PYTHON_MOD_GET_MODULE_AND_NAME_INFO, obj);
std::ostringstream oss;
auto obj_module = py::cast<std::string>(obj_info[0]);
auto obj_name = py::cast<std::string>(obj_info[1]);
auto obj_type = py::cast<std::string>(obj_info[2]);
oss << "Failed to compile in GRAPH_MODE because the " << obj_type << " '" << obj_module << "." << obj_name
<< "' is not supported in 'construct' or function with @jit decorator. "
<< "Try to use the " << obj_type << " '" << obj_module << "." << obj_name << "' externally "
<< "such as initialized in the method '__init__' before assigning"
<< ".\nFor more details, please refer to "
<< "https://www.mindspore.cn/docs/zh-CN/master/design/dynamic_graph_and_static_graph.html \n";
// Check if the API is decoratored by @jit_forbidden_register.
bool is_jit_forbidden_register = data_converter::IsJITForbiddenAPI(obj);
if (is_jit_forbidden_register) {
MS_LOG(EXCEPTION) << oss.str();
}
// Check if the API's module is in the JIT forbidden module set.
bool is_jit_forbidden_module =
py::cast<bool>(python_adapter::CallPyModFn(module, PYTHON_MOD_IS_JIT_FORBIDDEN_MODULE, obj_info[0]));
if (is_jit_forbidden_module) {
MS_LOG(EXCEPTION) << oss.str();
}
}
ValuePtr ConvertOtherObj(const py::object &obj, bool forbid_reuse = false) {
auto obj_type = data_converter::GetObjType(obj);
MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " ";
if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
// Check JIT forbidden API
CheckJITForbiddenAPI(obj);
MS_LOG(DEBUG) << "Resolve the class type, need create class instance.";
std::string desc = py::str(obj);
// desc has format "<class xxxx>", strip the '<' and '>' by offset 1.
@ -438,6 +466,10 @@ ValuePtr ConvertOtherObj(const py::object &obj, bool forbid_reuse = false) {
}
if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD ||
(obj_type == RESOLVE_TYPE_CLASS_INSTANCE && py::hasattr(obj, PYTHON_PARSE_METHOD))) {
if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) {
// Check JIT forbidden API
CheckJITForbiddenAPI(obj);
}
MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type;
FuncGraphPtr func_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_PARSE_METHOD, forbid_reuse);
if (func_graph == nullptr) {
@ -733,6 +765,9 @@ bool IsNumpyArrayInstance(const py::object &obj) {
// Check if the object is MsClass instance.
bool IsMsClassInstance(const py::object &obj) { return py::hasattr(obj, PYTHON_MS_CLASS); }
// Check if the object is jit forbidden api.
bool IsJITForbiddenAPI(const py::object &obj) { return py::hasattr(obj, PYTHON_JIT_FORBIDDEN); }
// Check if the object is class type.
bool IsClassType(const py::object &obj) {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);

View File

@ -47,6 +47,7 @@ ClassInstanceTypeDef GetClassInstanceType(const py::object &obj);
bool IsCellInstance(const py::object &obj);
bool IsNumpyArrayInstance(const py::object &obj);
bool IsMsClassInstance(const py::object &obj);
bool IsJITForbiddenAPI(const py::object &obj);
bool IsClassType(const py::object &obj);
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs);
py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs);

View File

@ -55,6 +55,7 @@ enum ParseTargetTypeDef {
};
// Define python module name.
const char PYTHON_MOD_MODULE[] = "mindspore";
const char PYTHON_MOD_PARSE_MODULE[] = "mindspore._extends.parse";
const char PYTHON_MOD_PARSE_OBJECT_FUNCTION[] = "parse_cb";
const char PYTHON_MOD_RESOLVE_FUNCTION[] = "resolve_symbol";
@ -108,6 +109,9 @@ const char PYTHON_PARSE_GET_CONVERT_OBJECT_FOR_UNSUPPORTED_TYPE[] = "get_convert
const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj";
const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";
const char PYTHON_MOD_GET_MODULE_AND_NAME_INFO[] = "get_obj_module_and_name_info";
const char PYTHON_MOD_IS_JIT_FORBIDDEN_MODULE[] = "is_jit_forbidden_module";
// Define the common name.
const char NAMED_PRIMITIVE_LEN[] = "len";
const char NAMED_PRIMITIVE_BODY[] = "body";

View File

@ -19,6 +19,7 @@ namespace mindspore {
const char PYTHON_PRIMITIVE_FLAG[] = "__primitive_flag__";
const char PYTHON_CELL_AS_LIST[] = "__cell_as_list__";
const char PYTHON_MS_CLASS[] = "__ms_class__";
const char PYTHON_JIT_FORBIDDEN[] = "__jit_forbidden__";
const char PYTHON_ADAPTER_TENSOR[] = "__adapter_tensor__";
const char PYTHON_CLASS_MEMBER_NAMESPACE[] = "__class_member_namespace__";
const char PYTHON_FUNCTION_FORBID_REUSE[] = "__function_forbid_reuse__";

View File

@ -21,6 +21,7 @@ namespace mindspore {
extern const char PYTHON_PRIMITIVE_FLAG[];
extern const char PYTHON_CELL_AS_LIST[];
extern const char PYTHON_MS_CLASS[];
extern const char PYTHON_JIT_FORBIDDEN[];
extern const char PYTHON_ADAPTER_TENSOR[];
extern const char PYTHON_CLASS_MEMBER_NAMESPACE[];
extern const char PYTHON_FUNCTION_FORBID_REUSE[];

View File

@ -31,6 +31,7 @@ from mindspore.profiler import Profiler, EnvProfiler
from mindspore.parallel import set_algo_parameters, get_algo_parameters, reset_algo_parameters, \
rank_list_for_transform, transform_checkpoint_by_rank, transform_checkpoints, merge_pipeline_strategys, shard
from mindspore.rewrite import SymbolTree, ScopedValue, Node, NodeType, TreeNodeHelper
from mindspore._check_jit_forbidden_api import get_obj_module_and_name_info, is_jit_forbidden_module
__all__ = ["run_check"]

View File

@ -0,0 +1,60 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Check jit forbidden api."""
import types
from mindspore import log as logger
# module: such as "mindspore.common.initializer"
_jit_forbidden_module = set()
def jit_forbidden_register(fn):
setattr(fn, '__jit_forbidden__', True)
def jit_forbidden(*args, **kwargs):
return fn(*args, **kwargs)
return jit_forbidden
def add_jit_forbidden_module(jit_forbidden_module):
logger.debug(f'add jit_forbidden_module_set: {_jit_forbidden_module}')
return _jit_forbidden_module.add(jit_forbidden_module)
def remove_jit_forbidden_module(jit_forbidden_module):
logger.debug(f'remove jit_forbidden_module_set: {_jit_forbidden_module}')
return _jit_forbidden_module.remove(jit_forbidden_module)
def get_jit_forbidden_module(jit_forbidden_module):
logger.debug(f'get jit_forbidden_module_set: {_jit_forbidden_module}')
return _jit_forbidden_module
def get_obj_module_and_name_info(obj):
"""Return the description of the object whose type is class, function or method."""
if isinstance(obj, (types.FunctionType, types.MethodType)):
return obj.__module__, obj.__qualname__, "method or function"
return obj.__module__, obj.__name__, "class"
def is_jit_forbidden_module(obj_module):
"""Return the matching result of object module in jit forbidden module set."""
if obj_module in _jit_forbidden_module:
return True
return False
add_jit_forbidden_module("mindspore.common.initializer")

View File

@ -41,6 +41,7 @@ from mindspore.ops.operations import Cast
from mindspore.ops.primitive import Primitive
from mindspore.ops.operations import _inner_ops as inner
from mindspore.parallel.shard import Shard
from mindspore._check_jit_forbidden_api import jit_forbidden_register
class Cell(Cell_):
@ -1195,6 +1196,7 @@ class Cell(Cell_):
param.is_init = False
param.name = prefix + name
@jit_forbidden_register
def trainable_params(self, recurse=True):
"""
Returns all trainable parameters.
@ -1209,6 +1211,7 @@ class Cell(Cell_):
"""
return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
@jit_forbidden_register
def untrainable_params(self, recurse=True):
"""
Returns all untrainable parameters.
@ -1223,6 +1226,7 @@ class Cell(Cell_):
"""
return list(filter(lambda x: not x.requires_grad, self.get_parameters(expand=recurse)))
@jit_forbidden_register
def get_parameters(self, expand=True):
"""
Returns an iterator over cell parameters.

View File

@ -1,4 +1,4 @@
# Copyright 2021-2022 Huawei Technologies Co., Ltd
# Copyright 2021-2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -107,13 +107,17 @@ def test_fallback_tensor_with_init():
"""
Feature: JIT Fallback
Description: Test Tensor() with init in graph mode.
Expectation: No exception.
Expectation: throw RuntimeError.
"""
@jit
def foo():
me_x = Tensor(shape=(1, 3), dtype=mstype.float32, init=One())
return me_x
print(foo())
with pytest.raises(RuntimeError) as ex:
foo()
assert "Failed to compile in GRAPH_MODE" in str(ex.value)
assert "the class 'mindspore.common.initializer.One'" in str(ex.value)
def test_fallback_tensor_reshape():

View File

@ -0,0 +1,191 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test jit forbidden api in graph mode. """
import pytest
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import context, jit
from mindspore.common.initializer import initializer, One
context.set_context(mode=context.GRAPH_MODE)
def test_jit_forbidden_api_one1():
"""
Feature: mindspore.common.initializer.One
Description: test jit forbidden api 'One' in graph mode.
Expectation: throw RuntimeError
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self):
t = One()
return t
net = Net()
with pytest.raises(RuntimeError) as ex:
net()
assert "Failed to compile in GRAPH_MODE" in str(ex.value)
assert "the class 'mindspore.common.initializer.One'" in str(ex.value)
def test_jit_forbidden_api_one2():
"""
Feature: mindspore.common.initializer.One
Description: test jit forbidden api 'One' in graph mode.
Expectation: throw RuntimeError
"""
@jit
def foo():
t = One()
return t
with pytest.raises(RuntimeError) as ex:
foo()
assert "Failed to compile in GRAPH_MODE" in str(ex.value)
assert "the class 'mindspore.common.initializer.One'" in str(ex.value)
def test_jit_forbidden_api_initializer1():
"""
Feature: mindspore.common.initializer.initializer
Description: test jit forbidden api 'initializer' in graph mode.
Expectation: throw RuntimeError
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self):
t = initializer('ones', [1, 2, 3], mstype.float32)
return t
net = Net()
with pytest.raises(RuntimeError) as ex:
net()
assert "Failed to compile in GRAPH_MODE" in str(ex.value)
assert "the method or function 'mindspore.common.initializer.initializer'" in str(ex.value)
def test_jit_forbidden_api_initializer2():
"""
Feature: mindspore.common.initializer.initializer
Description: test jit forbidden api 'initializer' in graph mode.
Expectation: throw RuntimeError
"""
@jit
def foo():
t = initializer('ones', [1, 2, 3], mstype.float32)
return t
with pytest.raises(RuntimeError) as ex:
foo()
assert "Failed to compile in GRAPH_MODE" in str(ex.value)
assert "the method or function 'mindspore.common.initializer.initializer'" in str(ex.value)
def test_jit_forbidden_api_untrainable_params1():
"""
Feature: mindspore.nn.cell.Cell.untrainable_params
Description: test jit forbidden api 'untrainable_params' in graph mode.
Expectation: throw RuntimeError
"""
class InnerNet(nn.Cell):
def construct(self):
return True
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self):
out = self.inner_net.untrainable_params()
return out
net = Net()
with pytest.raises(RuntimeError) as ex:
net()
assert "Failed to compile in GRAPH_MODE" in str(ex.value)
assert "the method or function 'mindspore.nn.cell.Cell.untrainable_params'" in str(ex.value)
def test_jit_forbidden_api_untrainable_params2():
"""
Feature: mindspore.nn.cell.Cell.untrainable_params
Description: test jit forbidden api 'untrainable_params' in graph mode.
Expectation: throw RuntimeError
"""
class Net(nn.Cell):
def construct(self):
return True
@jit
def foo():
return Net().untrainable_params()
with pytest.raises(RuntimeError) as ex:
foo()
assert "Failed to compile in GRAPH_MODE" in str(ex.value)
assert "the method or function 'mindspore.nn.cell.Cell.untrainable_params'" in str(ex.value)
def test_jit_forbidden_api_get_parameters1():
"""
Feature: mindspore.nn.cell.Cell.get_parameters
Description: test jit forbidden api 'get_parameters' in graph mode.
Expectation: throw RuntimeError
"""
class InnerNet(nn.Cell):
def construct(self):
return True
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self):
out = self.inner_net.get_parameters()
return out
net = Net()
with pytest.raises(RuntimeError) as ex:
net()
assert "Failed to compile in GRAPH_MODE" in str(ex.value)
assert "the method or function 'mindspore.nn.cell.Cell.get_parameters'" in str(ex.value)
def test_jit_forbidden_api_get_parameters2():
"""
Feature: mindspore.nn.cell.Cell.untrainable_params
Description: test jit forbidden api 'get_parameters' in graph mode.
Expectation: throw RuntimeError
"""
class Net(nn.Cell):
def construct(self):
return True
@jit
def foo():
return Net().get_parameters()
with pytest.raises(RuntimeError) as ex:
foo()
assert "Failed to compile in GRAPH_MODE" in str(ex.value)
assert "the method or function 'mindspore.nn.cell.Cell.get_parameters'" in str(ex.value)