forked from mindspore-Ecosystem/mindspore
fix shape bug.
This commit is contained in:
parent
93daaf4de9
commit
4477fcfe19
|
@ -446,7 +446,9 @@ Tensor::Tensor(const Tensor &tensor)
|
|||
event_(tensor.event_),
|
||||
sync_status_(tensor.sync_status_),
|
||||
device_sync_(tensor.device_sync_),
|
||||
padding_type_(tensor.padding_type()) {}
|
||||
padding_type_(tensor.padding_type()) {
|
||||
CheckShape(tensor.shape_);
|
||||
}
|
||||
|
||||
Tensor::Tensor(const Tensor &tensor, TypeId data_type)
|
||||
: MetaTensor(data_type, tensor.shape_),
|
||||
|
@ -456,29 +458,43 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type)
|
|||
event_(tensor.event_),
|
||||
sync_status_(tensor.sync_status_),
|
||||
device_sync_(tensor.device_sync_),
|
||||
padding_type_(tensor.padding_type()) {}
|
||||
padding_type_(tensor.padding_type()) {
|
||||
CheckShape(tensor.shape_);
|
||||
}
|
||||
|
||||
Tensor::Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data)
|
||||
: MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {}
|
||||
: MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {
|
||||
CheckShape(shape);
|
||||
}
|
||||
|
||||
Tensor::Tensor(TypeId data_type, const ShapeVector &shape)
|
||||
: Tensor(data_type, shape, MakeTensorData(data_type, shape)) {}
|
||||
: Tensor(data_type, shape, MakeTensorData(data_type, shape)) {
|
||||
CheckShape(shape);
|
||||
}
|
||||
|
||||
Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len)
|
||||
: Tensor(data_type, shape, MakeTensorData(data_type, shape, data, data_len)) {}
|
||||
: Tensor(data_type, shape, MakeTensorData(data_type, shape, data, data_len)) {
|
||||
CheckShape(shape);
|
||||
}
|
||||
|
||||
Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, TypeId src_data_type)
|
||||
: Tensor(data_type, shape, MakeTensorData(data_type, shape, data, src_data_type)) {}
|
||||
: Tensor(data_type, shape, MakeTensorData(data_type, shape, data, src_data_type)) {
|
||||
CheckShape(shape);
|
||||
}
|
||||
|
||||
Tensor::Tensor(const std::vector<int64_t> &input, const TypePtr &data_type)
|
||||
: MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {static_cast<int>(input.size())}),
|
||||
data_(MakeTensorData(data_type_, shape_, input.data(), input.size())),
|
||||
id_(MakeId()) {}
|
||||
id_(MakeId()) {
|
||||
CheckShape(shape_);
|
||||
}
|
||||
|
||||
Tensor::Tensor(const std::vector<double> &input, const TypePtr &data_type)
|
||||
: MetaTensor(TypeIdOf(data_type, kNumberTypeFloat32), {static_cast<int>(input.size())}),
|
||||
data_(MakeTensorData(data_type_, shape_, input.data(), input.size())),
|
||||
id_(MakeId()) {}
|
||||
id_(MakeId()) {
|
||||
CheckShape(shape_);
|
||||
}
|
||||
|
||||
Tensor::Tensor(int64_t input, const TypePtr &data_type)
|
||||
: MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {}),
|
||||
|
@ -497,6 +513,7 @@ bool Tensor::operator==(const Tensor &tensor) const {
|
|||
bool Tensor::ValueEqual(const Tensor &tensor) const {
|
||||
return (&tensor == this || (MetaTensor::operator==(tensor) && data_->equals(*tensor.data_)));
|
||||
}
|
||||
|
||||
// assgin value to this tensor
|
||||
Tensor &Tensor::AssignValue(const Tensor &tensor) {
|
||||
if (this != &tensor) {
|
||||
|
@ -573,6 +590,17 @@ std::string Tensor::ToStringRepr() const {
|
|||
return buf.str();
|
||||
}
|
||||
|
||||
void Tensor::CheckShape(const ShapeVector &shape) const {
|
||||
// Check tensor's shape, ignore one-dimensional tensor, including empty tensor with shape=(0,).
|
||||
if (shape.size() > 1) {
|
||||
for (const auto &s : shape) {
|
||||
if (s == 0) {
|
||||
MS_EXCEPTION(ValueError) << "Zero is not supported in the shape of Tensor !";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Tensor::data_sync(bool need_wait) const {
|
||||
if (need_wait) {
|
||||
Wait();
|
||||
|
|
|
@ -261,6 +261,8 @@ class Tensor : public MetaTensor {
|
|||
|
||||
std::string ToStringRepr() const;
|
||||
|
||||
void CheckShape(const ShapeVector &shape) const;
|
||||
|
||||
bool is_init() const { return init_flag_; }
|
||||
void set_init_flag(bool flag) { init_flag_ = flag; }
|
||||
|
||||
|
|
|
@ -33,11 +33,6 @@ class TensroAdd(nn.Cell):
|
|||
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
self.x = Parameter(initializer(
|
||||
Tensor(np.random.randn(2, 0).astype(np.float32)), [2, 0]), name='x')
|
||||
self.y = Parameter(initializer(
|
||||
Tensor(np.random.randn(2, 1).astype(np.float32)), [2, 1]), name='y')
|
||||
|
||||
self.x1 = Parameter(initializer(
|
||||
Tensor(np.arange(3).reshape(3).astype(np.float32)), [3]), name='x1')
|
||||
self.y1 = Parameter(initializer(
|
||||
|
@ -55,20 +50,17 @@ class TensroAdd(nn.Cell):
|
|||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return (
|
||||
self.add(self.x, self.y), self.add(self.x1, self.y1), self.add(self.x2, self.y2),
|
||||
self.add(self.x3, self.y3))
|
||||
return (self.add(self.x1, self.y1), self.add(self.x2, self.y2), self.add(self.x3, self.y3))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_TensroAdd():
|
||||
def test_TensorAdd():
|
||||
add = TensroAdd()
|
||||
output = add()
|
||||
expect0 = np.array([])
|
||||
expect1 = np.array([2, 3, 4])
|
||||
expect2 = np.array(
|
||||
expect0 = np.array([2, 3, 4])
|
||||
expect1 = np.array(
|
||||
[[[[0., 2., 4.],
|
||||
[6., 8., 10.],
|
||||
[12., 14., 16.]],
|
||||
|
@ -96,7 +88,7 @@ def test_TensroAdd():
|
|||
[[144., 146., 148.],
|
||||
[150., 152., 154.],
|
||||
[156., 158., 160.]]]])
|
||||
expect3 = np.array(
|
||||
expect2 = np.array(
|
||||
[[[[0., 2., 4.],
|
||||
[6., 8., 10.],
|
||||
[12., 14., 16.]],
|
||||
|
@ -128,4 +120,26 @@ def test_TensroAdd():
|
|||
assert (output[0].asnumpy() == expect0).all()
|
||||
assert (output[1].asnumpy() == expect1).all()
|
||||
assert (output[2].asnumpy() == expect2).all()
|
||||
assert (output[3].asnumpy() == expect3).all()
|
||||
|
||||
|
||||
class TensorAdd2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(TensorAdd2, self).__init__()
|
||||
self.add = P.TensorAdd()
|
||||
self.x = Parameter(initializer(
|
||||
Tensor(np.random.randn(2, 0).astype(np.float32)), [2, 0]), name='x')
|
||||
self.y = Parameter(initializer(
|
||||
Tensor(np.random.randn(2, 1).astype(np.float32)), [2, 1]), name='y')
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.add(self.x, self.y)
|
||||
|
||||
|
||||
# Constructing a tensor with 0 in shape is not support, excluding empty tensor.
|
||||
@pytest.mark.skip(reason='0 in shape is not support')
|
||||
def test_TensorAdd_shape_has_zero():
|
||||
add = TensorAdd2()
|
||||
output = add()
|
||||
expect = np.array([])
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
|
|
@ -786,8 +786,8 @@ def test_tensor_assign_exception():
|
|||
t = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
|
||||
tck = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
|
||||
# Error for A[Slice] = Number
|
||||
# 1. A[Slice] = Number, Slice error
|
||||
with pytest.raises(IndexError):
|
||||
# 1. A[Slice] = Number, 0 in shape
|
||||
with pytest.raises(ValueError):
|
||||
net_e2(t, 2)
|
||||
|
||||
# Error for A[Slice] = U, U is a Tensor
|
||||
|
|
|
@ -68,6 +68,18 @@ def test_tensor():
|
|||
assert t4.dtype == ms.int64
|
||||
|
||||
|
||||
def test_tensor_empty():
|
||||
t = ms.Tensor(np.ones(0), ms.float32)
|
||||
assert isinstance(t, ms.Tensor)
|
||||
assert t.shape == (0,)
|
||||
|
||||
|
||||
def test_tensor_shape_has_zero():
|
||||
with pytest.raises(ValueError):
|
||||
t = ms.Tensor(np.ones((1, 0)), ms.float32)
|
||||
print(t)
|
||||
|
||||
|
||||
def test_tensor_type_float16():
|
||||
t_float16 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float16))
|
||||
assert isinstance(t_float16, ms.Tensor)
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
""" test ops """
|
||||
import functools
|
||||
import pytest
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -770,6 +771,8 @@ class StridedSliceNet(nn.Cell):
|
|||
return out_0, out_1, out_2, out_3
|
||||
|
||||
|
||||
# Constructing a tensor with 0 in shape is not support, excluding empty tensor.
|
||||
@pytest.mark.skip(reason='0 in shape is not support')
|
||||
def test_strided_slice_const():
|
||||
class StridedSLiceConstNet(nn.Cell):
|
||||
"""StridedSLiceConstNet net definition"""
|
||||
|
|
|
@ -464,8 +464,8 @@ def test_tensor_assign():
|
|||
net(Ta, b, Tck)
|
||||
net2(t, b, tck)
|
||||
# Error for A[Slice] = Number
|
||||
# 1. A[Slice] = Number, Slice error
|
||||
with pytest.raises(IndexError):
|
||||
# 1. A[Slice] = Number, 0 in shape
|
||||
with pytest.raises(ValueError):
|
||||
net_e2(t, Tensor(2, mstype.int32))
|
||||
|
||||
# Error for A[Slice] = U, U is a Tensor
|
||||
|
|
Loading…
Reference in New Issue