diff --git a/mindspore/nn/probability/bijector/bijector.py b/mindspore/nn/probability/bijector/bijector.py index 79fe5d129d7..4760efd4b65 100644 --- a/mindspore/nn/probability/bijector/bijector.py +++ b/mindspore/nn/probability/bijector/bijector.py @@ -50,6 +50,8 @@ class Bijector(Cell): self._parameters = {} # parsing parameters for k in param.keys(): + if k == 'param': + continue if not(k == 'self' or k.startswith('_')): self._parameters[k] = param[k] self._is_constant_jacobian = is_constant_jacobian diff --git a/mindspore/nn/probability/bijector/power_transform.py b/mindspore/nn/probability/bijector/power_transform.py index e67f676238b..696749692d8 100644 --- a/mindspore/nn/probability/bijector/power_transform.py +++ b/mindspore/nn/probability/bijector/power_transform.py @@ -35,6 +35,9 @@ class PowerTransform(Bijector): Args: power (int or float): scale factor. Default: 0. name (str): name of the bijector. Default: 'PowerTransform'. + param (dict): parameters used to initialize the bijector. This is only used when other bijectors that inherits + from powertransform passing in parameters. In this case the derived bijector may overwrite the param args. + Default: None. Examples: >>> # To initialize a PowerTransform bijector of power 0.5