!4072 Fix a bug in Tensor::equals()

Merge pull request !4072 from hewei/fixbug_tensor_equal
This commit is contained in:
mindspore-ci-bot 2020-08-07 09:55:55 +08:00 committed by Gitee
commit 3b7df4e512
2 changed files with 30 additions and 8 deletions

View File

@ -176,13 +176,8 @@ class TensorDataImpl : public TensorData {
ssize_t ndim() const override { return static_cast<ssize_t>(ndim_); }
void *data() override {
static T empty_data = static_cast<T>(0);
if (data_size_ == 0) {
// Prevent null pointer for empty shape.
return &empty_data;
}
// Lazy allocation.
if (data_ == nullptr) {
// Lazy allocation.
data_ = std::make_unique<T[]>(data_size_);
}
return data_.get();
@ -193,8 +188,14 @@ class TensorDataImpl : public TensorData {
if (ptr == nullptr) {
return false;
}
return (ptr == this) || ((ndim_ == ptr->ndim_) && (data_size_ == ptr->data_size_) &&
(std::equal(data_.get(), data_.get() + data_size_, ptr->data_.get())));
if (ptr == this) {
return true;
}
if (data_ == nullptr || ptr->data_ == nullptr) {
return false;
}
return (ndim_ == ptr->ndim_) && (data_size_ == ptr->data_size_) &&
std::equal(data_.get(), data_.get() + data_size_, ptr->data_.get());
}
std::string ToString(const TypeId type, const std::vector<int> &shape) const override {

View File

@ -225,6 +225,27 @@ TEST_F(TestTensor, EqualTest) {
ASSERT_EQ(TypeId::kNumberTypeFloat64, tensor_float64->data_type_c());
}
TEST_F(TestTensor, ValueEqualTest) {
py::tuple tuple = py::make_tuple(1, 2, 3, 4, 5, 6);
TensorPtr t1 = TensorPy::MakeTensor(py::array(tuple), kInt32);
TensorPtr t2 = TensorPy::MakeTensor(py::array(tuple), kInt32);
ASSERT_TRUE(t1->ValueEqual(*t1));
ASSERT_TRUE(t1->ValueEqual(*t2));
std::vector<int> shape = {6};
TensorPtr t3 = std::make_shared<Tensor>(kInt32->type_id(), shape);
TensorPtr t4 = std::make_shared<Tensor>(kInt32->type_id(), shape);
ASSERT_TRUE(t3->ValueEqual(*t3));
ASSERT_FALSE(t3->ValueEqual(*t4));
ASSERT_FALSE(t3->ValueEqual(*t1));
ASSERT_FALSE(t1->ValueEqual(*t3));
memcpy_s(t3->data_c(), t3->data().nbytes(), t1->data_c(), t1->data().nbytes());
ASSERT_TRUE(t1->ValueEqual(*t3));
ASSERT_FALSE(t3->ValueEqual(*t4));
ASSERT_FALSE(t4->ValueEqual(*t3));
}
TEST_F(TestTensor, PyArrayTest) {
py::array_t<float, py::array::c_style> input({2, 3});
auto array = input.mutable_unchecked();