!7049 Fixes to TensorRow conversion and Tensor
Merge pull request !7049 from MahdiRahmaniHanzaki/c_func
This commit is contained in:
commit
03f0e64af9
|
@ -294,7 +294,7 @@ std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remai
|
|||
// Function to create a BucketBatchByLength dataset
|
||||
std::shared_ptr<BucketBatchByLengthDataset> Dataset::BucketBatchByLength(
|
||||
const std::vector<std::string> &column_names, const std::vector<int32_t> &bucket_boundaries,
|
||||
const std::vector<int32_t> &bucket_batch_sizes, TensorRow (*element_length_function)(TensorRow),
|
||||
const std::vector<int32_t> &bucket_batch_sizes, std::function<TensorRow(TensorRow)> element_length_function,
|
||||
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,
|
||||
bool drop_remainder) {
|
||||
auto ds = std::make_shared<BucketBatchByLengthDataset>(column_names, bucket_boundaries, bucket_batch_sizes,
|
||||
|
@ -1698,7 +1698,7 @@ bool BatchDataset::ValidateParams() {
|
|||
#ifndef ENABLE_ANDROID
|
||||
BucketBatchByLengthDataset::BucketBatchByLengthDataset(
|
||||
const std::vector<std::string> &column_names, const std::vector<int32_t> &bucket_boundaries,
|
||||
const std::vector<int32_t> &bucket_batch_sizes, TensorRow (*element_length_function)(TensorRow),
|
||||
const std::vector<int32_t> &bucket_batch_sizes, std::function<TensorRow(TensorRow)> element_length_function,
|
||||
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,
|
||||
bool drop_remainder)
|
||||
: column_names_(column_names),
|
||||
|
|
|
@ -174,6 +174,18 @@ class Tensor {
|
|||
return CreateFromVector(items, TensorShape({static_cast<dsize_t>(items.size())}), out);
|
||||
}
|
||||
|
||||
/// Create a 1D boolean Tensor from a given list of boolean values.
|
||||
/// \param[in] items elements of the tensor
|
||||
/// \param[in] shape shape of the output tensor
|
||||
/// \param[out] out output argument to hold the created Tensor
|
||||
/// \return Status Code
|
||||
static Status CreateFromVector(const std::vector<bool> &items, const TensorShape &shape, TensorPtr *out) {
|
||||
std::vector<uint8_t> temp(items.begin(), items.end());
|
||||
RETURN_IF_NOT_OK(CreateFromVector(temp, shape, out));
|
||||
(*out)->type_ = DataType(DataType::DE_BOOL);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// Create a numeric scalar Tensor from the given value.
|
||||
/// \tparam T type of value
|
||||
/// \param[in] item value
|
||||
|
|
|
@ -72,7 +72,7 @@ class TensorRow {
|
|||
// Destructor
|
||||
~TensorRow() = default;
|
||||
|
||||
/// Convert a vector of primitive types to a TensorRow consisting of n single data Tensors.
|
||||
/// Convert a vector of primitive types to a TensorRow consisting of one 1-D Tensor with the shape n.
|
||||
/// \tparam `T`
|
||||
/// \param[in] o input vector
|
||||
/// \param[out] output TensorRow
|
||||
|
@ -85,14 +85,9 @@ class TensorRow {
|
|||
if (data_type == DataType::DE_STRING) {
|
||||
RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported.");
|
||||
}
|
||||
|
||||
for (int i = 0; i < o.size(); i++) {
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
Tensor::CreateEmpty(TensorShape({1}), data_type, &tensor);
|
||||
std::string_view s;
|
||||
tensor->SetItemAt({0}, o[i]);
|
||||
output->push_back(tensor);
|
||||
}
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(o, &tensor));
|
||||
output->push_back(tensor);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -110,13 +105,12 @@ class TensorRow {
|
|||
RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported.");
|
||||
}
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
Tensor::CreateEmpty(TensorShape({1}), data_type, &tensor);
|
||||
tensor->SetItemAt({0}, o);
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(o, &tensor));
|
||||
output->push_back(tensor);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// Return the value in a TensorRow consiting of 1 single data Tensor
|
||||
/// Return the value in a TensorRow consisting of 1 single data Tensor.
|
||||
/// \tparam `T`
|
||||
/// \param[in] input TensorRow
|
||||
/// \param[out] o the primitive variable
|
||||
|
@ -127,23 +121,23 @@ class TensorRow {
|
|||
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: Data type was not recognized.");
|
||||
}
|
||||
if (data_type == DataType::DE_STRING) {
|
||||
RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported.");
|
||||
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: Data type string is not supported.");
|
||||
}
|
||||
if (input.size() != 1) {
|
||||
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input TensorRow is empty.");
|
||||
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input TensorRow must have exactly one tensor.");
|
||||
}
|
||||
if (input.at(0)->type() != data_type) {
|
||||
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The output type doesn't match the input tensor type.");
|
||||
}
|
||||
if (input.at(0)->shape() != TensorShape({1})) {
|
||||
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input tensors must have a shape of {1}.");
|
||||
if (input.at(0)->shape() != TensorShape({})) {
|
||||
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input tensors must be a scalar tensor.");
|
||||
}
|
||||
return input.at(0)->GetItemAt(o, {0});
|
||||
}
|
||||
|
||||
/// Convert a TensorRow consisting of n single data tensors to a vector of size n
|
||||
/// Convert a TensorRow consisting of one 1-D tensor to a vector of size n.
|
||||
/// \tparam `T`
|
||||
/// \param[in] o TensorRow consisting of n single data tensors
|
||||
/// \param[in] o TensorRow consisting of one 1-D tensor
|
||||
/// \param[out] o vector of primitive variable
|
||||
template <typename T>
|
||||
static Status ConvertFromTensorRow(const TensorRow &input, std::vector<T> *o) {
|
||||
|
@ -152,15 +146,15 @@ class TensorRow {
|
|||
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: Data type was not recognized.");
|
||||
}
|
||||
if (data_type == DataType::DE_STRING) {
|
||||
RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported.");
|
||||
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: Data type string is not supported.");
|
||||
}
|
||||
for (int i = 0; i < input.size(); i++) {
|
||||
if (input.at(i)->shape() != TensorShape({1})) {
|
||||
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input tensor must have a shape of 1.");
|
||||
}
|
||||
T item;
|
||||
RETURN_IF_NOT_OK(input.at(i)->GetItemAt(&item, {0}));
|
||||
o->push_back(item);
|
||||
if (input.size() != 1) {
|
||||
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input TensorRow must have exactly one tensor.");
|
||||
}
|
||||
if (input.at(0)->Rank() != 1)
|
||||
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input tensor must have a rank of 1.");
|
||||
for (auto it = input.at(0)->begin<T>(); it != input.at(0)->end<T>(); it++) {
|
||||
o->push_back(*it);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -507,7 +507,8 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
/// \return Shared pointer to the current BucketBatchByLengthDataset
|
||||
std::shared_ptr<BucketBatchByLengthDataset> BucketBatchByLength(
|
||||
const std::vector<std::string> &column_names, const std::vector<int32_t> &bucket_boundaries,
|
||||
const std::vector<int32_t> &bucket_batch_sizes, TensorRow (*element_length_function)(TensorRow) = nullptr,
|
||||
const std::vector<int32_t> &bucket_batch_sizes,
|
||||
std::function<TensorRow(TensorRow)> element_length_function = nullptr,
|
||||
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {},
|
||||
bool pad_to_bucket_boundary = false, bool drop_remainder = false);
|
||||
|
||||
|
@ -1156,7 +1157,8 @@ class BucketBatchByLengthDataset : public Dataset {
|
|||
/// \brief Constructor
|
||||
BucketBatchByLengthDataset(
|
||||
const std::vector<std::string> &column_names, const std::vector<int32_t> &bucket_boundaries,
|
||||
const std::vector<int32_t> &bucket_batch_sizes, TensorRow (*element_length_function)(TensorRow) = nullptr,
|
||||
const std::vector<int32_t> &bucket_batch_sizes,
|
||||
std::function<TensorRow(TensorRow)> element_length_function = nullptr,
|
||||
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {},
|
||||
bool pad_to_bucket_boundary = false, bool drop_remainder = false);
|
||||
|
||||
|
@ -1175,7 +1177,7 @@ class BucketBatchByLengthDataset : public Dataset {
|
|||
std::vector<std::string> column_names_;
|
||||
std::vector<int32_t> bucket_boundaries_;
|
||||
std::vector<int32_t> bucket_batch_sizes_;
|
||||
TensorRow (*element_length_function_)(TensorRow);
|
||||
std::function<TensorRow(TensorRow)> element_length_function_;
|
||||
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_info_;
|
||||
bool pad_to_bucket_boundary_;
|
||||
bool drop_remainder_;
|
||||
|
|
|
@ -27,7 +27,7 @@ Status CFuncOp::Compute(const TensorRow &input, TensorRow *output) {
|
|||
IO_CHECK_VECTOR(input, output);
|
||||
Status ret = Status(StatusCode::kOK, "CFunc Call Succeed");
|
||||
try {
|
||||
*output = (*c_func_ptr_)(input);
|
||||
*output = c_func_ptr_(input);
|
||||
} catch (const std::exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED("Unexpected error in CFuncOp");
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
class CFuncOp : public TensorOp {
|
||||
public:
|
||||
explicit CFuncOp(TensorRow (*func)(TensorRow)) : c_func_ptr_(func) {}
|
||||
explicit CFuncOp(std::function<TensorRow(TensorRow)> func) : c_func_ptr_(func) {}
|
||||
|
||||
~CFuncOp() override = default;
|
||||
|
||||
|
@ -42,7 +42,7 @@ class CFuncOp : public TensorOp {
|
|||
std::string Name() const override { return kCFuncOp; }
|
||||
|
||||
private:
|
||||
TensorRow (*c_func_ptr_)(TensorRow);
|
||||
std::function<TensorRow(TensorRow)> c_func_ptr_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,8 +38,7 @@ TEST_F(MindDataTestTensorRowDE, ConvertToTensorRowBoolTest) {
|
|||
ASSERT_EQ(s, Status::OK());
|
||||
TensorRow expected_bool;
|
||||
std::shared_ptr<Tensor> expected_tensor;
|
||||
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_BOOL), &expected_tensor);
|
||||
expected_tensor->SetItemAt<bool>({0}, bool_value);
|
||||
Tensor::CreateScalar(bool_value, &expected_tensor);
|
||||
expected_bool.push_back(expected_tensor);
|
||||
ASSERT_EQ(*(bool_output.at(0)) == *(expected_bool.at(0)), true);
|
||||
}
|
||||
|
@ -52,8 +51,7 @@ TEST_F(MindDataTestTensorRowDE, ConvertToTensorRowIntTest) {
|
|||
s = TensorRow::ConvertToTensorRow(int_value, &int_output);
|
||||
ASSERT_EQ(s, Status::OK());
|
||||
std::shared_ptr<Tensor> expected_tensor;
|
||||
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_INT32), &expected_tensor);
|
||||
expected_tensor->SetItemAt({0}, int_value);
|
||||
Tensor::CreateScalar(int_value, &expected_tensor);
|
||||
expected_int.push_back(expected_tensor);
|
||||
ASSERT_EQ(*(int_output.at(0)) == *(expected_int.at(0)), true);
|
||||
}
|
||||
|
@ -67,8 +65,7 @@ TEST_F(MindDataTestTensorRowDE, ConvertToTensorRowFloatTest) {
|
|||
s = TensorRow::ConvertToTensorRow(float_value, &float_output);
|
||||
ASSERT_EQ(s, Status::OK());
|
||||
std::shared_ptr<Tensor> expected_tensor;
|
||||
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_FLOAT32), &expected_tensor);
|
||||
expected_tensor->SetItemAt({0}, float_value);
|
||||
Tensor::CreateScalar(float_value, &expected_tensor);
|
||||
expected_float.push_back(expected_tensor);
|
||||
ASSERT_EQ(*(float_output.at(0)) == *(expected_float.at(0)), true);
|
||||
}
|
||||
|
@ -80,15 +77,10 @@ TEST_F(MindDataTestTensorRowDE, ConvertToTensorRowBoolVectorTest) {
|
|||
s = TensorRow::ConvertToTensorRow(bool_value, &bool_output);
|
||||
ASSERT_EQ(s, Status::OK());
|
||||
TensorRow expected_bool;
|
||||
std::shared_ptr<Tensor> expected_tensor, expected_tensor2;
|
||||
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_BOOL), &expected_tensor);
|
||||
expected_tensor->SetItemAt<bool>({0}, bool_value[0]);
|
||||
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_BOOL), &expected_tensor2);
|
||||
expected_tensor2->SetItemAt<bool>({0}, bool_value[1]);
|
||||
std::shared_ptr<Tensor> expected_tensor;
|
||||
Tensor::CreateFromVector<bool>(bool_value, &expected_tensor);
|
||||
expected_bool.push_back(expected_tensor);
|
||||
expected_bool.push_back(expected_tensor2);
|
||||
ASSERT_EQ(*(bool_output.at(0)) == *(expected_bool.at(0)), true);
|
||||
ASSERT_EQ(*(bool_output.at(1)) == *(expected_bool.at(1)), true);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestTensorRowDE, ConvertToTensorRowIntVectorTest) {
|
||||
|
@ -98,15 +90,10 @@ TEST_F(MindDataTestTensorRowDE, ConvertToTensorRowIntVectorTest) {
|
|||
TensorRow expected_int;
|
||||
s = TensorRow::ConvertToTensorRow(int_value, &int_output);
|
||||
ASSERT_EQ(s, Status::OK());
|
||||
std::shared_ptr<Tensor> expected_tensor, expected_tensor2;
|
||||
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_UINT64), &expected_tensor);
|
||||
expected_tensor->SetItemAt({0}, int_value[0]);
|
||||
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_UINT64), &expected_tensor2);
|
||||
expected_tensor2->SetItemAt({0}, int_value[1]);
|
||||
std::shared_ptr<Tensor> expected_tensor;
|
||||
Tensor::CreateFromVector(int_value, &expected_tensor);
|
||||
expected_int.push_back(expected_tensor);
|
||||
expected_int.push_back(expected_tensor2);
|
||||
ASSERT_EQ(*(int_output.at(0)) == *(expected_int.at(0)), true);
|
||||
ASSERT_EQ(*(int_output.at(1)) == *(expected_int.at(1)), true);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestTensorRowDE, ConvertToTensorRowFloatVectorTest) {
|
||||
|
@ -116,15 +103,10 @@ TEST_F(MindDataTestTensorRowDE, ConvertToTensorRowFloatVectorTest) {
|
|||
TensorRow expected_float;
|
||||
s = TensorRow::ConvertToTensorRow(float_value, &float_output);
|
||||
ASSERT_EQ(s, Status::OK());
|
||||
std::shared_ptr<Tensor> expected_tensor, expected_tensor2;
|
||||
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_FLOAT64), &expected_tensor);
|
||||
expected_tensor->SetItemAt({0}, float_value[0]);
|
||||
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_FLOAT64), &expected_tensor2);
|
||||
expected_tensor2->SetItemAt({0}, float_value[1]);
|
||||
std::shared_ptr<Tensor> expected_tensor;
|
||||
Tensor::CreateFromVector(float_value, &expected_tensor);
|
||||
expected_float.push_back(expected_tensor);
|
||||
expected_float.push_back(expected_tensor2);
|
||||
ASSERT_EQ(*(float_output.at(0)) == *(expected_float.at(0)), true);
|
||||
ASSERT_EQ(*(float_output.at(1)) == *(expected_float.at(1)), true);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowBoolTest) {
|
||||
|
@ -133,8 +115,7 @@ TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowBoolTest) {
|
|||
bool result;
|
||||
TensorRow input_tensor_row;
|
||||
std::shared_ptr<Tensor> input_tensor;
|
||||
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_BOOL), &input_tensor);
|
||||
input_tensor->SetItemAt<bool>({0}, bool_value);
|
||||
Tensor::CreateScalar(bool_value, &input_tensor);
|
||||
input_tensor_row.push_back(input_tensor);
|
||||
s = TensorRow::ConvertFromTensorRow(input_tensor_row, &result);
|
||||
ASSERT_EQ(s, Status::OK());
|
||||
|
@ -147,8 +128,7 @@ TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowIntTest) {
|
|||
int32_t result;
|
||||
TensorRow input_tensor_row;
|
||||
std::shared_ptr<Tensor> input_tensor;
|
||||
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_INT32), &input_tensor);
|
||||
input_tensor->SetItemAt({0}, int_value);
|
||||
Tensor::CreateScalar(int_value, &input_tensor);
|
||||
input_tensor_row.push_back(input_tensor);
|
||||
s = TensorRow::ConvertFromTensorRow(input_tensor_row, &result);
|
||||
ASSERT_EQ(s, Status::OK());
|
||||
|
@ -161,8 +141,7 @@ TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowFloatTest) {
|
|||
float result;
|
||||
TensorRow input_tensor_row;
|
||||
std::shared_ptr<Tensor> input_tensor;
|
||||
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_FLOAT32), &input_tensor);
|
||||
input_tensor->SetItemAt({0}, float_value);
|
||||
Tensor::CreateScalar(float_value, &input_tensor);
|
||||
input_tensor_row.push_back(input_tensor);
|
||||
s = TensorRow::ConvertFromTensorRow(input_tensor_row, &result);
|
||||
ASSERT_EQ(s, Status::OK());
|
||||
|
@ -174,13 +153,9 @@ TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowBoolVectorTest) {
|
|||
std::vector<bool> bool_value = {true, false};
|
||||
std::vector<bool> result;
|
||||
TensorRow input_tensor_row;
|
||||
std::shared_ptr<Tensor> input_tensor1, input_tensor2;
|
||||
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_BOOL), &input_tensor1);
|
||||
input_tensor1->SetItemAt<bool>({0}, bool_value[0]);
|
||||
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_BOOL), &input_tensor2);
|
||||
input_tensor2->SetItemAt<bool>({0}, bool_value[1]);
|
||||
input_tensor_row.push_back(input_tensor1);
|
||||
input_tensor_row.push_back(input_tensor2);
|
||||
std::shared_ptr<Tensor> input_tensor;
|
||||
Tensor::CreateFromVector<bool>(bool_value, &input_tensor);
|
||||
input_tensor_row.push_back(input_tensor);
|
||||
s = TensorRow::ConvertFromTensorRow(input_tensor_row, &result);
|
||||
ASSERT_EQ(s, Status::OK());
|
||||
ASSERT_EQ(result, bool_value);
|
||||
|
@ -191,13 +166,9 @@ TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowIntVectorTest) {
|
|||
std::vector<uint64_t> int_value = {12, 16};
|
||||
std::vector<uint64_t> result;
|
||||
TensorRow input_tensor_row;
|
||||
std::shared_ptr<Tensor> input_tensor1, input_tensor2;
|
||||
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_UINT64), &input_tensor1);
|
||||
input_tensor1->SetItemAt({0}, int_value[0]);
|
||||
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_UINT64), &input_tensor2);
|
||||
input_tensor2->SetItemAt({0}, int_value[1]);
|
||||
input_tensor_row.push_back(input_tensor1);
|
||||
input_tensor_row.push_back(input_tensor2);
|
||||
std::shared_ptr<Tensor> input_tensor;
|
||||
Tensor::CreateFromVector(int_value, &input_tensor);
|
||||
input_tensor_row.push_back(input_tensor);
|
||||
s = TensorRow::ConvertFromTensorRow(input_tensor_row, &result);
|
||||
ASSERT_EQ(s, Status::OK());
|
||||
ASSERT_EQ(result, int_value);
|
||||
|
@ -208,13 +179,9 @@ TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowFloatVectorTest) {
|
|||
std::vector<double> float_value = {12.57, 0.264};
|
||||
std::vector<double> result;
|
||||
TensorRow input_tensor_row;
|
||||
std::shared_ptr<Tensor> input_tensor1, input_tensor2;
|
||||
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_FLOAT64), &input_tensor1);
|
||||
input_tensor1->SetItemAt({0}, float_value[0]);
|
||||
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_FLOAT64), &input_tensor2);
|
||||
input_tensor2->SetItemAt({0}, float_value[1]);
|
||||
input_tensor_row.push_back(input_tensor1);
|
||||
input_tensor_row.push_back(input_tensor2);
|
||||
std::shared_ptr<Tensor> input_tensor;
|
||||
Tensor::CreateFromVector(float_value, &input_tensor);
|
||||
input_tensor_row.push_back(input_tensor);
|
||||
s = TensorRow::ConvertFromTensorRow(input_tensor_row, &result);
|
||||
ASSERT_EQ(s, Status::OK());
|
||||
ASSERT_EQ(result, float_value);
|
||||
|
@ -231,8 +198,7 @@ TEST_F(MindDataTestTensorRowDE, ConvertToTensorRowInvalidDataTest) {
|
|||
TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowTypeMismatchTest) {
|
||||
TensorRow input_tensor_row;
|
||||
std::shared_ptr<Tensor> input_tensor1;
|
||||
Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_BOOL), &input_tensor1);
|
||||
input_tensor1->SetItemAt({0}, false);
|
||||
Tensor::CreateScalar(false, &input_tensor1);
|
||||
input_tensor_row.push_back(input_tensor1);
|
||||
double output;
|
||||
ASSERT_FALSE(TensorRow::ConvertFromTensorRow(input_tensor_row, &output).IsOk());
|
||||
|
@ -243,7 +209,7 @@ TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowTypeMismatchTest) {
|
|||
TEST_F(MindDataTestTensorRowDE, ConvertFromTensorRowInvalidShapeTest) {
|
||||
TensorRow input_tensor_row;
|
||||
std::shared_ptr<Tensor> input_tensor1;
|
||||
Tensor::CreateEmpty(TensorShape({2}), DataType(DataType::DE_FLOAT64), &input_tensor1);
|
||||
Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_FLOAT64), &input_tensor1);
|
||||
input_tensor_row.push_back(input_tensor1);
|
||||
std::vector<double> output;
|
||||
ASSERT_FALSE(TensorRow::ConvertFromTensorRow(input_tensor_row, &output).IsOk());
|
||||
|
|
Loading…
Reference in New Issue