diff --git a/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/ops/_op_impl/aicpu/__init__.py index dc5fba988ce..1f2e05284f5 100644 --- a/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/ops/_op_impl/aicpu/__init__.py @@ -38,3 +38,4 @@ from .gamma import _gamma_aicpu from .poisson import _poisson_aicpu from .uniform_int import _uniform_int_aicpu from .uniform_real import _uniform_real_aicpu +from .laplace import _laplace_aicpu diff --git a/mindspore/ops/_op_impl/aicpu/laplace.py b/mindspore/ops/_op_impl/aicpu/laplace.py new file mode 100644 index 00000000000..6bb284ed07b --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/laplace.py @@ -0,0 +1,33 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""RandomLaplace op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +laplace_op_info = AiCPURegOp("Laplace") \ + .fusion_type("OPAQUE") \ + .input(0, "shape", "required") \ + .input(1, "mean", "required") \ + .input(2, "lambda_param", "required") \ + .output(0, "output", "required") \ + .attr("seed", "int") \ + .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ + .get_op_info() + +@op_info_register(laplace_op_info) +def _laplace_aicpu(): + """RandomLaplace AiCPU register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 6ba138bc221..ad83ac346ad 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -55,7 +55,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps) from .random_ops import (RandomChoiceWithMask, Normal, Gamma, Poisson, UniformInt, UniformReal, - RandomCategorical) + RandomCategorical, Laplace) from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm, BiasAdd, Conv2D, DepthwiseConv2dNative, @@ -177,6 +177,7 @@ __all__ = [ 'Poisson', 'UniformInt', 'UniformReal', + 'Laplace', 'RandomCategorical', 'ResizeBilinear', 'ScalarSummary', diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index 251acf33f0f..4e60fb4003e 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -36,9 +36,9 @@ class Normal(PrimitiveWithInfer): Inputs: - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. - - **mean** (Tensor) - The mean μ distribution parameter, The mean specifies the location of the peak. - With float32 data type. - - **stddev** (Tensor) - the deviation σ distribution parameter. With float32 data type. + - **mean** (Tensor) - The mean μ distribution parameter, which specifies the location of the peak. + With float32 data type. + - **stddev** (Tensor) - The deviation σ distribution parameter. With float32 data type. Outputs: Tensor, has the shape 'shape' input and dtype as float32. @@ -75,6 +75,60 @@ class Normal(PrimitiveWithInfer): return out +class Laplace(PrimitiveWithInfer): + r""" + Generates random numbers according to the Laplace random number distribution. + It is defined as: + + .. math:: + \text{f}(x;μ,λ) = \frac{1}{2λ}\exp(-\frac{|x-μ|}{λ}), + + Args: + seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. + Default: 0. + + Inputs: + - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. + - **mean** (Tensor) - The mean μ distribution parameter, which specifies the location of the peak. + With float32 data type. + - **lambda_param** (Tensor) - The parameter used for controling the variance of this random distribution. The + variance of Laplace distribution is equal to twice the square of lambda_param. With float32 data type. + + Outputs: + Tensor, has the shape 'shape' input and dtype as float32. + + Examples: + >>> shape = (4, 16) + >>> mean = Tensor(1.0, mstype.float32) + >>> lambda_param = Tensor(1.0, mstype.float32) + >>> laplace = P.Laplace(seed=2) + >>> output = laplace(shape, mean, lambda_param) + """ + + @prim_attr_register + def __init__(self, seed=0): + """Init Laplace""" + self.init_prim_io_names(inputs=['shape', 'mean', 'lambda_param'], outputs=['output']) + validator.check_value_type('seed', seed, [int], self.name) + + def __infer__(self, shape, mean, lambda_param): + shape_v = shape["value"] + if shape_v is None: + raise ValueError(f"For {self.name}, shape must be const.") + validator.check_value_type("shape", shape_v, [tuple], self.name) + for i, shape_i in enumerate(shape_v): + validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) + validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name) + validator.check_tensor_type_same({"lambda_param": lambda_param["dtype"]}, [mstype.float32], self.name) + broadcast_shape = get_broadcast_shape(mean['shape'], lambda_param['shape'], self.name) + broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name) + out = { + 'shape': broadcast_shape, + 'dtype': mstype.float32, + 'value': None} + return out + + class Gamma(PrimitiveWithInfer): r""" Produces random positive floating-point values x, distributed according to probability density function: @@ -101,7 +155,7 @@ class Gamma(PrimitiveWithInfer): >>> alpha = Tensor(1.0, mstype.float32) >>> beta = Tensor(1.0, mstype.float32) >>> gamma = P.Gamma(seed=3) - >>> output = normal(shape, alpha, beta) + >>> output = Gamma(shape, alpha, beta) """ @prim_attr_register diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_laplace.py b/tests/st/ops/ascend/test_aicpu_ops/test_laplace.py new file mode 100644 index 00000000000..1cfd6c0ede2 --- /dev/null +++ b/tests/st/ops/ascend/test_aicpu_ops/test_laplace.py @@ -0,0 +1,57 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, shape, seed=0): + super(Net, self).__init__() + self.laplace = P.Laplace(seed=seed) + self.shape = shape + + def construct(self, mean, lambda_param): + return self.laplace(self.shape, mean, lambda_param) + + +def test_net_1D(): + seed = 10 + shape = (3, 2, 4) + mean = 1.0 + lambda_param = 1.0 + net = Net(shape, seed) + tmean, tlambda_param = Tensor(mean, mstype.float32), Tensor(lambda_param, mstype.float32) + output = net(tmean, tlambda_param) + print(output.asnumpy()) + assert output.shape == (3, 2, 4) + + +def test_net_ND(): + seed = 10 + shape = (3, 1, 2) + mean = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32) + lambda_param = np.array([1.0]).astype(np.float32) + net = Net(shape, seed) + tmean, tlambda_param = Tensor(mean), Tensor(lambda_param) + output = net(tmean, tlambda_param) + print(output.asnumpy()) + assert output.shape == (3, 2, 2) \ No newline at end of file diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index c9dab27b664..4faafc38301 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -410,6 +410,17 @@ class NormalNet(nn.Cell): return out +class LaplaceNet(nn.Cell): + def __init__(self, shape=None, seed=0): + super(LaplaceNet, self).__init__() + self.laplace = P.Laplace(seed=seed) + self.shape = shape + + def construct(self, mean, lambda_param): + out = self.laplace(self.shape, mean, lambda_param) + return out + + class GammaNet(nn.Cell): def __init__(self, shape=None, seed=0): super(GammaNet, self).__init__() @@ -666,6 +677,10 @@ test_case_math_ops = [ 'block': NormalNet((3, 2, 4), 0), 'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(1.0, mstype.float32)], 'skip': ['backward']}), + ('Laplace', { + 'block': LaplaceNet((3, 2, 4), 0), + 'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(1.0, mstype.float32)], + 'skip': ['backward']}), ('Gamma', { 'block': GammaNet((3, 2, 4), 0), 'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(1.0, mstype.float32)],