!21494 pyfunc primitive register

Merge pull request !21494 from chenweifeng/pyfunc-primitive-register
This commit is contained in:
i-robot 2021-08-10 13:43:46 +00:00 committed by Gitee
commit ace6e49e00
7 changed files with 259 additions and 3 deletions

View File

@ -135,7 +135,8 @@ void ScalarToRawMemory(const py::object &obj, const TypePtr &type, const Address
void ArrayToRawMemory(const py::array &array, const AddressPtr &address) {
if (static_cast<unsigned int>(array.flags()) & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_) {
const py::buffer_info &buf_info = array.request();
CHECK_RET_WITH_EXCEPT(memcpy_s(address->addr, address->size, buf_info.ptr, buf_info.size), EOK, "memcpy failed.");
CHECK_RET_WITH_EXCEPT(memcpy_s(address->addr, address->size, buf_info.ptr, buf_info.size * buf_info.itemsize), EOK,
"memcpy failed.");
} else {
// Transform numpy array to row major buffer.
Py_buffer pybuf;

View File

@ -65,3 +65,4 @@ from .pad import _pad_cpu
from .range import _range_cpu
from .tensor_copy_slices import _tensor_copy_slices_cpu
from .l2loss import _l2loss_cpu
from .pyfunc import _pyfunc_cpu

View File

@ -0,0 +1,29 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""PyFunc op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
pyfunc_op_info = CpuRegOp("PyFunc") \
.input(0, "x", "dynamic") \
.output(0, "y", "dynamic") \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(pyfunc_op_info)
def _pyfunc_cpu():
"""PyFunc cpu register"""
return

View File

@ -47,3 +47,12 @@ class Registry(UserDict):
if key in self:
fn = self[prim_obj.name]
return fn
class PyFuncRegistry(UserDict):
def register(self, key, value):
self[key] = value
def get(self, key):
if key not in self:
raise ValueError(f"Python function with key{key} not registered.")
return self[key]

View File

@ -92,7 +92,7 @@ from ._quant_ops import *
from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingBoxEncode,
ConfusionMatrix, PopulationCount, UpdateState, Load,
CheckValid, Partial, Depend, identity, CheckBprop, Push, Pull, PullWeight, PushWeight,
StartFLJob, UpdateModel, GetModel)
StartFLJob, UpdateModel, GetModel, PyFunc)
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
CusMatMulCubeDenseRight,
@ -524,6 +524,7 @@ __all__ = [
"MDIterationGradientDescent",
"BondForceWithAtomEnergyAndVirial",
"ConstrainForceCycle",
"PyFunc"
]
__all__.sort()

View File

@ -15,13 +15,14 @@
"""Other operators."""
import functools
from mindspore import log as logger
from mindspore.common import monad
from mindspore.common._decorator import deprecated
from .. import signature as sig
from ..._checkparam import Validator as validator, Rel
from ...common import dtype as mstype
from ..primitive import Primitive, PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register
from .._register_for_op import PyFuncRegistry
class Assign(Primitive):
"""
@ -842,3 +843,80 @@ class identity(Primitive):
def __call__(self, x):
return x
pyfunc_register = PyFuncRegistry()
def get_pyfunc(fn_id):
return pyfunc_register.get(fn_id)
class PyFunc(PrimitiveWithInfer):
r"""
Execute Python function.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Args:
fn (function): Python function which inputs and outputs should be Python built-in scalar or numpy ndarray.
in_types (list[:class:`mindspore.dtype`]): The type of the inputs.
in_shapes (list[tuple[int]]): The dimensionality of the inputs.
out_types (list[:class:`mindspore.dtype`]): The type of the outputs.
out_shapes (list[tuple[int]]): The dimensionality of the outputs.
stateful (bool): Whether the function is stateful or not.
If True, the execution order are same with model definition.
Inputs:
- **input_x** (Union(tuple[Tensor], list[Tensor])) - The input tuple or list
is made up of multiple tensors.
Outputs:
tuple[Tensor], execution results Python functions.
Raises:
TypeError: If the Python function execution failed.
Supported Platforms:
``CPU``
Examples:
>>> def func(x1, x2):
>>> return x1 + x2
>>> x1 = Tensor(np.array([1, 2, 3]).astype(np.float32))
>>> x2 = Tensor(np.array([1, 2, 3]).astype(np.float32))
>>> op = P.PyFunc(func, [x1.dtype, x2.dtype], [x1.shape, x2.shape], [x1.dtype], [x1.dtype])
>>> output = op((x1, x2))
>>> print(output[0].asnumpy())
[2. 4. 6.]
"""
def __init__(self, fn, in_types, in_shapes, out_types, out_shapes, stateful=True):
super(PyFunc, self).__init__(self.__class__.__name__)
pyfunc_register.register(id(fn), fn)
self.add_prim_attr('fn_id', id(fn))
self.add_prim_attr('in_types', in_types)
self.add_prim_attr('in_shapes', in_shapes)
self.add_prim_attr('out_types', out_types)
self.add_prim_attr('out_shapes', out_shapes)
validator.check_value_type("in_types", in_types, [list, tuple], self.name)
validator.check_value_type("in_shapes", in_shapes, [list, tuple], self.name)
validator.check("in_types length", len(in_types), "in_shapes length", len(in_shapes), Rel.EQ, self.name)
validator.check_value_type("out_types", out_types, [list, tuple], self.name)
validator.check_value_type("out_shapes", out_shapes, [list, tuple], self.name)
validator.check("out_types length", len(out_types), "out_shapes length", len(out_shapes), Rel.EQ, self.name)
self.add_prim_attr("side_effect_io", stateful)
self.add_prim_attr("primitive_target", "CPU")
def infer_shape(self, *args):
if self.out_shapes:
return tuple(self.out_shapes)
logger.warning("The function output are empty tuple. Add a placeholder instead. "
"Do not use it as it could be any uninitialized data.")
return ((1,),)
def infer_dtype(self, *args):
if self.out_shapes:
return tuple(self.out_types)
logger.warning("The function output are empty tuple. Add a placeholder instead. "
"Do not use it as it could be any uninitialized data.")
return (mstype.int32,)

View File

@ -0,0 +1,137 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test loss """
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
def func_single_output(x1, x2):
return x1 - x2
def func_multi_output(x1, x2):
return (x1 + x2), (x1 - x2)
output = 0
def func_no_output(x1, x2):
global output
output = x1 + x2
class PyFuncNet(nn.Cell):
def __init__(self, fn, in_types, in_shapes, out_types, out_shapes):
super().__init__()
self.func = P.PyFunc(fn, in_types, in_shapes, out_types, out_shapes)
self.relu = P.ReLU()
def construct(self, x1, x2):
x = self.func((x1, x2))
return self.relu(x[0])
def func_with_dtype(ms_dtype, np_dtype):
shape = (40, 40)
np.random.seed(42)
x1 = np.random.randint(-5, 5, size=shape).astype(np_dtype)
x2 = np.random.randint(-5, 5, size=shape).astype(np_dtype)
expect = func_single_output(x1, x2)
expect = P.ReLU()(Tensor(expect))
net = PyFuncNet(func_single_output, [ms_dtype, ms_dtype], [shape, shape], [ms_dtype], [shape])
x = net(Tensor(x1), Tensor(x2))
assert np.allclose(x.asnumpy(), expect.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pyfunc_single_output():
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
func_with_dtype(ms.float16, np.float16)
func_with_dtype(ms.float32, np.float32)
func_with_dtype(ms.float64, np.float64)
func_with_dtype(ms.int32, np.int32)
func_with_dtype(ms.int64, np.int64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pyfunc_multi_output():
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
shape = (40, 40)
dtype = ms.float32
np.random.seed(42)
x1 = np.random.randint(-5, 5, size=shape).astype(np.float32)
x2 = np.random.randint(-5, 5, size=shape).astype(np.float32)
expect, _ = func_multi_output(x1, x2)
expect = P.ReLU()(Tensor(expect))
net = PyFuncNet(func_multi_output, [dtype, dtype], [shape, shape], [dtype, dtype], [shape, shape])
x = net(Tensor(x1), Tensor(x2))
assert np.allclose(x.asnumpy(), expect.asnumpy())
class PyFuncGraph(nn.Cell):
def __init__(self, fn, in_types, in_shapes, out_types, out_shapes):
super().__init__()
self.func = P.PyFunc(fn, in_types, in_shapes, out_types, out_shapes)
def construct(self, x1, x2):
return self.func((x1, x2))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pyfunc_no_output():
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
shape = (40, 40)
dtype = ms.float32
np.random.seed(42)
x1 = np.random.randint(-5, 5, size=shape).astype(np.float32)
x2 = np.random.randint(-5, 5, size=shape).astype(np.float32)
func_no_output(x1, x2)
global output
expect = output
net = PyFuncGraph(func_no_output, [dtype, dtype], [shape, shape], [], [])
net(Tensor(x1), Tensor(x2))
net_output = output
assert np.allclose(net_output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pyfunc_scalar():
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
shape = ()
ms_dtype = ms.int32
x1 = int(10)
x2 = int(5)
expect = func_single_output(x1, x2)
net = PyFuncGraph(func_single_output, [ms_dtype, ms_dtype], [shape, shape], [ms_dtype], [shape])
x = net(Tensor(x1), Tensor(x2))
assert np.allclose(x[0].asnumpy(), expect)