forked from mindspore-Ecosystem/mindspore
!21494 pyfunc primitive register
Merge pull request !21494 from chenweifeng/pyfunc-primitive-register
This commit is contained in:
commit
ace6e49e00
|
@ -135,7 +135,8 @@ void ScalarToRawMemory(const py::object &obj, const TypePtr &type, const Address
|
|||
void ArrayToRawMemory(const py::array &array, const AddressPtr &address) {
|
||||
if (static_cast<unsigned int>(array.flags()) & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_) {
|
||||
const py::buffer_info &buf_info = array.request();
|
||||
CHECK_RET_WITH_EXCEPT(memcpy_s(address->addr, address->size, buf_info.ptr, buf_info.size), EOK, "memcpy failed.");
|
||||
CHECK_RET_WITH_EXCEPT(memcpy_s(address->addr, address->size, buf_info.ptr, buf_info.size * buf_info.itemsize), EOK,
|
||||
"memcpy failed.");
|
||||
} else {
|
||||
// Transform numpy array to row major buffer.
|
||||
Py_buffer pybuf;
|
||||
|
|
|
@ -65,3 +65,4 @@ from .pad import _pad_cpu
|
|||
from .range import _range_cpu
|
||||
from .tensor_copy_slices import _tensor_copy_slices_cpu
|
||||
from .l2loss import _l2loss_cpu
|
||||
from .pyfunc import _pyfunc_cpu
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""PyFunc op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
pyfunc_op_info = CpuRegOp("PyFunc") \
|
||||
.input(0, "x", "dynamic") \
|
||||
.output(0, "y", "dynamic") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(pyfunc_op_info)
|
||||
def _pyfunc_cpu():
|
||||
"""PyFunc cpu register"""
|
||||
return
|
|
@ -47,3 +47,12 @@ class Registry(UserDict):
|
|||
if key in self:
|
||||
fn = self[prim_obj.name]
|
||||
return fn
|
||||
|
||||
class PyFuncRegistry(UserDict):
|
||||
def register(self, key, value):
|
||||
self[key] = value
|
||||
|
||||
def get(self, key):
|
||||
if key not in self:
|
||||
raise ValueError(f"Python function with key{key} not registered.")
|
||||
return self[key]
|
||||
|
|
|
@ -92,7 +92,7 @@ from ._quant_ops import *
|
|||
from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingBoxEncode,
|
||||
ConfusionMatrix, PopulationCount, UpdateState, Load,
|
||||
CheckValid, Partial, Depend, identity, CheckBprop, Push, Pull, PullWeight, PushWeight,
|
||||
StartFLJob, UpdateModel, GetModel)
|
||||
StartFLJob, UpdateModel, GetModel, PyFunc)
|
||||
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
|
||||
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
|
||||
CusMatMulCubeDenseRight,
|
||||
|
@ -524,6 +524,7 @@ __all__ = [
|
|||
"MDIterationGradientDescent",
|
||||
"BondForceWithAtomEnergyAndVirial",
|
||||
"ConstrainForceCycle",
|
||||
"PyFunc"
|
||||
]
|
||||
|
||||
__all__.sort()
|
||||
|
|
|
@ -15,13 +15,14 @@
|
|||
|
||||
"""Other operators."""
|
||||
import functools
|
||||
from mindspore import log as logger
|
||||
from mindspore.common import monad
|
||||
from mindspore.common._decorator import deprecated
|
||||
from .. import signature as sig
|
||||
from ..._checkparam import Validator as validator, Rel
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import Primitive, PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register
|
||||
|
||||
from .._register_for_op import PyFuncRegistry
|
||||
|
||||
class Assign(Primitive):
|
||||
"""
|
||||
|
@ -842,3 +843,80 @@ class identity(Primitive):
|
|||
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
pyfunc_register = PyFuncRegistry()
|
||||
def get_pyfunc(fn_id):
|
||||
return pyfunc_register.get(fn_id)
|
||||
|
||||
class PyFunc(PrimitiveWithInfer):
|
||||
r"""
|
||||
Execute Python function.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
||||
Args:
|
||||
fn (function): Python function which inputs and outputs should be Python built-in scalar or numpy ndarray.
|
||||
in_types (list[:class:`mindspore.dtype`]): The type of the inputs.
|
||||
in_shapes (list[tuple[int]]): The dimensionality of the inputs.
|
||||
out_types (list[:class:`mindspore.dtype`]): The type of the outputs.
|
||||
out_shapes (list[tuple[int]]): The dimensionality of the outputs.
|
||||
stateful (bool): Whether the function is stateful or not.
|
||||
If True, the execution order are same with model definition.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Union(tuple[Tensor], list[Tensor])) - The input tuple or list
|
||||
is made up of multiple tensors.
|
||||
|
||||
Outputs:
|
||||
tuple[Tensor], execution results Python functions.
|
||||
|
||||
Raises:
|
||||
TypeError: If the Python function execution failed.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> def func(x1, x2):
|
||||
>>> return x1 + x2
|
||||
>>> x1 = Tensor(np.array([1, 2, 3]).astype(np.float32))
|
||||
>>> x2 = Tensor(np.array([1, 2, 3]).astype(np.float32))
|
||||
>>> op = P.PyFunc(func, [x1.dtype, x2.dtype], [x1.shape, x2.shape], [x1.dtype], [x1.dtype])
|
||||
>>> output = op((x1, x2))
|
||||
>>> print(output[0].asnumpy())
|
||||
[2. 4. 6.]
|
||||
"""
|
||||
|
||||
def __init__(self, fn, in_types, in_shapes, out_types, out_shapes, stateful=True):
|
||||
super(PyFunc, self).__init__(self.__class__.__name__)
|
||||
pyfunc_register.register(id(fn), fn)
|
||||
self.add_prim_attr('fn_id', id(fn))
|
||||
self.add_prim_attr('in_types', in_types)
|
||||
self.add_prim_attr('in_shapes', in_shapes)
|
||||
self.add_prim_attr('out_types', out_types)
|
||||
self.add_prim_attr('out_shapes', out_shapes)
|
||||
validator.check_value_type("in_types", in_types, [list, tuple], self.name)
|
||||
validator.check_value_type("in_shapes", in_shapes, [list, tuple], self.name)
|
||||
validator.check("in_types length", len(in_types), "in_shapes length", len(in_shapes), Rel.EQ, self.name)
|
||||
validator.check_value_type("out_types", out_types, [list, tuple], self.name)
|
||||
validator.check_value_type("out_shapes", out_shapes, [list, tuple], self.name)
|
||||
validator.check("out_types length", len(out_types), "out_shapes length", len(out_shapes), Rel.EQ, self.name)
|
||||
self.add_prim_attr("side_effect_io", stateful)
|
||||
self.add_prim_attr("primitive_target", "CPU")
|
||||
|
||||
def infer_shape(self, *args):
|
||||
if self.out_shapes:
|
||||
return tuple(self.out_shapes)
|
||||
|
||||
logger.warning("The function output are empty tuple. Add a placeholder instead. "
|
||||
"Do not use it as it could be any uninitialized data.")
|
||||
return ((1,),)
|
||||
|
||||
def infer_dtype(self, *args):
|
||||
if self.out_shapes:
|
||||
return tuple(self.out_types)
|
||||
|
||||
logger.warning("The function output are empty tuple. Add a placeholder instead. "
|
||||
"Do not use it as it could be any uninitialized data.")
|
||||
return (mstype.int32,)
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test loss """
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
def func_single_output(x1, x2):
|
||||
return x1 - x2
|
||||
|
||||
def func_multi_output(x1, x2):
|
||||
return (x1 + x2), (x1 - x2)
|
||||
|
||||
output = 0
|
||||
def func_no_output(x1, x2):
|
||||
global output
|
||||
output = x1 + x2
|
||||
|
||||
class PyFuncNet(nn.Cell):
|
||||
def __init__(self, fn, in_types, in_shapes, out_types, out_shapes):
|
||||
super().__init__()
|
||||
self.func = P.PyFunc(fn, in_types, in_shapes, out_types, out_shapes)
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def construct(self, x1, x2):
|
||||
x = self.func((x1, x2))
|
||||
return self.relu(x[0])
|
||||
|
||||
|
||||
def func_with_dtype(ms_dtype, np_dtype):
|
||||
shape = (40, 40)
|
||||
np.random.seed(42)
|
||||
x1 = np.random.randint(-5, 5, size=shape).astype(np_dtype)
|
||||
x2 = np.random.randint(-5, 5, size=shape).astype(np_dtype)
|
||||
|
||||
expect = func_single_output(x1, x2)
|
||||
expect = P.ReLU()(Tensor(expect))
|
||||
|
||||
net = PyFuncNet(func_single_output, [ms_dtype, ms_dtype], [shape, shape], [ms_dtype], [shape])
|
||||
x = net(Tensor(x1), Tensor(x2))
|
||||
assert np.allclose(x.asnumpy(), expect.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pyfunc_single_output():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
func_with_dtype(ms.float16, np.float16)
|
||||
func_with_dtype(ms.float32, np.float32)
|
||||
func_with_dtype(ms.float64, np.float64)
|
||||
func_with_dtype(ms.int32, np.int32)
|
||||
func_with_dtype(ms.int64, np.int64)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pyfunc_multi_output():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
shape = (40, 40)
|
||||
dtype = ms.float32
|
||||
|
||||
np.random.seed(42)
|
||||
x1 = np.random.randint(-5, 5, size=shape).astype(np.float32)
|
||||
x2 = np.random.randint(-5, 5, size=shape).astype(np.float32)
|
||||
expect, _ = func_multi_output(x1, x2)
|
||||
expect = P.ReLU()(Tensor(expect))
|
||||
|
||||
net = PyFuncNet(func_multi_output, [dtype, dtype], [shape, shape], [dtype, dtype], [shape, shape])
|
||||
x = net(Tensor(x1), Tensor(x2))
|
||||
|
||||
assert np.allclose(x.asnumpy(), expect.asnumpy())
|
||||
|
||||
|
||||
class PyFuncGraph(nn.Cell):
|
||||
def __init__(self, fn, in_types, in_shapes, out_types, out_shapes):
|
||||
super().__init__()
|
||||
self.func = P.PyFunc(fn, in_types, in_shapes, out_types, out_shapes)
|
||||
|
||||
def construct(self, x1, x2):
|
||||
return self.func((x1, x2))
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pyfunc_no_output():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
shape = (40, 40)
|
||||
dtype = ms.float32
|
||||
|
||||
np.random.seed(42)
|
||||
x1 = np.random.randint(-5, 5, size=shape).astype(np.float32)
|
||||
x2 = np.random.randint(-5, 5, size=shape).astype(np.float32)
|
||||
func_no_output(x1, x2)
|
||||
global output
|
||||
expect = output
|
||||
|
||||
net = PyFuncGraph(func_no_output, [dtype, dtype], [shape, shape], [], [])
|
||||
net(Tensor(x1), Tensor(x2))
|
||||
net_output = output
|
||||
|
||||
assert np.allclose(net_output, expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pyfunc_scalar():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
shape = ()
|
||||
ms_dtype = ms.int32
|
||||
|
||||
x1 = int(10)
|
||||
x2 = int(5)
|
||||
expect = func_single_output(x1, x2)
|
||||
|
||||
net = PyFuncGraph(func_single_output, [ms_dtype, ms_dtype], [shape, shape], [ms_dtype], [shape])
|
||||
x = net(Tensor(x1), Tensor(x2))
|
||||
assert np.allclose(x[0].asnumpy(), expect)
|
Loading…
Reference in New Issue