!4072 Fix a bug in Tensor::equals()
Merge pull request !4072 from hewei/fixbug_tensor_equal
This commit is contained in:
commit
3b7df4e512
|
@ -176,13 +176,8 @@ class TensorDataImpl : public TensorData {
|
|||
ssize_t ndim() const override { return static_cast<ssize_t>(ndim_); }
|
||||
|
||||
void *data() override {
|
||||
static T empty_data = static_cast<T>(0);
|
||||
if (data_size_ == 0) {
|
||||
// Prevent null pointer for empty shape.
|
||||
return &empty_data;
|
||||
}
|
||||
// Lazy allocation.
|
||||
if (data_ == nullptr) {
|
||||
// Lazy allocation.
|
||||
data_ = std::make_unique<T[]>(data_size_);
|
||||
}
|
||||
return data_.get();
|
||||
|
@ -193,8 +188,14 @@ class TensorDataImpl : public TensorData {
|
|||
if (ptr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return (ptr == this) || ((ndim_ == ptr->ndim_) && (data_size_ == ptr->data_size_) &&
|
||||
(std::equal(data_.get(), data_.get() + data_size_, ptr->data_.get())));
|
||||
if (ptr == this) {
|
||||
return true;
|
||||
}
|
||||
if (data_ == nullptr || ptr->data_ == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return (ndim_ == ptr->ndim_) && (data_size_ == ptr->data_size_) &&
|
||||
std::equal(data_.get(), data_.get() + data_size_, ptr->data_.get());
|
||||
}
|
||||
|
||||
std::string ToString(const TypeId type, const std::vector<int> &shape) const override {
|
||||
|
|
|
@ -225,6 +225,27 @@ TEST_F(TestTensor, EqualTest) {
|
|||
ASSERT_EQ(TypeId::kNumberTypeFloat64, tensor_float64->data_type_c());
|
||||
}
|
||||
|
||||
TEST_F(TestTensor, ValueEqualTest) {
|
||||
py::tuple tuple = py::make_tuple(1, 2, 3, 4, 5, 6);
|
||||
TensorPtr t1 = TensorPy::MakeTensor(py::array(tuple), kInt32);
|
||||
TensorPtr t2 = TensorPy::MakeTensor(py::array(tuple), kInt32);
|
||||
ASSERT_TRUE(t1->ValueEqual(*t1));
|
||||
ASSERT_TRUE(t1->ValueEqual(*t2));
|
||||
|
||||
std::vector<int> shape = {6};
|
||||
TensorPtr t3 = std::make_shared<Tensor>(kInt32->type_id(), shape);
|
||||
TensorPtr t4 = std::make_shared<Tensor>(kInt32->type_id(), shape);
|
||||
ASSERT_TRUE(t3->ValueEqual(*t3));
|
||||
ASSERT_FALSE(t3->ValueEqual(*t4));
|
||||
ASSERT_FALSE(t3->ValueEqual(*t1));
|
||||
ASSERT_FALSE(t1->ValueEqual(*t3));
|
||||
|
||||
memcpy_s(t3->data_c(), t3->data().nbytes(), t1->data_c(), t1->data().nbytes());
|
||||
ASSERT_TRUE(t1->ValueEqual(*t3));
|
||||
ASSERT_FALSE(t3->ValueEqual(*t4));
|
||||
ASSERT_FALSE(t4->ValueEqual(*t3));
|
||||
}
|
||||
|
||||
TEST_F(TestTensor, PyArrayTest) {
|
||||
py::array_t<float, py::array::c_style> input({2, 3});
|
||||
auto array = input.mutable_unchecked();
|
||||
|
|
Loading…
Reference in New Issue