!48184 [JIT Fallback] Support any type feature in runtime procedure, handle the exception for Fallback process, and hide no_recursive api.

Merge pull request !48184 from 张清华/opt_jit_fallback
This commit is contained in:
i-robot 2023-01-31 06:09:21 +00:00 committed by Gitee
commit f030af3a4b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
12 changed files with 195 additions and 37 deletions

View File

@ -20,6 +20,7 @@ mindspore/mindspore/ccsrc/pipeline/jit/parse/resolve.cc:mindspore::parse::Resolv
mindspore/mindspore/ccsrc/pipeline/jit/pipeline.cc:mindspore::pipeline::GraphExecutorPy::Compile
mindspore/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc:mindspore::abstract::ConvertAbstractToPython
mindspore/mindspore/ccsrc/pybind_api/ir/log_adapter_py.h:mindspore::PyExceptionInitializer::HandleExceptionPy
mindspore/mindspore/ccsrc/pybind_api/ir/py_execute_py.cc:mindspore::PyExecuteInitializer::CppInferShapeAndTypePy
mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/math/unary_op_gpu_kernel.h:mindspore::kernel::UnaryOpGpuKernel::Launch
mindspore/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/dynamic_rnn_grad_fission_v2.cc:mindspore::opt::AddLSTMInputGradNode
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/drop_out_gen_mask_kernels.cc:aicpu::ARMDropOutGenMaskKernel

View File

@ -35,6 +35,8 @@
namespace mindspore {
namespace opt::dynamic_shape {
InfPyHandler cpp_infer_py_handler_{nullptr};
void set_cpp_infer_py_handler(const InfPyHandler &infer_handler) { cpp_infer_py_handler_ = infer_handler; }
namespace {
constexpr int64_t kInvalidShape = -2;
@ -279,6 +281,7 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
auto primitive = GetValueNode<PrimitivePtr>(inputs[0]);
auto input_size = common::AnfAlgo::GetInputTensorNum(cnode);
bool skip_nop_node = !context->get_param<bool>(MS_CTX_ENABLE_MINDRT);
bool has_py_execute_data = false;
for (size_t i = 0; i < input_size; i++) {
auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i, false);
auto real_input = input_node_with_index.first;
@ -297,6 +300,9 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
}
auto updated_abs = MakeNewAbstract(real_input, depended_value, real_input_index);
if (updated_abs->has_user_data<kernel::PyExecuteOutputData>()) {
has_py_execute_data = true;
}
(void)args_spec_list.emplace_back(updated_abs);
} else {
auto abs = real_input->abstract();
@ -312,9 +318,17 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
}
}
// Pynative mode is rely on the origin abstract of cnode, so cannot modify the abstract inplace, clone from old
// abstract instead.
opt::CppInferShape(primitive, args_spec_list, cnode);
if (!has_py_execute_data && !IsPrimitiveCNode(cnode, prim::kPrimPyExecute)) {
// Pynative mode is rely on the origin abstract of cnode, so cannot modify the abstract inplace, clone from old
// abstract instead.
opt::CppInferShape(primitive, args_spec_list, cnode);
} else {
if (cpp_infer_py_handler_ == nullptr) {
MS_LOG(EXCEPTION) << "\'cpp_infer_py_handler_\' should not be null.";
}
const auto &abs = cpp_infer_py_handler_(cnode, primitive, args_spec_list);
cnode->set_abstract(abs);
}
}
inline bool IsDeprecatedCpuOrGpuKernelMod(kernel::KernelModType kernel_mod_type) {

View File

@ -54,5 +54,9 @@ class CustomActorNodeManager {
DISABLE_COPY_AND_ASSIGN(CustomActorNodeManager)
OrderedMap<AnfNodePtr, RelatedCustomActorNode> custom_nodes_map_;
};
using InfPyHandler = abstract::AbstractBasePtr (*)(const CNodePtr &, const PrimitivePtr &, const AbstractBasePtrList &);
extern InfPyHandler cpp_infer_py_handler_;
BACKEND_EXPORT void set_cpp_infer_py_handler(const InfPyHandler &infer_handler);
} // namespace mindspore::opt::dynamic_shape
#endif // MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_DYNAMIC_SHAPE_DYNAMIC_SHAPE_HELPER_H

View File

