!25726 Add Pynative Test Cases for Custom AOT, Pyfunc

Merge pull request !25726 from jiaoy1224/arithmetic
This commit is contained in:
i-robot 2021-11-09 06:31:46 +00:00 committed by Gitee
commit 9a3b52eda7
7 changed files with 98 additions and 40 deletions

View File

@ -212,7 +212,6 @@ void PyObjectToRawMemorys(const py::object &object, const PyFuncArgumentInfo &ou
} // namespace
void PyFuncCpuKernel::InitKernel(const CNodePtr &kernel_node) {
is_custom_ = IsPrimitiveCNode(kernel_node, prim::kPrimCustom);
func_id_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "fn_id");
fake_output_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "fake_output");
single_scalar_output_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "single_scalar_output");
@ -313,8 +312,7 @@ bool PyFuncCpuKernel::ExecuteKernel(const std::vector<AddressPtr> &inputs, const
py::function PyFuncCpuKernel::GetPythonFunc() {
py::gil_scoped_acquire gil_acquire;
static const std::string &module_name =
is_custom_ ? "mindspore.ops.operations.custom_ops" : "mindspore.ops.operations.other_ops";
static const std::string &module_name = "mindspore.ops.operations._pyfunc_registry";
static const std::string &entrance = "get_pyfunc";
py::module module = py::module::import(module_name.c_str());
py::object get_pyfunc_obj = module.attr(entrance.c_str());

View File

@ -42,8 +42,7 @@ struct PyFuncArgumentInfo {
class PyFuncCpuKernel : public CPUKernel {
public:
PyFuncCpuKernel()
: is_custom_(false), init_(false), fake_output_(false), single_scalar_output_(false), func_id_(-1) {}
PyFuncCpuKernel() : init_(false), fake_output_(false), single_scalar_output_(false), func_id_(-1) {}
~PyFuncCpuKernel() = default;
// Init kernel including analyse PyFunc input and output info.
@ -59,10 +58,6 @@ class PyFuncCpuKernel : public CPUKernel {
py::function GetPythonFunc();
bool ExecuteKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
// both mindspore.ops.operations.custom_ops.Custom and mindspore.ops.operations.PyFunc will launch
// this kernel (these two have similar features, will further be unified);if is_custom_ is true, then it's
// launched from Custom; if not, it's from PyFunc
bool is_custom_;
bool init_;
bool fake_output_;
bool single_scalar_output_;

View File

@ -0,0 +1,29 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Register pyfunc for py_func_cpu_kernel"""
from mindspore.ops._register_for_op import PyFuncRegistry
registered_py_id = PyFuncRegistry()
def add_pyfunc(fn_id, fn):
registered_py_id.register(fn_id, fn)
def get_pyfunc(fn_id):
return registered_py_id.get(fn_id)

View File

@ -21,9 +21,9 @@ import functools
from mindspore import ops
from mindspore import log as logger
from mindspore.ops.op_info_register import RegOp, DataType
from mindspore.ops._register_for_op import PyFuncRegistry
from mindspore.common import dtype as mstype
from mindspore._c_expression import Oplib
from ._pyfunc_registry import add_pyfunc
class CustomRegOp(RegOp):
@ -129,10 +129,6 @@ def custom_op_info_register(*reg_info):
return decorator
def get_pyfunc(fn):
return Custom.registered_py_id.get(fn)
class Custom(ops.PrimitiveWithInfer):
r"""
`Custom` primitive is used for user defined operators and is to enhance the expressive ability of built-in
@ -311,7 +307,6 @@ class Custom(ops.PrimitiveWithInfer):
registered_func = {}
attr_dict = {} # Save input_names and attr_names for func.
registered_py_id = PyFuncRegistry()
def __init__(self, func, out_shape, out_dtype, func_type, bprop=None, reg_info=None):
ops.PrimitiveWithInfer.__init__(self, "Custom")
@ -331,7 +326,7 @@ class Custom(ops.PrimitiveWithInfer):
self.fn_id = id(self.func)
self.uniq_name = self.name + "_" + self.func_name + "_" + str(self.fn_id)
if func_type == "pyfunc":
Custom.registered_py_id.register(self.fn_id, self.func)
add_pyfunc(self.fn_id, self.func)
elif isinstance(self.func, str):
self.func_name = self.func
self.uniq_name = self.name + "_" + self.func_name

View File

@ -22,7 +22,7 @@ from .. import signature as sig
from ..._checkparam import Validator as validator, Rel
from ...common import dtype as mstype
from ..primitive import Primitive, PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register
from .._register_for_op import PyFuncRegistry
from ._pyfunc_registry import add_pyfunc
class Assign(Primitive):
@ -893,13 +893,6 @@ class identity(Primitive):
return x
pyfunc_register = PyFuncRegistry()
def get_pyfunc(fn_id):
return pyfunc_register.get(fn_id)
class PyFunc(PrimitiveWithInfer):
r"""
Execute Python function.
@ -951,7 +944,7 @@ class PyFunc(PrimitiveWithInfer):
def __init__(self, fn, in_types, in_shapes, out_types, out_shapes, stateful=True):
super(PyFunc, self).__init__(self.__class__.__name__)
pyfunc_register.register(id(fn), fn)
add_pyfunc(id(fn), fn)
self.add_prim_attr('fn_id', id(fn))
self.add_prim_attr('in_types', in_types)
self.add_prim_attr('in_shapes', in_shapes)

View File

@ -105,7 +105,7 @@ add_cpu_info = CustomRegOp() \
@pytest.mark.env_onecard
def test_aot_single_output_cpu():
"""
Feature: custom aot operator, multiple inputs, single output, CPU
Feature: custom aot operator, multiple inputs, single output, CPU, GRAPH_MODE
Description: pre-compile xxx.cc to xxx.so, custom operator launches xxx.so
Expectation: nn result matches numpy result
"""
@ -300,24 +300,15 @@ multioutput_bprop_gpu_info = CustomRegOp() \
.get_op_info()
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training
def test_add_mul_div_bprop():
"""
Feature: custom aot operator with reg info, bprop(Cell), multiple outputs, GPU
Description: pre-compile xxx.cu to xxx.so, custom operator launches xxx.so
Expectation: nn result matches numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
def add_mul_div_bprop(source, execf, source_prop, execf_prop):
x = np.array([1.0, 4.0, 9.0]).astype(np.float32)
y = np.array([1.0, 1.0, 1.0]).astype(np.float32)
sens = np.array([1.0, 1.0, 1.0]).astype(np.float32)
expect_dx = np.array([5.0, 17.0, 37.0]).astype(np.float32)
expect_dy = np.array([-1.0, -16.0, -81.0]).astype(np.float32)
cmd_bprop, func_path_bprop = get_file_path_gpu("add_mul_div_bprop.cu", "add_mul_div_bprop.so")
check_exec_file(cmd_bprop, func_path_bprop, "add_mul_div_bprop.cu", "add_mul_div_bprop.so")
cmd_bprop, func_path_bprop = get_file_path_gpu(source_prop, execf_prop)
check_exec_file(cmd_bprop, func_path_bprop, source_prop, execf_prop)
try:
aot_bprop = Custom(func_path_bprop + ":CustomAddMulDivBprop",
([3], [3]), (mstype.float32, mstype.float32), "aot", reg_info=multioutput_bprop_gpu_info)
@ -330,8 +321,8 @@ def test_add_mul_div_bprop():
res = aot_bprop(x, y, dout[0], dout[1], dout[2])
return res
cmd, func_path = get_file_path_gpu("add_mul_div.cu", "add_mul_div.so")
check_exec_file(cmd, func_path, "add_mul_div.cu", "add_mul_div.so")
cmd, func_path = get_file_path_gpu(source, execf)
check_exec_file(cmd, func_path, source, execf)
try:
net = AOTMultiOutputNet(func_path + ":CustomAddMulDiv", ([3], [3], [3]),
(mstype.float32, mstype.float32, mstype.float32), bprop=bprop, reg=multioutput_gpu_info)
@ -349,3 +340,30 @@ def test_add_mul_div_bprop():
dy_np = dy.asnumpy()
assert np.allclose(expect_dx, dx_np, 0.0001, 0.0001)
assert np.allclose(expect_dy, dy_np, 0.0001, 0.0001)
@ pytest.mark.level0
@ pytest.mark.env_onecard
@ pytest.mark.platform_x86_gpu_training
def test_add_mul_div_bprop_graph():
"""
Feature: custom aot operator, bprop(Cell), multiple outputs, GPU, GRAPH_MODE
Description: pre-compile xxx.cu to xxx.so, custom operator launches xxx.so
Expectation: nn result matches numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
add_mul_div_bprop("add_mul_div.cu", "add_mul_div.so", "add_mul_div_bprop.cu", "add_mul_div_bprop.so")
@ pytest.mark.level0
@ pytest.mark.env_onecard
@ pytest.mark.platform_x86_gpu_training
def test_add_mul_div_bprop_pynative():
"""
Feature: custom aot operator, bprop(Cell), multiple outputs, GPU, PYNATIVE_MODE
Description: pre-compile xxx.cu to xxx.so, custom operator launches xxx.so
Expectation: nn result matches numpy result
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
add_mul_div_bprop("add_mul_div.cu", "add_mul_div_pynative.so",
"add_mul_div_bprop.cu", "add_mul_div_bprop_pynative.so")

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import platform
import numpy as np
import pytest
@ -161,3 +162,32 @@ def test_pyfunc_scalar():
net = PyFuncGraph(func_single_output, shape, ms_dtype)
x = net(Tensor(x1), Tensor(x2))
assert np.allclose(x.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_pyfunc_pynative():
"""
Feature: test case for Custom op with func_type="pyfunc"
Description: the net runs on CPU; PYNATIVE_MODE
Expectation: the result match with numpy result
"""
sys = platform.system()
if sys == 'Windows':
pass
else:
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
shape = (40, 40)
np.random.seed(42)
x1 = np.random.randint(-5, 5, size=shape).astype(np.float32)
x2 = np.random.randint(-5, 5, size=shape).astype(np.float32)
n1, n2 = func_multi_output(x1, x2)
net = Custom(func_multi_output, (shape, shape), (ms.float32, ms.float32), "pyfunc")
out = net(Tensor(x1), Tensor(x2))
add = P.Add()
res = add(out[0], out[1])
assert np.allclose(res.asnumpy(), n1+n2)