forked from mindspore-Ecosystem/mindspore
!9380 Codedex change and comments change
From: @liangzhibo Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
493594e045
|
@ -49,6 +49,9 @@ static int64_t GetRank() {
|
||||||
}
|
}
|
||||||
|
|
||||||
static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num) {
|
static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num) {
|
||||||
|
if (stage_num == 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "stage_num is zero";
|
||||||
|
}
|
||||||
if (device_num % stage_num != 0) {
|
if (device_num % stage_num != 0) {
|
||||||
MS_LOG(EXCEPTION) << "Device_num must be divisible by the stage_num, got device_num: " << device_num
|
MS_LOG(EXCEPTION) << "Device_num must be divisible by the stage_num, got device_num: " << device_num
|
||||||
<< "stage_num: " << stage_num;
|
<< "stage_num: " << stage_num;
|
||||||
|
|
|
@ -72,21 +72,22 @@ class Parameter(MetaTensor_):
|
||||||
>>> from mindspore import context
|
>>> from mindspore import context
|
||||||
>>>
|
>>>
|
||||||
>>> class Net(Cell):
|
>>> class Net(Cell):
|
||||||
>>> def __init__(self):
|
... def __init__(self):
|
||||||
>>> super(Net, self).__init__()
|
... super(Net, self).__init__()
|
||||||
>>> self.matmul = P.MatMul()
|
... self.matmul = P.MatMul()
|
||||||
>>> self.weight = Parameter(Tensor(np.ones((1,2))), requires_grad=True)
|
... self.weight = Parameter(Tensor(np.ones((1,2))), name="w", requires_grad=True)
|
||||||
>>>
|
...
|
||||||
>>> def construct(self, x):
|
... def construct(self, x):
|
||||||
>>> out = self.matmul(self.weight, x)
|
... out = self.matmul(self.weight, x)
|
||||||
>>> return out
|
... return out
|
||||||
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||||
>>> net = Net()
|
>>> net = Net()
|
||||||
>>> x = Tensor(np.ones((2,1)))
|
>>> x = Tensor(np.ones((2,1)))
|
||||||
>>> net(x)
|
>>> print(net(x))
|
||||||
[[2.]]
|
[[2.]]
|
||||||
>>> net.weight.set_data(Tensor(np.zeros((1,2))))
|
>>> net.weight.set_data(Tensor(np.zeros((1,2))))
|
||||||
>>> net(x)
|
Parameter (name=w)
|
||||||
|
>>> print(net(x))
|
||||||
[[0.]]
|
[[0.]]
|
||||||
"""
|
"""
|
||||||
__base_type__ = {}
|
__base_type__ = {}
|
||||||
|
|
|
@ -46,6 +46,8 @@ class Tensor(Tensor_):
|
||||||
Tensor, with the same shape as `input_data`.
|
Tensor, with the same shape as `input_data`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
>>> import mindspore as ms
|
||||||
|
>>> import mindspore.nn as nn
|
||||||
>>> # initialize a tensor with input data
|
>>> # initialize a tensor with input data
|
||||||
>>> t1 = Tensor(np.zeros([1, 2, 3]), mindspore.float32)
|
>>> t1 = Tensor(np.zeros([1, 2, 3]), mindspore.float32)
|
||||||
>>> assert isinstance(t1, Tensor)
|
>>> assert isinstance(t1, Tensor)
|
||||||
|
@ -381,17 +383,25 @@ class RowTensor:
|
||||||
RowTensor, composed of `indices`, `values`, and `dense_shape`.
|
RowTensor, composed of `indices`, `values`, and `dense_shape`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
>>> import mindspore as ms
|
||||||
|
>>> import mindspore.nn as nn
|
||||||
>>> class Net(nn.Cell):
|
>>> class Net(nn.Cell):
|
||||||
>>> def __init__(self, dense_shape):
|
... def __init__(self, dense_shape):
|
||||||
>>> super(Net, self).__init__()
|
... super(Net, self).__init__()
|
||||||
>>> self.dense_shape = dense_shape
|
... self.dense_shape = dense_shape
|
||||||
>>> def construct(self, indices, values):
|
... def construct(self, indices, values):
|
||||||
>>> x = RowTensor(indices, values, self.dense_shape)
|
... x = RowTensor(indices, values, self.dense_shape)
|
||||||
>>> return x.values, x.indices, x.dense_shape
|
... return x.values, x.indices, x.dense_shape
|
||||||
>>>
|
>>>
|
||||||
>>> indices = Tensor([0])
|
>>> indices = Tensor([0])
|
||||||
>>> values = Tensor([[1, 2]], dtype=ms.float32)
|
>>> values = Tensor([[1, 2]], dtype=ms.float32)
|
||||||
>>> Net((3, 2))(indices, values)
|
>>> out = Net((3, 2))(indices, values)
|
||||||
|
>>> print(out[0])
|
||||||
|
[[1. 2.]]
|
||||||
|
>>> print(out[1])
|
||||||
|
[0]
|
||||||
|
>>> print(out[2])
|
||||||
|
(3, 2)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, indices, values, dense_shape):
|
def __init__(self, indices, values, dense_shape):
|
||||||
|
@ -437,17 +447,26 @@ class SparseTensor:
|
||||||
SparseTensor, composed of `indices`, `values`, and `dense_shape`.
|
SparseTensor, composed of `indices`, `values`, and `dense_shape`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
>>> import mindspore as ms
|
||||||
|
>>> import mindspore.nn as nn
|
||||||
>>> class Net(nn.Cell):
|
>>> class Net(nn.Cell):
|
||||||
>>> def __init__(self, dense_shape):
|
... def __init__(self, dense_shape):
|
||||||
>>> super(Net, self).__init__()
|
... super(Net, self).__init__()
|
||||||
>>> self.dense_shape = dense_shape
|
... self.dense_shape = dense_shape
|
||||||
>>> def construct(self, indices, values):
|
... def construct(self, indices, values):
|
||||||
>>> x = SparseTensor(indices, values, self.dense_shape)
|
... x = SparseTensor(indices, values, self.dense_shape)
|
||||||
>>> return x.values, x.indices, x.dense_shape
|
... return x.values, x.indices, x.dense_shape
|
||||||
>>>
|
>>>
|
||||||
>>> indices = Tensor([[0, 1], [1, 2]])
|
>>> indices = Tensor([[0, 1], [1, 2]])
|
||||||
>>> values = Tensor([1, 2], dtype=ms.float32)
|
>>> values = Tensor([1, 2], dtype=ms.float32)
|
||||||
>>> Net((3, 4))(indices, values)
|
>>> out = Net((3, 4))(indices, values)
|
||||||
|
>>> print(out[0])
|
||||||
|
[1. 2.]
|
||||||
|
>>> print(out[1])
|
||||||
|
[[0 1]
|
||||||
|
[1 2]]
|
||||||
|
>>> print(out[2])
|
||||||
|
(3, 4)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, indices, values, dense_shape):
|
def __init__(self, indices, values, dense_shape):
|
||||||
|
|
|
@ -811,6 +811,9 @@ class PConstant : public PBase<PConstant<T> > {
|
||||||
template <typename TM>
|
template <typename TM>
|
||||||
void CalcByOperator(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data,
|
void CalcByOperator(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data,
|
||||||
int out_data_size, BinOperator bin_operator) const {
|
int out_data_size, BinOperator bin_operator) const {
|
||||||
|
if (out_data_size <= 0) {
|
||||||
|
MS_EXCEPTION(ValueError) << "out_data_size should be greater than zeros";
|
||||||
|
}
|
||||||
TM *data_1 = reinterpret_cast<TM *>(in_data_1);
|
TM *data_1 = reinterpret_cast<TM *>(in_data_1);
|
||||||
TM *data_2 = reinterpret_cast<TM *>(in_data_2);
|
TM *data_2 = reinterpret_cast<TM *>(in_data_2);
|
||||||
TM *data_out = new TM[out_data_size];
|
TM *data_out = new TM[out_data_size];
|
||||||
|
|
|
@ -41,9 +41,9 @@ class InplaceAssign(PrimitiveWithInfer):
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> def construct(self, x):
|
>>> def construct(self, x):
|
||||||
>>> val = x - 1.0
|
... val = x - 1.0
|
||||||
>>> ret = x + 2.0
|
... ret = x + 2.0
|
||||||
>>> return InplaceAssign()(x, val, ret)
|
... return InplaceAssign()(x, val, ret)
|
||||||
>>> x = Tensor([2.0], mindspore.float32)
|
>>> x = Tensor([2.0], mindspore.float32)
|
||||||
>>> net = Net()
|
>>> net = Net()
|
||||||
>>> net(x)
|
>>> net(x)
|
||||||
|
@ -225,8 +225,8 @@ class BiasAdd(GraphKernel):
|
||||||
Tensor, the sum of x and bias.
|
Tensor, the sum of x and bias.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> layer = BiasGrad()
|
>>> layer = BiasAdd()
|
||||||
>>> output = BiasAdd(Tensor([1, 2, 3]), Tensor([1,]))
|
>>> output = layer(Tensor([1, 2, 3]), Tensor([1,]))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -286,7 +286,8 @@ class EqualCount(GraphKernel):
|
||||||
>>> x = Tensor(np.array([1, 2, 3]), mindspore.int32)
|
>>> x = Tensor(np.array([1, 2, 3]), mindspore.int32)
|
||||||
>>> y = Tensor(np.array([1, 2, 4]), mindspore.int32)
|
>>> y = Tensor(np.array([1, 2, 4]), mindspore.int32)
|
||||||
>>> equal_count = EqualCount()
|
>>> equal_count = EqualCount()
|
||||||
>>> equal_count(x, y)
|
>>> print(equal_count(x, y))
|
||||||
|
2
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(EqualCount, self).__init__()
|
super(EqualCount, self).__init__()
|
||||||
|
@ -744,7 +745,7 @@ class Tanh(GraphKernel):
|
||||||
>>> tanh = Tanh()
|
>>> tanh = Tanh()
|
||||||
>>> result = tanh(input_x)
|
>>> result = tanh(input_x)
|
||||||
>>> print(result)
|
>>> print(result)
|
||||||
[0.7615941 0.9640276 0.9950548 0.9993293 0.99990916]
|
[0.7615941 0.9640276 0.9950548 0.9993293 0.99990916]
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Tanh, self).__init__()
|
super(Tanh, self).__init__()
|
||||||
|
|
|
@ -570,14 +570,20 @@ class Pad(Cell):
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def interpolate(shape, size, scale):
|
def interpolate(shape, size, scale, align_corners):
|
||||||
"""Check input and calculate shape"""
|
"""Check input and calculate shape"""
|
||||||
|
if not isinstance(align_corners, bool):
|
||||||
|
raise TypeError("align_corners should be type boolean")
|
||||||
if size is None and scale is None:
|
if size is None and scale is None:
|
||||||
raise ValueError("size and scale both none")
|
raise ValueError("size and scale both none")
|
||||||
if size is not None and scale is not None:
|
if size is not None and scale is not None:
|
||||||
raise ValueError("size and scale both not none")
|
raise ValueError("size and scale both not none")
|
||||||
if size is not None:
|
if size is not None:
|
||||||
|
if not isinstance(size, (tuple, list)):
|
||||||
|
raise ValueError("size must be tuple or list")
|
||||||
Validator.check_int(len(size), 2, Rel.EQ, "size", "interpolate")
|
Validator.check_int(len(size), 2, Rel.EQ, "size", "interpolate")
|
||||||
|
Validator.check_int(size[0], 1, Rel.GE, "size[0]", "interpolate")
|
||||||
|
Validator.check_int(size[1], 1, Rel.GE, "size[1]", "interpolate")
|
||||||
return size
|
return size
|
||||||
Validator.check_int(scale, 1, Rel.GE, "scale factor", "interpolate")
|
Validator.check_int(scale, 1, Rel.GE, "scale factor", "interpolate")
|
||||||
ret = (scale * shape[2], scale * shape[3])
|
ret = (scale * shape[2], scale * shape[3])
|
||||||
|
@ -620,7 +626,7 @@ class Interpolate(Cell):
|
||||||
super(Interpolate, self).__init__()
|
super(Interpolate, self).__init__()
|
||||||
|
|
||||||
def construct(self, x, size=None, scale_factor=None, align_corners=False):
|
def construct(self, x, size=None, scale_factor=None, align_corners=False):
|
||||||
shape = interpolate(x.shape, size, scale_factor)
|
shape = interpolate(x.shape, size, scale_factor, align_corners)
|
||||||
resize_bilinear = P.ResizeBilinear(shape, align_corners)
|
resize_bilinear = P.ResizeBilinear(shape, align_corners)
|
||||||
return resize_bilinear(x)
|
return resize_bilinear(x)
|
||||||
|
|
||||||
|
@ -715,10 +721,12 @@ class Tril(Cell):
|
||||||
super(Tril, self).__init__()
|
super(Tril, self).__init__()
|
||||||
self.dtype = P.DType()
|
self.dtype = P.DType()
|
||||||
self.mul = P.Mul()
|
self.mul = P.Mul()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
|
||||||
def construct(self, x, k=0):
|
def construct(self, x, k=0):
|
||||||
assist = tril(x.shape, self.dtype(x), k)
|
assist = tril(x.shape, self.dtype(x), k)
|
||||||
return self.mul(x, assist)
|
result = self.mul(self.cast(x, mstype.int32), self.cast(assist, mstype.int32))
|
||||||
|
return self.cast(result, self.dtype(x))
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
@ -755,10 +763,12 @@ class Triu(Cell):
|
||||||
super(Triu, self).__init__()
|
super(Triu, self).__init__()
|
||||||
self.dtype = P.DType()
|
self.dtype = P.DType()
|
||||||
self.mul = P.Mul()
|
self.mul = P.Mul()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
|
||||||
def construct(self, x, k=0):
|
def construct(self, x, k=0):
|
||||||
assist = triu(x.shape, self.dtype(x), k)
|
assist = triu(x.shape, self.dtype(x), k)
|
||||||
return self.mul(x, assist)
|
result = self.mul(self.cast(x, mstype.int32), self.cast(assist, mstype.int32))
|
||||||
|
return self.cast(result, self.dtype(x))
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
|
Loading…
Reference in New Issue