@ -27,6 +27,7 @@
#include "include/common/utils/python_fallback_running.h"
#include "plugin/factory/ms_factory.h"
#include "mindspore/ccsrc/pipeline/jit/parse/resolve.h"
#include "utils/trace_base.h"
namespace mindspore {
namespace kernel {
@ -327,14 +328,28 @@ bool PyExecuteCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const
params[0] = global_dict;
params[1] = local_dict;
MS_LOG(DEBUG) << "Python script: " << py_script << ", params: " << params;
mindspore::ScopedFallbackRunning fallback_running;
const auto &output = CallPythonScript(py_script, params);
const auto &output_type = py::str(output.get_type());
MS_LOG(DEBUG) << "Python output type: " << output_type << ", output: " << output;
if (output_type.cast<std::string>() == "<class 'mindspore.common.tensor.Tensor'>") { // It's Python Tensor type.
TensorToRawMemory(output.cast<tensor::TensorPtr>(), outputs[0]);
try {
mindspore::ScopedFallbackRunning fallback_running;
const auto &output = CallPythonScript(py_script, params);
const auto &output_type = py::str(output.get_type());
MS_LOG(DEBUG) << "Python output type: " << output_type << ", output: " << output;
if (py::isinstance<tensor::Tensor>(output)) {
TensorToRawMemory(output.cast<tensor::TensorPtr>(), outputs[0]);
}
AttachPyOutputData(output);
} catch (const py::error_already_set &e) {
auto error_type_name = py::cast<std::string>(python_adapter::GetPyObjAttr(e.type(), "__name__"));
auto error_iter = exception_types_map.find(error_type_name);
if (error_iter != exception_types_map.end()) {
auto &handler = LogWriter::GetExceptionHandler();
if (handler != nullptr) {
std::stringstream ss;
ss << py::str(e.value()) << ".\n\n" << trace::GetDebugInfo(kernel_node_->debug_info());
handler(error_iter->second, ss.str());
}
}
throw std::runtime_error(py::str(e.value()));
}
AttachPyOutputData(output);
return true;
}

View File

@ -24,6 +24,7 @@
#include "mindspore/ccsrc/pipeline/jit/parse/data_converter.h"
#include "mindspore/ccsrc/pybind_api/ir/tensor_py.h"
#include "mindspore/ccsrc/plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.h"
#include "mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
namespace py = pybind11;
namespace mindspore {
@ -38,12 +39,15 @@ py::object CallPythonGetGlobalParams() {
class PyExecuteInitializer {
public:
PyExecuteInitializer() { mindspore::ops::PyExecuteInfer::set_infer_handler(InferPy); }
PyExecuteInitializer() {
mindspore::ops::PyExecuteInfer::set_infer_handler(PyExecuteInferPy);
mindspore::opt::dynamic_shape::set_cpp_infer_py_handler(CppInferShapeAndTypePy);
}
~PyExecuteInitializer() = default;
private:
static abstract::ShapePtr InferPy(const std::vector<AbstractBasePtr> &input_args) {
static abstract::AbstractBasePtr PyExecuteInferPy(const std::vector<AbstractBasePtr> &input_args) {
const auto &script_abs = input_args[0];
const auto &script = script_abs->BuildValue();
const auto &script_str = dyn_cast<StringImm>(script);
@ -53,7 +57,8 @@ class PyExecuteInitializer {
const auto &keys = dyn_cast<ValueSequence>(keys_tuple);
if (keys == nullptr) {
MS_LOG(DEBUG) << "The keys is not tuple value, but got " << keys_tuple->ToString();
return std::make_shared<abstract::Shape>(ShapeVector({1}));
const auto &infer_shape = std::make_shared<abstract::Shape>(ShapeVector({1}));
return abstract::MakeAbstract(infer_shape, kFloat64);
}
const auto &values_tuple_abs = input_args[2];
const auto &values_tuple = values_tuple_abs->BuildValue();
@ -63,7 +68,8 @@ class PyExecuteInitializer {
const auto &values = dyn_cast<ValueSequence>(values_tuple);
if (values == nullptr) {
MS_LOG(DEBUG) << "The values is not tuple value, but got " << keys_tuple->ToString();
return std::make_shared<abstract::Shape>(ShapeVector({1}));
const auto &infer_shape = std::make_shared<abstract::Shape>(ShapeVector({1}));
return abstract::MakeAbstract(infer_shape, kFloat64);
}
MS_LOG(DEBUG) << "script: " << script->ToString() << ", keys_tuple: " << keys_tuple->ToString()
<< ", values_tuple: " << values_tuple->ToString();
@ -100,15 +106,125 @@ class PyExecuteInitializer {
params[0] = global_dict;
params[1] = local_dict;
MS_LOG(DEBUG) << "Python script: " << py_script << ", params: " << params;
mindspore::ScopedFallbackRunning fallback_running;
const auto &output = parse::data_converter::CallPythonScript(py_script, params);
MS_LOG(DEBUG) << "Python output type: " << py::str(output.get_type()) << ", output: " << output;
if (py::isinstance<tensor::Tensor>(output)) {
const auto &tensor = output.cast<tensor::TensorPtr>();
return std::make_shared<abstract::Shape>(tensor->shape());
try {
mindspore::ScopedFallbackRunning fallback_running;
const auto &output = parse::data_converter::CallPythonScript(py_script, params);
MS_LOG(DEBUG) << "Python output type: " << py::str(output.get_type()) << ", output: " << output;
if (py::isinstance<tensor::Tensor>(output)) {
const auto &tensor = output.cast<tensor::TensorPtr>();
const auto &infer_shape = std::make_shared<abstract::Shape>(tensor->shape());
return abstract::MakeAbstract(infer_shape, tensor->Dtype());
}
} catch (const py::error_already_set &e) {
auto error_type_name = py::cast<std::string>(python_adapter::GetPyObjAttr(e.type(), "__name__"));
auto error_iter = exception_types_map.find(error_type_name);
if (error_iter != exception_types_map.end()) {
auto &handler = LogWriter::GetExceptionHandler();
if (handler != nullptr) {
handler(error_iter->second, py::str(e.value()));
}
}
throw std::runtime_error(py::str(e.value()));
}
return std::make_shared<abstract::Shape>(ShapeVector({1}));
const auto &infer_shape = std::make_shared<abstract::Shape>(ShapeVector({1}));
return abstract::MakeAbstract(infer_shape, kFloat64);
}
static abstract::AbstractBasePtr CppInferShapeAndTypePy(const CNodePtr &cnode, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// We can't catch the pybind11 exception by py::builtin_exception or its base class,
// so we have to list all pybind11 exceptions and catch one by one here.
try {
const auto &abs = opt::CppInferShapeAndType(primitive, args_spec_list);
MS_LOG(DEBUG) << "The abstract of " << cnode->fullname_with_scope() << " changes from " << cnode->abstract()
<< " to " << abs;
return abs;
} catch (const py::type_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::type_error(ss.str());
} catch (const py::value_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::value_error(ss.str());
} catch (const py::index_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::index_error(ss.str());
} catch (const py::key_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::key_error(ss.str());
} catch (const py::attribute_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::attribute_error(ss.str());
} catch (const py::name_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::name_error(ss.str());
} catch (const py::assertion_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::assertion_error(ss.str());
} catch (const py::base_exception &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::base_exception(ss.str());
} catch (const py::keyboard_interrupt &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::keyboard_interrupt(ss.str());
} catch (const py::stop_iteration &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::stop_iteration(ss.str());
} catch (const py::overflow_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::overflow_error(ss.str());
} catch (const py::zero_division_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::zero_division_error(ss.str());
} catch (const py::environment_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::environment_error(ss.str());
} catch (const py::io_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::io_error(ss.str());
} catch (const py::os_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::os_error(ss.str());
} catch (const py::memory_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::memory_error(ss.str());
} catch (const py::unbound_local_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::unbound_local_error(ss.str());
} catch (const py::not_implemented_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::not_implemented_error(ss.str());
} catch (const py::indentation_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::indentation_error(ss.str());
} catch (const py::runtime_warning &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::runtime_warning(ss.str());
} catch (const std::runtime_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw std::runtime_error(ss.str());
}
}
};

View File

@ -53,7 +53,7 @@ T ShapeSize(const std::vector<T> &shape) {
return std::accumulate(shape.begin(), shape.end(), static_cast<T>(1), std::multiplies<T>());
}
AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type);
MS_CORE_API AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type);
MS_CORE_API AbstractBasePtr MakeMonadAbstract(const MonadTypePtr &type);
MS_CORE_API AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type);

View File

@ -26,7 +26,7 @@ namespace mindspore {
namespace ops {
MIND_API_OPERATOR_IMPL(PyExecute, BaseOperator);
BaseShapePtr PyExecuteInfer::InferShape(const PrimitivePtr &primitive,
AbstractBasePtr PyExecuteInfer::InferPy(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const {
MS_EXCEPTION_IF_NULL(primitive);
for (const auto &item : input_args) {
@ -37,17 +37,24 @@ BaseShapePtr PyExecuteInfer::InferShape(const PrimitivePtr &primitive,
if (infer_handler_ == nullptr) {
MS_LOG(EXCEPTION) << "infer_handler_ should not be null.";
}
return infer_handler_(input_args);
const auto &abs = infer_handler_(input_args);
return abs;
}
BaseShapePtr PyExecuteInfer::InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const {
const auto &abs = InferPy(primitive, input_args);
return abs->BuildShape();
}
TypePtr PyExecuteInfer::InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
return kFloat64;
MS_LOG(EXCEPTION) << "Should not invoke InferType().";
}
AbstractBasePtr PyExecuteInfer::InferShapeAndType(const abstract::AnalysisEnginePtr &engine,
const PrimitivePtr &primitive,
AbstractBasePtr PyExecuteInfer::InferShapeAndType(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const {
MS_LOG(EXCEPTION) << "Should not invoke InferShapeAndType.";
const auto &abs = infer_handler_(input_args);
return abs;
}
std::set<int64_t> PyExecuteInfer::GetValueDependArgIndices() const { return {-1}; }

View File

@ -47,9 +47,11 @@ class MIND_API PyExecuteInfer : public abstract::OpInferBase {
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override;
AbstractBasePtr InferPy(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const;
std::set<int64_t> GetValueDependArgIndices() const override;
using InferHandler = abstract::ShapePtr (*)(const std::vector<AbstractBasePtr> &);
using InferHandler = abstract::AbstractBasePtr (*)(const std::vector<AbstractBasePtr> &);
static void set_infer_handler(const InferHandler &infer_handler) { infer_handler_ = infer_handler; }
private:

View File

@ -15,7 +15,7 @@
"""Top-level reference to dtype of common module."""
from __future__ import absolute_import
from mindspore.common import dtype
from mindspore.common.api import no_recursive, ms_function, ms_memory_recycle, ms_class, jit, jit_class
from mindspore.common.api import ms_function, ms_memory_recycle, ms_class, jit, jit_class
from mindspore.common.dtype import Type, int8, byte, int16, short, int32, intc, int64, intp, \
uint8, ubyte, uint16, ushort, uint32, uintc, uint64, uintp, float16, half, \
float32, single, float64, double, bool_, float_, list_, tuple_, int_, \
@ -59,7 +59,7 @@ __all__ = [
__all__.extend([
"Tensor", "RowTensor", "SparseTensor", "COOTensor", "CSRTensor", # tensor
"no_recursive", "ms_function", "ms_class", 'jit', 'jit_class', # api
"ms_function", "ms_class", 'jit', 'jit_class', # api
"Parameter", "ParameterTuple", # parameter
"dtype",
"set_seed", "get_seed", # random seed

View File

@ -736,7 +736,7 @@ def _add_flags(fn=None, **flags):
return ret
def no_recursive(callable_obj):
def _no_recursive(callable_obj):
"""
Method or function decorator for ignoring recursive check.

View File

@ -483,7 +483,6 @@ def test_call_no_self_other_object_method_runtime():
assert np.all(result == z)
@pytest.mark.skip(reason="Not supported by now")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@ -502,11 +501,10 @@ def test_getattr_tensor_with_wrong_attr():
return abs_func()
with pytest.raises(AttributeError) as err:
foo(Tensor([-1, -2, -3])) # Not throw error any more, should move to ST.
foo(ms.Tensor([-1, -2, -3])) # Not throw error any more, should move to ST.
assert "object has no attribute" in str(err.value)
@pytest.mark.skip(reason="Not supported by now")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training

View File

@ -14,6 +14,7 @@
# ============================================================================
import numpy as np
import mindspore as ms
from mindspore.common.api import _no_recursive as no_recursive
ms.set_context(mode=ms.GRAPH_MODE)
@ -32,7 +33,7 @@ def test_cell_no_recursive():
Description: test no_recursive flag.
Expectation: No exception.
"""
@ms.no_recursive
@no_recursive
class Net(ms.nn.Cell):
def __init__(self):
super().__init__()
@ -48,7 +49,7 @@ def test_cell_no_recursive():
@ms.jit
@ms.no_recursive
@no_recursive
def func(x, y):
res = double(x) + double(y)
print(res)