support multi-channel

This commit is contained in:
harshvardhangupta 2022-06-23 16:33:45 -04:00
parent 1dd2f514e6
commit c96df0b9ad
10 changed files with 136 additions and 67 deletions

View File

@ -989,25 +989,8 @@ Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
return Status::OK();
}
Status NormalizePad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std, const std::string &dtype,
bool is_hwc) {
mean->Squeeze();
std->Squeeze();
std::vector<float> mean_v;
std::vector<float> std_v;
for (int j = 0; j < kDefaultImageChannel; j++) {
float mean_c, std_c;
RETURN_IF_NOT_OK(mean->GetItemAt<float>(&mean_c, {j}));
RETURN_IF_NOT_OK(std->GetItemAt<float>(&std_c, {j}));
if (std_c <= 0.0) {
RETURN_STATUS_UNEXPECTED("NormalizePad: std vector element must be greater than 0.0, got: " +
std::to_string(std_c));
}
mean_v.push_back(mean_c);
std_v.push_back(std_c);
}
Status NormalizePad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, std::vector<float> mean,
std::vector<float> std, const std::string &dtype, bool is_hwc) {
int64_t channel_index = kChannelIndexCHW;
if (is_hwc) {
channel_index = kChannelIndexHWC;
@ -1016,6 +999,7 @@ Status NormalizePad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
if (input->Rank() == kDefaultImageRank) {
channels = input->shape()[channel_index];
}
if (is_hwc) {
TensorShape new_shape = TensorShape({input->shape()[0], input->shape()[1], channels + 1});
Tensor::CreateEmpty(new_shape, DataType(dtype), output);
@ -1027,21 +1011,21 @@ Status NormalizePad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
}
// caller provided 1 mean/std value and there are more than one channel --> duplicate mean/std value
if (mean_v.size() == 1 && (*output)->shape()[channel_index] != 1) {
for (int64_t i = 0; i < (*output)->shape()[channel_index] - 1; i++) {
mean_v.push_back(mean_v[0]);
std_v.push_back(std_v[0]);
if (mean.size() == 1 && channels > 1) {
while (mean.size() < channels) {
mean.push_back(mean[0]);
std.push_back(std[0]);
}
}
CHECK_FAIL_RETURN_UNEXPECTED((*output)->shape()[channel_index] == mean_v.size() + 1,
CHECK_FAIL_RETURN_UNEXPECTED((*output)->shape()[channel_index] == mean.size() + 1,
"NormalizePad: number of channels does not match the size of mean and std vectors, got "
"channels: " +
std::to_string((*output)->shape()[channel_index]) +
", size of mean: " + std::to_string(mean_v.size()));
", size of mean: " + std::to_string(mean.size()));
if (dtype == "float16") {
RETURN_IF_NOT_OK(Normalize_caller<float16>(input, output, mean_v, std_v, is_hwc, true));
RETURN_IF_NOT_OK(Normalize_caller<float16>(input, output, mean, std, is_hwc, true));
} else {
RETURN_IF_NOT_OK(Normalize_caller<float>(input, output, mean_v, std_v, is_hwc, true));
RETURN_IF_NOT_OK(Normalize_caller<float>(input, output, mean, std, is_hwc, true));
}
if (input->Rank() == kMinImageRank) {
(*output)->Squeeze();

View File

@ -236,16 +236,15 @@ Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, std::vector<float> mean,
std::vector<float> std, bool is_hwc);
/// \brief Returns Normalized and paded image
/// \brief Returns Normalized and padded image
/// \param input: Tensor of shape <H,W,C> in RGB order and any OpenCv compatible type, see CVTensor.
/// \param mean: Tensor of shape <3> and type DE_FLOAT32 which are mean of each channel in RGB order
/// \param std: Tensor of shape <3> and type DE_FLOAT32 which are std of each channel in RGB order
/// \param mean: vector of float values which are mean of each channel
/// \param std: vector of float values which are std of each channel
/// \param dtype: output dtype
/// \param is_hwc: Check if input is HWC/CHW format
/// \param output: Normalized image Tensor and pad an extra channel, return a dtype Tensor
Status NormalizePad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std, const std::string &dtype,
bool is_hwc);
Status NormalizePad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, std::vector<float> mean,
std::vector<float> std, const std::string &dtype, bool is_hwc);
/// \brief Returns image with adjusted brightness.
/// \param input: Tensor of shape <H,W,3> in RGB order and any OpenCv compatible type, see CVTensor.

View File

@ -22,21 +22,9 @@
namespace mindspore {
namespace dataset {
NormalizePadOp::NormalizePadOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b,
std::string dtype, bool is_hwc) {
Status s = Tensor::CreateFromVector<float>({mean_r, mean_g, mean_b}, &mean_);
if (s.IsError()) {
MS_LOG(ERROR) << "NormalizePad: invalid mean value, got: (" + std::to_string(mean_r) + std::to_string(mean_g) +
std::to_string(mean_b) + ").";
}
s = Tensor::CreateFromVector<float>({std_r, std_g, std_b}, &std_);
if (s.IsError()) {
MS_LOG(ERROR) << "NormalizePad: invalid std value, got: (" + std::to_string(std_r) + std::to_string(std_g) +
std::to_string(std_b) + ").";
}
dtype_ = dtype;
is_hwc_ = is_hwc;
}
NormalizePadOp::NormalizePadOp(const std::vector<float> &mean, const std::vector<float> &std, std::string dtype,
bool is_hwc)
: mean_(mean), std_(std), dtype_(dtype), is_hwc_(is_hwc) {}
Status NormalizePadOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
@ -45,7 +33,16 @@ Status NormalizePadOp::Compute(const std::shared_ptr<Tensor> &input, std::shared
}
void NormalizePadOp::Print(std::ostream &out) const {
out << "NormalizeOp, mean: " << *(mean_.get()) << std::endl << "std: " << *(std_.get()) << std::endl;
out << "NormalizePadOp, mean: ";
for (const auto &m : mean_) {
out << m << ", ";
}
out << "}" << std::endl << "std: ";
for (const auto &s : std_) {
out << s << ", ";
}
out << "}" << std::endl << "is_hwc: " << is_hwc_;
out << "}" << std::endl;
}
} // namespace dataset
} // namespace mindspore

View File

@ -18,6 +18,7 @@
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
@ -27,8 +28,8 @@ namespace mindspore {
namespace dataset {
class NormalizePadOp : public TensorOp {
public:
NormalizePadOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b,
std::string dtype = "float32", bool is_hwc = true);
NormalizePadOp(const std::vector<float> &mean, const std::vector<float> &std, std::string dtype = "float32",
bool is_hwc = true);
~NormalizePadOp() override = default;
@ -39,8 +40,8 @@ class NormalizePadOp : public TensorOp {
std::string Name() const override { return kNormalizePadOp; }
private:
std::shared_ptr<Tensor> mean_;
std::shared_ptr<Tensor> std_;
std::vector<float> mean_;
std::vector<float> std_;
std::string dtype_;
bool is_hwc_;
};

View File

@ -47,12 +47,7 @@ Status NormalizePadOperation::ValidateParams() {
}
std::shared_ptr<TensorOp> NormalizePadOperation::Build() {
constexpr size_t dimension_zero = 0;
constexpr size_t dimension_one = 1;
constexpr size_t dimension_two = 2;
return std::make_shared<NormalizePadOp>(mean_[dimension_zero], mean_[dimension_one], mean_[dimension_two],
std_[dimension_zero], std_[dimension_one], std_[dimension_two], dtype_,
is_hwc_);
return std::make_shared<NormalizePadOp>(mean_, std_, dtype_, is_hwc_);
}
Status NormalizePadOperation::to_json(nlohmann::json *out_json) {

View File

@ -262,6 +262,11 @@ class Cutout(py_transforms.PyTensorOperation):
'img dimension should be 3. Got {}.'.format(np_img.ndim))
_, image_h, image_w = np_img.shape
if self.length > image_h or self.length > image_w:
raise ValueError(
f"Patch length is too large, got patch length: {self.length} and image height: {image_h}, image "
f"width: {image_w}")
scale = (self.length * self.length) / (image_h * image_w)
bounded = False

View File

@ -35,12 +35,11 @@ TEST_F(MindDataTestNormalizePadOP, TestFloat32) {
std::shared_ptr<Tensor> output_tensor;
// Numbers are from the resnet50 model implementation
float mean[3] = {121.0, 115.0, 100.0};
float std[3] = {70.0, 68.0, 71.0};
std::vector<float> mean = {121.0, 115.0, 100.0};
std::vector<float> std = {70.0, 68.0, 71.0};
// NormalizePad Op
std::unique_ptr<NormalizePadOp> op =
std::make_unique<NormalizePadOp>(mean[0], mean[1], mean[2], std[0], std[1], std[2], "float32", true);
std::unique_ptr<NormalizePadOp> op = std::make_unique<NormalizePadOp>(mean, std, "float32", true);
EXPECT_TRUE(op->OneToOne());
Status s = op->Compute(input_tensor_, &output_tensor);
EXPECT_TRUE(s.IsOk());
@ -54,12 +53,11 @@ TEST_F(MindDataTestNormalizePadOP, TestFloat16) {
std::shared_ptr<Tensor> output_tensor;
// Numbers are from the resnet50 model implementation
float mean[3] = {121.0, 115.0, 100.0};
float std[3] = {70.0, 68.0, 71.0};
std::vector<float> mean = {121.0, 115.0, 100.0};
std::vector<float> std = {70.0, 68.0, 71.0};
// NormalizePad Op
std::unique_ptr<NormalizePadOp> op =
std::make_unique<NormalizePadOp>(mean[0], mean[1], mean[2], std[0], std[1], std[2], "float16", true);
std::unique_ptr<NormalizePadOp> op = std::make_unique<NormalizePadOp>(mean, std, "float16", true);
EXPECT_TRUE(op->OneToOne());
Status s = op->Compute(input_tensor_, &output_tensor);
EXPECT_TRUE(s.IsOk());

View File

@ -16,6 +16,7 @@
Testing CutOut op in DE
"""
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms
@ -288,6 +289,19 @@ def test_cutout_4channel_hwc():
op(image)
def test_cut_out_validation():
"""
Feature: CutOut op
Description: Test CutOut Op with patch length greater than image dimensions
Expectation: Raises an exception
"""
image = np.random.randn(3, 1024, 856).astype(np.uint8)
op = vision.CutOut(length=1500, num_patches=3, is_hwc=False)
with pytest.raises(RuntimeError) as errinfo:
op(image)
assert 'box size is too large for image erase' in str(errinfo.value)
if __name__ == "__main__":
test_cut_out_op(plot=True)
test_cut_out_op_multicut(plot=True)
@ -296,3 +310,4 @@ if __name__ == "__main__":
test_cut_out_comp_chw()
test_cutout_4channel_chw()
test_cutout_4channel_hwc()
test_cut_out_validation()

View File

@ -24,6 +24,7 @@ from util import diff_mse, visualize_image
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
MNIST_DATA_DIR = "../data/dataset/testMnistData"
GENERATE_GOLDEN = False
@ -188,6 +189,62 @@ def test_decode_normalizepad_op():
num_iter += 1
def test_multi_channel_normalizepad_chw():
"""
Feature: NormalizePad op
Description: Test NormalizePad Op with multi-channel CHW input
Expectation: Test succeeds.
"""
mean = [0.475, 0.45, 0.392, 0.5]
std = [0.275, 0.267, 0.278, 0.3]
image = np.random.randn(4, 1024, 856).astype(np.uint8)
op = vision.NormalizePad(mean, std, is_hwc=False)
op(image)
def test_multi_channel_normalizepad_hwc():
"""
Feature: NormalizePad op
Description: Test NormalizePad Op with multi-channel HWC input
Expectation: Test succeeds.
"""
mean = [0.475, 0.45, 0.392, 0.5]
std = [0.275, 0.267, 0.278, 0.3]
image = np.random.randn(1024, 856, 4).astype(np.uint8)
op = vision.NormalizePad(mean, std, is_hwc=True)
op(image)
def test_normalizepad_op_1channel(plot=False):
"""
Feature: NormalizePad op
Description: Test NormalizePad Op with single channel input
Expectation: Test succeeds. MSE difference is negligible.
"""
logger.info("Test NormalizePad Single Channel with HWC")
mean = [121.0]
std = [70.0]
normalizepad_op = vision.NormalizePad(mean, std, is_hwc=True)
# First dataset
data2 = ds.MnistDataset(MNIST_DATA_DIR, shuffle=False)
data1 = data2.map(operations=normalizepad_op, input_columns=["image"])
num_iter = 0
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
image_de_normalized = item1["image"]
image_original = item2["image"]
image_np_normalized = normalizepad_np(image_original, mean, std)
mse = diff_mse(image_de_normalized, image_np_normalized)
logger.info("image_{}, mse: {}".format(num_iter + 1, mse))
assert mse < 0.01
if plot:
visualize_image(image_original, image_de_normalized, mse, image_np_normalized)
num_iter += 1
assert num_iter == 10000
def test_normalizepad_exception_unequal_size_1():
"""
Feature: NormalizePad op
@ -259,6 +316,9 @@ if __name__ == "__main__":
test_normalizepad_op_chw(plot=True)
test_normalizepad_op_comp_chw()
test_decode_normalizepad_op()
test_multi_channel_normalizepad_chw()
test_multi_channel_normalizepad_hwc()
test_normalizepad_exception_unequal_size_1()
test_normalizepad_exception_unequal_size_2()
test_normalizepad_exception_invalid_range()
test_normalizepad_op_1channel()

View File

@ -16,6 +16,7 @@
Testing CutOut op in DE
"""
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms.py_transforms
@ -219,8 +220,22 @@ def test_cut_out_comp(plot=False):
visualize_list(image_list_1, image_list_2, visualize_mode=2)
def test_cut_out_validation():
"""
Feature: CutOut op
Description: Test CutOut Op with patch length greater than image dimensions
Expectation: Raises an exception
"""
image = np.random.randn(3, 1024, 856).astype(np.uint8)
op = f.Cutout(length=1500, num_patches=3)
with pytest.raises(ValueError) as errinfo:
op(image)
assert 'Patch length is too large' in str(errinfo.value)
if __name__ == "__main__":
test_cut_out_op(plot=True)
test_cut_out_op_multicut(plot=True)
test_cut_out_md5()
test_cut_out_comp(plot=True)
test_cut_out_validation()