Modify normalize to support multiple of channels

This commit is contained in:
hesham 2021-05-07 12:31:33 -04:00
parent da3d4dd1fd
commit fd943bbdc0
13 changed files with 150 additions and 94 deletions

View File

@ -665,50 +665,93 @@ Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
return Status::OK(); return Status::OK();
} }
Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, template <typename T>
const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std) { void Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, std::vector<float> mean,
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input); std::vector<float> std) {
if (!input_cv->mat().data) { auto itr_out = (*output)->begin<float>();
RETURN_STATUS_UNEXPECTED("Normalize: load image failed."); auto itr = input->begin<T>();
} auto end = input->end<T>();
if (input_cv->Rank() != 3) { int64_t num_channels = (*output)->shape()[2];
RETURN_STATUS_UNEXPECTED("Normalize: image shape is not <H,W,C>.");
} while (itr != end) {
cv::Mat in_image = input_cv->mat(); for (int64_t i = 0; i < num_channels; i++) {
std::shared_ptr<CVTensor> output_cv; *itr_out = static_cast<float>(*itr) / std[i] - mean[i];
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), DataType(DataType::DE_FLOAT32), &output_cv)); ++itr_out;
mean->Squeeze(); ++itr;
if (mean->type() != DataType::DE_FLOAT32 || mean->Rank() != 1 || mean->shape()[0] != 3) {
std::string err_msg = "Normalize: mean should be of size 3 and type float.";
return Status(StatusCode::kMDShapeMisMatch, err_msg);
}
std->Squeeze();
if (std->type() != DataType::DE_FLOAT32 || std->Rank() != 1 || std->shape()[0] != 3) {
std::string err_msg = "Normalize: std tensor should be of size 3 and type float.";
return Status(StatusCode::kMDShapeMisMatch, err_msg);
}
try {
// NOTE: We are assuming the input image is in RGB and the mean
// and std are in RGB
std::vector<cv::Mat> rgb;
cv::split(in_image, rgb);
if (rgb.size() != 3) {
RETURN_STATUS_UNEXPECTED("Normalize: input image is not in RGB.");
} }
for (uint8_t i = 0; i < 3; i++) {
float mean_c, std_c;
RETURN_IF_NOT_OK(mean->GetItemAt<float>(&mean_c, {i}));
RETURN_IF_NOT_OK(std->GetItemAt<float>(&std_c, {i}));
rgb[i].convertTo(rgb[i], CV_32F, 1.0 / std_c, (-mean_c / std_c));
}
cv::merge(rgb, output_cv->mat());
*output = std::static_pointer_cast<Tensor>(output_cv);
return Status::OK();
} catch (const cv::Exception &e) {
RETURN_STATUS_UNEXPECTED("Normalize: " + std::string(e.what()));
} }
} }
Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, std::vector<float> mean,
std::vector<float> std) {
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), DataType(DataType::DE_FLOAT32), output));
if (input->Rank() == 2) {
RETURN_IF_NOT_OK((*output)->ExpandDim(2));
}
CHECK_FAIL_RETURN_UNEXPECTED((*output)->Rank() == 3, "Normalize: image shape is not <H,W,C>.");
CHECK_FAIL_RETURN_UNEXPECTED(std.size() == mean.size(), "Normalize: mean and std vectors are not of same size.");
// caller provided 1 mean/std value and there are more than one channel --> duplicate mean/std value
if (mean.size() == 1 && (*output)->shape()[2] != 1) {
std::vector<float> mean_t, std_t;
for (int64_t i = 0; i < (*output)->shape()[2] - 1; i++) {
mean.push_back(mean[0]);
std.push_back(std[0]);
}
}
CHECK_FAIL_RETURN_UNEXPECTED((*output)->shape()[2] == mean.size(),
"Normalize: number of channels does not match the size of mean and std vectors.");
switch (input->type().value()) {
case DataType::DE_BOOL:
Normalize<bool>(input, output, mean, std);
break;
case DataType::DE_INT8:
Normalize<int8_t>(input, output, mean, std);
break;
case DataType::DE_UINT8:
Normalize<uint8_t>(input, output, mean, std);
break;
case DataType::DE_INT16:
Normalize<int16_t>(input, output, mean, std);
break;
case DataType::DE_UINT16:
Normalize<uint16_t>(input, output, mean, std);
break;
case DataType::DE_INT32:
Normalize<int32_t>(input, output, mean, std);
break;
case DataType::DE_UINT32:
Normalize<uint32_t>(input, output, mean, std);
break;
case DataType::DE_INT64:
Normalize<int64_t>(input, output, mean, std);
break;
case DataType::DE_UINT64:
Normalize<uint64_t>(input, output, mean, std);
break;
#ifndef ENABLE_MD_LITE_X86_64
case DataType::DE_FLOAT16:
Normalize<float16>(input, output, mean, std);
break;
#endif
case DataType::DE_FLOAT32:
Normalize<float>(input, output, mean, std);
break;
case DataType::DE_FLOAT64:
Normalize<double>(input, output, mean, std);
break;
default:
RETURN_STATUS_UNEXPECTED("Normalize: unsupported type.");
}
if (input->Rank() == 2) {
(*output)->Squeeze();
}
return Status::OK();
}
Status NormalizePad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, Status NormalizePad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std, const std::string &dtype) { const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std, const std::string &dtype) {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input); std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);

