!31469 Update json generation for custom op

Merge pull request !31469 from zichun_ye/custom_json_update
This commit is contained in:
i-robot 2022-03-22 09:03:36 +00:00 committed by Gitee
commit 3d3c32125b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 34 additions and 6 deletions

View File

@ -9,6 +9,18 @@ mindspore.ops.Custom
.. warning::这是一个实验性接口,后续可能删除或修改。 .. warning::这是一个实验性接口,后续可能删除或修改。
.. note::
不同自定义算子的函数类型func_type)支持的平台类型不同。每种类型支持的平台如下:
- "hybrid": ["Ascend", "GPU"].
- "akg": ["Ascend", "GPU"].
- "tbe": ["Ascend"].
- "aot": ["GPU", "CPU"].
- "pyfunc": ["CPU"].
- "julia": ["CPU"].
- "aicpu": ["Ascend"].
**参数:** **参数:**
- **func** (Union[function, str]) - 自定义算子的函数表达。 - **func** (Union[function, str]) - 自定义算子的函数表达。

View File

@ -27,6 +27,7 @@
#include "utils/anf_utils.h" #include "utils/anf_utils.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "kernel/oplib/oplib.h" #include "kernel/oplib/oplib.h"
#include "common/graph_kernel/core/graph_kernel_utils.h"
namespace mindspore::graphkernel { namespace mindspore::graphkernel {
using kernel::OpAttr; using kernel::OpAttr;
@ -1002,11 +1003,7 @@ bool AkgKernelJsonGenerator::CollectFusedJsonWithSingleKernel(const CNodePtr &c_
kernel_json_ = nlohmann::json(); kernel_json_ = nlohmann::json();
std::vector<AnfNodePtr> node_list, input_list, output_list; std::vector<AnfNodePtr> node_list, input_list, output_list;
FuncGraphPtr fg = std::get<0>(BuildGraphFromNodes({c_node})); FuncGraphPtr fg = std::get<0>(BuildGraphFromNodes({c_node}));
FuncGraphManagerPtr mng = fg->manager(); FuncGraphManagerPtr mng = GkUtils::GetFuncGraphManager(fg);
if (mng == nullptr) {
mng = Manage(fg, false);
fg->set_manager(mng);
}
auto out_cnode = fg->output()->cast<CNodePtr>(); auto out_cnode = fg->output()->cast<CNodePtr>();
if (out_cnode == nullptr) { if (out_cnode == nullptr) {
MS_LOG(ERROR) << "Wrong graph generated for kernel [" << c_node->fullname_with_scope() MS_LOG(ERROR) << "Wrong graph generated for kernel [" << c_node->fullname_with_scope()
@ -1031,6 +1028,25 @@ bool AkgKernelJsonGenerator::CollectFusedJsonWithSingleKernel(const CNodePtr &c_
(void)mng->Replace(vnode, parameter); (void)mng->Replace(vnode, parameter);
} }
// add new parameter for the same inputs
std::set<AnfNodePtr> inputs_set;
bool changed = false;
for (size_t i = 1; i < out_cnode->size(); i++) {
auto inp = out_cnode->input(i);
if (inputs_set.count(inp) == 0) {
(void)inputs_set.insert(inp);
} else {
auto p = fg->add_parameter();
p->set_abstract(inp->abstract());
p->set_kernel_info(inp->kernel_info_ptr());
out_cnode->set_input(i, p);
changed = true;
}
}
if (changed) {
GkUtils::UpdateFuncGraphManager(mng, fg);
}
node_list.push_back(out_cnode); node_list.push_back(out_cnode);
(void)input_list.insert(input_list.begin(), out_cnode->inputs().begin() + 1, out_cnode->inputs().end()); (void)input_list.insert(input_list.begin(), out_cnode->inputs().begin() + 1, out_cnode->inputs().end());
auto output_num = static_cast<int64_t>(AnfUtils::GetOutputTensorNum(out_cnode)); auto output_num = static_cast<int64_t>(AnfUtils::GetOutputTensorNum(out_cnode));

View File

@ -217,7 +217,7 @@ class Custom(ops.PrimitiveWithInfer):
>>> # In this case, the input func must be a function written in the Hybrid DSL >>> # In this case, the input func must be a function written in the Hybrid DSL
>>> # and decorated by @ms_hybrid. >>> # and decorated by @ms_hybrid.
>>> @ms_hybrid >>> @ms_hybrid
>>> def outer_product_script(a, b): ... def outer_product_script(a, b):
... c = output_tensor(a.shape, a.dtype) ... c = output_tensor(a.shape, a.dtype)
... for i0 in range(a.shape[0]): ... for i0 in range(a.shape[0]):
... for i1 in range(b.shape[1]): ... for i1 in range(b.shape[1]):