forked from mindspore-Ecosystem/mindspore
commit
97ccb349e6
|
@ -36,7 +36,7 @@ class DETensor : public mindspore::MSTensor::Impl {
|
|||
~DETensor() = default;
|
||||
explicit DETensor(std::shared_ptr<dataset::Tensor> tensor_impl);
|
||||
#ifndef ENABLE_ANDROID
|
||||
explicit DETensor(std::shared_ptr<dataset::DeviceTensor> device_tensor_impl, bool is_device);
|
||||
DETensor(std::shared_ptr<dataset::DeviceTensor> device_tensor_impl, bool is_device);
|
||||
#endif
|
||||
const std::string &Name() const override;
|
||||
|
||||
|
|
|
@ -89,6 +89,13 @@ Status MindDataNode::ValidateParams() {
|
|||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (shuffle_mode_ != ShuffleMode::kFalse && shuffle_mode_ != ShuffleMode::kFiles &&
|
||||
shuffle_mode_ != ShuffleMode::kGlobal && shuffle_mode_ != ShuffleMode::kInfile) {
|
||||
std::string err_msg = "TFRecordNode: Invalid ShuffleMode, check input value of enum.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
std::vector<std::string> dataset_file_vec =
|
||||
search_for_pattern_ ? std::vector<std::string>{dataset_file_} : dataset_files_;
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("MindDataNode", dataset_file_vec));
|
||||
|
|
|
@ -343,11 +343,6 @@ Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
|
|||
const_cast<void *>(reinterpret_cast<const void *>(input->GetBuffer())),
|
||||
GetLiteCVDataType(input->type()));
|
||||
|
||||
std::shared_ptr<Tensor> output_tensor;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), DataType(DataType::DE_FLOAT32), &output_tensor));
|
||||
|
||||
uint8_t *buffer = reinterpret_cast<uint8_t *>(&(*output_tensor->begin<uint8_t>()));
|
||||
|
||||
if (input->type() == DataType::DE_UINT8) {
|
||||
LiteMat lite_mat_float;
|
||||
// change input to float
|
||||
|
@ -359,6 +354,10 @@ Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
|
|||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ret, "Normalize: normalize failed.");
|
||||
|
||||
std::shared_ptr<Tensor> output_tensor;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(input->shape(), DataType(DataType::DE_FLOAT32),
|
||||
static_cast<uchar *>(lite_mat_norm.data_ptr_), &output_tensor));
|
||||
|
||||
*output = output_tensor;
|
||||
} catch (std::runtime_error &e) {
|
||||
RETURN_STATUS_UNEXPECTED("Normalize: " + std::string(e.what()));
|
||||
|
|
Loading…
Reference in New Issue