Updated TensorRow conversion

This commit is contained in:
Mahdi 2020-09-29 14:26:04 -04:00
parent 2a799fe90e
commit e92d4edceb
7 changed files with 65 additions and 91 deletions

View File

@ -294,7 +294,7 @@ std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remai
// Function to create a BucketBatchByLength dataset // Function to create a BucketBatchByLength dataset
std::shared_ptr<BucketBatchByLengthDataset> Dataset::BucketBatchByLength( std::shared_ptr<BucketBatchByLengthDataset> Dataset::BucketBatchByLength(
const std::vector<std::string> &column_names, const std::vector<int32_t> &bucket_boundaries, const std::vector<std::string> &column_names, const std::vector<int32_t> &bucket_boundaries,
const std::vector<int32_t> &bucket_batch_sizes, TensorRow (*element_length_function)(TensorRow), const std::vector<int32_t> &bucket_batch_sizes, std::function<TensorRow(TensorRow)> element_length_function,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary, const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,
bool drop_remainder) { bool drop_remainder) {
auto ds = std::make_shared<BucketBatchByLengthDataset>(column_names, bucket_boundaries, bucket_batch_sizes, auto ds = std::make_shared<BucketBatchByLengthDataset>(column_names, bucket_boundaries, bucket_batch_sizes,
@ -1698,7 +1698,7 @@ bool BatchDataset::ValidateParams() {
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
BucketBatchByLengthDataset::BucketBatchByLengthDataset( BucketBatchByLengthDataset::BucketBatchByLengthDataset(
const std::vector<std::string> &column_names, const std::vector<int32_t> &bucket_boundaries, const std::vector<std::string> &column_names, const std::vector<int32_t> &bucket_boundaries,
const std::vector<int32_t> &bucket_batch_sizes, TensorRow (*element_length_function)(TensorRow), const std::vector<int32_t> &bucket_batch_sizes, std::function<TensorRow(TensorRow)> element_length_function,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary, const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,
bool drop_remainder) bool drop_remainder)
: column_names_(column_names), : column_names_(column_names),

View File

@ -174,6 +174,18 @@ class Tensor {
return CreateFromVector(items, TensorShape({static_cast<dsize_t>(items.size())}), out); return CreateFromVector(items, TensorShape({static_cast<dsize_t>(items.size())}), out);
} }
/// Create a 1D boolean Tensor from a given list of boolean values.
/// \param[in] items elements of the tensor
/// \param[in] shape shape of the output tensor
/// \param[out] out output argument to hold the created Tensor
/// \return Status Code
static Status CreateFromVector(const std::vector<bool> &items, const TensorShape &shape, TensorPtr *out) {
std::vector<uint8_t> temp(items.begin(), items.end());
RETURN_IF_NOT_OK(CreateFromVector(temp, shape, out));
(*out)->type_ = DataType(DataType::DE_BOOL);
return Status::OK();
}
/// Create a numeric scalar Tensor from the given value. /// Create a numeric scalar Tensor from the given value.
/// \tparam T type of value /// \tparam T type of value
/// \param[in] item value /// \param[in] item value

View File

@ -72,7 +72,7 @@ class TensorRow {
// Destructor // Destructor
~TensorRow() = default; ~TensorRow() = default;
/// Convert a vector of primitive types to a TensorRow consisting of n single data Tensors. /// Convert a vector of primitive types to a TensorRow consisting of one 1-D Tensor with the shape n.
/// \tparam `T` /// \tparam `T`
/// \param[in] o input vector /// \param[in] o input vector
/// \param[out] output TensorRow /// \param[out] output TensorRow
@ -85,14 +85,9 @@ class TensorRow {
if (data_type == DataType::DE_STRING) { if (data_type == DataType::DE_STRING) {
RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported."); RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported.");
} }
for (int i = 0; i < o.size(); i++) {
std::shared_ptr<Tensor> tensor; std::shared_ptr<Tensor> tensor;
Tensor::CreateEmpty(TensorShape({1}), data_type, &tensor); RETURN_IF_NOT_OK(Tensor::CreateFromVector(o, &tensor));
std::string_view s;
tensor->SetItemAt({0}, o[i]);
output->push_back(tensor); output->push_back(tensor);
}
return Status::OK(); return Status::OK();
} }
@ -110,13 +105,12 @@ class TensorRow {
RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported."); RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported.");
} }
std::shared_ptr<Tensor> tensor; std::shared_ptr<Tensor> tensor;
Tensor::CreateEmpty(TensorShape({1}), data_type, &tensor); RETURN_IF_NOT_OK(Tensor::CreateScalar(o, &tensor));
tensor->SetItemAt({0}, o);
output->push_back(tensor); output->push_back(tensor);
return Status::OK(); return Status::OK();
} }
/// Return the value in a TensorRow consiting of 1 single data Tensor /// Return the value in a TensorRow consisting of 1 single data Tensor.
/// \tparam `T` /// \tparam `T`
/// \param[in] input TensorRow /// \param[in] input TensorRow
/// \param[out] o the primitive variable /// \param[out] o the primitive variable
@ -127,23 +121,23 @@ class TensorRow {
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: Data type was not recognized."); RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: Data type was not recognized.");
} }
if (data_type == DataType::DE_STRING) { if (data_type == DataType::DE_STRING) {
RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported."); RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: Data type string is not supported.");
} }
if (input.size() != 1) { if (input.size() != 1) {
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input TensorRow is empty."); RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input TensorRow must have exactly one tensor.");
} }
if (input.at(0)->type() != data_type) { if (input.at(0)->type() != data_type) {
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The output type doesn't match the input tensor type."); RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The output type doesn't match the input tensor type.");
} }
if (input.at(0)->shape() != TensorShape({1})) { if (input.at(0)->shape() != TensorShape({})) {
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input tensors must have a shape of {1}."); RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input tensors must be a scalar tensor.");
} }
return input.at(0)->GetItemAt(o, {0}); return input.at(0)->GetItemAt(o, {0});
} }
/// Convert a TensorRow consisting of n single data tensors to a vector of size n /// Convert a TensorRow consisting of one 1-D tensor to a vector of size n.
/// \tparam `T` /// \tparam `T`
/// \param[in] o TensorRow consisting of n single data tensors /// \param[in] o TensorRow consisting of one 1-D tensor
/// \param[out] o vector of primitive variable /// \param[out] o vector of primitive variable
template <typename T> template <typename T>
static Status ConvertFromTensorRow(const TensorRow &input, std::vector<T> *o) { static Status ConvertFromTensorRow(const TensorRow &input, std::vector<T> *o) {
@ -152,15 +146,15 @@ class TensorRow {
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: Data type was not recognized."); RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: Data type was not recognized.");
} }
if (data_type == DataType::DE_STRING) { if (data_type == DataType::DE_STRING) {
RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported."); RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: Data type string is not supported.");
} }
for (int i = 0; i < input.size(); i++) { if (input.size() != 1) {
if (input.at(i)->shape() != TensorShape({1})) { RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input TensorRow must have exactly one tensor.");
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input tensor must have a shape of 1.");
} }
T item; if (input.at(0)->Rank() != 1)
RETURN_IF_NOT_OK(input.at(i)->GetItemAt(&item, {0})); RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input tensor must have a rank of 1.");
o->push_back(item); for (auto it = input.at(0)->begin<T>(); it != input.at(0)->end<T>(); it++) {
o->push_back(*it);
} }
return Status::OK(); return Status::OK();
} }