View File

@ -190,8 +190,8 @@ Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
/// \param mean: Tensor of shape <3> and type DE_FLOAT32 which are mean of each channel in RGB order /// \param mean: Tensor of shape <3> and type DE_FLOAT32 which are mean of each channel in RGB order
/// \param std: Tensor of shape <3> and type DE_FLOAT32 which are std of each channel in RGB order /// \param std: Tensor of shape <3> and type DE_FLOAT32 which are std of each channel in RGB order
/// \param output: Normalized image Tensor of same input shape and type DE_FLOAT32 /// \param output: Normalized image Tensor of same input shape and type DE_FLOAT32
Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, std::vector<float> mean,
const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std); std::vector<float> std);
/// \brief Returns Normalized and paded image /// \brief Returns Normalized and paded image
/// \param input: Tensor of shape <H,W,C> in RGB order and any OpenCv compatible type, see CVTensor. /// \param input: Tensor of shape <H,W,C> in RGB order and any OpenCv compatible type, see CVTensor.

View File

@ -747,7 +747,7 @@ bool SubStractMeanNormalize(const LiteMat &src, LiteMat &dst, const std::vector<
uint32_t src_start = (h * src.width_ + w) * src.channel_; uint32_t src_start = (h * src.width_ + w) * src.channel_;
for (int c = 0; c < src.channel_; c++) { for (int c = 0; c < src.channel_; c++) {
uint32_t index = src_start + c; uint32_t index = src_start + c;
dst_start_p[index] = (src_start_p[index] - mean[c]) / std[c]; dst_start_p[index] = (src_start_p[index] / std[c]) - mean[c];
} }
} }
} }

View File

