forked from mindspore-Ecosystem/mindspore
Modify normalize to support multiple of channels
This commit is contained in:
parent
da3d4dd1fd
commit
fd943bbdc0
|
@ -665,50 +665,93 @@ Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
|
template <typename T>
|
||||||
const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std) {
|
void Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, std::vector<float> mean,
|
||||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
|
std::vector<float> std) {
|
||||||
if (!input_cv->mat().data) {
|
auto itr_out = (*output)->begin<float>();
|
||||||
RETURN_STATUS_UNEXPECTED("Normalize: load image failed.");
|
auto itr = input->begin<T>();
|
||||||
}
|
auto end = input->end<T>();
|
||||||
if (input_cv->Rank() != 3) {
|
int64_t num_channels = (*output)->shape()[2];
|
||||||
RETURN_STATUS_UNEXPECTED("Normalize: image shape is not <H,W,C>.");
|
|
||||||
}
|
while (itr != end) {
|
||||||
cv::Mat in_image = input_cv->mat();
|
for (int64_t i = 0; i < num_channels; i++) {
|
||||||
std::shared_ptr<CVTensor> output_cv;
|
*itr_out = static_cast<float>(*itr) / std[i] - mean[i];
|
||||||
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), DataType(DataType::DE_FLOAT32), &output_cv));
|
++itr_out;
|
||||||
mean->Squeeze();
|
++itr;
|
||||||
if (mean->type() != DataType::DE_FLOAT32 || mean->Rank() != 1 || mean->shape()[0] != 3) {
|
|
||||||
std::string err_msg = "Normalize: mean should be of size 3 and type float.";
|
|
||||||
return Status(StatusCode::kMDShapeMisMatch, err_msg);
|
|
||||||
}
|
|
||||||
std->Squeeze();
|
|
||||||
if (std->type() != DataType::DE_FLOAT32 || std->Rank() != 1 || std->shape()[0] != 3) {
|
|
||||||
std::string err_msg = "Normalize: std tensor should be of size 3 and type float.";
|
|
||||||
return Status(StatusCode::kMDShapeMisMatch, err_msg);
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
// NOTE: We are assuming the input image is in RGB and the mean
|
|
||||||
// and std are in RGB
|
|
||||||
std::vector<cv::Mat> rgb;
|
|
||||||
cv::split(in_image, rgb);
|
|
||||||
if (rgb.size() != 3) {
|
|
||||||
RETURN_STATUS_UNEXPECTED("Normalize: input image is not in RGB.");
|
|
||||||
}
|
}
|
||||||
for (uint8_t i = 0; i < 3; i++) {
|
|
||||||
float mean_c, std_c;
|
|
||||||
RETURN_IF_NOT_OK(mean->GetItemAt<float>(&mean_c, {i}));
|
|
||||||
RETURN_IF_NOT_OK(std->GetItemAt<float>(&std_c, {i}));
|
|
||||||
rgb[i].convertTo(rgb[i], CV_32F, 1.0 / std_c, (-mean_c / std_c));
|
|
||||||
}
|
|
||||||
cv::merge(rgb, output_cv->mat());
|
|
||||||
*output = std::static_pointer_cast<Tensor>(output_cv);
|
|
||||||
return Status::OK();
|
|
||||||
} catch (const cv::Exception &e) {
|
|
||||||
RETURN_STATUS_UNEXPECTED("Normalize: " + std::string(e.what()));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, std::vector<float> mean,
|
||||||
|
std::vector<float> std) {
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), DataType(DataType::DE_FLOAT32), output));
|
||||||
|
if (input->Rank() == 2) {
|
||||||
|
RETURN_IF_NOT_OK((*output)->ExpandDim(2));
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED((*output)->Rank() == 3, "Normalize: image shape is not <H,W,C>.");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(std.size() == mean.size(), "Normalize: mean and std vectors are not of same size.");
|
||||||
|
|
||||||
|
// caller provided 1 mean/std value and there are more than one channel --> duplicate mean/std value
|
||||||
|
if (mean.size() == 1 && (*output)->shape()[2] != 1) {
|
||||||
|
std::vector<float> mean_t, std_t;
|
||||||
|
for (int64_t i = 0; i < (*output)->shape()[2] - 1; i++) {
|
||||||
|
mean.push_back(mean[0]);
|
||||||
|
std.push_back(std[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED((*output)->shape()[2] == mean.size(),
|
||||||
|
"Normalize: number of channels does not match the size of mean and std vectors.");
|
||||||
|
|
||||||
|
switch (input->type().value()) {
|
||||||
|
case DataType::DE_BOOL:
|
||||||
|
Normalize<bool>(input, output, mean, std);
|
||||||
|
break;
|
||||||
|
case DataType::DE_INT8:
|
||||||
|
Normalize<int8_t>(input, output, mean, std);
|
||||||
|
break;
|
||||||
|
case DataType::DE_UINT8:
|
||||||
|
Normalize<uint8_t>(input, output, mean, std);
|
||||||
|
break;
|
||||||
|
case DataType::DE_INT16:
|
||||||
|
Normalize<int16_t>(input, output, mean, std);
|
||||||
|
break;
|
||||||
|
case DataType::DE_UINT16:
|
||||||
|
Normalize<uint16_t>(input, output, mean, std);
|
||||||
|
break;
|
||||||
|
case DataType::DE_INT32:
|
||||||
|
Normalize<int32_t>(input, output, mean, std);
|
||||||
|
break;
|
||||||
|
case DataType::DE_UINT32:
|
||||||
|
Normalize<uint32_t>(input, output, mean, std);
|
||||||
|
break;
|
||||||
|
case DataType::DE_INT64:
|
||||||
|
Normalize<int64_t>(input, output, mean, std);
|
||||||
|
break;
|
||||||
|
case DataType::DE_UINT64:
|
||||||
|
Normalize<uint64_t>(input, output, mean, std);
|
||||||
|
break;
|
||||||
|
#ifndef ENABLE_MD_LITE_X86_64
|
||||||
|
case DataType::DE_FLOAT16:
|
||||||
|
Normalize<float16>(input, output, mean, std);
|
||||||
|
break;
|
||||||
|
#endif
|
||||||
|
case DataType::DE_FLOAT32:
|
||||||
|
Normalize<float>(input, output, mean, std);
|
||||||
|
break;
|
||||||
|
case DataType::DE_FLOAT64:
|
||||||
|
Normalize<double>(input, output, mean, std);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
RETURN_STATUS_UNEXPECTED("Normalize: unsupported type.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (input->Rank() == 2) {
|
||||||
|
(*output)->Squeeze();
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status NormalizePad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
|
Status NormalizePad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
|
||||||
const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std, const std::string &dtype) {
|
const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std, const std::string &dtype) {
|
||||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
|
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
|
||||||
|
|
|
@ -190,8 +190,8 @@ Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
|
||||||
/// \param mean: Tensor of shape <3> and type DE_FLOAT32 which are mean of each channel in RGB order
|
/// \param mean: Tensor of shape <3> and type DE_FLOAT32 which are mean of each channel in RGB order
|
||||||
/// \param std: Tensor of shape <3> and type DE_FLOAT32 which are std of each channel in RGB order
|
/// \param std: Tensor of shape <3> and type DE_FLOAT32 which are std of each channel in RGB order
|
||||||
/// \param output: Normalized image Tensor of same input shape and type DE_FLOAT32
|
/// \param output: Normalized image Tensor of same input shape and type DE_FLOAT32
|
||||||
Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
|
Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, std::vector<float> mean,
|
||||||
const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std);
|
std::vector<float> std);
|
||||||
|
|
||||||
/// \brief Returns Normalized and paded image
|
/// \brief Returns Normalized and paded image
|
||||||
/// \param input: Tensor of shape <H,W,C> in RGB order and any OpenCv compatible type, see CVTensor.
|
/// \param input: Tensor of shape <H,W,C> in RGB order and any OpenCv compatible type, see CVTensor.
|
||||||
|
|
|
@ -747,7 +747,7 @@ bool SubStractMeanNormalize(const LiteMat &src, LiteMat &dst, const std::vector<
|
||||||
uint32_t src_start = (h * src.width_ + w) * src.channel_;
|
uint32_t src_start = (h * src.width_ + w) * src.channel_;
|
||||||
for (int c = 0; c < src.channel_; c++) {
|
for (int c = 0; c < src.channel_; c++) {
|
||||||
uint32_t index = src_start + c;
|
uint32_t index = src_start + c;
|
||||||
dst_start_p[index] = (src_start_p[index] - mean[c]) / std[c];
|
dst_start_p[index] = (src_start_p[index] / std[c]) - mean[c];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -305,8 +305,8 @@ Status GetJpegImageInfo(const std::shared_ptr<Tensor> &input, int *img_width, in
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
|
Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, std::vector<float> vec_mean,
|
||||||
const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std) {
|
std::vector<float> vec_std) {
|
||||||
if (input->Rank() != 3) {
|
if (input->Rank() != 3) {
|
||||||
RETURN_STATUS_UNEXPECTED("Normalize: image shape is not <H,W,C>.");
|
RETURN_STATUS_UNEXPECTED("Normalize: image shape is not <H,W,C>.");
|
||||||
}
|
}
|
||||||
|
@ -315,28 +315,7 @@ Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
|
||||||
RETURN_STATUS_UNEXPECTED("Normalize: image datatype is not uint8 or float32.");
|
RETURN_STATUS_UNEXPECTED("Normalize: image datatype is not uint8 or float32.");
|
||||||
}
|
}
|
||||||
|
|
||||||
mean->Squeeze();
|
|
||||||
if (mean->type() != DataType::DE_FLOAT32 || mean->Rank() != 1 || mean->shape()[0] != 3) {
|
|
||||||
std::string err_msg = "Normalize: mean should be of size 3 and type float.";
|
|
||||||
return Status(StatusCode::kMDShapeMisMatch, err_msg);
|
|
||||||
}
|
|
||||||
std->Squeeze();
|
|
||||||
if (std->type() != DataType::DE_FLOAT32 || std->Rank() != 1 || std->shape()[0] != 3) {
|
|
||||||
std::string err_msg = "Normalize: std should be of size 3 and type float.";
|
|
||||||
return Status(StatusCode::kMDShapeMisMatch, err_msg);
|
|
||||||
}
|
|
||||||
// convert mean, std back to vector
|
|
||||||
std::vector<float> vec_mean;
|
|
||||||
std::vector<float> vec_std;
|
|
||||||
try {
|
try {
|
||||||
for (uint8_t i = 0; i < 3; i++) {
|
|
||||||
float mean_c, std_c;
|
|
||||||
RETURN_IF_NOT_OK(mean->GetItemAt<float>(&mean_c, {i}));
|
|
||||||
RETURN_IF_NOT_OK(std->GetItemAt<float>(&std_c, {i}));
|
|
||||||
vec_mean.push_back(mean_c);
|
|
||||||
vec_std.push_back(std_c);
|
|
||||||
}
|
|
||||||
|
|
||||||
LiteMat lite_mat_norm;
|
LiteMat lite_mat_norm;
|
||||||
bool ret = false;
|
bool ret = false;
|
||||||
LiteMat lite_mat_rgb(input->shape()[1], input->shape()[0], input->shape()[2],
|
LiteMat lite_mat_rgb(input->shape()[1], input->shape()[0], input->shape()[2],
|
||||||
|
|
|
@ -79,8 +79,8 @@ Status GetJpegImageInfo(const std::shared_ptr<Tensor> &input, int *img_width, in
|
||||||
/// \param[in] mean Tensor of shape <3> and type DE_FLOAT32 which are mean of each channel in RGB order
|
/// \param[in] mean Tensor of shape <3> and type DE_FLOAT32 which are mean of each channel in RGB order
|
||||||
/// \param[in] std Tensor of shape <3> and type DE_FLOAT32 which are std of each channel in RGB order
|
/// \param[in] std Tensor of shape <3> and type DE_FLOAT32 which are std of each channel in RGB order
|
||||||
/// \param[out] output Normalized image Tensor of same input shape and type DE_FLOAT32
|
/// \param[out] output Normalized image Tensor of same input shape and type DE_FLOAT32
|
||||||
Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
|
Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, std::vector<float> vec_mean,
|
||||||
const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std);
|
std::vector<float> vec_std);
|
||||||
|
|
||||||
/// \brief Returns Resized image.
|
/// \brief Returns Resized image.
|
||||||
/// \param[in] input
|
/// \param[in] input
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
#include "minddata/dataset/kernels/image/normalize_op.h"
|
#include "minddata/dataset/kernels/image/normalize_op.h"
|
||||||
|
|
||||||
#include <random>
|
#include <random>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
#ifndef ENABLE_ANDROID
|
||||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||||
|
@ -26,14 +27,10 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
NormalizeOp::NormalizeOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b) {
|
NormalizeOp::NormalizeOp(const std::vector<float> &mean, const std::vector<float> &std) : mean_(mean), std_(std) {
|
||||||
Status s = Tensor::CreateFromVector<float>({mean_r, mean_g, mean_b}, &mean_);
|
// pre-calculate normalized mean to be used later in each Compute
|
||||||
if (s.IsError()) {
|
for (int8_t i = 0; i < mean.size(); i++) {
|
||||||
MS_LOG(ERROR) << "Normalize: invalid mean value.";
|
mean_[i] = mean_[i] / std_[i];
|
||||||
}
|
|
||||||
s = Tensor::CreateFromVector<float>({std_r, std_g, std_b}, &std_);
|
|
||||||
if (s.IsError()) {
|
|
||||||
MS_LOG(ERROR) << "Normalize: invalid std value.";
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,7 +41,15 @@ Status NormalizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_pt
|
||||||
}
|
}
|
||||||
|
|
||||||
void NormalizeOp::Print(std::ostream &out) const {
|
void NormalizeOp::Print(std::ostream &out) const {
|
||||||
out << "NormalizeOp, mean: " << *(mean_.get()) << std::endl << "std: " << *(std_.get()) << std::endl;
|
out << "NormalizeOp, mean: ";
|
||||||
|
for (const auto &m : mean_) {
|
||||||
|
out << m << ", ";
|
||||||
|
}
|
||||||
|
out << "}" << std::endl << "std: ";
|
||||||
|
for (const auto &s : std_) {
|
||||||
|
out << s << ", ";
|
||||||
|
}
|
||||||
|
out << "}" << std::endl;
|
||||||
}
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "minddata/dataset/core/tensor.h"
|
#include "minddata/dataset/core/tensor.h"
|
||||||
#include "minddata/dataset/kernels/tensor_op.h"
|
#include "minddata/dataset/kernels/tensor_op.h"
|
||||||
|
@ -27,7 +28,7 @@ namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
class NormalizeOp : public TensorOp {
|
class NormalizeOp : public TensorOp {
|
||||||
public:
|
public:
|
||||||
NormalizeOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b);
|
NormalizeOp(const std::vector<float> &mean, const std::vector<float> &std);
|
||||||
|
|
||||||
~NormalizeOp() override = default;
|
~NormalizeOp() override = default;
|
||||||
|
|
||||||
|
@ -38,8 +39,8 @@ class NormalizeOp : public TensorOp {
|
||||||
std::string Name() const override { return kNormalizeOp; }
|
std::string Name() const override { return kNormalizeOp; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<Tensor> mean_;
|
std::vector<float> mean_;
|
||||||
std::shared_ptr<Tensor> std_;
|
std::vector<float> std_;
|
||||||
};
|
};
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -73,13 +73,18 @@ Status ValidateVectorColorAttribute(const std::string &op_name, const std::strin
|
||||||
|
|
||||||
Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float> &mean,
|
Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float> &mean,
|
||||||
const std::vector<float> &std) {
|
const std::vector<float> &std) {
|
||||||
if (mean.size() != 3) {
|
if (mean.size() == 0) {
|
||||||
std::string err_msg = op_name + ": mean expecting size 3, got size: " + std::to_string(mean.size());
|
std::string err_msg = op_name + ": mean expecting non-empty vector";
|
||||||
MS_LOG(ERROR) << err_msg;
|
MS_LOG(ERROR) << err_msg;
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
}
|
}
|
||||||
if (std.size() != 3) {
|
if (std.size() == 0) {
|
||||||
std::string err_msg = op_name + ": std expecting size 3, got size: " + std::to_string(std.size());
|
std::string err_msg = op_name + ": std expecting non-empty vector";
|
||||||
|
MS_LOG(ERROR) << err_msg;
|
||||||
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
|
}
|
||||||
|
if (mean.size() != std.size()) {
|
||||||
|
std::string err_msg = op_name + ": mean and std vectors are expected to be of the same size";
|
||||||
MS_LOG(ERROR) << err_msg;
|
MS_LOG(ERROR) << err_msg;
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,9 +38,7 @@ Status NormalizeOperation::ValidateParams() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<TensorOp> NormalizeOperation::Build() {
|
std::shared_ptr<TensorOp> NormalizeOperation::Build() { return std::make_shared<NormalizeOp>(mean_, std_); }
|
||||||
return std::make_shared<NormalizeOp>(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status NormalizeOperation::to_json(nlohmann::json *out_json) {
|
Status NormalizeOperation::to_json(nlohmann::json *out_json) {
|
||||||
nlohmann::json args;
|
nlohmann::json args;
|
||||||
|
|
|
@ -374,10 +374,6 @@ class Normalize(ImageTensorOperation):
|
||||||
|
|
||||||
@check_normalize_c
|
@check_normalize_c
|
||||||
def __init__(self, mean, std):
|
def __init__(self, mean, std):
|
||||||
if len(mean) == 1:
|
|
||||||
mean = [mean[0]] * 3
|
|
||||||
if len(std) == 1:
|
|
||||||
std = [std[0]] * 3
|
|
||||||
self.mean = mean
|
self.mean = mean
|
||||||
self.std = std
|
self.std = std
|
||||||
|
|
||||||
|
|
|
@ -35,11 +35,11 @@ TEST_F(MindDataTestNormalizeOP, TestOp) {
|
||||||
std::shared_ptr<Tensor> output_tensor;
|
std::shared_ptr<Tensor> output_tensor;
|
||||||
|
|
||||||
// Numbers are from the resnet50 model implementation
|
// Numbers are from the resnet50 model implementation
|
||||||
float mean[3] = {121.0, 115.0, 100.0};
|
std::vector<float> mean = {121.0, 115.0, 100.0};
|
||||||
float std[3] = {70.0, 68.0, 71.0};
|
std::vector<float> std = {70.0, 68.0, 71.0};
|
||||||
|
|
||||||
// Normalize Op
|
// Normalize Op
|
||||||
std::unique_ptr<NormalizeOp> op(new NormalizeOp(mean[0], mean[1], mean[2], std[0], std[1], std[2]));
|
std::unique_ptr<NormalizeOp> op(new NormalizeOp(mean, std));
|
||||||
EXPECT_TRUE(op->OneToOne());
|
EXPECT_TRUE(op->OneToOne());
|
||||||
Status s = op->Compute(input_tensor_, &output_tensor);
|
Status s = op->Compute(input_tensor_, &output_tensor);
|
||||||
EXPECT_TRUE(s.IsOk());
|
EXPECT_TRUE(s.IsOk());
|
||||||
|
|
Binary file not shown.
|
@ -340,6 +340,35 @@ def test_normalize_grayscale_exception():
|
||||||
assert "Input is not within the required range" in str(e)
|
assert "Input is not within the required range" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_channels():
|
||||||
|
logger.info("test_multiple_channels")
|
||||||
|
|
||||||
|
def util_test(item, mean, std):
|
||||||
|
data = ds.NumpySlicesDataset([item], shuffle=False)
|
||||||
|
data = data.map(c_vision.Normalize(mean, std))
|
||||||
|
for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
actual = d[0]
|
||||||
|
mean = np.array(mean, dtype=item.dtype)
|
||||||
|
std = np.array(std, dtype=item.dtype)
|
||||||
|
expected = item
|
||||||
|
if len(item.shape) != 1 and len(mean) == 1:
|
||||||
|
mean = [mean[0]] * expected.shape[-1]
|
||||||
|
std = [std[0]] * expected.shape[-1]
|
||||||
|
if len(item.shape) == 2:
|
||||||
|
expected = np.expand_dims(expected, 2)
|
||||||
|
for c in range(expected.shape[-1]):
|
||||||
|
expected[:, :, c] = (expected[:, :, c] - mean[c]) / std[c]
|
||||||
|
expected = expected.squeeze()
|
||||||
|
|
||||||
|
np.testing.assert_almost_equal(actual, expected, decimal=6)
|
||||||
|
|
||||||
|
util_test(np.ones(shape=[2, 2, 3]), mean=[0.5, 0.6, 0.7], std=[0.1, 0.2, 0.3])
|
||||||
|
util_test(np.ones(shape=[20, 45, 3]) * 1.3, mean=[0.5, 0.6, 0.7], std=[0.1, 0.2, 0.3])
|
||||||
|
util_test(np.ones(shape=[20, 45, 4]) * 1.3, mean=[0.5, 0.6, 0.7, 0.8], std=[0.1, 0.2, 0.3, 0.4])
|
||||||
|
util_test(np.ones(shape=[2, 2]), mean=[0.5], std=[0.1])
|
||||||
|
util_test(np.ones(shape=[2, 2, 5]), mean=[0.5], std=[0.1])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_decode_op()
|
test_decode_op()
|
||||||
test_decode_normalize_op()
|
test_decode_normalize_op()
|
||||||
|
|
Loading…
Reference in New Issue