fill down in functional

This commit is contained in:
hw_hz 2023-07-14 17:06:41 +08:00
parent c5160d494c
commit e12f046c8a
51 changed files with 194 additions and 195 deletions

View File

@ -284,7 +284,7 @@ class ZeroLikeFillZero : public AnfVisitor {
public:
ZeroLikeFillZero() {
py::gil_scoped_acquire gil;
PrimFill_ = prim::GetPythonOps("fill_", "mindspore.ops.functional")->cast<PrimitivePtr>();
PrimFill_ = prim::GetPythonOps("fill", "mindspore.ops.functional")->cast<PrimitivePtr>();
PrimShape_ = prim::GetPythonOps("shape_", "mindspore.ops.functional")->cast<PrimitivePtr>();
PrimDType_ = prim::GetPythonOps("dtype", "mindspore.ops.functional")->cast<PrimitivePtr>();
}

View File

@ -83,17 +83,17 @@ class ImageGradients(Cell):
_check_input_4d(F.shape(images), "images", self.cls_name)
batch_size, depth, height, width = P.Shape()(images)
if height == 1:
dy = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
dy = F.fill(P.DType()(images), (batch_size, depth, 1, width), 0)
else:
dy = images[:, :, 1:, :] - images[:, :, :height - 1, :]
dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
dy_last = F.fill(P.DType()(images), (batch_size, depth, 1, width), 0)
dy = P.Concat(2)((dy, dy_last))
if width == 1:
dx = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
dx = F.fill(P.DType()(images), (batch_size, depth, height, 1), 0)
else:
dx = images[:, :, :, 1:] - images[:, :, :, :width - 1]
dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
dx_last = F.fill(P.DType()(images), (batch_size, depth, height, 1), 0)
dx = P.Concat(3)((dx, dx_last))
return dy, dx

View File

@ -223,7 +223,6 @@ class LGamma(Cell):
self.abs = P.Abs()
self.shape = P.Shape()
self.dtype = P.DType()
self.fill = P.Fill()
self.floor = P.Floor()
self.equal = P.Equal()
self.greater = P.Greater()
@ -240,7 +239,7 @@ class LGamma(Cell):
if F.is_sequence_value_unknown(self.shape(x)):
infinity = self.ones_like(x) * F.cast(self.inf, input_dtype)
else:
infinity = self.fill(input_dtype, self.shape(x), self.inf)
infinity = F.fill(input_dtype, self.shape(x), self.inf)
need_to_reflect = self.less(x, 0.5)
neg_input = -x
@ -335,7 +334,6 @@ class DiGamma(Cell):
self.abs = P.Abs()
self.shape = P.Shape()
self.dtype = P.DType()
self.fill = P.Fill()
self.floor = P.Floor()
self.equal = P.Equal()
self.less = P.Less()
@ -371,7 +369,7 @@ class DiGamma(Cell):
reduced_input = x + self.abs(self.floor(x + 0.5))
reflection = y - self.pi * self.cos(self.pi * reduced_input) / self.sin(self.pi * reduced_input)
real_result = self.select(need_to_reflect, reflection, y)
nan = self.fill(self.dtype(x), self.shape(x), np.nan)
nan = F.fill(self.dtype(x), self.shape(x), np.nan)
return self.select(self.logicaland(self.less(x, 0), self.equal(x, self.floor(x))),
nan, real_result)
@ -391,7 +389,6 @@ def _igamma_series(ax, x, a, enabled):
logicaland = P.LogicalAnd()
greater = P.Greater()
fill = P.Fill()
shape = P.Shape()
dtype = P.DType()
select = P.Select()
@ -424,8 +421,8 @@ def _igamma_series(ax, x, a, enabled):
select(enabled, x, vals[4]), select(enabled, dc_da, vals[5]),
select(enabled, dans_da, vals[6]))
ones = fill(dtype(a), shape(a), 1)
zeros = fill(dtype(a), shape(a), 0)
ones = F.fill(dtype(a), shape(a), 1)
zeros = F.fill(dtype(a), shape(a), 0)
vals = (enabled, a, ones, ones, x, zeros, zeros)
vals = _while_helper_func(cond, body, vals)
@ -441,7 +438,6 @@ def _igammac_continued_fraction(ax, x, a, enabled):
greater = P.Greater()
less = P.Less()
notequal = P.NotEqual()
fill = P.Fill()
shape = P.Shape()
dtype = P.DType()
select = P.Select()
@ -482,7 +478,7 @@ def _igammac_continued_fraction(ax, x, a, enabled):
qk_is_nonzero = notequal(qk, 0)
r = pk / qk
t = select(qk_is_nonzero, abs_x((ans - r) / r), fill(dtype(t), shape(t), 1))
t = select(qk_is_nonzero, abs_x((ans - r) / r), F.fill(dtype(t), shape(t), 1))
ans = select(qk_is_nonzero, r, ans)
dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c
@ -490,7 +486,7 @@ def _igammac_continued_fraction(ax, x, a, enabled):
dans_da_new = select(qk_is_nonzero, (dpk_da - ans * dqk_da) / qk, dans_da)
grad_conditional = select(qk_is_nonzero,
abs_x(dans_da_new - dans_da),
fill(dtype(dans_da), shape(dans_da), 1))
F.fill(dtype(dans_da), shape(dans_da), 1))
pkm2 = pkm1
pkm1 = pk
@ -525,16 +521,16 @@ def _igammac_continued_fraction(ax, x, a, enabled):
y = 1 - a
z = x + y + 1
c = fill(dtype(x), shape(x), 0)
pkm2 = fill(dtype(x), shape(x), 1)
c = F.fill(dtype(x), shape(x), 0)
pkm2 = F.fill(dtype(x), shape(x), 1)
qkm2 = x
pkm1 = x + 1
qkm1 = z * x
ans = pkm1 / qkm1
t = fill(dtype(x), shape(x), 1)
dpkm2_da = fill(dtype(x), shape(x), 0)
dqkm2_da = fill(dtype(x), shape(x), 0)
dpkm1_da = fill(dtype(x), shape(x), 0)
t = F.fill(dtype(x), shape(x), 1)
dpkm2_da = F.fill(dtype(x), shape(x), 0)
dqkm2_da = F.fill(dtype(x), shape(x), 0)
dpkm1_da = F.fill(dtype(x), shape(x), 0)
dqkm1_da = -x
dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1
vals = (enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da)
@ -606,7 +602,6 @@ class IGamma(Cell):
self.exp = P.Exp()
self.select = P.Select()
self.zeroslike = P.ZerosLike()
self.fill = P.Fill()
self.shape = P.Shape()
self.dtype = P.DType()
self.lgamma = LGamma()
@ -633,7 +628,7 @@ class IGamma(Cell):
1 - _igammac_continued_fraction(ax, x, a, self.logicaland(enabled, use_igammac)),
_igamma_series(ax, x, a, self.logicaland(enabled, self.logicalnot(use_igammac))))
output = self.select(x_is_zero, self.zeroslike(output), output)
output = self.select(domain_error, self.fill(self.dtype(a), self.shape(a), np.nan), output)
output = self.select(domain_error, F.fill(self.dtype(a), self.shape(a), np.nan), output)
return output

View File

@ -220,7 +220,7 @@ class _ConstantPadNd(Cell):
output = ops.Pad(new_padding)(x)
mask = ops.Pad(new_padding)(mask)
ones = ops.OnesLike()(output)
value = ops.Fill()(output.dtype, output.shape, self.value)
value = ops.fill(output.dtype, output.shape, self.value)
output = ops.Add()(ops.Mul()(mask, output), ops.Mul()(ops.Sub()(ones, mask), value))
slice_op = ops.Slice()
begin, size = _get_begin_size(output.shape, start, end)

