forked from mindspore-Ecosystem/mindspore
!48184 [JIT Fallback] Support any type feature in runtime procedure, handle the exception for Fallback process, and hide no_recursive api.
Merge pull request !48184 from 张清华/opt_jit_fallback
This commit is contained in:
commit
f030af3a4b
|
@ -20,6 +20,7 @@ mindspore/mindspore/ccsrc/pipeline/jit/parse/resolve.cc:mindspore::parse::Resolv
|
|||
mindspore/mindspore/ccsrc/pipeline/jit/pipeline.cc:mindspore::pipeline::GraphExecutorPy::Compile
|
||||
mindspore/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc:mindspore::abstract::ConvertAbstractToPython
|
||||
mindspore/mindspore/ccsrc/pybind_api/ir/log_adapter_py.h:mindspore::PyExceptionInitializer::HandleExceptionPy
|
||||
mindspore/mindspore/ccsrc/pybind_api/ir/py_execute_py.cc:mindspore::PyExecuteInitializer::CppInferShapeAndTypePy
|
||||
mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/math/unary_op_gpu_kernel.h:mindspore::kernel::UnaryOpGpuKernel::Launch
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/dynamic_rnn_grad_fission_v2.cc:mindspore::opt::AddLSTMInputGradNode
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/drop_out_gen_mask_kernels.cc:aicpu::ARMDropOutGenMaskKernel
|
||||
|
|
|
@ -35,6 +35,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt::dynamic_shape {
|
||||
InfPyHandler cpp_infer_py_handler_{nullptr};
|
||||
void set_cpp_infer_py_handler(const InfPyHandler &infer_handler) { cpp_infer_py_handler_ = infer_handler; }
|
||||
namespace {
|
||||
constexpr int64_t kInvalidShape = -2;
|
||||
|
||||
|
@ -279,6 +281,7 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
|
|||
auto primitive = GetValueNode<PrimitivePtr>(inputs[0]);
|
||||
auto input_size = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||
bool skip_nop_node = !context->get_param<bool>(MS_CTX_ENABLE_MINDRT);
|
||||
bool has_py_execute_data = false;
|
||||
for (size_t i = 0; i < input_size; i++) {
|
||||
auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i, false);
|
||||
auto real_input = input_node_with_index.first;
|
||||
|
@ -297,6 +300,9 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
|
|||
}
|
||||
|
||||
auto updated_abs = MakeNewAbstract(real_input, depended_value, real_input_index);
|
||||
if (updated_abs->has_user_data<kernel::PyExecuteOutputData>()) {
|
||||
has_py_execute_data = true;
|
||||
}
|
||||
(void)args_spec_list.emplace_back(updated_abs);
|
||||
} else {
|
||||
auto abs = real_input->abstract();
|
||||
|
@ -312,9 +318,17 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
|
|||
}
|
||||
}
|
||||
|
||||
// Pynative mode is rely on the origin abstract of cnode, so cannot modify the abstract inplace, clone from old
|
||||
// abstract instead.
|
||||
opt::CppInferShape(primitive, args_spec_list, cnode);
|
||||
if (!has_py_execute_data && !IsPrimitiveCNode(cnode, prim::kPrimPyExecute)) {
|
||||
// Pynative mode is rely on the origin abstract of cnode, so cannot modify the abstract inplace, clone from old
|
||||
// abstract instead.
|
||||
opt::CppInferShape(primitive, args_spec_list, cnode);
|
||||
} else {
|
||||
if (cpp_infer_py_handler_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "\'cpp_infer_py_handler_\' should not be null.";
|
||||
}
|
||||
const auto &abs = cpp_infer_py_handler_(cnode, primitive, args_spec_list);
|
||||
cnode->set_abstract(abs);
|
||||
}
|
||||
}
|
||||
|
||||
inline bool IsDeprecatedCpuOrGpuKernelMod(kernel::KernelModType kernel_mod_type) {
|
||||
|
|
|
@ -54,5 +54,9 @@ class CustomActorNodeManager {
|
|||
DISABLE_COPY_AND_ASSIGN(CustomActorNodeManager)
|
||||
OrderedMap<AnfNodePtr, RelatedCustomActorNode> custom_nodes_map_;
|
||||
};
|
||||
|
||||
using InfPyHandler = abstract::AbstractBasePtr (*)(const CNodePtr &, const PrimitivePtr &, const AbstractBasePtrList &);
|
||||
extern InfPyHandler cpp_infer_py_handler_;
|
||||
BACKEND_EXPORT void set_cpp_infer_py_handler(const InfPyHandler &infer_handler);
|
||||
} // namespace mindspore::opt::dynamic_shape
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_DYNAMIC_SHAPE_DYNAMIC_SHAPE_HELPER_H
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "include/common/utils/python_fallback_running.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "mindspore/ccsrc/pipeline/jit/parse/resolve.h"
|
||||
#include "utils/trace_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -327,14 +328,28 @@ bool PyExecuteCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const
|
|||
params[0] = global_dict;
|
||||
params[1] = local_dict;
|
||||
MS_LOG(DEBUG) << "Python script: " << py_script << ", params: " << params;
|
||||
mindspore::ScopedFallbackRunning fallback_running;
|
||||
const auto &output = CallPythonScript(py_script, params);
|
||||
const auto &output_type = py::str(output.get_type());
|
||||
MS_LOG(DEBUG) << "Python output type: " << output_type << ", output: " << output;
|
||||
if (output_type.cast<std::string>() == "<class 'mindspore.common.tensor.Tensor'>") { // It's Python Tensor type.
|
||||
TensorToRawMemory(output.cast<tensor::TensorPtr>(), outputs[0]);
|
||||
try {
|
||||
mindspore::ScopedFallbackRunning fallback_running;
|
||||
const auto &output = CallPythonScript(py_script, params);
|
||||
const auto &output_type = py::str(output.get_type());
|
||||
MS_LOG(DEBUG) << "Python output type: " << output_type << ", output: " << output;
|
||||
if (py::isinstance<tensor::Tensor>(output)) {
|
||||
TensorToRawMemory(output.cast<tensor::TensorPtr>(), outputs[0]);
|
||||
}
|
||||
AttachPyOutputData(output);
|
||||
} catch (const py::error_already_set &e) {
|
||||
auto error_type_name = py::cast<std::string>(python_adapter::GetPyObjAttr(e.type(), "__name__"));
|
||||
auto error_iter = exception_types_map.find(error_type_name);
|
||||
if (error_iter != exception_types_map.end()) {
|
||||
auto &handler = LogWriter::GetExceptionHandler();
|
||||
if (handler != nullptr) {
|
||||
std::stringstream ss;
|
||||
ss << py::str(e.value()) << ".\n\n" << trace::GetDebugInfo(kernel_node_->debug_info());
|
||||
handler(error_iter->second, ss.str());
|
||||
}
|
||||
}
|
||||
throw std::runtime_error(py::str(e.value()));
|
||||
}
|
||||
AttachPyOutputData(output);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "mindspore/ccsrc/pipeline/jit/parse/data_converter.h"
|
||||
#include "mindspore/ccsrc/pybind_api/ir/tensor_py.h"
|
||||
#include "mindspore/ccsrc/plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.h"
|
||||
#include "mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
|
@ -38,12 +39,15 @@ py::object CallPythonGetGlobalParams() {
|
|||
|
||||
class PyExecuteInitializer {
|
||||
public:
|
||||
PyExecuteInitializer() { mindspore::ops::PyExecuteInfer::set_infer_handler(InferPy); }
|
||||
PyExecuteInitializer() {
|
||||
mindspore::ops::PyExecuteInfer::set_infer_handler(PyExecuteInferPy);
|
||||
mindspore::opt::dynamic_shape::set_cpp_infer_py_handler(CppInferShapeAndTypePy);
|
||||
}
|
||||
|
||||
~PyExecuteInitializer() = default;
|
||||
|
||||
private:
|
||||
static abstract::ShapePtr InferPy(const std::vector<AbstractBasePtr> &input_args) {
|
||||
static abstract::AbstractBasePtr PyExecuteInferPy(const std::vector<AbstractBasePtr> &input_args) {
|
||||
const auto &script_abs = input_args[0];
|
||||
const auto &script = script_abs->BuildValue();
|
||||
const auto &script_str = dyn_cast<StringImm>(script);
|
||||
|
@ -53,7 +57,8 @@ class PyExecuteInitializer {
|
|||
const auto &keys = dyn_cast<ValueSequence>(keys_tuple);
|
||||
if (keys == nullptr) {
|
||||
MS_LOG(DEBUG) << "The keys is not tuple value, but got " << keys_tuple->ToString();
|
||||
return std::make_shared<abstract::Shape>(ShapeVector({1}));
|
||||
const auto &infer_shape = std::make_shared<abstract::Shape>(ShapeVector({1}));
|
||||
return abstract::MakeAbstract(infer_shape, kFloat64);
|
||||
}
|
||||
const auto &values_tuple_abs = input_args[2];
|
||||
const auto &values_tuple = values_tuple_abs->BuildValue();
|
||||
|
@ -63,7 +68,8 @@ class PyExecuteInitializer {
|
|||
const auto &values = dyn_cast<ValueSequence>(values_tuple);
|
||||
if (values == nullptr) {
|
||||
MS_LOG(DEBUG) << "The values is not tuple value, but got " << keys_tuple->ToString();
|
||||
return std::make_shared<abstract::Shape>(ShapeVector({1}));
|
||||
const auto &infer_shape = std::make_shared<abstract::Shape>(ShapeVector({1}));
|
||||
return abstract::MakeAbstract(infer_shape, kFloat64);
|
||||
}
|
||||
MS_LOG(DEBUG) << "script: " << script->ToString() << ", keys_tuple: " << keys_tuple->ToString()
|
||||
<< ", values_tuple: " << values_tuple->ToString();
|
||||
|
@ -100,15 +106,125 @@ class PyExecuteInitializer {
|
|||
params[0] = global_dict;
|
||||
params[1] = local_dict;
|
||||
MS_LOG(DEBUG) << "Python script: " << py_script << ", params: " << params;
|
||||
mindspore::ScopedFallbackRunning fallback_running;
|
||||
const auto &output = parse::data_converter::CallPythonScript(py_script, params);
|
||||
MS_LOG(DEBUG) << "Python output type: " << py::str(output.get_type()) << ", output: " << output;
|
||||
if (py::isinstance<tensor::Tensor>(output)) {
|
||||
const auto &tensor = output.cast<tensor::TensorPtr>();
|
||||
return std::make_shared<abstract::Shape>(tensor->shape());
|
||||
try {
|
||||
mindspore::ScopedFallbackRunning fallback_running;
|
||||
const auto &output = parse::data_converter::CallPythonScript(py_script, params);
|
||||
MS_LOG(DEBUG) << "Python output type: " << py::str(output.get_type()) << ", output: " << output;
|
||||
if (py::isinstance<tensor::Tensor>(output)) {
|
||||
const auto &tensor = output.cast<tensor::TensorPtr>();
|
||||
const auto &infer_shape = std::make_shared<abstract::Shape>(tensor->shape());
|
||||
return abstract::MakeAbstract(infer_shape, tensor->Dtype());
|
||||
}
|
||||
} catch (const py::error_already_set &e) {
|
||||
auto error_type_name = py::cast<std::string>(python_adapter::GetPyObjAttr(e.type(), "__name__"));
|
||||
auto error_iter = exception_types_map.find(error_type_name);
|
||||
if (error_iter != exception_types_map.end()) {
|
||||
auto &handler = LogWriter::GetExceptionHandler();
|
||||
if (handler != nullptr) {
|
||||
handler(error_iter->second, py::str(e.value()));
|
||||
}
|
||||
}
|
||||
throw std::runtime_error(py::str(e.value()));
|
||||
}
|
||||
|
||||
return std::make_shared<abstract::Shape>(ShapeVector({1}));
|
||||
const auto &infer_shape = std::make_shared<abstract::Shape>(ShapeVector({1}));
|
||||
return abstract::MakeAbstract(infer_shape, kFloat64);
|
||||
}
|
||||
|
||||
static abstract::AbstractBasePtr CppInferShapeAndTypePy(const CNodePtr &cnode, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// We can't catch the pybind11 exception by py::builtin_exception or its base class,
|
||||
// so we have to list all pybind11 exceptions and catch one by one here.
|
||||
try {
|
||||
const auto &abs = opt::CppInferShapeAndType(primitive, args_spec_list);
|
||||
MS_LOG(DEBUG) << "The abstract of " << cnode->fullname_with_scope() << " changes from " << cnode->abstract()
|
||||
<< " to " << abs;
|
||||
return abs;
|
||||
} catch (const py::type_error &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
|
||||
throw py::type_error(ss.str());
|
||||
} catch (const py::value_error &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
|
||||
throw py::value_error(ss.str());
|
||||
} catch (const py::index_error &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
|
||||
throw py::index_error(ss.str());
|
||||
} catch (const py::key_error &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
|
||||
throw py::key_error(ss.str());
|
||||
} catch (const py::attribute_error &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
|
||||
throw py::attribute_error(ss.str());
|
||||
} catch (const py::name_error &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
|
||||
throw py::name_error(ss.str());
|
||||
} catch (const py::assertion_error &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
|
||||
throw py::assertion_error(ss.str());
|
||||
} catch (const py::base_exception &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
|
||||
throw py::base_exception(ss.str());
|
||||
} catch (const py::keyboard_interrupt &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
|
||||
throw py::keyboard_interrupt(ss.str());
|
||||
} catch (const py::stop_iteration &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
|
||||
throw py::stop_iteration(ss.str());
|
||||
} catch (const py::overflow_error &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
|
||||
throw py::overflow_error(ss.str());
|
||||
} catch (const py::zero_division_error &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
|
||||
throw py::zero_division_error(ss.str());
|
||||
} catch (const py::environment_error &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
|
||||
throw py::environment_error(ss.str());
|
||||
} catch (const py::io_error &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
|
||||
throw py::io_error(ss.str());
|
||||
} catch (const py::os_error &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
|
||||
throw py::os_error(ss.str());
|
||||
} catch (const py::memory_error &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
|
||||
throw py::memory_error(ss.str());
|
||||
} catch (const py::unbound_local_error &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
|
||||
throw py::unbound_local_error(ss.str());
|
||||
} catch (const py::not_implemented_error &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
|
||||
throw py::not_implemented_error(ss.str());
|
||||
} catch (const py::indentation_error &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
|
||||
throw py::indentation_error(ss.str());
|
||||
} catch (const py::runtime_warning &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
|
||||
throw py::runtime_warning(ss.str());
|
||||
} catch (const std::runtime_error &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ T ShapeSize(const std::vector<T> &shape) {
|
|||
return std::accumulate(shape.begin(), shape.end(), static_cast<T>(1), std::multiplies<T>());
|
||||
}
|
||||
|
||||
AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type);
|
||||
MS_CORE_API AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type);
|
||||
MS_CORE_API AbstractBasePtr MakeMonadAbstract(const MonadTypePtr &type);
|
||||
MS_CORE_API AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type);
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
MIND_API_OPERATOR_IMPL(PyExecute, BaseOperator);
|
||||
|
||||
BaseShapePtr PyExecuteInfer::InferShape(const PrimitivePtr &primitive,
|
||||
AbstractBasePtr PyExecuteInfer::InferPy(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
for (const auto &item : input_args) {
|
||||
|
@ -37,17 +37,24 @@ BaseShapePtr PyExecuteInfer::InferShape(const PrimitivePtr &primitive,
|
|||
if (infer_handler_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "infer_handler_ should not be null.";
|
||||
}
|
||||
return infer_handler_(input_args);
|
||||
const auto &abs = infer_handler_(input_args);
|
||||
return abs;
|
||||
}
|
||||
|
||||
BaseShapePtr PyExecuteInfer::InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const {
|
||||
const auto &abs = InferPy(primitive, input_args);
|
||||
return abs->BuildShape();
|
||||
}
|
||||
|
||||
TypePtr PyExecuteInfer::InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
|
||||
return kFloat64;
|
||||
MS_LOG(EXCEPTION) << "Should not invoke InferType().";
|
||||
}
|
||||
|
||||
AbstractBasePtr PyExecuteInfer::InferShapeAndType(const abstract::AnalysisEnginePtr &engine,
|
||||
const PrimitivePtr &primitive,
|
||||
AbstractBasePtr PyExecuteInfer::InferShapeAndType(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const {
|
||||
MS_LOG(EXCEPTION) << "Should not invoke InferShapeAndType.";
|
||||
const auto &abs = infer_handler_(input_args);
|
||||
return abs;
|
||||
}
|
||||
|
||||
std::set<int64_t> PyExecuteInfer::GetValueDependArgIndices() const { return {-1}; }
|
||||
|
|
|
@ -47,9 +47,11 @@ class MIND_API PyExecuteInfer : public abstract::OpInferBase {
|
|||
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override;
|
||||
|
||||
AbstractBasePtr InferPy(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const;
|
||||
|
||||
std::set<int64_t> GetValueDependArgIndices() const override;
|
||||
|
||||
using InferHandler = abstract::ShapePtr (*)(const std::vector<AbstractBasePtr> &);
|
||||
using InferHandler = abstract::AbstractBasePtr (*)(const std::vector<AbstractBasePtr> &);
|
||||
static void set_infer_handler(const InferHandler &infer_handler) { infer_handler_ = infer_handler; }
|
||||
|
||||
private:
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
"""Top-level reference to dtype of common module."""
|
||||
from __future__ import absolute_import
|
||||
from mindspore.common import dtype
|
||||
from mindspore.common.api import no_recursive, ms_function, ms_memory_recycle, ms_class, jit, jit_class
|
||||
from mindspore.common.api import ms_function, ms_memory_recycle, ms_class, jit, jit_class
|
||||
from mindspore.common.dtype import Type, int8, byte, int16, short, int32, intc, int64, intp, \
|
||||
uint8, ubyte, uint16, ushort, uint32, uintc, uint64, uintp, float16, half, \
|
||||
float32, single, float64, double, bool_, float_, list_, tuple_, int_, \
|
||||
|
@ -59,7 +59,7 @@ __all__ = [
|
|||
|
||||
__all__.extend([
|
||||
"Tensor", "RowTensor", "SparseTensor", "COOTensor", "CSRTensor", # tensor
|
||||
"no_recursive", "ms_function", "ms_class", 'jit', 'jit_class', # api
|
||||
"ms_function", "ms_class", 'jit', 'jit_class', # api
|
||||
"Parameter", "ParameterTuple", # parameter
|
||||
"dtype",
|
||||
"set_seed", "get_seed", # random seed
|
||||
|
|
|
@ -736,7 +736,7 @@ def _add_flags(fn=None, **flags):
|
|||
return ret
|
||||
|
||||
|
||||
def no_recursive(callable_obj):
|
||||
def _no_recursive(callable_obj):
|
||||
"""
|
||||
Method or function decorator for ignoring recursive check.
|
||||
|
||||
|
|
|
@ -483,7 +483,6 @@ def test_call_no_self_other_object_method_runtime():
|
|||
assert np.all(result == z)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not supported by now")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -502,11 +501,10 @@ def test_getattr_tensor_with_wrong_attr():
|
|||
return abs_func()
|
||||
|
||||
with pytest.raises(AttributeError) as err:
|
||||
foo(Tensor([-1, -2, -3])) # Not throw error any more, should move to ST.
|
||||
foo(ms.Tensor([-1, -2, -3])) # Not throw error any more, should move to ST.
|
||||
assert "object has no attribute" in str(err.value)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not supported by now")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore.common.api import _no_recursive as no_recursive
|
||||
|
||||
ms.set_context(mode=ms.GRAPH_MODE)
|
||||
|
||||
|
@ -32,7 +33,7 @@ def test_cell_no_recursive():
|
|||
Description: test no_recursive flag.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms.no_recursive
|
||||
@no_recursive
|
||||
class Net(ms.nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -48,7 +49,7 @@ def test_cell_no_recursive():
|
|||
|
||||
|
||||
@ms.jit
|
||||
@ms.no_recursive
|
||||
@no_recursive
|
||||
def func(x, y):
|
||||
res = double(x) + double(y)
|
||||
print(res)
|
||||
|
|
Loading…
Reference in New Issue