View File

@ -507,7 +507,8 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return Shared pointer to the current BucketBatchByLengthDataset /// \return Shared pointer to the current BucketBatchByLengthDataset
std::shared_ptr<BucketBatchByLengthDataset> BucketBatchByLength( std::shared_ptr<BucketBatchByLengthDataset> BucketBatchByLength(
const std::vector<std::string> &column_names, const std::vector<int32_t> &bucket_boundaries, const std::vector<std::string> &column_names, const std::vector<int32_t> &bucket_boundaries,
const std::vector<int32_t> &bucket_batch_sizes, TensorRow (*element_length_function)(TensorRow) = nullptr, const std::vector<int32_t> &bucket_batch_sizes,
std::function<TensorRow(TensorRow)> element_length_function = nullptr,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {}, const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {},
bool pad_to_bucket_boundary = false, bool drop_remainder = false); bool pad_to_bucket_boundary = false, bool drop_remainder = false);
@ -1156,7 +1157,8 @@ class BucketBatchByLengthDataset : public Dataset {
/// \brief Constructor /// \brief Constructor
BucketBatchByLengthDataset( BucketBatchByLengthDataset(
const std::vector<std::string> &column_names, const std::vector<int32_t> &bucket_boundaries, const std::vector<std::string> &column_names, const std::vector<int32_t> &bucket_boundaries,
const std::vector<int32_t> &bucket_batch_sizes, TensorRow (*element_length_function)(TensorRow) = nullptr, const std::vector<int32_t> &bucket_batch_sizes,
std::function<TensorRow(TensorRow)> element_length_function = nullptr,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {}, const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {},
bool pad_to_bucket_boundary = false, bool drop_remainder = false); bool pad_to_bucket_boundary = false, bool drop_remainder = false);
@ -1175,7 +1177,7 @@ class BucketBatchByLengthDataset : public Dataset {
std::vector<std::string> column_names_; std::vector<std::string> column_names_;
std::vector<int32_t> bucket_boundaries_; std::vector<int32_t> bucket_boundaries_;
std::vector<int32_t> bucket_batch_sizes_; std::vector<int32_t> bucket_batch_sizes_;
TensorRow (*element_length_function_)(TensorRow); std::function<TensorRow(TensorRow)> element_length_function_;
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_info_; std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_info_;
bool pad_to_bucket_boundary_; bool pad_to_bucket_boundary_;
bool drop_remainder_; bool drop_remainder_;

View File

@ -27,7 +27,7 @@ Status CFuncOp::Compute(const TensorRow &input, TensorRow *output) {
IO_CHECK_VECTOR(input, output); IO_CHECK_VECTOR(input, output);
Status ret = Status(StatusCode::kOK, "CFunc Call Succeed"); Status ret = Status(StatusCode::kOK, "CFunc Call Succeed");
try { try {
*output = (*c_func_ptr_)(input); *output = c_func_ptr_(input);
} catch (const std::exception &e) { } catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED("Unexpected error in CFuncOp"); RETURN_STATUS_UNEXPECTED("Unexpected error in CFuncOp");
} }

View File

@ -29,7 +29,7 @@ namespace mindspore {
namespace dataset { namespace dataset {
class CFuncOp : public TensorOp { class CFuncOp : public TensorOp {
public: public:
explicit CFuncOp(TensorRow (*func)(TensorRow)) : c_func_ptr_(func) {} explicit CFuncOp(std::function<TensorRow(TensorRow)> func) : c_func_ptr_(func) {}
~CFuncOp() override = default; ~CFuncOp() override = default;
@ -42,7 +42,7 @@ class CFuncOp : public TensorOp {
std::string Name() const override { return kCFuncOp; } std::string Name() const override { return kCFuncOp; }
private: private:
TensorRow (*c_func_ptr_)(TensorRow); std::function<TensorRow(TensorRow)> c_func_ptr_;
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -38,8 +38,7 @@ TEST_F(MindDataTestTensorRowDE, ConvertToTensorRowBoolTest) {
ASSERT_EQ(s, Status::OK()); ASSERT_EQ(s, Status::OK());
TensorRow expected_bool; TensorRow expected_bool;
std::shared_ptr<Tensor> expected_tensor; std::shared_ptr<Tensor> expected_tensor;
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_BOOL), &expected_tensor); Tensor::CreateScalar(bool_value, &expected_tensor);
expected_tensor->SetItemAt<bool>({0}, bool_value);
expected_bool.push_back(expected_tensor); expected_bool.push_back(expected_tensor);
ASSERT_EQ(*(bool_output.at(0)) == *(expected_bool.at(0)), true); ASSERT_EQ(*(bool_output.at(0)) == *(expected_bool.at(0)), true);
} }
@ -52,8 +51,7 @@ TEST_F(MindDataTestTensorRowDE, ConvertToTensorRowIntTest) {
s = TensorRow::ConvertToTensorRow(int_value, &int_output); s = TensorRow::ConvertToTensorRow(int_value, &int_output);
ASSERT_EQ(s, Status::OK()); ASSERT_EQ(s, Status::OK());
std::shared_ptr<Tensor> expected_tensor; std::shared_ptr<Tensor> expected_tensor;
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_INT32), &expected_tensor); Tensor::CreateScalar(int_value, &expected_tensor);
expected_tensor->SetItemAt({0}, int_value);
expected_int.push_back(expected_tensor); expected_int.push_back(expected_tensor);
ASSERT_EQ(*(int_output.at(0)) == *(expected_int.at(0)), true); ASSERT_EQ(*(int_output.at(0)) == *(expected_int.at(0)), true);
} }
@ -67,8 +65,7 @@ TEST_F(MindDataTestTensorRowDE, ConvertToTensorRowFloatTest) {
s = TensorRow::ConvertToTensorRow(float_value, &float_output); s = TensorRow::ConvertToTensorRow(float_value, &float_output);
ASSERT_EQ(s, Status::OK()); ASSERT_EQ(s, Status::OK());
std::shared_ptr<Tensor> expected_tensor; std::shared_ptr<Tensor> expected_tensor;
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_FLOAT32), &expected_tensor); Tensor::CreateScalar(float_value, &expected_tensor);
expected_tensor->SetItemAt({0}, float_value);
expected_float.push_back(expected_tensor); expected_float.push_back(expected_tensor);
ASSERT_EQ(*(float_output.at(0)) == *(expected_float.at(0)), true); ASSERT_EQ(*(float_output.at(0)) == *(expected_float.at(0)), true);
} }
@ -80,15 +77,10 @@ TEST_F(MindDataTestTensorRowDE, ConvertToTensorRowBoolVectorTest) {
s = TensorRow::ConvertToTensorRow(bool_value, &bool_output); s = TensorRow::ConvertToTensorRow(bool_value, &bool_output);
ASSERT_EQ(s, Status::OK()); ASSERT_EQ(s, Status::OK());
TensorRow expected_bool; TensorRow expected_bool;
std::shared_ptr<Tensor> expected_tensor, expected_tensor2; std::shared_ptr<Tensor> expected_tensor;
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_BOOL), &expected_tensor); Tensor::CreateFromVector<bool>(bool_value, &expected_tensor);
expected_tensor->SetItemAt<bool>({0}, bool_value[0]);
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_BOOL), &expected_tensor2);
expected_tensor2->SetItemAt<bool>({0}, bool_value[1]);
expected_bool.push_back(expected_tensor); expected_bool.push_back(expected_tensor);
expected_bool.push_back(expected_tensor2);
ASSERT_EQ(*(bool_output.at(0)) == *(expected_bool.at(0)), true); ASSERT_EQ(*(bool_output.at(0)) == *(expected_bool.at(0)), true);
ASSERT_EQ(*(bool_output.at(1)) == *(expected_bool.at(1)), true);
} }
TEST_F(MindDataTestTensorRowDE, ConvertToTensorRowIntVectorTest) { TEST_F(MindDataTestTensorRowDE, ConvertToTensorRowIntVectorTest) {
@ -98,15 +90,10 @@ TEST_F(MindDataTestTensorRowDE, ConvertToTensorRowIntVectorTest) {
TensorRow expected_int; TensorRow expected_int;
s = TensorRow::ConvertToTensorRow(int_value, &int_output); s = TensorRow::ConvertToTensorRow(int_value, &int_output);
ASSERT_EQ(s, Status::OK()); ASSERT_EQ(s, Status::OK());
std::shared_ptr<Tensor> expected_tensor, expected_tensor2; std::shared_ptr<Tensor> expected_tensor;
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_UINT64), &expected_tensor); Tensor::CreateFromVector(int_value, &expected_tensor);
expected_tensor->SetItemAt({0}, int_value[0]);
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_UINT64), &expected_tensor2);
expected_tensor2->SetItemAt({0}, int_value[1]);
expected_int.push_back(expected_tensor); expected_int.push_back(expected_tensor);
expected_int.push_back(expected_tensor2);
ASSERT_EQ(*(int_output.at(0)) == *(expected_int.at(0)), true); ASSERT_EQ(*(int_output.at(0)) == *(expected_int.at(0)), true);
ASSERT_EQ(*(int_output.at(1)) == *(expected_int.at(1)), true);
} }
TEST_F(MindDataTestTensorRowDE, ConvertToTensorRowFloatVectorTest) { TEST_F(MindDataTestTensorRowDE, ConvertToTensorRowFloatVectorTest) {
@ -116,15 +103,10 @@ TEST_F(MindDataTestTensorRowDE, ConvertToTensorRowFloatVectorTest) {
TensorRow expected_float; TensorRow expected_float;
s = TensorRow::ConvertToTensorRow(float_value, &float_output); s = TensorRow::ConvertToTensorRow(float_value, &float_output);
ASSERT_EQ(s, Status::OK()); ASSERT_EQ(s, Status::OK());
std::shared_ptr<Tensor> expected_tensor, expected_tensor2; std::shared_ptr<Tensor> expected_tensor;
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_FLOAT64), &expected_tensor); Tensor::CreateFromVector(float_value, &expected_tensor);
expected_tensor->SetItemAt({0}, float_value[0]);
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_FLOAT64), &expected_tensor2);
expected_tensor2->SetItemAt({0}, float_value[1]);
expected_float.push_back(expected_tensor); expected_float.push_back(expected_tensor);
expected_float.push_back(expected_tensor2);
ASSERT_EQ(*(float_output.at(0)) == *(expected_float.at(0)), true); ASSERT_EQ(*(float_output.at(0)) == *(expected_float.at(0)), true);
ASSERT_EQ(*(float_output.at(1)) == *(expected_float.at(1)), true);
} }
TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowBoolTest) { TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowBoolTest) {
@ -133,8 +115,7 @@ TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowBoolTest) {
bool result; bool result;
TensorRow input_tensor_row; TensorRow input_tensor_row;
std::shared_ptr<Tensor> input_tensor; std::shared_ptr<Tensor> input_tensor;
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_BOOL), &input_tensor); Tensor::CreateScalar(bool_value, &input_tensor);
input_tensor->SetItemAt<bool>({0}, bool_value);
input_tensor_row.push_back(input_tensor); input_tensor_row.push_back(input_tensor);
s = TensorRow::ConvertFromTensorRow(input_tensor_row, &result); s = TensorRow::ConvertFromTensorRow(input_tensor_row, &result);
ASSERT_EQ(s, Status::OK()); ASSERT_EQ(s, Status::OK());
@ -147,8 +128,7 @@ TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowIntTest) {
int32_t result; int32_t result;
TensorRow input_tensor_row; TensorRow input_tensor_row;
std::shared_ptr<Tensor> input_tensor; std::shared_ptr<Tensor> input_tensor;
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_INT32), &input_tensor); Tensor::CreateScalar(int_value, &input_tensor);
input_tensor->SetItemAt({0}, int_value);
input_tensor_row.push_back(input_tensor); input_tensor_row.push_back(input_tensor);
s = TensorRow::ConvertFromTensorRow(input_tensor_row, &result); s = TensorRow::ConvertFromTensorRow(input_tensor_row, &result);
ASSERT_EQ(s, Status::OK()); ASSERT_EQ(s, Status::OK());
@ -161,8 +141,7 @@ TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowFloatTest) {
float result; float result;
TensorRow input_tensor_row; TensorRow input_tensor_row;
std::shared_ptr<Tensor> input_tensor; std::shared_ptr<Tensor> input_tensor;
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_FLOAT32), &input_tensor); Tensor::CreateScalar(float_value, &input_tensor);
input_tensor->SetItemAt({0}, float_value);
input_tensor_row.push_back(input_tensor); input_tensor_row.push_back(input_tensor);
s = TensorRow::ConvertFromTensorRow(input_tensor_row, &result); s = TensorRow::ConvertFromTensorRow(input_tensor_row, &result);
ASSERT_EQ(s, Status::OK()); ASSERT_EQ(s, Status::OK());
@ -174,13 +153,9 @@ TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowBoolVectorTest) {
std::vector<bool> bool_value = {true, false}; std::vector<bool> bool_value = {true, false};
std::vector<bool> result; std::vector<bool> result;
TensorRow input_tensor_row; TensorRow input_tensor_row;
std::shared_ptr<Tensor> input_tensor1, input_tensor2; std::shared_ptr<Tensor> input_tensor;
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_BOOL), &input_tensor1); Tensor::CreateFromVector<bool>(bool_value, &input_tensor);
input_tensor1->SetItemAt<bool>({0}, bool_value[0]); input_tensor_row.push_back(input_tensor);
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_BOOL), &input_tensor2);
input_tensor2->SetItemAt<bool>({0}, bool_value[1]);
input_tensor_row.push_back(input_tensor1);
input_tensor_row.push_back(input_tensor2);
s = TensorRow::ConvertFromTensorRow(input_tensor_row, &result); s = TensorRow::ConvertFromTensorRow(input_tensor_row, &result);
ASSERT_EQ(s, Status::OK()); ASSERT_EQ(s, Status::OK());
ASSERT_EQ(result, bool_value); ASSERT_EQ(result, bool_value);
@ -191,13 +166,9 @@ TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowIntVectorTest) {
std::vector<uint64_t> int_value = {12, 16}; std::vector<uint64_t> int_value = {12, 16};
std::vector<uint64_t> result; std::vector<uint64_t> result;
TensorRow input_tensor_row; TensorRow input_tensor_row;
std::shared_ptr<Tensor> input_tensor1, input_tensor2; std::shared_ptr<Tensor> input_tensor;
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_UINT64), &input_tensor1); Tensor::CreateFromVector(int_value, &input_tensor);
input_tensor1->SetItemAt({0}, int_value[0]); input_tensor_row.push_back(input_tensor);
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_UINT64), &input_tensor2);
input_tensor2->SetItemAt({0}, int_value[1]);
input_tensor_row.push_back(input_tensor1);
input_tensor_row.push_back(input_tensor2);
s = TensorRow::ConvertFromTensorRow(input_tensor_row, &result); s = TensorRow::ConvertFromTensorRow(input_tensor_row, &result);
ASSERT_EQ(s, Status::OK()); ASSERT_EQ(s, Status::OK());
ASSERT_EQ(result, int_value); ASSERT_EQ(result, int_value);
@ -208,13 +179,9 @@ TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowFloatVectorTest) {
std::vector<double> float_value = {12.57, 0.264}; std::vector<double> float_value = {12.57, 0.264};
std::vector<double> result; std::vector<double> result;
TensorRow input_tensor_row; TensorRow input_tensor_row;
std::shared_ptr<Tensor> input_tensor1, input_tensor2; std::shared_ptr<Tensor> input_tensor;
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_FLOAT64), &input_tensor1); Tensor::CreateFromVector(float_value, &input_tensor);
input_tensor1->SetItemAt({0}, float_value[0]); input_tensor_row.push_back(input_tensor);
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_FLOAT64), &input_tensor2);
input_tensor2->SetItemAt({0}, float_value[1]);
input_tensor_row.push_back(input_tensor1);
input_tensor_row.push_back(input_tensor2);
s = TensorRow::ConvertFromTensorRow(input_tensor_row, &result); s = TensorRow::ConvertFromTensorRow(input_tensor_row, &result);
ASSERT_EQ(s, Status::OK()); ASSERT_EQ(s, Status::OK());
ASSERT_EQ(result, float_value); ASSERT_EQ(result, float_value);
@ -231,8 +198,7 @@ TEST_F(MindDataTestTensorRowDE, ConvertToTensorRowInvalidDataTest) {
TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowTypeMismatchTest) { TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowTypeMismatchTest) {
TensorRow input_tensor_row; TensorRow input_tensor_row;
std::shared_ptr<Tensor> input_tensor1; std::shared_ptr<Tensor> input_tensor1;
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_BOOL), &input_tensor1); Tensor::CreateScalar(false, &input_tensor1);
input_tensor1->SetItemAt({0}, false);
input_tensor_row.push_back(input_tensor1); input_tensor_row.push_back(input_tensor1);
double output; double output;
ASSERT_FALSE(TensorRow::ConvertFromTensorRow(input_tensor_row, &output).IsOk()); ASSERT_FALSE(TensorRow::ConvertFromTensorRow(input_tensor_row, &output).IsOk());
@ -243,7 +209,7 @@ TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowTypeMismatchTest) {
TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowInvalidShapeTest) { TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowInvalidShapeTest) {
TensorRow input_tensor_row; TensorRow input_tensor_row;
std::shared_ptr<Tensor> input_tensor1; std::shared_ptr<Tensor> input_tensor1;
Tensor::CreateEmpty(TensorShape({2}), DataType(DataType::DE_FLOAT64), &input_tensor1); Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_FLOAT64), &input_tensor1);
input_tensor_row.push_back(input_tensor1); input_tensor_row.push_back(input_tensor1);
std::vector<double> output; std::vector<double> output;
ASSERT_FALSE(TensorRow::ConvertFromTensorRow(input_tensor_row, &output).IsOk()); ASSERT_FALSE(TensorRow::ConvertFromTensorRow(input_tensor_row, &output).IsOk());