Codedex change and change some comment

This commit is contained in:
l00591931 2020-12-02 16:53:42 +08:00
parent 7dc96eb856
commit 256a93911d
6 changed files with 72 additions and 35 deletions

View File

@ -49,6 +49,9 @@ static int64_t GetRank() {
}
static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num) {
if (stage_num == 0) {
MS_LOG(EXCEPTION) << "stage_num is zero";
}
if (device_num % stage_num != 0) {
MS_LOG(EXCEPTION) << "Device_num must be divisible by the stage_num, got device_num: " << device_num
<< "stage_num: " << stage_num;

View File

@ -72,21 +72,22 @@ class Parameter(MetaTensor_):
>>> from mindspore import context
>>>
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.matmul = P.MatMul()
>>> self.weight = Parameter(Tensor(np.ones((1,2))), requires_grad=True)
>>>
>>> def construct(self, x):
>>> out = self.matmul(self.weight, x)
>>> return out
... def __init__(self):
... super(Net, self).__init__()
... self.matmul = P.MatMul()
... self.weight = Parameter(Tensor(np.ones((1,2))), name="w", requires_grad=True)
...
... def construct(self, x):
... out = self.matmul(self.weight, x)
... return out
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
>>> net = Net()
>>> x = Tensor(np.ones((2,1)))
>>> net(x)
>>> print(net(x))
[[2.]]
>>> net.weight.set_data(Tensor(np.zeros((1,2))))
>>> net(x)
Parameter (name=w)
>>> print(net(x))
[[0.]]
"""
__base_type__ = {}

View File

@ -46,6 +46,8 @@ class Tensor(Tensor_):
Tensor, with the same shape as `input_data`.
Examples:
>>> import mindspore as ms
>>> import mindspore.nn as nn
>>> # initialize a tensor with input data
>>> t1 = Tensor(np.zeros([1, 2, 3]), mindspore.float32)
>>> assert isinstance(t1, Tensor)
@ -381,17 +383,25 @@ class RowTensor:
RowTensor, composed of `indices`, `values`, and `dense_shape`.
Examples:
>>> import mindspore as ms
>>> import mindspore.nn as nn
>>> class Net(nn.Cell):
>>> def __init__(self, dense_shape):
>>> super(Net, self).__init__()
>>> self.dense_shape = dense_shape
>>> def construct(self, indices, values):
>>> x = RowTensor(indices, values, self.dense_shape)
>>> return x.values, x.indices, x.dense_shape
... def __init__(self, dense_shape):
... super(Net, self).__init__()
... self.dense_shape = dense_shape
... def construct(self, indices, values):
... x = RowTensor(indices, values, self.dense_shape)
... return x.values, x.indices, x.dense_shape
>>>
>>> indices = Tensor([0])
>>> values = Tensor([[1, 2]], dtype=ms.float32)
>>> Net((3, 2))(indices, values)
>>> out = Net((3, 2))(indices, values)
>>> print(out[0])
[[1. 2.]]
>>> print(out[1])
[0]
>>> print(out[2])
(3, 2)
"""
def __init__(self, indices, values, dense_shape):
@ -437,17 +447,26 @@ class SparseTensor:
SparseTensor, composed of `indices`, `values`, and `dense_shape`.
Examples:
>>> import mindspore as ms
>>> import mindspore.nn as nn
>>> class Net(nn.Cell):
>>> def __init__(self, dense_shape):
>>> super(Net, self).__init__()
>>> self.dense_shape = dense_shape
>>> def construct(self, indices, values):
>>> x = SparseTensor(indices, values, self.dense_shape)
>>> return x.values, x.indices, x.dense_shape
... def __init__(self, dense_shape):
... super(Net, self).__init__()
... self.dense_shape = dense_shape
... def construct(self, indices, values):
... x = SparseTensor(indices, values, self.dense_shape)
... return x.values, x.indices, x.dense_shape
>>>
>>> indices = Tensor([[0, 1], [1, 2]])
>>> values = Tensor([1, 2], dtype=ms.float32)
>>> Net((3, 4))(indices, values)
>>> out = Net((3, 4))(indices, values)
>>> print(out[0])
[1. 2.]
>>> print(out[1])
[[0 1]
[1 2]]
>>> print(out[2])
(3, 4)
"""
def __init__(self, indices, values, dense_shape):

View File

@ -811,6 +811,9 @@ class PConstant : public PBase<PConstant<T> > {
template <typename TM>
void CalcByOperator(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data,
int out_data_size, BinOperator bin_operator) const {
if (out_data_size <= 0) {
MS_EXCEPTION(ValueError) << "out_data_size should be greater than zeros";
}
TM *data_1 = reinterpret_cast<TM *>(in_data_1);
TM *data_2 = reinterpret_cast<TM *>(in_data_2);
TM *data_out = new TM[out_data_size];

View File

@ -41,9 +41,9 @@ class InplaceAssign(PrimitiveWithInfer):
Examples:
>>> def construct(self, x):
>>> val = x - 1.0
>>> ret = x + 2.0
>>> return InplaceAssign()(x, val, ret)
... val = x - 1.0
... ret = x + 2.0
... return InplaceAssign()(x, val, ret)
>>> x = Tensor([2.0], mindspore.float32)
>>> net = Net()
>>> net(x)
@ -225,8 +225,8 @@ class BiasAdd(GraphKernel):
Tensor, the sum of x and bias.
Example:
>>> layer = BiasGrad()
>>> output = BiasAdd(Tensor([1, 2, 3]), Tensor([1,]))
>>> layer = BiasAdd()
>>> output = layer(Tensor([1, 2, 3]), Tensor([1,]))
"""
def __init__(self):
@ -286,7 +286,8 @@ class EqualCount(GraphKernel):
>>> x = Tensor(np.array([1, 2, 3]), mindspore.int32)
>>> y = Tensor(np.array([1, 2, 4]), mindspore.int32)
>>> equal_count = EqualCount()
>>> equal_count(x, y)
>>> print(equal_count(x, y))
2
"""
def __init__(self):
super(EqualCount, self).__init__()
@ -744,7 +745,7 @@ class Tanh(GraphKernel):
>>> tanh = Tanh()
>>> result = tanh(input_x)
>>> print(result)
[0.7615941 0.9640276 0.9950548 0.9993293 0.99990916]
[0.7615941 0.9640276 0.9950548 0.9993293 0.99990916]
"""
def __init__(self):
super(Tanh, self).__init__()

View File

@ -570,14 +570,20 @@ class Pad(Cell):
@constexpr
def interpolate(shape, size, scale):
def interpolate(shape, size, scale, align_corners):
"""Check input and calculate shape"""
if not isinstance(align_corners, bool):
raise TypeError("align_corners should be type boolean")
if size is None and scale is None:
raise ValueError("size and scale both none")
if size is not None and scale is not None:
raise ValueError("size and scale both not none")
if size is not None:
if not isinstance(size, (tuple, list)):
raise ValueError("size must be tuple or list")
Validator.check_int(len(size), 2, Rel.EQ, "size", "interpolate")
Validator.check_int(size[0], 1, Rel.GE, "size[0]", "interpolate")
Validator.check_int(size[1], 1, Rel.GE, "size[1]", "interpolate")
return size
Validator.check_int(scale, 1, Rel.GE, "scale factor", "interpolate")
ret = (scale * shape[2], scale * shape[3])
@ -620,7 +626,7 @@ class Interpolate(Cell):
super(Interpolate, self).__init__()
def construct(self, x, size=None, scale_factor=None, align_corners=False):
shape = interpolate(x.shape, size, scale_factor)
shape = interpolate(x.shape, size, scale_factor, align_corners)
resize_bilinear = P.ResizeBilinear(shape, align_corners)
return resize_bilinear(x)
@ -715,10 +721,12 @@ class Tril(Cell):
super(Tril, self).__init__()
self.dtype = P.DType()
self.mul = P.Mul()
self.cast = P.Cast()
def construct(self, x, k=0):
assist = tril(x.shape, self.dtype(x), k)
return self.mul(x, assist)
result = self.mul(self.cast(x, mstype.int32), self.cast(assist, mstype.int32))
return self.cast(result, self.dtype(x))
@constexpr
@ -755,10 +763,12 @@ class Triu(Cell):
super(Triu, self).__init__()
self.dtype = P.DType()
self.mul = P.Mul()
self.cast = P.Cast()
def construct(self, x, k=0):
assist = triu(x.shape, self.dtype(x), k)
return self.mul(x, assist)
result = self.mul(self.cast(x, mstype.int32), self.cast(assist, mstype.int32))
return self.cast(result, self.dtype(x))
@constexpr