forked from mindspore-Ecosystem/mindspore
!31469 Update json generation for custom op
Merge pull request !31469 from zichun_ye/custom_json_update
This commit is contained in:
commit
3d3c32125b
|
@ -9,6 +9,18 @@ mindspore.ops.Custom
|
||||||
|
|
||||||
.. warning::这是一个实验性接口,后续可能删除或修改。
|
.. warning::这是一个实验性接口,后续可能删除或修改。
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
不同自定义算子的函数类型(func_type)支持的平台类型不同。每种类型支持的平台如下:
|
||||||
|
|
||||||
|
- "hybrid": ["Ascend", "GPU"].
|
||||||
|
- "akg": ["Ascend", "GPU"].
|
||||||
|
- "tbe": ["Ascend"].
|
||||||
|
- "aot": ["GPU", "CPU"].
|
||||||
|
- "pyfunc": ["CPU"].
|
||||||
|
- "julia": ["CPU"].
|
||||||
|
- "aicpu": ["Ascend"].
|
||||||
|
|
||||||
**参数:**
|
**参数:**
|
||||||
|
|
||||||
- **func** (Union[function, str]) - 自定义算子的函数表达。
|
- **func** (Union[function, str]) - 自定义算子的函数表达。
|
||||||
|
|
|
@ -27,6 +27,7 @@
|
||||||
#include "utils/anf_utils.h"
|
#include "utils/anf_utils.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
#include "kernel/oplib/oplib.h"
|
#include "kernel/oplib/oplib.h"
|
||||||
|
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||||
|
|
||||||
namespace mindspore::graphkernel {
|
namespace mindspore::graphkernel {
|
||||||
using kernel::OpAttr;
|
using kernel::OpAttr;
|
||||||
|
@ -1002,11 +1003,7 @@ bool AkgKernelJsonGenerator::CollectFusedJsonWithSingleKernel(const CNodePtr &c_
|
||||||
kernel_json_ = nlohmann::json();
|
kernel_json_ = nlohmann::json();
|
||||||
std::vector<AnfNodePtr> node_list, input_list, output_list;
|
std::vector<AnfNodePtr> node_list, input_list, output_list;
|
||||||
FuncGraphPtr fg = std::get<0>(BuildGraphFromNodes({c_node}));
|
FuncGraphPtr fg = std::get<0>(BuildGraphFromNodes({c_node}));
|
||||||
FuncGraphManagerPtr mng = fg->manager();
|
FuncGraphManagerPtr mng = GkUtils::GetFuncGraphManager(fg);
|
||||||
if (mng == nullptr) {
|
|
||||||
mng = Manage(fg, false);
|
|
||||||
fg->set_manager(mng);
|
|
||||||
}
|
|
||||||
auto out_cnode = fg->output()->cast<CNodePtr>();
|
auto out_cnode = fg->output()->cast<CNodePtr>();
|
||||||
if (out_cnode == nullptr) {
|
if (out_cnode == nullptr) {
|
||||||
MS_LOG(ERROR) << "Wrong graph generated for kernel [" << c_node->fullname_with_scope()
|
MS_LOG(ERROR) << "Wrong graph generated for kernel [" << c_node->fullname_with_scope()
|
||||||
|
@ -1031,6 +1028,25 @@ bool AkgKernelJsonGenerator::CollectFusedJsonWithSingleKernel(const CNodePtr &c_
|
||||||
(void)mng->Replace(vnode, parameter);
|
(void)mng->Replace(vnode, parameter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add new parameter for the same inputs
|
||||||
|
std::set<AnfNodePtr> inputs_set;
|
||||||
|
bool changed = false;
|
||||||
|
for (size_t i = 1; i < out_cnode->size(); i++) {
|
||||||
|
auto inp = out_cnode->input(i);
|
||||||
|
if (inputs_set.count(inp) == 0) {
|
||||||
|
(void)inputs_set.insert(inp);
|
||||||
|
} else {
|
||||||
|
auto p = fg->add_parameter();
|
||||||
|
p->set_abstract(inp->abstract());
|
||||||
|
p->set_kernel_info(inp->kernel_info_ptr());
|
||||||
|
out_cnode->set_input(i, p);
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (changed) {
|
||||||
|
GkUtils::UpdateFuncGraphManager(mng, fg);
|
||||||
|
}
|
||||||
|
|
||||||
node_list.push_back(out_cnode);
|
node_list.push_back(out_cnode);
|
||||||
(void)input_list.insert(input_list.begin(), out_cnode->inputs().begin() + 1, out_cnode->inputs().end());
|
(void)input_list.insert(input_list.begin(), out_cnode->inputs().begin() + 1, out_cnode->inputs().end());
|
||||||
auto output_num = static_cast<int64_t>(AnfUtils::GetOutputTensorNum(out_cnode));
|
auto output_num = static_cast<int64_t>(AnfUtils::GetOutputTensorNum(out_cnode));
|
||||||
|
|
|
@ -217,7 +217,7 @@ class Custom(ops.PrimitiveWithInfer):
|
||||||
>>> # In this case, the input func must be a function written in the Hybrid DSL
|
>>> # In this case, the input func must be a function written in the Hybrid DSL
|
||||||
>>> # and decorated by @ms_hybrid.
|
>>> # and decorated by @ms_hybrid.
|
||||||
>>> @ms_hybrid
|
>>> @ms_hybrid
|
||||||
>>> def outer_product_script(a, b):
|
... def outer_product_script(a, b):
|
||||||
... c = output_tensor(a.shape, a.dtype)
|
... c = output_tensor(a.shape, a.dtype)
|
||||||
... for i0 in range(a.shape[0]):
|
... for i0 in range(a.shape[0]):
|
||||||
... for i1 in range(b.shape[1]):
|
... for i1 in range(b.shape[1]):
|
||||||
|
|
Loading…
Reference in New Issue