!25374 Adapt Custom Op to Pyfunc Kernel

Merge pull request !25374 from jiaoy1224/arithmetic
This commit is contained in:
i-robot 2021-11-01 02:23:26 +00:00 committed by Gitee
commit e80639357d
12 changed files with 425 additions and 96 deletions

View File

@ -40,7 +40,7 @@ CustomAOTCpuKernel::~CustomAOTCpuKernel() {
void CustomAOTCpuKernel::InitKernel(const CNodePtr &kernel_node) {
const auto &exec_info = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "func_name");
if (auto pos = exec_info.find(":"); pos != std::string::npos) {
cuda_path_ = exec_info.substr(0, pos);
file_path_ = exec_info.substr(0, pos);
func_name_ = exec_info.substr(pos + 1);
} else {
MS_LOG(EXCEPTION) << "Wrong execute info:" << exec_info;
@ -58,9 +58,9 @@ void CustomAOTCpuKernel::InitKernel(const CNodePtr &kernel_node) {
std::vector<int64_t> in_shape_tmp;
std::for_each(in_shape.begin(), in_shape.end(),
[&in_shape_tmp](size_t c) { in_shape_tmp.push_back(SizeToLong(c)); });
shape_list_.push_back(in_shape_tmp);
shape_list_.emplace_back(in_shape_tmp);
ndims_.push_back(SizeToInt(in_shape_tmp.size()));
type_list_.push_back(TypeIdToString(input_type_list[i], true));
type_list_.emplace_back(TypeIdToString(input_type_list[i], true));
}
num_output_ = AnfAlgo::GetOutputTensorNum(kernel_node);
@ -75,9 +75,9 @@ void CustomAOTCpuKernel::InitKernel(const CNodePtr &kernel_node) {
std::vector<int64_t> out_shape_tmp;
std::for_each(out_shape.begin(), out_shape.end(),
[&out_shape_tmp](size_t c) { out_shape_tmp.push_back(SizeToLong(c)); });
shape_list_.push_back(out_shape_tmp);
shape_list_.emplace_back(out_shape_tmp);
ndims_.push_back(SizeToInt(out_shape_tmp.size()));
type_list_.push_back(TypeIdToString(output_type_list[i], true));
type_list_.emplace_back(TypeIdToString(output_type_list[i], true));
}
std::transform(std::begin(shape_list_), std::end(shape_list_), std::back_inserter(shapes_),
@ -99,7 +99,7 @@ bool CustomAOTCpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std
#if !defined(_WIN32) && !defined(_WIN64)
if (!handle_) {
handle_ = dlopen(cuda_path_.c_str(), RTLD_LAZY | RTLD_LOCAL);
handle_ = dlopen(file_path_.c_str(), RTLD_LAZY | RTLD_LOCAL);
if (!handle_) {
MS_LOG(EXCEPTION) << "Open Error: " << dlerror();
}
@ -116,10 +116,15 @@ bool CustomAOTCpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std
int nparam = SizeToInt(params.size());
int ret = 0;
if (nparam == 0) {
ret = aot_func_(0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr);
} else {
ret = aot_func_(nparam, &params[0], &ndims_[0], &shapes_[0], &type_pointer_list_[0], nullptr, nullptr);
try {
if (nparam == 0) {
ret = aot_func_(0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr);
} else {
ret = aot_func_(nparam, &params[0], &ndims_[0], &shapes_[0], &type_pointer_list_[0], nullptr, nullptr);
}
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "CustomAOT operator failed when running user defined file " << file_path_ << "! "
<< "Error message is " << e.what();
}
switch (ret) {

View File

@ -43,7 +43,7 @@ class CustomAOTCpuKernel : public CPUKernel {
size_t num_input_;
size_t num_output_;
std::string cuda_path_;
std::string file_path_;
std::string func_name_;
void *handle_;
int (*aot_func_)(int, void **, int *, int64_t **, const char **, void *, void *);

View File

@ -27,8 +27,8 @@
namespace mindspore {
namespace kernel {
namespace {
py::object RawMemoryToScalar(const void *data, const TypePtr &type) {
switch (type->type_id()) {
py::object RawMemoryToScalar(const void *data, const TypeId &type) {
switch (type) {
case kNumberTypeBool:
return py::bool_(*reinterpret_cast<const bool *>(data));
case kNumberTypeInt16:
@ -56,12 +56,12 @@ py::object RawMemoryToScalar(const void *data, const TypePtr &type) {
case kNumberTypeFloat64:
return py::float_(*reinterpret_cast<const double *>(data));
default:
MS_LOG(EXCEPTION) << "Type: " << type->type_id() << " not supported.";
MS_LOG(EXCEPTION) << "Type: " << type << " not supported.";
}
}
void ScalarToRawMemory(const py::object &obj, const TypePtr &type, const AddressPtr &address) {
switch (type->type_id()) {
void ScalarToRawMemory(const py::object &obj, const TypeId &type, const AddressPtr &address) {
switch (type) {
case kNumberTypeBool: {
bool data = py::cast<bool>(obj);
CHECK_RET_WITH_EXCEPT(memcpy_s(address->addr, address->size, &data, sizeof(bool)), EOK, "memcpy failed.");
@ -128,7 +128,7 @@ void ScalarToRawMemory(const py::object &obj, const TypePtr &type, const Address
return;
}
default:
MS_LOG(EXCEPTION) << "Type: " << type->type_id() << " not supported.";
MS_LOG(EXCEPTION) << "Type: " << type << " not supported.";
}
}
@ -155,7 +155,7 @@ void ArrayToRawMemory(const py::array &array, const AddressPtr &address) {
}
}
void ObjectToRawMemory(const py::object &object, const PythonOjectType &object_type, const TypePtr &data_type,
void ObjectToRawMemory(const py::object &object, const PythonOjectType &object_type, const TypeId &data_type,
const AddressPtr &address) {
switch (object_type) {
case PythonOjectType::kScalar:
@ -212,14 +212,17 @@ void PyObjectToRawMemorys(const py::object &object, const PyFuncArgumentInfo &ou
} // namespace
void PyFuncCpuKernel::InitKernel(const CNodePtr &kernel_node) {
is_custom_ = IsPrimitiveCNode(kernel_node, prim::kPrimCustom);
func_id_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "fn_id");
fake_output_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "fake_output");
single_scalar_output_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "single_scalar_output");
BuildFuncInfo(kernel_node);
}
bool PyFuncCpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
if (!init_) {
py_func_ = GetPythonFunc(func_id_);
py_func_ = GetPythonFunc();
init_ = true;
}
@ -227,15 +230,40 @@ bool PyFuncCpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::v
}
void PyFuncCpuKernel::BuildFuncInfo(const CNodePtr &kernel_node) {
const auto &in_shapes = AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel_node, "in_shapes");
const auto &in_types = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "in_types");
const auto &out_shapes = AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel_node, "out_shapes");
const auto &out_types = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "out_types");
std::vector<TypeId> in_types = AnfAlgo::GetAllInputDeviceTypes(kernel_node);
std::vector<TypeId> out_types = AnfAlgo::GetAllOutputDeviceTypes(kernel_node);
std::vector<std::vector<int64_t>> in_shapes;
std::vector<std::vector<int64_t>> out_shapes;
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel_node); i++) {
std::vector<size_t> in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i);
std::vector<int64_t> in_shape_tmp;
std::for_each(in_shape.begin(), in_shape.end(),
[&in_shape_tmp](size_t c) { in_shape_tmp.push_back(SizeToLong(c)); });
in_shapes.emplace_back(in_shape_tmp);
}
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel_node); i++) {
std::vector<size_t> out_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, i);
std::vector<int64_t> out_shape_tmp;
std::for_each(out_shape.begin(), out_shape.end(),
[&out_shape_tmp](size_t c) { out_shape_tmp.push_back(SizeToLong(c)); });
out_shapes.emplace_back(out_shape_tmp);
}
if (in_shapes.size() != in_types.size()) {
MS_LOG(EXCEPTION) << "Input shapes'size is " << in_shapes.size() << ", while input types' size is "
<< in_types.size();
}
if (out_shapes.size() != out_types.size()) {
MS_LOG(EXCEPTION) << "Output shapes'size is " << out_shapes.size() << ", while output types' size is "
<< out_types.size();
}
input_infos_.dtypes = in_types;
input_infos_.shapes = in_shapes;
for (size_t i = 0; i < in_shapes.size(); i++) {
auto tensor = std::make_shared<tensor::Tensor>(in_types[i]->type_id(), in_shapes[i]);
auto tensor = std::make_shared<tensor::Tensor>(in_types[i], in_shapes[i]);
input_tensors_.push_back(tensor);
const auto &object_type = in_shapes[i].empty() ? PythonOjectType::kScalar : PythonOjectType::kNumpyArray;
@ -244,9 +272,13 @@ void PyFuncCpuKernel::BuildFuncInfo(const CNodePtr &kernel_node) {
output_infos_.dtypes = out_types;
output_infos_.shapes = out_shapes;
for (size_t j = 0; j < out_shapes.size(); j++) {
const auto &object_type = out_shapes[j].empty() ? PythonOjectType::kScalar : PythonOjectType::kNumpyArray;
(void)output_infos_.object_types.emplace_back(object_type);
if (single_scalar_output_) {
(void)output_infos_.object_types.emplace_back(PythonOjectType::kScalar);
} else {
for (size_t j = 0; j < out_shapes.size(); j++) {
const auto &object_type = out_shapes[j].empty() ? PythonOjectType::kScalar : PythonOjectType::kNumpyArray;
(void)output_infos_.object_types.emplace_back(object_type);
}
}
}
@ -265,28 +297,35 @@ bool PyFuncCpuKernel::ExecuteKernel(const std::vector<AddressPtr> &inputs, const
result = py_func_();
}
if (output_infos_.shapes.empty()) {
return true;
if (fake_output_) {
if (result.is_none()) {
return true;
} else {
MS_LOG(ERROR) << "This CustomPyfunc should have no outputs, but got 1";
return false;
}
}
PyObjectToRawMemorys(result, output_infos_, outputs);
return true;
}
py::function PyFuncCpuKernel::GetPythonFunc(const int64_t &func_id) {
py::function PyFuncCpuKernel::GetPythonFunc() {
py::gil_scoped_acquire gil_acquire;
static const std::string &module_name = "mindspore.ops.operations.other_ops";
static const std::string &func_name = "get_pyfunc";
static const std::string &module_name =
is_custom_ ? "mindspore.ops.operations.custom_ops" : "mindspore.ops.operations.other_ops";
static const std::string &entrance = "get_pyfunc";
py::module module = py::module::import(module_name.c_str());
py::object get_pyfunc_obj = module.attr(func_name.c_str());
py::object get_pyfunc_obj = module.attr(entrance.c_str());
if (get_pyfunc_obj.is_none()) {
MS_LOG(EXCEPTION) << "Cannot find a python function named " << func_name << "in module" << module_name;
MS_LOG(EXCEPTION) << "Cannot find a python function named " << entrance << "in module" << module_name;
}
py::function get_pyfunc = get_pyfunc_obj.cast<py::function>();
py::object py_func_obj = get_pyfunc(py::int_(func_id));
py::object py_func_obj = get_pyfunc(py::int_(func_id_));
if (py_func_obj.is_none()) {
MS_LOG(EXCEPTION) << "Cannot find python func with id: " << func_id;
MS_LOG(EXCEPTION) << "Cannot find python func with id: " << func_id_;
}
return py_func_obj.cast<py::function>();

View File

@ -35,14 +35,15 @@ struct PyFuncArgumentInfo {
// Empty vector indicate the Python object is Scalar and non-empty means Numpy Array.
std::vector<std::vector<int64_t>> shapes;
// Data type as int, float, bool.
std::vector<TypePtr> dtypes;
std::vector<TypeId> dtypes;
// Python object type
std::vector<PythonOjectType> object_types;
};
class PyFuncCpuKernel : public CPUKernel {
public:
PyFuncCpuKernel() : init_(false), func_id_(-1) {}
PyFuncCpuKernel()
: is_custom_(false), init_(false), fake_output_(false), single_scalar_output_(false), func_id_(-1) {}
~PyFuncCpuKernel() = default;
// Init kernel including analyse PyFunc input and output info.
@ -55,12 +56,18 @@ class PyFuncCpuKernel : public CPUKernel {
// Analyse PyFunc input/output spec.
void BuildFuncInfo(const CNodePtr &kernel_node);
// Get Python function from anchor.
py::function GetPythonFunc(const int64_t &func_id);
py::function GetPythonFunc();
bool ExecuteKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
// both mindspore.ops.operations.custom_ops.Custom and mindspore.ops.operations.PyFunc will launch
// this kernel (these two have similar features, will further be unified);if is_custom_ is true, then it's
// launched from Custom; if not, it's from PyFunc
bool is_custom_;
bool init_;
bool fake_output_;
bool single_scalar_output_;
// The Python object is not acceptable for `Primitive` attribute. So we pass an unique key instead of Python function.
// ME store the Python function to a dict, and pass the key to backend kernel.
// mindspore.ops.operations.PyFunc store the Python function to a dict, and pass the key to backend kernel.
// The kernel get the Python functhon by the key from the dict when the kernel is first invoked.
int64_t func_id_;
py::function py_func_;

View File

@ -52,7 +52,7 @@ class CustomAOTGpuKernel : public GpuKernel {
}
if (!handle_) {
handle_ = dlopen(cuda_path_.c_str(), RTLD_LAZY | RTLD_LOCAL);
handle_ = dlopen(file_path_.c_str(), RTLD_LAZY | RTLD_LOCAL);
if (!handle_) {
MS_LOG(ERROR) << "Open Error: " << dlerror();
return false;
@ -71,10 +71,16 @@ class CustomAOTGpuKernel : public GpuKernel {
int nparam = SizeToInt(params.size());
int ret = 0;
if (nparam == 0) {
ret = aot_func_(0, nullptr, nullptr, nullptr, nullptr, stream_ptr, nullptr);
} else {
ret = aot_func_(nparam, &params[0], &ndims_[0], &shapes_[0], &type_pointer_list_[0], stream_ptr, nullptr);
try {
if (nparam == 0) {
ret = aot_func_(0, nullptr, nullptr, nullptr, nullptr, stream_ptr, nullptr);
} else {
ret = aot_func_(nparam, &params[0], &ndims_[0], &shapes_[0], &type_pointer_list_[0], stream_ptr, nullptr);
}
} catch (const std::exception &e) {
MS_LOG(ERROR) << "CustomAOT operator failed when running user defined file " << file_path_ << "! "
<< "Error message is " << e.what();
return false;
}
switch (ret) {
@ -99,7 +105,7 @@ class CustomAOTGpuKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) override {
const auto &exec_info = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "func_name");
if (auto pos = exec_info.find(":"); pos != std::string::npos) {
cuda_path_ = exec_info.substr(0, pos);
file_path_ = exec_info.substr(0, pos);
func_name_ = exec_info.substr(pos + 1);
} else {
MS_LOG(ERROR) << "Wrong execute info:" << exec_info;
@ -119,9 +125,9 @@ class CustomAOTGpuKernel : public GpuKernel {
std::vector<int64_t> in_shape_tmp;
std::for_each(in_shape.begin(), in_shape.end(),
[&in_shape_tmp](size_t c) { in_shape_tmp.push_back(SizeToLong(c)); });
shape_list_.push_back(in_shape_tmp);
shape_list_.emplace_back(in_shape_tmp);
ndims_.push_back(SizeToInt(in_shape_tmp.size()));
type_list_.push_back(TypeIdToString(input_type_list[i], true));
type_list_.emplace_back(TypeIdToString(input_type_list[i], true));
}
num_output_ = AnfAlgo::GetOutputTensorNum(kernel_node);
@ -138,9 +144,9 @@ class CustomAOTGpuKernel : public GpuKernel {
std::vector<int64_t> out_shape_tmp;
std::for_each(out_shape.begin(), out_shape.end(),
[&out_shape_tmp](size_t c) { out_shape_tmp.push_back(SizeToLong(c)); });
shape_list_.push_back(out_shape_tmp);
shape_list_.emplace_back(out_shape_tmp);
ndims_.push_back(SizeToInt(out_shape_tmp.size()));
type_list_.push_back(TypeIdToString(output_type_list[i], true));
type_list_.emplace_back(TypeIdToString(output_type_list[i], true));
}
std::transform(std::begin(shape_list_), std::end(shape_list_), std::back_inserter(shapes_),
@ -182,7 +188,7 @@ class CustomAOTGpuKernel : public GpuKernel {
size_t num_input_;
size_t num_output_;
std::string cuda_path_;
std::string file_path_;
std::string func_name_;
void *handle_;
int (*aot_func_)(int, void **, int *, int64_t **, const char **, void *, void *);

View File

@ -114,12 +114,9 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
} // use LoadIm2Col only for THOR optimizer
// Use AKG_KERNEL if func_type of Custom is not tbe
if (IsPrimitiveCNode(kernel_node, prim::kPrimCustom)) {
auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node);
auto func_type = primitive->GetAttr("func_type");
if (func_type && GetValue<std::string>(func_type) != "tbe") {
kernel_type = KernelType::AKG_KERNEL;
}
if (IsPrimitiveCNode(kernel_node, prim::kPrimCustom) &&
AnfAlgo::GetNodeAttr<std::string>(kernel_node, kAttrFuncType) != kCustomTypeTbe) {
kernel_type = KernelType::AKG_KERNEL;
}
switch (kernel_type) {

View File

@ -23,6 +23,7 @@
#include "backend/kernel_compiler/kernel_build_info.h"
#include "backend/kernel_compiler/oplib/opinfo.h"
#include "backend/kernel_compiler/oplib/oplib.h"
#include "backend/kernel_compiler/cpu/pyfunc/py_func_cpu_kernel.h"
#include "backend/kernel_compiler/cpu/custom/custom_aot_cpu_kernel.h"
#include "utils/trace_base.h"
@ -317,9 +318,16 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
// Select for dynamic kernel(both the number and data type are undetermined).
const std::string &op_name = AnfAlgo::GetCNodeName(kernel_node);
if (IsPrimitiveCNode(kernel_node, prim::kPrimCustom) &&
!kernel::CPUKernelFactory::GetInstance().SearchRegisteredOp(op_name)) {
kernel::CPUKernelRegistrar(op_name, KernelAttr(), []() { return std::make_shared<kernel::CustomAOTCpuKernel>(); });
auto tp = AnfAlgo::GetNodeAttr<std::string>(kernel_node, kAttrFuncType);
if (tp == kCustomTypePyfunc) {
kernel::CPUKernelRegistrar(op_name, KernelAttr(), []() { return std::make_shared<kernel::PyFuncCpuKernel>(); });
} else if (tp == kCustomTypeAOT) {
kernel::CPUKernelRegistrar(op_name, KernelAttr(),
[]() { return std::make_shared<kernel::CustomAOTCpuKernel>(); });
}
}
if (IsDynamicParamKernel(op_name)) {

View File

@ -481,6 +481,12 @@ constexpr auto kAttrHiddenSize = "hidden_size";
constexpr auto kAttrInputSize = "input_size";
constexpr auto kAttrDstType = "dst_type";
constexpr auto kAttrSkipNopOpAddr = "skip_nop_op_addr";
constexpr auto kAttrFuncType = "func_type";
// custom operator func type
constexpr auto kCustomTypeAOT = "aot";
constexpr auto kCustomTypePyfunc = "pyfunc";
constexpr auto kCustomTypeTbe = "tbe";
// primal attr key name
constexpr auto kPrimalAttrForwardNodeName = "forward_node_name";

View File

@ -19,7 +19,10 @@ import inspect
import json
import functools
from mindspore import ops
from mindspore.ops.op_info_register import RegOp
from mindspore import log as logger
from mindspore.ops.op_info_register import RegOp, DataType
from mindspore.ops._register_for_op import PyFuncRegistry
from mindspore.common import dtype as mstype
from mindspore._c_expression import Oplib
@ -91,8 +94,8 @@ class CustomRegOp(RegOp):
Please note that target and the `func_type` of `Custom` op have some constraints.
If func_type is "akg", target can be one of ["Ascend", "GPU"].
If func_type is "tbe", target can only be "Ascend".
If func_type is "aot", target can only be ["GPU", "CPU"].
If func_type is "py_func", target can only be "CPU".
If func_type is "aot", target can be one of ["GPU", "CPU"].
If func_type is "pyfunc", target can only be "CPU".
Default: None.
"""
self._is_string(target)
@ -126,6 +129,10 @@ def custom_op_info_register(*reg_info):
return decorator
def get_pyfunc(fn):
return Custom.registered_py_id.get(fn)
class Custom(ops.PrimitiveWithInfer):
r"""
`Custom` primitive is used for user defined operators and is to enhance the expressive ability of built-in
@ -143,6 +150,7 @@ class Custom(ops.PrimitiveWithInfer):
the computation logic of a user defined operator. The function can be one of the following:
1. A AKG operator implementation function, which can use ir builder/tvm compute/hybrid grammar.
2. A TBE operator implementation function.
3. A pure python function
str:
If func is of str type, then str should be a path of binary file along with a function name. This could
@ -198,7 +206,7 @@ class Custom(ops.PrimitiveWithInfer):
function or the value of output dtype of func.
If func has single output, then the value of output shape is a mindspore.dtype.
If func has multiple outputs, then the value of output shape is a tuple of mindspore.dtype.
func_type (str): The implementation type of func, should be one of ["akg", "tbe", "aot", "py_func"].
func_type (str): The implementation type of func, should be one of ["akg", "tbe", "aot", "pyfunc"].
bprop (function): The gradient function of func. Default: None.
reg_info (Union[str, dict, list, tuple]): Represents the registration information of func with json format of
type str or dict. The registration information specifies supported formats of input and output, attributes
@ -223,11 +231,13 @@ class Custom(ops.PrimitiveWithInfer):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore as ms
>>> from mindspore.common import dtype as mstype
>>> from mindspore.ops.operations.custom_ops import Custom, CustomRegOp, custom_op_info_register
>>> from mindspore.ops.op_info_register import DataType
>>> from mindspore.nn import Cell
>>>
>>> #func_type="tbe"
>>>
>>> #-----------------------------------------------
>>> #--------------func_type="tbe"------------------
>>> square_with_bias_op_info = CustomRegOp() \
... .fusion_type("OPAQUE") \
... .attr("bias", "required", "float") \
@ -237,7 +247,6 @@ class Custom(ops.PrimitiveWithInfer):
... .dtype_format(DataType.F16_Default, DataType.F16_Default) \
... .target("Ascend") \
... .get_op_info()
>>>
>>> @custom_op_info_register(square_with_bias_op_info)
... def square_with_bias(input_x, output_y, bias=0.0, kernel_name="square_with_bias"):
... import te.lang.cce
@ -261,42 +270,53 @@ class Custom(ops.PrimitiveWithInfer):
... "tensor_list": [data, res]}
...
... te.lang.cce.cce_build_code(sch, config)
>>>
>>> class Net(Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.square_with_bias = Custom(square_with_bias, out_shape=[2, 3], out_dtype=mstype.float32, \
... func_type="tbe")
...
... def construct(self, x):
... res = self.square_with_bias(x, 1.0)
... return res
>>>
>>> #func_type="aot", platform=GPU
>>>
>>> class AOTSingleOutputNet(Cell):
... def __init__(self, func, out_shapes, out_types, reg=None):
... super(AOTSingleOutputNet, self).__init__()
... self.program = Custom(func, out_shapes, out_types, "aot", reg_info=reg)
... def construct(self, x, y):
... return self.program(x, y)
>>>
>>> #-----------------------------------------------
>>> #--------func_type="aot", platform=GPU----------
>>> reorganize_op_info = CustomRegOp() \
... .fusion_type("OPAQUE") \
... .input(0, "x1") \
... .input(1, "x2") \
... .output(0, "y") \
... .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
... .target("GPU") \
... .get_op_info()
>>>
>>> #test = AOTSingleOutputNet("./reorganize.so:CustomReorganize", shape, mstype.float32, reorganize_gpu_info)
>>> #output = test(Tensor(input_x), Tensor(input_y))
>>> #see more details in tests/st/ops/graph_kernel/custom/test_custom_aot.py
>>> class AOTSingleOutputNet(Cell):
... def __init__(self, func, out_shapes, out_types, reg=None):
... super(AOTSingleOutputNet, self).__init__()
... self.program = Custom("./reorganize.so:CustomReorganize", (2,3), mstype.float32, "aot", \
... reorganize_gpu_info)
... def construct(self, x, y):
... return self.program(x, y)
>>> #-----------------------------------------------
>>> #------------func_type="pyfunc"-----------------
>>> multi_output_op_info = CustomRegOp() \
... .input(0, "x1") \
... .input(1, "x2") \
... .output(0, "y1") \
... .output(1, "y2") \
... .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
... .get_op_info()
>>> @custom_op_info_register(multi_output_op_info)
>>> def func_multi_output(x1, x2):
... return (x1 + x2), (x1 - x2)
>>> class PyFuncNet(nn.Cell):
... def __init__(self, fn, out_shapes, out_types,):
... super().__init__()
... self.func = Custom(func_multi_output, ((2,3), (2,3)), (ms.float32, ms.float32), "pyfunc")
... def construct(self, x1, x2):
... return self.func(x1, x2)
"""
registered_func = {}
attr_dict = {} # Save input_names and attr_names for func.
registered_py_id = PyFuncRegistry()
def __init__(self, func, out_shape, out_dtype, func_type, bprop=None, reg_info=None):
ops.PrimitiveWithInfer.__init__(self, "Custom")
@ -305,6 +325,7 @@ class Custom(ops.PrimitiveWithInfer):
self.func = func
self.func_name = ""
self.uniq_name = ""
self.fn_id = -1
self.imply_path = ""
if callable(self.func):
# Get the original function if func is decorated
@ -312,7 +333,10 @@ class Custom(ops.PrimitiveWithInfer):
self.func = self.func.__dict__["__wrapped__"]
self.imply_path = os.path.realpath(inspect.getfile(self.func))
self.func_name = self.func.__name__
self.uniq_name = self.name + "_" + self.func_name + "_" + str(id(self.func))
self.fn_id = id(self.func)
self.uniq_name = self.name+"_"+self.func_name+"_"+str(self.fn_id)
if func_type == "pyfunc":
Custom.registered_py_id.register(self.fn_id, self.func)
elif isinstance(self.func, str):
self.func_name = self.func
self.uniq_name = self.name + "_" + self.func_name
@ -320,11 +344,22 @@ class Custom(ops.PrimitiveWithInfer):
raise TypeError("func should be of type function or str, but got {}".format(type(self.func)))
self.add_prim_attr("func_name", self.func_name)
self.add_prim_attr("uniq_name", self.uniq_name)
self.add_prim_attr("fn_id", self.fn_id)
self.add_prim_attr("imply_path", self.imply_path)
self.out_shape = out_shape
self.out_dtype = out_dtype
self.bprop = bprop
self.func_type = func_type
self.fake_output = False
self.single_scalar_output = False
if not self.out_dtype:
self.fake_output = True
elif not self.out_shape:
self.single_scalar_output = True
self.add_prim_attr("fake_output", self.fake_output)
self.add_prim_attr("single_scalar_output", self.single_scalar_output)
# Register info
self.register_info(reg_info)
@ -346,12 +381,20 @@ class Custom(ops.PrimitiveWithInfer):
def infer_shape(self, *args):
if callable(self.out_shape):
return self.out_shape(*args)
return self.out_shape
if self.out_shape:
return self.out_shape
logger.warning("The function output are empty tuple. Add a placeholder instead. "
"Do not use it as it could be any uninitialized data.")
return (1,)
def infer_dtype(self, *args):
if callable(self.out_dtype):
return self.out_dtype(*args)
return self.out_dtype
if self.out_dtype:
return self.out_dtype
logger.warning("The function output are empty tuple. Add a placeholder instead. "
"Do not use it as it could be any uninitialized data.")
return mstype.int32
def get_bprop(self):
return self.bprop
@ -367,6 +410,12 @@ class Custom(ops.PrimitiveWithInfer):
continue
if isinstance(reg_info, str):
reg_info = json.loads(reg_info)
if self.fake_output:
reg_info["outputs"].append(dict({"index": 0, "name": "y", "param_type": "required"}))
new_dtype_format = []
for i in reg_info["dtype_format"]:
new_dtype_format.append(i+(DataType.I32_Default,))
reg_info["dtype_format"] = new_dtype_format
target = self.get_target(reg_info)
# Reg info for func is only registered once for a certain target
if self.has_registered(target):
@ -379,6 +428,24 @@ class Custom(ops.PrimitiveWithInfer):
raise ValueError('Invalid reg info {}: {}\n'.format(self.imply_path, reg_info_str))
self.save_attr(reg_info)
self.save_register_status(target)
registered_targets = getattr(self.func, "registered_targets", [])
if self.func_type == "pyfunc":
self.add_prim_attr("primitive_target", "CPU")
if registered_targets != ["CPU"]:
logger.warning("CustomPyfunc only supports CPU platform, but gets registered target as {}. We will\
run CustomPyfunc on CPU".format(registered_targets))
elif self.func_type == "aot":
if set(registered_targets) == set(["GPU", "CPU"]):
logger.warning(
"Both GPU and CPU target are registered for CustomAOT. Target will be set according to context.")
elif registered_targets == ["GPU"]:
self.add_prim_attr("primitive_target", "GPU")
elif registered_targets == ["CPU"]:
self.add_prim_attr("primitive_target", "CPU")
else:
logger.warning(
"CustomPyfunc only supports CPU/GPU platform, but gets registered target as {}. Target\
will be set according to context.".format(registered_targets))
def get_expanded_list(self, data):
"""Recursive function to parse elements in list or tuple."""
@ -453,7 +520,7 @@ class Custom(ops.PrimitiveWithInfer):
target = imply_type_to_target.get(reg_info.get("imply_type"))
# Infer target from func_type
if target not in self.supported_targets:
func_type_to_target = {"tbe": "Ascend"}
func_type_to_target = {"tbe": "Ascend", "pyfunc": "CPU"}
target = func_type_to_target.get(self.func_type)
if target not in self.supported_targets:
raise ValueError("target should be one of {}, but got {}".format(self.supported_targets, target))
@ -466,7 +533,7 @@ class Custom(ops.PrimitiveWithInfer):
reg_info["imply_type"].strip():
return reg_info["imply_type"]
# Infer imply_type from func_type
func_type_to_imply_type = {"akg": "AKG", "tbe": "TBE", "aot": target, "py_func": target}
func_type_to_imply_type = {"akg": "AKG", "tbe": "TBE", "aot": target, "pyfunc": target}
return func_type_to_imply_type.get(self.func_type, "AKG")
def save_attr(self, reg_info):

View File

@ -889,6 +889,7 @@ class identity(Primitive):
def __call__(self, x):
return x
pyfunc_register = PyFuncRegistry()
@ -961,6 +962,14 @@ class PyFunc(PrimitiveWithInfer):
validator.check("out_types length", len(out_types), "out_shapes length", len(out_shapes), Rel.EQ, self.name)
self.add_prim_attr("side_effect_io", stateful)
self.add_prim_attr("primitive_target", "CPU")
fake_output = False
single_scalar_output = False
if not out_types:
fake_output = True
elif not out_shapes:
single_scalar_output = True
self.add_prim_attr("fake_output", fake_output)
self.add_prim_attr("single_scalar_output", single_scalar_output)
def infer_shape(self, *args):
if self.out_shapes:

View File

@ -78,7 +78,6 @@ def aot_single_output(get_file_path, source, execf, reg):
add_gpu_info = CustomRegOp() \
.fusion_type("OPAQUE") \
.input(0, "x1") \
.input(1, "x2") \
.output(0, "y") \
@ -101,7 +100,6 @@ def test_aot_single_output_gpu():
add_cpu_info = CustomRegOp() \
.fusion_type("OPAQUE") \
.input(0, "x1") \
.input(1, "x2") \
.output(0, "y") \
@ -128,7 +126,6 @@ def test_aot_single_output_cpu():
reorganize_gpu_info = CustomRegOp() \
.fusion_type("OPAQUE") \
.input(0, "x1") \
.input(1, "x2") \
.output(0, "y") \
@ -166,7 +163,6 @@ def test_reorganize():
hetero_square_mul_gpu_info = CustomRegOp() \
.fusion_type("OPAQUE") \
.input(0, "x1") \
.input(1, "x2") \
.output(0, "y") \
@ -215,7 +211,6 @@ class SquareGradNet(Cell):
square_gpu_info = CustomRegOp() \
.fusion_type("OPAQUE") \
.input(0, "x1") \
.output(0, "y") \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
@ -224,7 +219,6 @@ square_gpu_info = CustomRegOp() \
square_bprop_gpu_info = CustomRegOp() \
.fusion_type("OPAQUE") \
.input(0, "x1") \
.input(1, "x2") \
.input(2, "x3") \
@ -326,7 +320,6 @@ class AOTMultiOutputNet(Cell):
multioutput_gpu_info = CustomRegOp() \
.fusion_type("OPAQUE") \
.input(0, "x1") \
.input(1, "x2") \
.output(0, "y1") \
@ -339,7 +332,6 @@ multioutput_gpu_info = CustomRegOp() \
multioutput_bprop_gpu_info = CustomRegOp() \
.fusion_type("OPAQUE") \
.input(0, "x1") \
.input(1, "x2") \
.input(2, "x3") \

View File

@ -0,0 +1,193 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops.op_info_register import DataType
from mindspore.ops.operations.custom_ops import Custom, CustomRegOp, custom_op_info_register
single_output_op_info = CustomRegOp() \
.input(0, "x1") \
.input(1, "x2") \
.output(0, "y") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.get_op_info()
@custom_op_info_register(single_output_op_info)
def func_single_output(x1, x2):
return x1 - x2
multi_output_op_info = CustomRegOp() \
.input(0, "x1") \
.input(1, "x2") \
.output(0, "y1") \
.output(1, "y2") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@custom_op_info_register(multi_output_op_info)
def func_multi_output(x1, x2):
return (x1 + x2), (x1 - x2)
no_output_op_info = CustomRegOp() \
.input(0, "x1") \
.input(1, "x2") \
.dtype_format(DataType.F32_Default, DataType.F32_Default)\
.get_op_info()
output = 0
@custom_op_info_register(no_output_op_info)
def func_no_output(x1, x2):
global output
output = x1 + x2
class PyFuncNet(nn.Cell):
def __init__(self, fn, out_shapes, out_types,):
super().__init__()
self.func = Custom(fn, out_shapes, out_types, "pyfunc")
self.relu = P.ReLU()
def construct(self, x1, x2):
x = self.func(x1, x2)
return self.relu(x[0])
def func_with_dtype(ms_dtype, np_dtype):
shape = (40, 40)
np.random.seed(42)
x1 = np.random.randint(-5, 5, size=shape).astype(np_dtype)
x2 = np.random.randint(-5, 5, size=shape).astype(np_dtype)
expect = func_single_output(x1, x2)
expect = P.ReLU()(Tensor(expect))
net = PyFuncNet(func_single_output, (shape,), (ms_dtype,))
x = net(Tensor(x1), Tensor(x2))
assert np.allclose(x.asnumpy(), expect.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pyfunc_single_output():
"""
Feature: test case for Custom op with func_type="pyfunc"
Description: the net runs on GPU while custom pyfunc operator on CPU; GRAPH_MODE; single output
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
func_with_dtype(ms.float16, np.float16)
func_with_dtype(ms.float32, np.float32)
func_with_dtype(ms.float64, np.float64)
func_with_dtype(ms.int32, np.int32)
func_with_dtype(ms.int64, np.int64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pyfunc_multi_output():
"""
Feature: test case for Custom op with func_type="pyfunc"
Description: the net runs on GPU while custom pyfunc operator on CPU; GRAPH_MODE; multiple output
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
shape = (40, 40)
dtype = ms.float32
np.random.seed(42)
x1 = np.random.randint(-5, 5, size=shape).astype(np.float32)
x2 = np.random.randint(-5, 5, size=shape).astype(np.float32)
expect, _ = func_multi_output(x1, x2)
expect = P.ReLU()(Tensor(expect))
net = PyFuncNet(func_multi_output, (shape, shape), (dtype, dtype))
x = net(Tensor(x1), Tensor(x2))
assert np.allclose(x.asnumpy(), expect.asnumpy())
class PyFuncGraph(nn.Cell):
def __init__(self, fn, out_shapes, out_types):
super().__init__()
self.func = Custom(fn, out_shapes, out_types, "pyfunc")
def construct(self, x1, x2):
return self.func(x1, x2)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pyfunc_no_output():
"""
Feature: test case for Custom op with func_type="pyfunc"
Description: the net runs on GPU while custom pyfunc operator on CPU; GRAPH_MODE; no output
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
shape = (40, 40)
np.random.seed(42)
x1 = np.random.randint(-5, 5, size=shape).astype(np.float32)
x2 = np.random.randint(-5, 5, size=shape).astype(np.float32)
func_no_output(x1, x2)
global output
expect = output
net = PyFuncGraph(func_no_output, (), ())
net(Tensor(x1), Tensor(x2))
net_output = output
assert np.allclose(net_output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pyfunc_scalar():
"""
Feature: test case for Custom op with func_type="pyfunc"
Description: the net runs on GPU while custom pyfunc operator on CPU; GRAPH_MODE; scalar output
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
shape = ()
ms_dtype = ms.int32
x1 = int(10)
x2 = int(5)
expect = func_single_output(x1, x2)
net = PyFuncGraph(func_single_output, shape, ms_dtype)
x = net(Tensor(x1), Tensor(x2))
assert np.allclose(x.asnumpy(), expect)