!5090 Update fix custom exp/log ops cast logic to r0.7

Merge pull request !5090 from zichun_ye/r0.7_fix_custom_ops
This commit is contained in:
mindspore-ci-bot 2020-08-25 09:41:54 +08:00 committed by Gitee
commit d8d7cebc8a
3 changed files with 34 additions and 25 deletions

View File

@ -17,6 +17,7 @@ import numpy as np
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
def exp_by_step(input_x):
"""
Log op on Ascend doesn't supprot int types.
@ -24,23 +25,18 @@ def exp_by_step(input_x):
"""
exp = P.Exp()
cast = P.Cast()
dtype = P.DType()
checktype = P.IsSubClass()
if checktype(dtype(input_x), mstype.int_):
input_x = cast(input_x, mstype.float32)
elif checktype(dtype(input_x), mstype.float_):
pass
else:
return None
input_x = cast(input_x, mstype.float32)
return exp(input_x)
def expm1_by_step(input_x):
"""
Expm1 ops under GPU context.
"""
return exp_by_step(input_x) - 1.0
def log_by_step(input_x):
"""
Log op on Ascend is calculated as log(abs(x)).
@ -56,14 +52,8 @@ def log_by_step(input_x):
dtype = P.DType()
shape = P.Shape()
select = P.Select()
checktype = P.IsSubClass()
if checktype(dtype(input_x), mstype.int_):
input_x = cast(input_x, mstype.float32)
elif checktype(dtype(input_x), mstype.float_):
pass
else:
return None
input_x = cast(input_x, mstype.float32)
nan = fill(dtype(input_x), shape(input_x), np.nan)
inf = fill(dtype(input_x), shape(input_x), np.inf)
neg_x = less(input_x, 0.0)
@ -72,6 +62,7 @@ def log_by_step(input_x):
result = select(nonpos_x, -inf, log_x)
return select(neg_x, nan, result)
def log1p_by_step(x):
"""
Log1p ops on GPU device or when device_target == GPU.

View File

@ -14,15 +14,15 @@
# ============================================================================
"""Utitly functions to help distribution class."""
import numpy as np
from mindspore.ops import _utils as utils
from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
from mindspore import context
from mindspore._checkparam import Validator as validator
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import _utils as utils
from mindspore.ops import composite as C
from mindspore import context
from mindspore.ops import operations as P
from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
import mindspore.nn as nn
import mindspore.nn.probability as msp
@ -82,6 +82,24 @@ def convert_to_batch(t, batch_shape, required_type):
return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type)
def cast_type_for_device(dtype):
"""
use the alternative dtype supported by the device.
Args:
dtype (mindspore.dtype): input dtype.
Returns:
mindspore.dtype.
"""
if context.get_context("device_target") == "GPU":
if dtype in mstype.uint_type or dtype == mstype.int8:
return mstype.int16
if dtype == mstype.int64:
return mstype.int32
if dtype == mstype.float64:
return mstype.float32
return dtype
def check_scalar_from_param(params):
"""
Check if params are all scalars.
@ -293,10 +311,10 @@ def raise_not_impl_error(name):
def check_distribution_name(name, expected_name):
if name is None:
raise ValueError(
f"Distribution should be a constant which is not None.")
f"Input dist should be a constant which is not None.")
if name != expected_name:
raise ValueError(
f"Expected distribution name is {expected_name}, but got {name}.")
f"Expected dist input is {expected_name}, but got {name}.")
class CheckTuple(PrimitiveWithInfer):

View File

@ -16,9 +16,10 @@
from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device
from ._utils.utils import CheckTuple, CheckTensor
class Distribution(Cell):
"""
Base class for all mathematical distributions.
@ -43,12 +44,12 @@ class Distribution(Cell):
new distribution specified by the dist_spec_args. But it won't change the
original distribuion.
"""
def __init__(self,
seed,
dtype,
name,
param):
"""
Constructor of distribution class.
"""
@ -58,7 +59,7 @@ class Distribution(Cell):
self._name = name
self._seed = seed
self._dtype = dtype
self._dtype = cast_type_for_device(dtype)
self._parameters = {}
# parsing parameters
for k in param.keys():
@ -436,7 +437,6 @@ class Distribution(Cell):
"""
return self._sample(*args, **kwargs)
def construct(self, name, *args, **kwargs):
"""
Override construct in Cell.