forked from mindspore-Ecosystem/mindspore
!3615 Move nn/distribution to nn/probability/distribution
Merge pull request !3615 from XunDeng/pp_poc_v3
This commit is contained in:
commit
c2385e2ede
|
@ -17,14 +17,14 @@ Neural Networks Cells.
|
|||
|
||||
Pre-defined building blocks or computing units to construct Neural Networks.
|
||||
"""
|
||||
from . import layer, loss, optim, metrics, wrap, distribution
|
||||
from . import layer, loss, optim, metrics, wrap, probability
|
||||
from .cell import Cell, GraphKernel
|
||||
from .layer import *
|
||||
from .loss import *
|
||||
from .optim import *
|
||||
from .metrics import *
|
||||
from .wrap import *
|
||||
from .distribution import *
|
||||
from .probability import *
|
||||
|
||||
|
||||
__all__ = ["Cell", "GraphKernel"]
|
||||
|
@ -33,7 +33,7 @@ __all__.extend(loss.__all__)
|
|||
__all__.extend(optim.__all__)
|
||||
__all__.extend(metrics.__all__)
|
||||
__all__.extend(wrap.__all__)
|
||||
__all__.extend(distribution.__all__)
|
||||
__all__.extend(probability.__all__)
|
||||
|
||||
|
||||
__all__.sort()
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Probability.
|
||||
|
||||
The high-level components(Distributions) used to construct the probabilistic network.
|
||||
"""
|
||||
|
||||
from .distribution import *
|
||||
|
||||
__all__ = []
|
||||
__all__.extend(distribution.__all__)
|
|
@ -16,9 +16,9 @@
|
|||
"""Utitly functions to help distribution class."""
|
||||
import numpy as np
|
||||
from mindspore.ops import _utils as utils
|
||||
from ....common.tensor import Tensor
|
||||
from ....common.parameter import Parameter
|
||||
from ....common import dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
def cast_to_tensor(t, dtype=mstype.float32):
|
||||
"""
|
|
@ -13,10 +13,10 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Bernoulli Distribution"""
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import cast_to_tensor, check_prob
|
||||
from ...common import dtype as mstype
|
||||
|
||||
class Bernoulli(Distribution):
|
||||
"""
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""basic"""
|
||||
from ..cell import Cell
|
||||
from mindspore.nn.cell import Cell
|
||||
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param
|
||||
|
||||
class Distribution(Cell):
|
|
@ -15,8 +15,8 @@
|
|||
"""Exponential Distribution"""
|
||||
import numpy as np
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ...common import dtype as mstype
|
||||
from ._utils.utils import cast_to_tensor, check_greater_zero
|
||||
|
||||
class Exponential(Distribution):
|
|
@ -15,9 +15,9 @@
|
|||
"""Geometric Distribution"""
|
||||
import numpy as np
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import cast_to_tensor, check_prob
|
||||
from ...common import dtype as mstype
|
||||
|
||||
class Geometric(Distribution):
|
||||
"""
|
|
@ -16,10 +16,11 @@
|
|||
import numpy as np
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.context import get_context
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import convert_to_batch, check_greater_equal_zero
|
||||
from ...common import dtype as mstype
|
||||
from ...context import get_context
|
||||
|
||||
|
||||
class Normal(Distribution):
|
||||
"""
|
|
@ -14,8 +14,8 @@
|
|||
# ============================================================================
|
||||
"""Uniform Distribution"""
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ...common import dtype as mstype
|
||||
from ._utils.utils import convert_to_batch, check_greater
|
||||
|
||||
class Uniform(Distribution):
|
Loading…
Reference in New Issue