From 516ff13c89c25e42271c1ee1d9985e526cc24c70 Mon Sep 17 00:00:00 2001 From: ligan Date: Tue, 20 Dec 2022 16:19:27 +0800 Subject: [PATCH] bug --- .../api_python/ops/mindspore.ops.MultitypeFuncGraph.rst | 3 ++- mindspore/ccsrc/frontend/operator/composite/composite.cc | 1 + .../frontend/operator/composite/multitype_funcgraph.cc | 4 ++-- mindspore/core/ops/equal.cc | 6 ------ mindspore/python/mindspore/nn/cell.py | 8 ++++---- mindspore/python/mindspore/ops/composite/base.py | 1 + 6 files changed, 10 insertions(+), 13 deletions(-) diff --git a/docs/api/api_python/ops/mindspore.ops.MultitypeFuncGraph.rst b/docs/api/api_python/ops/mindspore.ops.MultitypeFuncGraph.rst index e8341b841cc..f0b946593db 100644 --- a/docs/api/api_python/ops/mindspore.ops.MultitypeFuncGraph.rst +++ b/docs/api/api_python/ops/mindspore.ops.MultitypeFuncGraph.rst @@ -1,13 +1,14 @@ mindspore.ops.MultitypeFuncGraph ================================ -.. py:class:: mindspore.ops.MultitypeFuncGraph(name, read_value=False) +.. py:class:: mindspore.ops.MultitypeFuncGraph(name, read_value=False, doc_url="") MultitypeFuncGraph是一个用于生成重载函数的类,使用不同类型作为输入。使用 `name` 去初始化一个MultitypeFuncGraph对象,然后用带有输入类型的 `register` 注册器进行装饰注册类型。这样使该函数可以使用不同的类型作为输入调用,一般与 `HyperMap` 、 `Map` 结合使用。 参数: - **name** (str) - 操作名。 - **read_value** (bool, 可选) - 如果注册函数不需要对输入的值进行更改,即所有输入都为按值传递,则将 `read_value` 设置为True。默认值:False。 + - **doc_url** (str, 可选) - 注册函数对应的官方文档链接。默认值:""。 异常: - **ValueError** - 找不到给定参数类型所匹配的函数。 diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc index a453b4661f2..09201712a1b 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.cc +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -922,6 +922,7 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp MS_EXCEPTION_IF_NULL(value); if (value->isa()) { auto value_obj = dyn_cast_ptr(value); + MS_EXCEPTION_IF_NULL(value_obj); auto obj_name = std::regex_replace(value_obj->name(), std::regex("MsClassObject:"), ""); MS_LOG(EXCEPTION) << "For 'GradOperation', the first argument must be a 'Function' or 'Cell' type " << "object, but got object with jit_class type" << obj_name << ".\n'GradOperation' " diff --git a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc index eebe84721ad..ea6141f8b3c 100644 --- a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc +++ b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc @@ -110,8 +110,8 @@ std::string IntToNumber(const std::string &v) { } } -const std::vector GetSortedCache(const TypeListMap &fn_cache_py_, - const TypePtrList &types, size_t match_max_idx) { +std::vector GetSortedCache(const TypeListMap &fn_cache_py_, + const TypePtrList &types, size_t match_max_idx) { std::vector cache_vec; std::transform(fn_cache_py_.begin(), fn_cache_py_.end(), back_inserter(cache_vec), [](const auto &fcp) { return fcp.first; }); diff --git a/mindspore/core/ops/equal.cc b/mindspore/core/ops/equal.cc index f0ef2a6ae96..24d51eaa873 100644 --- a/mindspore/core/ops/equal.cc +++ b/mindspore/core/ops/equal.cc @@ -37,9 +37,6 @@ void EqualImpl(void *x1, void *x2, void *result, size_t size) { T *x1_data = static_cast(x1); T *x2_data = static_cast(x2); auto result_data = static_cast(result); - MS_EXCEPTION_IF_NULL(x1_data); - MS_EXCEPTION_IF_NULL(x2_data); - MS_EXCEPTION_IF_NULL(result_data); for (size_t i = 0; i < size; ++i) { result_data[i] = x1_data[i] == x2_data[i]; } @@ -53,9 +50,6 @@ void EqualFloatImpl(void *x1, void *x2, void *result, size_t size) { T *x1_data = static_cast(x1); T *x2_data = static_cast(x2); auto result_data = static_cast(result); - MS_EXCEPTION_IF_NULL(x1_data); - MS_EXCEPTION_IF_NULL(x2_data); - MS_EXCEPTION_IF_NULL(result_data); for (size_t i = 0; i < size; ++i) { result_data[i] = std::abs(x1_data[i] - x2_data[i]) < std::numeric_limits::epsilon(); } diff --git a/mindspore/python/mindspore/nn/cell.py b/mindspore/python/mindspore/nn/cell.py index 7e6c6c255dc..afe22c4b396 100755 --- a/mindspore/python/mindspore/nn/cell.py +++ b/mindspore/python/mindspore/nn/cell.py @@ -327,9 +327,9 @@ class Cell(Cell_): try: if self.compile_cache: _cell_graph_executor.del_net_res(self, self.compile_cache) - except AttributeError: + except AttributeError as e: raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. " - f"Please use 'super().__init__()'.") + f"Please use 'super().__init__()'.") from e def __delattr__(self, name): if name in self._params: @@ -453,9 +453,9 @@ class Cell(Cell_): try: if self._enable_forward_pre_hook or self._enable_forward_hook or self._enable_backward_hook: return True - except AttributeError: + except AttributeError as e: raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. " - f"Please use 'super().__init__()'.") + f"Please use 'super().__init__()'.") from e if not self._is_recursion_hook: self._is_recursion_hook = True for cell in self.cells(): diff --git a/mindspore/python/mindspore/ops/composite/base.py b/mindspore/python/mindspore/ops/composite/base.py index 064e19b7857..c014b42ced2 100644 --- a/mindspore/python/mindspore/ops/composite/base.py +++ b/mindspore/python/mindspore/ops/composite/base.py @@ -611,6 +611,7 @@ class MultitypeFuncGraph(MultitypeFuncGraph_): name (str): Operator name. read_value (bool, optional): If the registered function do not need to set value on Parameter, and all inputs will pass by value, set `read_value` to True. Default: False. + doc_url (str, optional): The official document link corresponding to the registered function. Default:"". Raises: ValueError: If failed to find a matching function for the given arguments.