forked from mindspore-Ecosystem/mindspore
modify custom_ops to pass pynative mode
update dtype for device delete used funcs
This commit is contained in:
parent
1b71d50953
commit
d29bd6862a
|
@ -17,6 +17,7 @@ import numpy as np
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
def exp_by_step(input_x):
|
||||
"""
|
||||
Log op on Ascend doesn't supprot int types.
|
||||
|
@ -24,23 +25,18 @@ def exp_by_step(input_x):
|
|||
"""
|
||||
exp = P.Exp()
|
||||
cast = P.Cast()
|
||||
dtype = P.DType()
|
||||
checktype = P.IsSubClass()
|
||||
|
||||
if checktype(dtype(input_x), mstype.int_):
|
||||
input_x = cast(input_x, mstype.float32)
|
||||
elif checktype(dtype(input_x), mstype.float_):
|
||||
pass
|
||||
else:
|
||||
return None
|
||||
input_x = cast(input_x, mstype.float32)
|
||||
return exp(input_x)
|
||||
|
||||
|
||||
def expm1_by_step(input_x):
|
||||
"""
|
||||
Expm1 ops under GPU context.
|
||||
"""
|
||||
return exp_by_step(input_x) - 1.0
|
||||
|
||||
|
||||
def log_by_step(input_x):
|
||||
"""
|
||||
Log op on Ascend is calculated as log(abs(x)).
|
||||
|
@ -56,14 +52,8 @@ def log_by_step(input_x):
|
|||
dtype = P.DType()
|
||||
shape = P.Shape()
|
||||
select = P.Select()
|
||||
checktype = P.IsSubClass()
|
||||
|
||||
if checktype(dtype(input_x), mstype.int_):
|
||||
input_x = cast(input_x, mstype.float32)
|
||||
elif checktype(dtype(input_x), mstype.float_):
|
||||
pass
|
||||
else:
|
||||
return None
|
||||
input_x = cast(input_x, mstype.float32)
|
||||
nan = fill(dtype(input_x), shape(input_x), np.nan)
|
||||
inf = fill(dtype(input_x), shape(input_x), np.inf)
|
||||
neg_x = less(input_x, 0.0)
|
||||
|
@ -72,6 +62,7 @@ def log_by_step(input_x):
|
|||
result = select(nonpos_x, -inf, log_x)
|
||||
return select(neg_x, nan, result)
|
||||
|
||||
|
||||
def log1p_by_step(x):
|
||||
"""
|
||||
Log1p ops on GPU device or when device_target == GPU.
|
||||
|
|
|
@ -14,15 +14,15 @@
|
|||
# ============================================================================
|
||||
"""Utitly functions to help distribution class."""
|
||||
import numpy as np
|
||||
from mindspore.ops import _utils as utils
|
||||
from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
|
||||
from mindspore import context
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import _utils as utils
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore import context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
|
||||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability as msp
|
||||
|
||||
|
@ -82,6 +82,24 @@ def convert_to_batch(t, batch_shape, required_type):
|
|||
return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type)
|
||||
|
||||
|
||||
def cast_type_for_device(dtype):
|
||||
"""
|
||||
use the alternative dtype supported by the device.
|
||||
Args:
|
||||
dtype (mindspore.dtype): input dtype.
|
||||
Returns:
|
||||
mindspore.dtype.
|
||||
"""
|
||||
if context.get_context("device_target") == "GPU":
|
||||
if dtype in mstype.uint_type or dtype == mstype.int8:
|
||||
return mstype.int16
|
||||
if dtype == mstype.int64:
|
||||
return mstype.int32
|
||||
if dtype == mstype.float64:
|
||||
return mstype.float32
|
||||
return dtype
|
||||
|
||||
|
||||
def check_scalar_from_param(params):
|
||||
"""
|
||||
Check if params are all scalars.
|
||||
|
@ -293,10 +311,10 @@ def raise_not_impl_error(name):
|
|||
def check_distribution_name(name, expected_name):
|
||||
if name is None:
|
||||
raise ValueError(
|
||||
f"Distribution should be a constant which is not None.")
|
||||
f"Input dist should be a constant which is not None.")
|
||||
if name != expected_name:
|
||||
raise ValueError(
|
||||
f"Expected distribution name is {expected_name}, but got {name}.")
|
||||
f"Expected dist input is {expected_name}, but got {name}.")
|
||||
|
||||
|
||||
class CheckTuple(PrimitiveWithInfer):
|
||||
|
|
|
@ -16,9 +16,10 @@
|
|||
from mindspore.nn.cell import Cell
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param
|
||||
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device
|
||||
from ._utils.utils import CheckTuple, CheckTensor
|
||||
|
||||
|
||||
class Distribution(Cell):
|
||||
"""
|
||||
Base class for all mathematical distributions.
|
||||
|
@ -43,12 +44,12 @@ class Distribution(Cell):
|
|||
new distribution specified by the dist_spec_args. But it won't change the
|
||||
original distribuion.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
seed,
|
||||
dtype,
|
||||
name,
|
||||
param):
|
||||
|
||||
"""
|
||||
Constructor of distribution class.
|
||||
"""
|
||||
|
@ -58,7 +59,7 @@ class Distribution(Cell):
|
|||
|
||||
self._name = name
|
||||
self._seed = seed
|
||||
self._dtype = dtype
|
||||
self._dtype = cast_type_for_device(dtype)
|
||||
self._parameters = {}
|
||||
# parsing parameters
|
||||
for k in param.keys():
|
||||
|
@ -436,7 +437,6 @@ class Distribution(Cell):
|
|||
"""
|
||||
return self._sample(*args, **kwargs)
|
||||
|
||||
|
||||
def construct(self, name, *args, **kwargs):
|
||||
"""
|
||||
Override construct in Cell.
|
||||
|
|
Loading…
Reference in New Issue