forked from mindspore-Ecosystem/mindspore
commit
ed75ccc589
|
@ -1,13 +1,14 @@
|
|||
mindspore.ops.MultitypeFuncGraph
|
||||
================================
|
||||
|
||||
.. py:class:: mindspore.ops.MultitypeFuncGraph(name, read_value=False)
|
||||
.. py:class:: mindspore.ops.MultitypeFuncGraph(name, read_value=False, doc_url="")
|
||||
|
||||
MultitypeFuncGraph是一个用于生成重载函数的类,使用不同类型作为输入。使用 `name` 去初始化一个MultitypeFuncGraph对象,然后用带有输入类型的 `register` 注册器进行装饰注册类型。这样使该函数可以使用不同的类型作为输入调用,一般与 `HyperMap` 、 `Map` 结合使用。
|
||||
|
||||
参数:
|
||||
- **name** (str) - 操作名。
|
||||
- **read_value** (bool, 可选) - 如果注册函数不需要对输入的值进行更改,即所有输入都为按值传递,则将 `read_value` 设置为True。默认值:False。
|
||||
- **doc_url** (str, 可选) - 注册函数对应的官方文档链接。默认值:""。
|
||||
|
||||
异常:
|
||||
- **ValueError** - 找不到给定参数类型所匹配的函数。
|
||||
|
|
|
@ -922,6 +922,7 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
|
|||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<parse::MsClassObject>()) {
|
||||
auto value_obj = dyn_cast_ptr<parse::MsClassObject>(value);
|
||||
MS_EXCEPTION_IF_NULL(value_obj);
|
||||
auto obj_name = std::regex_replace(value_obj->name(), std::regex("MsClassObject:"), "");
|
||||
MS_LOG(EXCEPTION) << "For 'GradOperation', the first argument must be a 'Function' or 'Cell' type "
|
||||
<< "object, but got object with jit_class type" << obj_name << ".\n'GradOperation' "
|
||||
|
|
|
@ -110,8 +110,8 @@ std::string IntToNumber(const std::string &v) {
|
|||
}
|
||||
}
|
||||
|
||||
const std::vector<mindspore::TypePtrList> GetSortedCache(const TypeListMap<py::function> &fn_cache_py_,
|
||||
const TypePtrList &types, size_t match_max_idx) {
|
||||
std::vector<mindspore::TypePtrList> GetSortedCache(const TypeListMap<py::function> &fn_cache_py_,
|
||||
const TypePtrList &types, size_t match_max_idx) {
|
||||
std::vector<mindspore::TypePtrList> cache_vec;
|
||||
std::transform(fn_cache_py_.begin(), fn_cache_py_.end(), back_inserter(cache_vec),
|
||||
[](const auto &fcp) { return fcp.first; });
|
||||
|
|
|
@ -37,9 +37,6 @@ void EqualImpl(void *x1, void *x2, void *result, size_t size) {
|
|||
T *x1_data = static_cast<T *>(x1);
|
||||
T *x2_data = static_cast<T *>(x2);
|
||||
auto result_data = static_cast<bool *>(result);
|
||||
MS_EXCEPTION_IF_NULL(x1_data);
|
||||
MS_EXCEPTION_IF_NULL(x2_data);
|
||||
MS_EXCEPTION_IF_NULL(result_data);
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
result_data[i] = x1_data[i] == x2_data[i];
|
||||
}
|
||||
|
@ -53,9 +50,6 @@ void EqualFloatImpl(void *x1, void *x2, void *result, size_t size) {
|
|||
T *x1_data = static_cast<T *>(x1);
|
||||
T *x2_data = static_cast<T *>(x2);
|
||||
auto result_data = static_cast<bool *>(result);
|
||||
MS_EXCEPTION_IF_NULL(x1_data);
|
||||
MS_EXCEPTION_IF_NULL(x2_data);
|
||||
MS_EXCEPTION_IF_NULL(result_data);
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
result_data[i] = std::abs(x1_data[i] - x2_data[i]) < std::numeric_limits<T>::epsilon();
|
||||
}
|
||||
|
|
|
@ -327,9 +327,9 @@ class Cell(Cell_):
|
|||
try:
|
||||
if self.compile_cache:
|
||||
_cell_graph_executor.del_net_res(self, self.compile_cache)
|
||||
except AttributeError:
|
||||
except AttributeError as e:
|
||||
raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
|
||||
f"Please use 'super().__init__()'.")
|
||||
f"Please use 'super().__init__()'.") from e
|
||||
|
||||
def __delattr__(self, name):
|
||||
if name in self._params:
|
||||
|
@ -453,9 +453,9 @@ class Cell(Cell_):
|
|||
try:
|
||||
if self._enable_forward_pre_hook or self._enable_forward_hook or self._enable_backward_hook:
|
||||
return True
|
||||
except AttributeError:
|
||||
except AttributeError as e:
|
||||
raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
|
||||
f"Please use 'super().__init__()'.")
|
||||
f"Please use 'super().__init__()'.") from e
|
||||
if not self._is_recursion_hook:
|
||||
self._is_recursion_hook = True
|
||||
for cell in self.cells():
|
||||
|
|
|
@ -611,6 +611,7 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|||
name (str): Operator name.
|
||||
read_value (bool, optional): If the registered function do not need to set value on Parameter,
|
||||
and all inputs will pass by value, set `read_value` to True. Default: False.
|
||||
doc_url (str, optional): The official document link corresponding to the registered function. Default:"".
|
||||
|
||||
Raises:
|
||||
ValueError: If failed to find a matching function for the given arguments.
|
||||
|
|
Loading…
Reference in New Issue