From 1aca3f640408f8be92f07112b7a9c71652c16297 Mon Sep 17 00:00:00 2001 From: nhussain Date: Fri, 3 Jul 2020 13:58:02 -0400 Subject: [PATCH] fix unneeded call to typecast op for string --- .../ccsrc/dataset/kernels/data/data_utils.cc | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/mindspore/ccsrc/dataset/kernels/data/data_utils.cc b/mindspore/ccsrc/dataset/kernels/data/data_utils.cc index 40eba1edf6b..8dd5a159394 100644 --- a/mindspore/ccsrc/dataset/kernels/data/data_utils.cc +++ b/mindspore/ccsrc/dataset/kernels/data/data_utils.cc @@ -113,22 +113,27 @@ Status OneHotEncoding(std::shared_ptr input, std::shared_ptr *ou } Status Fill(const std::shared_ptr input, std::shared_ptr *output, std::shared_ptr fill_value) { - CHECK_FAIL_RETURN_UNEXPECTED(!((fill_value->type() == DataType::DE_STRING) && (input->type() != DataType::DE_STRING)), + const DataType &fill_type = fill_value->type(); + const DataType &input_type = input->type(); + const TensorShape &input_shape = input->shape(); + + CHECK_FAIL_RETURN_UNEXPECTED(!((fill_type == DataType::DE_STRING) && (input_type != DataType::DE_STRING)), "Types do not match"); CHECK_FAIL_RETURN_UNEXPECTED(fill_value->shape() == TensorShape({}), "fill_value is not a scalar"); - std::shared_ptr out; + std::shared_ptr out, fill_output; - const DataType &to = input->type(); - std::unique_ptr op(new TypeCastOp(to)); + if (input_type != DataType::DE_STRING && fill_type != DataType::DE_STRING && input_type != fill_type) { + std::unique_ptr op(new TypeCastOp(input_type)); + RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output)); + } else { + fill_output = fill_value; + } - std::shared_ptr fill_output; - RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output)); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input_shape, input_type)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input->shape(), input->type())); - - switch (input->type().value()) { + switch (input_type.value()) { case DataType::DE_BOOL: { bool value = 0; RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); @@ -206,10 +211,10 @@ Status Fill(const std::shared_ptr input, std::shared_ptr *output std::string_view fill_string_view; RETURN_IF_NOT_OK(fill_value->GetItemAt(&fill_string_view, {})); std::string fill_string = std::string(fill_string_view); - for (int i = 0; i < input->shape().NumOfElements(); i++) { + for (int i = 0; i < input_shape.NumOfElements(); i++) { strings.emplace_back(fill_string); } - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input->shape())); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input_shape)); break; } case DataType::DE_UNKNOWN: {