View File

@ -1529,7 +1529,6 @@ class MultiMarginLoss(LossBase):
"""Initialize MultiMarginLoss."""
super(MultiMarginLoss, self).__init__()
self.multi_margin_loss = MultiMarginLossOp(p=p, margin=margin, reduction=reduction)
self.generate_ones = ops.Fill()
self.weight = weight
def construct(self, x, target, weight=None):
@ -1541,7 +1540,7 @@ class MultiMarginLoss(LossBase):
if not weight_one:
_check_is_tensor('weight', weight, self.cls_name)
else:
weight = self.generate_ones(x.dtype, x.astype('float32')[0].shape, 1)
weight = F.fill(x.dtype, x.astype('float32')[0].shape, 1)
loss = self.multi_margin_loss(x, target, weight)
return loss

View File

@ -189,8 +189,8 @@ class Rprop(Optimizer):
self.prev = self._parameters.clone(prefix="prev", init='zeros')
self.step_size = self._parameters.clone(prefix="step_size", init='zeros')
self.fill = P.Fill()
self.sign = P.Sign()
self.fill = P.FillV2()
self.assign = P.Assign()
self.assignadd = P.AssignAdd()
self.cast = P.Cast()
@ -221,14 +221,26 @@ class Rprop(Optimizer):
param_fp32 = self.cast(param, mstype.float32)
sign = self.sign(gradient_fp32 * prev)
sign = self.select(sign > 0, self.fill(mstype.float32, sign.shape, self.etaplus), sign)
sign = self.select(sign < 0, self.fill(mstype.float32, sign.shape, self.etaminus), sign)
sign = self.select(sign == 0, self.fill(mstype.float32, sign.shape, 1.), sign)
sign = self.select(
sign > 0,
self.fill(sign.shape, self.cast(self.etaplus, mstype.float32)),
sign)
sign = self.select(
sign < 0,
self.fill(sign.shape, self.cast(self.etaminus,
mstype.float32)), sign)
sign = self.select(
sign == 0, self.fill(sign.shape,
self.cast(1., mstype.float32)), sign)
step_size_fp32 = ops.clip_by_value(step_size_fp32 * sign, self.step_size_min, self.step_size_max)
step_size_fp32 = ops.clip_by_value(step_size_fp32 * sign,
self.step_size_min,
self.step_size_max)
gradient_update = self.select(sign == self.etaminus, self.fill(mstype.float32, sign.shape, 0.),
gradient_fp32)
gradient_update = self.select(
sign == self.etaminus,
self.fill(sign.shape, self.cast(0., mstype.float32)),
gradient_fp32)
next_param = param_fp32 - self.sign(gradient_update) * step_size_fp32
self.assign(param, self.cast(next_param, param.dtype))

View File

