!31469 Update json generation for custom op

Merge pull request !31469 from zichun_ye/custom_json_update
This commit is contained in:
i-robot 2022-03-22 09:03:36 +00:00 committed by Gitee
commit 3d3c32125b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 34 additions and 6 deletions

View File

@ -9,6 +9,18 @@ mindspore.ops.Custom
.. warning::这是一个实验性接口,后续可能删除或修改。
.. note::
不同自定义算子的函数类型func_type)支持的平台类型不同。每种类型支持的平台如下:
- "hybrid": ["Ascend", "GPU"].
- "akg": ["Ascend", "GPU"].
- "tbe": ["Ascend"].
- "aot": ["GPU", "CPU"].
- "pyfunc": ["CPU"].
- "julia": ["CPU"].
- "aicpu": ["Ascend"].
**参数:**
- **func** (Union[function, str]) - 自定义算子的函数表达。

View File

@ -27,6 +27,7 @@
#include "utils/anf_utils.h"
#include "utils/ms_context.h"
#include "kernel/oplib/oplib.h"
#include "common/graph_kernel/core/graph_kernel_utils.h"
namespace mindspore::graphkernel {
using kernel::OpAttr;
@ -1002,11 +1003,7 @@ bool AkgKernelJsonGenerator::CollectFusedJsonWithSingleKernel(const CNodePtr &c_
kernel_json_ = nlohmann::json();
std::vector<AnfNodePtr> node_list, input_list, output_list;
FuncGraphPtr fg = std::get<0>(BuildGraphFromNodes({c_node}));
FuncGraphManagerPtr mng = fg->manager();
if (mng == nullptr) {
mng = Manage(fg, false);
fg->set_manager(mng);
}
FuncGraphManagerPtr mng = GkUtils::GetFuncGraphManager(fg);
auto out_cnode = fg->output()->cast<CNodePtr>();
if (out_cnode == nullptr) {
MS_LOG(ERROR) << "Wrong graph generated for kernel [" << c_node->fullname_with_scope()
@ -1031,6 +1028,25 @@ bool AkgKernelJsonGenerator::CollectFusedJsonWithSingleKernel(const CNodePtr &c_
(void)mng->Replace(vnode, parameter);
}
// add new parameter for the same inputs
std::set<AnfNodePtr> inputs_set;
bool changed = false;
for (size_t i = 1; i < out_cnode->size(); i++) {
auto inp = out_cnode->input(i);
if (inputs_set.count(inp) == 0) {
(void)inputs_set.insert(inp);
} else {
auto p = fg->add_parameter();
p->set_abstract(inp->abstract());
p->set_kernel_info(inp->kernel_info_ptr());
out_cnode->set_input(i, p);
changed = true;
}
}
if (changed) {
GkUtils::UpdateFuncGraphManager(mng, fg);
}
node_list.push_back(out_cnode);
(void)input_list.insert(input_list.begin(), out_cnode->inputs().begin() + 1, out_cnode->inputs().end());
auto output_num = static_cast<int64_t>(AnfUtils::GetOutputTensorNum(out_cnode));

View File

@ -217,7 +217,7 @@ class Custom(ops.PrimitiveWithInfer):
>>> # In this case, the input func must be a function written in the Hybrid DSL
>>> # and decorated by @ms_hybrid.
>>> @ms_hybrid
>>> def outer_product_script(a, b):
... def outer_product_script(a, b):
... c = output_tensor(a.shape, a.dtype)
... for i0 in range(a.shape[0]):
... for i1 in range(b.shape[1]):