diff --git a/mindspore/nn/probability/distribution/bernoulli.py b/mindspore/nn/probability/distribution/bernoulli.py index 00382cd5cd7..9f20c60af46 100644 --- a/mindspore/nn/probability/distribution/bernoulli.py +++ b/mindspore/nn/probability/distribution/bernoulli.py @@ -153,6 +153,16 @@ class Bernoulli(Distribution): """ return self._probs + def _get_dist_type(self): + return "Bernoulli" + + def _get_dist_args(self, probs1=None): + if probs1 is not None: + self.checktensor(probs1, 'probs') + else: + probs1 = self.probs + return (probs1,) + def _mean(self, probs1=None): r""" .. math:: diff --git a/mindspore/nn/probability/distribution/categorical.py b/mindspore/nn/probability/distribution/categorical.py index de85bbb7b6d..7546598810b 100644 --- a/mindspore/nn/probability/distribution/categorical.py +++ b/mindspore/nn/probability/distribution/categorical.py @@ -169,6 +169,16 @@ class Categorical(Distribution): """ return self._probs + def _get_dist_type(self): + return "Categorical" + + def _get_dist_args(self, probs=None): + if probs is not None: + self.checktensor(probs, 'probs') + else: + probs = self.probs + return (probs,) + def _mean(self, probs=None): r""" .. math:: diff --git a/mindspore/nn/probability/distribution/distribution.py b/mindspore/nn/probability/distribution/distribution.py index 39e1073e784..d218d885374 100644 --- a/mindspore/nn/probability/distribution/distribution.py +++ b/mindspore/nn/probability/distribution/distribution.py @@ -344,6 +344,33 @@ class Distribution(Cell): else: self._call_cross_entropy = self._raise_not_implemented_error('cross_entropy') + def _get_dist_args(self, *args, **kwargs): + return raise_not_implemented_util('get_dist_args', self.name, *args, **kwargs) + + def get_dist_args(self, *args, **kwargs): + """ + Check the availability and validity of default parameters and `dist_spec_args`. + + Args: + *args (list): the list of positional arguments forwarded to subclasses. + **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. + + Note: + `dist_spec_args` must be passed in through list or dictionary. The order of `dist_spec_args` + should follow the initialization order of default parameters through `_add_parameter`. + If some `dist_spec_args` is None, the corresponding default parameter is returned. + """ + return self._get_dist_args(*args, **kwargs) + + def _get_dist_type(self, *args, **kwargs): + return raise_not_implemented_util('get_dist_type', self.name, *args, **kwargs) + + def get_dist_type(self, *args, **kwargs): + """ + Return the type of the distribution. + """ + return self._get_dist_type(*args, **kwargs) + def _raise_not_implemented_error(self, func_name): name = self.name def raise_error(*args, **kwargs): @@ -721,4 +748,8 @@ class Distribution(Cell): return self._call_cross_entropy(*args, **kwargs) if name == 'sample': return self._sample(*args, **kwargs) + if name == 'get_dist_args': + return self._get_dist_args(*args, **kwargs) + if name == 'get_dist_type': + return self._get_dist_type(*args, **kwargs) return raise_not_implemented_util(name, self.name, *args, **kwargs) diff --git a/mindspore/nn/probability/distribution/exponential.py b/mindspore/nn/probability/distribution/exponential.py index 378e2cba319..64e3a88363e 100644 --- a/mindspore/nn/probability/distribution/exponential.py +++ b/mindspore/nn/probability/distribution/exponential.py @@ -157,6 +157,16 @@ class Exponential(Distribution): """ return self._rate + def _get_dist_type(self): + return "Exponential" + + def _get_dist_args(self, rate=None): + if rate is not None: + self.checktensor(rate, 'rate') + else: + rate = self.rate + return (rate,) + def _mean(self, rate=None): r""" .. math:: diff --git a/mindspore/nn/probability/distribution/geometric.py b/mindspore/nn/probability/distribution/geometric.py index ad0eef12c44..a7f087771a4 100644 --- a/mindspore/nn/probability/distribution/geometric.py +++ b/mindspore/nn/probability/distribution/geometric.py @@ -162,6 +162,16 @@ class Geometric(Distribution): """ return self._probs + def _get_dist_type(self): + return "Geometric" + + def _get_dist_args(self, probs1=None): + if probs1 is not None: + self.checktensor(probs1, 'probs') + else: + probs1 = self.probs + return (probs1,) + def _mean(self, probs1=None): r""" .. math:: diff --git a/mindspore/nn/probability/distribution/gumbel.py b/mindspore/nn/probability/distribution/gumbel.py index c341db85376..fca438a7777 100644 --- a/mindspore/nn/probability/distribution/gumbel.py +++ b/mindspore/nn/probability/distribution/gumbel.py @@ -109,7 +109,7 @@ class Gumbel(TransformedDistribution): bijector=msb.Invert(gumbel_cdf), seed=seed, name=name) - self._parameter_type = gumbel_cdf.parameter_type + self.parameter_type = gumbel_cdf.parameter_type self._broadcast_shape = gumbel_cdf.event_shape if self._broadcast_shape != (): self._is_scalar_batch = False @@ -146,6 +146,20 @@ class Gumbel(TransformedDistribution): str_info = f'batch_shape = {self._broadcast_shape}' return str_info + def _get_dist_type(self): + return "Gumbel" + + def _get_dist_args(self, loc=None, scale=None): + if loc is not None: + self.checktensor(loc, 'loc') + else: + loc = self.loc + if scale is not None: + self.checktensor(scale, 'scale') + else: + scale = self.scale + return loc, scale + def _mean(self): r""" The mean of the distribution. diff --git a/mindspore/nn/probability/distribution/log_normal.py b/mindspore/nn/probability/distribution/log_normal.py index eb5782e10cf..64bc160ff2e 100644 --- a/mindspore/nn/probability/distribution/log_normal.py +++ b/mindspore/nn/probability/distribution/log_normal.py @@ -161,6 +161,20 @@ class LogNormal(msd.TransformedDistribution): """Distribution parameter for the pre-transformed standard deviation.""" return self.distribution("sd") + def _get_dist_type(self): + return "LogNormal" + + def _get_dist_args(self, loc=None, scale=None): + if loc is not None: + self.checktensor(loc, 'loc') + else: + loc = self.distribution("mean") + if scale is not None: + self.checktensor(scale, 'scale') + else: + scale = self.distribution("sd") + return loc, scale + def extend_repr(self): if self.is_scalar_batch: s = f'loc = {self._mean_value}, scale = {self._sd_value}' diff --git a/mindspore/nn/probability/distribution/logistic.py b/mindspore/nn/probability/distribution/logistic.py index de27eee046e..7dedc645151 100644 --- a/mindspore/nn/probability/distribution/logistic.py +++ b/mindspore/nn/probability/distribution/logistic.py @@ -175,6 +175,20 @@ class Logistic(Distribution): """ return self._scale + def _get_dist_type(self): + return "Logistic" + + def _get_dist_args(self, loc=None, scale=None): + if loc is not None: + self.checktensor(loc, 'loc') + else: + loc = self.loc + if scale is not None: + self.checktensor(scale, 'scale') + else: + scale = self.scale + return loc, scale + def _mean(self, loc=None, scale=None): """ The mean of the distribution. diff --git a/mindspore/nn/probability/distribution/normal.py b/mindspore/nn/probability/distribution/normal.py index 0059c74b225..189e4f36fcf 100644 --- a/mindspore/nn/probability/distribution/normal.py +++ b/mindspore/nn/probability/distribution/normal.py @@ -154,6 +154,20 @@ class Normal(Distribution): s = f'batch_shape = {self._broadcast_shape}' return s + def _get_dist_type(self): + return "Normal" + + def _get_dist_args(self, mean=None, sd=None): + if mean is not None: + self.checktensor(mean, 'mean') + else: + mean = self._mean_value + if sd is not None: + self.checktensor(sd, 'sd') + else: + sd = self._sd_value + return mean, sd + def _mean(self, mean=None, sd=None): """ The mean of the distribution. diff --git a/mindspore/nn/probability/distribution/uniform.py b/mindspore/nn/probability/distribution/uniform.py index 31f317d7867..1324121fb0c 100644 --- a/mindspore/nn/probability/distribution/uniform.py +++ b/mindspore/nn/probability/distribution/uniform.py @@ -173,6 +173,20 @@ class Uniform(Distribution): """ return self._high + def _get_dist_type(self): + return "Uniform" + + def _get_dist_args(self, low=None, high=None): + if low is not None: + self.checktensor(low, 'low') + else: + low = self.low + if high is not None: + self.checktensor(high, 'high') + else: + high = self.high + return high, low + def _range(self, low=None, high=None): r""" Return the range of the distribution. diff --git a/tests/st/probability/distribution/test_get_dist_args.py b/tests/st/probability/distribution/test_get_dist_args.py new file mode 100644 index 00000000000..4fc24ad86b3 --- /dev/null +++ b/tests/st/probability/distribution/test_get_dist_args.py @@ -0,0 +1,101 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for Normal distribution""" +import numpy as np +import mindspore.context as context +import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd +from mindspore import Tensor +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Net1(nn.Cell): + """ + Test class: Normal distribution. `dist_spec_args` are `mean`, `sd`. + """ + def __init__(self): + super(Net1, self).__init__() + self.normal = msd.Normal(dtype=dtype.float32) + self.normal1 = msd.Normal(0.0, 1.0, dtype=dtype.float32) + self.normal2 = msd.Normal(3.0, 4.0, dtype=dtype.float32) + + def construct(self, value, mean, sd, mean_a, sd_a): + args_list = self.normal.get_dist_args(mean, sd) + prob = self.normal1.prob(value, *args_list) + args_list1 = self.normal.get_dist_args() + prob1 = self.normal2.prob(value, *args_list1) + + args_list2 = self.normal1.get_dist_args() + dist_type = self.normal1.get_dist_type() + kl_loss = self.normal2.kl_loss(dist_type, *args_list2) + + args_list3 = self.normal.get_dist_args(mean_a, sd_a) + dist_type = self.normal1.get_dist_type() + kl_loss1 = self.normal2.kl_loss(dist_type, *args_list3) + return prob, prob1, kl_loss, kl_loss1 + +def test1(): + """ + Test Normal with two `dist_spec_args`. + """ + net = Net1() + mean = Tensor(3.0, dtype=dtype.float32) + sd = Tensor(4.0, dtype=dtype.float32) + mean_a = Tensor(0.0, dtype=dtype.float32) + sd_a = Tensor(1.0, dtype=dtype.float32) + value = Tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) + ans, expected, ans1, expected1 = net(value, mean, sd, mean_a, sd_a) + tol = 1e-6 + assert (np.abs(ans.asnumpy() - expected.asnumpy()) < tol).all() + assert (np.abs(ans1.asnumpy() - expected1.asnumpy()) < tol).all() + +class Net2(nn.Cell): + """ + Test class: Exponential distribution. `dist_spec_args` is `rate`. + """ + def __init__(self): + super(Net2, self).__init__() + self.expon = msd.Exponential(dtype=dtype.float32) + self.expon1 = msd.Exponential(1.0, dtype=dtype.float32) + self.expon2 = msd.Exponential(2.0, dtype=dtype.float32) + + def construct(self, value, rate, rate1): + args_list = self.expon.get_dist_args(rate) + prob = self.expon1.prob(value, *args_list) + args_list1 = self.expon.get_dist_args() + prob1 = self.expon2.prob(value, *args_list1) + + args_list2 = self.expon1.get_dist_args() + dist_type = self.expon1.get_dist_type() + kl_loss = self.expon2.kl_loss(dist_type, *args_list2) + + args_list3 = self.expon.get_dist_args(rate1) + dist_type = self.expon.get_dist_type() + kl_loss1 = self.expon2.kl_loss(dist_type, *args_list3) + return prob, prob1, kl_loss, kl_loss1 + +def test2(): + """ + Test Expomential with single `dist_spec_args`. + """ + net = Net2() + rate = Tensor(2.0, dtype=dtype.float32) + rate1 = Tensor(1.0, dtype=dtype.float32) + value = Tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) + ans, expected, ans1, expected1 = net(value, rate, rate1) + tol = 1e-6 + assert (np.abs(ans.asnumpy() - expected.asnumpy()) < tol).all() + assert (np.abs(ans1.asnumpy() - expected1.asnumpy()) < tol).all() diff --git a/tests/ut/python/nn/probability/distribution/test_gumbel.py b/tests/ut/python/nn/probability/distribution/test_gumbel.py index 11af5b70a95..3c815ca152e 100644 --- a/tests/ut/python/nn/probability/distribution/test_gumbel.py +++ b/tests/ut/python/nn/probability/distribution/test_gumbel.py @@ -98,6 +98,8 @@ def test_kl_cross_entropy(): """ Test kl_loss and cross_entropy. """ + from mindspore import context + context.set_context(device_target="Ascend") net = KL() loc_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) scale_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)