forked from mindspore-Ecosystem/mindspore
Enable JIT Fallback in default.
This commit is contained in:
parent
c3ce9c56c6
commit
e8b421fe6e
|
@ -164,7 +164,7 @@ def resolve_symbol(namespace, symbol):
|
|||
return resolve_
|
||||
|
||||
# Raise a proper error if not using Fallback feature.
|
||||
if support_fallback_ != '1':
|
||||
if support_fallback_ == '0':
|
||||
# Raise NotImplementedError when parsing the numpy methods, but not the numpy constant.
|
||||
if namespace.name == "numpy" and \
|
||||
isinstance(resolve_, (types.FunctionType, types.MethodType, types.ModuleType)):
|
||||
|
@ -262,7 +262,7 @@ def get_obj_type(obj):
|
|||
obj_type = RESOLVE_TYPE_CLASS_INSTANCE
|
||||
else:
|
||||
# Raise a proper error if not using Fallback feature.
|
||||
if support_fallback_ == '1':
|
||||
if support_fallback_ != '0':
|
||||
obj_type = RESOLVE_TYPE_INVALID
|
||||
else:
|
||||
# here for ndarray, just print its shape (in case of the array to large and print many data in screen)
|
||||
|
|
|
@ -383,7 +383,7 @@ ValuePtr ConvertOtherObj(const py::object &obj) {
|
|||
// The fallback feature is enabled in default.
|
||||
// Not support change the flag during the process is alive.
|
||||
static const auto support_fallback = common::GetEnv("ENV_SUPPORT_FALLBACK");
|
||||
static const auto use_fallback = (support_fallback == "1");
|
||||
static const auto use_fallback = (support_fallback != "0");
|
||||
if (use_fallback) {
|
||||
auto res = std::make_shared<InterpretedObject>(obj, py::str(obj));
|
||||
MS_LOG(DEBUG) << "Get interpreted object: " << res->ToString();
|
||||
|
|
|
@ -75,7 +75,7 @@ void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr
|
|||
|
||||
// The fallback feature is enabled in default.
|
||||
// Not support change the flag during the process is alive.
|
||||
static const auto use_fallback = (parser_.support_fallback() == "1");
|
||||
static const auto use_fallback = (parser_.support_fallback() != "0");
|
||||
|
||||
auto [iter, is_new_name] = assigned_vars_.emplace(var_name, std::make_pair(node, false));
|
||||
if (!is_new_name) {
|
||||
|
@ -133,7 +133,7 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) {
|
|||
|
||||
// The fallback feature is enabled in default.
|
||||
// Not support change the flag during the process is alive.
|
||||
static const auto use_fallback = (parser_.support_fallback() == "1");
|
||||
static const auto use_fallback = (parser_.support_fallback() != "0");
|
||||
if (use_fallback) {
|
||||
MS_LOG(DEBUG) << "Update global params of block: " << ToString()
|
||||
<< ", with previous block: " << block->ToString() << ",\nCurrent: " << py::str(global_py_params())
|
||||
|
@ -242,7 +242,7 @@ AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
|
|||
|
||||
// The fallback feature is enabled in default.
|
||||
// Not support change the flag during the process is alive.
|
||||
static const auto use_fallback = (parser_.support_fallback() == "1");
|
||||
static const auto use_fallback = (parser_.support_fallback() != "0");
|
||||
if (!use_fallback) {
|
||||
py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value);
|
||||
return HandleNamespaceInfo(namespace_info);
|
||||
|
|
|
@ -1825,7 +1825,7 @@ AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodeP
|
|||
const py::object &value_object) {
|
||||
// The fallback feature is enabled in default.
|
||||
// Not support change the flag during the process is alive.
|
||||
static const auto use_fallback = (support_fallback() == "1");
|
||||
static const auto use_fallback = (support_fallback() != "0");
|
||||
if (!use_fallback || !value_node->interpret()) {
|
||||
return value_node;
|
||||
}
|
||||
|
@ -1842,19 +1842,19 @@ AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodeP
|
|||
auto [keys, values] = block->local_py_params();
|
||||
auto local_dict_node = ParseDictByKeysAndValues(block, keys, values);
|
||||
// Update the valued node if it need interpreting.
|
||||
constexpr int recursive_level = 3;
|
||||
constexpr int recursive_level = 2;
|
||||
MS_LOG(INFO) << "[" << block->func_graph()->ToString() << "] script_text: `" << script_text
|
||||
<< "`,\nvalue_node: " << value_node->DebugString(recursive_level)
|
||||
<< ",\nglobal_dict_node: " << global_dict_node->ToString()
|
||||
<< ",\nlocal_dict_node: " << local_dict_node->ToString();
|
||||
<< ",\nlocal_dict_node: " << local_dict_node->DebugString(recursive_level);
|
||||
AnfNodePtr interpreted_node = block->MakeInterpret(script_text, global_dict_node, local_dict_node, value_node);
|
||||
|
||||
// Print a hint for user.
|
||||
auto line_info = trace::GetDebugInfo(value_node->debug_info());
|
||||
MS_LOG(DEBUG) << "Found unsupported syntax in Graph mode, those codes would be fell back to Python interpreter:"
|
||||
<< "\n\n"
|
||||
<< line_info;
|
||||
InterpretNodeRecorder::GetInstance().Push(line_info);
|
||||
MS_LOG(INFO) << "Found unsupported syntax in Graph mode, those codes would be fallen back to Python interpreter:"
|
||||
<< "\n\n"
|
||||
<< line_info;
|
||||
InterpretNodeRecorder::GetInstance().PushLineInfo(line_info);
|
||||
return interpreted_node;
|
||||
}
|
||||
|
||||
|
|
|
@ -60,7 +60,7 @@ abstract::AbstractBasePtr ClassType::ToAbstract() {
|
|||
// The fallback feature is enabled in default.
|
||||
// Not support change the flag during the process is alive.
|
||||
static const auto support_fallback = common::GetEnv("ENV_SUPPORT_FALLBACK");
|
||||
static const auto use_fallback = (support_fallback == "1");
|
||||
static const auto use_fallback = (support_fallback != "0");
|
||||
if (use_fallback && !IsSupportedCreateInstanceType(obj())) {
|
||||
return abs_scalar;
|
||||
}
|
||||
|
|
|
@ -864,6 +864,24 @@ void CacheValidateFuncGraph(const std::string &phase, const ResourcePtr &resourc
|
|||
}
|
||||
}
|
||||
|
||||
void CheckInterpretNodeLineInfos() {
|
||||
auto &line_infos = InterpretNodeRecorder::GetInstance().LineInfos();
|
||||
if (line_infos.empty()) {
|
||||
return;
|
||||
}
|
||||
std::stringstream ss;
|
||||
ss << "Found unsupported syntax in Graph mode, those codes would be fallen back to Python interpreter:\n";
|
||||
size_t num = 1;
|
||||
for (auto &line : line_infos) {
|
||||
ss << "\t#" << num << ": " << line << "\n";
|
||||
++num;
|
||||
}
|
||||
ss << "\n";
|
||||
// Print the codes run in JIT Fallback with ERROR level.
|
||||
MS_LOG(ERROR) << ss.str();
|
||||
InterpretNodeRecorder::GetInstance().Clear();
|
||||
}
|
||||
|
||||
void Pipeline::Run(const std::string &phase) {
|
||||
MS_LOG(INFO) << "Pipeline run";
|
||||
MS_EXCEPTION_IF_NULL(resource_);
|
||||
|
@ -885,6 +903,7 @@ void Pipeline::Run(const std::string &phase) {
|
|||
if (action.first == "task_emit") {
|
||||
SetLoopCount(resource_);
|
||||
} else if (action.first == "validate") {
|
||||
CheckInterpretNodeLineInfos();
|
||||
CacheValidateFuncGraph(phase, resource_);
|
||||
}
|
||||
if (!result) {
|
||||
|
|
|
@ -1317,10 +1317,6 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
|||
MS_EXCEPTION_IF_NULL(global_dict);
|
||||
MS_LOG(DEBUG) << "arg_1, global_dict: " << global_dict->ToString() << ", [" << global_dict->type_name() << "]";
|
||||
ValuePtr global_dict_value = global_dict->BuildValue();
|
||||
if (global_dict_value == kAnyValue) {
|
||||
MS_LOG(EXCEPTION) << "Not support Tensor or variable type as input during running JIT Fallback, but got "
|
||||
<< global_dict->ToString();
|
||||
}
|
||||
py::object global_params_dict = ValueToPyData(global_dict_value);
|
||||
MS_LOG(DEBUG) << "arg_1, python global_params_dict: " << global_dict_value->ToString() << " -> "
|
||||
<< py::str(global_params_dict);
|
||||
|
@ -1331,10 +1327,6 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
|||
MS_EXCEPTION_IF_NULL(local_dict);
|
||||
MS_LOG(DEBUG) << "arg_2, local_dict: " << local_dict->ToString() << ", [" << local_dict->type_name() << "]";
|
||||
ValuePtr local_dict_value = local_dict->BuildValue();
|
||||
if (local_dict_value == kAnyValue) {
|
||||
MS_LOG(EXCEPTION) << "Not support Tensor or variable type as input during running JIT Fallback, but got "
|
||||
<< local_dict->ToString();
|
||||
}
|
||||
py::object local_params_dict = ValueToPyData(local_dict_value);
|
||||
MS_LOG(DEBUG) << "arg_2, python local_params_dict: " << local_dict_value->ToString() << " -> "
|
||||
<< py::str(local_params_dict);
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#ifndef MINDSPORE_CORE_UTILS_InterpretNodeRecorder_H_
|
||||
#define MINDSPORE_CORE_UTILS_InterpretNodeRecorder_H_
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_set>
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -32,7 +32,9 @@ class InterpretNodeRecorder {
|
|||
return instance;
|
||||
}
|
||||
|
||||
void Push(const std::string &line) { interpret_nodes_lines_.emplace_back(line); }
|
||||
void PushLineInfo(const std::string &line) { interpret_nodes_lines_.emplace(line); }
|
||||
|
||||
const std::unordered_set<std::string> &LineInfos() const { return interpret_nodes_lines_; }
|
||||
|
||||
void Clear() { interpret_nodes_lines_.clear(); }
|
||||
|
||||
|
@ -41,7 +43,7 @@ class InterpretNodeRecorder {
|
|||
virtual ~InterpretNodeRecorder() = default;
|
||||
|
||||
private:
|
||||
std::vector<std::string> interpret_nodes_lines_;
|
||||
std::unordered_set<std::string> interpret_nodes_lines_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_UTILS_InterpretNodeRecorder_H_
|
||||
|
|
|
@ -80,7 +80,6 @@ def np_fallback_func():
|
|||
return me_x
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_np_fallback_func():
|
||||
print(np_fallback_func())
|
||||
|
||||
|
@ -94,7 +93,6 @@ def div_mod_func1():
|
|||
return Tensor(a)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_div_mod_func1():
|
||||
print(div_mod_func1()) # (2, 2)
|
||||
|
||||
|
@ -106,7 +104,6 @@ def div_mod_func2(x, y):
|
|||
return Tensor(a)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_div_mod_func2_scalar():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -116,14 +113,16 @@ def test_div_mod_func2_scalar():
|
|||
print(div_mod_func2(8, 3)) # (2, 2)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
|
||||
def test_div_mod_func2_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test divmod in graph.
|
||||
Expectation: No exception.
|
||||
Description: Test divmod with Tensor input in graph. We'll support it in Tensor Input Fallback solution.
|
||||
Expectation: Not supported exception.
|
||||
"""
|
||||
print(div_mod_func2(Tensor(8), Tensor(3))) # name 'x' is not defined
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
print(div_mod_func2(Tensor(8), Tensor(3)))
|
||||
assert "Not support Tensor or variable type as input during running JIT Fallback, but got" in str(err.value)
|
||||
|
||||
|
||||
# NameError: name 'Tensor' is not defined.
|
||||
|
|
|
@ -45,12 +45,11 @@ def test_use_numpy_method():
|
|||
return ret
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(NotImplementedError) as err:
|
||||
# Not raise NotImplementedError('Mindspore not supports to use the numpy ...') any more,
|
||||
# but raise RuntimeError('Should not use Python object in runtime...'), after support JIT Fallback.
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
net()
|
||||
assert "Mindspore does not support to use the numpy methods " \
|
||||
"within the construct() or @ms_function decorated function in graph mode." \
|
||||
in str(err.value)
|
||||
|
||||
assert "Should not use Python object in runtime" in str(err.value)
|
||||
|
||||
def test_use_numpy_module():
|
||||
class Net(nn.Cell):
|
||||
|
@ -62,8 +61,8 @@ def test_use_numpy_module():
|
|||
return ret
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(NotImplementedError) as err:
|
||||
# Not raise NotImplementedError('Mindspore not supports to use the numpy ...') any more,
|
||||
# but raise RuntimeError('Should not use Python object in runtime...'), after support JIT Fallback.
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
net()
|
||||
assert "Mindspore does not support to use the numpy methods " \
|
||||
"within the construct() or @ms_function decorated function in graph mode." \
|
||||
in str(err.value)
|
||||
assert "Should not use Python object in runtime" in str(err.value)
|
||||
|
|
|
@ -330,6 +330,7 @@ def test_insert_defined_var_compute():
|
|||
net(Tensor([1, 2, 3], mstype.float32))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
|
||||
def test_call_unsupported_builtin_function_in_while():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -352,6 +353,7 @@ def test_call_unsupported_builtin_function_in_while():
|
|||
assert "ret = divmod(x, y)" in str(err.value)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
|
||||
def test_call_unsupported_builtin_function_in_if_in_for():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -374,6 +376,7 @@ def test_call_unsupported_builtin_function_in_if_in_for():
|
|||
assert "x = divmod(x, i)" in str(err.value)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
|
||||
def test_use_defined_class_obj_in_for():
|
||||
class Test:
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue