forked from mindspore-Ecosystem/mindspore
!3977 [MD] Refactor Concatenate Op
Merge pull request !3977 from nhussain/multi_dim_concat_2
This commit is contained in:
commit
7ec0b5857a
|
@ -526,16 +526,34 @@ Status Tensor::StartAddrOfIndex(std::vector<dsize_t> ind, uchar **start_addr_of_
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Tensor::InsertTensor(const std::vector<dsize_t> &ind, const std::shared_ptr<Tensor> &tensor) {
|
Status Tensor::InsertTensor(const std::vector<dsize_t> &ind, const std::shared_ptr<Tensor> &tensor,
|
||||||
|
const bool partial_insert) {
|
||||||
std::string err_msg;
|
std::string err_msg;
|
||||||
err_msg += (this->type() == DataType::DE_STRING) ? "[Tensor] Cannot batch tensors of type string\n" : "";
|
if (partial_insert) {
|
||||||
err_msg += (!this->shape().known() || !tensor->shape().known()) ? "[Tensor] unknown shape\n" : "";
|
err_msg += (ind.size() != 1)
|
||||||
err_msg += (ind.size() + tensor->Rank() != this->Rank()) ? "[Tensor] incorrect index\n" : "";
|
? "[Tensor] only supports 1D insertion of elements not along the full length of the axis\n"
|
||||||
err_msg += tensor->type().SizeInBytes() != this->type().SizeInBytes() ? "[Tensor] incorrect datatype\n" : "";
|
: "";
|
||||||
|
err_msg +=
|
||||||
|
(ind.at(0) + tensor->shape().NumOfElements() > shape().NumOfElements()) ? "[Tensor] incorrect index\n" : "";
|
||||||
|
} else {
|
||||||
|
err_msg += (ind.size() + tensor->Rank() != Rank()) ? "[Tensor] incorrect index\n" : "";
|
||||||
|
}
|
||||||
|
err_msg += (type() == DataType::DE_STRING) ? "[Tensor] Cannot insert into a tensor of type string\n" : "";
|
||||||
|
err_msg += (!shape().known() || !tensor->shape().known()) ? "[Tensor] unknown shape\n" : "";
|
||||||
|
|
||||||
|
err_msg += tensor->type().SizeInBytes() != type().SizeInBytes() ? "[Tensor] incorrect datatype\n" : "";
|
||||||
uchar *start_addr_of_ind = nullptr;
|
uchar *start_addr_of_ind = nullptr;
|
||||||
|
if (partial_insert) {
|
||||||
|
TensorShape remaining_shape = tensor->shape();
|
||||||
|
err_msg +=
|
||||||
|
(!StartAddrOfIndex(ind, &start_addr_of_ind, &remaining_shape).IsOk()) ? "[Tensor] incorrect index\n" : "";
|
||||||
|
} else {
|
||||||
TensorShape remaining_shape = TensorShape::CreateUnknownRankShape();
|
TensorShape remaining_shape = TensorShape::CreateUnknownRankShape();
|
||||||
err_msg += (!StartAddrOfIndex(ind, &start_addr_of_ind, &remaining_shape).IsOk()) ? "[Tensor] incorrect index\n" : "";
|
err_msg +=
|
||||||
|
(!StartAddrOfIndex(ind, &start_addr_of_ind, &remaining_shape).IsOk()) ? "[Tensor] incorrect index\n" : "";
|
||||||
err_msg += !(remaining_shape == tensor->shape()) ? "[Tensor] memory error\n" : "";
|
err_msg += !(remaining_shape == tensor->shape()) ? "[Tensor] memory error\n" : "";
|
||||||
|
}
|
||||||
|
|
||||||
if (!err_msg.empty()) {
|
if (!err_msg.empty()) {
|
||||||
MS_LOG(DEBUG) << "Insert tensor message: " << err_msg;
|
MS_LOG(DEBUG) << "Insert tensor message: " << err_msg;
|
||||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||||
|
@ -556,39 +574,6 @@ Status Tensor::InsertTensor(const std::vector<dsize_t> &ind, const std::shared_p
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Tensor::Concatenate(const std::vector<dsize_t> &index, const std::shared_ptr<Tensor> &tensor) {
|
|
||||||
std::string err_msg;
|
|
||||||
err_msg += (index.size() != 1) ? "[Tensor] only supports 1d concatenation \n" : "";
|
|
||||||
err_msg += (type() == DataType::DE_STRING) ? "[Tensor] Cannot batch tensors of type string\n" : "";
|
|
||||||
err_msg += (!shape().known() || !tensor->shape().known()) ? "[Tensor] unknown shape\n" : "";
|
|
||||||
|
|
||||||
err_msg +=
|
|
||||||
(index.at(0) + tensor->shape().NumOfElements() > this->shape().NumOfElements()) ? "[Tensor] incorrect index\n" : "";
|
|
||||||
err_msg += tensor->type().SizeInBytes() != this->type().SizeInBytes() ? "[Tensor] incorrect datatype\n" : "";
|
|
||||||
uchar *start_addr_of_ind = nullptr;
|
|
||||||
|
|
||||||
TensorShape remaining_shape = tensor->shape();
|
|
||||||
StartAddrOfIndex(index, &start_addr_of_ind, &remaining_shape);
|
|
||||||
err_msg += (start_addr_of_ind == nullptr) ? "Failed to create memory for Tensor.\n" : "";
|
|
||||||
|
|
||||||
if (!err_msg.empty()) {
|
|
||||||
MS_LOG(DEBUG) << "Insert tensor message: " << err_msg;
|
|
||||||
|
|
||||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
|
||||||
} else {
|
|
||||||
int ret_code =
|
|
||||||
memcpy_s(start_addr_of_ind, tensor->SizeInBytes(), tensor->GetMutableBuffer(), tensor->SizeInBytes());
|
|
||||||
|
|
||||||
if (ret_code == 0) {
|
|
||||||
return Status::OK();
|
|
||||||
} else {
|
|
||||||
err_msg += "[Tensor] error in memcpy_s when inserting tensor\n";
|
|
||||||
MS_LOG(DEBUG) << "Tensor message: " << err_msg;
|
|
||||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Tensor::ExpandDim(const dsize_t &axis) {
|
Status Tensor::ExpandDim(const dsize_t &axis) {
|
||||||
if (axis > Rank()) {
|
if (axis > Rank()) {
|
||||||
std::string err = "Axis is out of bound";
|
std::string err = "Axis is out of bound";
|
||||||
|
|
|
@ -330,8 +330,10 @@ class Tensor {
|
||||||
/// Ex: shape <2,3>, inserting into index {0} will replace the first row. index {1,2} will replace the last cell.
|
/// Ex: shape <2,3>, inserting into index {0} will replace the first row. index {1,2} will replace the last cell.
|
||||||
/// \param index
|
/// \param index
|
||||||
/// \param input
|
/// \param input
|
||||||
|
/// \param partial_insert: boolean to determine if insertion along the full axis is enforced
|
||||||
/// \return Status code
|
/// \return Status code
|
||||||
Status InsertTensor(const std::vector<dsize_t> &index, const std::shared_ptr<Tensor> &input);
|
Status InsertTensor(const std::vector<dsize_t> &index, const std::shared_ptr<Tensor> &input,
|
||||||
|
const bool partial_insert = false);
|
||||||
|
|
||||||
/// Find the address of the given index. Used in InsertTensor.
|
/// Find the address of the given index. Used in InsertTensor.
|
||||||
/// Example:
|
/// Example:
|
||||||
|
@ -393,9 +395,6 @@ class Tensor {
|
||||||
static Status GetBufferInfo(Tensor *t, py::buffer_info *out);
|
static Status GetBufferInfo(Tensor *t, py::buffer_info *out);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/// Concatenate based on given tensor, can fill in current tensor with a smaller one, unlike InsertTensor
|
|
||||||
Status Concatenate(const std::vector<dsize_t> &index, const std::shared_ptr<Tensor> &input);
|
|
||||||
|
|
||||||
/// TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor
|
/// TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor
|
||||||
/// The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6
|
/// The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6
|
||||||
/// \tparam T type of values in the Tensor Iterator
|
/// \tparam T type of values in the Tensor Iterator
|
||||||
|
|
|
@ -330,8 +330,10 @@ class Tensor {
|
||||||
/// Ex: shape <2,3>, inserting into index {0} will replace the first row. index {1,2} will replace the last cell.
|
/// Ex: shape <2,3>, inserting into index {0} will replace the first row. index {1,2} will replace the last cell.
|
||||||
/// \param index
|
/// \param index
|
||||||
/// \param input
|
/// \param input
|
||||||
|
/// \param partial_insert: boolean to determine if insertion along the full axis is enforced
|
||||||
/// \return Status code
|
/// \return Status code
|
||||||
Status InsertTensor(const std::vector<dsize_t> &index, const std::shared_ptr<Tensor> &input);
|
Status InsertTensor(const std::vector<dsize_t> &index, const std::shared_ptr<Tensor> &input,
|
||||||
|
const bool partial_insert = false);
|
||||||
|
|
||||||
/// Find the address of the given index. Used in InsertTensor.
|
/// Find the address of the given index. Used in InsertTensor.
|
||||||
/// Example:
|
/// Example:
|
||||||
|
@ -393,9 +395,6 @@ class Tensor {
|
||||||
static Status GetBufferInfo(Tensor *t, py::buffer_info *out);
|
static Status GetBufferInfo(Tensor *t, py::buffer_info *out);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/// Concatenate based on given tensor, can fill in current tensor with a smaller one, unlike InsertTensor
|
|
||||||
Status Concatenate(const std::vector<dsize_t> &index, const std::shared_ptr<Tensor> &input);
|
|
||||||
|
|
||||||
/// TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor
|
/// TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor
|
||||||
/// The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6
|
/// The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6
|
||||||
/// \tparam T type of values in the Tensor Iterator
|
/// \tparam T type of values in the Tensor Iterator
|
||||||
|
|
|
@ -580,77 +580,73 @@ Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu
|
||||||
|
|
||||||
Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr<Tensor> prepend,
|
Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr<Tensor> prepend,
|
||||||
std::shared_ptr<Tensor> append) {
|
std::shared_ptr<Tensor> append) {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Rank() == 1, "Only 1D tensors supported");
|
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(axis == 0 || axis == -1, "Only concatenation along the last dimension supported");
|
|
||||||
|
|
||||||
axis = Tensor::HandleNeg(axis, input[0]->shape().Rank());
|
axis = Tensor::HandleNeg(axis, input[0]->shape().Rank());
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(axis == 0, "Only axis=0 is supported");
|
CHECK_FAIL_RETURN_UNEXPECTED(axis == 0, "Only axis=0 is supported");
|
||||||
|
|
||||||
std::shared_ptr<Tensor> out;
|
TensorShape t = TensorShape::CreateScalar();
|
||||||
|
|
||||||
|
DataType first_dtype = input[0]->type();
|
||||||
|
|
||||||
|
TensorRow tensor_list;
|
||||||
|
|
||||||
if (prepend != nullptr) {
|
if (prepend != nullptr) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(first_dtype == prepend->type(), "Tensor types do not match");
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(prepend->shape().Rank() == 1, "Only 1D tensors supported");
|
CHECK_FAIL_RETURN_UNEXPECTED(prepend->shape().Rank() == 1, "Only 1D tensors supported");
|
||||||
RETURN_IF_NOT_OK(ConcatenateHelper(prepend, &out, axis, input[0]));
|
tensor_list.emplace_back(prepend);
|
||||||
} else {
|
|
||||||
out = input[0];
|
|
||||||
}
|
}
|
||||||
for (dsize_t i = 1; i < input.size(); i++) {
|
|
||||||
std::shared_ptr<Tensor> out_t;
|
for (dsize_t i = 0; i < input.size(); i++) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(first_dtype == input[i]->type(), "Tensor types do not match");
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(input[i]->shape().Rank() == 1, "Only 1D tensors supported");
|
CHECK_FAIL_RETURN_UNEXPECTED(input[i]->shape().Rank() == 1, "Only 1D tensors supported");
|
||||||
RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, input[i]));
|
tensor_list.emplace_back(input[i]);
|
||||||
out = out_t;
|
|
||||||
}
|
}
|
||||||
std::shared_ptr<Tensor> out_t;
|
|
||||||
if (append != nullptr) {
|
if (append != nullptr) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(first_dtype == append->type(), "Tensor types do not match");
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(append->shape().Rank() == 1, "Only 1D tensors supported");
|
CHECK_FAIL_RETURN_UNEXPECTED(append->shape().Rank() == 1, "Only 1D tensors supported");
|
||||||
RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, append));
|
tensor_list.emplace_back(append);
|
||||||
} else {
|
|
||||||
out_t = out;
|
|
||||||
}
|
}
|
||||||
output->push_back(out_t);
|
|
||||||
|
|
||||||
return Status::OK();
|
// create final shape
|
||||||
}
|
for (dsize_t i = 0; i < tensor_list[0]->shape().Rank(); i++) {
|
||||||
|
|
||||||
Status ConcatenateHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int8_t axis,
|
|
||||||
std::shared_ptr<Tensor> append) {
|
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == append->type(), "Tensor types do not match");
|
|
||||||
|
|
||||||
TensorShape t({});
|
|
||||||
|
|
||||||
for (dsize_t i = 0; i < input->shape().Rank(); i++) {
|
|
||||||
if (i != axis) {
|
if (i != axis) {
|
||||||
t = t.AppendDim(input->shape()[i]);
|
t = t.AppendDim(tensor_list[0]->shape()[i]);
|
||||||
} else {
|
} else {
|
||||||
dsize_t new_shape = input->shape()[i] + append->shape()[i];
|
dsize_t new_shape = 0;
|
||||||
|
for (dsize_t j = 0; j < tensor_list.size(); j++) {
|
||||||
|
new_shape = tensor_list[j]->shape()[i] + new_shape;
|
||||||
|
}
|
||||||
t = t.AppendDim(new_shape);
|
t = t.AppendDim(new_shape);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Tensor> out;
|
std::shared_ptr<Tensor> out;
|
||||||
|
|
||||||
if (input->type().IsNumeric()) {
|
if (input[0]->type().IsNumeric()) {
|
||||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(t, input->type(), &out));
|
RETURN_IF_NOT_OK(Tensor::CreateEmpty(t, tensor_list[0]->type(), &out));
|
||||||
|
std::vector<dsize_t> index(axis + 1, 0);
|
||||||
|
|
||||||
RETURN_IF_NOT_OK(out->Concatenate({0}, input));
|
int n = index.size() - 1;
|
||||||
RETURN_IF_NOT_OK(out->Concatenate({input->shape()[0]}, append));
|
for (dsize_t i = 0; i < tensor_list.size(); i++) {
|
||||||
*output = out;
|
RETURN_IF_NOT_OK(out->InsertTensor({index}, tensor_list[i], true));
|
||||||
|
index[n] = index[n] + tensor_list[i]->shape()[axis];
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
std::vector<std::string> strings;
|
std::vector<std::string> strings;
|
||||||
|
|
||||||
auto itr = input->begin<std::string_view>();
|
for (dsize_t i = 0; i < tensor_list.size(); i++) {
|
||||||
for (; itr != input->end<std::string_view>(); itr++) {
|
auto itr = tensor_list[i]->begin<std::string_view>();
|
||||||
|
for (; itr != tensor_list[i]->end<std::string_view>(); itr++) {
|
||||||
strings.emplace_back(*itr);
|
strings.emplace_back(*itr);
|
||||||
}
|
}
|
||||||
itr = append->begin<std::string_view>();
|
|
||||||
for (; itr != append->end<std::string_view>(); itr++) {
|
|
||||||
strings.emplace_back(*itr);
|
|
||||||
}
|
}
|
||||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(strings, t, &out));
|
RETURN_IF_NOT_OK(Tensor::CreateFromVector(strings, t, &out));
|
||||||
|
|
||||||
*output = out;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
output->push_back(out);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -152,11 +152,6 @@ Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu
|
||||||
|
|
||||||
Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr<Tensor> prepend,
|
Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr<Tensor> prepend,
|
||||||
std::shared_ptr<Tensor> append);
|
std::shared_ptr<Tensor> append);
|
||||||
|
|
||||||
// helper for concat, always append to the input, and pass that to the output
|
|
||||||
Status ConcatenateHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int8_t axis,
|
|
||||||
std::shared_ptr<Tensor> append);
|
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -28,9 +28,8 @@ class MindDataTestConcatenateOp : public UT::Common {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(MindDataTestConcatenateOp, TestOp) {
|
TEST_F(MindDataTestConcatenateOp, TestOp) {
|
||||||
MS_LOG(INFO) << "Doing MindDataTestConcatenate-TestOp.";
|
MS_LOG(INFO) << "Doing MindDataTestConcatenate-TestOp-SingleRowinput.";
|
||||||
std::vector<uint64_t> labels = {1, 1, 2};
|
std::vector<uint64_t> labels = {1, 1, 2};
|
||||||
TensorShape shape({3});
|
|
||||||
std::shared_ptr<Tensor> input;
|
std::shared_ptr<Tensor> input;
|
||||||
Tensor::CreateFromVector(labels, &input);
|
Tensor::CreateFromVector(labels, &input);
|
||||||
|
|
||||||
|
@ -57,3 +56,71 @@ TEST_F(MindDataTestConcatenateOp, TestOp) {
|
||||||
MS_LOG(DEBUG) << *expected << std::endl;
|
MS_LOG(DEBUG) << *expected << std::endl;
|
||||||
ASSERT_TRUE(*output == *expected);
|
ASSERT_TRUE(*output == *expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestConcatenateOp, TestOp2) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestConcatenate-TestOp2-MultiInput.";
|
||||||
|
std::vector<uint64_t> labels = {1, 12, 2};
|
||||||
|
std::shared_ptr<Tensor> row_1;
|
||||||
|
Tensor::CreateFromVector(labels, &row_1);
|
||||||
|
|
||||||
|
std::shared_ptr<Tensor> row_2;
|
||||||
|
Tensor::CreateFromVector(labels, &row_2);
|
||||||
|
|
||||||
|
std::vector<uint64_t> append_labels = {4, 4, 4};
|
||||||
|
std::shared_ptr<Tensor> append;
|
||||||
|
Tensor::CreateFromVector(append_labels, &append);
|
||||||
|
|
||||||
|
TensorRow tensor_list;
|
||||||
|
tensor_list.push_back(row_1);
|
||||||
|
tensor_list.push_back(row_2);
|
||||||
|
|
||||||
|
std::shared_ptr<Tensor> output;
|
||||||
|
std::unique_ptr<ConcatenateOp> op(new ConcatenateOp(0, nullptr, append));
|
||||||
|
|
||||||
|
TensorRow out_row;
|
||||||
|
Status s = op->Compute(tensor_list, &out_row);
|
||||||
|
std::vector<uint64_t> out = {1, 12, 2, 1, 12, 2, 4, 4, 4};
|
||||||
|
|
||||||
|
std::shared_ptr<Tensor> expected;
|
||||||
|
Tensor::CreateFromVector(out, &expected);
|
||||||
|
|
||||||
|
output = out_row[0];
|
||||||
|
EXPECT_TRUE(s.IsOk());
|
||||||
|
ASSERT_TRUE(output->shape() == expected->shape());
|
||||||
|
ASSERT_TRUE(output->type() == expected->type());
|
||||||
|
MS_LOG(DEBUG) << *output << std::endl;
|
||||||
|
MS_LOG(DEBUG) << *expected << std::endl;
|
||||||
|
ASSERT_TRUE(*output == *expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestConcatenateOp, TestOp3) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestConcatenate-TestOp3-Strings.";
|
||||||
|
std::vector<std::string> labels = {"hello", "bye"};
|
||||||
|
std::shared_ptr<Tensor> row_1;
|
||||||
|
Tensor::CreateFromVector(labels, &row_1);
|
||||||
|
|
||||||
|
std::vector<std::string> append_labels = {"1", "2", "3"};
|
||||||
|
std::shared_ptr<Tensor> append;
|
||||||
|
Tensor::CreateFromVector(append_labels, &append);
|
||||||
|
|
||||||
|
TensorRow tensor_list;
|
||||||
|
tensor_list.push_back(row_1);
|
||||||
|
|
||||||
|
std::shared_ptr<Tensor> output;
|
||||||
|
std::unique_ptr<ConcatenateOp> op(new ConcatenateOp(0, nullptr, append));
|
||||||
|
|
||||||
|
TensorRow out_row;
|
||||||
|
Status s = op->Compute(tensor_list, &out_row);
|
||||||
|
std::vector<std::string> out = {"hello", "bye", "1", "2", "3"};
|
||||||
|
|
||||||
|
std::shared_ptr<Tensor> expected;
|
||||||
|
Tensor::CreateFromVector(out, &expected);
|
||||||
|
|
||||||
|
output = out_row[0];
|
||||||
|
EXPECT_TRUE(s.IsOk());
|
||||||
|
ASSERT_TRUE(output->shape() == expected->shape());
|
||||||
|
ASSERT_TRUE(output->type() == expected->type());
|
||||||
|
MS_LOG(DEBUG) << *output << std::endl;
|
||||||
|
MS_LOG(DEBUG) << *expected << std::endl;
|
||||||
|
ASSERT_TRUE(*output == *expected);
|
||||||
|
}
|
||||||
|
|
|
@ -432,7 +432,7 @@ TEST_F(MindDataTestTensorDE, TensorSlice) {
|
||||||
ASSERT_EQ(*t2, *t);
|
ASSERT_EQ(*t2, *t);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MindDataTestTensorDE, TensorConcatenate) {
|
TEST_F(MindDataTestTensorDE, TensorPartialInsert) {
|
||||||
std::vector<uint32_t> values1 = {1, 2, 3, 0, 0, 0};
|
std::vector<uint32_t> values1 = {1, 2, 3, 0, 0, 0};
|
||||||
std::vector<uint32_t> values2 = {4, 5, 6};
|
std::vector<uint32_t> values2 = {4, 5, 6};
|
||||||
std::vector<uint32_t> expected = {1, 2, 3, 4, 5, 6};
|
std::vector<uint32_t> expected = {1, 2, 3, 4, 5, 6};
|
||||||
|
@ -445,7 +445,7 @@ TEST_F(MindDataTestTensorDE, TensorConcatenate) {
|
||||||
|
|
||||||
std::shared_ptr<Tensor> out;
|
std::shared_ptr<Tensor> out;
|
||||||
Tensor::CreateFromVector(expected, &out);
|
Tensor::CreateFromVector(expected, &out);
|
||||||
Status s = t1->Concatenate({3}, t2);
|
Status s = t1->InsertTensor({3}, t2, true);
|
||||||
EXPECT_TRUE(s.IsOk());
|
EXPECT_TRUE(s.IsOk());
|
||||||
|
|
||||||
auto i = out->begin<uint32_t>();
|
auto i = out->begin<uint32_t>();
|
||||||
|
@ -455,7 +455,7 @@ TEST_F(MindDataTestTensorDE, TensorConcatenate) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// should fail if the concatenated vector is too large
|
// should fail if the concatenated vector is too large
|
||||||
s = t1->Concatenate({5}, t2);
|
s = t1->InsertTensor({5}, t2, true);
|
||||||
EXPECT_FALSE(s.IsOk());
|
EXPECT_FALSE(s.IsOk());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -130,7 +130,7 @@ def test_concatenate_op_incorrect_dim():
|
||||||
def gen():
|
def gen():
|
||||||
yield (np.array([["ss", "ad"], ["ss", "ad"]], dtype='S'),)
|
yield (np.array([["ss", "ad"], ["ss", "ad"]], dtype='S'),)
|
||||||
|
|
||||||
prepend_tensor = np.array([3, 5], dtype=np.float)
|
prepend_tensor = np.array(["ss", "ss"], dtype='S')
|
||||||
concatenate_op = data_trans.Concatenate(0, prepend_tensor)
|
concatenate_op = data_trans.Concatenate(0, prepend_tensor)
|
||||||
data = ds.GeneratorDataset(gen, column_names=["col"])
|
data = ds.GeneratorDataset(gen, column_names=["col"])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue