forked from mindspore-Ecosystem/mindspore
!31469 Update json generation for custom op
Merge pull request !31469 from zichun_ye/custom_json_update
This commit is contained in:
commit
3d3c32125b
|
@ -9,6 +9,18 @@ mindspore.ops.Custom
|
|||
|
||||
.. warning::这是一个实验性接口,后续可能删除或修改。
|
||||
|
||||
.. note::
|
||||
|
||||
不同自定义算子的函数类型(func_type)支持的平台类型不同。每种类型支持的平台如下:
|
||||
|
||||
- "hybrid": ["Ascend", "GPU"].
|
||||
- "akg": ["Ascend", "GPU"].
|
||||
- "tbe": ["Ascend"].
|
||||
- "aot": ["GPU", "CPU"].
|
||||
- "pyfunc": ["CPU"].
|
||||
- "julia": ["CPU"].
|
||||
- "aicpu": ["Ascend"].
|
||||
|
||||
**参数:**
|
||||
|
||||
- **func** (Union[function, str]) - 自定义算子的函数表达。
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "utils/anf_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "kernel/oplib/oplib.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
using kernel::OpAttr;
|
||||
|
@ -1002,11 +1003,7 @@ bool AkgKernelJsonGenerator::CollectFusedJsonWithSingleKernel(const CNodePtr &c_
|
|||
kernel_json_ = nlohmann::json();
|
||||
std::vector<AnfNodePtr> node_list, input_list, output_list;
|
||||
FuncGraphPtr fg = std::get<0>(BuildGraphFromNodes({c_node}));
|
||||
FuncGraphManagerPtr mng = fg->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(fg, false);
|
||||
fg->set_manager(mng);
|
||||
}
|
||||
FuncGraphManagerPtr mng = GkUtils::GetFuncGraphManager(fg);
|
||||
auto out_cnode = fg->output()->cast<CNodePtr>();
|
||||
if (out_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "Wrong graph generated for kernel [" << c_node->fullname_with_scope()
|
||||
|
@ -1031,6 +1028,25 @@ bool AkgKernelJsonGenerator::CollectFusedJsonWithSingleKernel(const CNodePtr &c_
|
|||
(void)mng->Replace(vnode, parameter);
|
||||
}
|
||||
|
||||
// add new parameter for the same inputs
|
||||
std::set<AnfNodePtr> inputs_set;
|
||||
bool changed = false;
|
||||
for (size_t i = 1; i < out_cnode->size(); i++) {
|
||||
auto inp = out_cnode->input(i);
|
||||
if (inputs_set.count(inp) == 0) {
|
||||
(void)inputs_set.insert(inp);
|
||||
} else {
|
||||
auto p = fg->add_parameter();
|
||||
p->set_abstract(inp->abstract());
|
||||
p->set_kernel_info(inp->kernel_info_ptr());
|
||||
out_cnode->set_input(i, p);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
if (changed) {
|
||||
GkUtils::UpdateFuncGraphManager(mng, fg);
|
||||
}
|
||||
|
||||
node_list.push_back(out_cnode);
|
||||
(void)input_list.insert(input_list.begin(), out_cnode->inputs().begin() + 1, out_cnode->inputs().end());
|
||||
auto output_num = static_cast<int64_t>(AnfUtils::GetOutputTensorNum(out_cnode));
|
||||
|
|
|
@ -217,7 +217,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
>>> # In this case, the input func must be a function written in the Hybrid DSL
|
||||
>>> # and decorated by @ms_hybrid.
|
||||
>>> @ms_hybrid
|
||||
>>> def outer_product_script(a, b):
|
||||
... def outer_product_script(a, b):
|
||||
... c = output_tensor(a.shape, a.dtype)
|
||||
... for i0 in range(a.shape[0]):
|
||||
... for i1 in range(b.shape[1]):
|
||||
|
|
Loading…
Reference in New Issue