forked from mindspore-Ecosystem/mindspore
!4072 Fix a bug in Tensor::equals()
Merge pull request !4072 from hewei/fixbug_tensor_equal
This commit is contained in:
commit
3b7df4e512
|
@ -176,13 +176,8 @@ class TensorDataImpl : public TensorData {
|
||||||
ssize_t ndim() const override { return static_cast<ssize_t>(ndim_); }
|
ssize_t ndim() const override { return static_cast<ssize_t>(ndim_); }
|
||||||
|
|
||||||
void *data() override {
|
void *data() override {
|
||||||
static T empty_data = static_cast<T>(0);
|
|
||||||
if (data_size_ == 0) {
|
|
||||||
// Prevent null pointer for empty shape.
|
|
||||||
return &empty_data;
|
|
||||||
}
|
|
||||||
// Lazy allocation.
|
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
|
// Lazy allocation.
|
||||||
data_ = std::make_unique<T[]>(data_size_);
|
data_ = std::make_unique<T[]>(data_size_);
|
||||||
}
|
}
|
||||||
return data_.get();
|
return data_.get();
|
||||||
|
@ -193,8 +188,14 @@ class TensorDataImpl : public TensorData {
|
||||||
if (ptr == nullptr) {
|
if (ptr == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return (ptr == this) || ((ndim_ == ptr->ndim_) && (data_size_ == ptr->data_size_) &&
|
if (ptr == this) {
|
||||||
(std::equal(data_.get(), data_.get() + data_size_, ptr->data_.get())));
|
return true;
|
||||||
|
}
|
||||||
|
if (data_ == nullptr || ptr->data_ == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return (ndim_ == ptr->ndim_) && (data_size_ == ptr->data_size_) &&
|
||||||
|
std::equal(data_.get(), data_.get() + data_size_, ptr->data_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string ToString(const TypeId type, const std::vector<int> &shape) const override {
|
std::string ToString(const TypeId type, const std::vector<int> &shape) const override {
|
||||||
|
|
|
@ -225,6 +225,27 @@ TEST_F(TestTensor, EqualTest) {
|
||||||
ASSERT_EQ(TypeId::kNumberTypeFloat64, tensor_float64->data_type_c());
|
ASSERT_EQ(TypeId::kNumberTypeFloat64, tensor_float64->data_type_c());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TestTensor, ValueEqualTest) {
|
||||||
|
py::tuple tuple = py::make_tuple(1, 2, 3, 4, 5, 6);
|
||||||
|
TensorPtr t1 = TensorPy::MakeTensor(py::array(tuple), kInt32);
|
||||||
|
TensorPtr t2 = TensorPy::MakeTensor(py::array(tuple), kInt32);
|
||||||
|
ASSERT_TRUE(t1->ValueEqual(*t1));
|
||||||
|
ASSERT_TRUE(t1->ValueEqual(*t2));
|
||||||
|
|
||||||
|
std::vector<int> shape = {6};
|
||||||
|
TensorPtr t3 = std::make_shared<Tensor>(kInt32->type_id(), shape);
|
||||||
|
TensorPtr t4 = std::make_shared<Tensor>(kInt32->type_id(), shape);
|
||||||
|
ASSERT_TRUE(t3->ValueEqual(*t3));
|
||||||
|
ASSERT_FALSE(t3->ValueEqual(*t4));
|
||||||
|
ASSERT_FALSE(t3->ValueEqual(*t1));
|
||||||
|
ASSERT_FALSE(t1->ValueEqual(*t3));
|
||||||
|
|
||||||
|
memcpy_s(t3->data_c(), t3->data().nbytes(), t1->data_c(), t1->data().nbytes());
|
||||||
|
ASSERT_TRUE(t1->ValueEqual(*t3));
|
||||||
|
ASSERT_FALSE(t3->ValueEqual(*t4));
|
||||||
|
ASSERT_FALSE(t4->ValueEqual(*t3));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(TestTensor, PyArrayTest) {
|
TEST_F(TestTensor, PyArrayTest) {
|
||||||
py::array_t<float, py::array::c_style> input({2, 3});
|
py::array_t<float, py::array::c_style> input({2, 3});
|
||||||
auto array = input.mutable_unchecked();
|
auto array = input.mutable_unchecked();
|
||||||
|
|
Loading…
Reference in New Issue