!3615 Move nn/distribution to nn/probability/distribution

Merge pull request !3615 from XunDeng/pp_poc_v3
This commit is contained in:
mindspore-ci-bot 2020-07-29 10:02:11 +08:00 committed by Gitee
commit c2385e2ede
11 changed files with 38 additions and 13 deletions

View File

@ -17,14 +17,14 @@ Neural Networks Cells.
Pre-defined building blocks or computing units to construct Neural Networks.
"""
from . import layer, loss, optim, metrics, wrap, distribution
from . import layer, loss, optim, metrics, wrap, probability
from .cell import Cell, GraphKernel
from .layer import *
from .loss import *
from .optim import *
from .metrics import *
from .wrap import *
from .distribution import *
from .probability import *
__all__ = ["Cell", "GraphKernel"]
@ -33,7 +33,7 @@ __all__.extend(loss.__all__)
__all__.extend(optim.__all__)
__all__.extend(metrics.__all__)
__all__.extend(wrap.__all__)
__all__.extend(distribution.__all__)
__all__.extend(probability.__all__)
__all__.sort()

View File

@ -0,0 +1,24 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Probability.
The high-level components(Distributions) used to construct the probabilistic network.
"""
from .distribution import *
__all__ = []
__all__.extend(distribution.__all__)

View File

@ -16,9 +16,9 @@
"""Utitly functions to help distribution class."""
import numpy as np
from mindspore.ops import _utils as utils
from ....common.tensor import Tensor
from ....common.parameter import Parameter
from ....common import dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype
def cast_to_tensor(t, dtype=mstype.float32):
"""

View File

@ -13,10 +13,10 @@
# limitations under the License.
# ============================================================================
"""Bernoulli Distribution"""
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob
from ...common import dtype as mstype
class Bernoulli(Distribution):
"""

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""basic"""
from ..cell import Cell
from mindspore.nn.cell import Cell
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param
class Distribution(Cell):

View File

@ -15,8 +15,8 @@
"""Exponential Distribution"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ...common import dtype as mstype
from ._utils.utils import cast_to_tensor, check_greater_zero
class Exponential(Distribution):

View File

@ -15,9 +15,9 @@
"""Geometric Distribution"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob
from ...common import dtype as mstype
class Geometric(Distribution):
"""

View File

@ -16,10 +16,11 @@
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from mindspore.context import get_context
from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater_equal_zero
from ...common import dtype as mstype
from ...context import get_context
class Normal(Distribution):
"""

View File

@ -14,8 +14,8 @@
# ============================================================================
"""Uniform Distribution"""
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ...common import dtype as mstype
from ._utils.utils import convert_to_batch, check_greater
class Uniform(Distribution):