@ -305,8 +305,8 @@ Status GetJpegImageInfo(const std::shared_ptr<Tensor> &input, int *img_width, in
return Status::OK(); return Status::OK();
} }
Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, std::vector<float> vec_mean,
const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std) { std::vector<float> vec_std) {
if (input->Rank() != 3) { if (input->Rank() != 3) {
RETURN_STATUS_UNEXPECTED("Normalize: image shape is not <H,W,C>."); RETURN_STATUS_UNEXPECTED("Normalize: image shape is not <H,W,C>.");
} }
@ -315,28 +315,7 @@ Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
RETURN_STATUS_UNEXPECTED("Normalize: image datatype is not uint8 or float32."); RETURN_STATUS_UNEXPECTED("Normalize: image datatype is not uint8 or float32.");
} }
mean->Squeeze();
if (mean->type() != DataType::DE_FLOAT32 || mean->Rank() != 1 || mean->shape()[0] != 3) {
std::string err_msg = "Normalize: mean should be of size 3 and type float.";
return Status(StatusCode::kMDShapeMisMatch, err_msg);
}
std->Squeeze();
if (std->type() != DataType::DE_FLOAT32 || std->Rank() != 1 || std->shape()[0] != 3) {
std::string err_msg = "Normalize: std should be of size 3 and type float.";
return Status(StatusCode::kMDShapeMisMatch, err_msg);
}
// convert mean, std back to vector
std::vector<float> vec_mean;
std::vector<float> vec_std;
try { try {
for (uint8_t i = 0; i < 3; i++) {
float mean_c, std_c;
RETURN_IF_NOT_OK(mean->GetItemAt<float>(&mean_c, {i}));
RETURN_IF_NOT_OK(std->GetItemAt<float>(&std_c, {i}));
vec_mean.push_back(mean_c);
vec_std.push_back(std_c);
}
LiteMat lite_mat_norm; LiteMat lite_mat_norm;
bool ret = false; bool ret = false;
LiteMat lite_mat_rgb(input->shape()[1], input->shape()[0], input->shape()[2], LiteMat lite_mat_rgb(input->shape()[1], input->shape()[0], input->shape()[2],

View File

@ -79,8 +79,8 @@ Status GetJpegImageInfo(const std::shared_ptr<Tensor> &input, int *img_width, in
/// \param[in] mean Tensor of shape <3> and type DE_FLOAT32 which are mean of each channel in RGB order /// \param[in] mean Tensor of shape <3> and type DE_FLOAT32 which are mean of each channel in RGB order
/// \param[in] std Tensor of shape <3> and type DE_FLOAT32 which are std of each channel in RGB order /// \param[in] std Tensor of shape <3> and type DE_FLOAT32 which are std of each channel in RGB order
/// \param[out] output Normalized image Tensor of same input shape and type DE_FLOAT32 /// \param[out] output Normalized image Tensor of same input shape and type DE_FLOAT32
Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, std::vector<float> vec_mean,
const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std); std::vector<float> vec_std);
/// \brief Returns Resized image. /// \brief Returns Resized image.
/// \param[in] input /// \param[in] input

View File

@ -16,6 +16,7 @@
#include "minddata/dataset/kernels/image/normalize_op.h" #include "minddata/dataset/kernels/image/normalize_op.h"
#include <random> #include <random>
#include <vector>
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/image_utils.h" #include "minddata/dataset/kernels/image/image_utils.h"
@ -26,14 +27,10 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
NormalizeOp::NormalizeOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b) { NormalizeOp::NormalizeOp(const std::vector<float> &mean, const std::vector<float> &std) : mean_(mean), std_(std) {
Status s = Tensor::CreateFromVector<float>({mean_r, mean_g, mean_b}, &mean_); // pre-calculate normalized mean to be used later in each Compute
if (s.IsError()) { for (int8_t i = 0; i < mean.size(); i++) {
MS_LOG(ERROR) << "Normalize: invalid mean value."; mean_[i] = mean_[i] / std_[i];
}
s = Tensor::CreateFromVector<float>({std_r, std_g, std_b}, &std_);
if (s.IsError()) {
MS_LOG(ERROR) << "Normalize: invalid std value.";
} }
} }
@ -44,7 +41,15 @@ Status NormalizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_pt
} }
void NormalizeOp::Print(std::ostream &out) const { void NormalizeOp::Print(std::ostream &out) const {
out << "NormalizeOp, mean: " << *(mean_.get()) << std::endl << "std: " << *(std_.get()) << std::endl; out << "NormalizeOp, mean: ";
for (const auto &m : mean_) {
out << m << ", ";
}
out << "}" << std::endl << "std: ";
for (const auto &s : std_) {
out << s << ", ";
}
out << "}" << std::endl;
} }
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector>
#include "minddata/dataset/core/tensor.h" #include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h" #include "minddata/dataset/kernels/tensor_op.h"
@ -27,7 +28,7 @@ namespace mindspore {
namespace dataset { namespace dataset {
class NormalizeOp : public TensorOp { class NormalizeOp : public TensorOp {
public: public:
NormalizeOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b); NormalizeOp(const std::vector<float> &mean, const std::vector<float> &std);
~NormalizeOp() override = default; ~NormalizeOp() override = default;
@ -38,8 +39,8 @@ class NormalizeOp : public TensorOp {
std::string Name() const override { return kNormalizeOp; } std::string Name() const override { return kNormalizeOp; }
private: private:
std::shared_ptr<Tensor> mean_; std::vector<float> mean_;
std::shared_ptr<Tensor> std_; std::vector<float> std_;
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -73,13 +73,18 @@ Status ValidateVectorColorAttribute(const std::string &op_name, const std::strin
Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float> &mean, Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float> &mean,
const std::vector<float> &std) { const std::vector<float> &std) {
if (mean.size() != 3) { if (mean.size() == 0) {
std::string err_msg = op_name + ": mean expecting size 3, got size: " + std::to_string(mean.size()); std::string err_msg = op_name + ": mean expecting non-empty vector";
MS_LOG(ERROR) << err_msg; MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg); RETURN_STATUS_SYNTAX_ERROR(err_msg);
} }
if (std.size() != 3) { if (std.size() == 0) {
std::string err_msg = op_name + ": std expecting size 3, got size: " + std::to_string(std.size()); std::string err_msg = op_name + ": std expecting non-empty vector";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (mean.size() != std.size()) {
std::string err_msg = op_name + ": mean and std vectors are expected to be of the same size";
MS_LOG(ERROR) << err_msg; MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg); RETURN_STATUS_SYNTAX_ERROR(err_msg);
} }

View File

@ -38,9 +38,7 @@ Status NormalizeOperation::ValidateParams() {
return Status::OK(); return Status::OK();
} }
std::shared_ptr<TensorOp> NormalizeOperation::Build() { std::shared_ptr<TensorOp> NormalizeOperation::Build() { return std::make_shared<NormalizeOp>(mean_, std_); }
return std::make_shared<NormalizeOp>(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]);
}
Status NormalizeOperation::to_json(nlohmann::json *out_json) { Status NormalizeOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;

View File

@ -374,10 +374,6 @@ class Normalize(ImageTensorOperation):
@check_normalize_c @check_normalize_c
def __init__(self, mean, std): def __init__(self, mean, std):
if len(mean) == 1:
mean = [mean[0]] * 3
if len(std) == 1:
std = [std[0]] * 3
self.mean = mean self.mean = mean
self.std = std self.std = std

View File

@ -35,11 +35,11 @@ TEST_F(MindDataTestNormalizeOP, TestOp) {
std::shared_ptr<Tensor> output_tensor; std::shared_ptr<Tensor> output_tensor;
// Numbers are from the resnet50 model implementation // Numbers are from the resnet50 model implementation
float mean[3] = {121.0, 115.0, 100.0}; std::vector<float> mean = {121.0, 115.0, 100.0};
float std[3] = {70.0, 68.0, 71.0}; std::vector<float> std = {70.0, 68.0, 71.0};
// Normalize Op // Normalize Op
std::unique_ptr<NormalizeOp> op(new NormalizeOp(mean[0], mean[1], mean[2], std[0], std[1], std[2])); std::unique_ptr<NormalizeOp> op(new NormalizeOp(mean, std));
EXPECT_TRUE(op->OneToOne()); EXPECT_TRUE(op->OneToOne());
Status s = op->Compute(input_tensor_, &output_tensor); Status s = op->Compute(input_tensor_, &output_tensor);
EXPECT_TRUE(s.IsOk()); EXPECT_TRUE(s.IsOk());

View File

@ -340,6 +340,35 @@ def test_normalize_grayscale_exception():
assert "Input is not within the required range" in str(e) assert "Input is not within the required range" in str(e)
def test_multiple_channels():
logger.info("test_multiple_channels")
def util_test(item, mean, std):
data = ds.NumpySlicesDataset([item], shuffle=False)
data = data.map(c_vision.Normalize(mean, std))
for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
actual = d[0]
mean = np.array(mean, dtype=item.dtype)
std = np.array(std, dtype=item.dtype)
expected = item
if len(item.shape) != 1 and len(mean) == 1:
mean = [mean[0]] * expected.shape[-1]
std = [std[0]] * expected.shape[-1]
if len(item.shape) == 2:
expected = np.expand_dims(expected, 2)
for c in range(expected.shape[-1]):
expected[:, :, c] = (expected[:, :, c] - mean[c]) / std[c]
expected = expected.squeeze()
np.testing.assert_almost_equal(actual, expected, decimal=6)
util_test(np.ones(shape=[2, 2, 3]), mean=[0.5, 0.6, 0.7], std=[0.1, 0.2, 0.3])
util_test(np.ones(shape=[20, 45, 3]) * 1.3, mean=[0.5, 0.6, 0.7], std=[0.1, 0.2, 0.3])
util_test(np.ones(shape=[20, 45, 4]) * 1.3, mean=[0.5, 0.6, 0.7, 0.8], std=[0.1, 0.2, 0.3, 0.4])
util_test(np.ones(shape=[2, 2]), mean=[0.5], std=[0.1])
util_test(np.ones(shape=[2, 2, 5]), mean=[0.5], std=[0.1])
if __name__ == "__main__": if __name__ == "__main__":
test_decode_op() test_decode_op()
test_decode_normalize_op() test_decode_normalize_op()