@ -16,6 +16,7 @@
from mindspore import context
from mindspore.nn.cell import Cell
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.operations import _inner_ops as inner
from mindspore.common import dtype as mstype
from mindspore.common.tensor import Tensor
@ -96,7 +97,6 @@ class Bijector(Cell):
self.cast_base = P.Cast()
self.dtype_base = P.DType()
self.shape_base = P.Shape()
self.fill_base = P.Fill()
self.sametypeshape_base = inner.SameTypeShape()
self.issubclass_base = inner.IsSubClass()
@ -140,13 +140,13 @@ class Bijector(Cell):
if self.issubclass_base(value_type, mstype.float_):
return value
return raise_type_error('input value of bijector', value_type, mstype.float_)
dtype_tensor = self.fill_base(self.dtype, self.shape_base(value), 0.0)
dtype_tensor = F.fill(self.dtype, self.shape_base(value), 0.0)
self.sametypeshape_base(value, dtype_tensor)
return value
def _shape_mapping(self, shape):
shape_tensor = self.fill_base(self.parameter_type, shape, 0.0)
dist_shape_tensor = self.fill_base(
shape_tensor = F.fill(self.parameter_type, shape, 0.0)
dist_shape_tensor = F.fill(
self.parameter_type, self.batch_shape, 0.0)
return (shape_tensor + dist_shape_tensor).shape

View File

@ -14,6 +14,7 @@
# ============================================================================
"""PowerTransform Bijector"""
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from ..distribution._utils.utils import check_greater_equal_zero
from ..distribution._utils.custom_ops import exp_generic, log_generic
from .bijector import Bijector
@ -68,10 +69,7 @@ class PowerTransform(Bijector):
>>> print(ans4.shape)
(3,)
"""
def __init__(self,
power=0.,
name='PowerTransform'):
def __init__(self, power=0., name='PowerTransform'):
param = dict(locals())
param['param_dict'] = {'power': power}
super(PowerTransform, self).__init__(name=name, param=param)
@ -84,7 +82,6 @@ class PowerTransform(Bijector):
self.equal_base = P.Equal()
self.exp = exp_generic
self.expm1 = P.Expm1()
self.fill = P.Fill()
self.log = log_generic
self.log1p = P.Log1p()
self.select_base = P.Select()
@ -116,17 +113,18 @@ class PowerTransform(Bijector):
power_local = self.cast_param_by_value(x, self.power)
# broad cast the value of x and power
ones = self.fill(self.dtypeop(power_local),
self.shape(x + power_local), 1.)
ones = F.fill(self.dtypeop(power_local), self.shape(x + power_local),
1.)
power_local = power_local * ones
x = x * ones
safe_power = self.select_base(self.equal_base(power_local, P.ZerosLike()(power_local)),
ones,
power_local)
safe_power = self.select_base(
self.equal_base(power_local,
P.ZerosLike()(power_local)), ones, power_local)
forward_v = self.select_base(self.equal_base(power_local, P.ZerosLike()(power_local)),
self.exp(x),
self.exp(self.log1p(x * safe_power) / safe_power))
forward_v = self.select_base(
self.equal_base(power_local,
P.ZerosLike()(power_local)), self.exp(x),
self.exp(self.log1p(x * safe_power) / safe_power))
return forward_v
def _inverse(self, y):
@ -137,17 +135,18 @@ class PowerTransform(Bijector):
power_local = self.cast_param_by_value(y, self.power)
# broad cast the value of x and power
ones = self.fill(self.dtypeop(power_local),
self.shape(y + power_local), 1.)
ones = F.fill(self.dtypeop(power_local), self.shape(y + power_local),
1.)
power_local = power_local * ones
y = y * ones
safe_power = self.select_base(self.equal_base(power_local, P.ZerosLike()(power_local)),
ones,
power_local)
safe_power = self.select_base(
self.equal_base(power_local,
P.ZerosLike()(power_local)), ones, power_local)
inverse_v = self.select_base(self.equal_base(power_local, P.ZerosLike()(power_local)),
self.log(y),
self.expm1(self.log(y) * safe_power) / safe_power)
inverse_v = self.select_base(
self.equal_base(power_local,
P.ZerosLike()(power_local)), self.log(y),
self.expm1(self.log(y) * safe_power) / safe_power)
return inverse_v
@ -167,14 +166,15 @@ class PowerTransform(Bijector):
power_local = self.cast_param_by_value(x, self.power)
# broad cast the value of x and power
ones = self.fill(self.dtypeop(power_local),
self.shape(x + power_local), 1.)
ones = F.fill(self.dtypeop(power_local), self.shape(x + power_local),
1.)
power_local = power_local * ones
x = x * ones
forward_log_j = self.select_base(self.equal_base(power_local, P.ZerosLike()(power_local)),
x,
(1. / power_local - 1) * self.log1p(x * power_local))
forward_log_j = self.select_base(
self.equal_base(power_local,
P.ZerosLike()(power_local)), x,
(1. / power_local - 1) * self.log1p(x * power_local))
return forward_log_j

View File

@ -15,6 +15,7 @@
"""Softplus Bijector"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.nn.layer.activation import LogSigmoid
from ..distribution._utils.custom_ops import exp_generic, log_generic
from .bijector import Bijector
@ -84,7 +85,6 @@ class Softplus(Bijector):
self.abs = P.Abs()
self.dtypeop = P.DType()
self.cast = P.Cast()
self.fill = P.Fill()
self.greater = P.Greater()
self.less = P.Less()
self.log_sigmoid = LogSigmoid()
@ -103,7 +103,7 @@ class Softplus(Bijector):
too_large = self.greater(x, -self.threshold)
too_small_value = self.exp(x)
too_large_value = x
ones = self.fill(self.dtypeop(x), self.shape(x), 1.0)
ones = F.fill(self.dtypeop(x), self.shape(x), 1.0)
too_small_or_too_large = self.logicalor(too_small, too_large)
x = self.select(too_small_or_too_large, ones, x)
y = self.log(self.exp(x) + 1.0)
@ -119,7 +119,7 @@ class Softplus(Bijector):
too_large = self.greater(x, (-1) * self.threshold)
too_small_value = self.log(x)
too_large_value = x
ones = self.fill(self.dtypeop(x), self.shape(x), 1.0)
ones = F.fill(self.dtypeop(x), self.shape(x), 1.0)
too_small_or_too_large = self.logicalor(too_small, too_large)
x = self.select(too_small_or_too_large, ones, x)
y = x + self.log(self.abs(self.expm1((-1)*x)))

View File

@ -15,6 +15,7 @@
"""Utility functions to help distribution class."""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.primitive import constexpr
from mindspore.common import dtype as mstype
@ -52,7 +53,6 @@ def log_generic(input_x):
log = P.Log()
less = P.Less()
lessequal = P.LessEqual()
fill = P.Fill()
cast = P.Cast()
dtype = P.DType()
shape = P.Shape()
@ -61,8 +61,8 @@ def log_generic(input_x):
if not checktype(dtype(input_x), mstype.float_):
input_x = cast(input_x, mstype.float32)
nan = fill(dtype(input_x), shape(input_x), np.nan)
inf = fill(dtype(input_x), shape(input_x), np.inf)
nan = F.fill(dtype(input_x), shape(input_x), np.nan)
inf = F.fill(dtype(input_x), shape(input_x), np.inf)
neg_x = less(input_x, 0.0)
nonpos_x = lessequal(input_x, 0.0)
log_x = log(input_x)

View File

@ -15,6 +15,7 @@
"""Bernoulli Distribution"""
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore import _checkparam as Validator
from .distribution import Distribution
@ -151,7 +152,6 @@ class Bernoulli(Distribution):
self.cast = P.Cast()
self.const = P.ScalarToTensor()
self.floor = P.Floor()
self.fill = P.Fill()
self.less = P.Less()
self.shape = P.Shape()
self.select = P.Select()
@ -200,8 +200,8 @@ class Bernoulli(Distribution):
MODE(B) = 1 if probs1 > 0.5 else = 0
"""
probs1 = self._check_param_type(probs1)
zeros = self.fill(self.dtype, self.shape(probs1), 0.0)
ones = self.fill(self.dtype, self.shape(probs1), 1.0)
zeros = F.fill(self.dtype, self.shape(probs1), 0.0)
ones = F.fill(self.dtype, self.shape(probs1), 1.0)
comp = self.less(0.5, probs1)
return self.select(comp, ones, zeros)
@ -278,9 +278,9 @@ class Bernoulli(Distribution):
probs0 = self.broadcast((1.0 - probs1), broadcast_shape_tensor)
comp_zero = self.less(value, 0.0)
comp_one = self.less(value, 1.0)
zeros = self.fill(self.parameter_type, self.shape(
zeros = F.fill(self.parameter_type, self.shape(
broadcast_shape_tensor), 0.0)
ones = self.fill(self.parameter_type, self.shape(
ones = F.fill(self.parameter_type, self.shape(
broadcast_shape_tensor), 1.0)
less_than_zero = self.select(comp_zero, zeros, probs0)
return self.select(comp_one, less_than_zero, ones)

View File

@ -15,6 +15,7 @@
"""Beta Distribution"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
import mindspore.nn as nn
from mindspore import _checkparam as Validator
@ -186,7 +187,6 @@ class Beta(Distribution):
self.pow = P.Pow()
self.squeeze = P.Squeeze(0)
self.cast = P.Cast()
self.fill = P.Fill()
self.shape = P.Shape()
self.select = P.Select()
self.logicaland = P.LogicalAnd()
@ -266,7 +266,7 @@ class Beta(Distribution):
comp2 = self.greater(concentration0, 1.)
cond = self.logicaland(comp1, comp2)
batch_shape = self.shape(concentration1 + concentration0)
nan = self.fill(self.dtype, batch_shape, np.nan)
nan = F.fill(self.dtype, batch_shape, np.nan)
mode = (concentration1 - 1.) / (concentration1 + concentration0 - 2.)
return self.select(cond, mode, nan)
@ -379,7 +379,7 @@ class Beta(Distribution):
sample_shape = (1,)
else:
sample_shape = origin_shape
ones = self.fill(self.dtype, sample_shape, 1.0)
ones = F.fill(self.dtype, sample_shape, 1.0)
sample_gamma1 = C.gamma(
sample_shape, alpha=concentration1, beta=ones, seed=self.seed)
sample_gamma2 = C.gamma(

View File

@ -17,6 +17,7 @@ import numpy as np
from mindspore import context
from mindspore.common import Tensor
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.ops.functional import stop_gradient
from mindspore.ops.operations import _inner_ops as inner
@ -149,7 +150,6 @@ class Categorical(Distribution):
self.dtypeop = P.DType()
self.exp = exp_generic
self.expand_dim = P.ExpandDims()
self.fill = P.Fill()
self.gather = P.GatherNd()
self.greater = P.Greater()
self.issubclass = inner.IsSubClass()
@ -292,7 +292,7 @@ class Categorical(Distribution):
# here we simulate casting to int but still keeping float dtype
value = self.cast(value, self.dtypeop(probs))
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
zeros = F.fill(self.dtypeop(value), self.shape(value), 0.0)
between_zero_neone = self.logicand(self.less(value, 0,),
self.greater(value, -1.))
value = self.select(between_zero_neone,
@ -338,8 +338,8 @@ class Categorical(Distribution):
# reshape into label shape N
logits_pmf = self.gather(self.reshape(
logits, (-1, num_classes)), index)
nan = self.fill(self.dtypeop(logits_pmf),
self.shape(logits_pmf), self.nan)
nan = F.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf),
self.nan)
logits_pmf = self.select(out_of_bound, nan, logits_pmf)
ans = self.reshape(logits_pmf, label_shape)
if drop_dim:
@ -359,7 +359,7 @@ class Categorical(Distribution):
value = self.cast(value, self.dtypeop(probs))
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
zeros = F.fill(self.dtypeop(value), self.shape(value), 0.0)
between_zero_neone = self.logicand(
self.less(value, 0,), self.greater(value, -1.))
value = self.select(between_zero_neone, zeros, P.Floor()(value))
@ -394,7 +394,7 @@ class Categorical(Distribution):
# reshape probs and fill less_than_zero places with 0
probs = self.reshape(probs, (-1, num_classes))
cdf = self.gather(self.cumsum(probs, 1), index)
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
zeros = F.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
cdf = self.select(less_than_zero, zeros, cdf)
cdf = self.reshape(cdf, label_shape)
@ -425,7 +425,7 @@ class Categorical(Distribution):
sample_shape = (1,)
probs_2d = self.reshape(probs, (-1, num_classes))
sample_tensor = self.fill(self.dtype, shape, 1.0)
sample_tensor = F.fill(self.dtype, shape, 1.0)
sample_tensor = self.reshape(sample_tensor, (-1, 1))
num_sample = self.shape(sample_tensor)[0]
samples = C.multinomial(probs_2d, num_sample, seed=self.seed)

View File

@ -170,7 +170,6 @@ class Cauchy(Distribution):
self.const = P.ScalarToTensor()
self.dtypeop = P.DType()
self.exp = exp_generic
self.fill = P.Fill()
self.less = P.Less()
self.log = log_generic
self.log1p = log1p_generic

View File

@ -15,6 +15,7 @@
"""basic"""
from mindspore import context
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.nn.cell import Cell
from mindspore.ops.primitive import constexpr
from mindspore.ops.operations import _inner_ops as inner
@ -113,7 +114,6 @@ class Distribution(Cell):
# ops needed for the base class
self.cast_base = P.Cast()
self.dtype_base = P.DType()
self.fill_base = P.Fill()
self.sametypeshape_base = inner.SameTypeShape()
self.sq_base = P.Square()
self.sqrt_base = P.Sqrt()
@ -194,11 +194,11 @@ class Distribution(Cell):
if broadcast_shape is None:
broadcast_shape = self.shape_base(arg)
common_dtype = self.dtype_base(arg)
broadcast_shape_tensor = self.fill_base(
broadcast_shape_tensor = F.fill(
common_dtype, broadcast_shape, 1.0)
else:
broadcast_shape = self.shape_base(arg + broadcast_shape_tensor)
broadcast_shape_tensor = self.fill_base(
broadcast_shape_tensor = F.fill(
common_dtype, broadcast_shape, 1.0)
arg = self.broadcast(arg, broadcast_shape_tensor)
# check if the arguments have the same dtype

View File

@ -161,7 +161,6 @@ class Exponential(Distribution):
self.cast = P.Cast()
self.const = P.ScalarToTensor()
self.dtypeop = P.DType()
self.fill = P.Fill()
self.less = P.Less()
self.select = P.Select()
self.shape = P.Shape()
@ -209,7 +208,7 @@ class Exponential(Distribution):
MODE(EXP) = 0.
"""
rate = self._check_param_type(rate)
return self.fill(self.dtype, self.shape(rate), 0.)
return F.fill(self.dtype, self.shape(rate), 0.)
def _sd(self, rate=None):
r"""
@ -258,8 +257,8 @@ class Exponential(Distribution):
value = self.cast(value, self.dtype)
rate = self._check_param_type(rate)
prob = self.log(rate) - rate * value
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
neginf = self.fill(self.dtypeop(prob), self.shape(prob), -np.inf)
zeros = F.fill(self.dtypeop(prob), self.shape(prob), 0.0)
neginf = F.fill(self.dtypeop(prob), self.shape(prob), -np.inf)
comp = self.less(value, zeros)
return self.select(comp, neginf, prob)
@ -281,7 +280,7 @@ class Exponential(Distribution):
value = self.cast(value, self.dtype)
rate = self._check_param_type(rate)
cdf = 1.0 - self.exp(-1. * rate * value)
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
zeros = F.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
comp = self.less(value, zeros)
return self.select(comp, zeros, cdf)
@ -303,7 +302,7 @@ class Exponential(Distribution):
value = self.cast(value, self.dtype)
rate = self._check_param_type(rate)
sf = -1. * rate * value
zeros = self.fill(self.dtypeop(sf), self.shape(sf), 0.0)
zeros = F.fill(self.dtypeop(sf), self.shape(sf), 0.0)
comp = self.less(value, zeros)
return self.select(comp, zeros, sf)

View File

@ -15,6 +15,7 @@
"""Gamma Distribution"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
import mindspore.nn as nn
from mindspore import _checkparam as Validator
@ -185,7 +186,6 @@ class Gamma(Distribution):
self.squeeze = P.Squeeze(0)
self.cast = P.Cast()
self.dtypeop = P.DType()
self.fill = P.Fill()
self.shape = P.Shape()
self.select = P.Select()
self.greater = P.Greater()
@ -265,8 +265,8 @@ class Gamma(Distribution):
"""
concentration, rate = self._check_param_type(concentration, rate)
mode = (concentration - 1.) / rate
nan = self.fill(self.dtypeop(concentration),
self.shape(concentration), np.nan)
nan = F.fill(self.dtypeop(concentration), self.shape(concentration),
np.nan)
comp = self.greater(concentration, 1.)
return self.select(comp, mode, nan)

View File

@ -15,6 +15,7 @@
"""Geometric Distribution"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops import composite as C
from mindspore import _checkparam as Validator
@ -160,7 +161,6 @@ class Geometric(Distribution):
self.cast = P.Cast()
self.const = P.ScalarToTensor()
self.dtypeop = P.DType()
self.fill = P.Fill()
self.floor = P.Floor()
self.issubclass = inner.IsSubClass()
self.less = P.Less()
@ -212,7 +212,7 @@ class Geometric(Distribution):
MODE(Geo) = 0
"""
probs1 = self._check_param_type(probs1)
return self.fill(self.dtype, self.shape(probs1), 0.)
return F.fill(self.dtype, self.shape(probs1), 0.)
def _var(self, probs1=None):
r"""
@ -260,7 +260,7 @@ class Geometric(Distribution):
value = self.floor(value)
probs1 = self._check_param_type(probs1)
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
zeros = self.fill(self.dtypeop(pmf), self.shape(pmf), 0.0)
zeros = F.fill(self.dtypeop(pmf), self.shape(pmf), 0.0)
comp = self.less(value, zeros)
return self.select(comp, zeros, pmf)
@ -283,7 +283,7 @@ class Geometric(Distribution):
probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1
cdf = 1.0 - self.pow(probs0, value + 1.0)
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
zeros = F.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
comp = self.less(value, zeros)
return self.select(comp, zeros, cdf)

View File

@ -15,6 +15,7 @@
"""Gumbel Distribution"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import _checkparam as Validator
from mindspore.common import dtype as mstype
import mindspore.nn.probability.bijector as msb
@ -101,7 +102,6 @@ class Gumbel(TransformedDistribution):
self.const = P.ScalarToTensor()
self.exp = exp_generic
self.expm1 = P.Expm1()
self.fill = P.Fill()
self.lgamma = P.Lgamma()
self.log = log_generic
self.shape = P.Shape()
@ -163,7 +163,7 @@ class Gumbel(TransformedDistribution):
"""
The mode of the distribution.
"""
return self.loc * self.fill(self.parameter_type, self.shape(self.scale), 1.0)
return self.loc * F.fill(self.parameter_type, self.shape(self.scale), 1.0)
def _sd(self):
r"""
@ -173,7 +173,7 @@ class Gumbel(TransformedDistribution):
STD(X) = \frac{\pi}{\sqrt(6)} * scale
"""
scale = self.scale * \
self.fill(self.parameter_type, self.broadcast_shape, 1.0)
F.fill(self.parameter_type, self.broadcast_shape, 1.0)
return scale * np.pi / self.sqrt(self.const(6., mstype.float32))
def _entropy(self):
@ -184,7 +184,7 @@ class Gumbel(TransformedDistribution):
H(X) = 1. + \log(scale) + Euler-Mascheroni_constant
"""
scale = self.scale * \
self.fill(self.parameter_type, self.broadcast_shape, 1.0)
F.fill(self.parameter_type, self.broadcast_shape, 1.0)
return 1. + self.log(scale) + np.euler_gamma
def _log_prob(self, value):

View File

@ -15,6 +15,7 @@
"""LogNormal Distribution"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
import mindspore.nn.probability.bijector as msb
import mindspore.nn.probability.distribution as msd
@ -101,7 +102,6 @@ class LogNormal(msd.TransformedDistribution):
self.expm1 = P.Expm1()
self.log = log_generic
self.erf = P.Erf()
self.fill = P.Fill()
self.greater = P.Greater()
self.select = P.Select()
self.shape = P.Shape()
@ -202,7 +202,7 @@ class LogNormal(msd.TransformedDistribution):
cdf = self.distribution("cdf", inverse_value, mean, sd)
# to increase numerical stability, set cdf = 0 when value <= 0
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
zeros = F.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
return self.select(self.greater(value, 0.), cdf, zeros)

View File

@ -15,6 +15,7 @@
"""Logistic Distribution"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore import _checkparam as Validator
from mindspore.common import dtype as mstype
@ -153,7 +154,6 @@ class Logistic(Distribution):
self.dtypeop = P.DType()
self.exp = exp_generic
self.expm1 = P.Expm1()
self.fill = P.Fill()
self.less = P.Less()
self.log = log_generic
self.log1p = P.Log1p()
@ -179,7 +179,7 @@ class Logistic(Distribution):
too_small_value = self.exp(x)
too_large_value = x
too_small_or_too_large = self.logicalor(too_small, too_large)
ones = self.fill(self.dtypeop(x), self.shape(x), 1.0)
ones = F.fill(self.dtypeop(x), self.shape(x), 1.0)
x = self.select(too_small_or_too_large, ones, x)
y = self.log(self.exp(x) + 1.0)
return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y))

View File

@ -15,6 +15,7 @@
"""Poisson Distribution"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore import _checkparam as Validator
from mindspore.common import dtype as mstype
@ -149,7 +150,6 @@ class Poisson(Distribution):
self.floor = P.Floor()
self.dtypeop = P.DType()
self.shape = P.Shape()
self.fill = P.Fill()
self.less = P.Less()
self.equal = P.Equal()
self.select = P.Select()
@ -228,8 +228,8 @@ class Poisson(Distribution):
value = self.cast(value, self.dtype)
rate = self._check_param_type(rate)
log_rate = self.log(rate)
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
inf = self.fill(self.dtypeop(value), self.shape(value), np.inf)
zeros = F.fill(self.dtypeop(value), self.shape(value), 0.0)
inf = F.fill(self.dtypeop(value), self.shape(value), np.inf)
safe_x = self.select(self.less(value, zeros), zeros, value)
y = log_rate * safe_x - self.lgamma(safe_x + 1.)
comp = self.equal(value, safe_x)
@ -254,7 +254,7 @@ class Poisson(Distribution):
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
rate = self._check_param_type(rate)
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
zeros = F.fill(self.dtypeop(value), self.shape(value), 0.0)
comp = self.less(value, zeros)
safe_x = self.select(comp, zeros, value)
cdf = 1. - self.igamma(1. + safe_x, rate)

View File

@ -16,6 +16,7 @@
import numpy as np
from mindspore import _checkparam as validator
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
import mindspore.nn as nn
from .distribution import Distribution
@ -125,7 +126,6 @@ class TransformedDistribution(Distribution):
self.cast_base = P.Cast()
self.equal_base = P.Equal()
self.select_base = P.Select()
self.fill_base = P.Fill()
# broadcast bijector batch_shape and distribution batch_shape
self._broadcast_shape = self._broadcast_bijector_dist()
@ -176,9 +176,9 @@ class TransformedDistribution(Distribution):
"""
if self.batch_shape is None or self.bijector.batch_shape is None:
return None
bijector_shape_tensor = self.fill_base(
bijector_shape_tensor = F.fill(
self.dtype, self.bijector.batch_shape, 0.0)
dist_shape_tensor = self.fill_base(self.dtype, self.batch_shape, 0.0)
dist_shape_tensor = F.fill(self.dtype, self.batch_shape, 0.0)
return (bijector_shape_tensor + dist_shape_tensor).shape
def _cdf(self, value, *args, **kwargs):

View File

@ -14,6 +14,7 @@
# ============================================================================
"""Uniform Distribution"""
import numpy as np
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore import _checkparam as Validator
@ -170,7 +171,6 @@ class Uniform(Distribution):
self.cast = P.Cast()
self.const = P.ScalarToTensor()
self.dtypeop = P.DType()
self.fill = P.Fill()
self.less = P.Less()
self.lessequal = P.LessEqual()
self.logicaland = P.LogicalAnd()
@ -287,10 +287,10 @@ class Uniform(Distribution):
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
low, high = self._check_param_type(low, high)
neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
neg_ones = F.fill(self.dtype, self.shape(value), -1.0)
prob = self.exp(neg_ones * self.log(high - low))
broadcast_shape = self.shape(prob)
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
zeros = F.fill(self.dtypeop(prob), broadcast_shape, 0.0)
comp_lo = self.less(value, low)
comp_hi = self.lessequal(value, high)
less_than_low = self.select(comp_lo, zeros, prob)
@ -316,7 +316,7 @@ class Uniform(Distribution):
kl = self.log(high_b - low_b) - self.log(high_a - low_a)
comp = self.logicaland(self.lessequal(
low_b, low_a), self.lessequal(high_a, high_b))
inf = self.fill(self.dtypeop(kl), self.shape(kl), np.inf)
inf = F.fill(self.dtypeop(kl), self.shape(kl), np.inf)
return self.select(comp, kl, inf)
def _cdf(self, value, low=None, high=None):
@ -338,8 +338,8 @@ class Uniform(Distribution):
low, high = self._check_param_type(low, high)
prob = (value - low) / (high - low)
broadcast_shape = self.shape(prob)
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
ones = self.fill(self.dtypeop(prob), broadcast_shape, 1.0)
zeros = F.fill(self.dtypeop(prob), broadcast_shape, 0.0)
ones = F.fill(self.dtypeop(prob), broadcast_shape, 1.0)
comp_lo = self.less(value, low)
comp_hi = self.less(value, high)
less_than_low = self.select(comp_lo, zeros, prob)

View File

@ -690,7 +690,7 @@ class _TrainPipelineAccuStepCell(TrainOneStepCell):
if not self.sense_flag:
return self._no_sens_impl(*inputs)
loss = self.network(*inputs)
sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
sens = ops.fill(ops.DType()(loss), ops.Shape()(loss), self.sens)
grads = self.grad(self.network, self.weights)(*inputs, sens)
accu_grads = ops.depend(self.accu_grads, grads)
if self.opt_shard:

View File

@ -356,7 +356,7 @@ class DistributedGradReducer(Cell):
... def construct(self, *args):
... weights = self.weights
... loss = self.network(*args)
... sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
... sens = F.fill(ops.DType()(loss), ops.Shape()(loss), self.sens)
... grads = self.grad(self.network, weights)(*args, sens)
... if self.reducer_flag:
... # apply grad reducer on grads

View File

@ -398,7 +398,6 @@ def get_bprop_extract_volume_patches(self):
expend_dims = P.ExpandDims()
scatter_nd = P.ScatterNd()
slice_op = P.Slice()
fill = P.Fill()
dtype = P.DType()
cast = P.Cast()
matmul = P.MatMul()
@ -466,7 +465,7 @@ def get_bprop_extract_volume_patches(self):
idx_tensor = concat((expend_dims(x_idx_patched, -1), expend_dims(out_idx, -1)))
idx_map = P.Reshape()(idx_tensor, (-1, 2))
sp_shape = (x_indices_num, out_indices_num)
sp_mat_full = scatter_nd(idx_map, fill(dtype(dout), (out_indices_num,), 1), sp_shape)
sp_mat_full = scatter_nd(idx_map, F.fill(dtype(dout), (out_indices_num,), 1), sp_shape)
sp_tensor = slice_op(sp_mat_full, (1, 0), (x_indices_num - 1, out_indices_num))
grad = P.Transpose()(dout, (0, 2, 3, 4, 1))

View File

@ -155,7 +155,7 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
_check(indices_shape)
indices_len = len(indices_shape)
if indices_len == 1:
prefix = P.Range()(Tensor(0, indices_dtype), P.Fill()(
prefix = P.Range()(Tensor(0, indices_dtype), F.fill(
indices_dtype, (), axis_size), Tensor(1, indices_dtype))
return prefix

View File

@ -398,10 +398,16 @@ def slice2indices(input_slice, shape):
return False
ndim = len(shape)
mesh = list()
grids = [P.Range()(P.Fill()(mstype.int64, (), start), P.Fill()(
mstype.int64, (), stop), P.Fill()(mstype.int64, (), step))]
grids += [P.Range()(Tensor(0, mstype.int64), P.Fill()(mstype.int64, (), dim_size),
Tensor(1, mstype.int64)) for dim_size in shape[1:]]
range_op = P.Range()
cast_op = P.Cast()
grids = [
range_op(cast_op(start, mstype.int64), cast_op(stop, mstype.int64),
cast_op(step, mstype.int64))
]
grids += [
range_op(Tensor(0, mstype.int64), cast_op(dim_size, mstype.int64),
Tensor(1, mstype.int64)) for dim_size in shape[1:]
]
for j, grid in enumerate(grids):
mesh.append(P.Reshape()(grid, tuple(
[grid.size if j == t else 1 for t in range(ndim)])))
@ -543,7 +549,7 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty
updates_shape = indices_shape + data_shape[1:]
else:
updates_shape = indices_shape[:-1] + data_shape[indices_shape[-1]:]
return P.Fill()(data_dtype, updates_shape, value)
return P.FillV2()(updates_shape, P.Cast()(value, data_dtype))
def generate_updates_shape(data_shape, index_shape, op_type, is_dynamic):
@ -865,11 +871,11 @@ def int_to_index(i, shape):
_check(i, dim_size)
i = (i + dim_size) % dim_size
if len(shape) == 1:
return P.Fill()(mstype.int64, (1, 1), i)
return P.FillV2()((1, 1), P.Cast()(i, mstype.int64))
mesh = list()
ndim = len(shape) - 1
for j, size in enumerate(shape[1:]):
grid = P.Range()(Tensor(0, mstype.int64), P.Fill()(mstype.int64, (), size), Tensor(1, mstype.int64))
grid = P.Range()(Tensor(0, mstype.int64), P.Cast()(size, mstype.int64), Tensor(1, mstype.int64))
mesh.append(P.Reshape()(grid, tuple([size if j == t else 1 for t in range(ndim)])))
shapes = map(P.Shape(), mesh)
out_shape = infer_out_shape(*shapes)
@ -877,7 +883,8 @@ def int_to_index(i, shape):
for arr in mesh:
mesh_arrays.append(P.BroadcastTo(out_shape)(arr))
index = P.Stack(-1)(mesh_arrays)
return P.Concat(-1)((P.Fill()(mstype.int64, P.Shape()(index)[:-1] + (1,), i), index))
return P.Concat(-1)((P.FillV2()(P.Shape()(index)[:-1] + (1,),
P.Cast()(i, mstype.int64)), index))
@constexpr

View File

@ -34,7 +34,6 @@ from .array_func import (
matrix_band_part,
padding,
fill,
fill_,
full,
full_like,
chunk,

View File

@ -65,8 +65,6 @@ from mindspore.ops._utils.utils import ms_arrange
tuple_to_tensor_ = TupleToTensor()
eye_ = P.Eye()
fills_ = Fills()
fill_ = P.Fill()
fillv2_ = P.FillV2()
ones_ = P.Ones()
ones_like_ = P.OnesLike()
tile_ = P.Tile()
@ -746,7 +744,7 @@ def fill(type, shape, value): # pylint: disable=redefined-outer-name
if isinstance(shape, tuple):
shape = tuple_to_tensor_(shape, mstype.int32)
value = cast_(value, type)
return fillv2_(shape, value)
return _get_cache_prim(P.FillV2)()(shape, value)
def full(size, fill_value, *, dtype=None): # pylint: disable=redefined-outer-name
@ -791,7 +789,7 @@ def full(size, fill_value, *, dtype=None): # pylint: disable=redefined-outer-nam
raise TypeError(f"For 'ops.full', 'dtype' must be mindspore.type, but got {dtype}.")
if isinstance(size, list):
size = tuple(size)
return fill_(dtype, size, fill_value)
return F.fill(dtype, size, fill_value)
def full_like(input, fill_value, *, dtype=None):
@ -6893,18 +6891,17 @@ def diagonal(input, offset=0, dim1=0, dim2=1):
x_shape = input.shape
n, m = x_shape[-2:]
fill_op = _get_cache_prim(P.Fill)()
e = _get_cache_prim(P.Eye)()(n, m, dtype)
if offset >= m or offset <= -n:
e = fill_op(dtype, (n, m), 0)
e = F.fill(dtype, (n, m), 0)
elif offset != 0:
e = e.astype(mstype.float32)
if offset > 0:
e_left = fill_op(mstype.float32, (n, offset), 0)
e_left = F.fill(mstype.float32, (n, offset), 0)
e_right = e[..., 0:m - offset:1]
e = _get_cache_prim(P.Concat)(1)((e_left, e_right)).astype(dtype)
elif offset < 0:
e_upper = fill_op(mstype.float32, (-offset, m), 0)
e_upper = F.fill(mstype.float32, (-offset, m), 0)
e_lower = e[0:n + offset:1, ...]
e = _get_cache_prim(P.Concat)(0)((e_upper, e_lower)).astype(dtype)
e = F.broadcast_to(e, x_shape)
@ -7690,7 +7687,6 @@ __all__ = [
'matrix_band_part',
'padding',
'fill',
'fill_',
'fills',
'tile',
'size',

View File

@ -218,7 +218,6 @@ tensor_operator_registry.register('logdet', logdet)
tensor_operator_registry.register('log_matrix_determinant', log_matrix_determinant)
tensor_operator_registry.register('matrix_determinant', matrix_determinant)
tensor_operator_registry.register('ceil', P.Ceil)
tensor_operator_registry.register('fill', P.Fill)
tensor_operator_registry.register('fillv2', P.FillV2)
tensor_operator_registry.register('tile', P.Tile)
tensor_operator_registry.register('logit', logit)

View File

@ -322,7 +322,6 @@ class BertAttentionRelativePositionValues(nn.Cell):
max_relative_position=16,
initializer_range=initializer_range,
use_one_hot_embeddings=use_one_hot_embeddings)
self.fill = P.Fill()
self.multiply = P.Mul()
self.type = P.DType()
self.cast = P.Cast()
@ -358,7 +357,7 @@ class BertAttentionRelativePositionValues(nn.Cell):
context_layer = self.transpose(input_tensor, self.trans_shape)
context_layer = self.reshape(context_layer, self.shp_return)
# ge reshape should not return, need an operator here
ones = self.cast(self.fill((1, 1), 1), self.type(context_layer))
ones = self.cast(F.fill((1, 1), 1), self.type(context_layer))
context_layer = self.multiply(context_layer, ones)
return relations_values_embedding, context_layer

View File

@ -32,7 +32,6 @@ class NpuFloatNet(nn.Cell):
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_status = P.NPUClearFloatStatus()
self.fill = P.Fill()
self.shape_op = P.Shape()
self.select = P.Select()
self.less = P.Less()
@ -52,8 +51,8 @@ class NpuFloatNet(nn.Cell):
# let reduce_sum depend on get_statusk
init = F.depend(init, get_status)
flag_sum = self.reduce_sum(init, (0,))
base = self.cast(self.fill(self.dtype(
res), self.shape_op(res), 0.0), self.dtype(flag_sum))
base = self.cast(F.fill(self.dtype(res), self.shape_op(res), 0.0),
self.dtype(flag_sum))
cond = self.less(base, flag_sum)
out = self.select(cond, self.cast(base, self.dtype(res)), res)
return out

View File

@ -794,9 +794,9 @@ class MultiTaskTrainOneStepCell(nn.Cell):
def construct(self, *inputs):
weights = self.weights
(loss, aloss, closs) = self.network(*inputs)
sens = (P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens),
P.Fill()(P.DType()(aloss), P.Shape()(aloss), 0.0),
P.Fill()(P.DType()(closs), P.Shape()(closs), 0.0))
sens = (F.fill(P.DType()(loss), P.Shape()(loss), self.sens),
F.fill(P.DType()(aloss), P.Shape()(aloss), 0.0),
F.fill(P.DType()(closs), P.Shape()(closs), 0.0))
grads = self.grad(self.network, weights)(*inputs, sens)
grads = self.grad_reducer(grads)
return (F.depend(loss, self.optimizer(grads)), aloss, closs)

View File

@ -20,6 +20,7 @@ from sklearn.metrics import roc_auc_score
import mindspore.common.dtype as mstype
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.nn import Dropout
from mindspore.nn.optim import Adam
from mindspore.train import Metric
@ -315,7 +316,7 @@ class TrainStepWrap(nn.Cell):
def construct(self, batch_ids, batch_wts, label):
weights = self.weights
loss = self.network(batch_ids, batch_wts, label)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) #
sens = F.fill(P.DType()(loss), P.Shape()(loss), self.sens) #
grads = self.grad(self.network, weights)(batch_ids, batch_wts, label, sens)
if self.reducer_flag:
# apply grad reducer on grads

View File

@ -19,6 +19,7 @@ from mindspore import Parameter, ParameterTuple
import mindspore.common.dtype as mstype
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.nn.optim import Adam, FTRL
from mindspore.common.initializer import Uniform, initializer
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
@ -302,8 +303,8 @@ class TrainStepWrap(nn.Cell):
weights_w = self.weights_w
weights_d = self.weights_d
loss_w, loss_d = self.network(batch_ids, batch_wts, label)
sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens)
sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens)
sens_w = F.fill(P.DType()(loss_w), P.Shape()(loss_w), self.sens)
sens_d = F.fill(P.DType()(loss_d), P.Shape()(loss_d), self.sens)
grads_w = self.grad_w(self.loss_net_w, weights_w)(batch_ids, batch_wts,
label, sens_w)
grads_d = self.grad_d(self.loss_net_d, weights_d)(batch_ids, batch_wts,

View File

@ -530,7 +530,7 @@ class YoloLossBlock(nn.Cell):
true_xy = y_true[:, :, :, :, :2] * grid_shape - grid
true_wh = y_true[:, :, :, :, 2:4]
true_wh = P.Select()(P.Equal()(true_wh, 0.0),
P.Fill()(P.DType()(true_wh), P.Shape()(true_wh), 1.0),
F.fill(P.DType()(true_wh), P.Shape()(true_wh), 1.0),
true_wh)
true_wh = P.Log()(true_wh / self.anchors * self.input_shape)
box_loss_scale = 2 - y_true[:, :, :, :, 2:3] * y_true[:, :, :, :, 3:4]
@ -666,7 +666,7 @@ class TrainingWrapper(nn.Cell):
def construct(self, *args):
weights = self.weights
loss = self.network(*args)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
sens = F.fill(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*args, sens)
if self.reducer_flag:
# apply grad reducer on grads

View File

@ -314,9 +314,8 @@ class YoloLossBlock(nn.Cell):
true_xy = y_true[:, :, :, :, :2] * grid_shape - grid
true_wh = y_true[:, :, :, :, 2:4]
true_wh = P.Select()(P.Equal()(true_wh, 0.0),
P.Fill()(P.DType()(true_wh),
P.Shape()(true_wh), 1.0),
true_wh)
F.fill(P.DType()(true_wh),
P.Shape()(true_wh), 1.0), true_wh)
true_wh = P.Log()(true_wh / self.anchors * input_shape)
# 2-w*h for large picture, use small scale, since small obj need more precise
box_loss_scale = 2 - y_true[:, :, :, :, 2:3] * y_true[:, :, :, :, 3:4]
@ -432,7 +431,7 @@ class TrainingWrapper(nn.Cell):
def construct(self, *args):
weights = self.weights
loss = self.network(*args)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
sens = F.fill(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*args, sens)
if self.reducer_flag:
grads = self.grad_reducer(grads)

View File

@ -36,7 +36,6 @@ class OhemLoss(nn.Cell):
self.not_equal = P.NotEqual()
self.equal = P.Equal()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.fill = P.Fill()
self.transpose = P.Transpose()
self.ignore_label = ignore_label
self.loss_weight = 1.0
@ -51,13 +50,13 @@ class OhemLoss(nn.Cell):
weights = self.cast(self.not_equal(labels, self.ignore_label), mstype.float32) * self.loss_weight
weighted_losses = self.mul(losses, weights)
loss = self.reduce_sum(weighted_losses, (0,))
zeros = self.fill(mstype.float32, self.shape(weights), 0.0)
ones = self.fill(mstype.float32, self.shape(weights), 1.0)
zeros = F.fill(mstype.float32, self.shape(weights), 0.0)
ones = F.fill(mstype.float32, self.shape(weights), 1.0)
present = self.select(self.equal(weights, zeros), zeros, ones)
present = self.reduce_sum(present, (0,))
zeros = self.fill(mstype.float32, self.shape(present), 0.0)
min_control = self.fill(mstype.float32, self.shape(present), 1.0)
zeros = F.fill(mstype.float32, self.shape(present), 0.0)
min_control = F.fill(mstype.float32, self.shape(present), 1.0)
present = self.select(self.equal(present, zeros), min_control, present)
loss = loss / present
return loss

View File

@ -88,7 +88,7 @@ class TrainForwardBackward(Cell):
def construct(self, *inputs):
weights = self.weights
loss = self.network(*inputs)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
sens = F.fill(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*inputs, sens)
self.hyper_map(F.partial(_sum_op), self.grad_sum, grads)
return loss

View File

@ -59,7 +59,6 @@ class LambGPUOrigin(nn.Cell):
self.op_norm = layer.Norm()
self.op_select = P.Select()
self.op_greater = P.Greater()
self.op_fill = P.Fill()
self.op_dtype = P.DType()
def construct(self, beta1, beta2, eps, global_step, lr, weight_decay, decay_flag):
@ -82,12 +81,12 @@ class LambGPUOrigin(nn.Cell):
g_norm_hat = self.op_norm(self.op_mul(next_mm, self.op_rsqrt(next_vv + eps)) + weight_decay * param_fp32)
zeros = F.zeros_like(w_norm)
ones = self.op_fill(self.op_dtype(w_norm), self.op_shape(w_norm), 1.0)
ones = F.fill(self.op_dtype(w_norm), self.op_shape(w_norm), 1.0)
trust_ratio = self.op_select(
self.op_greater(w_norm, zeros),
self.op_select(self.op_greater(g_norm, zeros), w_norm / g_norm_hat, ones),
ones)
tens = self.op_fill(self.op_dtype(trust_ratio), self.op_shape(trust_ratio), 10.0)
tens = F.fill(self.op_dtype(trust_ratio), self.op_shape(trust_ratio), 10.0)
trust_ratio = ops.clip_by_value(trust_ratio, zeros, tens)
update = next_mm / (self.op_sqrt(next_vv) + eps)

View File

@ -21,7 +21,7 @@ import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
context.set_context(mode=context.GRAPH_MODE)
@ -46,13 +46,12 @@ def test_dtype_and_shape_as_attr_to_new_tensor():
class Net(nn.Cell):
def __init__(self, value):
super(Net, self).__init__()
self.fill = P.Fill()
self.value = value
def construct(self, x):
dtype = x.dtype
shape = x.shape
y = self.fill(dtype, shape, self.value)
y = F.fill(dtype, shape, self.value)
return y
net = Net(2.2)

View File

@ -19,7 +19,6 @@ import mindspore
from mindspore import nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops import operations as P
class Layer1(nn.Cell):
def __init__(self):
@ -61,7 +60,6 @@ class SwitchNet(nn.Cell):
self.layer2 = Layer2()
self.layer3 = Layer3()
self.layers = (self.layer1, self.layer2, self.layer3)
self.fill = P.Fill()
def construct(self, x, index):
y = self.layers[index](x)
@ -75,7 +73,6 @@ class MySwitchNet(nn.Cell):
self.layer2 = Layer2()
self.layer3 = Layer3()
self.layers = (self.layer1, self.layer2, self.layer3)
self.fill = P.Fill()
def construct(self, x, index):
y = self.layers[0](x)
@ -99,7 +96,6 @@ class MySwitchNetPynative(nn.Cell):
self.layer2 = Layer2()
self.layer3 = Layer3()
self.layers = (self.layer1, self.layer2, self.layer3)
self.fill = P.Fill()
def construct(self, x, index):
return self.layers[index](x)

View File

@ -284,7 +284,6 @@ class NpuFloatNet(nn.Cell):
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_status = P.NPUClearFloatStatus()
self.fill = P.Fill()
self.shape_op = P.Shape()
self.select = P.Select()
self.less = P.Less()
@ -303,7 +302,7 @@ class NpuFloatNet(nn.Cell):
get_status = self.get_status(init)
init = F.depend(init, get_status) # let reduce_sum depend on get_statusk
flag_sum = self.reduce_sum(init, (0,))
base = self.cast(self.fill(self.dtype(res), self.shape_op(res), 0.0), self.dtype(flag_sum))
base = self.cast(F.fill(self.dtype(res), self.shape_op(res), 0.0), self.dtype(flag_sum))
cond = self.less(base, flag_sum)
out = self.select(cond, self.cast(base, self.dtype(res)), res)
return out
@ -314,11 +313,10 @@ class DiagNet(nn.Cell):
def __init__(self):
super(DiagNet, self).__init__()
self.fill = P.Fill()
self.diag = P.Diag()
def construct(self, x):
return x - self.diag(self.fill(mstype.float32, (3,), 1.0))
return x - self.diag(F.fill(mstype.float32, (3,), 1.0))
class FmaxFunc(nn.Cell):

View File

@ -23,6 +23,7 @@ from mindspore.common.api import _cell_graph_executor
from mindspore.nn.optim import Adam, FTRL
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.parallel._utils import _reset_op_id as reset_op_id
@ -99,8 +100,8 @@ class TrainStepWarp(nn.Cell):
weights_w = self.weights_w
weights_d = self.weights_d
loss_w, loss_d = self.network(x)
sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens)
sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens)
sens_w = F.fill(P.DType()(loss_w), P.Shape()(loss_w), self.sens)
sens_d = F.fill(P.DType()(loss_d), P.Shape()(loss_d), self.sens)
grads_w = self.grad_w(self.loss_net_w, weights_w)(x, sens_w)
self.optimizer_w(grads_w)
grads_d = self.grad_d(self.loss_net_d, weights_d)(x, sens_d)

View File

@ -22,6 +22,7 @@ from mindspore.common.api import _cell_graph_executor
from mindspore.nn.optim import Adam, FTRL
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.parallel._utils import _reset_op_id as reset_op_id
from mindspore.parallel._cost_model_context import _set_algo_single_loop
@ -117,8 +118,8 @@ class TrainStepWarp(nn.Cell):
weights_w = self.weights_w
weights_d = self.weights_d
loss_w, loss_d = self.network(x)
sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens)
sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens)
sens_w = F.fill(P.DType()(loss_w), P.Shape()(loss_w), self.sens)
sens_d = F.fill(P.DType()(loss_d), P.Shape()(loss_d), self.sens)
grads_w = self.grad_w(self.loss_net_w, weights_w)(x, sens_w)
self.optimizer_w(grads_w)
grads_d = self.grad_d(self.loss_net_d, weights_d)(x, sens_d)

View File

@ -67,10 +67,10 @@ def test_two_matmul():
self.matmul2 = P.MatMul().shard(strategy2)
self.matmul3 = P.MatMul().shard(strategy3)
self.diag = P.Diag()
self.fill = P.Fill()
self.fillv2 = P.FillV2()
def construct(self, x, y):
fill = self.diag(self.fill(mstype.float32, (128,), 1.0))
fill = self.diag(self.fillv2((128,), Tensor(1.0, mstype.float32)))
out1 = self.matmul1(fill, x)
out2 = self.matmul2(y, fill)
out = self.matmul3(out1, out2)
@ -121,10 +121,10 @@ def test_two_matmul1():
self.matmul2 = P.MatMul().shard(strategy2)
self.matmul3 = P.MatMul().shard(strategy3)
self.diag = P.Diag()
self.fill = P.Fill()
self.fillv2 = P.FillV2()
def construct(self, x, y):
fill = self.diag(self.fill(mstype.float32, (128,), 1.0))
fill = self.diag(self.fillv2((128,), Tensor(1.0, mstype.float32)))
out1 = self.matmul1(fill, x)
out2 = self.matmul2(fill, y)
out = self.matmul3(out1, out2)

View File

@ -25,6 +25,7 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from mindspore.parallel import set_algo_parameters
from mindspore.train import Model
@ -414,7 +415,7 @@ class TrainOneStepCell(nn.Cell):
def construct(self, data):
weights = self.weights
loss = self.network(data)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
sens = F.fill(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(data, sens)
self.optimizer(grads)

View File

@ -21,6 +21,7 @@ from mindspore.common.api import _cell_graph_executor
from mindspore.nn.optim import Adam, FTRL
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
def setup_function():
@ -98,8 +99,8 @@ class TrainStepWrap(nn.Cell):
weights_w = self.weights_w
weights_d = self.weights_d
loss_w, loss_d = self.network(x)
sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens)
sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens)
sens_w = F.fill(P.DType()(loss_w), P.Shape()(loss_w), self.sens)
sens_d = F.fill(P.DType()(loss_d), P.Shape()(loss_d), self.sens)
grads_w = self.grad_w(self.loss_net_w, weights_w)(x, sens_w)
self.optimizer_w(grads_w)
grads_d = self.grad_d(self.loss_net_d, weights_d)(x, sens_d)

View File

@ -22,6 +22,7 @@ from mindspore import context
from mindspore.common.api import _cell_graph_executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
def setup_function():
@ -48,7 +49,7 @@ class GradWrap2(nn.Cell):
def construct(self, x, y, b):
loss = self.network(x, y, b)
sens = P.Fill()(mstype.float32, P.Shape()(loss), 1.0)
sens = F.fill(mstype.float32, P.Shape()(loss), 1.0)
return grad_all_with_sens(self.network)(x, y, b, sens)