[feat][assistant][I4S2FM]expand the normalization operation of high dimension
This commit is contained in:
parent
d6e6872183
commit
d443cee406
|
@ -18,6 +18,7 @@
|
|||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/kernels/data/data_utils.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#else
|
||||
|
@ -33,11 +34,41 @@ NormalizeOp::NormalizeOp(const std::vector<float> &mean, const std::vector<float
|
|||
Status NormalizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
// Doing the Normalization
|
||||
auto input_shape = input->shape();
|
||||
dsize_t rank = input_shape.Rank();
|
||||
if (rank < kMinImageRank) {
|
||||
std::string err_msg = "Normalize: input tensor should have at least 2 dimensions, but got: " + std::to_string(rank);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
} else if (rank <= kDefaultImageRank) {
|
||||
// [H, W] or [H, W, C]
|
||||
#ifndef ENABLE_ANDROID
|
||||
return Normalize(input, output, mean_, std_, is_hwc_);
|
||||
return Normalize(input, output, mean_, std_, is_hwc_);
|
||||
#else
|
||||
return Normalize(input, output, mean_, std_);
|
||||
return Normalize(input, output, mean_, std_);
|
||||
#endif
|
||||
} else {
|
||||
// reshape [..., H, W, C] to [N, H, W, C]
|
||||
dsize_t num_batch = input->Size() / (input_shape[-3] * input_shape[-2] * input_shape[-1]);
|
||||
TensorShape new_shape({num_batch, input_shape[-3], input_shape[-2], input_shape[-1]});
|
||||
RETURN_IF_NOT_OK(input->Reshape(new_shape));
|
||||
|
||||
// split [N, H, W, C] to N [H, W, C], and normalize N [H, W, C]
|
||||
std::vector<std::shared_ptr<Tensor>> input_vector_hwc, output_vector_hwc;
|
||||
RETURN_IF_NOT_OK(BatchTensorToTensorVector(input, &input_vector_hwc));
|
||||
for (auto input_hwc : input_vector_hwc) {
|
||||
std::shared_ptr<Tensor> normalize;
|
||||
#ifndef ENABLE_ANDROID
|
||||
RETURN_IF_NOT_OK(Normalize(input_hwc, &normalize, mean_, std_, is_hwc_));
|
||||
#else
|
||||
RETURN_IF_NOT_OK(Normalize(input_hwc, &normalize, mean_, std_));
|
||||
#endif
|
||||
output_vector_hwc.push_back(normalize);
|
||||
}
|
||||
// integrate N [H, W, C] to [N, H, W, C], and reshape [..., H, W, C]
|
||||
RETURN_IF_NOT_OK(TensorVectorToBatchTensor(output_vector_hwc, &(*output)));
|
||||
RETURN_IF_NOT_OK((*output)->Reshape(input_shape));
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
void NormalizeOp::Print(std::ostream &out) const {
|
||||
|
|
|
@ -819,8 +819,8 @@ class Normalize(ImageTensorOperation):
|
|||
TypeError: If `mean` is not of type sequence.
|
||||
TypeError: If `std` is not of type sequence.
|
||||
ValueError: If `mean` is not in range [0.0, 255.0].
|
||||
ValueError: If `mean` is not in range (0.0, 255.0].
|
||||
RuntimeError: If given tensor shape is not <H, W> or <H, W, C>.
|
||||
ValueError: If `std` is not in range (0.0, 255.0].
|
||||
RuntimeError: If given tensor shape is not <H, W> or <...,H, W, C>.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU`` ``Ascend`` ``GPU``
|
||||
|
|
|
@ -1401,7 +1401,8 @@ class Normalize(ImageTensorOperation):
|
|||
TypeError: If `std` is not of type sequence.
|
||||
TypeError: If `is_hwc` is not of type bool.
|
||||
ValueError: If `mean` is not in range [0.0, 255.0].
|
||||
ValueError: If `mean` is not in range (0.0, 255.0].
|
||||
ValueError: If `std` is not in range (0.0, 255.0].
|
||||
RuntimeError: If given tensor shape is not <H, W> or <...,H, W, C>.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
#include "common/common.h"
|
||||
#include "common/cvop_common.h"
|
||||
#include "minddata/dataset/kernels/data/data_utils.h"
|
||||
#include "minddata/dataset/kernels/image/normalize_op.h"
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
@ -27,9 +28,9 @@ class MindDataTestNormalizeOP : public UT::CVOP::CVOpCommon {
|
|||
MindDataTestNormalizeOP() : CVOpCommon() {}
|
||||
};
|
||||
|
||||
/// Feature: Normalize op
|
||||
/// Description: Test input parameters at Runtime level
|
||||
/// Expectation: Results are successfully outputted.
|
||||
/// Feature: Normalize
|
||||
/// Description: Normalize the image and save
|
||||
/// Expectation: normalized image saves successfully
|
||||
TEST_F(MindDataTestNormalizeOP, TestOp) {
|
||||
MS_LOG(INFO) << "Doing TestNormalizeOp::TestOp2.";
|
||||
std::shared_ptr<Tensor> output_tensor;
|
||||
|
@ -55,3 +56,40 @@ TEST_F(MindDataTestNormalizeOP, TestOp) {
|
|||
cv::FileStorage file(output_filename, cv::FileStorage::WRITE);
|
||||
file << "imageData" << cv_output_image;
|
||||
}
|
||||
|
||||
/// Feature: Normalize
|
||||
/// Description: Test Normalize with 4 dimension tensor
|
||||
/// Expectation: The result is as expected
|
||||
TEST_F(MindDataTestNormalizeOP, TestOp4Dim) {
|
||||
MS_LOG(INFO) << "Doing TestNormalizeOp-TestOp4Dim.";
|
||||
std::shared_ptr<Tensor> output_tensor;
|
||||
|
||||
// construct a fake 4 dimension data
|
||||
std::shared_ptr<Tensor> input_tensor_cp;
|
||||
ASSERT_OK(Tensor::CreateFromTensor(input_tensor_, &input_tensor_cp));
|
||||
std::vector<std::shared_ptr<Tensor>> tensor_list;
|
||||
tensor_list.push_back(input_tensor_cp);
|
||||
tensor_list.push_back(input_tensor_cp);
|
||||
TensorShape shape = input_tensor_cp->shape();
|
||||
std::shared_ptr<Tensor> input_4d;
|
||||
ASSERT_OK(TensorVectorToBatchTensor(tensor_list, &input_4d));
|
||||
std::vector<float> mean = {121.0, 115.0, 100.0};
|
||||
std::vector<float> std = {70.0, 68.0, 71.0};
|
||||
|
||||
// Normalize Op
|
||||
std::unique_ptr<NormalizeOp> op = std::make_unique<NormalizeOp>(mean, std, true);
|
||||
EXPECT_TRUE(op->OneToOne());
|
||||
Status s = op->Compute(input_4d, &output_tensor);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
|
||||
std::string output_filename = GetFilename();
|
||||
output_filename.replace(output_filename.end() - 8, output_filename.end(), "imagefolder/normalizeOpVideoOut.yml");
|
||||
|
||||
std::shared_ptr<CVTensor> p = CVTensor::AsCVTensor(output_tensor);
|
||||
cv::Mat cv_output_video;
|
||||
cv_output_video = p->mat();
|
||||
|
||||
MS_LOG(DEBUG) << "Storing output file to : " << output_filename << std::endl;
|
||||
cv::FileStorage file(output_filename, cv::FileStorage::WRITE);
|
||||
file << "videoData" << cv_output_video;
|
||||
}
|
||||
|
|
|
@ -160,6 +160,65 @@ def test_normalize_op_chw(plot=False):
|
|||
num_iter += 1
|
||||
|
||||
|
||||
def test_normalize_op_video():
|
||||
"""
|
||||
Feature: Normalize op
|
||||
Description: Test NormalizeOp in Cpp transformation with 4 dimension input,
|
||||
where the input tensor is (..., T, H, W, C)
|
||||
Expectation: The dataset is processed successfully
|
||||
"""
|
||||
logger.info("Test NormalizeOp in cpp transformations with 4 dimension input")
|
||||
mean = [121.0, 115.0, 100.0]
|
||||
std = [70.0, 68.0, 71.0]
|
||||
input_np_original = np.array([[87, 88, 232, 239],
|
||||
[11, 229, 22, 79],
|
||||
[250, 20, 173, 213]], dtype=np.float32)
|
||||
expect_output = np.array([[-1.0714285, -1.3088236, 0.11267605, 1.7142857],
|
||||
[-0.35211268, -0.55714285, 1.0735294, 0.52112675],
|
||||
[-0.27941176, 0.43661973, 1.3428571, -0.9411765]], dtype=np.float32)
|
||||
shape = (2, 2, 1, 3)
|
||||
input_np_original = input_np_original.reshape(shape)
|
||||
expect_output = expect_output.reshape(shape)
|
||||
|
||||
# define operations
|
||||
normalize_op = vision.Normalize(mean, std, True)
|
||||
|
||||
# doing the Normalization
|
||||
vidio_de_normalized = normalize_op(input_np_original)
|
||||
|
||||
mse = diff_mse(vidio_de_normalized, expect_output)
|
||||
assert mse < 0.01
|
||||
|
||||
|
||||
def test_normalize_op_5d():
|
||||
"""
|
||||
Feature: Normalize op
|
||||
Description: Test NormalizeOp in Cpp transformation with 5 dim input, where the input tensor is (..., T, H, W, C)
|
||||
Expectation: The dataset is processed successfully
|
||||
"""
|
||||
logger.info("Test NormalizeOp in cpp transformations with 5 dimension input")
|
||||
mean = [121.0, 115.0, 100.0]
|
||||
std = [70.0, 68.0, 71.0]
|
||||
input_np_original = np.array([[87, 88, 232, 239],
|
||||
[11, 229, 22, 79],
|
||||
[250, 20, 173, 213]], dtype=np.float32)
|
||||
expect_output = np.array([[-1.0714285, -1.3088236, 0.11267605, 1.7142857],
|
||||
[-0.35211268, -0.55714285, 1.0735294, 0.52112675],
|
||||
[-0.27941176, 0.43661973, 1.3428571, -0.9411765]], dtype=np.float32)
|
||||
shape = (2, 1, 2, 1, 3)
|
||||
input_np_original = input_np_original.reshape(shape)
|
||||
expect_output = expect_output.reshape(shape)
|
||||
|
||||
# define operations
|
||||
normalize_op = vision.Normalize(mean, std, True)
|
||||
|
||||
# doing the Normalization
|
||||
vidio_de_normalized = normalize_op(input_np_original)
|
||||
|
||||
mse = diff_mse(vidio_de_normalized, expect_output)
|
||||
assert mse < 0.01
|
||||
|
||||
|
||||
def test_decode_op():
|
||||
"""
|
||||
Feature: Decode op
|
||||
|
@ -466,6 +525,8 @@ if __name__ == "__main__":
|
|||
test_decode_normalize_op()
|
||||
test_normalize_op_hwc(plot=True)
|
||||
test_normalize_op_chw(plot=True)
|
||||
test_normalize_op_video()
|
||||
test_normalize_op_5d()
|
||||
test_normalize_md5_01()
|
||||
test_normalize_md5_02()
|
||||
test_normalize_exception_unequal_size_1()
|
||||
|
|
Loading…
Reference in New Issue