!25726 Add Pynative Test Cases for Custom AOT, Pyfunc
Merge pull request !25726 from jiaoy1224/arithmetic
This commit is contained in:
commit
9a3b52eda7
|
@ -212,7 +212,6 @@ void PyObjectToRawMemorys(const py::object &object, const PyFuncArgumentInfo &ou
|
|||
} // namespace
|
||||
|
||||
void PyFuncCpuKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
is_custom_ = IsPrimitiveCNode(kernel_node, prim::kPrimCustom);
|
||||
func_id_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "fn_id");
|
||||
fake_output_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "fake_output");
|
||||
single_scalar_output_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "single_scalar_output");
|
||||
|
@ -313,8 +312,7 @@ bool PyFuncCpuKernel::ExecuteKernel(const std::vector<AddressPtr> &inputs, const
|
|||
|
||||
py::function PyFuncCpuKernel::GetPythonFunc() {
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
static const std::string &module_name =
|
||||
is_custom_ ? "mindspore.ops.operations.custom_ops" : "mindspore.ops.operations.other_ops";
|
||||
static const std::string &module_name = "mindspore.ops.operations._pyfunc_registry";
|
||||
static const std::string &entrance = "get_pyfunc";
|
||||
py::module module = py::module::import(module_name.c_str());
|
||||
py::object get_pyfunc_obj = module.attr(entrance.c_str());
|
||||
|
|
|
@ -42,8 +42,7 @@ struct PyFuncArgumentInfo {
|
|||
|
||||
class PyFuncCpuKernel : public CPUKernel {
|
||||
public:
|
||||
PyFuncCpuKernel()
|
||||
: is_custom_(false), init_(false), fake_output_(false), single_scalar_output_(false), func_id_(-1) {}
|
||||
PyFuncCpuKernel() : init_(false), fake_output_(false), single_scalar_output_(false), func_id_(-1) {}
|
||||
~PyFuncCpuKernel() = default;
|
||||
|
||||
// Init kernel including analyse PyFunc input and output info.
|
||||
|
@ -59,10 +58,6 @@ class PyFuncCpuKernel : public CPUKernel {
|
|||
py::function GetPythonFunc();
|
||||
bool ExecuteKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
// both mindspore.ops.operations.custom_ops.Custom and mindspore.ops.operations.PyFunc will launch
|
||||
// this kernel (these two have similar features, will further be unified);if is_custom_ is true, then it's
|
||||
// launched from Custom; if not, it's from PyFunc
|
||||
bool is_custom_;
|
||||
bool init_;
|
||||
bool fake_output_;
|
||||
bool single_scalar_output_;
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Register pyfunc for py_func_cpu_kernel"""
|
||||
|
||||
from mindspore.ops._register_for_op import PyFuncRegistry
|
||||
|
||||
|
||||
registered_py_id = PyFuncRegistry()
|
||||
|
||||
|
||||
def add_pyfunc(fn_id, fn):
|
||||
registered_py_id.register(fn_id, fn)
|
||||
|
||||
|
||||
def get_pyfunc(fn_id):
|
||||
return registered_py_id.get(fn_id)
|
|
@ -21,9 +21,9 @@ import functools
|
|||
from mindspore import ops
|
||||
from mindspore import log as logger
|
||||
from mindspore.ops.op_info_register import RegOp, DataType
|
||||
from mindspore.ops._register_for_op import PyFuncRegistry
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore._c_expression import Oplib
|
||||
from ._pyfunc_registry import add_pyfunc
|
||||
|
||||
|
||||
class CustomRegOp(RegOp):
|
||||
|
@ -129,10 +129,6 @@ def custom_op_info_register(*reg_info):
|
|||
return decorator
|
||||
|
||||
|
||||
def get_pyfunc(fn):
|
||||
return Custom.registered_py_id.get(fn)
|
||||
|
||||
|
||||
class Custom(ops.PrimitiveWithInfer):
|
||||
r"""
|
||||
`Custom` primitive is used for user defined operators and is to enhance the expressive ability of built-in
|
||||
|
@ -311,7 +307,6 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
|
||||
registered_func = {}
|
||||
attr_dict = {} # Save input_names and attr_names for func.
|
||||
registered_py_id = PyFuncRegistry()
|
||||
|
||||
def __init__(self, func, out_shape, out_dtype, func_type, bprop=None, reg_info=None):
|
||||
ops.PrimitiveWithInfer.__init__(self, "Custom")
|
||||
|
@ -331,7 +326,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
self.fn_id = id(self.func)
|
||||
self.uniq_name = self.name + "_" + self.func_name + "_" + str(self.fn_id)
|
||||
if func_type == "pyfunc":
|
||||
Custom.registered_py_id.register(self.fn_id, self.func)
|
||||
add_pyfunc(self.fn_id, self.func)
|
||||
elif isinstance(self.func, str):
|
||||
self.func_name = self.func
|
||||
self.uniq_name = self.name + "_" + self.func_name
|
||||
|
|
|
@ -22,7 +22,7 @@ from .. import signature as sig
|
|||
from ..._checkparam import Validator as validator, Rel
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import Primitive, PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register
|
||||
from .._register_for_op import PyFuncRegistry
|
||||
from ._pyfunc_registry import add_pyfunc
|
||||
|
||||
|
||||
class Assign(Primitive):
|
||||
|
@ -893,13 +893,6 @@ class identity(Primitive):
|
|||
return x
|
||||
|
||||
|
||||
pyfunc_register = PyFuncRegistry()
|
||||
|
||||
|
||||
def get_pyfunc(fn_id):
|
||||
return pyfunc_register.get(fn_id)
|
||||
|
||||
|
||||
class PyFunc(PrimitiveWithInfer):
|
||||
r"""
|
||||
Execute Python function.
|
||||
|
@ -951,7 +944,7 @@ class PyFunc(PrimitiveWithInfer):
|
|||
|
||||
def __init__(self, fn, in_types, in_shapes, out_types, out_shapes, stateful=True):
|
||||
super(PyFunc, self).__init__(self.__class__.__name__)
|
||||
pyfunc_register.register(id(fn), fn)
|
||||
add_pyfunc(id(fn), fn)
|
||||
self.add_prim_attr('fn_id', id(fn))
|
||||
self.add_prim_attr('in_types', in_types)
|
||||
self.add_prim_attr('in_shapes', in_shapes)
|
||||
|
|
|
@ -105,7 +105,7 @@ add_cpu_info = CustomRegOp() \
|
|||
@pytest.mark.env_onecard
|
||||
def test_aot_single_output_cpu():
|
||||
"""
|
||||
Feature: custom aot operator, multiple inputs, single output, CPU
|
||||
Feature: custom aot operator, multiple inputs, single output, CPU, GRAPH_MODE
|
||||
Description: pre-compile xxx.cc to xxx.so, custom operator launches xxx.so
|
||||
Expectation: nn result matches numpy result
|
||||
"""
|
||||
|
@ -300,24 +300,15 @@ multioutput_bprop_gpu_info = CustomRegOp() \
|
|||
.get_op_info()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
def test_add_mul_div_bprop():
|
||||
"""
|
||||
Feature: custom aot operator with reg info, bprop(Cell), multiple outputs, GPU
|
||||
Description: pre-compile xxx.cu to xxx.so, custom operator launches xxx.so
|
||||
Expectation: nn result matches numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
def add_mul_div_bprop(source, execf, source_prop, execf_prop):
|
||||
x = np.array([1.0, 4.0, 9.0]).astype(np.float32)
|
||||
y = np.array([1.0, 1.0, 1.0]).astype(np.float32)
|
||||
sens = np.array([1.0, 1.0, 1.0]).astype(np.float32)
|
||||
expect_dx = np.array([5.0, 17.0, 37.0]).astype(np.float32)
|
||||
expect_dy = np.array([-1.0, -16.0, -81.0]).astype(np.float32)
|
||||
|
||||
cmd_bprop, func_path_bprop = get_file_path_gpu("add_mul_div_bprop.cu", "add_mul_div_bprop.so")
|
||||
check_exec_file(cmd_bprop, func_path_bprop, "add_mul_div_bprop.cu", "add_mul_div_bprop.so")
|
||||
cmd_bprop, func_path_bprop = get_file_path_gpu(source_prop, execf_prop)
|
||||
check_exec_file(cmd_bprop, func_path_bprop, source_prop, execf_prop)
|
||||
try:
|
||||
aot_bprop = Custom(func_path_bprop + ":CustomAddMulDivBprop",
|
||||
([3], [3]), (mstype.float32, mstype.float32), "aot", reg_info=multioutput_bprop_gpu_info)
|
||||
|
@ -330,8 +321,8 @@ def test_add_mul_div_bprop():
|
|||
res = aot_bprop(x, y, dout[0], dout[1], dout[2])
|
||||
return res
|
||||
|
||||
cmd, func_path = get_file_path_gpu("add_mul_div.cu", "add_mul_div.so")
|
||||
check_exec_file(cmd, func_path, "add_mul_div.cu", "add_mul_div.so")
|
||||
cmd, func_path = get_file_path_gpu(source, execf)
|
||||
check_exec_file(cmd, func_path, source, execf)
|
||||
try:
|
||||
net = AOTMultiOutputNet(func_path + ":CustomAddMulDiv", ([3], [3], [3]),
|
||||
(mstype.float32, mstype.float32, mstype.float32), bprop=bprop, reg=multioutput_gpu_info)
|
||||
|
@ -349,3 +340,30 @@ def test_add_mul_div_bprop():
|
|||
dy_np = dy.asnumpy()
|
||||
assert np.allclose(expect_dx, dx_np, 0.0001, 0.0001)
|
||||
assert np.allclose(expect_dy, dy_np, 0.0001, 0.0001)
|
||||
|
||||
|
||||
@ pytest.mark.level0
|
||||
@ pytest.mark.env_onecard
|
||||
@ pytest.mark.platform_x86_gpu_training
|
||||
def test_add_mul_div_bprop_graph():
|
||||
"""
|
||||
Feature: custom aot operator, bprop(Cell), multiple outputs, GPU, GRAPH_MODE
|
||||
Description: pre-compile xxx.cu to xxx.so, custom operator launches xxx.so
|
||||
Expectation: nn result matches numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
add_mul_div_bprop("add_mul_div.cu", "add_mul_div.so", "add_mul_div_bprop.cu", "add_mul_div_bprop.so")
|
||||
|
||||
|
||||
@ pytest.mark.level0
|
||||
@ pytest.mark.env_onecard
|
||||
@ pytest.mark.platform_x86_gpu_training
|
||||
def test_add_mul_div_bprop_pynative():
|
||||
"""
|
||||
Feature: custom aot operator, bprop(Cell), multiple outputs, GPU, PYNATIVE_MODE
|
||||
Description: pre-compile xxx.cu to xxx.so, custom operator launches xxx.so
|
||||
Expectation: nn result matches numpy result
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
add_mul_div_bprop("add_mul_div.cu", "add_mul_div_pynative.so",
|
||||
"add_mul_div_bprop.cu", "add_mul_div_bprop_pynative.so")
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import platform
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
@ -161,3 +162,32 @@ def test_pyfunc_scalar():
|
|||
net = PyFuncGraph(func_single_output, shape, ms_dtype)
|
||||
x = net(Tensor(x1), Tensor(x2))
|
||||
assert np.allclose(x.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_pyfunc_pynative():
|
||||
"""
|
||||
Feature: test case for Custom op with func_type="pyfunc"
|
||||
Description: the net runs on CPU; PYNATIVE_MODE
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
sys = platform.system()
|
||||
if sys == 'Windows':
|
||||
pass
|
||||
else:
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
|
||||
shape = (40, 40)
|
||||
|
||||
np.random.seed(42)
|
||||
x1 = np.random.randint(-5, 5, size=shape).astype(np.float32)
|
||||
x2 = np.random.randint(-5, 5, size=shape).astype(np.float32)
|
||||
n1, n2 = func_multi_output(x1, x2)
|
||||
|
||||
net = Custom(func_multi_output, (shape, shape), (ms.float32, ms.float32), "pyfunc")
|
||||
out = net(Tensor(x1), Tensor(x2))
|
||||
add = P.Add()
|
||||
res = add(out[0], out[1])
|
||||
|
||||
assert np.allclose(res.asnumpy(), n1+n2)
|
||||
|
|
Loading…
Reference in New Issue