forked from mindspore-Ecosystem/mindspore
!2855 Fix issue with string typecasting for FillOp
Merge pull request !2855 from nhussain/fill_op_typecasting
This commit is contained in:
commit
9991df8676
|
@ -113,22 +113,27 @@ Status OneHotEncoding(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *ou
|
|||
}
|
||||
|
||||
Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, std::shared_ptr<Tensor> fill_value) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!((fill_value->type() == DataType::DE_STRING) && (input->type() != DataType::DE_STRING)),
|
||||
const DataType &fill_type = fill_value->type();
|
||||
const DataType &input_type = input->type();
|
||||
const TensorShape &input_shape = input->shape();
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!((fill_type == DataType::DE_STRING) && (input_type != DataType::DE_STRING)),
|
||||
"Types do not match");
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(fill_value->shape() == TensorShape({}), "fill_value is not a scalar");
|
||||
|
||||
std::shared_ptr<Tensor> out;
|
||||
std::shared_ptr<Tensor> out, fill_output;
|
||||
|
||||
const DataType &to = input->type();
|
||||
std::unique_ptr<TypeCastOp> op(new TypeCastOp(to));
|
||||
|
||||
std::shared_ptr<Tensor> fill_output;
|
||||
if (input_type != DataType::DE_STRING && fill_type != DataType::DE_STRING && input_type != fill_type) {
|
||||
std::unique_ptr<TypeCastOp> op(new TypeCastOp(input_type));
|
||||
RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output));
|
||||
} else {
|
||||
fill_output = fill_value;
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input->shape(), input->type()));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input_shape, input_type));
|
||||
|
||||
switch (input->type().value()) {
|
||||
switch (input_type.value()) {
|
||||
case DataType::DE_BOOL: {
|
||||
bool value = 0;
|
||||
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
|
||||
|
@ -206,10 +211,10 @@ Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output
|
|||
std::string_view fill_string_view;
|
||||
RETURN_IF_NOT_OK(fill_value->GetItemAt(&fill_string_view, {}));
|
||||
std::string fill_string = std::string(fill_string_view);
|
||||
for (int i = 0; i < input->shape().NumOfElements(); i++) {
|
||||
for (int i = 0; i < input_shape.NumOfElements(); i++) {
|
||||
strings.emplace_back(fill_string);
|
||||
}
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input->shape()));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input_shape));
|
||||
break;
|
||||
}
|
||||
case DataType::DE_UNKNOWN: {
|
||||
|
|
Loading…
Reference in New Issue